提交 61645b1d 编写于 作者: C chenguowei01

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleSeg into develop

......@@ -35,7 +35,7 @@ PaddleSeg是基于[PaddlePaddle](https://www.paddlepaddle.org.cn)开发的端到
- **高性能**
PaddleSeg支持多进程I/O、多卡并行、跨卡Batch Norm同步等训练加速策略,结合飞桨核心框架的显存优化功能,可大幅度减少分割模型的显存开销,让开发者更低成本、更高效地完成图像分割训练。
PaddleSeg支持多进程I/O、多卡并行等训练加速策略,结合飞桨核心框架的显存优化功能,可大幅度减少分割模型的显存开销,让开发者更低成本、更高效地完成图像分割训练。
- **工业级部署**
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from utils.util import AttrDict, merge_cfg_from_args, get_arguments
import os
......@@ -19,10 +33,10 @@ cfg.class_num = 20
# 均值, 图像预处理减去的均值
cfg.MEAN = 0.406, 0.456, 0.485
# 标准差,图像预处理除以标准差
cfg.STD = 0.225, 0.224, 0.229
cfg.STD = 0.225, 0.224, 0.229
# 多尺度预测时图像尺寸
cfg.multi_scales = (377,377), (473,473), (567,567)
cfg.multi_scales = (377, 377), (473, 473), (567, 567)
# 多尺度预测时图像是否水平翻转
cfg.flip = True
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import numpy as np
......@@ -12,18 +26,19 @@ config = importlib.import_module('config')
cfg = getattr(config, 'cfg')
# paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启
os.environ['FLAGS_eager_delete_tensor_gb']='0.0'
os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
import paddle.fluid as fluid
# 预测数据集类
class TestDataSet():
def __init__(self):
self.data_dir = cfg.data_dir
self.data_dir = cfg.data_dir
self.data_list_file = cfg.data_list_file
self.data_list = self.get_data_list()
self.data_num = len(self.data_list)
def get_data_list(self):
# 获取预测图像路径列表
data_list = []
......@@ -56,10 +71,10 @@ class TestDataSet():
img_path = self.data_list[index]
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
if img is None:
return img, img,img_path, None
return img, img, img_path, None
img_name = img_path.split(os.sep)[-1]
name_prefix = img_name.replace('.'+img_name.split('.')[-1],'')
name_prefix = img_name.replace('.' + img_name.split('.')[-1], '')
img_shape = img.shape[:2]
img_process = self.preprocess(img)
......@@ -90,39 +105,44 @@ def infer():
if image is None:
print(im_name, 'is None')
continue
# 预测
if cfg.example == 'ACE2P':
# ACE2P模型使用多尺度预测
reader = importlib.import_module('reader')
multi_scale_test = getattr(reader, 'multi_scale_test')
parsing, logits = multi_scale_test(exe, test_prog, feed_name, fetch_list, image, im_shape)
parsing, logits = multi_scale_test(exe, test_prog, feed_name,
fetch_list, image, im_shape)
else:
# HumanSeg,RoadLine模型单尺度预测
result = exe.run(program=test_prog, feed={feed_name[0]: image}, fetch_list=fetch_list)
result = exe.run(
program=test_prog,
feed={feed_name[0]: image},
fetch_list=fetch_list)
parsing = np.argmax(result[0][0], axis=0)
parsing = cv2.resize(parsing.astype(np.uint8), im_shape[::-1])
# 预测结果保存
result_path = os.path.join(cfg.vis_dir, im_name + '.png')
if cfg.example == 'HumanSeg':
logits = result[0][0][1]*255
logits = result[0][0][1] * 255
logits = cv2.resize(logits, im_shape[::-1])
ret, logits = cv2.threshold(logits, thresh, 0, cv2.THRESH_TOZERO)
logits = 255 *(logits - thresh)/(255 - thresh)
logits = 255 * (logits - thresh) / (255 - thresh)
# 将分割结果添加到alpha通道
rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)), axis=2)
rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)),
axis=2)
cv2.imwrite(result_path, rgba)
else:
else:
output_im = PILImage.fromarray(np.asarray(parsing, dtype=np.uint8))
output_im.putpalette(palette)
output_im.save(result_path)
if (idx + 1) % 100 == 0:
print('%d processd' % (idx + 1))
print('%d processd done' % (idx + 1))
print('%d processd done' % (idx + 1))
return 0
......
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
from config import cfg
import cv2
def get_affine_points(src_shape, dst_shape, rot_grad=0):
# 获取图像和仿射后图像的三组对应点坐标
# 三组点为仿射变换后图像的中心点, [w/2,0], [0,0],及对应原始图像的点
......@@ -23,7 +38,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0):
# 原始图像三组点
points = [[0, 0]] * 3
points[0] = (np.array([w, h]) - 1) * 0.5
points[0] = (np.array([w, h]) - 1) * 0.5
points[1] = points[0] + 0.5 * affine_shape[0] * np.array([sin_v, -cos_v])
points[2] = points[1] - 0.5 * affine_shape[1] * np.array([cos_v, sin_v])
......@@ -34,6 +49,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0):
return points, points_trans
def preprocess(im):
# ACE2P模型数据预处理
im_shape = im.shape[:2]
......@@ -42,13 +58,10 @@ def preprocess(im):
# 获取图像和仿射变换后图像的对应点坐标
points, points_trans = get_affine_points(im_shape, scale)
# 根据对应点集获得仿射矩阵
trans = cv2.getAffineTransform(np.float32(points),
np.float32(points_trans))
trans = cv2.getAffineTransform(
np.float32(points), np.float32(points_trans))
# 根据仿射矩阵对图像进行仿射
input = cv2.warpAffine(im,
trans,
scale[::-1],
flags=cv2.INTER_LINEAR)
input = cv2.warpAffine(im, trans, scale[::-1], flags=cv2.INTER_LINEAR)
# 减均值测,除以方差,转换数据格式为NCHW
input = input.astype(np.float32)
......@@ -66,19 +79,20 @@ def preprocess(im):
return input_images
def multi_scale_test(exe, test_prog, feed_name, fetch_list,
input_ims, im_shape):
def multi_scale_test(exe, test_prog, feed_name, fetch_list, input_ims,
im_shape):
# 由于部分类别分左右部位, flipped_idx为其水平翻转后对应的标签
flipped_idx = (15, 14, 17, 16, 19, 18)
ms_outputs = []
# 多尺度预测
for idx, scale in enumerate(cfg.multi_scales):
input_im = input_ims[idx]
parsing_output = exe.run(program=test_prog,
feed={feed_name[0]: input_im},
fetch_list=fetch_list)
parsing_output = exe.run(
program=test_prog,
feed={feed_name[0]: input_im},
fetch_list=fetch_list)
output = parsing_output[0][0]
if cfg.flip:
# 若水平翻转,对部分类别进行翻转,与原始预测结果取均值
......@@ -92,7 +106,8 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list,
# 仿射变换回图像原始尺寸
points, points_trans = get_affine_points(im_shape, scale)
M = cv2.getAffineTransform(np.float32(points_trans), np.float32(points))
logits_result = cv2.warpAffine(output, M, im_shape[::-1], flags=cv2.INTER_LINEAR)
logits_result = cv2.warpAffine(
output, M, im_shape[::-1], flags=cv2.INTER_LINEAR)
ms_outputs.append(logits_result)
# 多尺度预测结果求均值,求预测概率最大的类别
......@@ -100,4 +115,3 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list,
ms_fused_parsing_output = np.mean(ms_fused_parsing_output, axis=0)
parsing = np.argmax(ms_fused_parsing_output, axis=2)
return parsing, ms_fused_parsing_output
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -7,6 +7,7 @@
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import os
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--use_gpu",
action="store_true",
help="Use gpu or cpu to test.")
parser.add_argument('--example',
type=str,
help='RoadLine, HumanSeg or ACE2P')
parser.add_argument(
"--use_gpu", action="store_true", help="Use gpu or cpu to test.")
parser.add_argument(
'--example', type=str, help='RoadLine, HumanSeg or ACE2P')
return parser.parse_args()
......@@ -34,6 +48,7 @@ class AttrDict(dict):
else:
self[name] = value
def merge_cfg_from_args(args, cfg):
"""Merge config keys, values in args into the global config."""
for k, v in vars(args).items():
......@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg):
value = v
if value is not None:
cfg[k] = value
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,9 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# utils for memory management which is allocated on sharedmemory,
# note that these structures may not be thread-safe
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import models
import argparse
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import os.path as osp
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .humanseg import HumanSegMobile
from .humanseg import HumanSegServer
from .humanseg import HumanSegLite
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .backbone import mobilenet_v2
from .backbone import xception
from .deeplabv3p import DeepLabv3p
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .mobilenet_v2 import MobileNetV2
from .xception import Xception
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -10,6 +11,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from datasets.dataset import Dataset
import transforms
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from datasets.dataset import Dataset
from models import HumanSegMobile, HumanSegLite, HumanSegServer
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from datasets.dataset import Dataset
from models import HumanSegMobile, HumanSegLite, HumanSegServer
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import cv2
import os
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -205,11 +206,9 @@ def load_pretrained_weights(exe, main_prog, weights_dir, fuse_bn=False):
vars_to_load.append(var)
logging.debug("Weight {} will be load".format(var.name))
fluid.io.load_vars(
executor=exe,
dirname=weights_dir,
main_program=main_prog,
vars=vars_to_load)
params_dict = fluid.io.load_program_state(
weights_dir, var_list=vars_to_load)
fluid.io.set_program_state(main_prog, params_dict)
if len(vars_to_load) == 0:
logging.warning(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from datasets.dataset import Dataset
import transforms
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import os.path as osp
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -21,6 +21,7 @@ from models.model_builder import ModelPhase
from pdseg.data_aug import get_random_scale, randomly_scale_image_and_label, random_rotation, \
rand_scale_aspect, hsv_color_jitter, rand_crop
def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN):
"""
改变图像及标签图像尺寸
......@@ -44,7 +45,8 @@ def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN):
if grt is not None:
grt = cv2.resize(grt, target_size, interpolation=cv2.INTER_NEAREST)
if grt_instance is not None:
grt_instance = cv2.resize(grt_instance, target_size, interpolation=cv2.INTER_NEAREST)
grt_instance = cv2.resize(
grt_instance, target_size, interpolation=cv2.INTER_NEAREST)
elif cfg.AUG.AUG_METHOD == 'stepscaling':
if mode == ModelPhase.TRAIN:
min_scale_factor = cfg.AUG.MIN_SCALE_FACTOR
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -122,7 +122,10 @@ def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs):
if ckpt_dir is not None:
print('load test model:', ckpt_dir)
fluid.io.load_params(exe, ckpt_dir, main_program=test_prog)
try:
fluid.load(test_prog, os.path.join(ckpt_dir, 'model'), exe)
except:
fluid.io.load_params(exe, ckpt_dir, main_program=test_prog)
# Use streaming confusion matrix to calculate mean_iou
np.set_printoptions(
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,7 +18,6 @@ from __future__ import print_function
import paddle.fluid as fluid
from utils.config import cfg
from pdseg.models.libs.model_libs import scope, name_scope
from pdseg.models.libs.model_libs import bn, bn_relu, relu
......@@ -86,7 +85,12 @@ def bottleneck(inputs,
with scope('down_sample'):
inputs_shape = inputs.shape
with scope('main_max_pool'):
net_main = fluid.layers.conv2d(inputs, inputs_shape[1], filter_size=3, stride=2, padding='SAME')
net_main = fluid.layers.conv2d(
inputs,
inputs_shape[1],
filter_size=3,
stride=2,
padding='SAME')
#First get the difference in depth to pad, then pad with zeros only on the last dimension.
depth_to_pad = abs(inputs_shape[1] - output_depth)
......@@ -95,12 +99,16 @@ def bottleneck(inputs,
net_main = fluid.layers.pad(net_main, paddings=paddings)
with scope('block1'):
net = conv(inputs, reduced_depth, [2, 2], stride=2, padding='same')
net = conv(
inputs, reduced_depth, [2, 2], stride=2, padding='same')
net = bn(net)
net = prelu(net, decoder=decoder)
with scope('block2'):
net = conv(net, reduced_depth, [filter_size, filter_size], padding='same')
net = conv(
net,
reduced_depth, [filter_size, filter_size],
padding='same')
net = bn(net)
net = prelu(net, decoder=decoder)
......@@ -137,13 +145,18 @@ def bottleneck(inputs,
# Second conv block --- apply dilated convolution here
with scope('block2'):
net = conv(net, reduced_depth, filter_size, padding='SAME', dilation=dilation_rate)
net = conv(
net,
reduced_depth,
filter_size,
padding='SAME',
dilation=dilation_rate)
net = bn(net)
net = prelu(net, decoder=decoder)
# Final projection with 1x1 kernel (Expansion)
with scope('block3'):
net = conv(net, output_depth, [1,1])
net = conv(net, output_depth, [1, 1])
net = bn(net)
net = prelu(net, decoder=decoder)
......@@ -172,9 +185,11 @@ def bottleneck(inputs,
# Second conv block --- apply asymmetric conv here
with scope('block2'):
with scope('asymmetric_conv2a'):
net = conv(net, reduced_depth, [filter_size, 1], padding='same')
net = conv(
net, reduced_depth, [filter_size, 1], padding='same')
with scope('asymmetric_conv2b'):
net = conv(net, reduced_depth, [1, filter_size], padding='same')
net = conv(
net, reduced_depth, [1, filter_size], padding='same')
net = bn(net)
net = prelu(net, decoder=decoder)
......@@ -211,7 +226,8 @@ def bottleneck(inputs,
with scope('unpool'):
net_unpool = conv(inputs, output_depth, [1, 1])
net_unpool = bn(net_unpool)
net_unpool = fluid.layers.resize_bilinear(net_unpool, out_shape=output_shape[2:])
net_unpool = fluid.layers.resize_bilinear(
net_unpool, out_shape=output_shape[2:])
# First 1x1 projection to reduce depth
with scope('block1'):
......@@ -220,7 +236,12 @@ def bottleneck(inputs,
net = prelu(net, decoder=decoder)
with scope('block2'):
net = deconv(net, reduced_depth, filter_size=filter_size, stride=2, padding='same')
net = deconv(
net,
reduced_depth,
filter_size=filter_size,
stride=2,
padding='same')
net = bn(net)
net = prelu(net, decoder=decoder)
......@@ -253,7 +274,10 @@ def bottleneck(inputs,
# Second conv block
with scope('block2'):
net = conv(net, reduced_depth, [filter_size, filter_size], padding='same')
net = conv(
net,
reduced_depth, [filter_size, filter_size],
padding='same')
net = bn(net)
net = prelu(net, decoder=decoder)
......@@ -281,17 +305,33 @@ def ENet_stage1(inputs, name_scope='stage1_block'):
= bottleneck(inputs, output_depth=64, filter_size=3, regularizer_prob=0.01, type=DOWNSAMPLING,
name_scope='bottleneck1_0')
with scope('bottleneck1_1'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01,
name_scope='bottleneck1_1')
net = bottleneck(
net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_1')
with scope('bottleneck1_2'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01,
name_scope='bottleneck1_2')
net = bottleneck(
net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_2')
with scope('bottleneck1_3'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01,
name_scope='bottleneck1_3')
net = bottleneck(
net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_3')
with scope('bottleneck1_4'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01,
name_scope='bottleneck1_4')
net = bottleneck(
net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_4')
return net, inputs_shape_1
......@@ -302,17 +342,38 @@ def ENet_stage2(inputs, name_scope='stage2_block'):
name_scope='bottleneck2_0')
for i in range(2):
with scope('bottleneck2_{}'.format(str(4 * i + 1))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1,
name_scope='bottleneck2_{}'.format(str(4 * i + 1)))
net = bottleneck(
net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
name_scope='bottleneck2_{}'.format(str(4 * i + 1)))
with scope('bottleneck2_{}'.format(str(4 * i + 2))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+1)),
name_scope='bottleneck2_{}'.format(str(4 * i + 2)))
net = bottleneck(
net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 1)),
name_scope='bottleneck2_{}'.format(str(4 * i + 2)))
with scope('bottleneck2_{}'.format(str(4 * i + 3))):
net = bottleneck(net, output_depth=128, filter_size=5, regularizer_prob=0.1, type=ASYMMETRIC,
name_scope='bottleneck2_{}'.format(str(4 * i + 3)))
net = bottleneck(
net,
output_depth=128,
filter_size=5,
regularizer_prob=0.1,
type=ASYMMETRIC,
name_scope='bottleneck2_{}'.format(str(4 * i + 3)))
with scope('bottleneck2_{}'.format(str(4 * i + 4))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+2)),
name_scope='bottleneck2_{}'.format(str(4 * i + 4)))
net = bottleneck(
net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 2)),
name_scope='bottleneck2_{}'.format(str(4 * i + 4)))
return net, inputs_shape_2
......@@ -320,52 +381,106 @@ def ENet_stage3(inputs, name_scope='stage3_block'):
with scope(name_scope):
for i in range(2):
with scope('bottleneck3_{}'.format(str(4 * i + 0))):
net = bottleneck(inputs, output_depth=128, filter_size=3, regularizer_prob=0.1,
name_scope='bottleneck3_{}'.format(str(4 * i + 0)))
net = bottleneck(
inputs,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
name_scope='bottleneck3_{}'.format(str(4 * i + 0)))
with scope('bottleneck3_{}'.format(str(4 * i + 1))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+1)),
name_scope='bottleneck3_{}'.format(str(4 * i + 1)))
net = bottleneck(
net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 1)),
name_scope='bottleneck3_{}'.format(str(4 * i + 1)))
with scope('bottleneck3_{}'.format(str(4 * i + 2))):
net = bottleneck(net, output_depth=128, filter_size=5, regularizer_prob=0.1, type=ASYMMETRIC,
name_scope='bottleneck3_{}'.format(str(4 * i + 2)))
net = bottleneck(
net,
output_depth=128,
filter_size=5,
regularizer_prob=0.1,
type=ASYMMETRIC,
name_scope='bottleneck3_{}'.format(str(4 * i + 2)))
with scope('bottleneck3_{}'.format(str(4 * i + 3))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+2)),
name_scope='bottleneck3_{}'.format(str(4 * i + 3)))
net = bottleneck(
net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 2)),
name_scope='bottleneck3_{}'.format(str(4 * i + 3)))
return net
def ENet_stage4(inputs, inputs_shape, connect_tensor,
skip_connections=True, name_scope='stage4_block'):
def ENet_stage4(inputs,
inputs_shape,
connect_tensor,
skip_connections=True,
name_scope='stage4_block'):
with scope(name_scope):
with scope('bottleneck4_0'):
net = bottleneck(inputs, output_depth=64, filter_size=3, regularizer_prob=0.1,
type=UPSAMPLING, decoder=True, output_shape=inputs_shape,
name_scope='bottleneck4_0')
net = bottleneck(
inputs,
output_depth=64,
filter_size=3,
regularizer_prob=0.1,
type=UPSAMPLING,
decoder=True,
output_shape=inputs_shape,
name_scope='bottleneck4_0')
if skip_connections:
net = fluid.layers.elementwise_add(net, connect_tensor)
with scope('bottleneck4_1'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.1, decoder=True,
name_scope='bottleneck4_1')
net = bottleneck(
net,
output_depth=64,
filter_size=3,
regularizer_prob=0.1,
decoder=True,
name_scope='bottleneck4_1')
with scope('bottleneck4_2'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.1, decoder=True,
name_scope='bottleneck4_2')
net = bottleneck(
net,
output_depth=64,
filter_size=3,
regularizer_prob=0.1,
decoder=True,
name_scope='bottleneck4_2')
return net
def ENet_stage5(inputs, inputs_shape, connect_tensor, skip_connections=True,
def ENet_stage5(inputs,
inputs_shape,
connect_tensor,
skip_connections=True,
name_scope='stage5_block'):
with scope(name_scope):
net = bottleneck(inputs, output_depth=16, filter_size=3, regularizer_prob=0.1, type=UPSAMPLING,
decoder=True, output_shape=inputs_shape,
name_scope='bottleneck5_0')
net = bottleneck(
inputs,
output_depth=16,
filter_size=3,
regularizer_prob=0.1,
type=UPSAMPLING,
decoder=True,
output_shape=inputs_shape,
name_scope='bottleneck5_0')
if skip_connections:
net = fluid.layers.elementwise_add(net, connect_tensor)
with scope('bottleneck5_1'):
net = bottleneck(net, output_depth=16, filter_size=3, regularizer_prob=0.1, decoder=True,
name_scope='bottleneck5_1')
net = bottleneck(
net,
output_depth=16,
filter_size=3,
regularizer_prob=0.1,
decoder=True,
name_scope='bottleneck5_1')
return net
......@@ -378,14 +493,16 @@ def decoder(input, num_classes):
segStage3 = ENet_stage3(stage2)
segStage4 = ENet_stage4(segStage3, inputs_shape_2, stage1)
segStage5 = ENet_stage5(segStage4, inputs_shape_1, initial)
segLogits = deconv(segStage5, num_classes, filter_size=2, stride=2, padding='SAME')
segLogits = deconv(
segStage5, num_classes, filter_size=2, stride=2, padding='SAME')
# Embedding branch
with scope('LaneNetEm'):
emStage3 = ENet_stage3(stage2)
emStage4 = ENet_stage4(emStage3, inputs_shape_2, stage1)
emStage5 = ENet_stage5(emStage4, inputs_shape_1, initial)
emLogits = deconv(emStage5, 4, filter_size=2, stride=2, padding='SAME')
emLogits = deconv(
emStage5, 4, filter_size=2, stride=2, padding='SAME')
elif 'vgg' in cfg.MODEL.LANENET.BACKBONE:
encoder_list = ['pool5', 'pool4', 'pool3']
......@@ -396,14 +513,16 @@ def decoder(input, num_classes):
encoder_list = encoder_list[1:]
for i in range(len(encoder_list)):
with scope('deconv_{:d}'.format(i + 1)):
deconv_out = deconv(score, 64, filter_size=4, stride=2, padding='SAME')
deconv_out = deconv(
score, 64, filter_size=4, stride=2, padding='SAME')
input_tensor = input[encoder_list[i]]
with scope('score_{:d}'.format(i + 1)):
score = conv(input_tensor, 64, 1)
score = fluid.layers.elementwise_add(deconv_out, score)
with scope('deconv_final'):
emLogits = deconv(score, 64, filter_size=16, stride=8, padding='SAME')
emLogits = deconv(
score, 64, filter_size=16, stride=8, padding='SAME')
with scope('score_final'):
segLogits = conv(emLogits, num_classes, 1)
emLogits = relu(conv(emLogits, 4, 1))
......@@ -415,7 +534,8 @@ def encoder(input):
model = vgg_backbone(layers=16)
#output = model.net(input)
_, encode_feature_dict = model.net(input, end_points=13, decode_points=[7, 10, 13])
_, encode_feature_dict = model.net(
input, end_points=13, decode_points=[7, 10, 13])
output = {}
output['pool3'] = encode_feature_dict[7]
output['pool4'] = encode_feature_dict[10]
......@@ -427,8 +547,9 @@ def encoder(input):
stage2, inputs_shape_2 = ENet_stage2(stage1)
output = (initial, stage1, stage2, inputs_shape_1, inputs_shape_2)
else:
raise Exception("LaneNet expect enet and vgg backbone, but received {}".
format(cfg.MODEL.LANENET.BACKBONE))
raise Exception(
"LaneNet expect enet and vgg backbone, but received {}".format(
cfg.MODEL.LANENET.BACKBONE))
return output
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -58,7 +58,8 @@ class LaneNetDataset():
if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // cfg.NUM_TRAINERS
self.lines = self.all_lines[num_lines * cfg.TRAINER_ID: num_lines * (cfg.TRAINER_ID + 1)]
self.lines = self.all_lines[num_lines * cfg.TRAINER_ID:num_lines *
(cfg.TRAINER_ID + 1)]
self.shuffle_seed += 1
elif self.shuffle:
np.random.shuffle(self.lines)
......@@ -86,7 +87,8 @@ class LaneNetDataset():
if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // self.num_trainers
self.lines = self.all_lines[num_lines * self.trainer_id: num_lines * (self.trainer_id + 1)]
self.lines = self.all_lines[num_lines * self.trainer_id:num_lines *
(self.trainer_id + 1)]
self.shuffle_seed += 1
elif self.shuffle:
np.random.shuffle(self.lines)
......@@ -118,7 +120,8 @@ class LaneNetDataset():
def batch_reader(is_test=False, drop_last=drop_last):
if is_test:
imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], []
for img, grt, grt_instance, img_name, valid_shape, org_shape in reader():
for img, grt, grt_instance, img_name, valid_shape, org_shape in reader(
):
imgs.append(img)
grts.append(grt)
grts_instance.append(grt_instance)
......@@ -126,14 +129,15 @@ class LaneNetDataset():
valid_shapes.append(valid_shape)
org_shapes.append(org_shape)
if len(imgs) == batch_size:
yield np.array(imgs), np.array(
grts), np.array(grts_instance), img_names, np.array(valid_shapes), np.array(
org_shapes)
yield np.array(imgs), np.array(grts), np.array(
grts_instance), img_names, np.array(
valid_shapes), np.array(org_shapes)
imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], []
if not drop_last and len(imgs) > 0:
yield np.array(imgs), np.array(grts), np.array(grts_instance), img_names, np.array(
valid_shapes), np.array(org_shapes)
yield np.array(imgs), np.array(grts), np.array(
grts_instance), img_names, np.array(
valid_shapes), np.array(org_shapes)
else:
imgs, labs, labs_instance, ignore = [], [], [], []
bs = 0
......@@ -144,12 +148,14 @@ class LaneNetDataset():
ignore.append(ig)
bs += 1
if bs == batch_size:
yield np.array(imgs), np.array(labs), np.array(labs_instance), np.array(ignore)
yield np.array(imgs), np.array(labs), np.array(
labs_instance), np.array(ignore)
bs = 0
imgs, labs, labs_instance, ignore = [], [], [], []
if not drop_last and bs > 0:
yield np.array(imgs), np.array(labs), np.array(labs_instance), np.array(ignore)
yield np.array(imgs), np.array(labs), np.array(
labs_instance), np.array(ignore)
return batch_reader(is_test, drop_last)
......@@ -299,10 +305,12 @@ class LaneNetDataset():
img, grt = aug.rand_crop(img, grt, mode=mode)
elif ModelPhase.is_eval(mode):
img, grt, grt_instance = aug.resize(img, grt, grt_instance, mode=mode)
img, grt, grt_instance = aug.resize(
img, grt, grt_instance, mode=mode)
elif ModelPhase.is_visual(mode):
ori_img = img.copy()
img, grt, grt_instance = aug.resize(img, grt, grt_instance, mode=mode)
img, grt, grt_instance = aug.resize(
img, grt, grt_instance, mode=mode)
valid_shape = [img.shape[0], img.shape[1]]
else:
raise ValueError("Dataset mode={} Error!".format(mode))
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -40,10 +40,10 @@ from pdseg.utils.timer import Timer, calculate_eta
from reader import LaneNetDataset
from models.model_builder import build_model
from models.model_builder import ModelPhase
from models.model_builder import parse_shape_from_file
from eval import evaluate
from vis import visualize
from utils import dist_utils
from utils.load_model_utils import load_pretrained_weights
def parse_args():
......@@ -101,37 +101,6 @@ def parse_args():
return parser.parse_args()
def save_vars(executor, dirname, program=None, vars=None):
"""
Temporary resolution for Win save variables compatability.
Will fix in PaddlePaddle v1.5.2
"""
save_program = fluid.Program()
save_block = save_program.global_block()
for each_var in vars:
# NOTE: don't save the variable which type is RAW
if each_var.type == fluid.core.VarDesc.VarType.RAW:
continue
new_var = save_block.create_var(
name=each_var.name,
shape=each_var.shape,
dtype=each_var.dtype,
type=each_var.type,
lod_level=each_var.lod_level,
persistable=True)
file_path = os.path.join(dirname, new_var.name)
file_path = os.path.normpath(file_path)
save_block.append_op(
type='save',
inputs={'X': [new_var]},
outputs={},
attrs={'file_path': file_path})
executor.run(save_program)
def save_checkpoint(exe, program, ckpt_name):
"""
Save checkpoint for evaluation or resume training
......@@ -141,29 +110,22 @@ def save_checkpoint(exe, program, ckpt_name):
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
save_vars(
exe,
ckpt_dir,
program,
vars=list(filter(fluid.io.is_persistable, program.list_vars())))
fluid.save(program, os.path.join(ckpt_dir, 'model'))
return ckpt_dir
def load_checkpoint(exe, program):
"""
Load checkpoiont from pretrained model directory for resume training
Load checkpoiont for resuming training
"""
print('Resume model training from:', cfg.TRAIN.RESUME_MODEL_DIR)
if not os.path.exists(cfg.TRAIN.RESUME_MODEL_DIR):
raise ValueError("TRAIN.PRETRAIN_MODEL {} not exist!".format(
cfg.TRAIN.RESUME_MODEL_DIR))
fluid.io.load_persistables(
exe, cfg.TRAIN.RESUME_MODEL_DIR, main_program=program)
model_path = cfg.TRAIN.RESUME_MODEL_DIR
print('Resume model training from:', model_path)
if not os.path.exists(model_path):
raise ValueError(
"TRAIN.PRETRAIN_MODEL {} not exist!".format(model_path))
fluid.load(program, os.path.join(model_path, 'model'), exe)
# Check is path ended by path spearator
if model_path[-1] == os.sep:
model_path = model_path[0:-1]
......@@ -178,7 +140,6 @@ def load_checkpoint(exe, program):
else:
raise ValueError("Resume model path is not valid!")
print("Model checkpoint loaded successfully!")
return begin_epoch
......@@ -271,44 +232,7 @@ def train(cfg):
begin_epoch = load_checkpoint(exe, train_prog)
# Load pretrained model
elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL_DIR):
print_info('Pretrained model dir: ', cfg.TRAIN.PRETRAINED_MODEL_DIR)
load_vars = []
load_fail_vars = []
def var_shape_matched(var, shape):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist = os.path.exists(
os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
if var_exist:
var_shape = parse_shape_from_file(
os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
if var_shape != shape:
print(var.name, var_shape, shape)
return var_shape == shape
return False
for x in train_prog.list_vars():
if isinstance(x, fluid.framework.Parameter):
shape = tuple(fluid.global_scope().find_var(
x.name).get_tensor().shape())
if var_shape_matched(x, shape):
load_vars.append(x)
else:
load_fail_vars.append(x)
fluid.io.load_vars(
exe, dirname=cfg.TRAIN.PRETRAINED_MODEL_DIR, vars=load_vars)
for var in load_vars:
print_info("Parameter[{}] loaded sucessfully!".format(var.name))
for var in load_fail_vars:
print_info(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it.".format(var.name))
print_info("{}/{} pretrained parameters loaded successfully!".format(
len(load_vars),
len(load_vars) + len(load_fail_vars)))
load_pretrained_weights(exe, train_prog, cfg.TRAIN.PRETRAINED_MODEL_DIR)
else:
print_info(
'Pretrained model dir {} not exists, training from scratch...'.
......@@ -393,8 +317,7 @@ def train(cfg):
avg_emb_loss, avg_acc, avg_fp, avg_fn, speed,
calculate_eta(all_step - step, speed)))
if args.use_vdl:
log_writer.add_scalar('Train/loss', avg_loss,
step)
log_writer.add_scalar('Train/loss', avg_loss, step)
log_writer.add_scalar('Train/lr', lr[0], step)
log_writer.add_scalar('Train/speed', speed, step)
sys.stdout.flush()
......@@ -423,8 +346,7 @@ def train(cfg):
use_gpu=args.use_gpu,
use_mpio=args.use_mpio)
if args.use_vdl:
log_writer.add_scalar('Evaluate/accuracy', accuracy,
step)
log_writer.add_scalar('Evaluate/accuracy', accuracy, step)
log_writer.add_scalar('Evaluate/fp', fp, step)
log_writer.add_scalar('Evaluate/fn', fn, step)
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -80,8 +80,8 @@ cfg.DATASET.DATA_DIM = 3
cfg.DATASET.SEPARATOR = ' '
# 忽略的像素标签值, 默认为255,一般无需改动
cfg.DATASET.IGNORE_INDEX = 255
# 数据增强是图像的padding值
cfg.DATASET.PADDING_VALUE = [127.5,127.5,127.5]
# 数据增强是图像的padding值
cfg.DATASET.PADDING_VALUE = [127.5, 127.5, 127.5]
########################### 数据增强配置 ######################################
# 图像镜像左右翻转
......
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
generate tusimple training dataset
"""
......@@ -14,12 +28,16 @@ import numpy as np
def init_args():
parser = argparse.ArgumentParser()
parser.add_argument('--src_dir', type=str, help='The origin path of unzipped tusimple dataset')
parser.add_argument(
'--src_dir',
type=str,
help='The origin path of unzipped tusimple dataset')
return parser.parse_args()
def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, instance_dst_dir):
def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir,
instance_dst_dir):
assert ops.exists(json_file_path), '{:s} not exist'.format(json_file_path)
......@@ -39,11 +57,14 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
h_samples = info_dict['h_samples']
lanes = info_dict['lanes']
image_name_new = '{:s}.png'.format('{:d}'.format(line_index + image_nums).zfill(4))
image_name_new = '{:s}.png'.format(
'{:d}'.format(line_index + image_nums).zfill(4))
src_image = cv2.imread(image_path, cv2.IMREAD_COLOR)
dst_binary_image = np.zeros([src_image.shape[0], src_image.shape[1]], np.uint8)
dst_instance_image = np.zeros([src_image.shape[0], src_image.shape[1]], np.uint8)
dst_binary_image = np.zeros(
[src_image.shape[0], src_image.shape[1]], np.uint8)
dst_instance_image = np.zeros(
[src_image.shape[0], src_image.shape[1]], np.uint8)
for lane_index, lane in enumerate(lanes):
assert len(h_samples) == len(lane)
......@@ -62,13 +83,23 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
lane_pts = np.vstack((lane_x, lane_y)).transpose()
lane_pts = np.array([lane_pts], np.int64)
cv2.polylines(dst_binary_image, lane_pts, isClosed=False,
color=255, thickness=5)
cv2.polylines(dst_instance_image, lane_pts, isClosed=False,
color=lane_index * 50 + 20, thickness=5)
dst_binary_image_path = ops.join(src_dir, binary_dst_dir, image_name_new)
dst_instance_image_path = ops.join(src_dir, instance_dst_dir, image_name_new)
cv2.polylines(
dst_binary_image,
lane_pts,
isClosed=False,
color=255,
thickness=5)
cv2.polylines(
dst_instance_image,
lane_pts,
isClosed=False,
color=lane_index * 50 + 20,
thickness=5)
dst_binary_image_path = ops.join(src_dir, binary_dst_dir,
image_name_new)
dst_instance_image_path = ops.join(src_dir, instance_dst_dir,
image_name_new)
dst_rgb_image_path = ops.join(src_dir, ori_dst_dir, image_name_new)
cv2.imwrite(dst_binary_image_path, dst_binary_image)
......@@ -78,7 +109,12 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
print('Process {:s} success'.format(image_name))
def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train', split=False):
def gen_sample(src_dir,
b_gt_image_dir,
i_gt_image_dir,
image_dir,
phase='train',
split=False):
label_list = []
with open('{:s}/{}ing/{}.txt'.format(src_dir, phase, phase), 'w') as file:
......@@ -92,7 +128,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
image_path = ops.join(image_dir, image_name)
assert ops.exists(image_path), '{:s} not exist'.format(image_path)
assert ops.exists(instance_gt_image_path), '{:s} not exist'.format(instance_gt_image_path)
assert ops.exists(instance_gt_image_path), '{:s} not exist'.format(
instance_gt_image_path)
b_gt_image = cv2.imread(binary_gt_image_path, cv2.IMREAD_COLOR)
i_gt_image = cv2.imread(instance_gt_image_path, cv2.IMREAD_COLOR)
......@@ -102,7 +139,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
print('image: {:s} corrupt'.format(image_name))
continue
else:
info = '{:s} {:s} {:s}'.format(image_path, binary_gt_image_path, instance_gt_image_path)
info = '{:s} {:s} {:s}'.format(image_path, binary_gt_image_path,
instance_gt_image_path)
file.write(info + '\n')
label_list.append(info)
if phase == 'train' and split:
......@@ -110,10 +148,12 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
val_list_len = len(label_list) // 10
val_label_list = label_list[:val_list_len]
train_label_list = label_list[val_list_len:]
with open('{:s}/{}ing/train_part.txt'.format(src_dir, phase, phase), 'w') as file:
with open('{:s}/{}ing/train_part.txt'.format(src_dir, phase, phase),
'w') as file:
for info in train_label_list:
file.write(info + '\n')
with open('{:s}/{}ing/val_part.txt'.format(src_dir, phase, phase), 'w') as file:
with open('{:s}/{}ing/val_part.txt'.format(src_dir, phase, phase),
'w') as file:
for info in val_label_list:
file.write(info + '\n')
return
......@@ -130,12 +170,14 @@ def process_tusimple_dataset(src_dir):
for json_label_path in glob.glob('{:s}/label*.json'.format(src_dir)):
json_label_name = ops.split(json_label_path)[1]
shutil.copyfile(json_label_path, ops.join(traing_folder_path, json_label_name))
shutil.copyfile(json_label_path,
ops.join(traing_folder_path, json_label_name))
for json_label_path in glob.glob('{:s}/test_label.json'.format(src_dir)):
json_label_name = ops.split(json_label_path)[1]
shutil.copyfile(json_label_path, ops.join(testing_folder_path, json_label_name))
shutil.copyfile(json_label_path,
ops.join(testing_folder_path, json_label_name))
train_gt_image_dir = ops.join('training', 'gt_image')
train_gt_binary_dir = ops.join('training', 'gt_binary_image')
......@@ -154,9 +196,11 @@ def process_tusimple_dataset(src_dir):
os.makedirs(os.path.join(src_dir, test_gt_instance_dir), exist_ok=True)
for json_label_path in glob.glob('{:s}/*.json'.format(traing_folder_path)):
process_json_file(json_label_path, src_dir, train_gt_image_dir, train_gt_binary_dir, train_gt_instance_dir)
process_json_file(json_label_path, src_dir, train_gt_image_dir,
train_gt_binary_dir, train_gt_instance_dir)
gen_sample(src_dir, train_gt_binary_dir, train_gt_instance_dir, train_gt_image_dir, 'train', True)
gen_sample(src_dir, train_gt_binary_dir, train_gt_instance_dir,
train_gt_image_dir, 'train', True)
if __name__ == '__main__':
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this code heavily base on https://github.com/MaybeShewill-CV/lanenet-lane-detection/blob/master/lanenet_model/lanenet_postprocess.py
"""
LaneNet model post process
......@@ -22,12 +35,14 @@ def _morphological_process(image, kernel_size=5):
:return:
"""
if len(image.shape) == 3:
raise ValueError('Binary segmentation result image should be a single channel image')
raise ValueError(
'Binary segmentation result image should be a single channel image')
if image.dtype is not np.uint8:
image = np.array(image, np.uint8)
kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(kernel_size, kernel_size))
kernel = cv2.getStructuringElement(
shape=cv2.MORPH_ELLIPSE, ksize=(kernel_size, kernel_size))
# close operation fille hole
closing = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel, iterations=1)
......@@ -46,13 +61,15 @@ def _connect_components_analysis(image):
else:
gray_image = image
return cv2.connectedComponentsWithStats(gray_image, connectivity=8, ltype=cv2.CV_32S)
return cv2.connectedComponentsWithStats(
gray_image, connectivity=8, ltype=cv2.CV_32S)
class _LaneFeat(object):
"""
"""
def __init__(self, feat, coord, class_id=-1):
"""
lane feat object
......@@ -108,18 +125,21 @@ class _LaneNetCluster(object):
"""
Instance segmentation result cluster
"""
def __init__(self):
"""
"""
self._color_map = [np.array([255, 0, 0]),
np.array([0, 255, 0]),
np.array([0, 0, 255]),
np.array([125, 125, 0]),
np.array([0, 125, 125]),
np.array([125, 0, 125]),
np.array([50, 100, 50]),
np.array([100, 50, 100])]
self._color_map = [
np.array([255, 0, 0]),
np.array([0, 255, 0]),
np.array([0, 0, 255]),
np.array([125, 125, 0]),
np.array([0, 125, 125]),
np.array([125, 0, 125]),
np.array([50, 100, 50]),
np.array([100, 50, 100])
]
@staticmethod
def _embedding_feats_dbscan_cluster(embedding_image_feats):
......@@ -186,15 +206,16 @@ class _LaneNetCluster(object):
# get embedding feats and coords
get_lane_embedding_feats_result = self._get_lane_embedding_feats(
binary_seg_ret=binary_seg_result,
instance_seg_ret=instance_seg_result
)
instance_seg_ret=instance_seg_result)
# dbscan cluster
dbscan_cluster_result = self._embedding_feats_dbscan_cluster(
embedding_image_feats=get_lane_embedding_feats_result['lane_embedding_feats']
)
embedding_image_feats=get_lane_embedding_feats_result[
'lane_embedding_feats'])
mask = np.zeros(shape=[binary_seg_result.shape[0], binary_seg_result.shape[1], 3], dtype=np.uint8)
mask = np.zeros(
shape=[binary_seg_result.shape[0], binary_seg_result.shape[1], 3],
dtype=np.uint8)
db_labels = dbscan_cluster_result['db_labels']
unique_labels = dbscan_cluster_result['unique_labels']
coord = get_lane_embedding_feats_result['lane_coordinates']
......@@ -219,11 +240,13 @@ class LaneNetPostProcessor(object):
"""
lanenet post process for lane generation
"""
def __init__(self, ipm_remap_file_path='./utils/tusimple_ipm_remap.yml'):
"""
convert front car view to bird view
"""
assert ops.exists(ipm_remap_file_path), '{:s} not exist'.format(ipm_remap_file_path)
assert ops.exists(ipm_remap_file_path), '{:s} not exist'.format(
ipm_remap_file_path)
self._cluster = _LaneNetCluster()
self._ipm_remap_file_path = ipm_remap_file_path
......@@ -232,14 +255,16 @@ class LaneNetPostProcessor(object):
self._remap_to_ipm_x = remap_file_load_ret['remap_to_ipm_x']
self._remap_to_ipm_y = remap_file_load_ret['remap_to_ipm_y']
self._color_map = [np.array([255, 0, 0]),
np.array([0, 255, 0]),
np.array([0, 0, 255]),
np.array([125, 125, 0]),
np.array([0, 125, 125]),
np.array([125, 0, 125]),
np.array([50, 100, 50]),
np.array([100, 50, 100])]
self._color_map = [
np.array([255, 0, 0]),
np.array([0, 255, 0]),
np.array([0, 0, 255]),
np.array([125, 125, 0]),
np.array([0, 125, 125]),
np.array([125, 0, 125]),
np.array([50, 100, 50]),
np.array([100, 50, 100])
]
def _load_remap_matrix(self):
fs = cv2.FileStorage(self._ipm_remap_file_path, cv2.FILE_STORAGE_READ)
......@@ -256,15 +281,20 @@ class LaneNetPostProcessor(object):
return ret
def postprocess(self, binary_seg_result, instance_seg_result=None,
min_area_threshold=100, source_image=None,
def postprocess(self,
binary_seg_result,
instance_seg_result=None,
min_area_threshold=100,
source_image=None,
data_source='tusimple'):
# convert binary_seg_result
binary_seg_result = np.array(binary_seg_result * 255, dtype=np.uint8)
# apply image morphology operation to fill in the hold and reduce the small area
morphological_ret = _morphological_process(binary_seg_result, kernel_size=5)
connect_components_analysis_ret = _connect_components_analysis(image=morphological_ret)
morphological_ret = _morphological_process(
binary_seg_result, kernel_size=5)
connect_components_analysis_ret = _connect_components_analysis(
image=morphological_ret)
labels = connect_components_analysis_ret[1]
stats = connect_components_analysis_ret[2]
......@@ -276,8 +306,7 @@ class LaneNetPostProcessor(object):
# apply embedding features cluster
mask_image, lane_coords = self._cluster.apply_lane_feats_cluster(
binary_seg_result=morphological_ret,
instance_seg_result=instance_seg_result
)
instance_seg_result=instance_seg_result)
if mask_image is None:
return {
......@@ -292,15 +321,15 @@ class LaneNetPostProcessor(object):
for lane_index, coords in enumerate(lane_coords):
if data_source == 'tusimple':
tmp_mask = np.zeros(shape=(720, 1280), dtype=np.uint8)
tmp_mask[tuple((np.int_(coords[:, 1] * 720 / 256), np.int_(coords[:, 0] * 1280 / 512)))] = 255
tmp_mask[tuple((np.int_(coords[:, 1] * 720 / 256),
np.int_(coords[:, 0] * 1280 / 512)))] = 255
else:
raise ValueError('Wrong data source now only support tusimple')
tmp_ipm_mask = cv2.remap(
tmp_mask,
self._remap_to_ipm_x,
self._remap_to_ipm_y,
interpolation=cv2.INTER_NEAREST
)
interpolation=cv2.INTER_NEAREST)
nonzero_y = np.array(tmp_ipm_mask.nonzero()[0])
nonzero_x = np.array(tmp_ipm_mask.nonzero()[1])
......@@ -309,16 +338,19 @@ class LaneNetPostProcessor(object):
[ipm_image_height, ipm_image_width] = tmp_ipm_mask.shape
plot_y = np.linspace(10, ipm_image_height, ipm_image_height - 10)
fit_x = fit_param[0] * plot_y ** 2 + fit_param[1] * plot_y + fit_param[2]
fit_x = fit_param[0] * plot_y**2 + fit_param[
1] * plot_y + fit_param[2]
lane_pts = []
for index in range(0, plot_y.shape[0], 5):
src_x = self._remap_to_ipm_x[
int(plot_y[index]), int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
int(plot_y[index]),
int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
if src_x <= 0:
continue
src_y = self._remap_to_ipm_y[
int(plot_y[index]), int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
int(plot_y[index]),
int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
src_y = src_y if src_y > 0 else 0
lane_pts.append([src_x, src_y])
......@@ -366,8 +398,10 @@ class LaneNetPostProcessor(object):
continue
lane_color = self._color_map[index].tolist()
cv2.circle(source_image, (int(interpolation_src_pt_x),
int(interpolation_src_pt_y)), 5, lane_color, -1)
cv2.circle(
source_image,
(int(interpolation_src_pt_x), int(interpolation_src_pt_y)),
5, lane_color, -1)
ret = {
'mask_image': mask_image,
'fit_params': fit_params,
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
import six
import numpy as np
def parse_param_file(param_file, return_shape=True):
from paddle.fluid.proto.framework_pb2 import VarType
f = open(param_file, 'rb')
version = np.fromstring(f.read(4), dtype='int32')
lod_level = np.fromstring(f.read(8), dtype='int64')
for i in range(int(lod_level)):
_size = np.fromstring(f.read(8), dtype='int64')
_ = f.read(_size)
version = np.fromstring(f.read(4), dtype='int32')
tensor_desc = VarType.TensorDesc()
tensor_desc_size = np.fromstring(f.read(4), dtype='int32')
tensor_desc.ParseFromString(f.read(int(tensor_desc_size)))
tensor_shape = tuple(tensor_desc.dims)
if return_shape:
f.close()
return tuple(tensor_desc.dims)
if tensor_desc.data_type != 5:
raise Exception(
"Unexpected data type while parse {}".format(param_file))
data_size = 4
for i in range(len(tensor_shape)):
data_size *= tensor_shape[i]
weight = np.fromstring(f.read(data_size), dtype='float32')
f.close()
return np.reshape(weight, tensor_shape)
def load_pdparams(exe, main_prog, model_dir):
import paddle.fluid as fluid
from paddle.fluid.proto.framework_pb2 import VarType
from paddle.fluid.framework import Program
vars_to_load = list()
vars_not_load = list()
import pickle
with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f:
params_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
unused_vars = list()
for var in main_prog.list_vars():
if not isinstance(var, fluid.framework.Parameter):
continue
if var.name not in params_dict:
print("{} is not in saved model".format(var.name))
vars_not_load.append(var.name)
continue
if var.shape != params_dict[var.name].shape:
unused_vars.append(var.name)
vars_not_load.append(var.name)
print(
"[SKIP] Shape of pretrained weight {} doesn't match.(Pretrained: {}, Actual: {})"
.format(var.name, params_dict[var.name].shape, var.shape))
continue
vars_to_load.append(var)
for var_name in unused_vars:
del params_dict[var_name]
fluid.io.set_program_state(main_prog, params_dict)
if len(vars_to_load) == 0:
print(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
)
else:
print("There are {}/{} varaibles in {} are loaded.".format(
len(vars_to_load),
len(vars_to_load) + len(vars_not_load), model_dir))
def load_pretrained_weights(exe, main_prog, weights_dir):
if not osp.exists(weights_dir):
raise Exception("Path {} not exists.".format(weights_dir))
if osp.exists(osp.join(weights_dir, "model.pdparams")):
return load_pdparams(exe, main_prog, weights_dir)
import paddle.fluid as fluid
vars_to_load = list()
vars_not_load = list()
for var in main_prog.list_vars():
if not isinstance(var, fluid.framework.Parameter):
continue
if not osp.exists(osp.join(weights_dir, var.name)):
print("[SKIP] Pretrained weight {}/{} doesn't exist".format(
weights_dir, var.name))
vars_not_load.append(var)
continue
pretrained_shape = parse_param_file(osp.join(weights_dir, var.name))
actual_shape = tuple(var.shape)
if pretrained_shape != actual_shape:
print(
"[SKIP] Shape of pretrained weight {}/{} doesn't match.(Pretrained: {}, Actual: {})"
.format(weights_dir, var.name, pretrained_shape, actual_shape))
vars_not_load.append(var)
continue
vars_to_load.append(var)
params_dict = fluid.io.load_program_state(
weights_dir, var_list=vars_to_load)
fluid.io.set_program_state(main_prog, params_dict)
if len(vars_to_load) == 0:
print(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
)
else:
print("There are {}/{} varaibles in {} are loaded.".format(
len(vars_to_load),
len(vars_to_load) + len(vars_not_load), weights_dir))
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -45,6 +45,7 @@ from models.model_builder import ModelPhase
from utils import lanenet_postprocess
import matplotlib.pyplot as plt
def parse_args():
parser = argparse.ArgumentParser(description='PaddeSeg visualization tools')
parser.add_argument(
......@@ -106,7 +107,6 @@ def minmax_scale(input_arr):
return output_arr
def visualize(cfg,
vis_file_list=None,
use_gpu=False,
......@@ -119,7 +119,6 @@ def visualize(cfg,
if vis_file_list is None:
vis_file_list = cfg.DATASET.TEST_FILE_LIST
dataset = LaneNetDataset(
file_list=vis_file_list,
mode=ModelPhase.VISUAL,
......@@ -139,7 +138,12 @@ def visualize(cfg,
ckpt_dir = cfg.TEST.TEST_MODEL if not ckpt_dir else ckpt_dir
fluid.io.load_params(exe, ckpt_dir, main_program=test_prog)
if ckpt_dir is not None:
print('load test model:', ckpt_dir)
try:
fluid.load(test_prog, os.path.join(ckpt_dir, 'model'), exe)
except:
fluid.io.load_params(exe, ckpt_dir, main_program=test_prog)
save_dir = os.path.join(vis_dir, 'visual_results')
makedirs(save_dir)
......@@ -161,22 +165,26 @@ def visualize(cfg,
for i in range(num_imgs):
gt_image = org_imgs[i]
binary_seg_image, instance_seg_image = segLogits[i].squeeze(-1), emLogits[i].transpose((1,2,0))
binary_seg_image, instance_seg_image = segLogits[i].squeeze(
-1), emLogits[i].transpose((1, 2, 0))
postprocess_result = postprocessor.postprocess(
binary_seg_result=binary_seg_image,
instance_seg_result=instance_seg_image,
source_image=gt_image
)
pred_binary_fn = os.path.join(save_dir, to_png_fn(img_names[i], name='_pred_binary'))
pred_lane_fn = os.path.join(save_dir, to_png_fn(img_names[i], name='_pred_lane'))
pred_instance_fn = os.path.join(save_dir, to_png_fn(img_names[i], name='_pred_instance'))
source_image=gt_image)
pred_binary_fn = os.path.join(
save_dir, to_png_fn(img_names[i], name='_pred_binary'))
pred_lane_fn = os.path.join(
save_dir, to_png_fn(img_names[i], name='_pred_lane'))
pred_instance_fn = os.path.join(
save_dir, to_png_fn(img_names[i], name='_pred_instance'))
dirname = os.path.dirname(pred_binary_fn)
makedirs(dirname)
mask_image = postprocess_result['mask_image']
for i in range(4):
instance_seg_image[:, :, i] = minmax_scale(instance_seg_image[:, :, i])
instance_seg_image[:, :, i] = minmax_scale(
instance_seg_image[:, :, i])
embedding_image = np.array(instance_seg_image).astype(np.uint8)
plt.figure('mask_image')
......@@ -189,13 +197,13 @@ def visualize(cfg,
plt.imshow(binary_seg_image * 255, cmap='gray')
plt.show()
cv2.imwrite(pred_binary_fn, np.array(binary_seg_image * 255).astype(np.uint8))
cv2.imwrite(pred_binary_fn,
np.array(binary_seg_image * 255).astype(np.uint8))
cv2.imwrite(pred_lane_fn, postprocess_result['source_image'])
cv2.imwrite(pred_instance_fn, mask_image)
print(pred_lane_fn, 'saved!')
if __name__ == '__main__':
args = parse_args()
if args.cfg_file is not None:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -23,7 +24,8 @@ from test_utils import download_file_and_uncompress
if __name__ == "__main__":
download_file_and_uncompress(
url='https://paddleseg.bj.bcebos.com/models/unet_mechanical_industry_meter.tar',
url=
'https://paddleseg.bj.bcebos.com/models/unet_mechanical_industry_meter.tar',
savepath=LOCAL_PATH,
extrapath=LOCAL_PATH)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .load_model import *
from .unet import *
from .hrnet import *
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import paddle.fluid as fluid
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import numpy as np
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .unet import UNet
from .hrnet import HRNet
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
import sys
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import argparse
import transforms.transforms as T
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -201,11 +202,9 @@ def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False):
vars_to_load.append(var)
logging.debug("Weight {} will be load".format(var.name))
fluid.io.load_vars(
executor=exe,
dirname=weights_dir,
main_program=main_prog,
vars=vars_to_load)
params_dict = fluid.io.load_program_state(
weights_dir, var_list=vars_to_load)
fluid.io.set_program_state(main_prog, params_dict)
if len(vars_to_load) == 0:
logging.warning(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from utils.util import AttrDict, merge_cfg_from_args, get_arguments
import os
......@@ -6,20 +20,20 @@ args = get_arguments()
cfg = AttrDict()
# 待预测图像所在路径
cfg.data_dir = os.path.join(args.example , "data", "test_images")
cfg.data_dir = os.path.join(args.example, "data", "test_images")
# 待预测图像名称列表
cfg.data_list_file = os.path.join(args.example , "data", "test.txt")
cfg.data_list_file = os.path.join(args.example, "data", "test.txt")
# 模型加载路径
cfg.model_path = os.path.join(args.example , "model")
cfg.model_path = os.path.join(args.example, "model")
# 预测结果保存路径
cfg.vis_dir = os.path.join(args.example , "result")
cfg.vis_dir = os.path.join(args.example, "result")
# 预测类别数
cfg.class_num = 2
# 均值, 图像预处理减去的均值
cfg.MEAN = 127.5, 127.5, 127.5
# 标准差,图像预处理除以标准差
cfg.STD = 127.5, 127.5, 127.5
cfg.STD = 127.5, 127.5, 127.5
# 待预测图像输入尺寸
cfg.input_size = 1536, 576
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import numpy as np
......@@ -12,18 +26,19 @@ config = importlib.import_module('config')
cfg = getattr(config, 'cfg')
# paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启
os.environ['FLAGS_eager_delete_tensor_gb']='0.0'
os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
import paddle.fluid as fluid
# 预测数据集类
class TestDataSet():
def __init__(self):
self.data_dir = cfg.data_dir
self.data_dir = cfg.data_dir
self.data_list_file = cfg.data_list_file
self.data_list = self.get_data_list()
self.data_num = len(self.data_list)
def get_data_list(self):
# 获取预测图像路径列表
data_list = []
......@@ -40,7 +55,7 @@ class TestDataSet():
def preprocess(self, img):
# 图像预处理
if cfg.example == 'ACE2P':
reader = importlib.import_module(args.example+'.reader')
reader = importlib.import_module(args.example + '.reader')
ACE2P_preprocess = getattr(reader, 'preprocess')
img = ACE2P_preprocess(img)
else:
......@@ -56,10 +71,10 @@ class TestDataSet():
img_path = self.data_list[index]
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
if img is None:
return img, img,img_path, None
return img, img, img_path, None
img_name = img_path.split(os.sep)[-1]
name_prefix = img_name.replace('.'+img_name.split('.')[-1],'')
name_prefix = img_name.replace('.' + img_name.split('.')[-1], '')
img_shape = img.shape[:2]
img_process = self.preprocess(img)
......@@ -90,39 +105,44 @@ def infer():
if image is None:
print(im_name, 'is None')
continue
# 预测
if cfg.example == 'ACE2P':
# ACE2P模型使用多尺度预测
reader = importlib.import_module(args.example+'.reader')
reader = importlib.import_module(args.example + '.reader')
multi_scale_test = getattr(reader, 'multi_scale_test')
parsing, logits = multi_scale_test(exe, test_prog, feed_name, fetch_list, image, im_shape)
parsing, logits = multi_scale_test(exe, test_prog, feed_name,
fetch_list, image, im_shape)
else:
# HumanSeg,RoadLine模型单尺度预测
result = exe.run(program=test_prog, feed={feed_name[0]: image}, fetch_list=fetch_list)
result = exe.run(
program=test_prog,
feed={feed_name[0]: image},
fetch_list=fetch_list)
parsing = np.argmax(result[0][0], axis=0)
parsing = cv2.resize(parsing.astype(np.uint8), im_shape[::-1])
# 预测结果保存
result_path = os.path.join(cfg.vis_dir, im_name + '.png')
if cfg.example == 'HumanSeg':
logits = result[0][0][1]*255
logits = result[0][0][1] * 255
logits = cv2.resize(logits, im_shape[::-1])
ret, logits = cv2.threshold(logits, thresh, 0, cv2.THRESH_TOZERO)
logits = 255 *(logits - thresh)/(255 - thresh)
logits = 255 * (logits - thresh) / (255 - thresh)
# 将分割结果添加到alpha通道
rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)), axis=2)
rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)),
axis=2)
cv2.imwrite(result_path, rgba)
else:
else:
output_im = PILImage.fromarray(np.asarray(parsing, dtype=np.uint8))
output_im.putpalette(palette)
output_im.save(result_path)
if (idx + 1) % 100 == 0:
print('%d processd' % (idx + 1))
print('%d processd done' % (idx + 1))
print('%d processd done' % (idx + 1))
return 0
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: RainbowSecret
## Microsoft Research
## yuyua@microsoft.com
## Copyright (c) 2018
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import os
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--use_gpu",
action="store_true",
help="Use gpu or cpu to test.")
parser.add_argument('--example',
type=str,
help='RoadLine, HumanSeg or ACE2P')
parser.add_argument(
"--use_gpu", action="store_true", help="Use gpu or cpu to test.")
parser.add_argument(
'--example', type=str, help='RoadLine, HumanSeg or ACE2P')
return parser.parse_args()
......@@ -34,6 +48,7 @@ class AttrDict(dict):
else:
self[name] = value
def merge_cfg_from_args(args, cfg):
"""Merge config keys, values in args into the global config."""
for k, v in vars(args).items():
......@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg):
value = v
if value is not None:
cfg[k] = value
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -20,6 +21,8 @@ from PIL import Image
import glob
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
def remove_colormap(filename):
gray_anno = np.array(Image.open(filename))
return gray_anno
......@@ -30,6 +33,7 @@ def save_annotation(annotation, filename):
annotation = Image.fromarray(annotation)
annotation.save(filename)
def convert_list(origin_file, seg_file, output_folder):
with open(seg_file, 'w') as fid_seg:
with open(origin_file) as fid_ori:
......@@ -43,6 +47,7 @@ def convert_list(origin_file, seg_file, output_folder):
new_line = ' '.join([img_name, anno_name])
fid_seg.write(new_line + "\n")
if __name__ == "__main__":
pascal_root = "./VOCtrainval_11-May-2012/VOC2012"
pascal_root = os.path.join(LOCAL_PATH, pascal_root)
......@@ -54,7 +59,7 @@ if __name__ == "__main__":
# 标注图转换后存储目录
output_folder = os.path.join(pascal_root, "SegmentationClassAug")
print("annotation convert and file list convert")
if not os.path.exists(os.path.join(LOCAL_PATH, output_folder)):
os.mkdir(os.path.join(LOCAL_PATH, output_folder))
......@@ -67,5 +72,5 @@ if __name__ == "__main__":
convert_list(train_path, train_path.replace('txt', 'list'), output_folder)
convert_list(val_path, val_path.replace('txt', 'list'), output_folder)
convert_list(trainval_path, trainval_path.replace('txt', 'list'), output_folder)
convert_list(trainval_path, trainval_path.replace('txt', 'list'),
output_folder)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -28,12 +29,12 @@ from convert_voc2012 import remove_colormap
from convert_voc2012 import save_annotation
def download_VOC_dataset(savepath, extrapath):
url = "https://paddleseg.bj.bcebos.com/dataset/VOCtrainval_11-May-2012.tar"
download_file_and_uncompress(
url=url, savepath=savepath, extrapath=extrapath)
if __name__ == "__main__":
download_VOC_dataset(LOCAL_PATH, LOCAL_PATH)
print("Dataset download finish!")
......@@ -45,10 +46,10 @@ if __name__ == "__main__":
train_path = os.path.join(txt_folder, "train.txt")
val_path = os.path.join(txt_folder, "val.txt")
trainval_path = os.path.join(txt_folder, "trainval.txt")
# 标注图转换后存储目录
output_folder = os.path.join(pascal_root, "SegmentationClassAug")
print("annotation convert and file list convert")
if not os.path.exists(output_folder):
os.mkdir(output_folder)
......@@ -61,5 +62,5 @@ if __name__ == "__main__":
convert_list(train_path, train_path.replace('txt', 'list'), output_folder)
convert_list(val_path, val_path.replace('txt', 'list'), output_folder)
convert_list(trainval_path, trainval_path.replace('txt', 'list'), output_folder)
convert_list(trainval_path, trainval_path.replace('txt', 'list'),
output_folder)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -29,9 +29,9 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
gflags.DEFINE_string("conf", default="", help="Configuration File Path")
gflags.DEFINE_string("input_dir", default="", help="Directory of Input Images")
gflags.DEFINE_boolean("use_pr", default=False, help="Use optimized model")
gflags.DEFINE_string("trt_mode", default="", help="Use optimized model")
gflags.DEFINE_string("ext", default=".jpeg|.jpg", help="Input Image File Extensions")
gflags.DEFINE_string(
"ext", default=".jpeg|.jpg", help="Input Image File Extensions")
gflags.FLAGS = gflags.FLAGS
......@@ -103,6 +103,9 @@ class DeployConfig:
self.batch_size = deploy_conf["BATCH_SIZE"]
# 9. channels
self.channels = deploy_conf["CHANNELS"]
# 10. use_pr
self.use_pr = deploy_conf["USE_PR"]
class ImageReader:
......@@ -257,23 +260,24 @@ class Predictor:
# record starting time point
total_start = time.time()
batch_size = self.config.batch_size
use_pr = self.config.use_pr
for i in range(0, len(images), batch_size):
real_batch_size = batch_size
if i + batch_size >= len(images):
real_batch_size = len(images) - i
reader_start = time.time()
img_datas = self.image_reader.process(images[i:i + real_batch_size],
gflags.FLAGS.use_pr)
use_pr)
input_data = np.concatenate([item[1] for item in img_datas])
input_data = self.create_tensor(
input_data, real_batch_size, use_pr=gflags.FLAGS.use_pr)
input_data, real_batch_size, use_pr=use_pr)
reader_end = time.time()
infer_start = time.time()
output_data = self.predictor.run(input_data)[0]
infer_end = time.time()
output_data = output_data.as_ndarray()
post_start = time.time()
self.output_result(img_datas, output_data, gflags.FLAGS.use_pr)
self.output_result(img_datas, output_data, use_pr)
post_end = time.time()
reader_time += (reader_end - reader_start)
infer_time += (infer_end - infer_start)
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,4 +14,4 @@
# limitations under the License.
import models
import utils
from . import tools
\ No newline at end of file
from . import tools
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册