未验证 提交 664dc115 编写于 作者: W wuyefeilin 提交者: GitHub

Unifiy copyright format (#261)

* update model save load

* first add

* update model save and load

* update train.py

* update LaneNet model saving and loading

* adapt slim to paddle-1.8

* update distillation save and load

* update nas model save and load

* update model load op

* update utils.py

* update load_model_utils.py

* update model saving and loading

* update copyright

* update palette.py
上级 020d1072
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from utils.util import AttrDict, merge_cfg_from_args, get_arguments from utils.util import AttrDict, merge_cfg_from_args, get_arguments
import os import os
...@@ -19,10 +33,10 @@ cfg.class_num = 20 ...@@ -19,10 +33,10 @@ cfg.class_num = 20
# 均值, 图像预处理减去的均值 # 均值, 图像预处理减去的均值
cfg.MEAN = 0.406, 0.456, 0.485 cfg.MEAN = 0.406, 0.456, 0.485
# 标准差,图像预处理除以标准差 # 标准差,图像预处理除以标准差
cfg.STD = 0.225, 0.224, 0.229 cfg.STD = 0.225, 0.224, 0.229
# 多尺度预测时图像尺寸 # 多尺度预测时图像尺寸
cfg.multi_scales = (377,377), (473,473), (567,567) cfg.multi_scales = (377, 377), (473, 473), (567, 567)
# 多尺度预测时图像是否水平翻转 # 多尺度预测时图像是否水平翻转
cfg.flip = True cfg.flip = True
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import cv2 import cv2
import numpy as np import numpy as np
...@@ -12,18 +26,19 @@ config = importlib.import_module('config') ...@@ -12,18 +26,19 @@ config = importlib.import_module('config')
cfg = getattr(config, 'cfg') cfg = getattr(config, 'cfg')
# paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启 # paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启
os.environ['FLAGS_eager_delete_tensor_gb']='0.0' os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
import paddle.fluid as fluid import paddle.fluid as fluid
# 预测数据集类 # 预测数据集类
class TestDataSet(): class TestDataSet():
def __init__(self): def __init__(self):
self.data_dir = cfg.data_dir self.data_dir = cfg.data_dir
self.data_list_file = cfg.data_list_file self.data_list_file = cfg.data_list_file
self.data_list = self.get_data_list() self.data_list = self.get_data_list()
self.data_num = len(self.data_list) self.data_num = len(self.data_list)
def get_data_list(self): def get_data_list(self):
# 获取预测图像路径列表 # 获取预测图像路径列表
data_list = [] data_list = []
...@@ -56,10 +71,10 @@ class TestDataSet(): ...@@ -56,10 +71,10 @@ class TestDataSet():
img_path = self.data_list[index] img_path = self.data_list[index]
img = cv2.imread(img_path, cv2.IMREAD_COLOR) img = cv2.imread(img_path, cv2.IMREAD_COLOR)
if img is None: if img is None:
return img, img,img_path, None return img, img, img_path, None
img_name = img_path.split(os.sep)[-1] img_name = img_path.split(os.sep)[-1]
name_prefix = img_name.replace('.'+img_name.split('.')[-1],'') name_prefix = img_name.replace('.' + img_name.split('.')[-1], '')
img_shape = img.shape[:2] img_shape = img.shape[:2]
img_process = self.preprocess(img) img_process = self.preprocess(img)
...@@ -90,39 +105,44 @@ def infer(): ...@@ -90,39 +105,44 @@ def infer():
if image is None: if image is None:
print(im_name, 'is None') print(im_name, 'is None')
continue continue
# 预测 # 预测
if cfg.example == 'ACE2P': if cfg.example == 'ACE2P':
# ACE2P模型使用多尺度预测 # ACE2P模型使用多尺度预测
reader = importlib.import_module('reader') reader = importlib.import_module('reader')
multi_scale_test = getattr(reader, 'multi_scale_test') multi_scale_test = getattr(reader, 'multi_scale_test')
parsing, logits = multi_scale_test(exe, test_prog, feed_name, fetch_list, image, im_shape) parsing, logits = multi_scale_test(exe, test_prog, feed_name,
fetch_list, image, im_shape)
else: else:
# HumanSeg,RoadLine模型单尺度预测 # HumanSeg,RoadLine模型单尺度预测
result = exe.run(program=test_prog, feed={feed_name[0]: image}, fetch_list=fetch_list) result = exe.run(
program=test_prog,
feed={feed_name[0]: image},
fetch_list=fetch_list)
parsing = np.argmax(result[0][0], axis=0) parsing = np.argmax(result[0][0], axis=0)
parsing = cv2.resize(parsing.astype(np.uint8), im_shape[::-1]) parsing = cv2.resize(parsing.astype(np.uint8), im_shape[::-1])
# 预测结果保存 # 预测结果保存
result_path = os.path.join(cfg.vis_dir, im_name + '.png') result_path = os.path.join(cfg.vis_dir, im_name + '.png')
if cfg.example == 'HumanSeg': if cfg.example == 'HumanSeg':
logits = result[0][0][1]*255 logits = result[0][0][1] * 255
logits = cv2.resize(logits, im_shape[::-1]) logits = cv2.resize(logits, im_shape[::-1])
ret, logits = cv2.threshold(logits, thresh, 0, cv2.THRESH_TOZERO) ret, logits = cv2.threshold(logits, thresh, 0, cv2.THRESH_TOZERO)
logits = 255 *(logits - thresh)/(255 - thresh) logits = 255 * (logits - thresh) / (255 - thresh)
# 将分割结果添加到alpha通道 # 将分割结果添加到alpha通道
rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)), axis=2) rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)),
axis=2)
cv2.imwrite(result_path, rgba) cv2.imwrite(result_path, rgba)
else: else:
output_im = PILImage.fromarray(np.asarray(parsing, dtype=np.uint8)) output_im = PILImage.fromarray(np.asarray(parsing, dtype=np.uint8))
output_im.putpalette(palette) output_im.putpalette(palette)
output_im.save(result_path) output_im.save(result_path)
if (idx + 1) % 100 == 0: if (idx + 1) % 100 == 0:
print('%d processd' % (idx + 1)) print('%d processd' % (idx + 1))
print('%d processd done' % (idx + 1)) print('%d processd done' % (idx + 1))
return 0 return 0
......
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from config import cfg from config import cfg
import cv2 import cv2
def get_affine_points(src_shape, dst_shape, rot_grad=0): def get_affine_points(src_shape, dst_shape, rot_grad=0):
# 获取图像和仿射后图像的三组对应点坐标 # 获取图像和仿射后图像的三组对应点坐标
# 三组点为仿射变换后图像的中心点, [w/2,0], [0,0],及对应原始图像的点 # 三组点为仿射变换后图像的中心点, [w/2,0], [0,0],及对应原始图像的点
...@@ -23,7 +38,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0): ...@@ -23,7 +38,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0):
# 原始图像三组点 # 原始图像三组点
points = [[0, 0]] * 3 points = [[0, 0]] * 3
points[0] = (np.array([w, h]) - 1) * 0.5 points[0] = (np.array([w, h]) - 1) * 0.5
points[1] = points[0] + 0.5 * affine_shape[0] * np.array([sin_v, -cos_v]) points[1] = points[0] + 0.5 * affine_shape[0] * np.array([sin_v, -cos_v])
points[2] = points[1] - 0.5 * affine_shape[1] * np.array([cos_v, sin_v]) points[2] = points[1] - 0.5 * affine_shape[1] * np.array([cos_v, sin_v])
...@@ -34,6 +49,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0): ...@@ -34,6 +49,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0):
return points, points_trans return points, points_trans
def preprocess(im): def preprocess(im):
# ACE2P模型数据预处理 # ACE2P模型数据预处理
im_shape = im.shape[:2] im_shape = im.shape[:2]
...@@ -42,13 +58,10 @@ def preprocess(im): ...@@ -42,13 +58,10 @@ def preprocess(im):
# 获取图像和仿射变换后图像的对应点坐标 # 获取图像和仿射变换后图像的对应点坐标
points, points_trans = get_affine_points(im_shape, scale) points, points_trans = get_affine_points(im_shape, scale)
# 根据对应点集获得仿射矩阵 # 根据对应点集获得仿射矩阵
trans = cv2.getAffineTransform(np.float32(points), trans = cv2.getAffineTransform(
np.float32(points_trans)) np.float32(points), np.float32(points_trans))
# 根据仿射矩阵对图像进行仿射 # 根据仿射矩阵对图像进行仿射
input = cv2.warpAffine(im, input = cv2.warpAffine(im, trans, scale[::-1], flags=cv2.INTER_LINEAR)
trans,
scale[::-1],
flags=cv2.INTER_LINEAR)
# 减均值测,除以方差,转换数据格式为NCHW # 减均值测,除以方差,转换数据格式为NCHW
input = input.astype(np.float32) input = input.astype(np.float32)
...@@ -66,19 +79,20 @@ def preprocess(im): ...@@ -66,19 +79,20 @@ def preprocess(im):
return input_images return input_images
def multi_scale_test(exe, test_prog, feed_name, fetch_list, def multi_scale_test(exe, test_prog, feed_name, fetch_list, input_ims,
input_ims, im_shape): im_shape):
# 由于部分类别分左右部位, flipped_idx为其水平翻转后对应的标签 # 由于部分类别分左右部位, flipped_idx为其水平翻转后对应的标签
flipped_idx = (15, 14, 17, 16, 19, 18) flipped_idx = (15, 14, 17, 16, 19, 18)
ms_outputs = [] ms_outputs = []
# 多尺度预测 # 多尺度预测
for idx, scale in enumerate(cfg.multi_scales): for idx, scale in enumerate(cfg.multi_scales):
input_im = input_ims[idx] input_im = input_ims[idx]
parsing_output = exe.run(program=test_prog, parsing_output = exe.run(
feed={feed_name[0]: input_im}, program=test_prog,
fetch_list=fetch_list) feed={feed_name[0]: input_im},
fetch_list=fetch_list)
output = parsing_output[0][0] output = parsing_output[0][0]
if cfg.flip: if cfg.flip:
# 若水平翻转,对部分类别进行翻转,与原始预测结果取均值 # 若水平翻转,对部分类别进行翻转,与原始预测结果取均值
...@@ -92,7 +106,8 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list, ...@@ -92,7 +106,8 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list,
# 仿射变换回图像原始尺寸 # 仿射变换回图像原始尺寸
points, points_trans = get_affine_points(im_shape, scale) points, points_trans = get_affine_points(im_shape, scale)
M = cv2.getAffineTransform(np.float32(points_trans), np.float32(points)) M = cv2.getAffineTransform(np.float32(points_trans), np.float32(points))
logits_result = cv2.warpAffine(output, M, im_shape[::-1], flags=cv2.INTER_LINEAR) logits_result = cv2.warpAffine(
output, M, im_shape[::-1], flags=cv2.INTER_LINEAR)
ms_outputs.append(logits_result) ms_outputs.append(logits_result)
# 多尺度预测结果求均值,求预测概率最大的类别 # 多尺度预测结果求均值,求预测概率最大的类别
...@@ -100,4 +115,3 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list, ...@@ -100,4 +115,3 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list,
ms_fused_parsing_output = np.mean(ms_fused_parsing_output, axis=0) ms_fused_parsing_output = np.mean(ms_fused_parsing_output, axis=0)
parsing = np.argmax(ms_fused_parsing_output, axis=2) parsing = np.argmax(ms_fused_parsing_output, axis=2)
return parsing, ms_fused_parsing_output return parsing, ms_fused_parsing_output
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
## This source code is licensed under the MIT-style license found in the ## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree ## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import argparse import argparse
import os import os
def get_arguments(): def get_arguments():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--use_gpu", parser.add_argument(
action="store_true", "--use_gpu", action="store_true", help="Use gpu or cpu to test.")
help="Use gpu or cpu to test.") parser.add_argument(
parser.add_argument('--example', '--example', type=str, help='RoadLine, HumanSeg or ACE2P')
type=str,
help='RoadLine, HumanSeg or ACE2P')
return parser.parse_args() return parser.parse_args()
...@@ -34,6 +48,7 @@ class AttrDict(dict): ...@@ -34,6 +48,7 @@ class AttrDict(dict):
else: else:
self[name] = value self[name] = value
def merge_cfg_from_args(args, cfg): def merge_cfg_from_args(args, cfg):
"""Merge config keys, values in args into the global config.""" """Merge config keys, values in args into the global config."""
for k, v in vars(args).items(): for k, v in vars(args).items():
...@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg): ...@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg):
value = v value = v
if value is not None: if value is not None:
cfg[k] = value cfg[k] = value
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,9 +13,6 @@ ...@@ -12,9 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# utils for memory management which is allocated on sharedmemory,
# note that these structures may not be thread-safe
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import models import models
import argparse import argparse
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import os import os
import os.path as osp import os.path as osp
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .humanseg import HumanSegMobile from .humanseg import HumanSegMobile
from .humanseg import HumanSegServer from .humanseg import HumanSegServer
from .humanseg import HumanSegLite from .humanseg import HumanSegLite
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .backbone import mobilenet_v2 from .backbone import mobilenet_v2
from .backbone import xception from .backbone import xception
from .deeplabv3p import DeepLabv3p from .deeplabv3p import DeepLabv3p
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .mobilenet_v2 import MobileNetV2 from .mobilenet_v2 import MobileNetV2
from .xception import Xception from .xception import Xception
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -10,6 +11,7 @@ ...@@ -10,6 +11,7 @@
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
from datasets.dataset import Dataset from datasets.dataset import Dataset
import transforms import transforms
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
from datasets.dataset import Dataset from datasets.dataset import Dataset
from models import HumanSegMobile, HumanSegLite, HumanSegServer from models import HumanSegMobile, HumanSegLite, HumanSegServer
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
from datasets.dataset import Dataset from datasets.dataset import Dataset
from models import HumanSegMobile, HumanSegLite, HumanSegServer from models import HumanSegMobile, HumanSegLite, HumanSegServer
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
import cv2 import cv2
import os import os
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
from datasets.dataset import Dataset from datasets.dataset import Dataset
import transforms import transforms
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import os import os
import os.path as osp import os.path as osp
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,6 +21,7 @@ from models.model_builder import ModelPhase ...@@ -21,6 +21,7 @@ from models.model_builder import ModelPhase
from pdseg.data_aug import get_random_scale, randomly_scale_image_and_label, random_rotation, \ from pdseg.data_aug import get_random_scale, randomly_scale_image_and_label, random_rotation, \
rand_scale_aspect, hsv_color_jitter, rand_crop rand_scale_aspect, hsv_color_jitter, rand_crop
def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN): def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN):
""" """
改变图像及标签图像尺寸 改变图像及标签图像尺寸
...@@ -44,7 +45,8 @@ def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN): ...@@ -44,7 +45,8 @@ def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN):
if grt is not None: if grt is not None:
grt = cv2.resize(grt, target_size, interpolation=cv2.INTER_NEAREST) grt = cv2.resize(grt, target_size, interpolation=cv2.INTER_NEAREST)
if grt_instance is not None: if grt_instance is not None:
grt_instance = cv2.resize(grt_instance, target_size, interpolation=cv2.INTER_NEAREST) grt_instance = cv2.resize(
grt_instance, target_size, interpolation=cv2.INTER_NEAREST)
elif cfg.AUG.AUG_METHOD == 'stepscaling': elif cfg.AUG.AUG_METHOD == 'stepscaling':
if mode == ModelPhase.TRAIN: if mode == ModelPhase.TRAIN:
min_scale_factor = cfg.AUG.MIN_SCALE_FACTOR min_scale_factor = cfg.AUG.MIN_SCALE_FACTOR
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,7 +18,6 @@ from __future__ import print_function ...@@ -18,7 +18,6 @@ from __future__ import print_function
import paddle.fluid as fluid import paddle.fluid as fluid
from utils.config import cfg from utils.config import cfg
from pdseg.models.libs.model_libs import scope, name_scope from pdseg.models.libs.model_libs import scope, name_scope
from pdseg.models.libs.model_libs import bn, bn_relu, relu from pdseg.models.libs.model_libs import bn, bn_relu, relu
...@@ -86,7 +85,12 @@ def bottleneck(inputs, ...@@ -86,7 +85,12 @@ def bottleneck(inputs,
with scope('down_sample'): with scope('down_sample'):
inputs_shape = inputs.shape inputs_shape = inputs.shape
with scope('main_max_pool'): with scope('main_max_pool'):
net_main = fluid.layers.conv2d(inputs, inputs_shape[1], filter_size=3, stride=2, padding='SAME') net_main = fluid.layers.conv2d(
inputs,
inputs_shape[1],
filter_size=3,
stride=2,
padding='SAME')
#First get the difference in depth to pad, then pad with zeros only on the last dimension. #First get the difference in depth to pad, then pad with zeros only on the last dimension.
depth_to_pad = abs(inputs_shape[1] - output_depth) depth_to_pad = abs(inputs_shape[1] - output_depth)
...@@ -95,12 +99,16 @@ def bottleneck(inputs, ...@@ -95,12 +99,16 @@ def bottleneck(inputs,
net_main = fluid.layers.pad(net_main, paddings=paddings) net_main = fluid.layers.pad(net_main, paddings=paddings)
with scope('block1'): with scope('block1'):
net = conv(inputs, reduced_depth, [2, 2], stride=2, padding='same') net = conv(
inputs, reduced_depth, [2, 2], stride=2, padding='same')
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
with scope('block2'): with scope('block2'):
net = conv(net, reduced_depth, [filter_size, filter_size], padding='same') net = conv(
net,
reduced_depth, [filter_size, filter_size],
padding='same')
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
...@@ -137,13 +145,18 @@ def bottleneck(inputs, ...@@ -137,13 +145,18 @@ def bottleneck(inputs,
# Second conv block --- apply dilated convolution here # Second conv block --- apply dilated convolution here
with scope('block2'): with scope('block2'):
net = conv(net, reduced_depth, filter_size, padding='SAME', dilation=dilation_rate) net = conv(
net,
reduced_depth,
filter_size,
padding='SAME',
dilation=dilation_rate)
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
# Final projection with 1x1 kernel (Expansion) # Final projection with 1x1 kernel (Expansion)
with scope('block3'): with scope('block3'):
net = conv(net, output_depth, [1,1]) net = conv(net, output_depth, [1, 1])
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
...@@ -172,9 +185,11 @@ def bottleneck(inputs, ...@@ -172,9 +185,11 @@ def bottleneck(inputs,
# Second conv block --- apply asymmetric conv here # Second conv block --- apply asymmetric conv here
with scope('block2'): with scope('block2'):
with scope('asymmetric_conv2a'): with scope('asymmetric_conv2a'):
net = conv(net, reduced_depth, [filter_size, 1], padding='same') net = conv(
net, reduced_depth, [filter_size, 1], padding='same')
with scope('asymmetric_conv2b'): with scope('asymmetric_conv2b'):
net = conv(net, reduced_depth, [1, filter_size], padding='same') net = conv(
net, reduced_depth, [1, filter_size], padding='same')
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
...@@ -211,7 +226,8 @@ def bottleneck(inputs, ...@@ -211,7 +226,8 @@ def bottleneck(inputs,
with scope('unpool'): with scope('unpool'):
net_unpool = conv(inputs, output_depth, [1, 1]) net_unpool = conv(inputs, output_depth, [1, 1])
net_unpool = bn(net_unpool) net_unpool = bn(net_unpool)
net_unpool = fluid.layers.resize_bilinear(net_unpool, out_shape=output_shape[2:]) net_unpool = fluid.layers.resize_bilinear(
net_unpool, out_shape=output_shape[2:])
# First 1x1 projection to reduce depth # First 1x1 projection to reduce depth
with scope('block1'): with scope('block1'):
...@@ -220,7 +236,12 @@ def bottleneck(inputs, ...@@ -220,7 +236,12 @@ def bottleneck(inputs,
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
with scope('block2'): with scope('block2'):
net = deconv(net, reduced_depth, filter_size=filter_size, stride=2, padding='same') net = deconv(
net,
reduced_depth,
filter_size=filter_size,
stride=2,
padding='same')
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
...@@ -253,7 +274,10 @@ def bottleneck(inputs, ...@@ -253,7 +274,10 @@ def bottleneck(inputs,
# Second conv block # Second conv block
with scope('block2'): with scope('block2'):
net = conv(net, reduced_depth, [filter_size, filter_size], padding='same') net = conv(
net,
reduced_depth, [filter_size, filter_size],
padding='same')
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
...@@ -281,17 +305,33 @@ def ENet_stage1(inputs, name_scope='stage1_block'): ...@@ -281,17 +305,33 @@ def ENet_stage1(inputs, name_scope='stage1_block'):
= bottleneck(inputs, output_depth=64, filter_size=3, regularizer_prob=0.01, type=DOWNSAMPLING, = bottleneck(inputs, output_depth=64, filter_size=3, regularizer_prob=0.01, type=DOWNSAMPLING,
name_scope='bottleneck1_0') name_scope='bottleneck1_0')
with scope('bottleneck1_1'): with scope('bottleneck1_1'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01, net = bottleneck(
name_scope='bottleneck1_1') net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_1')
with scope('bottleneck1_2'): with scope('bottleneck1_2'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01, net = bottleneck(
name_scope='bottleneck1_2') net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_2')
with scope('bottleneck1_3'): with scope('bottleneck1_3'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01, net = bottleneck(
name_scope='bottleneck1_3') net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_3')
with scope('bottleneck1_4'): with scope('bottleneck1_4'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01, net = bottleneck(
name_scope='bottleneck1_4') net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_4')
return net, inputs_shape_1 return net, inputs_shape_1
...@@ -302,17 +342,38 @@ def ENet_stage2(inputs, name_scope='stage2_block'): ...@@ -302,17 +342,38 @@ def ENet_stage2(inputs, name_scope='stage2_block'):
name_scope='bottleneck2_0') name_scope='bottleneck2_0')
for i in range(2): for i in range(2):
with scope('bottleneck2_{}'.format(str(4 * i + 1))): with scope('bottleneck2_{}'.format(str(4 * i + 1))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, net = bottleneck(
name_scope='bottleneck2_{}'.format(str(4 * i + 1))) net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
name_scope='bottleneck2_{}'.format(str(4 * i + 1)))
with scope('bottleneck2_{}'.format(str(4 * i + 2))): with scope('bottleneck2_{}'.format(str(4 * i + 2))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+1)), net = bottleneck(
name_scope='bottleneck2_{}'.format(str(4 * i + 2))) net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 1)),
name_scope='bottleneck2_{}'.format(str(4 * i + 2)))
with scope('bottleneck2_{}'.format(str(4 * i + 3))): with scope('bottleneck2_{}'.format(str(4 * i + 3))):
net = bottleneck(net, output_depth=128, filter_size=5, regularizer_prob=0.1, type=ASYMMETRIC, net = bottleneck(
name_scope='bottleneck2_{}'.format(str(4 * i + 3))) net,
output_depth=128,
filter_size=5,
regularizer_prob=0.1,
type=ASYMMETRIC,
name_scope='bottleneck2_{}'.format(str(4 * i + 3)))
with scope('bottleneck2_{}'.format(str(4 * i + 4))): with scope('bottleneck2_{}'.format(str(4 * i + 4))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+2)), net = bottleneck(
name_scope='bottleneck2_{}'.format(str(4 * i + 4))) net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 2)),
name_scope='bottleneck2_{}'.format(str(4 * i + 4)))
return net, inputs_shape_2 return net, inputs_shape_2
...@@ -320,52 +381,106 @@ def ENet_stage3(inputs, name_scope='stage3_block'): ...@@ -320,52 +381,106 @@ def ENet_stage3(inputs, name_scope='stage3_block'):
with scope(name_scope): with scope(name_scope):
for i in range(2): for i in range(2):
with scope('bottleneck3_{}'.format(str(4 * i + 0))): with scope('bottleneck3_{}'.format(str(4 * i + 0))):
net = bottleneck(inputs, output_depth=128, filter_size=3, regularizer_prob=0.1, net = bottleneck(
name_scope='bottleneck3_{}'.format(str(4 * i + 0))) inputs,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
name_scope='bottleneck3_{}'.format(str(4 * i + 0)))
with scope('bottleneck3_{}'.format(str(4 * i + 1))): with scope('bottleneck3_{}'.format(str(4 * i + 1))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+1)), net = bottleneck(
name_scope='bottleneck3_{}'.format(str(4 * i + 1))) net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 1)),
name_scope='bottleneck3_{}'.format(str(4 * i + 1)))
with scope('bottleneck3_{}'.format(str(4 * i + 2))): with scope('bottleneck3_{}'.format(str(4 * i + 2))):
net = bottleneck(net, output_depth=128, filter_size=5, regularizer_prob=0.1, type=ASYMMETRIC, net = bottleneck(
name_scope='bottleneck3_{}'.format(str(4 * i + 2))) net,
output_depth=128,
filter_size=5,
regularizer_prob=0.1,
type=ASYMMETRIC,
name_scope='bottleneck3_{}'.format(str(4 * i + 2)))
with scope('bottleneck3_{}'.format(str(4 * i + 3))): with scope('bottleneck3_{}'.format(str(4 * i + 3))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+2)), net = bottleneck(
name_scope='bottleneck3_{}'.format(str(4 * i + 3))) net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 2)),
name_scope='bottleneck3_{}'.format(str(4 * i + 3)))
return net return net
def ENet_stage4(inputs, inputs_shape, connect_tensor, def ENet_stage4(inputs,
skip_connections=True, name_scope='stage4_block'): inputs_shape,
connect_tensor,
skip_connections=True,
name_scope='stage4_block'):
with scope(name_scope): with scope(name_scope):
with scope('bottleneck4_0'): with scope('bottleneck4_0'):
net = bottleneck(inputs, output_depth=64, filter_size=3, regularizer_prob=0.1, net = bottleneck(
type=UPSAMPLING, decoder=True, output_shape=inputs_shape, inputs,
name_scope='bottleneck4_0') output_depth=64,
filter_size=3,
regularizer_prob=0.1,
type=UPSAMPLING,
decoder=True,
output_shape=inputs_shape,
name_scope='bottleneck4_0')
if skip_connections: if skip_connections:
net = fluid.layers.elementwise_add(net, connect_tensor) net = fluid.layers.elementwise_add(net, connect_tensor)
with scope('bottleneck4_1'): with scope('bottleneck4_1'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.1, decoder=True, net = bottleneck(
name_scope='bottleneck4_1') net,
output_depth=64,
filter_size=3,
regularizer_prob=0.1,
decoder=True,
name_scope='bottleneck4_1')
with scope('bottleneck4_2'): with scope('bottleneck4_2'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.1, decoder=True, net = bottleneck(
name_scope='bottleneck4_2') net,
output_depth=64,
filter_size=3,
regularizer_prob=0.1,
decoder=True,
name_scope='bottleneck4_2')
return net return net
def ENet_stage5(inputs, inputs_shape, connect_tensor, skip_connections=True, def ENet_stage5(inputs,
inputs_shape,
connect_tensor,
skip_connections=True,
name_scope='stage5_block'): name_scope='stage5_block'):
with scope(name_scope): with scope(name_scope):
net = bottleneck(inputs, output_depth=16, filter_size=3, regularizer_prob=0.1, type=UPSAMPLING, net = bottleneck(
decoder=True, output_shape=inputs_shape, inputs,
name_scope='bottleneck5_0') output_depth=16,
filter_size=3,
regularizer_prob=0.1,
type=UPSAMPLING,
decoder=True,
output_shape=inputs_shape,
name_scope='bottleneck5_0')
if skip_connections: if skip_connections:
net = fluid.layers.elementwise_add(net, connect_tensor) net = fluid.layers.elementwise_add(net, connect_tensor)
with scope('bottleneck5_1'): with scope('bottleneck5_1'):
net = bottleneck(net, output_depth=16, filter_size=3, regularizer_prob=0.1, decoder=True, net = bottleneck(
name_scope='bottleneck5_1') net,
output_depth=16,
filter_size=3,
regularizer_prob=0.1,
decoder=True,
name_scope='bottleneck5_1')
return net return net
...@@ -378,14 +493,16 @@ def decoder(input, num_classes): ...@@ -378,14 +493,16 @@ def decoder(input, num_classes):
segStage3 = ENet_stage3(stage2) segStage3 = ENet_stage3(stage2)
segStage4 = ENet_stage4(segStage3, inputs_shape_2, stage1) segStage4 = ENet_stage4(segStage3, inputs_shape_2, stage1)
segStage5 = ENet_stage5(segStage4, inputs_shape_1, initial) segStage5 = ENet_stage5(segStage4, inputs_shape_1, initial)
segLogits = deconv(segStage5, num_classes, filter_size=2, stride=2, padding='SAME') segLogits = deconv(
segStage5, num_classes, filter_size=2, stride=2, padding='SAME')
# Embedding branch # Embedding branch
with scope('LaneNetEm'): with scope('LaneNetEm'):
emStage3 = ENet_stage3(stage2) emStage3 = ENet_stage3(stage2)
emStage4 = ENet_stage4(emStage3, inputs_shape_2, stage1) emStage4 = ENet_stage4(emStage3, inputs_shape_2, stage1)
emStage5 = ENet_stage5(emStage4, inputs_shape_1, initial) emStage5 = ENet_stage5(emStage4, inputs_shape_1, initial)
emLogits = deconv(emStage5, 4, filter_size=2, stride=2, padding='SAME') emLogits = deconv(
emStage5, 4, filter_size=2, stride=2, padding='SAME')
elif 'vgg' in cfg.MODEL.LANENET.BACKBONE: elif 'vgg' in cfg.MODEL.LANENET.BACKBONE:
encoder_list = ['pool5', 'pool4', 'pool3'] encoder_list = ['pool5', 'pool4', 'pool3']
...@@ -396,14 +513,16 @@ def decoder(input, num_classes): ...@@ -396,14 +513,16 @@ def decoder(input, num_classes):
encoder_list = encoder_list[1:] encoder_list = encoder_list[1:]
for i in range(len(encoder_list)): for i in range(len(encoder_list)):
with scope('deconv_{:d}'.format(i + 1)): with scope('deconv_{:d}'.format(i + 1)):
deconv_out = deconv(score, 64, filter_size=4, stride=2, padding='SAME') deconv_out = deconv(
score, 64, filter_size=4, stride=2, padding='SAME')
input_tensor = input[encoder_list[i]] input_tensor = input[encoder_list[i]]
with scope('score_{:d}'.format(i + 1)): with scope('score_{:d}'.format(i + 1)):
score = conv(input_tensor, 64, 1) score = conv(input_tensor, 64, 1)
score = fluid.layers.elementwise_add(deconv_out, score) score = fluid.layers.elementwise_add(deconv_out, score)
with scope('deconv_final'): with scope('deconv_final'):
emLogits = deconv(score, 64, filter_size=16, stride=8, padding='SAME') emLogits = deconv(
score, 64, filter_size=16, stride=8, padding='SAME')
with scope('score_final'): with scope('score_final'):
segLogits = conv(emLogits, num_classes, 1) segLogits = conv(emLogits, num_classes, 1)
emLogits = relu(conv(emLogits, 4, 1)) emLogits = relu(conv(emLogits, 4, 1))
...@@ -415,7 +534,8 @@ def encoder(input): ...@@ -415,7 +534,8 @@ def encoder(input):
model = vgg_backbone(layers=16) model = vgg_backbone(layers=16)
#output = model.net(input) #output = model.net(input)
_, encode_feature_dict = model.net(input, end_points=13, decode_points=[7, 10, 13]) _, encode_feature_dict = model.net(
input, end_points=13, decode_points=[7, 10, 13])
output = {} output = {}
output['pool3'] = encode_feature_dict[7] output['pool3'] = encode_feature_dict[7]
output['pool4'] = encode_feature_dict[10] output['pool4'] = encode_feature_dict[10]
...@@ -427,8 +547,9 @@ def encoder(input): ...@@ -427,8 +547,9 @@ def encoder(input):
stage2, inputs_shape_2 = ENet_stage2(stage1) stage2, inputs_shape_2 = ENet_stage2(stage1)
output = (initial, stage1, stage2, inputs_shape_1, inputs_shape_2) output = (initial, stage1, stage2, inputs_shape_1, inputs_shape_2)
else: else:
raise Exception("LaneNet expect enet and vgg backbone, but received {}". raise Exception(
format(cfg.MODEL.LANENET.BACKBONE)) "LaneNet expect enet and vgg backbone, but received {}".format(
cfg.MODEL.LANENET.BACKBONE))
return output return output
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -58,7 +58,8 @@ class LaneNetDataset(): ...@@ -58,7 +58,8 @@ class LaneNetDataset():
if self.shuffle and cfg.NUM_TRAINERS > 1: if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines) np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // cfg.NUM_TRAINERS num_lines = len(self.all_lines) // cfg.NUM_TRAINERS
self.lines = self.all_lines[num_lines * cfg.TRAINER_ID: num_lines * (cfg.TRAINER_ID + 1)] self.lines = self.all_lines[num_lines * cfg.TRAINER_ID:num_lines *
(cfg.TRAINER_ID + 1)]
self.shuffle_seed += 1 self.shuffle_seed += 1
elif self.shuffle: elif self.shuffle:
np.random.shuffle(self.lines) np.random.shuffle(self.lines)
...@@ -86,7 +87,8 @@ class LaneNetDataset(): ...@@ -86,7 +87,8 @@ class LaneNetDataset():
if self.shuffle and cfg.NUM_TRAINERS > 1: if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines) np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // self.num_trainers num_lines = len(self.all_lines) // self.num_trainers
self.lines = self.all_lines[num_lines * self.trainer_id: num_lines * (self.trainer_id + 1)] self.lines = self.all_lines[num_lines * self.trainer_id:num_lines *
(self.trainer_id + 1)]
self.shuffle_seed += 1 self.shuffle_seed += 1
elif self.shuffle: elif self.shuffle:
np.random.shuffle(self.lines) np.random.shuffle(self.lines)
...@@ -118,7 +120,8 @@ class LaneNetDataset(): ...@@ -118,7 +120,8 @@ class LaneNetDataset():
def batch_reader(is_test=False, drop_last=drop_last): def batch_reader(is_test=False, drop_last=drop_last):
if is_test: if is_test:
imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], [] imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], []
for img, grt, grt_instance, img_name, valid_shape, org_shape in reader(): for img, grt, grt_instance, img_name, valid_shape, org_shape in reader(
):
imgs.append(img) imgs.append(img)
grts.append(grt) grts.append(grt)
grts_instance.append(grt_instance) grts_instance.append(grt_instance)
...@@ -126,14 +129,15 @@ class LaneNetDataset(): ...@@ -126,14 +129,15 @@ class LaneNetDataset():
valid_shapes.append(valid_shape) valid_shapes.append(valid_shape)
org_shapes.append(org_shape) org_shapes.append(org_shape)
if len(imgs) == batch_size: if len(imgs) == batch_size:
yield np.array(imgs), np.array( yield np.array(imgs), np.array(grts), np.array(
grts), np.array(grts_instance), img_names, np.array(valid_shapes), np.array( grts_instance), img_names, np.array(
org_shapes) valid_shapes), np.array(org_shapes)
imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], [] imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], []
if not drop_last and len(imgs) > 0: if not drop_last and len(imgs) > 0:
yield np.array(imgs), np.array(grts), np.array(grts_instance), img_names, np.array( yield np.array(imgs), np.array(grts), np.array(
valid_shapes), np.array(org_shapes) grts_instance), img_names, np.array(
valid_shapes), np.array(org_shapes)
else: else:
imgs, labs, labs_instance, ignore = [], [], [], [] imgs, labs, labs_instance, ignore = [], [], [], []
bs = 0 bs = 0
...@@ -144,12 +148,14 @@ class LaneNetDataset(): ...@@ -144,12 +148,14 @@ class LaneNetDataset():
ignore.append(ig) ignore.append(ig)
bs += 1 bs += 1
if bs == batch_size: if bs == batch_size:
yield np.array(imgs), np.array(labs), np.array(labs_instance), np.array(ignore) yield np.array(imgs), np.array(labs), np.array(
labs_instance), np.array(ignore)
bs = 0 bs = 0
imgs, labs, labs_instance, ignore = [], [], [], [] imgs, labs, labs_instance, ignore = [], [], [], []
if not drop_last and bs > 0: if not drop_last and bs > 0:
yield np.array(imgs), np.array(labs), np.array(labs_instance), np.array(ignore) yield np.array(imgs), np.array(labs), np.array(
labs_instance), np.array(ignore)
return batch_reader(is_test, drop_last) return batch_reader(is_test, drop_last)
...@@ -299,10 +305,12 @@ class LaneNetDataset(): ...@@ -299,10 +305,12 @@ class LaneNetDataset():
img, grt = aug.rand_crop(img, grt, mode=mode) img, grt = aug.rand_crop(img, grt, mode=mode)
elif ModelPhase.is_eval(mode): elif ModelPhase.is_eval(mode):
img, grt, grt_instance = aug.resize(img, grt, grt_instance, mode=mode) img, grt, grt_instance = aug.resize(
img, grt, grt_instance, mode=mode)
elif ModelPhase.is_visual(mode): elif ModelPhase.is_visual(mode):
ori_img = img.copy() ori_img = img.copy()
img, grt, grt_instance = aug.resize(img, grt, grt_instance, mode=mode) img, grt, grt_instance = aug.resize(
img, grt, grt_instance, mode=mode)
valid_shape = [img.shape[0], img.shape[1]] valid_shape = [img.shape[0], img.shape[1]]
else: else:
raise ValueError("Dataset mode={} Error!".format(mode)) raise ValueError("Dataset mode={} Error!".format(mode))
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -80,8 +80,8 @@ cfg.DATASET.DATA_DIM = 3 ...@@ -80,8 +80,8 @@ cfg.DATASET.DATA_DIM = 3
cfg.DATASET.SEPARATOR = ' ' cfg.DATASET.SEPARATOR = ' '
# 忽略的像素标签值, 默认为255,一般无需改动 # 忽略的像素标签值, 默认为255,一般无需改动
cfg.DATASET.IGNORE_INDEX = 255 cfg.DATASET.IGNORE_INDEX = 255
# 数据增强是图像的padding值 # 数据增强是图像的padding值
cfg.DATASET.PADDING_VALUE = [127.5,127.5,127.5] cfg.DATASET.PADDING_VALUE = [127.5, 127.5, 127.5]
########################### 数据增强配置 ###################################### ########################### 数据增强配置 ######################################
# 图像镜像左右翻转 # 图像镜像左右翻转
......
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
generate tusimple training dataset generate tusimple training dataset
""" """
...@@ -14,12 +28,16 @@ import numpy as np ...@@ -14,12 +28,16 @@ import numpy as np
def init_args(): def init_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--src_dir', type=str, help='The origin path of unzipped tusimple dataset') parser.add_argument(
'--src_dir',
type=str,
help='The origin path of unzipped tusimple dataset')
return parser.parse_args() return parser.parse_args()
def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, instance_dst_dir): def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir,
instance_dst_dir):
assert ops.exists(json_file_path), '{:s} not exist'.format(json_file_path) assert ops.exists(json_file_path), '{:s} not exist'.format(json_file_path)
...@@ -39,11 +57,14 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst ...@@ -39,11 +57,14 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
h_samples = info_dict['h_samples'] h_samples = info_dict['h_samples']
lanes = info_dict['lanes'] lanes = info_dict['lanes']
image_name_new = '{:s}.png'.format('{:d}'.format(line_index + image_nums).zfill(4)) image_name_new = '{:s}.png'.format(
'{:d}'.format(line_index + image_nums).zfill(4))
src_image = cv2.imread(image_path, cv2.IMREAD_COLOR) src_image = cv2.imread(image_path, cv2.IMREAD_COLOR)
dst_binary_image = np.zeros([src_image.shape[0], src_image.shape[1]], np.uint8) dst_binary_image = np.zeros(
dst_instance_image = np.zeros([src_image.shape[0], src_image.shape[1]], np.uint8) [src_image.shape[0], src_image.shape[1]], np.uint8)
dst_instance_image = np.zeros(
[src_image.shape[0], src_image.shape[1]], np.uint8)
for lane_index, lane in enumerate(lanes): for lane_index, lane in enumerate(lanes):
assert len(h_samples) == len(lane) assert len(h_samples) == len(lane)
...@@ -62,13 +83,23 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst ...@@ -62,13 +83,23 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
lane_pts = np.vstack((lane_x, lane_y)).transpose() lane_pts = np.vstack((lane_x, lane_y)).transpose()
lane_pts = np.array([lane_pts], np.int64) lane_pts = np.array([lane_pts], np.int64)
cv2.polylines(dst_binary_image, lane_pts, isClosed=False, cv2.polylines(
color=255, thickness=5) dst_binary_image,
cv2.polylines(dst_instance_image, lane_pts, isClosed=False, lane_pts,
color=lane_index * 50 + 20, thickness=5) isClosed=False,
color=255,
dst_binary_image_path = ops.join(src_dir, binary_dst_dir, image_name_new) thickness=5)
dst_instance_image_path = ops.join(src_dir, instance_dst_dir, image_name_new) cv2.polylines(
dst_instance_image,
lane_pts,
isClosed=False,
color=lane_index * 50 + 20,
thickness=5)
dst_binary_image_path = ops.join(src_dir, binary_dst_dir,
image_name_new)
dst_instance_image_path = ops.join(src_dir, instance_dst_dir,
image_name_new)
dst_rgb_image_path = ops.join(src_dir, ori_dst_dir, image_name_new) dst_rgb_image_path = ops.join(src_dir, ori_dst_dir, image_name_new)
cv2.imwrite(dst_binary_image_path, dst_binary_image) cv2.imwrite(dst_binary_image_path, dst_binary_image)
...@@ -78,7 +109,12 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst ...@@ -78,7 +109,12 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
print('Process {:s} success'.format(image_name)) print('Process {:s} success'.format(image_name))
def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train', split=False): def gen_sample(src_dir,
b_gt_image_dir,
i_gt_image_dir,
image_dir,
phase='train',
split=False):
label_list = [] label_list = []
with open('{:s}/{}ing/{}.txt'.format(src_dir, phase, phase), 'w') as file: with open('{:s}/{}ing/{}.txt'.format(src_dir, phase, phase), 'w') as file:
...@@ -92,7 +128,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train' ...@@ -92,7 +128,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
image_path = ops.join(image_dir, image_name) image_path = ops.join(image_dir, image_name)
assert ops.exists(image_path), '{:s} not exist'.format(image_path) assert ops.exists(image_path), '{:s} not exist'.format(image_path)
assert ops.exists(instance_gt_image_path), '{:s} not exist'.format(instance_gt_image_path) assert ops.exists(instance_gt_image_path), '{:s} not exist'.format(
instance_gt_image_path)
b_gt_image = cv2.imread(binary_gt_image_path, cv2.IMREAD_COLOR) b_gt_image = cv2.imread(binary_gt_image_path, cv2.IMREAD_COLOR)
i_gt_image = cv2.imread(instance_gt_image_path, cv2.IMREAD_COLOR) i_gt_image = cv2.imread(instance_gt_image_path, cv2.IMREAD_COLOR)
...@@ -102,7 +139,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train' ...@@ -102,7 +139,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
print('image: {:s} corrupt'.format(image_name)) print('image: {:s} corrupt'.format(image_name))
continue continue
else: else:
info = '{:s} {:s} {:s}'.format(image_path, binary_gt_image_path, instance_gt_image_path) info = '{:s} {:s} {:s}'.format(image_path, binary_gt_image_path,
instance_gt_image_path)
file.write(info + '\n') file.write(info + '\n')
label_list.append(info) label_list.append(info)
if phase == 'train' and split: if phase == 'train' and split:
...@@ -110,10 +148,12 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train' ...@@ -110,10 +148,12 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
val_list_len = len(label_list) // 10 val_list_len = len(label_list) // 10
val_label_list = label_list[:val_list_len] val_label_list = label_list[:val_list_len]
train_label_list = label_list[val_list_len:] train_label_list = label_list[val_list_len:]
with open('{:s}/{}ing/train_part.txt'.format(src_dir, phase, phase), 'w') as file: with open('{:s}/{}ing/train_part.txt'.format(src_dir, phase, phase),
'w') as file:
for info in train_label_list: for info in train_label_list:
file.write(info + '\n') file.write(info + '\n')
with open('{:s}/{}ing/val_part.txt'.format(src_dir, phase, phase), 'w') as file: with open('{:s}/{}ing/val_part.txt'.format(src_dir, phase, phase),
'w') as file:
for info in val_label_list: for info in val_label_list:
file.write(info + '\n') file.write(info + '\n')
return return
...@@ -130,12 +170,14 @@ def process_tusimple_dataset(src_dir): ...@@ -130,12 +170,14 @@ def process_tusimple_dataset(src_dir):
for json_label_path in glob.glob('{:s}/label*.json'.format(src_dir)): for json_label_path in glob.glob('{:s}/label*.json'.format(src_dir)):
json_label_name = ops.split(json_label_path)[1] json_label_name = ops.split(json_label_path)[1]
shutil.copyfile(json_label_path, ops.join(traing_folder_path, json_label_name)) shutil.copyfile(json_label_path,
ops.join(traing_folder_path, json_label_name))
for json_label_path in glob.glob('{:s}/test_label.json'.format(src_dir)): for json_label_path in glob.glob('{:s}/test_label.json'.format(src_dir)):
json_label_name = ops.split(json_label_path)[1] json_label_name = ops.split(json_label_path)[1]
shutil.copyfile(json_label_path, ops.join(testing_folder_path, json_label_name)) shutil.copyfile(json_label_path,
ops.join(testing_folder_path, json_label_name))
train_gt_image_dir = ops.join('training', 'gt_image') train_gt_image_dir = ops.join('training', 'gt_image')
train_gt_binary_dir = ops.join('training', 'gt_binary_image') train_gt_binary_dir = ops.join('training', 'gt_binary_image')
...@@ -154,9 +196,11 @@ def process_tusimple_dataset(src_dir): ...@@ -154,9 +196,11 @@ def process_tusimple_dataset(src_dir):
os.makedirs(os.path.join(src_dir, test_gt_instance_dir), exist_ok=True) os.makedirs(os.path.join(src_dir, test_gt_instance_dir), exist_ok=True)
for json_label_path in glob.glob('{:s}/*.json'.format(traing_folder_path)): for json_label_path in glob.glob('{:s}/*.json'.format(traing_folder_path)):
process_json_file(json_label_path, src_dir, train_gt_image_dir, train_gt_binary_dir, train_gt_instance_dir) process_json_file(json_label_path, src_dir, train_gt_image_dir,
train_gt_binary_dir, train_gt_instance_dir)
gen_sample(src_dir, train_gt_binary_dir, train_gt_instance_dir, train_gt_image_dir, 'train', True) gen_sample(src_dir, train_gt_binary_dir, train_gt_instance_dir,
train_gt_image_dir, 'train', True)
if __name__ == '__main__': if __name__ == '__main__':
......
#!/usr/bin/env python3 # coding: utf8
# -*- coding: utf-8 -*- # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this code heavily base on https://github.com/MaybeShewill-CV/lanenet-lane-detection/blob/master/lanenet_model/lanenet_postprocess.py # this code heavily base on https://github.com/MaybeShewill-CV/lanenet-lane-detection/blob/master/lanenet_model/lanenet_postprocess.py
""" """
LaneNet model post process LaneNet model post process
...@@ -22,12 +35,14 @@ def _morphological_process(image, kernel_size=5): ...@@ -22,12 +35,14 @@ def _morphological_process(image, kernel_size=5):
:return: :return:
""" """
if len(image.shape) == 3: if len(image.shape) == 3:
raise ValueError('Binary segmentation result image should be a single channel image') raise ValueError(
'Binary segmentation result image should be a single channel image')
if image.dtype is not np.uint8: if image.dtype is not np.uint8:
image = np.array(image, np.uint8) image = np.array(image, np.uint8)
kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(kernel_size, kernel_size)) kernel = cv2.getStructuringElement(
shape=cv2.MORPH_ELLIPSE, ksize=(kernel_size, kernel_size))
# close operation fille hole # close operation fille hole
closing = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel, iterations=1) closing = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel, iterations=1)
...@@ -46,13 +61,15 @@ def _connect_components_analysis(image): ...@@ -46,13 +61,15 @@ def _connect_components_analysis(image):
else: else:
gray_image = image gray_image = image
return cv2.connectedComponentsWithStats(gray_image, connectivity=8, ltype=cv2.CV_32S) return cv2.connectedComponentsWithStats(
gray_image, connectivity=8, ltype=cv2.CV_32S)
class _LaneFeat(object): class _LaneFeat(object):
""" """
""" """
def __init__(self, feat, coord, class_id=-1): def __init__(self, feat, coord, class_id=-1):
""" """
lane feat object lane feat object
...@@ -108,18 +125,21 @@ class _LaneNetCluster(object): ...@@ -108,18 +125,21 @@ class _LaneNetCluster(object):
""" """
Instance segmentation result cluster Instance segmentation result cluster
""" """
def __init__(self): def __init__(self):
""" """
""" """
self._color_map = [np.array([255, 0, 0]), self._color_map = [
np.array([0, 255, 0]), np.array([255, 0, 0]),
np.array([0, 0, 255]), np.array([0, 255, 0]),
np.array([125, 125, 0]), np.array([0, 0, 255]),
np.array([0, 125, 125]), np.array([125, 125, 0]),
np.array([125, 0, 125]), np.array([0, 125, 125]),
np.array([50, 100, 50]), np.array([125, 0, 125]),
np.array([100, 50, 100])] np.array([50, 100, 50]),
np.array([100, 50, 100])
]
@staticmethod @staticmethod
def _embedding_feats_dbscan_cluster(embedding_image_feats): def _embedding_feats_dbscan_cluster(embedding_image_feats):
...@@ -186,15 +206,16 @@ class _LaneNetCluster(object): ...@@ -186,15 +206,16 @@ class _LaneNetCluster(object):
# get embedding feats and coords # get embedding feats and coords
get_lane_embedding_feats_result = self._get_lane_embedding_feats( get_lane_embedding_feats_result = self._get_lane_embedding_feats(
binary_seg_ret=binary_seg_result, binary_seg_ret=binary_seg_result,
instance_seg_ret=instance_seg_result instance_seg_ret=instance_seg_result)
)
# dbscan cluster # dbscan cluster
dbscan_cluster_result = self._embedding_feats_dbscan_cluster( dbscan_cluster_result = self._embedding_feats_dbscan_cluster(
embedding_image_feats=get_lane_embedding_feats_result['lane_embedding_feats'] embedding_image_feats=get_lane_embedding_feats_result[
) 'lane_embedding_feats'])
mask = np.zeros(shape=[binary_seg_result.shape[0], binary_seg_result.shape[1], 3], dtype=np.uint8) mask = np.zeros(
shape=[binary_seg_result.shape[0], binary_seg_result.shape[1], 3],
dtype=np.uint8)
db_labels = dbscan_cluster_result['db_labels'] db_labels = dbscan_cluster_result['db_labels']
unique_labels = dbscan_cluster_result['unique_labels'] unique_labels = dbscan_cluster_result['unique_labels']
coord = get_lane_embedding_feats_result['lane_coordinates'] coord = get_lane_embedding_feats_result['lane_coordinates']
...@@ -219,11 +240,13 @@ class LaneNetPostProcessor(object): ...@@ -219,11 +240,13 @@ class LaneNetPostProcessor(object):
""" """
lanenet post process for lane generation lanenet post process for lane generation
""" """
def __init__(self, ipm_remap_file_path='./utils/tusimple_ipm_remap.yml'): def __init__(self, ipm_remap_file_path='./utils/tusimple_ipm_remap.yml'):
""" """
convert front car view to bird view convert front car view to bird view
""" """
assert ops.exists(ipm_remap_file_path), '{:s} not exist'.format(ipm_remap_file_path) assert ops.exists(ipm_remap_file_path), '{:s} not exist'.format(
ipm_remap_file_path)
self._cluster = _LaneNetCluster() self._cluster = _LaneNetCluster()
self._ipm_remap_file_path = ipm_remap_file_path self._ipm_remap_file_path = ipm_remap_file_path
...@@ -232,14 +255,16 @@ class LaneNetPostProcessor(object): ...@@ -232,14 +255,16 @@ class LaneNetPostProcessor(object):
self._remap_to_ipm_x = remap_file_load_ret['remap_to_ipm_x'] self._remap_to_ipm_x = remap_file_load_ret['remap_to_ipm_x']
self._remap_to_ipm_y = remap_file_load_ret['remap_to_ipm_y'] self._remap_to_ipm_y = remap_file_load_ret['remap_to_ipm_y']
self._color_map = [np.array([255, 0, 0]), self._color_map = [
np.array([0, 255, 0]), np.array([255, 0, 0]),
np.array([0, 0, 255]), np.array([0, 255, 0]),
np.array([125, 125, 0]), np.array([0, 0, 255]),
np.array([0, 125, 125]), np.array([125, 125, 0]),
np.array([125, 0, 125]), np.array([0, 125, 125]),
np.array([50, 100, 50]), np.array([125, 0, 125]),
np.array([100, 50, 100])] np.array([50, 100, 50]),
np.array([100, 50, 100])
]
def _load_remap_matrix(self): def _load_remap_matrix(self):
fs = cv2.FileStorage(self._ipm_remap_file_path, cv2.FILE_STORAGE_READ) fs = cv2.FileStorage(self._ipm_remap_file_path, cv2.FILE_STORAGE_READ)
...@@ -256,15 +281,20 @@ class LaneNetPostProcessor(object): ...@@ -256,15 +281,20 @@ class LaneNetPostProcessor(object):
return ret return ret
def postprocess(self, binary_seg_result, instance_seg_result=None, def postprocess(self,
min_area_threshold=100, source_image=None, binary_seg_result,
instance_seg_result=None,
min_area_threshold=100,
source_image=None,
data_source='tusimple'): data_source='tusimple'):
# convert binary_seg_result # convert binary_seg_result
binary_seg_result = np.array(binary_seg_result * 255, dtype=np.uint8) binary_seg_result = np.array(binary_seg_result * 255, dtype=np.uint8)
# apply image morphology operation to fill in the hold and reduce the small area # apply image morphology operation to fill in the hold and reduce the small area
morphological_ret = _morphological_process(binary_seg_result, kernel_size=5) morphological_ret = _morphological_process(
connect_components_analysis_ret = _connect_components_analysis(image=morphological_ret) binary_seg_result, kernel_size=5)
connect_components_analysis_ret = _connect_components_analysis(
image=morphological_ret)
labels = connect_components_analysis_ret[1] labels = connect_components_analysis_ret[1]
stats = connect_components_analysis_ret[2] stats = connect_components_analysis_ret[2]
...@@ -276,8 +306,7 @@ class LaneNetPostProcessor(object): ...@@ -276,8 +306,7 @@ class LaneNetPostProcessor(object):
# apply embedding features cluster # apply embedding features cluster
mask_image, lane_coords = self._cluster.apply_lane_feats_cluster( mask_image, lane_coords = self._cluster.apply_lane_feats_cluster(
binary_seg_result=morphological_ret, binary_seg_result=morphological_ret,
instance_seg_result=instance_seg_result instance_seg_result=instance_seg_result)
)
if mask_image is None: if mask_image is None:
return { return {
...@@ -292,15 +321,15 @@ class LaneNetPostProcessor(object): ...@@ -292,15 +321,15 @@ class LaneNetPostProcessor(object):
for lane_index, coords in enumerate(lane_coords): for lane_index, coords in enumerate(lane_coords):
if data_source == 'tusimple': if data_source == 'tusimple':
tmp_mask = np.zeros(shape=(720, 1280), dtype=np.uint8) tmp_mask = np.zeros(shape=(720, 1280), dtype=np.uint8)
tmp_mask[tuple((np.int_(coords[:, 1] * 720 / 256), np.int_(coords[:, 0] * 1280 / 512)))] = 255 tmp_mask[tuple((np.int_(coords[:, 1] * 720 / 256),
np.int_(coords[:, 0] * 1280 / 512)))] = 255
else: else:
raise ValueError('Wrong data source now only support tusimple') raise ValueError('Wrong data source now only support tusimple')
tmp_ipm_mask = cv2.remap( tmp_ipm_mask = cv2.remap(
tmp_mask, tmp_mask,
self._remap_to_ipm_x, self._remap_to_ipm_x,
self._remap_to_ipm_y, self._remap_to_ipm_y,
interpolation=cv2.INTER_NEAREST interpolation=cv2.INTER_NEAREST)
)
nonzero_y = np.array(tmp_ipm_mask.nonzero()[0]) nonzero_y = np.array(tmp_ipm_mask.nonzero()[0])
nonzero_x = np.array(tmp_ipm_mask.nonzero()[1]) nonzero_x = np.array(tmp_ipm_mask.nonzero()[1])
...@@ -309,16 +338,19 @@ class LaneNetPostProcessor(object): ...@@ -309,16 +338,19 @@ class LaneNetPostProcessor(object):
[ipm_image_height, ipm_image_width] = tmp_ipm_mask.shape [ipm_image_height, ipm_image_width] = tmp_ipm_mask.shape
plot_y = np.linspace(10, ipm_image_height, ipm_image_height - 10) plot_y = np.linspace(10, ipm_image_height, ipm_image_height - 10)
fit_x = fit_param[0] * plot_y ** 2 + fit_param[1] * plot_y + fit_param[2] fit_x = fit_param[0] * plot_y**2 + fit_param[
1] * plot_y + fit_param[2]
lane_pts = [] lane_pts = []
for index in range(0, plot_y.shape[0], 5): for index in range(0, plot_y.shape[0], 5):
src_x = self._remap_to_ipm_x[ src_x = self._remap_to_ipm_x[
int(plot_y[index]), int(np.clip(fit_x[index], 0, ipm_image_width - 1))] int(plot_y[index]),
int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
if src_x <= 0: if src_x <= 0:
continue continue
src_y = self._remap_to_ipm_y[ src_y = self._remap_to_ipm_y[
int(plot_y[index]), int(np.clip(fit_x[index], 0, ipm_image_width - 1))] int(plot_y[index]),
int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
src_y = src_y if src_y > 0 else 0 src_y = src_y if src_y > 0 else 0
lane_pts.append([src_x, src_y]) lane_pts.append([src_x, src_y])
...@@ -366,8 +398,10 @@ class LaneNetPostProcessor(object): ...@@ -366,8 +398,10 @@ class LaneNetPostProcessor(object):
continue continue
lane_color = self._color_map[index].tolist() lane_color = self._color_map[index].tolist()
cv2.circle(source_image, (int(interpolation_src_pt_x), cv2.circle(
int(interpolation_src_pt_y)), 5, lane_color, -1) source_image,
(int(interpolation_src_pt_x), int(interpolation_src_pt_y)),
5, lane_color, -1)
ret = { ret = {
'mask_image': mask_image, 'mask_image': mask_image,
'fit_params': fit_params, 'fit_params': fit_params,
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -23,7 +24,8 @@ from test_utils import download_file_and_uncompress ...@@ -23,7 +24,8 @@ from test_utils import download_file_and_uncompress
if __name__ == "__main__": if __name__ == "__main__":
download_file_and_uncompress( download_file_and_uncompress(
url='https://paddleseg.bj.bcebos.com/models/unet_mechanical_industry_meter.tar', url=
'https://paddleseg.bj.bcebos.com/models/unet_mechanical_industry_meter.tar',
savepath=LOCAL_PATH, savepath=LOCAL_PATH,
extrapath=LOCAL_PATH) extrapath=LOCAL_PATH)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .load_model import * from .load_model import *
from .unet import * from .unet import *
from .hrnet import * from .hrnet import *
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
import paddle.fluid as fluid import paddle.fluid as fluid
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
import numpy as np import numpy as np
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .unet import UNet from .unet import UNet
from .hrnet import HRNet from .hrnet import HRNet
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import os.path as osp import os.path as osp
import sys import sys
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp import os.path as osp
import argparse import argparse
import transforms.transforms as T import transforms.transforms as T
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp import os.path as osp
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from utils.util import AttrDict, merge_cfg_from_args, get_arguments from utils.util import AttrDict, merge_cfg_from_args, get_arguments
import os import os
...@@ -6,20 +20,20 @@ args = get_arguments() ...@@ -6,20 +20,20 @@ args = get_arguments()
cfg = AttrDict() cfg = AttrDict()
# 待预测图像所在路径 # 待预测图像所在路径
cfg.data_dir = os.path.join(args.example , "data", "test_images") cfg.data_dir = os.path.join(args.example, "data", "test_images")
# 待预测图像名称列表 # 待预测图像名称列表
cfg.data_list_file = os.path.join(args.example , "data", "test.txt") cfg.data_list_file = os.path.join(args.example, "data", "test.txt")
# 模型加载路径 # 模型加载路径
cfg.model_path = os.path.join(args.example , "model") cfg.model_path = os.path.join(args.example, "model")
# 预测结果保存路径 # 预测结果保存路径
cfg.vis_dir = os.path.join(args.example , "result") cfg.vis_dir = os.path.join(args.example, "result")
# 预测类别数 # 预测类别数
cfg.class_num = 2 cfg.class_num = 2
# 均值, 图像预处理减去的均值 # 均值, 图像预处理减去的均值
cfg.MEAN = 127.5, 127.5, 127.5 cfg.MEAN = 127.5, 127.5, 127.5
# 标准差,图像预处理除以标准差 # 标准差,图像预处理除以标准差
cfg.STD = 127.5, 127.5, 127.5 cfg.STD = 127.5, 127.5, 127.5
# 待预测图像输入尺寸 # 待预测图像输入尺寸
cfg.input_size = 1536, 576 cfg.input_size = 1536, 576
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import cv2 import cv2
import numpy as np import numpy as np
...@@ -12,18 +26,19 @@ config = importlib.import_module('config') ...@@ -12,18 +26,19 @@ config = importlib.import_module('config')
cfg = getattr(config, 'cfg') cfg = getattr(config, 'cfg')
# paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启 # paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启
os.environ['FLAGS_eager_delete_tensor_gb']='0.0' os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
import paddle.fluid as fluid import paddle.fluid as fluid
# 预测数据集类 # 预测数据集类
class TestDataSet(): class TestDataSet():
def __init__(self): def __init__(self):
self.data_dir = cfg.data_dir self.data_dir = cfg.data_dir
self.data_list_file = cfg.data_list_file self.data_list_file = cfg.data_list_file
self.data_list = self.get_data_list() self.data_list = self.get_data_list()
self.data_num = len(self.data_list) self.data_num = len(self.data_list)
def get_data_list(self): def get_data_list(self):
# 获取预测图像路径列表 # 获取预测图像路径列表
data_list = [] data_list = []
...@@ -40,7 +55,7 @@ class TestDataSet(): ...@@ -40,7 +55,7 @@ class TestDataSet():
def preprocess(self, img): def preprocess(self, img):
# 图像预处理 # 图像预处理
if cfg.example == 'ACE2P': if cfg.example == 'ACE2P':
reader = importlib.import_module(args.example+'.reader') reader = importlib.import_module(args.example + '.reader')
ACE2P_preprocess = getattr(reader, 'preprocess') ACE2P_preprocess = getattr(reader, 'preprocess')
img = ACE2P_preprocess(img) img = ACE2P_preprocess(img)
else: else:
...@@ -56,10 +71,10 @@ class TestDataSet(): ...@@ -56,10 +71,10 @@ class TestDataSet():
img_path = self.data_list[index] img_path = self.data_list[index]
img = cv2.imread(img_path, cv2.IMREAD_COLOR) img = cv2.imread(img_path, cv2.IMREAD_COLOR)
if img is None: if img is None:
return img, img,img_path, None return img, img, img_path, None
img_name = img_path.split(os.sep)[-1] img_name = img_path.split(os.sep)[-1]
name_prefix = img_name.replace('.'+img_name.split('.')[-1],'') name_prefix = img_name.replace('.' + img_name.split('.')[-1], '')
img_shape = img.shape[:2] img_shape = img.shape[:2]
img_process = self.preprocess(img) img_process = self.preprocess(img)
...@@ -90,39 +105,44 @@ def infer(): ...@@ -90,39 +105,44 @@ def infer():
if image is None: if image is None:
print(im_name, 'is None') print(im_name, 'is None')
continue continue
# 预测 # 预测
if cfg.example == 'ACE2P': if cfg.example == 'ACE2P':
# ACE2P模型使用多尺度预测 # ACE2P模型使用多尺度预测
reader = importlib.import_module(args.example+'.reader') reader = importlib.import_module(args.example + '.reader')
multi_scale_test = getattr(reader, 'multi_scale_test') multi_scale_test = getattr(reader, 'multi_scale_test')
parsing, logits = multi_scale_test(exe, test_prog, feed_name, fetch_list, image, im_shape) parsing, logits = multi_scale_test(exe, test_prog, feed_name,
fetch_list, image, im_shape)
else: else:
# HumanSeg,RoadLine模型单尺度预测 # HumanSeg,RoadLine模型单尺度预测
result = exe.run(program=test_prog, feed={feed_name[0]: image}, fetch_list=fetch_list) result = exe.run(
program=test_prog,
feed={feed_name[0]: image},
fetch_list=fetch_list)
parsing = np.argmax(result[0][0], axis=0) parsing = np.argmax(result[0][0], axis=0)
parsing = cv2.resize(parsing.astype(np.uint8), im_shape[::-1]) parsing = cv2.resize(parsing.astype(np.uint8), im_shape[::-1])
# 预测结果保存 # 预测结果保存
result_path = os.path.join(cfg.vis_dir, im_name + '.png') result_path = os.path.join(cfg.vis_dir, im_name + '.png')
if cfg.example == 'HumanSeg': if cfg.example == 'HumanSeg':
logits = result[0][0][1]*255 logits = result[0][0][1] * 255
logits = cv2.resize(logits, im_shape[::-1]) logits = cv2.resize(logits, im_shape[::-1])
ret, logits = cv2.threshold(logits, thresh, 0, cv2.THRESH_TOZERO) ret, logits = cv2.threshold(logits, thresh, 0, cv2.THRESH_TOZERO)
logits = 255 *(logits - thresh)/(255 - thresh) logits = 255 * (logits - thresh) / (255 - thresh)
# 将分割结果添加到alpha通道 # 将分割结果添加到alpha通道
rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)), axis=2) rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)),
axis=2)
cv2.imwrite(result_path, rgba) cv2.imwrite(result_path, rgba)
else: else:
output_im = PILImage.fromarray(np.asarray(parsing, dtype=np.uint8)) output_im = PILImage.fromarray(np.asarray(parsing, dtype=np.uint8))
output_im.putpalette(palette) output_im.putpalette(palette)
output_im.save(result_path) output_im.save(result_path)
if (idx + 1) % 100 == 0: if (idx + 1) % 100 == 0:
print('%d processd' % (idx + 1)) print('%d processd' % (idx + 1))
print('%d processd done' % (idx + 1)) print('%d processd done' % (idx + 1))
return 0 return 0
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # coding: utf8
## Created by: RainbowSecret # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
## Microsoft Research #
## yuyua@microsoft.com # Licensed under the Apache License, Version 2.0 (the "License");
## Copyright (c) 2018 # you may not use this file except in compliance with the License.
## # You may obtain a copy of the License at
## This source code is licensed under the MIT-style license found in the #
## LICENSE file in the root directory of this source tree # http://www.apache.org/licenses/LICENSE-2.0
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import argparse import argparse
import os import os
def get_arguments(): def get_arguments():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--use_gpu", parser.add_argument(
action="store_true", "--use_gpu", action="store_true", help="Use gpu or cpu to test.")
help="Use gpu or cpu to test.") parser.add_argument(
parser.add_argument('--example', '--example', type=str, help='RoadLine, HumanSeg or ACE2P')
type=str,
help='RoadLine, HumanSeg or ACE2P')
return parser.parse_args() return parser.parse_args()
...@@ -34,6 +48,7 @@ class AttrDict(dict): ...@@ -34,6 +48,7 @@ class AttrDict(dict):
else: else:
self[name] = value self[name] = value
def merge_cfg_from_args(args, cfg): def merge_cfg_from_args(args, cfg):
"""Merge config keys, values in args into the global config.""" """Merge config keys, values in args into the global config."""
for k, v in vars(args).items(): for k, v in vars(args).items():
...@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg): ...@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg):
value = v value = v
if value is not None: if value is not None:
cfg[k] = value cfg[k] = value
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -20,6 +21,8 @@ from PIL import Image ...@@ -20,6 +21,8 @@ from PIL import Image
import glob import glob
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__)) LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
def remove_colormap(filename): def remove_colormap(filename):
gray_anno = np.array(Image.open(filename)) gray_anno = np.array(Image.open(filename))
return gray_anno return gray_anno
...@@ -30,6 +33,7 @@ def save_annotation(annotation, filename): ...@@ -30,6 +33,7 @@ def save_annotation(annotation, filename):
annotation = Image.fromarray(annotation) annotation = Image.fromarray(annotation)
annotation.save(filename) annotation.save(filename)
def convert_list(origin_file, seg_file, output_folder): def convert_list(origin_file, seg_file, output_folder):
with open(seg_file, 'w') as fid_seg: with open(seg_file, 'w') as fid_seg:
with open(origin_file) as fid_ori: with open(origin_file) as fid_ori:
...@@ -43,6 +47,7 @@ def convert_list(origin_file, seg_file, output_folder): ...@@ -43,6 +47,7 @@ def convert_list(origin_file, seg_file, output_folder):
new_line = ' '.join([img_name, anno_name]) new_line = ' '.join([img_name, anno_name])
fid_seg.write(new_line + "\n") fid_seg.write(new_line + "\n")
if __name__ == "__main__": if __name__ == "__main__":
pascal_root = "./VOCtrainval_11-May-2012/VOC2012" pascal_root = "./VOCtrainval_11-May-2012/VOC2012"
pascal_root = os.path.join(LOCAL_PATH, pascal_root) pascal_root = os.path.join(LOCAL_PATH, pascal_root)
...@@ -54,7 +59,7 @@ if __name__ == "__main__": ...@@ -54,7 +59,7 @@ if __name__ == "__main__":
# 标注图转换后存储目录 # 标注图转换后存储目录
output_folder = os.path.join(pascal_root, "SegmentationClassAug") output_folder = os.path.join(pascal_root, "SegmentationClassAug")
print("annotation convert and file list convert") print("annotation convert and file list convert")
if not os.path.exists(os.path.join(LOCAL_PATH, output_folder)): if not os.path.exists(os.path.join(LOCAL_PATH, output_folder)):
os.mkdir(os.path.join(LOCAL_PATH, output_folder)) os.mkdir(os.path.join(LOCAL_PATH, output_folder))
...@@ -67,5 +72,5 @@ if __name__ == "__main__": ...@@ -67,5 +72,5 @@ if __name__ == "__main__":
convert_list(train_path, train_path.replace('txt', 'list'), output_folder) convert_list(train_path, train_path.replace('txt', 'list'), output_folder)
convert_list(val_path, val_path.replace('txt', 'list'), output_folder) convert_list(val_path, val_path.replace('txt', 'list'), output_folder)
convert_list(trainval_path, trainval_path.replace('txt', 'list'), output_folder) convert_list(trainval_path, trainval_path.replace('txt', 'list'),
output_folder)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -28,12 +29,12 @@ from convert_voc2012 import remove_colormap ...@@ -28,12 +29,12 @@ from convert_voc2012 import remove_colormap
from convert_voc2012 import save_annotation from convert_voc2012 import save_annotation
def download_VOC_dataset(savepath, extrapath): def download_VOC_dataset(savepath, extrapath):
url = "https://paddleseg.bj.bcebos.com/dataset/VOCtrainval_11-May-2012.tar" url = "https://paddleseg.bj.bcebos.com/dataset/VOCtrainval_11-May-2012.tar"
download_file_and_uncompress( download_file_and_uncompress(
url=url, savepath=savepath, extrapath=extrapath) url=url, savepath=savepath, extrapath=extrapath)
if __name__ == "__main__": if __name__ == "__main__":
download_VOC_dataset(LOCAL_PATH, LOCAL_PATH) download_VOC_dataset(LOCAL_PATH, LOCAL_PATH)
print("Dataset download finish!") print("Dataset download finish!")
...@@ -45,10 +46,10 @@ if __name__ == "__main__": ...@@ -45,10 +46,10 @@ if __name__ == "__main__":
train_path = os.path.join(txt_folder, "train.txt") train_path = os.path.join(txt_folder, "train.txt")
val_path = os.path.join(txt_folder, "val.txt") val_path = os.path.join(txt_folder, "val.txt")
trainval_path = os.path.join(txt_folder, "trainval.txt") trainval_path = os.path.join(txt_folder, "trainval.txt")
# 标注图转换后存储目录 # 标注图转换后存储目录
output_folder = os.path.join(pascal_root, "SegmentationClassAug") output_folder = os.path.join(pascal_root, "SegmentationClassAug")
print("annotation convert and file list convert") print("annotation convert and file list convert")
if not os.path.exists(output_folder): if not os.path.exists(output_folder):
os.mkdir(output_folder) os.mkdir(output_folder)
...@@ -61,5 +62,5 @@ if __name__ == "__main__": ...@@ -61,5 +62,5 @@ if __name__ == "__main__":
convert_list(train_path, train_path.replace('txt', 'list'), output_folder) convert_list(train_path, train_path.replace('txt', 'list'), output_folder)
convert_list(val_path, val_path.replace('txt', 'list'), output_folder) convert_list(val_path, val_path.replace('txt', 'list'), output_folder)
convert_list(trainval_path, trainval_path.replace('txt', 'list'), output_folder) convert_list(trainval_path, trainval_path.replace('txt', 'list'),
output_folder)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -31,7 +31,8 @@ gflags.DEFINE_string("conf", default="", help="Configuration File Path") ...@@ -31,7 +31,8 @@ gflags.DEFINE_string("conf", default="", help="Configuration File Path")
gflags.DEFINE_string("input_dir", default="", help="Directory of Input Images") gflags.DEFINE_string("input_dir", default="", help="Directory of Input Images")
gflags.DEFINE_boolean("use_pr", default=False, help="Use optimized model") gflags.DEFINE_boolean("use_pr", default=False, help="Use optimized model")
gflags.DEFINE_string("trt_mode", default="", help="Use optimized model") gflags.DEFINE_string("trt_mode", default="", help="Use optimized model")
gflags.DEFINE_string("ext", default=".jpeg|.jpg", help="Input Image File Extensions") gflags.DEFINE_string(
"ext", default=".jpeg|.jpg", help="Input Image File Extensions")
gflags.FLAGS = gflags.FLAGS gflags.FLAGS = gflags.FLAGS
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,4 +14,4 @@ ...@@ -14,4 +14,4 @@
# limitations under the License. # limitations under the License.
import models import models
import utils import utils
from . import tools from . import tools
\ No newline at end of file
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -427,12 +440,17 @@ def max_img_size_statistics(): ...@@ -427,12 +440,17 @@ def max_img_size_statistics():
logger.info("max width and max height of images are ({},{})".format( logger.info("max width and max height of images are ({},{})".format(
max_width, max_height)) max_width, max_height))
def num_classes_loss_matching_check(): def num_classes_loss_matching_check():
loss_type = cfg.SOLVER.LOSS loss_type = cfg.SOLVER.LOSS
num_classes = cfg.DATASET.NUM_CLASSES num_classes = cfg.DATASET.NUM_CLASSES
if num_classes > 2 and (("dice_loss" in loss_type) or ("bce_loss" in loss_type)): if num_classes > 2 and (("dice_loss" in loss_type) or
logger.info(error_print("loss check." ("bce_loss" in loss_type)):
" Dice loss and bce loss is only applicable to binary classfication")) logger.info(
error_print(
"loss check."
" Dice loss and bce loss is only applicable to binary classfication"
))
else: else:
logger.info(correct_print("loss check")) logger.info(correct_print("loss check"))
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册