未验证 提交 02c880d3 编写于 作者: L LielinJiang 提交者: GitHub

Merge pull request #146 from LielinJiang/laneNet_1125

Add LaneNet and weighted cross entropy
# LaneNet 模型训练教程
* 本教程旨在介绍如何通过使用PaddleSeg进行车道线检测
* 在阅读本教程前,请确保您已经了解过PaddleSeg的[快速入门](../README.md#快速入门)[基础功能](../README.md#基础功能)等章节,以便对PaddleSeg有一定的了解
## 环境依赖
* PaddlePaddle >= 1.7.0 或develop版本
* Python 3.5+
通过以下命令安装python包依赖,请确保在该分支上至少执行过一次以下命令
```shell
$ pip install -r requirements.txt
```
## 一. 准备待训练数据
我们提前准备好了一份处理好的数据集,通过以下代码进行下载,该数据集由图森车道线检测数据集转换而来,你也可以在这个[页面](https://github.com/TuSimple/tusimple-benchmark/issues/3)下载原始数据集。
```shell
python dataset/download_tusimple.py
```
数据目录结构
```
LaneNet
|-- dataset
|-- tusimple_lane_detection
|-- training
|-- gt_binary_image
|-- gt_image
|-- gt_instance_image
|-- train_part.txt
|-- val_part.txt
```
## 二. 下载预训练模型
下载[vgg预训练模型](https://paddle-imagenet-models-name.bj.bcebos.com/VGG16_pretrained.tar),放在```pretrained_models```文件夹下。
## 三. 准备配置
接着我们需要确定相关配置,从本教程的角度,配置分为三部分:
* 数据集
* 训练集主目录
* 训练集文件列表
* 测试集文件列表
* 评估集文件列表
* 预训练模型
* 预训练模型名称
* 预训练模型的backbone网络
* 预训练模型路径
* 其他
* 学习率
* Batch大小
* ...
在三者中,预训练模型的配置尤为重要,如果模型或者BACKBONE配置错误,会导致预训练的参数没有加载,进而影响收敛速度。预训练模型相关的配置如第二步所展示。
数据集的配置和数据路径有关,在本教程中,数据存放在`dataset/tusimple_lane_detection`
其他配置则根据数据集和机器环境的情况进行调节,最终我们保存一个如下内容的yaml配置文件,存放路径为**configs/lanenet.yaml**
```yaml
# 数据集配置
DATASET:
DATA_DIR: "./dataset/tusimple_lane_detection"
IMAGE_TYPE: "rgb" # choice rgb or rgba
NUM_CLASSES: 2
TEST_FILE_LIST: "./dataset/tusimple_lane_detection/training/val_part.txt"
TRAIN_FILE_LIST: "./dataset/tusimple_lane_detection/training/train_part.txt"
VAL_FILE_LIST: "./dataset/tusimple_lane_detection/training/val_part.txt"
SEPARATOR: " "
# 预训练模型配置
MODEL:
MODEL_NAME: "lanenet"
# 其他配置
EVAL_CROP_SIZE: (512, 256)
TRAIN_CROP_SIZE: (512, 256)
AUG:
AUG_METHOD: u"unpadding" # choice unpadding rangescaling and stepscaling
FIX_RESIZE_SIZE: (512, 256) # (width, height), for unpadding
MIRROR: False
RICH_CROP:
ENABLE: False
BATCH_SIZE: 4
TEST:
TEST_MODEL: "./saved_model/lanenet/final/"
TRAIN:
MODEL_SAVE_DIR: "./saved_model/lanenet/"
PRETRAINED_MODEL_DIR: "./pretrained_models/VGG16_pretrained"
SNAPSHOT_EPOCH: 5
SOLVER:
NUM_EPOCHS: 100
LR: 0.0005
LR_POLICY: "poly"
OPTIMIZER: "sgd"
WEIGHT_DECAY: 0.001
```
## 五. 开始训练
使用下述命令启动训练
```shell
CUDA_VISIBLE_DEVICES=0 python -u train.py --cfg configs/lanenet.yaml --use_gpu --use_mpio --do_eval
```
## 六. 进行评估
模型训练完成,使用下述命令启动评估
```shell
CUDA_VISIBLE_DEVICES=0 python -u eval.py --use_gpu --cfg configs/lanenet.yaml
```
## 七. 可视化
需要先下载一个车前视角和鸟瞰图视角转换所需文件,点击[链接](https://paddleseg.bj.bcebos.com/resources/tusimple_ipm_remap.tar),下载后放在```./utils```下。同时我们提供了一个训练好的模型,点击[链接](https://paddleseg.bj.bcebos.com/models/lanenet_vgg_tusimple.tar),下载后放在```./pretrained_models/```下,使用如下命令进行可视化
```shell
CUDA_VISIBLE_DEVICES=0 python -u ./vis.py --cfg configs/lanenet.yaml --use_gpu --vis_dir vis_result \
TEST.TEST_MODEL pretrained_models/LaneNet_vgg_tusimple/
```
可视化结果示例:
预测结果:<br/>
![](imgs/0005_pred_lane.png)
分割结果:<br/>
![](imgs/0005_pred_binary.png)<br/>
车道线实例预测结果:<br/>
![](imgs/0005_pred_instance.png)
EVAL_CROP_SIZE: (512, 256) # (width, height), for unpadding rangescaling and stepscaling
TRAIN_CROP_SIZE: (512, 256) # (width, height), for unpadding rangescaling and stepscaling
AUG:
AUG_METHOD: u"unpadding" # choice unpadding rangescaling and stepscaling
FIX_RESIZE_SIZE: (512, 256) # (width, height), for unpadding
INF_RESIZE_VALUE: 500 # for rangescaling
MAX_RESIZE_VALUE: 600 # for rangescaling
MIN_RESIZE_VALUE: 400 # for rangescaling
MAX_SCALE_FACTOR: 2.0 # for stepscaling
MIN_SCALE_FACTOR: 0.5 # for stepscaling
SCALE_STEP_SIZE: 0.25 # for stepscaling
MIRROR: False
RICH_CROP:
ENABLE: False
BATCH_SIZE: 4
DATALOADER:
BUF_SIZE: 256
NUM_WORKERS: 4
DATASET:
DATA_DIR: "./dataset/tusimple_lane_detection"
IMAGE_TYPE: "rgb" # choice rgb or rgba
NUM_CLASSES: 2
TEST_FILE_LIST: "./dataset/tusimple_lane_detection/training/val_part.txt"
TEST_TOTAL_IMAGES: 362
TRAIN_FILE_LIST: "./dataset/tusimple_lane_detection/training/train_part.txt"
TRAIN_TOTAL_IMAGES: 3264
VAL_FILE_LIST: "./dataset/tusimple_lane_detection/training/val_part.txt"
VAL_TOTAL_IMAGES: 362
SEPARATOR: " "
IGNORE_INDEX: 255
FREEZE:
MODEL_FILENAME: "__model__"
PARAMS_FILENAME: "__params__"
MODEL:
MODEL_NAME: "lanenet"
DEFAULT_NORM_TYPE: "bn"
TEST:
TEST_MODEL: "./saved_model/lanenet/final/"
TRAIN:
MODEL_SAVE_DIR: "./saved_model/lanenet/"
PRETRAINED_MODEL_DIR: "./pretrained_models/VGG16_pretrained"
SNAPSHOT_EPOCH: 1
SOLVER:
NUM_EPOCHS: 100
LR: 0.0005
LR_POLICY: "poly"
OPTIMIZER: "sgd"
WEIGHT_DECAY: 0.001
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import cv2
import numpy as np
from utils.config import cfg
from models.model_builder import ModelPhase
from pdseg.data_aug import get_random_scale, randomly_scale_image_and_label, random_rotation, \
rand_scale_aspect, hsv_color_jitter, rand_crop
def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN):
"""
改变图像及标签图像尺寸
AUG.AUG_METHOD为unpadding,所有模式均直接resize到AUG.FIX_RESIZE_SIZE的尺寸
AUG.AUG_METHOD为stepscaling, 按比例resize,训练时比例范围AUG.MIN_SCALE_FACTOR到AUG.MAX_SCALE_FACTOR,间隔为AUG.SCALE_STEP_SIZE,其他模式返回原图
AUG.AUG_METHOD为rangescaling,长边对齐,短边按比例变化,训练时长边对齐范围AUG.MIN_RESIZE_VALUE到AUG.MAX_RESIZE_VALUE,其他模式长边对齐AUG.INF_RESIZE_VALUE
Args:
img(numpy.ndarray): 输入图像
grt(numpy.ndarray): 标签图像,默认为None
mode(string): 模式, 默认训练模式,即ModelPhase.TRAIN
Returns:
resize后的图像和标签图
"""
if cfg.AUG.AUG_METHOD == 'unpadding':
target_size = cfg.AUG.FIX_RESIZE_SIZE
img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)
if grt is not None:
grt = cv2.resize(grt, target_size, interpolation=cv2.INTER_NEAREST)
if grt_instance is not None:
grt_instance = cv2.resize(grt_instance, target_size, interpolation=cv2.INTER_NEAREST)
elif cfg.AUG.AUG_METHOD == 'stepscaling':
if mode == ModelPhase.TRAIN:
min_scale_factor = cfg.AUG.MIN_SCALE_FACTOR
max_scale_factor = cfg.AUG.MAX_SCALE_FACTOR
step_size = cfg.AUG.SCALE_STEP_SIZE
scale_factor = get_random_scale(min_scale_factor, max_scale_factor,
step_size)
img, grt = randomly_scale_image_and_label(
img, grt, scale=scale_factor)
elif cfg.AUG.AUG_METHOD == 'rangescaling':
min_resize_value = cfg.AUG.MIN_RESIZE_VALUE
max_resize_value = cfg.AUG.MAX_RESIZE_VALUE
if mode == ModelPhase.TRAIN:
if min_resize_value == max_resize_value:
random_size = min_resize_value
else:
random_size = int(
np.random.uniform(min_resize_value, max_resize_value) + 0.5)
else:
random_size = cfg.AUG.INF_RESIZE_VALUE
value = max(img.shape[0], img.shape[1])
scale = float(random_size) / float(value)
img = cv2.resize(
img, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
if grt is not None:
grt = cv2.resize(
grt, (0, 0),
fx=scale,
fy=scale,
interpolation=cv2.INTER_NEAREST)
else:
raise Exception("Unexpect data augmention method: {}".format(
cfg.AUG.AUG_METHOD))
return img, grt, grt_instance
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
TEST_PATH = os.path.join(LOCAL_PATH, "../../../", "test")
sys.path.append(TEST_PATH)
from test_utils import download_file_and_uncompress
def download_tusimple_dataset(savepath, extrapath):
url = "https://paddleseg.bj.bcebos.com/dataset/tusimple_lane_detection.tar"
download_file_and_uncompress(
url=url, savepath=savepath, extrapath=extrapath)
if __name__ == "__main__":
download_tusimple_dataset(LOCAL_PATH, LOCAL_PATH)
print("Dataset download finish!")
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# GPU memory garbage collection optimization flags
os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
import sys
cur_path = os.path.abspath(os.path.dirname(__file__))
root_path = os.path.split(os.path.split(cur_path)[0])[0]
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
SEG_PATH = os.path.join(LOCAL_PATH, "../../../")
sys.path.append(SEG_PATH)
sys.path.append(root_path)
import time
import argparse
import functools
import pprint
import cv2
import numpy as np
import paddle
import paddle.fluid as fluid
from utils.config import cfg
from pdseg.utils.timer import Timer, calculate_eta
from models.model_builder import build_model
from models.model_builder import ModelPhase
from reader import LaneNetDataset
def parse_args():
parser = argparse.ArgumentParser(description='PaddleSeg model evalution')
parser.add_argument(
'--cfg',
dest='cfg_file',
help='Config file for training (and optionally testing)',
default=None,
type=str)
parser.add_argument(
'--use_gpu',
dest='use_gpu',
help='Use gpu or cpu',
action='store_true',
default=False)
parser.add_argument(
'--use_mpio',
dest='use_mpio',
help='Use multiprocess IO or not',
action='store_true',
default=False)
parser.add_argument(
'opts',
help='See utils/config.py for all options',
default=None,
nargs=argparse.REMAINDER)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs):
np.set_printoptions(precision=5, suppress=True)
startup_prog = fluid.Program()
test_prog = fluid.Program()
dataset = LaneNetDataset(
file_list=cfg.DATASET.VAL_FILE_LIST,
mode=ModelPhase.TRAIN,
shuffle=True,
data_dir=cfg.DATASET.DATA_DIR)
def data_generator():
#TODO: check is batch reader compatitable with Windows
if use_mpio:
data_gen = dataset.multiprocess_generator(
num_processes=cfg.DATALOADER.NUM_WORKERS,
max_queue_size=cfg.DATALOADER.BUF_SIZE)
else:
data_gen = dataset.generator()
for b in data_gen:
yield b
py_reader, pred, grts, masks, accuracy, fp, fn = build_model(
test_prog, startup_prog, phase=ModelPhase.EVAL)
py_reader.decorate_sample_generator(
data_generator, drop_last=False, batch_size=cfg.BATCH_SIZE)
# Get device environment
places = fluid.cuda_places() if use_gpu else fluid.cpu_places()
place = places[0]
dev_count = len(places)
print("#Device count: {}".format(dev_count))
exe = fluid.Executor(place)
exe.run(startup_prog)
test_prog = test_prog.clone(for_test=True)
ckpt_dir = cfg.TEST.TEST_MODEL if not ckpt_dir else ckpt_dir
if ckpt_dir is not None:
print('load test model:', ckpt_dir)
fluid.io.load_params(exe, ckpt_dir, main_program=test_prog)
# Use streaming confusion matrix to calculate mean_iou
np.set_printoptions(
precision=4, suppress=True, linewidth=160, floatmode="fixed")
fetch_list = [pred.name, grts.name, masks.name, accuracy.name, fp.name, fn.name]
num_images = 0
step = 0
avg_acc = 0.0
avg_fp = 0.0
avg_fn = 0.0
# cur_images = 0
all_step = cfg.DATASET.TEST_TOTAL_IMAGES // cfg.BATCH_SIZE + 1
timer = Timer()
timer.start()
py_reader.start()
while True:
try:
step += 1
pred, grts, masks, out_acc, out_fp, out_fn = exe.run(
test_prog, fetch_list=fetch_list, return_numpy=True)
avg_acc += np.mean(out_acc) * pred.shape[0]
avg_fp += np.mean(out_fp) * pred.shape[0]
avg_fn += np.mean(out_fn) * pred.shape[0]
num_images += pred.shape[0]
speed = 1.0 / timer.elapsed_time()
print(
"[EVAL]step={} accuracy={:.4f} fp={:.4f} fn={:.4f} step/sec={:.2f} | ETA {}"
.format(step, avg_acc / num_images, avg_fp / num_images, avg_fn / num_images, speed,
calculate_eta(all_step - step, speed)))
timer.restart()
sys.stdout.flush()
except fluid.core.EOFException:
break
print("[EVAL]#image={} accuracy={:.4f} fp={:.4f} fn={:.4f}".format(
num_images, avg_acc / num_images, avg_fp / num_images, avg_fn / num_images))
return avg_acc / num_images, avg_fp / num_images, avg_fn / num_images
def main():
args = parse_args()
if args.cfg_file is not None:
cfg.update_from_file(args.cfg_file)
if args.opts:
cfg.update_from_list(args.opts)
cfg.check_and_infer()
print(pprint.pformat(cfg))
evaluate(cfg, **args.__dict__)
if __name__ == '__main__':
main()
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import numpy as np
from utils.config import cfg
def unsorted_segment_sum(data, segment_ids, unique_labels, feature_dims):
zeros = fluid.layers.fill_constant_batch_size_like(unique_labels, shape=[1, feature_dims],
dtype='float32', value=0)
segment_ids = fluid.layers.unsqueeze(segment_ids, axes=[1])
segment_ids.stop_gradient = True
segment_sum = fluid.layers.scatter_nd_add(zeros, segment_ids, data)
zeros.stop_gradient = True
return segment_sum
def norm(x, axis=-1):
distance = fluid.layers.reduce_sum(fluid.layers.abs(x), dim=axis, keep_dim=True)
return distance
def discriminative_loss_single(
prediction,
correct_label,
feature_dim,
label_shape,
delta_v,
delta_d,
param_var,
param_dist,
param_reg):
correct_label = fluid.layers.reshape(
correct_label, [
label_shape[1] * label_shape[0]])
prediction = fluid.layers.transpose(prediction, [1, 2, 0])
reshaped_pred = fluid.layers.reshape(
prediction, [
label_shape[1] * label_shape[0], feature_dim])
unique_labels, unique_id, counts = fluid.layers.unique_with_counts(correct_label)
correct_label.stop_gradient = True
counts = fluid.layers.cast(counts, 'float32')
num_instances = fluid.layers.shape(unique_labels)
segmented_sum = unsorted_segment_sum(
reshaped_pred, unique_id, unique_labels, feature_dims=feature_dim)
counts_rsp = fluid.layers.reshape(counts, (-1, 1))
mu = fluid.layers.elementwise_div(segmented_sum, counts_rsp)
counts_rsp.stop_gradient = True
mu_expand = fluid.layers.gather(mu, unique_id)
tmp = fluid.layers.elementwise_sub(mu_expand, reshaped_pred)
distance = norm(tmp)
distance = distance - delta_v
distance_pos = fluid.layers.greater_equal(distance, fluid.layers.zeros_like(distance))
distance_pos = fluid.layers.cast(distance_pos, 'float32')
distance = distance * distance_pos
distance = fluid.layers.square(distance)
l_var = unsorted_segment_sum(distance, unique_id, unique_labels, feature_dims=1)
l_var = fluid.layers.elementwise_div(l_var, counts_rsp)
l_var = fluid.layers.reduce_sum(l_var)
l_var = l_var / fluid.layers.cast(num_instances * (num_instances - 1), 'float32')
mu_interleaved_rep = fluid.layers.expand(mu, [num_instances, 1])
mu_band_rep = fluid.layers.expand(mu, [1, num_instances])
mu_band_rep = fluid.layers.reshape(mu_band_rep, (num_instances * num_instances, feature_dim))
mu_diff = fluid.layers.elementwise_sub(mu_band_rep, mu_interleaved_rep)
intermediate_tensor = fluid.layers.reduce_sum(fluid.layers.abs(mu_diff), dim=1)
intermediate_tensor.stop_gradient = True
zero_vector = fluid.layers.zeros([1], 'float32')
bool_mask = fluid.layers.not_equal(intermediate_tensor, zero_vector)
temp = fluid.layers.where(bool_mask)
mu_diff_bool = fluid.layers.gather(mu_diff, temp)
mu_norm = norm(mu_diff_bool)
mu_norm = 2. * delta_d - mu_norm
mu_norm_pos = fluid.layers.greater_equal(mu_norm, fluid.layers.zeros_like(mu_norm))
mu_norm_pos = fluid.layers.cast(mu_norm_pos, 'float32')
mu_norm = mu_norm * mu_norm_pos
mu_norm_pos.stop_gradient = True
mu_norm = fluid.layers.square(mu_norm)
l_dist = fluid.layers.reduce_mean(mu_norm)
l_reg = fluid.layers.reduce_mean(norm(mu, axis=1))
l_var = param_var * l_var
l_dist = param_dist * l_dist
l_reg = param_reg * l_reg
loss = l_var + l_dist + l_reg
return loss, l_var, l_dist, l_reg
def discriminative_loss(prediction, correct_label, feature_dim, image_shape,
delta_v, delta_d, param_var, param_dist, param_reg):
batch_size = int(cfg.BATCH_SIZE_PER_DEV)
output_ta_loss = 0.
output_ta_var = 0.
output_ta_dist = 0.
output_ta_reg = 0.
for i in range(batch_size):
disc_loss_single, l_var_single, l_dist_single, l_reg_single = discriminative_loss_single(
prediction[i], correct_label[i], feature_dim, image_shape, delta_v, delta_d, param_var, param_dist,
param_reg)
output_ta_loss += disc_loss_single
output_ta_var += l_var_single
output_ta_dist += l_dist_single
output_ta_reg += l_reg_single
disc_loss = output_ta_loss / batch_size
l_var = output_ta_var / batch_size
l_dist = output_ta_dist / batch_size
l_reg = output_ta_reg / batch_size
return disc_loss, l_var, l_dist, l_reg
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import models.modeling
#import models.backbone
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("..")
import struct
import paddle.fluid as fluid
from paddle.fluid.proto.framework_pb2 import VarType
from pdseg import solver
from utils.config import cfg
from pdseg.loss import multi_softmax_with_loss
from loss import discriminative_loss
from models.modeling import lanenet
class ModelPhase(object):
"""
Standard name for model phase in PaddleSeg
The following standard keys are defined:
* `TRAIN`: training mode.
* `EVAL`: testing/evaluation mode.
* `PREDICT`: prediction/inference mode.
* `VISUAL` : visualization mode
"""
TRAIN = 'train'
EVAL = 'eval'
PREDICT = 'predict'
VISUAL = 'visual'
@staticmethod
def is_train(phase):
return phase == ModelPhase.TRAIN
@staticmethod
def is_predict(phase):
return phase == ModelPhase.PREDICT
@staticmethod
def is_eval(phase):
return phase == ModelPhase.EVAL
@staticmethod
def is_visual(phase):
return phase == ModelPhase.VISUAL
@staticmethod
def is_valid_phase(phase):
""" Check valid phase """
if ModelPhase.is_train(phase) or ModelPhase.is_predict(phase) \
or ModelPhase.is_eval(phase) or ModelPhase.is_visual(phase):
return True
return False
def seg_model(image, class_num):
model_name = cfg.MODEL.MODEL_NAME
if model_name == 'lanenet':
logits = lanenet.lanenet(image, class_num)
else:
raise Exception(
"unknow model name, only support unet, deeplabv3p, icnet, pspnet, hrnet"
)
return logits
def softmax(logit):
logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
logit = fluid.layers.softmax(logit)
logit = fluid.layers.transpose(logit, [0, 3, 1, 2])
return logit
def sigmoid_to_softmax(logit):
"""
one channel to two channel
"""
logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
logit = fluid.layers.sigmoid(logit)
logit_back = 1 - logit
logit = fluid.layers.concat([logit_back, logit], axis=-1)
logit = fluid.layers.transpose(logit, [0, 3, 1, 2])
return logit
def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
if not ModelPhase.is_valid_phase(phase):
raise ValueError("ModelPhase {} is not valid!".format(phase))
if ModelPhase.is_train(phase):
width = cfg.TRAIN_CROP_SIZE[0]
height = cfg.TRAIN_CROP_SIZE[1]
else:
width = cfg.EVAL_CROP_SIZE[0]
height = cfg.EVAL_CROP_SIZE[1]
image_shape = [cfg.DATASET.DATA_DIM, height, width]
grt_shape = [1, height, width]
class_num = cfg.DATASET.NUM_CLASSES
with fluid.program_guard(main_prog, start_prog):
with fluid.unique_name.guard():
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(
name='label', shape=grt_shape, dtype='int32')
if cfg.MODEL.MODEL_NAME == 'lanenet':
label_instance = fluid.layers.data(
name='label_instance', shape=grt_shape, dtype='int32')
mask = fluid.layers.data(
name='mask', shape=grt_shape, dtype='int32')
# use PyReader when doing traning and evaluation
if ModelPhase.is_train(phase) or ModelPhase.is_eval(phase):
py_reader = fluid.io.PyReader(
feed_list=[image, label, label_instance, mask],
capacity=cfg.DATALOADER.BUF_SIZE,
iterable=False,
use_double_buffer=True)
loss_type = cfg.SOLVER.LOSS
if not isinstance(loss_type, list):
loss_type = list(loss_type)
logits = seg_model(image, class_num)
if ModelPhase.is_train(phase):
loss_valid = False
valid_loss = []
if cfg.MODEL.MODEL_NAME == 'lanenet':
embeding_logit = logits[1]
logits = logits[0]
disc_loss, _, _, l_reg = discriminative_loss(embeding_logit, label_instance, 4,
image_shape[1:], 0.5, 3.0, 1.0, 1.0, 0.001)
if "softmax_loss" in loss_type:
weight = None
if cfg.MODEL.MODEL_NAME == 'lanenet':
weight = get_dynamic_weight(label)
seg_loss = multi_softmax_with_loss(logits, label, mask, class_num, weight)
loss_valid = True
valid_loss.append("softmax_loss")
if not loss_valid:
raise Exception("SOLVER.LOSS: {} is set wrong. it should "
"include one of (softmax_loss, bce_loss, dice_loss) at least"
" example: ['softmax_loss']".format(cfg.SOLVER.LOSS))
invalid_loss = [x for x in loss_type if x not in valid_loss]
if len(invalid_loss) > 0:
print("Warning: the loss {} you set is invalid. it will not be included in loss computed.".format(invalid_loss))
avg_loss = disc_loss + 0.00001 * l_reg + seg_loss
#get pred result in original size
if isinstance(logits, tuple):
logit = logits[0]
else:
logit = logits
if logit.shape[2:] != label.shape[2:]:
logit = fluid.layers.resize_bilinear(logit, label.shape[2:])
# return image input and logit output for inference graph prune
if ModelPhase.is_predict(phase):
if class_num == 1:
logit = sigmoid_to_softmax(logit)
else:
logit = softmax(logit)
return image, logit
if class_num == 1:
out = sigmoid_to_softmax(logit)
out = fluid.layers.transpose(out, [0, 2, 3, 1])
else:
out = fluid.layers.transpose(logit, [0, 2, 3, 1])
pred = fluid.layers.argmax(out, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
if ModelPhase.is_visual(phase):
if cfg.MODEL.MODEL_NAME == 'lanenet':
return pred, logits[1]
if class_num == 1:
logit = sigmoid_to_softmax(logit)
else:
logit = softmax(logit)
return pred, logit
accuracy, fp, fn = compute_metric(pred, label)
if ModelPhase.is_eval(phase):
return py_reader, pred, label, mask, accuracy, fp, fn
if ModelPhase.is_train(phase):
optimizer = solver.Solver(main_prog, start_prog)
decayed_lr = optimizer.optimise(avg_loss)
return py_reader, avg_loss, decayed_lr, pred, label, mask, disc_loss, seg_loss, accuracy, fp, fn
def compute_metric(pred, label):
label = fluid.layers.transpose(label, [0, 2, 3, 1])
idx = fluid.layers.where(pred == 1)
pix_cls_ret = fluid.layers.gather_nd(label, idx)
correct_num = fluid.layers.reduce_sum(fluid.layers.cast(pix_cls_ret, 'float32'))
gt_num = fluid.layers.cast(fluid.layers.shape(fluid.layers.gather_nd(label,
fluid.layers.where(label == 1)))[0], 'int64')
pred_num = fluid.layers.cast(fluid.layers.shape(fluid.layers.gather_nd(pred, idx))[0], 'int64')
accuracy = correct_num / gt_num
false_pred = pred_num - correct_num
fp = fluid.layers.cast(false_pred, 'float32') / fluid.layers.cast(fluid.layers.shape(pix_cls_ret)[0], 'int64')
label_cls_ret = fluid.layers.gather_nd(label, fluid.layers.where(label == 1))
mis_pred = fluid.layers.cast(fluid.layers.shape(label_cls_ret)[0], 'int64') - correct_num
fn = fluid.layers.cast(mis_pred, 'float32') / fluid.layers.cast(fluid.layers.shape(label_cls_ret)[0], 'int64')
accuracy.stop_gradient = True
fp.stop_gradient = True
fn.stop_gradient = True
return accuracy, fp, fn
def get_dynamic_weight(label):
label = fluid.layers.reshape(label, [-1])
unique_labels, unique_id, counts = fluid.layers.unique_with_counts(label)
counts = fluid.layers.cast(counts, 'float32')
weight = 1.0 / fluid.layers.log((counts / fluid.layers.reduce_sum(counts) + 1.02))
return weight
def to_int(string, dest="I"):
return struct.unpack(dest, string)[0]
def parse_shape_from_file(filename):
with open(filename, "rb") as file:
version = file.read(4)
lod_level = to_int(file.read(8), dest="Q")
for i in range(lod_level):
_size = to_int(file.read(8), dest="Q")
_ = file.read(_size)
version = file.read(4)
tensor_desc_size = to_int(file.read(4))
tensor_desc = VarType.TensorDesc()
tensor_desc.ParseFromString(file.read(tensor_desc_size))
return tuple(tensor_desc.dims)
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from utils.config import cfg
from pdseg.models.libs.model_libs import scope, name_scope
from pdseg.models.libs.model_libs import bn, bn_relu, relu
from pdseg.models.libs.model_libs import conv, max_pool, deconv
from pdseg.models.backbone.vgg import VGGNet as vgg_backbone
#from models.backbone.vgg import VGGNet as vgg_backbone
# Bottleneck type
REGULAR = 1
DOWNSAMPLING = 2
UPSAMPLING = 3
DILATED = 4
ASYMMETRIC = 5
def prelu(x, decoder=False):
# If decoder, then perform relu else perform prelu
if decoder:
return fluid.layers.relu(x)
return fluid.layers.prelu(x, 'channel')
def iniatial_block(inputs, name_scope='iniatial_block'):
'''
The initial block for Enet has 2 branches: The convolution branch and Maxpool branch.
The conv branch has 13 filters, while the maxpool branch gives 3 channels corresponding to the RGB channels.
Both output layers are then concatenated to give an output of 16 channels.
:param inputs(Tensor): A 4D tensor of shape [batch_size, height, width, channels]
:return net_concatenated(Tensor): a 4D Tensor of new shape [batch_size, height, width, channels]
'''
# Convolutional branch
with scope(name_scope):
net_conv = conv(inputs, 13, 3, stride=2, padding=1)
net_conv = bn(net_conv)
net_conv = fluid.layers.prelu(net_conv, 'channel')
# Max pool branch
net_pool = max_pool(inputs, [2, 2], stride=2, padding='SAME')
# Concatenated output - does it matter max pool comes first or conv comes first? probably not.
net_concatenated = fluid.layers.concat([net_conv, net_pool], axis=1)
return net_concatenated
def bottleneck(inputs,
output_depth,
filter_size,
regularizer_prob,
projection_ratio=4,
type=REGULAR,
seed=0,
output_shape=None,
dilation_rate=None,
decoder=False,
name_scope='bottleneck'):
# Calculate the depth reduction based on the projection ratio used in 1x1 convolution.
reduced_depth = int(inputs.shape[1] / projection_ratio)
# DOWNSAMPLING BOTTLENECK
if type == DOWNSAMPLING:
#=============MAIN BRANCH=============
#Just perform a max pooling
with scope('down_sample'):
inputs_shape = inputs.shape
with scope('main_max_pool'):
net_main = fluid.layers.conv2d(inputs, inputs_shape[1], filter_size=3, stride=2, padding='SAME')
#First get the difference in depth to pad, then pad with zeros only on the last dimension.
depth_to_pad = abs(inputs_shape[1] - output_depth)
paddings = [0, 0, 0, depth_to_pad, 0, 0, 0, 0]
with scope('main_padding'):
net_main = fluid.layers.pad(net_main, paddings=paddings)
with scope('block1'):
net = conv(inputs, reduced_depth, [2, 2], stride=2, padding='same')
net = bn(net)
net = prelu(net, decoder=decoder)
with scope('block2'):
net = conv(net, reduced_depth, [filter_size, filter_size], padding='same')
net = bn(net)
net = prelu(net, decoder=decoder)
with scope('block3'):
net = conv(net, output_depth, [1, 1], padding='same')
net = bn(net)
net = prelu(net, decoder=decoder)
# Regularizer
net = fluid.layers.dropout(net, regularizer_prob, seed=seed)
# Finally, combine the two branches together via an element-wise addition
net = fluid.layers.elementwise_add(net, net_main)
net = prelu(net, decoder=decoder)
return net, inputs_shape
# DILATION CONVOLUTION BOTTLENECK
# Everything is the same as a regular bottleneck except for the dilation rate argument
elif type == DILATED:
#Check if dilation rate is given
if not dilation_rate:
raise ValueError('Dilation rate is not given.')
with scope('dilated'):
# Save the main branch for addition later
net_main = inputs
# First projection with 1x1 kernel (dimensionality reduction)
with scope('block1'):
net = conv(inputs, reduced_depth, [1, 1])
net = bn(net)
net = prelu(net, decoder=decoder)
# Second conv block --- apply dilated convolution here
with scope('block2'):
net = conv(net, reduced_depth, filter_size, padding='SAME', dilation=dilation_rate)
net = bn(net)
net = prelu(net, decoder=decoder)
# Final projection with 1x1 kernel (Expansion)
with scope('block3'):
net = conv(net, output_depth, [1,1])
net = bn(net)
net = prelu(net, decoder=decoder)
# Regularizer
net = fluid.layers.dropout(net, regularizer_prob, seed=seed)
net = prelu(net, decoder=decoder)
# Add the main branch
net = fluid.layers.elementwise_add(net_main, net)
net = prelu(net, decoder=decoder)
return net
# ASYMMETRIC CONVOLUTION BOTTLENECK
# Everything is the same as a regular bottleneck except for a [5,5] kernel decomposed into two [5,1] then [1,5]
elif type == ASYMMETRIC:
# Save the main branch for addition later
with scope('asymmetric'):
net_main = inputs
# First projection with 1x1 kernel (dimensionality reduction)
with scope('block1'):
net = conv(inputs, reduced_depth, [1, 1])
net = bn(net)
net = prelu(net, decoder=decoder)
# Second conv block --- apply asymmetric conv here
with scope('block2'):
with scope('asymmetric_conv2a'):
net = conv(net, reduced_depth, [filter_size, 1], padding='same')
with scope('asymmetric_conv2b'):
net = conv(net, reduced_depth, [1, filter_size], padding='same')
net = bn(net)
net = prelu(net, decoder=decoder)
# Final projection with 1x1 kernel
with scope('block3'):
net = conv(net, output_depth, [1, 1])
net = bn(net)
net = prelu(net, decoder=decoder)
# Regularizer
net = fluid.layers.dropout(net, regularizer_prob, seed=seed)
net = prelu(net, decoder=decoder)
# Add the main branch
net = fluid.layers.elementwise_add(net_main, net)
net = prelu(net, decoder=decoder)
return net
# UPSAMPLING BOTTLENECK
# Everything is the same as a regular one, except convolution becomes transposed.
elif type == UPSAMPLING:
#Check if pooling indices is given
#Check output_shape given or not
if output_shape is None:
raise ValueError('Output depth is not given')
#=======MAIN BRANCH=======
#Main branch to upsample. output shape must match with the shape of the layer that was pooled initially, in order
#for the pooling indices to work correctly. However, the initial pooled layer was padded, so need to reduce dimension
#before unpooling. In the paper, padding is replaced with convolution for this purpose of reducing the depth!
with scope('upsampling'):
with scope('unpool'):
net_unpool = conv(inputs, output_depth, [1, 1])
net_unpool = bn(net_unpool)
net_unpool = fluid.layers.resize_bilinear(net_unpool, out_shape=output_shape[2:])
# First 1x1 projection to reduce depth
with scope('block1'):
net = conv(inputs, reduced_depth, [1, 1])
net = bn(net)
net = prelu(net, decoder=decoder)
with scope('block2'):
net = deconv(net, reduced_depth, filter_size=filter_size, stride=2, padding='same')
net = bn(net)
net = prelu(net, decoder=decoder)
# Final projection with 1x1 kernel
with scope('block3'):
net = conv(net, output_depth, [1, 1])
net = bn(net)
net = prelu(net, decoder=decoder)
# Regularizer
net = fluid.layers.dropout(net, regularizer_prob, seed=seed)
net = prelu(net, decoder=decoder)
# Finally, add the unpooling layer and the sub branch together
net = fluid.layers.elementwise_add(net, net_unpool)
net = prelu(net, decoder=decoder)
return net
# REGULAR BOTTLENECK
else:
with scope('regular'):
net_main = inputs
# First projection with 1x1 kernel
with scope('block1'):
net = conv(inputs, reduced_depth, [1, 1])
net = bn(net)
net = prelu(net, decoder=decoder)
# Second conv block
with scope('block2'):
net = conv(net, reduced_depth, [filter_size, filter_size], padding='same')
net = bn(net)
net = prelu(net, decoder=decoder)
# Final projection with 1x1 kernel
with scope('block3'):
net = conv(net, output_depth, [1, 1])
net = bn(net)
net = prelu(net, decoder=decoder)
# Regularizer
net = fluid.layers.dropout(net, regularizer_prob, seed=seed)
net = prelu(net, decoder=decoder)
# Add the main branch
net = fluid.layers.elementwise_add(net_main, net)
net = prelu(net, decoder=decoder)
return net
def ENet_stage1(inputs, name_scope='stage1_block'):
with scope(name_scope):
with scope('bottleneck1_0'):
net, inputs_shape_1 \
= bottleneck(inputs, output_depth=64, filter_size=3, regularizer_prob=0.01, type=DOWNSAMPLING,
name_scope='bottleneck1_0')
with scope('bottleneck1_1'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01,
name_scope='bottleneck1_1')
with scope('bottleneck1_2'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01,
name_scope='bottleneck1_2')
with scope('bottleneck1_3'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01,
name_scope='bottleneck1_3')
with scope('bottleneck1_4'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01,
name_scope='bottleneck1_4')
return net, inputs_shape_1
def ENet_stage2(inputs, name_scope='stage2_block'):
with scope(name_scope):
net, inputs_shape_2 \
= bottleneck(inputs, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DOWNSAMPLING,
name_scope='bottleneck2_0')
for i in range(2):
with scope('bottleneck2_{}'.format(str(4 * i + 1))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1,
name_scope='bottleneck2_{}'.format(str(4 * i + 1)))
with scope('bottleneck2_{}'.format(str(4 * i + 2))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+1)),
name_scope='bottleneck2_{}'.format(str(4 * i + 2)))
with scope('bottleneck2_{}'.format(str(4 * i + 3))):
net = bottleneck(net, output_depth=128, filter_size=5, regularizer_prob=0.1, type=ASYMMETRIC,
name_scope='bottleneck2_{}'.format(str(4 * i + 3)))
with scope('bottleneck2_{}'.format(str(4 * i + 4))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+2)),
name_scope='bottleneck2_{}'.format(str(4 * i + 4)))
return net, inputs_shape_2
def ENet_stage3(inputs, name_scope='stage3_block'):
with scope(name_scope):
for i in range(2):
with scope('bottleneck3_{}'.format(str(4 * i + 0))):
net = bottleneck(inputs, output_depth=128, filter_size=3, regularizer_prob=0.1,
name_scope='bottleneck3_{}'.format(str(4 * i + 0)))
with scope('bottleneck3_{}'.format(str(4 * i + 1))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+1)),
name_scope='bottleneck3_{}'.format(str(4 * i + 1)))
with scope('bottleneck3_{}'.format(str(4 * i + 2))):
net = bottleneck(net, output_depth=128, filter_size=5, regularizer_prob=0.1, type=ASYMMETRIC,
name_scope='bottleneck3_{}'.format(str(4 * i + 2)))
with scope('bottleneck3_{}'.format(str(4 * i + 3))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+2)),
name_scope='bottleneck3_{}'.format(str(4 * i + 3)))
return net
def ENet_stage4(inputs, inputs_shape, connect_tensor,
skip_connections=True, name_scope='stage4_block'):
with scope(name_scope):
with scope('bottleneck4_0'):
net = bottleneck(inputs, output_depth=64, filter_size=3, regularizer_prob=0.1,
type=UPSAMPLING, decoder=True, output_shape=inputs_shape,
name_scope='bottleneck4_0')
if skip_connections:
net = fluid.layers.elementwise_add(net, connect_tensor)
with scope('bottleneck4_1'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.1, decoder=True,
name_scope='bottleneck4_1')
with scope('bottleneck4_2'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.1, decoder=True,
name_scope='bottleneck4_2')
return net
def ENet_stage5(inputs, inputs_shape, connect_tensor, skip_connections=True,
name_scope='stage5_block'):
with scope(name_scope):
net = bottleneck(inputs, output_depth=16, filter_size=3, regularizer_prob=0.1, type=UPSAMPLING,
decoder=True, output_shape=inputs_shape,
name_scope='bottleneck5_0')
if skip_connections:
net = fluid.layers.elementwise_add(net, connect_tensor)
with scope('bottleneck5_1'):
net = bottleneck(net, output_depth=16, filter_size=3, regularizer_prob=0.1, decoder=True,
name_scope='bottleneck5_1')
return net
def decoder(input, num_classes):
if 'enet' in cfg.MODEL.LANENET.BACKBONE:
# Segmentation branch
with scope('LaneNetSeg'):
initial, stage1, stage2, inputs_shape_1, inputs_shape_2 = input
segStage3 = ENet_stage3(stage2)
segStage4 = ENet_stage4(segStage3, inputs_shape_2, stage1)
segStage5 = ENet_stage5(segStage4, inputs_shape_1, initial)
segLogits = deconv(segStage5, num_classes, filter_size=2, stride=2, padding='SAME')
# Embedding branch
with scope('LaneNetEm'):
emStage3 = ENet_stage3(stage2)
emStage4 = ENet_stage4(emStage3, inputs_shape_2, stage1)
emStage5 = ENet_stage5(emStage4, inputs_shape_1, initial)
emLogits = deconv(emStage5, 4, filter_size=2, stride=2, padding='SAME')
elif 'vgg' in cfg.MODEL.LANENET.BACKBONE:
encoder_list = ['pool5', 'pool4', 'pool3']
# score stage
input_tensor = input[encoder_list[0]]
with scope('score_origin'):
score = conv(input_tensor, 64, 1)
encoder_list = encoder_list[1:]
for i in range(len(encoder_list)):
with scope('deconv_{:d}'.format(i + 1)):
deconv_out = deconv(score, 64, filter_size=4, stride=2, padding='SAME')
input_tensor = input[encoder_list[i]]
with scope('score_{:d}'.format(i + 1)):
score = conv(input_tensor, 64, 1)
score = fluid.layers.elementwise_add(deconv_out, score)
with scope('deconv_final'):
emLogits = deconv(score, 64, filter_size=16, stride=8, padding='SAME')
with scope('score_final'):
segLogits = conv(emLogits, num_classes, 1)
emLogits = relu(conv(emLogits, 4, 1))
return segLogits, emLogits
def encoder(input):
if 'vgg' in cfg.MODEL.LANENET.BACKBONE:
model = vgg_backbone(layers=16)
#output = model.net(input)
_, encode_feature_dict = model.net(input, end_points=13, decode_points=[7, 10, 13])
output = {}
output['pool3'] = encode_feature_dict[7]
output['pool4'] = encode_feature_dict[10]
output['pool5'] = encode_feature_dict[13]
elif 'enet' in cfg.MODEL.LANET.BACKBONE:
with scope('LaneNetBase'):
initial = iniatial_block(input)
stage1, inputs_shape_1 = ENet_stage1(initial)
stage2, inputs_shape_2 = ENet_stage2(stage1)
output = (initial, stage1, stage2, inputs_shape_1, inputs_shape_2)
else:
raise Exception("LaneNet expect enet and vgg backbone, but received {}".
format(cfg.MODEL.LANENET.BACKBONE))
return output
def lanenet(img, num_classes):
output = encoder(img)
segLogits, emLogits = decoder(output, num_classes)
return segLogits, emLogits
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
import os
import time
import codecs
import numpy as np
import cv2
from utils.config import cfg
import data_aug as aug
from pdseg.data_utils import GeneratorEnqueuer
from models.model_builder import ModelPhase
import copy
def cv2_imread(file_path, flag=cv2.IMREAD_COLOR):
# resolve cv2.imread open Chinese file path issues on Windows Platform.
return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag)
class LaneNetDataset():
def __init__(self,
file_list,
data_dir,
shuffle=False,
mode=ModelPhase.TRAIN):
self.mode = mode
self.shuffle = shuffle
self.data_dir = data_dir
self.shuffle_seed = 0
# NOTE: Please ensure file list was save in UTF-8 coding format
with codecs.open(file_list, 'r', 'utf-8') as flist:
self.lines = [line.strip() for line in flist]
self.all_lines = copy.deepcopy(self.lines)
if shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
elif shuffle:
np.random.shuffle(self.lines)
def generator(self):
if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // cfg.NUM_TRAINERS
self.lines = self.all_lines[num_lines * cfg.TRAINER_ID: num_lines * (cfg.TRAINER_ID + 1)]
self.shuffle_seed += 1
elif self.shuffle:
np.random.shuffle(self.lines)
for line in self.lines:
yield self.process_image(line, self.data_dir, self.mode)
def sharding_generator(self, pid=0, num_processes=1):
"""
Use line id as shard key for multiprocess io
It's a normal generator if pid=0, num_processes=1
"""
for index, line in enumerate(self.lines):
# Use index and pid to shard file list
if index % num_processes == pid:
yield self.process_image(line, self.data_dir, self.mode)
def batch_reader(self, batch_size):
br = self.batch(self.reader, batch_size)
for batch in br:
yield batch[0], batch[1], batch[2]
def multiprocess_generator(self, max_queue_size=32, num_processes=8):
# Re-shuffle file list
if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // self.num_trainers
self.lines = self.all_lines[num_lines * self.trainer_id: num_lines * (self.trainer_id + 1)]
self.shuffle_seed += 1
elif self.shuffle:
np.random.shuffle(self.lines)
# Create multiple sharding generators according to num_processes for multiple processes
generators = []
for pid in range(num_processes):
generators.append(self.sharding_generator(pid, num_processes))
try:
enqueuer = GeneratorEnqueuer(generators)
enqueuer.start(max_queue_size=max_queue_size, workers=num_processes)
while True:
generator_out = None
while enqueuer.is_running():
if not enqueuer.queue.empty():
generator_out = enqueuer.queue.get(timeout=5)
break
else:
time.sleep(0.01)
if generator_out is None:
break
yield generator_out
finally:
if enqueuer is not None:
enqueuer.stop()
def batch(self, reader, batch_size, is_test=False, drop_last=False):
def batch_reader(is_test=False, drop_last=drop_last):
if is_test:
imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], []
for img, grt, grt_instance, img_name, valid_shape, org_shape in reader():
imgs.append(img)
grts.append(grt)
grts_instance.append(grt_instance)
img_names.append(img_name)
valid_shapes.append(valid_shape)
org_shapes.append(org_shape)
if len(imgs) == batch_size:
yield np.array(imgs), np.array(
grts), np.array(grts_instance), img_names, np.array(valid_shapes), np.array(
org_shapes)
imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], []
if not drop_last and len(imgs) > 0:
yield np.array(imgs), np.array(grts), np.array(grts_instance), img_names, np.array(
valid_shapes), np.array(org_shapes)
else:
imgs, labs, labs_instance, ignore = [], [], [], []
bs = 0
for img, lab, lab_instance, ig in reader():
imgs.append(img)
labs.append(lab)
labs_instance.append(lab_instance)
ignore.append(ig)
bs += 1
if bs == batch_size:
yield np.array(imgs), np.array(labs), np.array(labs_instance), np.array(ignore)
bs = 0
imgs, labs, labs_instance, ignore = [], [], [], []
if not drop_last and bs > 0:
yield np.array(imgs), np.array(labs), np.array(labs_instance), np.array(ignore)
return batch_reader(is_test, drop_last)
def load_image(self, line, src_dir, mode=ModelPhase.TRAIN):
# original image cv2.imread flag setting
cv2_imread_flag = cv2.IMREAD_COLOR
if cfg.DATASET.IMAGE_TYPE == "rgba":
# If use RBGA 4 channel ImageType, use IMREAD_UNCHANGED flags to
# reserver alpha channel
cv2_imread_flag = cv2.IMREAD_UNCHANGED
parts = line.strip().split(cfg.DATASET.SEPARATOR)
if len(parts) != 3:
if mode == ModelPhase.TRAIN or mode == ModelPhase.EVAL:
raise Exception("File list format incorrect! It should be"
" image_name{}label_name\\n".format(
cfg.DATASET.SEPARATOR))
img_name, grt_name, grt_instance_name = parts[0], None, None
else:
img_name, grt_name, grt_instance_name = parts[0], parts[1], parts[2]
img_path = os.path.join(src_dir, img_name)
img = cv2_imread(img_path, cv2_imread_flag)
if grt_name is not None:
grt_path = os.path.join(src_dir, grt_name)
grt_instance_path = os.path.join(src_dir, grt_instance_name)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
grt[grt == 255] = 1
grt[grt != 1] = 0
grt_instance = cv2_imread(grt_instance_path, cv2.IMREAD_GRAYSCALE)
else:
grt = None
grt_instance = None
if img is None:
raise Exception(
"Empty image, src_dir: {}, img: {} & lab: {}".format(
src_dir, img_path, grt_path))
img_height = img.shape[0]
img_width = img.shape[1]
if grt is not None:
grt_height = grt.shape[0]
grt_width = grt.shape[1]
if img_height != grt_height or img_width != grt_width:
raise Exception(
"source img and label img must has the same size")
else:
if mode == ModelPhase.TRAIN or mode == ModelPhase.EVAL:
raise Exception(
"Empty image, src_dir: {}, img: {} & lab: {}".format(
src_dir, img_path, grt_path))
if len(img.shape) < 3:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img_channels = img.shape[2]
if img_channels < 3:
raise Exception("PaddleSeg only supports gray, rgb or rgba image")
if img_channels != cfg.DATASET.DATA_DIM:
raise Exception(
"Input image channel({}) is not match cfg.DATASET.DATA_DIM({}), img_name={}"
.format(img_channels, cfg.DATASET.DATADIM, img_name))
if img_channels != len(cfg.MEAN):
raise Exception(
"img name {}, img chns {} mean size {}, size unequal".format(
img_name, img_channels, len(cfg.MEAN)))
if img_channels != len(cfg.STD):
raise Exception(
"img name {}, img chns {} std size {}, size unequal".format(
img_name, img_channels, len(cfg.STD)))
return img, grt, grt_instance, img_name, grt_name
def normalize_image(self, img):
""" 像素归一化后减均值除方差 """
img = img.transpose((2, 0, 1)).astype('float32') / 255.0
img_mean = np.array(cfg.MEAN).reshape((len(cfg.MEAN), 1, 1))
img_std = np.array(cfg.STD).reshape((len(cfg.STD), 1, 1))
img -= img_mean
img /= img_std
return img
def process_image(self, line, data_dir, mode):
""" process_image """
img, grt, grt_instance, img_name, grt_name = self.load_image(
line, data_dir, mode=mode)
if mode == ModelPhase.TRAIN:
img, grt, grt_instance = aug.resize(img, grt, grt_instance, mode)
if cfg.AUG.RICH_CROP.ENABLE:
if cfg.AUG.RICH_CROP.BLUR:
if cfg.AUG.RICH_CROP.BLUR_RATIO <= 0:
n = 0
elif cfg.AUG.RICH_CROP.BLUR_RATIO >= 1:
n = 1
else:
n = int(1.0 / cfg.AUG.RICH_CROP.BLUR_RATIO)
if n > 0:
if np.random.randint(0, n) == 0:
radius = np.random.randint(3, 10)
if radius % 2 != 1:
radius = radius + 1
if radius > 9:
radius = 9
img = cv2.GaussianBlur(img, (radius, radius), 0, 0)
img, grt = aug.random_rotation(
img,
grt,
rich_crop_max_rotation=cfg.AUG.RICH_CROP.MAX_ROTATION,
mean_value=cfg.DATASET.PADDING_VALUE)
img, grt = aug.rand_scale_aspect(
img,
grt,
rich_crop_min_scale=cfg.AUG.RICH_CROP.MIN_AREA_RATIO,
rich_crop_aspect_ratio=cfg.AUG.RICH_CROP.ASPECT_RATIO)
img = aug.hsv_color_jitter(
img,
brightness_jitter_ratio=cfg.AUG.RICH_CROP.
BRIGHTNESS_JITTER_RATIO,
saturation_jitter_ratio=cfg.AUG.RICH_CROP.
SATURATION_JITTER_RATIO,
contrast_jitter_ratio=cfg.AUG.RICH_CROP.
CONTRAST_JITTER_RATIO)
if cfg.AUG.FLIP:
if cfg.AUG.FLIP_RATIO <= 0:
n = 0
elif cfg.AUG.FLIP_RATIO >= 1:
n = 1
else:
n = int(1.0 / cfg.AUG.FLIP_RATIO)
if n > 0:
if np.random.randint(0, n) == 0:
img = img[::-1, :, :]
grt = grt[::-1, :]
if cfg.AUG.MIRROR:
if np.random.randint(0, 2) == 1:
img = img[:, ::-1, :]
grt = grt[:, ::-1]
img, grt = aug.rand_crop(img, grt, mode=mode)
elif ModelPhase.is_eval(mode):
img, grt, grt_instance = aug.resize(img, grt, grt_instance, mode=mode)
elif ModelPhase.is_visual(mode):
ori_img = img.copy()
img, grt, grt_instance = aug.resize(img, grt, grt_instance, mode=mode)
valid_shape = [img.shape[0], img.shape[1]]
else:
raise ValueError("Dataset mode={} Error!".format(mode))
# Normalize image
img = self.normalize_image(img)
if ModelPhase.is_train(mode) or ModelPhase.is_eval(mode):
grt = np.expand_dims(np.array(grt).astype('int32'), axis=0)
ignore = (grt != cfg.DATASET.IGNORE_INDEX).astype('int32')
if ModelPhase.is_train(mode):
return (img, grt, grt_instance, ignore)
elif ModelPhase.is_eval(mode):
return (img, grt, grt_instance, ignore)
elif ModelPhase.is_visual(mode):
return (img, grt, grt_instance, img_name, valid_shape, ori_img)
pre-commit
yapf == 0.26.0
flake8
pyyaml >= 5.1
tb-paddle
tensorboard >= 1.15.0
Pillow
numpy
six
opencv-python
tqdm
requests
sklearn
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# GPU memory garbage collection optimization flags
os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
import sys
cur_path = os.path.abspath(os.path.dirname(__file__))
root_path = os.path.split(os.path.split(cur_path)[0])[0]
SEG_PATH = os.path.join(cur_path, "../../../")
sys.path.append(SEG_PATH)
sys.path.append(root_path)
import argparse
import pprint
import numpy as np
import paddle.fluid as fluid
from utils.config import cfg
from pdseg.utils.timer import Timer, calculate_eta
from reader import LaneNetDataset
from models.model_builder import build_model
from models.model_builder import ModelPhase
from models.model_builder import parse_shape_from_file
from eval import evaluate
from vis import visualize
from utils import dist_utils
def parse_args():
parser = argparse.ArgumentParser(description='PaddleSeg training')
parser.add_argument(
'--cfg',
dest='cfg_file',
help='Config file for training (and optionally testing)',
default=None,
type=str)
parser.add_argument(
'--use_gpu',
dest='use_gpu',
help='Use gpu or cpu',
action='store_true',
default=False)
parser.add_argument(
'--use_mpio',
dest='use_mpio',
help='Use multiprocess I/O or not',
action='store_true',
default=False)
parser.add_argument(
'--log_steps',
dest='log_steps',
help='Display logging information at every log_steps',
default=10,
type=int)
parser.add_argument(
'--debug',
dest='debug',
help='debug mode, display detail information of training',
action='store_true')
parser.add_argument(
'--use_tb',
dest='use_tb',
help='whether to record the data during training to Tensorboard',
action='store_true')
parser.add_argument(
'--tb_log_dir',
dest='tb_log_dir',
help='Tensorboard logging directory',
default=None,
type=str)
parser.add_argument(
'--do_eval',
dest='do_eval',
help='Evaluation models result on every new checkpoint',
action='store_true')
parser.add_argument(
'opts',
help='See utils/config.py for all options',
default=None,
nargs=argparse.REMAINDER)
return parser.parse_args()
def save_vars(executor, dirname, program=None, vars=None):
"""
Temporary resolution for Win save variables compatability.
Will fix in PaddlePaddle v1.5.2
"""
save_program = fluid.Program()
save_block = save_program.global_block()
for each_var in vars:
# NOTE: don't save the variable which type is RAW
if each_var.type == fluid.core.VarDesc.VarType.RAW:
continue
new_var = save_block.create_var(
name=each_var.name,
shape=each_var.shape,
dtype=each_var.dtype,
type=each_var.type,
lod_level=each_var.lod_level,
persistable=True)
file_path = os.path.join(dirname, new_var.name)
file_path = os.path.normpath(file_path)
save_block.append_op(
type='save',
inputs={'X': [new_var]},
outputs={},
attrs={'file_path': file_path})
executor.run(save_program)
def save_checkpoint(exe, program, ckpt_name):
"""
Save checkpoint for evaluation or resume training
"""
ckpt_dir = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, str(ckpt_name))
print("Save model checkpoint to {}".format(ckpt_dir))
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
save_vars(
exe,
ckpt_dir,
program,
vars=list(filter(fluid.io.is_persistable, program.list_vars())))
return ckpt_dir
def load_checkpoint(exe, program):
"""
Load checkpoiont from pretrained model directory for resume training
"""
print('Resume model training from:', cfg.TRAIN.RESUME_MODEL_DIR)
if not os.path.exists(cfg.TRAIN.RESUME_MODEL_DIR):
raise ValueError("TRAIN.PRETRAIN_MODEL {} not exist!".format(
cfg.TRAIN.RESUME_MODEL_DIR))
fluid.io.load_persistables(
exe, cfg.TRAIN.RESUME_MODEL_DIR, main_program=program)
model_path = cfg.TRAIN.RESUME_MODEL_DIR
# Check is path ended by path spearator
if model_path[-1] == os.sep:
model_path = model_path[0:-1]
epoch_name = os.path.basename(model_path)
# If resume model is final model
if epoch_name == 'final':
begin_epoch = cfg.SOLVER.NUM_EPOCHS
# If resume model path is end of digit, restore epoch status
elif epoch_name.isdigit():
epoch = int(epoch_name)
begin_epoch = epoch + 1
else:
raise ValueError("Resume model path is not valid!")
print("Model checkpoint loaded successfully!")
return begin_epoch
def print_info(*msg):
if cfg.TRAINER_ID == 0:
print(*msg)
def train(cfg):
startup_prog = fluid.Program()
train_prog = fluid.Program()
drop_last = True
dataset = LaneNetDataset(
file_list=cfg.DATASET.TRAIN_FILE_LIST,
mode=ModelPhase.TRAIN,
shuffle=True,
data_dir=cfg.DATASET.DATA_DIR)
def data_generator():
if args.use_mpio:
data_gen = dataset.multiprocess_generator(
num_processes=cfg.DATALOADER.NUM_WORKERS,
max_queue_size=cfg.DATALOADER.BUF_SIZE)
else:
data_gen = dataset.generator()
batch_data = []
for b in data_gen:
batch_data.append(b)
if len(batch_data) == (cfg.BATCH_SIZE // cfg.NUM_TRAINERS):
for item in batch_data:
yield item
batch_data = []
# Get device environment
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace()
places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
# Get number of GPU
dev_count = cfg.NUM_TRAINERS if cfg.NUM_TRAINERS > 1 else len(places)
print_info("#Device count: {}".format(dev_count))
# Make sure BATCH_SIZE can divided by GPU cards
assert cfg.BATCH_SIZE % dev_count == 0, (
'BATCH_SIZE:{} not divisble by number of GPUs:{}'.format(
cfg.BATCH_SIZE, dev_count))
# If use multi-gpu training mode, batch data will allocated to each GPU evenly
batch_size_per_dev = cfg.BATCH_SIZE // dev_count
cfg.BATCH_SIZE_PER_DEV = batch_size_per_dev
print_info("batch_size_per_dev: {}".format(batch_size_per_dev))
py_reader, avg_loss, lr, pred, grts, masks, emb_loss, seg_loss, accuracy, fp, fn = build_model(
train_prog, startup_prog, phase=ModelPhase.TRAIN)
py_reader.decorate_sample_generator(
data_generator, batch_size=batch_size_per_dev, drop_last=drop_last)
exe = fluid.Executor(place)
exe.run(startup_prog)
exec_strategy = fluid.ExecutionStrategy()
# Clear temporary variables every 100 iteration
if args.use_gpu:
exec_strategy.num_threads = fluid.core.get_cuda_device_count()
exec_strategy.num_iteration_per_drop_scope = 100
build_strategy = fluid.BuildStrategy()
if cfg.NUM_TRAINERS > 1 and args.use_gpu:
dist_utils.prepare_for_multi_process(exe, build_strategy, train_prog)
exec_strategy.num_threads = 1
if cfg.TRAIN.SYNC_BATCH_NORM and args.use_gpu:
if dev_count > 1:
# Apply sync batch norm strategy
print_info("Sync BatchNorm strategy is effective.")
build_strategy.sync_batch_norm = True
else:
print_info(
"Sync BatchNorm strategy will not be effective if GPU device"
" count <= 1")
compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
loss_name=avg_loss.name,
exec_strategy=exec_strategy,
build_strategy=build_strategy)
# Resume training
begin_epoch = cfg.SOLVER.BEGIN_EPOCH
if cfg.TRAIN.RESUME_MODEL_DIR:
begin_epoch = load_checkpoint(exe, train_prog)
# Load pretrained model
elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL_DIR):
print_info('Pretrained model dir: ', cfg.TRAIN.PRETRAINED_MODEL_DIR)
load_vars = []
load_fail_vars = []
def var_shape_matched(var, shape):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist = os.path.exists(
os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
if var_exist:
var_shape = parse_shape_from_file(
os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
if var_shape != shape:
print(var.name, var_shape, shape)
return var_shape == shape
return False
for x in train_prog.list_vars():
if isinstance(x, fluid.framework.Parameter):
shape = tuple(fluid.global_scope().find_var(
x.name).get_tensor().shape())
if var_shape_matched(x, shape):
load_vars.append(x)
else:
load_fail_vars.append(x)
fluid.io.load_vars(
exe, dirname=cfg.TRAIN.PRETRAINED_MODEL_DIR, vars=load_vars)
for var in load_vars:
print_info("Parameter[{}] loaded sucessfully!".format(var.name))
for var in load_fail_vars:
print_info(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it.".format(var.name))
print_info("{}/{} pretrained parameters loaded successfully!".format(
len(load_vars),
len(load_vars) + len(load_fail_vars)))
else:
print_info(
'Pretrained model dir {} not exists, training from scratch...'.
format(cfg.TRAIN.PRETRAINED_MODEL_DIR))
# fetch_list = [avg_loss.name, lr.name, accuracy.name, precision.name, recall.name]
fetch_list = [avg_loss.name, lr.name, seg_loss.name, emb_loss.name, accuracy.name, fp.name, fn.name]
if args.debug:
# Fetch more variable info and use streaming confusion matrix to
# calculate IoU results if in debug mode
np.set_printoptions(
precision=4, suppress=True, linewidth=160, floatmode="fixed")
fetch_list.extend([pred.name, grts.name, masks.name])
# cm = ConfusionMatrix(cfg.DATASET.NUM_CLASSES, streaming=True)
if args.use_tb:
if not args.tb_log_dir:
print_info("Please specify the log directory by --tb_log_dir.")
exit(1)
from tb_paddle import SummaryWriter
log_writer = SummaryWriter(args.tb_log_dir)
# trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
# num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
global_step = 0
all_step = cfg.DATASET.TRAIN_TOTAL_IMAGES // cfg.BATCH_SIZE
if cfg.DATASET.TRAIN_TOTAL_IMAGES % cfg.BATCH_SIZE and drop_last != True:
all_step += 1
all_step *= (cfg.SOLVER.NUM_EPOCHS - begin_epoch + 1)
avg_loss = 0.0
avg_seg_loss = 0.0
avg_emb_loss = 0.0
avg_acc = 0.0
avg_fp = 0.0
avg_fn = 0.0
timer = Timer()
timer.start()
if begin_epoch > cfg.SOLVER.NUM_EPOCHS:
raise ValueError(
("begin epoch[{}] is larger than cfg.SOLVER.NUM_EPOCHS[{}]").format(
begin_epoch, cfg.SOLVER.NUM_EPOCHS))
if args.use_mpio:
print_info("Use multiprocess reader")
else:
print_info("Use multi-thread reader")
for epoch in range(begin_epoch, cfg.SOLVER.NUM_EPOCHS + 1):
py_reader.start()
while True:
try:
# If not in debug mode, avoid unnessary log and calculate
loss, lr, out_seg_loss, out_emb_loss, out_acc, out_fp, out_fn = exe.run(
program=compiled_train_prog,
fetch_list=fetch_list,
return_numpy=True)
avg_loss += np.mean(np.array(loss))
avg_seg_loss += np.mean(np.array(out_seg_loss))
avg_emb_loss += np.mean(np.array(out_emb_loss))
avg_acc += np.mean(out_acc)
avg_fp += np.mean(out_fp)
avg_fn += np.mean(out_fn)
global_step += 1
if global_step % args.log_steps == 0 and cfg.TRAINER_ID == 0:
avg_loss /= args.log_steps
avg_seg_loss /= args.log_steps
avg_emb_loss /= args.log_steps
avg_acc /= args.log_steps
avg_fp /= args.log_steps
avg_fn /= args.log_steps
speed = args.log_steps / timer.elapsed_time()
print((
"epoch={} step={} lr={:.5f} loss={:.4f} seg_loss={:.4f} emb_loss={:.4f} accuracy={:.4} fp={:.4} fn={:.4} step/sec={:.3f} | ETA {}"
).format(epoch, global_step, lr[0], avg_loss, avg_seg_loss, avg_emb_loss, avg_acc, avg_fp, avg_fn, speed,
calculate_eta(all_step - global_step, speed)))
if args.use_tb:
log_writer.add_scalar('Train/loss', avg_loss,
global_step)
log_writer.add_scalar('Train/lr', lr[0],
global_step)
log_writer.add_scalar('Train/speed', speed,
global_step)
sys.stdout.flush()
avg_loss = 0.0
avg_seg_loss = 0.0
avg_emb_loss = 0.0
avg_acc = 0.0
avg_fp = 0.0
avg_fn = 0.0
timer.restart()
except fluid.core.EOFException:
py_reader.reset()
break
except Exception as e:
print(e)
if epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0 and cfg.TRAINER_ID == 0:
ckpt_dir = save_checkpoint(exe, train_prog, epoch)
if args.do_eval:
print("Evaluation start")
accuracy, fp, fn = evaluate(
cfg=cfg,
ckpt_dir=ckpt_dir,
use_gpu=args.use_gpu,
use_mpio=args.use_mpio)
if args.use_tb:
log_writer.add_scalar('Evaluate/accuracy', accuracy,
global_step)
log_writer.add_scalar('Evaluate/fp', fp,
global_step)
log_writer.add_scalar('Evaluate/fn', fn,
global_step)
# Use Tensorboard to visualize results
if args.use_tb and cfg.DATASET.VIS_FILE_LIST is not None:
visualize(
cfg=cfg,
use_gpu=args.use_gpu,
vis_file_list=cfg.DATASET.VIS_FILE_LIST,
vis_dir="visual",
ckpt_dir=ckpt_dir,
log_writer=log_writer)
# save final model
if cfg.TRAINER_ID == 0:
save_checkpoint(exe, train_prog, 'final')
def main(args):
if args.cfg_file is not None:
cfg.update_from_file(args.cfg_file)
if args.opts:
cfg.update_from_list(args.opts)
cfg.TRAINER_ID = int(os.getenv("PADDLE_TRAINER_ID", 0))
cfg.NUM_TRAINERS = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
cfg.check_and_infer()
print_info(pprint.pformat(cfg))
train(cfg)
if __name__ == '__main__':
args = parse_args()
if fluid.core.is_compiled_with_cuda() != True and args.use_gpu == True:
print(
"You can not set use_gpu = True in the model because you are using paddlepaddle-cpu."
)
print(
"Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_gpu=False to run models on CPU."
)
sys.exit(1)
main(args)
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import unicode_literals
import os
import sys
# LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
# PDSEG_PATH = os.path.join(LOCAL_PATH, "../../../", "pdseg")
# print(PDSEG_PATH)
# sys.path.insert(0, PDSEG_PATH)
# print(sys.path)
from pdseg.utils.collect import SegConfig
import numpy as np
cfg = SegConfig()
########################## 基本配置 ###########################################
# 均值,图像预处理减去的均值
cfg.MEAN = [0.5, 0.5, 0.5]
# 标准差,图像预处理除以标准差·
cfg.STD = [0.5, 0.5, 0.5]
# 批处理大小
cfg.BATCH_SIZE = 1
# 验证时图像裁剪尺寸(宽,高)
cfg.EVAL_CROP_SIZE = tuple()
# 训练时图像裁剪尺寸(宽,高)
cfg.TRAIN_CROP_SIZE = tuple()
# 多进程训练总进程数
cfg.NUM_TRAINERS = 1
# 多进程训练进程ID
cfg.TRAINER_ID = 0
# 每张gpu上的批大小,无需设置,程序会自动根据batch调整
cfg.BATCH_SIZE_PER_DEV = 1
########################## 数据载入配置 #######################################
# 数据载入时的并发数, 建议值8
cfg.DATALOADER.NUM_WORKERS = 8
# 数据载入时缓存队列大小, 建议值256
cfg.DATALOADER.BUF_SIZE = 256
########################## 数据集配置 #########################################
# 数据主目录目录
cfg.DATASET.DATA_DIR = './dataset/cityscapes/'
# 训练集列表
cfg.DATASET.TRAIN_FILE_LIST = './dataset/cityscapes/train.list'
# 训练集数量
cfg.DATASET.TRAIN_TOTAL_IMAGES = 2975
# 验证集列表
cfg.DATASET.VAL_FILE_LIST = './dataset/cityscapes/val.list'
# 验证数据数量
cfg.DATASET.VAL_TOTAL_IMAGES = 500
# 测试数据列表
cfg.DATASET.TEST_FILE_LIST = './dataset/cityscapes/test.list'
# 测试数据数量
cfg.DATASET.TEST_TOTAL_IMAGES = 500
# Tensorboard 可视化的数据集
cfg.DATASET.VIS_FILE_LIST = None
# 类别数(需包括背景类)
cfg.DATASET.NUM_CLASSES = 19
# 输入图像类型, 支持三通道'rgb',四通道'rgba',单通道灰度图'gray'
cfg.DATASET.IMAGE_TYPE = 'rgb'
# 输入图片的通道数
cfg.DATASET.DATA_DIM = 3
# 数据列表分割符, 默认为空格
cfg.DATASET.SEPARATOR = ' '
# 忽略的像素标签值, 默认为255,一般无需改动
cfg.DATASET.IGNORE_INDEX = 255
# 数据增强是图像的padding值
cfg.DATASET.PADDING_VALUE = [127.5,127.5,127.5]
########################### 数据增强配置 ######################################
# 图像镜像左右翻转
cfg.AUG.MIRROR = True
# 图像上下翻转开关,True/False
cfg.AUG.FLIP = False
# 图像启动上下翻转的概率,0-1
cfg.AUG.FLIP_RATIO = 0.5
# 图像resize的固定尺寸(宽,高),非负
cfg.AUG.FIX_RESIZE_SIZE = tuple()
# 图像resize的方式有三种:
# unpadding(固定尺寸),stepscaling(按比例resize),rangescaling(长边对齐)
cfg.AUG.AUG_METHOD = 'rangescaling'
# 图像resize方式为stepscaling,resize最小尺度,非负
cfg.AUG.MIN_SCALE_FACTOR = 0.5
# 图像resize方式为stepscaling,resize最大尺度,不小于MIN_SCALE_FACTOR
cfg.AUG.MAX_SCALE_FACTOR = 2.0
# 图像resize方式为stepscaling,resize尺度范围间隔,非负
cfg.AUG.SCALE_STEP_SIZE = 0.25
# 图像resize方式为rangescaling,训练时长边resize的范围最小值,非负
cfg.AUG.MIN_RESIZE_VALUE = 400
# 图像resize方式为rangescaling,训练时长边resize的范围最大值,
# 不小于MIN_RESIZE_VALUE
cfg.AUG.MAX_RESIZE_VALUE = 600
# 图像resize方式为rangescaling, 测试验证可视化模式下长边resize的长度,
# 在MIN_RESIZE_VALUE到MAX_RESIZE_VALUE范围内
cfg.AUG.INF_RESIZE_VALUE = 500
# RichCrop数据增广开关,用于提升模型鲁棒性
cfg.AUG.RICH_CROP.ENABLE = False
# 图像旋转最大角度,0-90
cfg.AUG.RICH_CROP.MAX_ROTATION = 15
# 裁取图像与原始图像面积比,0-1
cfg.AUG.RICH_CROP.MIN_AREA_RATIO = 0.5
# 裁取图像宽高比范围,非负
cfg.AUG.RICH_CROP.ASPECT_RATIO = 0.33
# 亮度调节范围,0-1
cfg.AUG.RICH_CROP.BRIGHTNESS_JITTER_RATIO = 0.5
# 饱和度调节范围,0-1
cfg.AUG.RICH_CROP.SATURATION_JITTER_RATIO = 0.5
# 对比度调节范围,0-1
cfg.AUG.RICH_CROP.CONTRAST_JITTER_RATIO = 0.5
# 图像模糊开关,True/False
cfg.AUG.RICH_CROP.BLUR = False
# 图像启动模糊百分比,0-1
cfg.AUG.RICH_CROP.BLUR_RATIO = 0.1
########################### 训练配置 ##########################################
# 模型保存路径
cfg.TRAIN.MODEL_SAVE_DIR = ''
# 预训练模型路径
cfg.TRAIN.PRETRAINED_MODEL_DIR = ''
# 是否resume,继续训练
cfg.TRAIN.RESUME_MODEL_DIR = ''
# 是否使用多卡间同步BatchNorm均值和方差
cfg.TRAIN.SYNC_BATCH_NORM = False
# 模型参数保存的epoch间隔数,可用来继续训练中断的模型
cfg.TRAIN.SNAPSHOT_EPOCH = 10
########################### 模型优化相关配置 ##################################
# 初始学习率
cfg.SOLVER.LR = 0.1
# 学习率下降方法, 支持poly piecewise cosine 三种
cfg.SOLVER.LR_POLICY = "poly"
# 优化算法, 支持SGD和Adam两种算法
cfg.SOLVER.OPTIMIZER = "sgd"
# 动量参数
cfg.SOLVER.MOMENTUM = 0.9
# 二阶矩估计的指数衰减率
cfg.SOLVER.MOMENTUM2 = 0.999
# 学习率Poly下降指数
cfg.SOLVER.POWER = 0.9
# step下降指数
cfg.SOLVER.GAMMA = 0.1
# step下降间隔
cfg.SOLVER.DECAY_EPOCH = [10, 20]
# 学习率权重衰减,0-1
cfg.SOLVER.WEIGHT_DECAY = 0.00004
# 训练开始epoch数,默认为1
cfg.SOLVER.BEGIN_EPOCH = 1
# 训练epoch数,正整数
cfg.SOLVER.NUM_EPOCHS = 30
# loss的选择,支持softmax_loss, bce_loss, dice_loss
cfg.SOLVER.LOSS = ["softmax_loss"]
# cross entropy weight, 默认为None,如果设置为'dynamic',会根据每个batch中各个类别的数目,
# 动态调整类别权重。
# 也可以设置一个静态权重(list的方式),比如有3类,每个类别权重可以设置为[0.1, 2.0, 0.9]
cfg.SOLVER.CROSS_ENTROPY_WEIGHT = None
########################## 测试配置 ###########################################
# 测试模型路径
cfg.TEST.TEST_MODEL = ''
########################## 模型通用配置 #######################################
# 模型名称, 支持deeplab, unet, icnet三种
cfg.MODEL.MODEL_NAME = ''
# BatchNorm类型: bn、gn(group_norm)
cfg.MODEL.DEFAULT_NORM_TYPE = 'bn'
# 多路损失加权值
cfg.MODEL.MULTI_LOSS_WEIGHT = [1.0]
# DEFAULT_NORM_TYPE为gn时group数
cfg.MODEL.DEFAULT_GROUP_NUMBER = 32
# 极小值, 防止分母除0溢出,一般无需改动
cfg.MODEL.DEFAULT_EPSILON = 1e-5
# BatchNorm动量, 一般无需改动
cfg.MODEL.BN_MOMENTUM = 0.99
# 是否使用FP16训练
cfg.MODEL.FP16 = False
# 混合精度训练需对LOSS进行scale, 默认为动态scale,静态scale可以设置为512.0
cfg.MODEL.SCALE_LOSS = "DYNAMIC"
########################## DeepLab模型配置 ####################################
# DeepLab backbone 配置, 可选项xception_65, mobilenetv2
cfg.MODEL.DEEPLAB.BACKBONE = "xception_65"
# DeepLab output stride
cfg.MODEL.DEEPLAB.OUTPUT_STRIDE = 16
# MobileNet backbone scale 设置
cfg.MODEL.DEEPLAB.DEPTH_MULTIPLIER = 1.0
# MobileNet backbone scale 设置
cfg.MODEL.DEEPLAB.ENCODER_WITH_ASPP = True
# MobileNet backbone scale 设置
cfg.MODEL.DEEPLAB.ENABLE_DECODER = True
# ASPP是否使用可分离卷积
cfg.MODEL.DEEPLAB.ASPP_WITH_SEP_CONV = True
# 解码器是否使用可分离卷积
cfg.MODEL.DEEPLAB.DECODER_USE_SEP_CONV = True
########################## UNET模型配置 #######################################
# 上采样方式, 默认为双线性插值
cfg.MODEL.UNET.UPSAMPLE_MODE = 'bilinear'
########################## ICNET模型配置 ######################################
# RESNET backbone scale 设置
cfg.MODEL.ICNET.DEPTH_MULTIPLIER = 0.5
# RESNET 层数 设置
cfg.MODEL.ICNET.LAYERS = 50
########################## PSPNET模型配置 ######################################
# Lannet backbone name
cfg.MODEL.LANENET.BACKBONE = "vgg"
########################## LaneNet模型配置 ######################################
########################## 预测部署模型配置 ###################################
# 预测保存的模型名称
cfg.FREEZE.MODEL_FILENAME = '__model__'
# 预测保存的参数名称
cfg.FREEZE.PARAMS_FILENAME = '__params__'
# 预测模型参数保存的路径
cfg.FREEZE.SAVE_DIR = 'freeze_model'
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddle.fluid as fluid
def nccl2_prepare(args, startup_prog, main_prog):
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
envs = args.dist_env
t.transpile(
envs["trainer_id"],
trainers=','.join(envs["trainer_endpoints"]),
current_endpoint=envs["current_endpoint"],
startup_program=startup_prog,
program=main_prog)
def pserver_prepare(args, train_prog, startup_prog):
config = fluid.DistributeTranspilerConfig()
config.slice_var_up = args.split_var
t = fluid.DistributeTranspiler(config=config)
envs = args.dist_env
training_role = envs["training_role"]
t.transpile(
envs["trainer_id"],
program=train_prog,
pservers=envs["pserver_endpoints"],
trainers=envs["num_trainers"],
sync_mode=not args.async_mode,
startup_program=startup_prog)
if training_role == "PSERVER":
pserver_program = t.get_pserver_program(envs["current_endpoint"])
pserver_startup_program = t.get_startup_program(
envs["current_endpoint"],
pserver_program,
startup_program=startup_prog)
return pserver_program, pserver_startup_program
elif training_role == "TRAINER":
train_program = t.get_trainer_program()
return train_program, startup_prog
else:
raise ValueError(
'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
)
def nccl2_prepare_paddle(trainer_id, startup_prog, main_prog):
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
trainers=os.environ.get('PADDLE_TRAINER_ENDPOINTS'),
current_endpoint=os.environ.get('PADDLE_CURRENT_ENDPOINT'),
startup_program=startup_prog,
program=main_prog)
def prepare_for_multi_process(exe, build_strategy, train_prog):
# prepare for multi-process
trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0))
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if num_trainers < 2: return
build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id
# NOTE(zcd): use multi processes to train the model,
# and each process use one GPU card.
startup_prog = fluid.Program()
nccl2_prepare_paddle(trainer_id, startup_prog, train_prog)
# the startup_prog are run two times, but it doesn't matter.
exe.run(startup_prog)
"""
generate tusimple training dataset
"""
import argparse
import glob
import json
import os
import os.path as ops
import shutil
import cv2
import numpy as np
def init_args():
parser = argparse.ArgumentParser()
parser.add_argument('--src_dir', type=str, help='The origin path of unzipped tusimple dataset')
return parser.parse_args()
def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, instance_dst_dir):
assert ops.exists(json_file_path), '{:s} not exist'.format(json_file_path)
image_nums = len(os.listdir(os.path.join(src_dir, ori_dst_dir)))
with open(json_file_path, 'r') as file:
for line_index, line in enumerate(file):
info_dict = json.loads(line)
image_dir = ops.split(info_dict['raw_file'])[0]
image_dir_split = image_dir.split('/')[1:]
image_dir_split.append(ops.split(info_dict['raw_file'])[1])
image_name = '_'.join(image_dir_split)
image_path = ops.join(src_dir, info_dict['raw_file'])
assert ops.exists(image_path), '{:s} not exist'.format(image_path)
h_samples = info_dict['h_samples']
lanes = info_dict['lanes']
image_name_new = '{:s}.png'.format('{:d}'.format(line_index + image_nums).zfill(4))
src_image = cv2.imread(image_path, cv2.IMREAD_COLOR)
dst_binary_image = np.zeros([src_image.shape[0], src_image.shape[1]], np.uint8)
dst_instance_image = np.zeros([src_image.shape[0], src_image.shape[1]], np.uint8)
for lane_index, lane in enumerate(lanes):
assert len(h_samples) == len(lane)
lane_x = []
lane_y = []
for index in range(len(lane)):
if lane[index] == -2:
continue
else:
ptx = lane[index]
pty = h_samples[index]
lane_x.append(ptx)
lane_y.append(pty)
if not lane_x:
continue
lane_pts = np.vstack((lane_x, lane_y)).transpose()
lane_pts = np.array([lane_pts], np.int64)
cv2.polylines(dst_binary_image, lane_pts, isClosed=False,
color=255, thickness=5)
cv2.polylines(dst_instance_image, lane_pts, isClosed=False,
color=lane_index * 50 + 20, thickness=5)
dst_binary_image_path = ops.join(src_dir, binary_dst_dir, image_name_new)
dst_instance_image_path = ops.join(src_dir, instance_dst_dir, image_name_new)
dst_rgb_image_path = ops.join(src_dir, ori_dst_dir, image_name_new)
cv2.imwrite(dst_binary_image_path, dst_binary_image)
cv2.imwrite(dst_instance_image_path, dst_instance_image)
cv2.imwrite(dst_rgb_image_path, src_image)
print('Process {:s} success'.format(image_name))
def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train', split=False):
label_list = []
with open('{:s}/{}ing/{}.txt'.format(src_dir, phase, phase), 'w') as file:
for image_name in os.listdir(b_gt_image_dir):
if not image_name.endswith('.png'):
continue
binary_gt_image_path = ops.join(b_gt_image_dir, image_name)
instance_gt_image_path = ops.join(i_gt_image_dir, image_name)
image_path = ops.join(image_dir, image_name)
assert ops.exists(image_path), '{:s} not exist'.format(image_path)
assert ops.exists(instance_gt_image_path), '{:s} not exist'.format(instance_gt_image_path)
b_gt_image = cv2.imread(binary_gt_image_path, cv2.IMREAD_COLOR)
i_gt_image = cv2.imread(instance_gt_image_path, cv2.IMREAD_COLOR)
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
if b_gt_image is None or image is None or i_gt_image is None:
print('image: {:s} corrupt'.format(image_name))
continue
else:
info = '{:s} {:s} {:s}'.format(image_path, binary_gt_image_path, instance_gt_image_path)
file.write(info + '\n')
label_list.append(info)
if phase == 'train' and split:
np.random.RandomState(0).shuffle(label_list)
val_list_len = len(label_list) // 10
val_label_list = label_list[:val_list_len]
train_label_list = label_list[val_list_len:]
with open('{:s}/{}ing/train_part.txt'.format(src_dir, phase, phase), 'w') as file:
for info in train_label_list:
file.write(info + '\n')
with open('{:s}/{}ing/val_part.txt'.format(src_dir, phase, phase), 'w') as file:
for info in val_label_list:
file.write(info + '\n')
return
def process_tusimple_dataset(src_dir):
traing_folder_path = ops.join(src_dir, 'training')
testing_folder_path = ops.join(src_dir, 'testing')
os.makedirs(traing_folder_path, exist_ok=True)
os.makedirs(testing_folder_path, exist_ok=True)
for json_label_path in glob.glob('{:s}/label*.json'.format(src_dir)):
json_label_name = ops.split(json_label_path)[1]
shutil.copyfile(json_label_path, ops.join(traing_folder_path, json_label_name))
for json_label_path in glob.glob('{:s}/test_label.json'.format(src_dir)):
json_label_name = ops.split(json_label_path)[1]
shutil.copyfile(json_label_path, ops.join(testing_folder_path, json_label_name))
train_gt_image_dir = ops.join('training', 'gt_image')
train_gt_binary_dir = ops.join('training', 'gt_binary_image')
train_gt_instance_dir = ops.join('training', 'gt_instance_image')
test_gt_image_dir = ops.join('testing', 'gt_image')
test_gt_binary_dir = ops.join('testing', 'gt_binary_image')
test_gt_instance_dir = ops.join('testing', 'gt_instance_image')
os.makedirs(os.path.join(src_dir, train_gt_image_dir), exist_ok=True)
os.makedirs(os.path.join(src_dir, train_gt_binary_dir), exist_ok=True)
os.makedirs(os.path.join(src_dir, train_gt_instance_dir), exist_ok=True)
os.makedirs(os.path.join(src_dir, test_gt_image_dir), exist_ok=True)
os.makedirs(os.path.join(src_dir, test_gt_binary_dir), exist_ok=True)
os.makedirs(os.path.join(src_dir, test_gt_instance_dir), exist_ok=True)
for json_label_path in glob.glob('{:s}/*.json'.format(traing_folder_path)):
process_json_file(json_label_path, src_dir, train_gt_image_dir, train_gt_binary_dir, train_gt_instance_dir)
gen_sample(src_dir, train_gt_binary_dir, train_gt_instance_dir, train_gt_image_dir, 'train', True)
if __name__ == '__main__':
args = init_args()
process_tusimple_dataset(args.src_dir)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# this code heavily base on https://github.com/MaybeShewill-CV/lanenet-lane-detection/blob/master/lanenet_model/lanenet_postprocess.py
"""
LaneNet model post process
"""
import os.path as ops
import math
import cv2
import time
import numpy as np
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler
def _morphological_process(image, kernel_size=5):
"""
morphological process to fill the hole in the binary segmentation result
:param image:
:param kernel_size:
:return:
"""
if len(image.shape) == 3:
raise ValueError('Binary segmentation result image should be a single channel image')
if image.dtype is not np.uint8:
image = np.array(image, np.uint8)
kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(kernel_size, kernel_size))
# close operation fille hole
closing = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel, iterations=1)
return closing
def _connect_components_analysis(image):
"""
connect components analysis to remove the small components
:param image:
:return:
"""
if len(image.shape) == 3:
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray_image = image
return cv2.connectedComponentsWithStats(gray_image, connectivity=8, ltype=cv2.CV_32S)
class _LaneFeat(object):
"""
"""
def __init__(self, feat, coord, class_id=-1):
"""
lane feat object
:param feat: lane embeddng feats [feature_1, feature_2, ...]
:param coord: lane coordinates [x, y]
:param class_id: lane class id
"""
self._feat = feat
self._coord = coord
self._class_id = class_id
@property
def feat(self):
return self._feat
@feat.setter
def feat(self, value):
if not isinstance(value, np.ndarray):
value = np.array(value, dtype=np.float64)
if value.dtype != np.float32:
value = np.array(value, dtype=np.float64)
self._feat = value
@property
def coord(self):
return self._coord
@coord.setter
def coord(self, value):
if not isinstance(value, np.ndarray):
value = np.array(value)
if value.dtype != np.int32:
value = np.array(value, dtype=np.int32)
self._coord = value
@property
def class_id(self):
return self._class_id
@class_id.setter
def class_id(self, value):
if not isinstance(value, np.int64):
raise ValueError('Class id must be integer')
self._class_id = value
class _LaneNetCluster(object):
"""
Instance segmentation result cluster
"""
def __init__(self):
"""
"""
self._color_map = [np.array([255, 0, 0]),
np.array([0, 255, 0]),
np.array([0, 0, 255]),
np.array([125, 125, 0]),
np.array([0, 125, 125]),
np.array([125, 0, 125]),
np.array([50, 100, 50]),
np.array([100, 50, 100])]
@staticmethod
def _embedding_feats_dbscan_cluster(embedding_image_feats):
"""
dbscan cluster
"""
db = DBSCAN(eps=0.4, min_samples=500)
try:
features = StandardScaler().fit_transform(embedding_image_feats)
db.fit(features)
except Exception as err:
print(err)
ret = {
'origin_features': None,
'cluster_nums': 0,
'db_labels': None,
'unique_labels': None,
'cluster_center': None
}
return ret
db_labels = db.labels_
unique_labels = np.unique(db_labels)
num_clusters = len(unique_labels)
cluster_centers = db.components_
ret = {
'origin_features': features,
'cluster_nums': num_clusters,
'db_labels': db_labels,
'unique_labels': unique_labels,
'cluster_center': cluster_centers
}
return ret
@staticmethod
def _get_lane_embedding_feats(binary_seg_ret, instance_seg_ret):
"""
get lane embedding features according the binary seg result
"""
idx = np.where(binary_seg_ret == 255)
lane_embedding_feats = instance_seg_ret[idx]
lane_coordinate = np.vstack((idx[1], idx[0])).transpose()
assert lane_embedding_feats.shape[0] == lane_coordinate.shape[0]
ret = {
'lane_embedding_feats': lane_embedding_feats,
'lane_coordinates': lane_coordinate
}
return ret
def apply_lane_feats_cluster(self, binary_seg_result, instance_seg_result):
"""
:param binary_seg_result:
:param instance_seg_result:
:return:
"""
# get embedding feats and coords
get_lane_embedding_feats_result = self._get_lane_embedding_feats(
binary_seg_ret=binary_seg_result,
instance_seg_ret=instance_seg_result
)
# dbscan cluster
dbscan_cluster_result = self._embedding_feats_dbscan_cluster(
embedding_image_feats=get_lane_embedding_feats_result['lane_embedding_feats']
)
mask = np.zeros(shape=[binary_seg_result.shape[0], binary_seg_result.shape[1], 3], dtype=np.uint8)
db_labels = dbscan_cluster_result['db_labels']
unique_labels = dbscan_cluster_result['unique_labels']
coord = get_lane_embedding_feats_result['lane_coordinates']
if db_labels is None:
return None, None
lane_coords = []
for index, label in enumerate(unique_labels.tolist()):
if label == -1:
continue
idx = np.where(db_labels == label)
pix_coord_idx = tuple((coord[idx][:, 1], coord[idx][:, 0]))
mask[pix_coord_idx] = self._color_map[index]
lane_coords.append(coord[idx])
return mask, lane_coords
class LaneNetPostProcessor(object):
"""
lanenet post process for lane generation
"""
def __init__(self, ipm_remap_file_path='./utils/tusimple_ipm_remap.yml'):
"""
convert front car view to bird view
"""
assert ops.exists(ipm_remap_file_path), '{:s} not exist'.format(ipm_remap_file_path)
self._cluster = _LaneNetCluster()
self._ipm_remap_file_path = ipm_remap_file_path
remap_file_load_ret = self._load_remap_matrix()
self._remap_to_ipm_x = remap_file_load_ret['remap_to_ipm_x']
self._remap_to_ipm_y = remap_file_load_ret['remap_to_ipm_y']
self._color_map = [np.array([255, 0, 0]),
np.array([0, 255, 0]),
np.array([0, 0, 255]),
np.array([125, 125, 0]),
np.array([0, 125, 125]),
np.array([125, 0, 125]),
np.array([50, 100, 50]),
np.array([100, 50, 100])]
def _load_remap_matrix(self):
fs = cv2.FileStorage(self._ipm_remap_file_path, cv2.FILE_STORAGE_READ)
remap_to_ipm_x = fs.getNode('remap_ipm_x').mat()
remap_to_ipm_y = fs.getNode('remap_ipm_y').mat()
ret = {
'remap_to_ipm_x': remap_to_ipm_x,
'remap_to_ipm_y': remap_to_ipm_y,
}
fs.release()
return ret
def postprocess(self, binary_seg_result, instance_seg_result=None,
min_area_threshold=100, source_image=None,
data_source='tusimple'):
# convert binary_seg_result
binary_seg_result = np.array(binary_seg_result * 255, dtype=np.uint8)
# apply image morphology operation to fill in the hold and reduce the small area
morphological_ret = _morphological_process(binary_seg_result, kernel_size=5)
connect_components_analysis_ret = _connect_components_analysis(image=morphological_ret)
labels = connect_components_analysis_ret[1]
stats = connect_components_analysis_ret[2]
for index, stat in enumerate(stats):
if stat[4] <= min_area_threshold:
idx = np.where(labels == index)
morphological_ret[idx] = 0
# apply embedding features cluster
mask_image, lane_coords = self._cluster.apply_lane_feats_cluster(
binary_seg_result=morphological_ret,
instance_seg_result=instance_seg_result
)
if mask_image is None:
return {
'mask_image': None,
'fit_params': None,
'source_image': None,
}
# lane line fit
fit_params = []
src_lane_pts = []
for lane_index, coords in enumerate(lane_coords):
if data_source == 'tusimple':
tmp_mask = np.zeros(shape=(720, 1280), dtype=np.uint8)
tmp_mask[tuple((np.int_(coords[:, 1] * 720 / 256), np.int_(coords[:, 0] * 1280 / 512)))] = 255
else:
raise ValueError('Wrong data source now only support tusimple')
tmp_ipm_mask = cv2.remap(
tmp_mask,
self._remap_to_ipm_x,
self._remap_to_ipm_y,
interpolation=cv2.INTER_NEAREST
)
nonzero_y = np.array(tmp_ipm_mask.nonzero()[0])
nonzero_x = np.array(tmp_ipm_mask.nonzero()[1])
fit_param = np.polyfit(nonzero_y, nonzero_x, 2)
fit_params.append(fit_param)
[ipm_image_height, ipm_image_width] = tmp_ipm_mask.shape
plot_y = np.linspace(10, ipm_image_height, ipm_image_height - 10)
fit_x = fit_param[0] * plot_y ** 2 + fit_param[1] * plot_y + fit_param[2]
lane_pts = []
for index in range(0, plot_y.shape[0], 5):
src_x = self._remap_to_ipm_x[
int(plot_y[index]), int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
if src_x <= 0:
continue
src_y = self._remap_to_ipm_y[
int(plot_y[index]), int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
src_y = src_y if src_y > 0 else 0
lane_pts.append([src_x, src_y])
src_lane_pts.append(lane_pts)
# tusimple test data sample point along y axis every 10 pixels
source_image_width = source_image.shape[1]
for index, single_lane_pts in enumerate(src_lane_pts):
single_lane_pt_x = np.array(single_lane_pts, dtype=np.float32)[:, 0]
single_lane_pt_y = np.array(single_lane_pts, dtype=np.float32)[:, 1]
if data_source == 'tusimple':
start_plot_y = 240
end_plot_y = 720
else:
raise ValueError('Wrong data source now only support tusimple')
step = int(math.floor((end_plot_y - start_plot_y) / 10))
for plot_y in np.linspace(start_plot_y, end_plot_y, step):
diff = single_lane_pt_y - plot_y
fake_diff_bigger_than_zero = diff.copy()
fake_diff_smaller_than_zero = diff.copy()
fake_diff_bigger_than_zero[np.where(diff <= 0)] = float('inf')
fake_diff_smaller_than_zero[np.where(diff > 0)] = float('-inf')
idx_low = np.argmax(fake_diff_smaller_than_zero)
idx_high = np.argmin(fake_diff_bigger_than_zero)
previous_src_pt_x = single_lane_pt_x[idx_low]
previous_src_pt_y = single_lane_pt_y[idx_low]
last_src_pt_x = single_lane_pt_x[idx_high]
last_src_pt_y = single_lane_pt_y[idx_high]
if previous_src_pt_y < start_plot_y or last_src_pt_y < start_plot_y or \
fake_diff_smaller_than_zero[idx_low] == float('-inf') or \
fake_diff_bigger_than_zero[idx_high] == float('inf'):
continue
interpolation_src_pt_x = (abs(previous_src_pt_y - plot_y) * previous_src_pt_x +
abs(last_src_pt_y - plot_y) * last_src_pt_x) / \
(abs(previous_src_pt_y - plot_y) + abs(last_src_pt_y - plot_y))
interpolation_src_pt_y = (abs(previous_src_pt_y - plot_y) * previous_src_pt_y +
abs(last_src_pt_y - plot_y) * last_src_pt_y) / \
(abs(previous_src_pt_y - plot_y) + abs(last_src_pt_y - plot_y))
if interpolation_src_pt_x > source_image_width or interpolation_src_pt_x < 10:
continue
lane_color = self._color_map[index].tolist()
cv2.circle(source_image, (int(interpolation_src_pt_x),
int(interpolation_src_pt_y)), 5, lane_color, -1)
ret = {
'mask_image': mask_image,
'fit_params': fit_params,
'source_image': source_image,
}
return ret
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# GPU memory garbage collection optimization flags
os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
import sys
cur_path = os.path.abspath(os.path.dirname(__file__))
root_path = os.path.split(os.path.split(cur_path)[0])[0]
SEG_PATH = os.path.join(cur_path, "../../../")
sys.path.append(SEG_PATH)
sys.path.append(root_path)
import matplotlib
matplotlib.use('Agg')
import time
import argparse
import pprint
import cv2
import numpy as np
import paddle.fluid as fluid
from utils.config import cfg
from reader import LaneNetDataset
from models.model_builder import build_model
from models.model_builder import ModelPhase
from utils import lanenet_postprocess
import matplotlib.pyplot as plt
def parse_args():
parser = argparse.ArgumentParser(description='PaddeSeg visualization tools')
parser.add_argument(
'--cfg',
dest='cfg_file',
help='Config file for training (and optionally testing)',
default=None,
type=str)
parser.add_argument(
'--use_gpu', dest='use_gpu', help='Use gpu or cpu', action='store_true')
parser.add_argument(
'--vis_dir',
dest='vis_dir',
help='visual save dir',
type=str,
default='visual')
parser.add_argument(
'--also_save_raw_results',
dest='also_save_raw_results',
help='whether to save raw result',
action='store_true')
parser.add_argument(
'--local_test',
dest='local_test',
help='if in local test mode, only visualize 5 images for testing',
action='store_true')
parser.add_argument(
'opts',
help='See config.py for all options',
default=None,
nargs=argparse.REMAINDER)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def makedirs(directory):
if not os.path.exists(directory):
os.makedirs(directory)
def to_png_fn(fn, name=""):
"""
Append png as filename postfix
"""
directory, filename = os.path.split(fn)
basename, ext = os.path.splitext(filename)
return basename + name + ".png"
def minmax_scale(input_arr):
min_val = np.min(input_arr)
max_val = np.max(input_arr)
output_arr = (input_arr - min_val) * 255.0 / (max_val - min_val)
return output_arr
def visualize(cfg,
vis_file_list=None,
use_gpu=False,
vis_dir="visual",
also_save_raw_results=False,
ckpt_dir=None,
log_writer=None,
local_test=False,
**kwargs):
if vis_file_list is None:
vis_file_list = cfg.DATASET.TEST_FILE_LIST
dataset = LaneNetDataset(
file_list=vis_file_list,
mode=ModelPhase.VISUAL,
shuffle=True,
data_dir=cfg.DATASET.DATA_DIR)
startup_prog = fluid.Program()
test_prog = fluid.Program()
pred, logit = build_model(test_prog, startup_prog, phase=ModelPhase.VISUAL)
# Clone forward graph
test_prog = test_prog.clone(for_test=True)
# Get device environment
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
ckpt_dir = cfg.TEST.TEST_MODEL if not ckpt_dir else ckpt_dir
fluid.io.load_params(exe, ckpt_dir, main_program=test_prog)
save_dir = os.path.join(vis_dir, 'visual_results')
makedirs(save_dir)
if also_save_raw_results:
raw_save_dir = os.path.join(vis_dir, 'raw_results')
makedirs(raw_save_dir)
fetch_list = [pred.name, logit.name]
test_reader = dataset.batch(dataset.generator, batch_size=1, is_test=True)
postprocessor = lanenet_postprocess.LaneNetPostProcessor()
for imgs, grts, grts_instance, img_names, valid_shapes, org_imgs in test_reader:
segLogits, emLogits = exe.run(
program=test_prog,
feed={'image': imgs},
fetch_list=fetch_list,
return_numpy=True)
num_imgs = segLogits.shape[0]
for i in range(num_imgs):
gt_image = org_imgs[i]
binary_seg_image, instance_seg_image = segLogits[i].squeeze(-1), emLogits[i].transpose((1,2,0))
postprocess_result = postprocessor.postprocess(
binary_seg_result=binary_seg_image,
instance_seg_result=instance_seg_image,
source_image=gt_image
)
pred_binary_fn = os.path.join(save_dir, to_png_fn(img_names[i], name='_pred_binary'))
pred_lane_fn = os.path.join(save_dir, to_png_fn(img_names[i], name='_pred_lane'))
pred_instance_fn = os.path.join(save_dir, to_png_fn(img_names[i], name='_pred_instance'))
dirname = os.path.dirname(pred_binary_fn)
makedirs(dirname)
mask_image = postprocess_result['mask_image']
for i in range(4):
instance_seg_image[:, :, i] = minmax_scale(instance_seg_image[:, :, i])
embedding_image = np.array(instance_seg_image).astype(np.uint8)
plt.figure('mask_image')
plt.imshow(mask_image[:, :, (2, 1, 0)])
plt.figure('src_image')
plt.imshow(gt_image[:, :, (2, 1, 0)])
plt.figure('instance_image')
plt.imshow(embedding_image[:, :, (2, 1, 0)])
plt.figure('binary_image')
plt.imshow(binary_seg_image * 255, cmap='gray')
plt.show()
cv2.imwrite(pred_binary_fn, np.array(binary_seg_image * 255).astype(np.uint8))
cv2.imwrite(pred_lane_fn, postprocess_result['source_image'])
cv2.imwrite(pred_instance_fn, mask_image)
print(pred_lane_fn, 'saved!')
if __name__ == '__main__':
args = parse_args()
if args.cfg_file is not None:
cfg.update_from_file(args.cfg_file)
if args.opts:
cfg.update_from_list(args.opts)
cfg.check_and_infer()
print(pprint.pformat(cfg))
visualize(cfg, **args.__dict__)
......@@ -20,7 +20,7 @@ import importlib
from utils.config import cfg
def softmax_with_loss(logit, label, ignore_mask=None, num_classes=2):
def softmax_with_loss(logit, label, ignore_mask=None, num_classes=2, weight=None):
ignore_mask = fluid.layers.cast(ignore_mask, 'float32')
label = fluid.layers.elementwise_min(
label, fluid.layers.assign(np.array([num_classes - 1], dtype=np.int32)))
......@@ -29,12 +29,40 @@ def softmax_with_loss(logit, label, ignore_mask=None, num_classes=2):
label = fluid.layers.reshape(label, [-1, 1])
label = fluid.layers.cast(label, 'int64')
ignore_mask = fluid.layers.reshape(ignore_mask, [-1, 1])
loss, probs = fluid.layers.softmax_with_cross_entropy(
logit,
label,
ignore_index=cfg.DATASET.IGNORE_INDEX,
return_softmax=True)
if weight is None:
loss, probs = fluid.layers.softmax_with_cross_entropy(
logit,
label,
ignore_index=cfg.DATASET.IGNORE_INDEX,
return_softmax=True)
else:
label_one_hot = fluid.layers.one_hot(input=label, depth=num_classes)
if isinstance(weight, list):
assert len(weight) == num_classes, "weight length must equal num of classes"
weight = fluid.layers.assign(np.array([weight], dtype='float32'))
elif isinstance(weight, str):
assert weight.lower() == 'dynamic', 'if weight is string, must be dynamic!'
tmp = []
total_num = fluid.layers.cast(fluid.layers.shape(label)[0], 'float32')
for i in range(num_classes):
cls_pixel_num = fluid.layers.reduce_sum(label_one_hot[:, i])
ratio = total_num / (cls_pixel_num + 1)
tmp.append(ratio)
weight = fluid.layers.concat(tmp)
weight = weight / fluid.layers.reduce_sum(weight) * num_classes
elif isinstance(weight, fluid.layers.Variable):
pass
else:
raise ValueError('Expect weight is a list, string or Variable, but receive {}'.format(type(weight)))
weight = fluid.layers.reshape(weight, [1, num_classes])
weighted_label_one_hot = fluid.layers.elementwise_mul(label_one_hot, weight)
probs = fluid.layers.softmax(logit)
loss = fluid.layers.cross_entropy(
probs,
weighted_label_one_hot,
soft_label=True,
ignore_index=cfg.DATASET.IGNORE_INDEX)
weighted_label_one_hot.stop_gradient = True
loss = loss * ignore_mask
avg_loss = fluid.layers.mean(loss) / fluid.layers.mean(ignore_mask)
......@@ -80,7 +108,7 @@ def bce_loss(logit, label, ignore_mask=None):
return loss
def multi_softmax_with_loss(logits, label, ignore_mask=None, num_classes=2):
def multi_softmax_with_loss(logits, label, ignore_mask=None, num_classes=2, weight=None):
if isinstance(logits, tuple):
avg_loss = 0
for i, logit in enumerate(logits):
......@@ -91,7 +119,7 @@ def multi_softmax_with_loss(logits, label, ignore_mask=None, num_classes=2):
num_classes)
avg_loss += cfg.MODEL.MULTI_LOSS_WEIGHT[i] * loss
else:
avg_loss = softmax_with_loss(logits, label, ignore_mask, num_classes)
avg_loss = softmax_with_loss(logits, label, ignore_mask, num_classes, weight=weight)
return avg_loss
def multi_dice_loss(logits, label, ignore_mask=None):
......
......@@ -14,5 +14,3 @@
# limitations under the License.
import models.modeling
import models.libs
import models.backbone
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
__all__ = ["VGGNet"]
def check_points(count, points):
if points is None:
return False
else:
if isinstance(points, list):
return (True if count in points else False)
else:
return (True if count == points else False)
class VGGNet():
def __init__(self, layers=16):
self.layers = layers
def net(self, input, class_dim=1000, end_points=None, decode_points=None):
short_cuts = dict()
layers_count = 0
layers = self.layers
vgg_spec = {
11: ([1, 1, 2, 2, 2]),
13: ([2, 2, 2, 2, 2]),
16: ([2, 2, 3, 3, 3]),
19: ([2, 2, 4, 4, 4])
}
assert layers in vgg_spec.keys(), \
"supported layers are {} but input layer is {}".format(vgg_spec.keys(), layers)
nums = vgg_spec[layers]
channels = [64, 128, 256, 512, 512]
conv = input
for i in range(len(nums)):
conv = self.conv_block(conv, channels[i], nums[i], name="conv" + str(i + 1) + "_")
layers_count += nums[i]
if check_points(layers_count, decode_points):
short_cuts[layers_count] = conv
if check_points(layers_count, end_points):
return conv, short_cuts
return conv
def conv_block(self, input, num_filter, groups, name=None):
conv = input
for i in range(groups):
conv = fluid.layers.conv2d(
input=conv,
num_filters=num_filter,
filter_size=3,
stride=1,
padding=1,
act='relu',
param_attr=fluid.param_attr.ParamAttr(
name=name + str(i + 1) + "_weights"),
bias_attr=False)
return fluid.layers.pool2d(
input=conv, pool_size=2, pool_type='max', pool_stride=2)
......@@ -223,8 +223,9 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
avg_loss_list = []
valid_loss = []
if "softmax_loss" in loss_type:
weight = cfg.SOLVER.CROSS_ENTROPY_WEIGHT
avg_loss_list.append(
multi_softmax_with_loss(logits, label, mask, class_num))
multi_softmax_with_loss(logits, label, mask, class_num, weight))
loss_valid = True
valid_loss.append("softmax_loss")
if "dice_loss" in loss_type:
......
......@@ -158,7 +158,10 @@ cfg.SOLVER.LOSS = ["softmax_loss"]
cfg.SOLVER.LR_WARMUP = False
# warmup的迭代次数
cfg.SOLVER.LR_WARMUP_STEPS = 2000
# cross entropy weight, 默认为None,如果设置为'dynamic',会根据每个batch中各个类别的数目,
# 动态调整类别权重。
# 也可以设置一个静态权重(list的方式),比如有3类,每个类别权重可以设置为[0.1, 2.0, 0.9]
cfg.SOLVER.CROSS_ENTROPY_WEIGHT = None
########################## 测试配置 ###########################################
# 测试模型路径
cfg.TEST.TEST_MODEL = ''
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册