未验证 提交 664dc115 编写于 作者: W wuyefeilin 提交者: GitHub

Unifiy copyright format (#261)

* update model save load

* first add

* update model save and load

* update train.py

* update LaneNet model saving and loading

* adapt slim to paddle-1.8

* update distillation save and load

* update nas model save and load

* update model load op

* update utils.py

* update load_model_utils.py

* update model saving and loading

* update copyright

* update palette.py
上级 020d1072
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from utils.util import AttrDict, merge_cfg_from_args, get_arguments from utils.util import AttrDict, merge_cfg_from_args, get_arguments
import os import os
...@@ -19,10 +33,10 @@ cfg.class_num = 20 ...@@ -19,10 +33,10 @@ cfg.class_num = 20
# 均值, 图像预处理减去的均值 # 均值, 图像预处理减去的均值
cfg.MEAN = 0.406, 0.456, 0.485 cfg.MEAN = 0.406, 0.456, 0.485
# 标准差,图像预处理除以标准差 # 标准差,图像预处理除以标准差
cfg.STD = 0.225, 0.224, 0.229 cfg.STD = 0.225, 0.224, 0.229
# 多尺度预测时图像尺寸 # 多尺度预测时图像尺寸
cfg.multi_scales = (377,377), (473,473), (567,567) cfg.multi_scales = (377, 377), (473, 473), (567, 567)
# 多尺度预测时图像是否水平翻转 # 多尺度预测时图像是否水平翻转
cfg.flip = True cfg.flip = True
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import cv2 import cv2
import numpy as np import numpy as np
...@@ -12,18 +26,19 @@ config = importlib.import_module('config') ...@@ -12,18 +26,19 @@ config = importlib.import_module('config')
cfg = getattr(config, 'cfg') cfg = getattr(config, 'cfg')
# paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启 # paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启
os.environ['FLAGS_eager_delete_tensor_gb']='0.0' os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
import paddle.fluid as fluid import paddle.fluid as fluid
# 预测数据集类 # 预测数据集类
class TestDataSet(): class TestDataSet():
def __init__(self): def __init__(self):
self.data_dir = cfg.data_dir self.data_dir = cfg.data_dir
self.data_list_file = cfg.data_list_file self.data_list_file = cfg.data_list_file
self.data_list = self.get_data_list() self.data_list = self.get_data_list()
self.data_num = len(self.data_list) self.data_num = len(self.data_list)
def get_data_list(self): def get_data_list(self):
# 获取预测图像路径列表 # 获取预测图像路径列表
data_list = [] data_list = []
...@@ -56,10 +71,10 @@ class TestDataSet(): ...@@ -56,10 +71,10 @@ class TestDataSet():
img_path = self.data_list[index] img_path = self.data_list[index]
img = cv2.imread(img_path, cv2.IMREAD_COLOR) img = cv2.imread(img_path, cv2.IMREAD_COLOR)
if img is None: if img is None:
return img, img,img_path, None return img, img, img_path, None
img_name = img_path.split(os.sep)[-1] img_name = img_path.split(os.sep)[-1]
name_prefix = img_name.replace('.'+img_name.split('.')[-1],'') name_prefix = img_name.replace('.' + img_name.split('.')[-1], '')
img_shape = img.shape[:2] img_shape = img.shape[:2]
img_process = self.preprocess(img) img_process = self.preprocess(img)
...@@ -90,39 +105,44 @@ def infer(): ...@@ -90,39 +105,44 @@ def infer():
if image is None: if image is None:
print(im_name, 'is None') print(im_name, 'is None')
continue continue
# 预测 # 预测
if cfg.example == 'ACE2P': if cfg.example == 'ACE2P':
# ACE2P模型使用多尺度预测 # ACE2P模型使用多尺度预测
reader = importlib.import_module('reader') reader = importlib.import_module('reader')
multi_scale_test = getattr(reader, 'multi_scale_test') multi_scale_test = getattr(reader, 'multi_scale_test')
parsing, logits = multi_scale_test(exe, test_prog, feed_name, fetch_list, image, im_shape) parsing, logits = multi_scale_test(exe, test_prog, feed_name,
fetch_list, image, im_shape)
else: else:
# HumanSeg,RoadLine模型单尺度预测 # HumanSeg,RoadLine模型单尺度预测
result = exe.run(program=test_prog, feed={feed_name[0]: image}, fetch_list=fetch_list) result = exe.run(
program=test_prog,
feed={feed_name[0]: image},
fetch_list=fetch_list)
parsing = np.argmax(result[0][0], axis=0) parsing = np.argmax(result[0][0], axis=0)
parsing = cv2.resize(parsing.astype(np.uint8), im_shape[::-1]) parsing = cv2.resize(parsing.astype(np.uint8), im_shape[::-1])
# 预测结果保存 # 预测结果保存
result_path = os.path.join(cfg.vis_dir, im_name + '.png') result_path = os.path.join(cfg.vis_dir, im_name + '.png')
if cfg.example == 'HumanSeg': if cfg.example == 'HumanSeg':
logits = result[0][0][1]*255 logits = result[0][0][1] * 255
logits = cv2.resize(logits, im_shape[::-1]) logits = cv2.resize(logits, im_shape[::-1])
ret, logits = cv2.threshold(logits, thresh, 0, cv2.THRESH_TOZERO) ret, logits = cv2.threshold(logits, thresh, 0, cv2.THRESH_TOZERO)
logits = 255 *(logits - thresh)/(255 - thresh) logits = 255 * (logits - thresh) / (255 - thresh)
# 将分割结果添加到alpha通道 # 将分割结果添加到alpha通道
rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)), axis=2) rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)),
axis=2)
cv2.imwrite(result_path, rgba) cv2.imwrite(result_path, rgba)
else: else:
output_im = PILImage.fromarray(np.asarray(parsing, dtype=np.uint8)) output_im = PILImage.fromarray(np.asarray(parsing, dtype=np.uint8))
output_im.putpalette(palette) output_im.putpalette(palette)
output_im.save(result_path) output_im.save(result_path)
if (idx + 1) % 100 == 0: if (idx + 1) % 100 == 0:
print('%d processd' % (idx + 1)) print('%d processd' % (idx + 1))
print('%d processd done' % (idx + 1)) print('%d processd done' % (idx + 1))
return 0 return 0
......
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from config import cfg from config import cfg
import cv2 import cv2
def get_affine_points(src_shape, dst_shape, rot_grad=0): def get_affine_points(src_shape, dst_shape, rot_grad=0):
# 获取图像和仿射后图像的三组对应点坐标 # 获取图像和仿射后图像的三组对应点坐标
# 三组点为仿射变换后图像的中心点, [w/2,0], [0,0],及对应原始图像的点 # 三组点为仿射变换后图像的中心点, [w/2,0], [0,0],及对应原始图像的点
...@@ -23,7 +38,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0): ...@@ -23,7 +38,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0):
# 原始图像三组点 # 原始图像三组点
points = [[0, 0]] * 3 points = [[0, 0]] * 3
points[0] = (np.array([w, h]) - 1) * 0.5 points[0] = (np.array([w, h]) - 1) * 0.5
points[1] = points[0] + 0.5 * affine_shape[0] * np.array([sin_v, -cos_v]) points[1] = points[0] + 0.5 * affine_shape[0] * np.array([sin_v, -cos_v])
points[2] = points[1] - 0.5 * affine_shape[1] * np.array([cos_v, sin_v]) points[2] = points[1] - 0.5 * affine_shape[1] * np.array([cos_v, sin_v])
...@@ -34,6 +49,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0): ...@@ -34,6 +49,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0):
return points, points_trans return points, points_trans
def preprocess(im): def preprocess(im):
# ACE2P模型数据预处理 # ACE2P模型数据预处理
im_shape = im.shape[:2] im_shape = im.shape[:2]
...@@ -42,13 +58,10 @@ def preprocess(im): ...@@ -42,13 +58,10 @@ def preprocess(im):
# 获取图像和仿射变换后图像的对应点坐标 # 获取图像和仿射变换后图像的对应点坐标
points, points_trans = get_affine_points(im_shape, scale) points, points_trans = get_affine_points(im_shape, scale)
# 根据对应点集获得仿射矩阵 # 根据对应点集获得仿射矩阵
trans = cv2.getAffineTransform(np.float32(points), trans = cv2.getAffineTransform(
np.float32(points_trans)) np.float32(points), np.float32(points_trans))
# 根据仿射矩阵对图像进行仿射 # 根据仿射矩阵对图像进行仿射
input = cv2.warpAffine(im, input = cv2.warpAffine(im, trans, scale[::-1], flags=cv2.INTER_LINEAR)
trans,
scale[::-1],
flags=cv2.INTER_LINEAR)
# 减均值测,除以方差,转换数据格式为NCHW # 减均值测,除以方差,转换数据格式为NCHW
input = input.astype(np.float32) input = input.astype(np.float32)
...@@ -66,19 +79,20 @@ def preprocess(im): ...@@ -66,19 +79,20 @@ def preprocess(im):
return input_images return input_images
def multi_scale_test(exe, test_prog, feed_name, fetch_list, def multi_scale_test(exe, test_prog, feed_name, fetch_list, input_ims,
input_ims, im_shape): im_shape):
# 由于部分类别分左右部位, flipped_idx为其水平翻转后对应的标签 # 由于部分类别分左右部位, flipped_idx为其水平翻转后对应的标签
flipped_idx = (15, 14, 17, 16, 19, 18) flipped_idx = (15, 14, 17, 16, 19, 18)
ms_outputs = [] ms_outputs = []
# 多尺度预测 # 多尺度预测
for idx, scale in enumerate(cfg.multi_scales): for idx, scale in enumerate(cfg.multi_scales):
input_im = input_ims[idx] input_im = input_ims[idx]
parsing_output = exe.run(program=test_prog, parsing_output = exe.run(
feed={feed_name[0]: input_im}, program=test_prog,
fetch_list=fetch_list) feed={feed_name[0]: input_im},
fetch_list=fetch_list)
output = parsing_output[0][0] output = parsing_output[0][0]
if cfg.flip: if cfg.flip:
# 若水平翻转,对部分类别进行翻转,与原始预测结果取均值 # 若水平翻转,对部分类别进行翻转,与原始预测结果取均值
...@@ -92,7 +106,8 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list, ...@@ -92,7 +106,8 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list,
# 仿射变换回图像原始尺寸 # 仿射变换回图像原始尺寸
points, points_trans = get_affine_points(im_shape, scale) points, points_trans = get_affine_points(im_shape, scale)
M = cv2.getAffineTransform(np.float32(points_trans), np.float32(points)) M = cv2.getAffineTransform(np.float32(points_trans), np.float32(points))
logits_result = cv2.warpAffine(output, M, im_shape[::-1], flags=cv2.INTER_LINEAR) logits_result = cv2.warpAffine(
output, M, im_shape[::-1], flags=cv2.INTER_LINEAR)
ms_outputs.append(logits_result) ms_outputs.append(logits_result)
# 多尺度预测结果求均值,求预测概率最大的类别 # 多尺度预测结果求均值,求预测概率最大的类别
...@@ -100,4 +115,3 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list, ...@@ -100,4 +115,3 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list,
ms_fused_parsing_output = np.mean(ms_fused_parsing_output, axis=0) ms_fused_parsing_output = np.mean(ms_fused_parsing_output, axis=0)
parsing = np.argmax(ms_fused_parsing_output, axis=2) parsing = np.argmax(ms_fused_parsing_output, axis=2)
return parsing, ms_fused_parsing_output return parsing, ms_fused_parsing_output
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
## This source code is licensed under the MIT-style license found in the ## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree ## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import argparse import argparse
import os import os
def get_arguments(): def get_arguments():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--use_gpu", parser.add_argument(
action="store_true", "--use_gpu", action="store_true", help="Use gpu or cpu to test.")
help="Use gpu or cpu to test.") parser.add_argument(
parser.add_argument('--example', '--example', type=str, help='RoadLine, HumanSeg or ACE2P')
type=str,
help='RoadLine, HumanSeg or ACE2P')
return parser.parse_args() return parser.parse_args()
...@@ -34,6 +48,7 @@ class AttrDict(dict): ...@@ -34,6 +48,7 @@ class AttrDict(dict):
else: else:
self[name] = value self[name] = value
def merge_cfg_from_args(args, cfg): def merge_cfg_from_args(args, cfg):
"""Merge config keys, values in args into the global config.""" """Merge config keys, values in args into the global config."""
for k, v in vars(args).items(): for k, v in vars(args).items():
...@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg): ...@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg):
value = v value = v
if value is not None: if value is not None:
cfg[k] = value cfg[k] = value
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,9 +13,6 @@ ...@@ -12,9 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# utils for memory management which is allocated on sharedmemory,
# note that these structures may not be thread-safe
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import models import models
import argparse import argparse
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import os import os
import os.path as osp import os.path as osp
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .humanseg import HumanSegMobile from .humanseg import HumanSegMobile
from .humanseg import HumanSegServer from .humanseg import HumanSegServer
from .humanseg import HumanSegLite from .humanseg import HumanSegLite
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .backbone import mobilenet_v2 from .backbone import mobilenet_v2
from .backbone import xception from .backbone import xception
from .deeplabv3p import DeepLabv3p from .deeplabv3p import DeepLabv3p
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .mobilenet_v2 import MobileNetV2 from .mobilenet_v2 import MobileNetV2
from .xception import Xception from .xception import Xception
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -10,6 +11,7 @@ ...@@ -10,6 +11,7 @@
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
from datasets.dataset import Dataset from datasets.dataset import Dataset
import transforms import transforms
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
from datasets.dataset import Dataset from datasets.dataset import Dataset
from models import HumanSegMobile, HumanSegLite, HumanSegServer from models import HumanSegMobile, HumanSegLite, HumanSegServer
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
from datasets.dataset import Dataset from datasets.dataset import Dataset
from models import HumanSegMobile, HumanSegLite, HumanSegServer from models import HumanSegMobile, HumanSegLite, HumanSegServer
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
import cv2 import cv2
import os import os
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
from datasets.dataset import Dataset from datasets.dataset import Dataset
import transforms import transforms
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import os import os
import os.path as osp import os.path as osp
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,6 +21,7 @@ from models.model_builder import ModelPhase ...@@ -21,6 +21,7 @@ from models.model_builder import ModelPhase
from pdseg.data_aug import get_random_scale, randomly_scale_image_and_label, random_rotation, \ from pdseg.data_aug import get_random_scale, randomly_scale_image_and_label, random_rotation, \
rand_scale_aspect, hsv_color_jitter, rand_crop rand_scale_aspect, hsv_color_jitter, rand_crop
def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN): def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN):
""" """
改变图像及标签图像尺寸 改变图像及标签图像尺寸
...@@ -44,7 +45,8 @@ def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN): ...@@ -44,7 +45,8 @@ def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN):
if grt is not None: if grt is not None:
grt = cv2.resize(grt, target_size, interpolation=cv2.INTER_NEAREST) grt = cv2.resize(grt, target_size, interpolation=cv2.INTER_NEAREST)
if grt_instance is not None: if grt_instance is not None:
grt_instance = cv2.resize(grt_instance, target_size, interpolation=cv2.INTER_NEAREST) grt_instance = cv2.resize(
grt_instance, target_size, interpolation=cv2.INTER_NEAREST)
elif cfg.AUG.AUG_METHOD == 'stepscaling': elif cfg.AUG.AUG_METHOD == 'stepscaling':
if mode == ModelPhase.TRAIN: if mode == ModelPhase.TRAIN:
min_scale_factor = cfg.AUG.MIN_SCALE_FACTOR min_scale_factor = cfg.AUG.MIN_SCALE_FACTOR
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,7 +18,6 @@ from __future__ import print_function ...@@ -18,7 +18,6 @@ from __future__ import print_function
import paddle.fluid as fluid import paddle.fluid as fluid
from utils.config import cfg from utils.config import cfg
from pdseg.models.libs.model_libs import scope, name_scope from pdseg.models.libs.model_libs import scope, name_scope
from pdseg.models.libs.model_libs import bn, bn_relu, relu from pdseg.models.libs.model_libs import bn, bn_relu, relu
...@@ -86,7 +85,12 @@ def bottleneck(inputs, ...@@ -86,7 +85,12 @@ def bottleneck(inputs,
with scope('down_sample'): with scope('down_sample'):
inputs_shape = inputs.shape inputs_shape = inputs.shape
with scope('main_max_pool'): with scope('main_max_pool'):
net_main = fluid.layers.conv2d(inputs, inputs_shape[1], filter_size=3, stride=2, padding='SAME') net_main = fluid.layers.conv2d(
inputs,
inputs_shape[1],
filter_size=3,
stride=2,
padding='SAME')
#First get the difference in depth to pad, then pad with zeros only on the last dimension. #First get the difference in depth to pad, then pad with zeros only on the last dimension.
depth_to_pad = abs(inputs_shape[1] - output_depth) depth_to_pad = abs(inputs_shape[1] - output_depth)
...@@ -95,12 +99,16 @@ def bottleneck(inputs, ...@@ -95,12 +99,16 @@ def bottleneck(inputs,
net_main = fluid.layers.pad(net_main, paddings=paddings) net_main = fluid.layers.pad(net_main, paddings=paddings)
with scope('block1'): with scope('block1'):
net = conv(inputs, reduced_depth, [2, 2], stride=2, padding='same') net = conv(
inputs, reduced_depth, [2, 2], stride=2, padding='same')
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
with scope('block2'): with scope('block2'):
net = conv(net, reduced_depth, [filter_size, filter_size], padding='same') net = conv(
net,
reduced_depth, [filter_size, filter_size],
padding='same')
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
...@@ -137,13 +145,18 @@ def bottleneck(inputs, ...@@ -137,13 +145,18 @@ def bottleneck(inputs,
# Second conv block --- apply dilated convolution here # Second conv block --- apply dilated convolution here
with scope('block2'): with scope('block2'):
net = conv(net, reduced_depth, filter_size, padding='SAME', dilation=dilation_rate) net = conv(
net,
reduced_depth,
filter_size,
padding='SAME',
dilation=dilation_rate)
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
# Final projection with 1x1 kernel (Expansion) # Final projection with 1x1 kernel (Expansion)
with scope('block3'): with scope('block3'):
net = conv(net, output_depth, [1,1]) net = conv(net, output_depth, [1, 1])
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
...@@ -172,9 +185,11 @@ def bottleneck(inputs, ...@@ -172,9 +185,11 @@ def bottleneck(inputs,
# Second conv block --- apply asymmetric conv here # Second conv block --- apply asymmetric conv here
with scope('block2'): with scope('block2'):
with scope('asymmetric_conv2a'): with scope('asymmetric_conv2a'):
net = conv(net, reduced_depth, [filter_size, 1], padding='same') net = conv(
net, reduced_depth, [filter_size, 1], padding='same')
with scope('asymmetric_conv2b'): with scope('asymmetric_conv2b'):
net = conv(net, reduced_depth, [1, filter_size], padding='same') net = conv(
net, reduced_depth, [1, filter_size], padding='same')
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
...@@ -211,7 +226,8 @@ def bottleneck(inputs, ...@@ -211,7 +226,8 @@ def bottleneck(inputs,
with scope('unpool'): with scope('unpool'):
net_unpool = conv(inputs, output_depth, [1, 1]) net_unpool = conv(inputs, output_depth, [1, 1])
net_unpool = bn(net_unpool) net_unpool = bn(net_unpool)
net_unpool = fluid.layers.resize_bilinear(net_unpool, out_shape=output_shape[2:]) net_unpool = fluid.layers.resize_bilinear(
net_unpool, out_shape=output_shape[2:])
# First 1x1 projection to reduce depth # First 1x1 projection to reduce depth
with scope('block1'): with scope('block1'):
...@@ -220,7 +236,12 @@ def bottleneck(inputs, ...@@ -220,7 +236,12 @@ def bottleneck(inputs,
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
with scope('block2'): with scope('block2'):
net = deconv(net, reduced_depth, filter_size=filter_size, stride=2, padding='same') net = deconv(
net,
reduced_depth,
filter_size=filter_size,
stride=2,
padding='same')
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
...@@ -253,7 +274,10 @@ def bottleneck(inputs, ...@@ -253,7 +274,10 @@ def bottleneck(inputs,
# Second conv block # Second conv block
with scope('block2'): with scope('block2'):
net = conv(net, reduced_depth, [filter_size, filter_size], padding='same') net = conv(
net,
reduced_depth, [filter_size, filter_size],
padding='same')
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
...@@ -281,17 +305,33 @@ def ENet_stage1(inputs, name_scope='stage1_block'): ...@@ -281,17 +305,33 @@ def ENet_stage1(inputs, name_scope='stage1_block'):
= bottleneck(inputs, output_depth=64, filter_size=3, regularizer_prob=0.01, type=DOWNSAMPLING, = bottleneck(inputs, output_depth=64, filter_size=3, regularizer_prob=0.01, type=DOWNSAMPLING,
name_scope='bottleneck1_0') name_scope='bottleneck1_0')
with scope('bottleneck1_1'): with scope('bottleneck1_1'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01, net = bottleneck(
name_scope='bottleneck1_1') net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_1')
with scope('bottleneck1_2'): with scope('bottleneck1_2'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01, net = bottleneck(
name_scope='bottleneck1_2') net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_2')
with scope('bottleneck1_3'): with scope('bottleneck1_3'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01, net = bottleneck(
name_scope='bottleneck1_3') net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_3')
with scope('bottleneck1_4'): with scope('bottleneck1_4'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01, net = bottleneck(
name_scope='bottleneck1_4') net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_4')
return net, inputs_shape_1 return net, inputs_shape_1
...@@ -302,17 +342,38 @@ def ENet_stage2(inputs, name_scope='stage2_block'): ...@@ -302,17 +342,38 @@ def ENet_stage2(inputs, name_scope='stage2_block'):
name_scope='bottleneck2_0') name_scope='bottleneck2_0')
for i in range(2): for i in range(2):
with scope('bottleneck2_{}'.format(str(4 * i + 1))): with scope('bottleneck2_{}'.format(str(4 * i + 1))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, net = bottleneck(
name_scope='bottleneck2_{}'.format(str(4 * i + 1))) net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
name_scope='bottleneck2_{}'.format(str(4 * i + 1)))
with scope('bottleneck2_{}'.format(str(4 * i + 2))): with scope('bottleneck2_{}'.format(str(4 * i + 2))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+1)), net = bottleneck(
name_scope='bottleneck2_{}'.format(str(4 * i + 2))) net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 1)),
name_scope='bottleneck2_{}'.format(str(4 * i + 2)))
with scope('bottleneck2_{}'.format(str(4 * i + 3))): with scope('bottleneck2_{}'.format(str(4 * i + 3))):
net = bottleneck(net, output_depth=128, filter_size=5, regularizer_prob=0.1, type=ASYMMETRIC, net = bottleneck(
name_scope='bottleneck2_{}'.format(str(4 * i + 3))) net,
output_depth=128,
filter_size=5,
regularizer_prob=0.1,
type=ASYMMETRIC,
name_scope='bottleneck2_{}'.format(str(4 * i + 3)))
with scope('bottleneck2_{}'.format(str(4 * i + 4))): with scope('bottleneck2_{}'.format(str(4 * i + 4))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+2)), net = bottleneck(
name_scope='bottleneck2_{}'.format(str(4 * i + 4))) net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 2)),
name_scope='bottleneck2_{}'.format(str(4 * i + 4)))
return net, inputs_shape_2 return net, inputs_shape_2
...@@ -320,52 +381,106 @@ def ENet_stage3(inputs, name_scope='stage3_block'): ...@@ -320,52 +381,106 @@ def ENet_stage3(inputs, name_scope='stage3_block'):
with scope(name_scope): with scope(name_scope):
for i in range(2): for i in range(2):
with scope('bottleneck3_{}'.format(str(4 * i + 0))): with scope('bottleneck3_{}'.format(str(4 * i + 0))):
net = bottleneck(inputs, output_depth=128, filter_size=3, regularizer_prob=0.1, net = bottleneck(
name_scope='bottleneck3_{}'.format(str(4 * i + 0))) inputs,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
name_scope='bottleneck3_{}'.format(str(4 * i + 0)))
with scope('bottleneck3_{}'.format(str(4 * i + 1))): with scope('bottleneck3_{}'.format(str(4 * i + 1))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+1)), net = bottleneck(
name_scope='bottleneck3_{}'.format(str(4 * i + 1))) net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 1)),
name_scope='bottleneck3_{}'.format(str(4 * i + 1)))
with scope('bottleneck3_{}'.format(str(4 * i + 2))): with scope('bottleneck3_{}'.format(str(4 * i + 2))):
net = bottleneck(net, output_depth=128, filter_size=5, regularizer_prob=0.1, type=ASYMMETRIC, net = bottleneck(
name_scope='bottleneck3_{}'.format(str(4 * i + 2))) net,
output_depth=128,
filter_size=5,
regularizer_prob=0.1,
type=ASYMMETRIC,
name_scope='bottleneck3_{}'.format(str(4 * i + 2)))
with scope('bottleneck3_{}'.format(str(4 * i + 3))): with scope('bottleneck3_{}'.format(str(4 * i + 3))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+2)), net = bottleneck(
name_scope='bottleneck3_{}'.format(str(4 * i + 3))) net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 2)),
name_scope='bottleneck3_{}'.format(str(4 * i + 3)))
return net return net
def ENet_stage4(inputs, inputs_shape, connect_tensor, def ENet_stage4(inputs,
skip_connections=True, name_scope='stage4_block'): inputs_shape,
connect_tensor,
skip_connections=True,
name_scope='stage4_block'):
with scope(name_scope): with scope(name_scope):
with scope('bottleneck4_0'): with scope('bottleneck4_0'):
net = bottleneck(inputs, output_depth=64, filter_size=3, regularizer_prob=0.1, net = bottleneck(
type=UPSAMPLING, decoder=True, output_shape=inputs_shape, inputs,
name_scope='bottleneck4_0') output_depth=64,
filter_size=3,
regularizer_prob=0.1,
type=UPSAMPLING,
decoder=True,
output_shape=inputs_shape,
name_scope='bottleneck4_0')
if skip_connections: if skip_connections:
net = fluid.layers.elementwise_add(net, connect_tensor) net = fluid.layers.elementwise_add(net, connect_tensor)
with scope('bottleneck4_1'): with scope('bottleneck4_1'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.1, decoder=True, net = bottleneck(
name_scope='bottleneck4_1') net,
output_depth=64,
filter_size=3,
regularizer_prob=0.1,
decoder=True,
name_scope='bottleneck4_1')
with scope('bottleneck4_2'): with scope('bottleneck4_2'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.1, decoder=True, net = bottleneck(
name_scope='bottleneck4_2') net,
output_depth=64,
filter_size=3,
regularizer_prob=0.1,
decoder=True,
name_scope='bottleneck4_2')
return net return net
def ENet_stage5(inputs, inputs_shape, connect_tensor, skip_connections=True, def ENet_stage5(inputs,
inputs_shape,
connect_tensor,
skip_connections=True,
name_scope='stage5_block'): name_scope='stage5_block'):
with scope(name_scope): with scope(name_scope):
net = bottleneck(inputs, output_depth=16, filter_size=3, regularizer_prob=0.1, type=UPSAMPLING, net = bottleneck(
decoder=True, output_shape=inputs_shape, inputs,
name_scope='bottleneck5_0') output_depth=16,
filter_size=3,
regularizer_prob=0.1,
type=UPSAMPLING,
decoder=True,
output_shape=inputs_shape,
name_scope='bottleneck5_0')
if skip_connections: if skip_connections:
net = fluid.layers.elementwise_add(net, connect_tensor) net = fluid.layers.elementwise_add(net, connect_tensor)
with scope('bottleneck5_1'): with scope('bottleneck5_1'):
net = bottleneck(net, output_depth=16, filter_size=3, regularizer_prob=0.1, decoder=True, net = bottleneck(
name_scope='bottleneck5_1') net,
output_depth=16,
filter_size=3,
regularizer_prob=0.1,
decoder=True,
name_scope='bottleneck5_1')
return net return net
...@@ -378,14 +493,16 @@ def decoder(input, num_classes): ...@@ -378,14 +493,16 @@ def decoder(input, num_classes):
segStage3 = ENet_stage3(stage2) segStage3 = ENet_stage3(stage2)
segStage4 = ENet_stage4(segStage3, inputs_shape_2, stage1) segStage4 = ENet_stage4(segStage3, inputs_shape_2, stage1)
segStage5 = ENet_stage5(segStage4, inputs_shape_1, initial) segStage5 = ENet_stage5(segStage4, inputs_shape_1, initial)
segLogits = deconv(segStage5, num_classes, filter_size=2, stride=2, padding='SAME') segLogits = deconv(
segStage5, num_classes, filter_size=2, stride=2, padding='SAME')
# Embedding branch # Embedding branch
with scope('LaneNetEm'): with scope('LaneNetEm'):
emStage3 = ENet_stage3(stage2) emStage3 = ENet_stage3(stage2)
emStage4 = ENet_stage4(emStage3, inputs_shape_2, stage1) emStage4 = ENet_stage4(emStage3, inputs_shape_2, stage1)
emStage5 = ENet_stage5(emStage4, inputs_shape_1, initial) emStage5 = ENet_stage5(emStage4, inputs_shape_1, initial)
emLogits = deconv(emStage5, 4, filter_size=2, stride=2, padding='SAME') emLogits = deconv(
emStage5, 4, filter_size=2, stride=2, padding='SAME')
elif 'vgg' in cfg.MODEL.LANENET.BACKBONE: elif 'vgg' in cfg.MODEL.LANENET.BACKBONE:
encoder_list = ['pool5', 'pool4', 'pool3'] encoder_list = ['pool5', 'pool4', 'pool3']
...@@ -396,14 +513,16 @@ def decoder(input, num_classes): ...@@ -396,14 +513,16 @@ def decoder(input, num_classes):
encoder_list = encoder_list[1:] encoder_list = encoder_list[1:]
for i in range(len(encoder_list)): for i in range(len(encoder_list)):
with scope('deconv_{:d}'.format(i + 1)): with scope('deconv_{:d}'.format(i + 1)):
deconv_out = deconv(score, 64, filter_size=4, stride=2, padding='SAME') deconv_out = deconv(
score, 64, filter_size=4, stride=2, padding='SAME')
input_tensor = input[encoder_list[i]] input_tensor = input[encoder_list[i]]
with scope('score_{:d}'.format(i + 1)): with scope('score_{:d}'.format(i + 1)):
score = conv(input_tensor, 64, 1) score = conv(input_tensor, 64, 1)
score = fluid.layers.elementwise_add(deconv_out, score) score = fluid.layers.elementwise_add(deconv_out, score)
with scope('deconv_final'): with scope('deconv_final'):
emLogits = deconv(score, 64, filter_size=16, stride=8, padding='SAME') emLogits = deconv(
score, 64, filter_size=16, stride=8, padding='SAME')
with scope('score_final'): with scope('score_final'):
segLogits = conv(emLogits, num_classes, 1) segLogits = conv(emLogits, num_classes, 1)
emLogits = relu(conv(emLogits, 4, 1)) emLogits = relu(conv(emLogits, 4, 1))
...@@ -415,7 +534,8 @@ def encoder(input): ...@@ -415,7 +534,8 @@ def encoder(input):
model = vgg_backbone(layers=16) model = vgg_backbone(layers=16)
#output = model.net(input) #output = model.net(input)
_, encode_feature_dict = model.net(input, end_points=13, decode_points=[7, 10, 13]) _, encode_feature_dict = model.net(
input, end_points=13, decode_points=[7, 10, 13])
output = {} output = {}
output['pool3'] = encode_feature_dict[7] output['pool3'] = encode_feature_dict[7]
output['pool4'] = encode_feature_dict[10] output['pool4'] = encode_feature_dict[10]
...@@ -427,8 +547,9 @@ def encoder(input): ...@@ -427,8 +547,9 @@ def encoder(input):
stage2, inputs_shape_2 = ENet_stage2(stage1) stage2, inputs_shape_2 = ENet_stage2(stage1)
output = (initial, stage1, stage2, inputs_shape_1, inputs_shape_2) output = (initial, stage1, stage2, inputs_shape_1, inputs_shape_2)
else: else:
raise Exception("LaneNet expect enet and vgg backbone, but received {}". raise Exception(
format(cfg.MODEL.LANENET.BACKBONE)) "LaneNet expect enet and vgg backbone, but received {}".format(
cfg.MODEL.LANENET.BACKBONE))
return output return output
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -58,7 +58,8 @@ class LaneNetDataset(): ...@@ -58,7 +58,8 @@ class LaneNetDataset():
if self.shuffle and cfg.NUM_TRAINERS > 1: if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines) np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // cfg.NUM_TRAINERS num_lines = len(self.all_lines) // cfg.NUM_TRAINERS
self.lines = self.all_lines[num_lines * cfg.TRAINER_ID: num_lines * (cfg.TRAINER_ID + 1)] self.lines = self.all_lines[num_lines * cfg.TRAINER_ID:num_lines *
(cfg.TRAINER_ID + 1)]
self.shuffle_seed += 1 self.shuffle_seed += 1
elif self.shuffle: elif self.shuffle:
np.random.shuffle(self.lines) np.random.shuffle(self.lines)
...@@ -86,7 +87,8 @@ class LaneNetDataset(): ...@@ -86,7 +87,8 @@ class LaneNetDataset():
if self.shuffle and cfg.NUM_TRAINERS > 1: if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines) np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // self.num_trainers num_lines = len(self.all_lines) // self.num_trainers
self.lines = self.all_lines[num_lines * self.trainer_id: num_lines * (self.trainer_id + 1)] self.lines = self.all_lines[num_lines * self.trainer_id:num_lines *
(self.trainer_id + 1)]
self.shuffle_seed += 1 self.shuffle_seed += 1
elif self.shuffle: elif self.shuffle:
np.random.shuffle(self.lines) np.random.shuffle(self.lines)
...@@ -118,7 +120,8 @@ class LaneNetDataset(): ...@@ -118,7 +120,8 @@ class LaneNetDataset():
def batch_reader(is_test=False, drop_last=drop_last): def batch_reader(is_test=False, drop_last=drop_last):
if is_test: if is_test:
imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], [] imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], []
for img, grt, grt_instance, img_name, valid_shape, org_shape in reader(): for img, grt, grt_instance, img_name, valid_shape, org_shape in reader(
):
imgs.append(img) imgs.append(img)
grts.append(grt) grts.append(grt)
grts_instance.append(grt_instance) grts_instance.append(grt_instance)
...@@ -126,14 +129,15 @@ class LaneNetDataset(): ...@@ -126,14 +129,15 @@ class LaneNetDataset():
valid_shapes.append(valid_shape) valid_shapes.append(valid_shape)
org_shapes.append(org_shape) org_shapes.append(org_shape)
if len(imgs) == batch_size: if len(imgs) == batch_size:
yield np.array(imgs), np.array( yield np.array(imgs), np.array(grts), np.array(
grts), np.array(grts_instance), img_names, np.array(valid_shapes), np.array( grts_instance), img_names, np.array(
org_shapes) valid_shapes), np.array(org_shapes)
imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], [] imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], []
if not drop_last and len(imgs) > 0: if not drop_last and len(imgs) > 0:
yield np.array(imgs), np.array(grts), np.array(grts_instance), img_names, np.array( yield np.array(imgs), np.array(grts), np.array(
valid_shapes), np.array(org_shapes) grts_instance), img_names, np.array(
valid_shapes), np.array(org_shapes)
else: else:
imgs, labs, labs_instance, ignore = [], [], [], [] imgs, labs, labs_instance, ignore = [], [], [], []
bs = 0 bs = 0
...@@ -144,12 +148,14 @@ class LaneNetDataset(): ...@@ -144,12 +148,14 @@ class LaneNetDataset():
ignore.append(ig) ignore.append(ig)
bs += 1 bs += 1
if bs == batch_size: if bs == batch_size:
yield np.array(imgs), np.array(labs), np.array(labs_instance), np.array(ignore) yield np.array(imgs), np.array(labs), np.array(
labs_instance), np.array(ignore)
bs = 0 bs = 0
imgs, labs, labs_instance, ignore = [], [], [], [] imgs, labs, labs_instance, ignore = [], [], [], []
if not drop_last and bs > 0: if not drop_last and bs > 0:
yield np.array(imgs), np.array(labs), np.array(labs_instance), np.array(ignore) yield np.array(imgs), np.array(labs), np.array(
labs_instance), np.array(ignore)
return batch_reader(is_test, drop_last) return batch_reader(is_test, drop_last)
...@@ -299,10 +305,12 @@ class LaneNetDataset(): ...@@ -299,10 +305,12 @@ class LaneNetDataset():
img, grt = aug.rand_crop(img, grt, mode=mode) img, grt = aug.rand_crop(img, grt, mode=mode)
elif ModelPhase.is_eval(mode): elif ModelPhase.is_eval(mode):
img, grt, grt_instance = aug.resize(img, grt, grt_instance, mode=mode) img, grt, grt_instance = aug.resize(
img, grt, grt_instance, mode=mode)
elif ModelPhase.is_visual(mode): elif ModelPhase.is_visual(mode):
ori_img = img.copy() ori_img = img.copy()
img, grt, grt_instance = aug.resize(img, grt, grt_instance, mode=mode) img, grt, grt_instance = aug.resize(
img, grt, grt_instance, mode=mode)
valid_shape = [img.shape[0], img.shape[1]] valid_shape = [img.shape[0], img.shape[1]]
else: else:
raise ValueError("Dataset mode={} Error!".format(mode)) raise ValueError("Dataset mode={} Error!".format(mode))
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -80,8 +80,8 @@ cfg.DATASET.DATA_DIM = 3 ...@@ -80,8 +80,8 @@ cfg.DATASET.DATA_DIM = 3
cfg.DATASET.SEPARATOR = ' ' cfg.DATASET.SEPARATOR = ' '
# 忽略的像素标签值, 默认为255,一般无需改动 # 忽略的像素标签值, 默认为255,一般无需改动
cfg.DATASET.IGNORE_INDEX = 255 cfg.DATASET.IGNORE_INDEX = 255
# 数据增强是图像的padding值 # 数据增强是图像的padding值
cfg.DATASET.PADDING_VALUE = [127.5,127.5,127.5] cfg.DATASET.PADDING_VALUE = [127.5, 127.5, 127.5]
########################### 数据增强配置 ###################################### ########################### 数据增强配置 ######################################
# 图像镜像左右翻转 # 图像镜像左右翻转
......
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
generate tusimple training dataset generate tusimple training dataset
""" """
...@@ -14,12 +28,16 @@ import numpy as np ...@@ -14,12 +28,16 @@ import numpy as np
def init_args(): def init_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--src_dir', type=str, help='The origin path of unzipped tusimple dataset') parser.add_argument(
'--src_dir',
type=str,
help='The origin path of unzipped tusimple dataset')
return parser.parse_args() return parser.parse_args()
def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, instance_dst_dir): def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir,
instance_dst_dir):
assert ops.exists(json_file_path), '{:s} not exist'.format(json_file_path) assert ops.exists(json_file_path), '{:s} not exist'.format(json_file_path)
...@@ -39,11 +57,14 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst ...@@ -39,11 +57,14 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
h_samples = info_dict['h_samples'] h_samples = info_dict['h_samples']
lanes = info_dict['lanes'] lanes = info_dict['lanes']
image_name_new = '{:s}.png'.format('{:d}'.format(line_index + image_nums).zfill(4)) image_name_new = '{:s}.png'.format(
'{:d}'.format(line_index + image_nums).zfill(4))
src_image = cv2.imread(image_path, cv2.IMREAD_COLOR) src_image = cv2.imread(image_path, cv2.IMREAD_COLOR)
dst_binary_image = np.zeros([src_image.shape[0], src_image.shape[1]], np.uint8) dst_binary_image = np.zeros(
dst_instance_image = np.zeros([src_image.shape[0], src_image.shape[1]], np.uint8) [src_image.shape[0], src_image.shape[1]], np.uint8)
dst_instance_image = np.zeros(
[src_image.shape[0], src_image.shape[1]], np.uint8)
for lane_index, lane in enumerate(lanes): for lane_index, lane in enumerate(lanes):
assert len(h_samples) == len(lane) assert len(h_samples) == len(lane)
...@@ -62,13 +83,23 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst ...@@ -62,13 +83,23 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
lane_pts = np.vstack((lane_x, lane_y)).transpose() lane_pts = np.vstack((lane_x, lane_y)).transpose()
lane_pts = np.array([lane_pts], np.int64) lane_pts = np.array([lane_pts], np.int64)
cv2.polylines(dst_binary_image, lane_pts, isClosed=False, cv2.polylines(
color=255, thickness=5) dst_binary_image,
cv2.polylines(dst_instance_image, lane_pts, isClosed=False, lane_pts,
color=lane_index * 50 + 20, thickness=5) isClosed=False,
color=255,
dst_binary_image_path = ops.join(src_dir, binary_dst_dir, image_name_new) thickness=5)
dst_instance_image_path = ops.join(src_dir, instance_dst_dir, image_name_new) cv2.polylines(
dst_instance_image,
lane_pts,
isClosed=False,
color=lane_index * 50 + 20,
thickness=5)
dst_binary_image_path = ops.join(src_dir, binary_dst_dir,
image_name_new)
dst_instance_image_path = ops.join(src_dir, instance_dst_dir,
image_name_new)
dst_rgb_image_path = ops.join(src_dir, ori_dst_dir, image_name_new) dst_rgb_image_path = ops.join(src_dir, ori_dst_dir, image_name_new)
cv2.imwrite(dst_binary_image_path, dst_binary_image) cv2.imwrite(dst_binary_image_path, dst_binary_image)
...@@ -78,7 +109,12 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst ...@@ -78,7 +109,12 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
print('Process {:s} success'.format(image_name)) print('Process {:s} success'.format(image_name))
def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train', split=False): def gen_sample(src_dir,
b_gt_image_dir,
i_gt_image_dir,
image_dir,
phase='train',
split=False):
label_list = [] label_list = []
with open('{:s}/{}ing/{}.txt'.format(src_dir, phase, phase), 'w') as file: with open('{:s}/{}ing/{}.txt'.format(src_dir, phase, phase), 'w') as file:
...@@ -92,7 +128,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train' ...@@ -92,7 +128,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
image_path = ops.join(image_dir, image_name) image_path = ops.join(image_dir, image_name)
assert ops.exists(image_path), '{:s} not exist'.format(image_path) assert ops.exists(image_path), '{:s} not exist'.format(image_path)
assert ops.exists(instance_gt_image_path), '{:s} not exist'.format(instance_gt_image_path) assert ops.exists(instance_gt_image_path), '{:s} not exist'.format(
instance_gt_image_path)
b_gt_image = cv2.imread(binary_gt_image_path, cv2.IMREAD_COLOR) b_gt_image = cv2.imread(binary_gt_image_path, cv2.IMREAD_COLOR)
i_gt_image = cv2.imread(instance_gt_image_path, cv2.IMREAD_COLOR) i_gt_image = cv2.imread(instance_gt_image_path, cv2.IMREAD_COLOR)
...@@ -102,7 +139,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train' ...@@ -102,7 +139,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
print('image: {:s} corrupt'.format(image_name)) print('image: {:s} corrupt'.format(image_name))
continue continue
else: else:
info = '{:s} {:s} {:s}'.format(image_path, binary_gt_image_path, instance_gt_image_path) info = '{:s} {:s} {:s}'.format(image_path, binary_gt_image_path,
instance_gt_image_path)
file.write(info + '\n') file.write(info + '\n')
label_list.append(info) label_list.append(info)
if phase == 'train' and split: if phase == 'train' and split:
...@@ -110,10 +148,12 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train' ...@@ -110,10 +148,12 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
val_list_len = len(label_list) // 10 val_list_len = len(label_list) // 10
val_label_list = label_list[:val_list_len] val_label_list = label_list[:val_list_len]
train_label_list = label_list[val_list_len:] train_label_list = label_list[val_list_len:]
with open('{:s}/{}ing/train_part.txt'.format(src_dir, phase, phase), 'w') as file: with open('{:s}/{}ing/train_part.txt'.format(src_dir, phase, phase),
'w') as file:
for info in train_label_list: for info in train_label_list:
file.write(info + '\n') file.write(info + '\n')
with open('{:s}/{}ing/val_part.txt'.format(src_dir, phase, phase), 'w') as file: with open('{:s}/{}ing/val_part.txt'.format(src_dir, phase, phase),
'w') as file:
for info in val_label_list: for info in val_label_list:
file.write(info + '\n') file.write(info + '\n')
return return
...@@ -130,12 +170,14 @@ def process_tusimple_dataset(src_dir): ...@@ -130,12 +170,14 @@ def process_tusimple_dataset(src_dir):
for json_label_path in glob.glob('{:s}/label*.json'.format(src_dir)): for json_label_path in glob.glob('{:s}/label*.json'.format(src_dir)):
json_label_name = ops.split(json_label_path)[1] json_label_name = ops.split(json_label_path)[1]
shutil.copyfile(json_label_path, ops.join(traing_folder_path, json_label_name)) shutil.copyfile(json_label_path,
ops.join(traing_folder_path, json_label_name))
for json_label_path in glob.glob('{:s}/test_label.json'.format(src_dir)): for json_label_path in glob.glob('{:s}/test_label.json'.format(src_dir)):
json_label_name = ops.split(json_label_path)[1] json_label_name = ops.split(json_label_path)[1]
shutil.copyfile(json_label_path, ops.join(testing_folder_path, json_label_name)) shutil.copyfile(json_label_path,
ops.join(testing_folder_path, json_label_name))
train_gt_image_dir = ops.join('training', 'gt_image') train_gt_image_dir = ops.join('training', 'gt_image')
train_gt_binary_dir = ops.join('training', 'gt_binary_image') train_gt_binary_dir = ops.join('training', 'gt_binary_image')
...@@ -154,9 +196,11 @@ def process_tusimple_dataset(src_dir): ...@@ -154,9 +196,11 @@ def process_tusimple_dataset(src_dir):
os.makedirs(os.path.join(src_dir, test_gt_instance_dir), exist_ok=True) os.makedirs(os.path.join(src_dir, test_gt_instance_dir), exist_ok=True)
for json_label_path in glob.glob('{:s}/*.json'.format(traing_folder_path)): for json_label_path in glob.glob('{:s}/*.json'.format(traing_folder_path)):
process_json_file(json_label_path, src_dir, train_gt_image_dir, train_gt_binary_dir, train_gt_instance_dir) process_json_file(json_label_path, src_dir, train_gt_image_dir,
train_gt_binary_dir, train_gt_instance_dir)
gen_sample(src_dir, train_gt_binary_dir, train_gt_instance_dir, train_gt_image_dir, 'train', True) gen_sample(src_dir, train_gt_binary_dir, train_gt_instance_dir,
train_gt_image_dir, 'train', True)
if __name__ == '__main__': if __name__ == '__main__':
......
#!/usr/bin/env python3 # coding: utf8
# -*- coding: utf-8 -*- # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this code heavily base on https://github.com/MaybeShewill-CV/lanenet-lane-detection/blob/master/lanenet_model/lanenet_postprocess.py # this code heavily base on https://github.com/MaybeShewill-CV/lanenet-lane-detection/blob/master/lanenet_model/lanenet_postprocess.py
""" """
LaneNet model post process LaneNet model post process
...@@ -22,12 +35,14 @@ def _morphological_process(image, kernel_size=5): ...@@ -22,12 +35,14 @@ def _morphological_process(image, kernel_size=5):
:return: :return:
""" """
if len(image.shape) == 3: if len(image.shape) == 3:
raise ValueError('Binary segmentation result image should be a single channel image') raise ValueError(
'Binary segmentation result image should be a single channel image')
if image.dtype is not np.uint8: if image.dtype is not np.uint8:
image = np.array(image, np.uint8) image = np.array(image, np.uint8)
kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(kernel_size, kernel_size)) kernel = cv2.getStructuringElement(
shape=cv2.MORPH_ELLIPSE, ksize=(kernel_size, kernel_size))
# close operation fille hole # close operation fille hole
closing = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel, iterations=1) closing = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel, iterations=1)
...@@ -46,13 +61,15 @@ def _connect_components_analysis(image): ...@@ -46,13 +61,15 @@ def _connect_components_analysis(image):
else: else:
gray_image = image gray_image = image
return cv2.connectedComponentsWithStats(gray_image, connectivity=8, ltype=cv2.CV_32S) return cv2.connectedComponentsWithStats(
gray_image, connectivity=8, ltype=cv2.CV_32S)
class _LaneFeat(object): class _LaneFeat(object):
""" """
""" """
def __init__(self, feat, coord, class_id=-1): def __init__(self, feat, coord, class_id=-1):
""" """
lane feat object lane feat object
...@@ -108,18 +125,21 @@ class _LaneNetCluster(object): ...@@ -108,18 +125,21 @@ class _LaneNetCluster(object):
""" """
Instance segmentation result cluster Instance segmentation result cluster
""" """
def __init__(self): def __init__(self):
""" """
""" """
self._color_map = [np.array([255, 0, 0]), self._color_map = [
np.array([0, 255, 0]), np.array([255, 0, 0]),
np.array([0, 0, 255]), np.array([0, 255, 0]),
np.array([125, 125, 0]), np.array([0, 0, 255]),
np.array([0, 125, 125]), np.array([125, 125, 0]),
np.array([125, 0, 125]), np.array([0, 125, 125]),
np.array([50, 100, 50]), np.array([125, 0, 125]),
np.array([100, 50, 100])] np.array([50, 100, 50]),
np.array([100, 50, 100])
]
@staticmethod @staticmethod
def _embedding_feats_dbscan_cluster(embedding_image_feats): def _embedding_feats_dbscan_cluster(embedding_image_feats):
...@@ -186,15 +206,16 @@ class _LaneNetCluster(object): ...@@ -186,15 +206,16 @@ class _LaneNetCluster(object):
# get embedding feats and coords # get embedding feats and coords
get_lane_embedding_feats_result = self._get_lane_embedding_feats( get_lane_embedding_feats_result = self._get_lane_embedding_feats(
binary_seg_ret=binary_seg_result, binary_seg_ret=binary_seg_result,
instance_seg_ret=instance_seg_result instance_seg_ret=instance_seg_result)
)
# dbscan cluster # dbscan cluster
dbscan_cluster_result = self._embedding_feats_dbscan_cluster( dbscan_cluster_result = self._embedding_feats_dbscan_cluster(
embedding_image_feats=get_lane_embedding_feats_result['lane_embedding_feats'] embedding_image_feats=get_lane_embedding_feats_result[
) 'lane_embedding_feats'])
mask = np.zeros(shape=[binary_seg_result.shape[0], binary_seg_result.shape[1], 3], dtype=np.uint8) mask = np.zeros(
shape=[binary_seg_result.shape[0], binary_seg_result.shape[1], 3],
dtype=np.uint8)
db_labels = dbscan_cluster_result['db_labels'] db_labels = dbscan_cluster_result['db_labels']
unique_labels = dbscan_cluster_result['unique_labels'] unique_labels = dbscan_cluster_result['unique_labels']
coord = get_lane_embedding_feats_result['lane_coordinates'] coord = get_lane_embedding_feats_result['lane_coordinates']
...@@ -219,11 +240,13 @@ class LaneNetPostProcessor(object): ...@@ -219,11 +240,13 @@ class LaneNetPostProcessor(object):
""" """
lanenet post process for lane generation lanenet post process for lane generation
""" """
def __init__(self, ipm_remap_file_path='./utils/tusimple_ipm_remap.yml'): def __init__(self, ipm_remap_file_path='./utils/tusimple_ipm_remap.yml'):
""" """
convert front car view to bird view convert front car view to bird view
""" """
assert ops.exists(ipm_remap_file_path), '{:s} not exist'.format(ipm_remap_file_path) assert ops.exists(ipm_remap_file_path), '{:s} not exist'.format(
ipm_remap_file_path)
self._cluster = _LaneNetCluster() self._cluster = _LaneNetCluster()
self._ipm_remap_file_path = ipm_remap_file_path self._ipm_remap_file_path = ipm_remap_file_path
...@@ -232,14 +255,16 @@ class LaneNetPostProcessor(object): ...@@ -232,14 +255,16 @@ class LaneNetPostProcessor(object):
self._remap_to_ipm_x = remap_file_load_ret['remap_to_ipm_x'] self._remap_to_ipm_x = remap_file_load_ret['remap_to_ipm_x']
self._remap_to_ipm_y = remap_file_load_ret['remap_to_ipm_y'] self._remap_to_ipm_y = remap_file_load_ret['remap_to_ipm_y']
self._color_map = [np.array([255, 0, 0]), self._color_map = [
np.array([0, 255, 0]), np.array([255, 0, 0]),
np.array([0, 0, 255]), np.array([0, 255, 0]),
np.array([125, 125, 0]), np.array([0, 0, 255]),
np.array([0, 125, 125]), np.array([125, 125, 0]),
np.array([125, 0, 125]), np.array([0, 125, 125]),
np.array([50, 100, 50]), np.array([125, 0, 125]),
np.array([100, 50, 100])] np.array([50, 100, 50]),
np.array([100, 50, 100])
]
def _load_remap_matrix(self): def _load_remap_matrix(self):
fs = cv2.FileStorage(self._ipm_remap_file_path, cv2.FILE_STORAGE_READ) fs = cv2.FileStorage(self._ipm_remap_file_path, cv2.FILE_STORAGE_READ)
...@@ -256,15 +281,20 @@ class LaneNetPostProcessor(object): ...@@ -256,15 +281,20 @@ class LaneNetPostProcessor(object):
return ret return ret
def postprocess(self, binary_seg_result, instance_seg_result=None, def postprocess(self,
min_area_threshold=100, source_image=None, binary_seg_result,
instance_seg_result=None,
min_area_threshold=100,
source_image=None,
data_source='tusimple'): data_source='tusimple'):
# convert binary_seg_result # convert binary_seg_result
binary_seg_result = np.array(binary_seg_result * 255, dtype=np.uint8) binary_seg_result = np.array(binary_seg_result * 255, dtype=np.uint8)
# apply image morphology operation to fill in the hold and reduce the small area # apply image morphology operation to fill in the hold and reduce the small area
morphological_ret = _morphological_process(binary_seg_result, kernel_size=5) morphological_ret = _morphological_process(
connect_components_analysis_ret = _connect_components_analysis(image=morphological_ret) binary_seg_result, kernel_size=5)
connect_components_analysis_ret = _connect_components_analysis(
image=morphological_ret)
labels = connect_components_analysis_ret[1] labels = connect_components_analysis_ret[1]
stats = connect_components_analysis_ret[2] stats = connect_components_analysis_ret[2]
...@@ -276,8 +306,7 @@ class LaneNetPostProcessor(object): ...@@ -276,8 +306,7 @@ class LaneNetPostProcessor(object):
# apply embedding features cluster # apply embedding features cluster
mask_image, lane_coords = self._cluster.apply_lane_feats_cluster( mask_image, lane_coords = self._cluster.apply_lane_feats_cluster(
binary_seg_result=morphological_ret, binary_seg_result=morphological_ret,
instance_seg_result=instance_seg_result instance_seg_result=instance_seg_result)
)
if mask_image is None: if mask_image is None:
return { return {
...@@ -292,15 +321,15 @@ class LaneNetPostProcessor(object): ...@@ -292,15 +321,15 @@ class LaneNetPostProcessor(object):
for lane_index, coords in enumerate(lane_coords): for lane_index, coords in enumerate(lane_coords):
if data_source == 'tusimple': if data_source == 'tusimple':
tmp_mask = np.zeros(shape=(720, 1280), dtype=np.uint8) tmp_mask = np.zeros(shape=(720, 1280), dtype=np.uint8)
tmp_mask[tuple((np.int_(coords[:, 1] * 720 / 256), np.int_(coords[:, 0] * 1280 / 512)))] = 255 tmp_mask[tuple((np.int_(coords[:, 1] * 720 / 256),
np.int_(coords[:, 0] * 1280 / 512)))] = 255
else: else:
raise ValueError('Wrong data source now only support tusimple') raise ValueError('Wrong data source now only support tusimple')
tmp_ipm_mask = cv2.remap( tmp_ipm_mask = cv2.remap(
tmp_mask, tmp_mask,
self._remap_to_ipm_x, self._remap_to_ipm_x,
self._remap_to_ipm_y, self._remap_to_ipm_y,
interpolation=cv2.INTER_NEAREST interpolation=cv2.INTER_NEAREST)
)
nonzero_y = np.array(tmp_ipm_mask.nonzero()[0]) nonzero_y = np.array(tmp_ipm_mask.nonzero()[0])
nonzero_x = np.array(tmp_ipm_mask.nonzero()[1]) nonzero_x = np.array(tmp_ipm_mask.nonzero()[1])
...@@ -309,16 +338,19 @@ class LaneNetPostProcessor(object): ...@@ -309,16 +338,19 @@ class LaneNetPostProcessor(object):
[ipm_image_height, ipm_image_width] = tmp_ipm_mask.shape [ipm_image_height, ipm_image_width] = tmp_ipm_mask.shape
plot_y = np.linspace(10, ipm_image_height, ipm_image_height - 10) plot_y = np.linspace(10, ipm_image_height, ipm_image_height - 10)
fit_x = fit_param[0] * plot_y ** 2 + fit_param[1] * plot_y + fit_param[2] fit_x = fit_param[0] * plot_y**2 + fit_param[
1] * plot_y + fit_param[2]
lane_pts = [] lane_pts = []
for index in range(0, plot_y.shape[0], 5): for index in range(0, plot_y.shape[0], 5):
src_x = self._remap_to_ipm_x[ src_x = self._remap_to_ipm_x[
int(plot_y[index]), int(np.clip(fit_x[index], 0, ipm_image_width - 1))] int(plot_y[index]),
int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
if src_x <= 0: if src_x <= 0:
continue continue
src_y = self._remap_to_ipm_y[ src_y = self._remap_to_ipm_y[
int(plot_y[index]), int(np.clip(fit_x[index], 0, ipm_image_width - 1))] int(plot_y[index]),
int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
src_y = src_y if src_y > 0 else 0 src_y = src_y if src_y > 0 else 0
lane_pts.append([src_x, src_y]) lane_pts.append([src_x, src_y])
...@@ -366,8 +398,10 @@ class LaneNetPostProcessor(object): ...@@ -366,8 +398,10 @@ class LaneNetPostProcessor(object):
continue continue
lane_color = self._color_map[index].tolist() lane_color = self._color_map[index].tolist()
cv2.circle(source_image, (int(interpolation_src_pt_x), cv2.circle(
int(interpolation_src_pt_y)), 5, lane_color, -1) source_image,
(int(interpolation_src_pt_x), int(interpolation_src_pt_y)),
5, lane_color, -1)
ret = { ret = {
'mask_image': mask_image, 'mask_image': mask_image,
'fit_params': fit_params, 'fit_params': fit_params,
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -23,7 +24,8 @@ from test_utils import download_file_and_uncompress ...@@ -23,7 +24,8 @@ from test_utils import download_file_and_uncompress
if __name__ == "__main__": if __name__ == "__main__":
download_file_and_uncompress( download_file_and_uncompress(
url='https://paddleseg.bj.bcebos.com/models/unet_mechanical_industry_meter.tar', url=
'https://paddleseg.bj.bcebos.com/models/unet_mechanical_industry_meter.tar',
savepath=LOCAL_PATH, savepath=LOCAL_PATH,
extrapath=LOCAL_PATH) extrapath=LOCAL_PATH)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .load_model import * from .load_model import *
from .unet import * from .unet import *
from .hrnet import * from .hrnet import *
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
import paddle.fluid as fluid import paddle.fluid as fluid
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
import numpy as np import numpy as np
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .unet import UNet from .unet import UNet
from .hrnet import HRNet from .hrnet import HRNet
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import os.path as osp import os.path as osp
import sys import sys
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp import os.path as osp
import argparse import argparse
import transforms.transforms as T import transforms.transforms as T
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp import os.path as osp
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from utils.util import AttrDict, merge_cfg_from_args, get_arguments from utils.util import AttrDict, merge_cfg_from_args, get_arguments
import os import os
...@@ -6,20 +20,20 @@ args = get_arguments() ...@@ -6,20 +20,20 @@ args = get_arguments()
cfg = AttrDict() cfg = AttrDict()
# 待预测图像所在路径 # 待预测图像所在路径
cfg.data_dir = os.path.join(args.example , "data", "test_images") cfg.data_dir = os.path.join(args.example, "data", "test_images")
# 待预测图像名称列表 # 待预测图像名称列表
cfg.data_list_file = os.path.join(args.example , "data", "test.txt") cfg.data_list_file = os.path.join(args.example, "data", "test.txt")
# 模型加载路径 # 模型加载路径
cfg.model_path = os.path.join(args.example , "model") cfg.model_path = os.path.join(args.example, "model")
# 预测结果保存路径 # 预测结果保存路径
cfg.vis_dir = os.path.join(args.example , "result") cfg.vis_dir = os.path.join(args.example, "result")
# 预测类别数 # 预测类别数
cfg.class_num = 2 cfg.class_num = 2
# 均值, 图像预处理减去的均值 # 均值, 图像预处理减去的均值
cfg.MEAN = 127.5, 127.5, 127.5 cfg.MEAN = 127.5, 127.5, 127.5
# 标准差,图像预处理除以标准差 # 标准差,图像预处理除以标准差
cfg.STD = 127.5, 127.5, 127.5 cfg.STD = 127.5, 127.5, 127.5
# 待预测图像输入尺寸 # 待预测图像输入尺寸
cfg.input_size = 1536, 576 cfg.input_size = 1536, 576
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import cv2 import cv2
import numpy as np import numpy as np
...@@ -12,18 +26,19 @@ config = importlib.import_module('config') ...@@ -12,18 +26,19 @@ config = importlib.import_module('config')
cfg = getattr(config, 'cfg') cfg = getattr(config, 'cfg')
# paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启 # paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启
os.environ['FLAGS_eager_delete_tensor_gb']='0.0' os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
import paddle.fluid as fluid import paddle.fluid as fluid
# 预测数据集类 # 预测数据集类
class TestDataSet(): class TestDataSet():
def __init__(self): def __init__(self):
self.data_dir = cfg.data_dir self.data_dir = cfg.data_dir
self.data_list_file = cfg.data_list_file self.data_list_file = cfg.data_list_file
self.data_list = self.get_data_list() self.data_list = self.get_data_list()
self.data_num = len(self.data_list) self.data_num = len(self.data_list)
def get_data_list(self): def get_data_list(self):
# 获取预测图像路径列表 # 获取预测图像路径列表
data_list = [] data_list = []
...@@ -40,7 +55,7 @@ class TestDataSet(): ...@@ -40,7 +55,7 @@ class TestDataSet():
def preprocess(self, img): def preprocess(self, img):
# 图像预处理 # 图像预处理
if cfg.example == 'ACE2P': if cfg.example == 'ACE2P':
reader = importlib.import_module(args.example+'.reader') reader = importlib.import_module(args.example + '.reader')
ACE2P_preprocess = getattr(reader, 'preprocess') ACE2P_preprocess = getattr(reader, 'preprocess')
img = ACE2P_preprocess(img) img = ACE2P_preprocess(img)
else: else:
...@@ -56,10 +71,10 @@ class TestDataSet(): ...@@ -56,10 +71,10 @@ class TestDataSet():
img_path = self.data_list[index] img_path = self.data_list[index]
img = cv2.imread(img_path, cv2.IMREAD_COLOR) img = cv2.imread(img_path, cv2.IMREAD_COLOR)
if img is None: if img is None:
return img, img,img_path, None return img, img, img_path, None
img_name = img_path.split(os.sep)[-1] img_name = img_path.split(os.sep)[-1]
name_prefix = img_name.replace('.'+img_name.split('.')[-1],'') name_prefix = img_name.replace('.' + img_name.split('.')[-1], '')
img_shape = img.shape[:2] img_shape = img.shape[:2]
img_process = self.preprocess(img) img_process = self.preprocess(img)
...@@ -90,39 +105,44 @@ def infer(): ...@@ -90,39 +105,44 @@ def infer():
if image is None: if image is None:
print(im_name, 'is None') print(im_name, 'is None')
continue continue
# 预测 # 预测
if cfg.example == 'ACE2P': if cfg.example == 'ACE2P':
# ACE2P模型使用多尺度预测 # ACE2P模型使用多尺度预测
reader = importlib.import_module(args.example+'.reader') reader = importlib.import_module(args.example + '.reader')
multi_scale_test = getattr(reader, 'multi_scale_test') multi_scale_test = getattr(reader, 'multi_scale_test')
parsing, logits = multi_scale_test(exe, test_prog, feed_name, fetch_list, image, im_shape) parsing, logits = multi_scale_test(exe, test_prog, feed_name,
fetch_list, image, im_shape)
else: else:
# HumanSeg,RoadLine模型单尺度预测 # HumanSeg,RoadLine模型单尺度预测
result = exe.run(program=test_prog, feed={feed_name[0]: image}, fetch_list=fetch_list) result = exe.run(
program=test_prog,
feed={feed_name[0]: image},
fetch_list=fetch_list)
parsing = np.argmax(result[0][0], axis=0) parsing = np.argmax(result[0][0], axis=0)
parsing = cv2.resize(parsing.astype(np.uint8), im_shape[::-1]) parsing = cv2.resize(parsing.astype(np.uint8), im_shape[::-1])
# 预测结果保存 # 预测结果保存
result_path = os.path.join(cfg.vis_dir, im_name + '.png') result_path = os.path.join(cfg.vis_dir, im_name + '.png')
if cfg.example == 'HumanSeg': if cfg.example == 'HumanSeg':
logits = result[0][0][1]*255 logits = result[0][0][1] * 255
logits = cv2.resize(logits, im_shape[::-1]) logits = cv2.resize(logits, im_shape[::-1])
ret, logits = cv2.threshold(logits, thresh, 0, cv2.THRESH_TOZERO) ret, logits = cv2.threshold(logits, thresh, 0, cv2.THRESH_TOZERO)
logits = 255 *(logits - thresh)/(255 - thresh) logits = 255 * (logits - thresh) / (255 - thresh)
# 将分割结果添加到alpha通道 # 将分割结果添加到alpha通道
rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)), axis=2) rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)),
axis=2)
cv2.imwrite(result_path, rgba) cv2.imwrite(result_path, rgba)
else: else:
output_im = PILImage.fromarray(np.asarray(parsing, dtype=np.uint8)) output_im = PILImage.fromarray(np.asarray(parsing, dtype=np.uint8))
output_im.putpalette(palette) output_im.putpalette(palette)
output_im.save(result_path) output_im.save(result_path)
if (idx + 1) % 100 == 0: if (idx + 1) % 100 == 0:
print('%d processd' % (idx + 1)) print('%d processd' % (idx + 1))
print('%d processd done' % (idx + 1)) print('%d processd done' % (idx + 1))
return 0 return 0
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # coding: utf8
## Created by: RainbowSecret # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
## Microsoft Research #
## yuyua@microsoft.com # Licensed under the Apache License, Version 2.0 (the "License");
## Copyright (c) 2018 # you may not use this file except in compliance with the License.
## # You may obtain a copy of the License at
## This source code is licensed under the MIT-style license found in the #
## LICENSE file in the root directory of this source tree # http://www.apache.org/licenses/LICENSE-2.0
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import argparse import argparse
import os import os
def get_arguments(): def get_arguments():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--use_gpu", parser.add_argument(
action="store_true", "--use_gpu", action="store_true", help="Use gpu or cpu to test.")
help="Use gpu or cpu to test.") parser.add_argument(
parser.add_argument('--example', '--example', type=str, help='RoadLine, HumanSeg or ACE2P')
type=str,
help='RoadLine, HumanSeg or ACE2P')
return parser.parse_args() return parser.parse_args()
...@@ -34,6 +48,7 @@ class AttrDict(dict): ...@@ -34,6 +48,7 @@ class AttrDict(dict):
else: else:
self[name] = value self[name] = value
def merge_cfg_from_args(args, cfg): def merge_cfg_from_args(args, cfg):
"""Merge config keys, values in args into the global config.""" """Merge config keys, values in args into the global config."""
for k, v in vars(args).items(): for k, v in vars(args).items():
...@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg): ...@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg):
value = v value = v
if value is not None: if value is not None:
cfg[k] = value cfg[k] = value
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -20,6 +21,8 @@ from PIL import Image ...@@ -20,6 +21,8 @@ from PIL import Image
import glob import glob
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__)) LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
def remove_colormap(filename): def remove_colormap(filename):
gray_anno = np.array(Image.open(filename)) gray_anno = np.array(Image.open(filename))
return gray_anno return gray_anno
...@@ -30,6 +33,7 @@ def save_annotation(annotation, filename): ...@@ -30,6 +33,7 @@ def save_annotation(annotation, filename):
annotation = Image.fromarray(annotation) annotation = Image.fromarray(annotation)
annotation.save(filename) annotation.save(filename)
def convert_list(origin_file, seg_file, output_folder): def convert_list(origin_file, seg_file, output_folder):
with open(seg_file, 'w') as fid_seg: with open(seg_file, 'w') as fid_seg:
with open(origin_file) as fid_ori: with open(origin_file) as fid_ori:
...@@ -43,6 +47,7 @@ def convert_list(origin_file, seg_file, output_folder): ...@@ -43,6 +47,7 @@ def convert_list(origin_file, seg_file, output_folder):
new_line = ' '.join([img_name, anno_name]) new_line = ' '.join([img_name, anno_name])
fid_seg.write(new_line + "\n") fid_seg.write(new_line + "\n")
if __name__ == "__main__": if __name__ == "__main__":
pascal_root = "./VOCtrainval_11-May-2012/VOC2012" pascal_root = "./VOCtrainval_11-May-2012/VOC2012"
pascal_root = os.path.join(LOCAL_PATH, pascal_root) pascal_root = os.path.join(LOCAL_PATH, pascal_root)
...@@ -54,7 +59,7 @@ if __name__ == "__main__": ...@@ -54,7 +59,7 @@ if __name__ == "__main__":
# 标注图转换后存储目录 # 标注图转换后存储目录
output_folder = os.path.join(pascal_root, "SegmentationClassAug") output_folder = os.path.join(pascal_root, "SegmentationClassAug")
print("annotation convert and file list convert") print("annotation convert and file list convert")
if not os.path.exists(os.path.join(LOCAL_PATH, output_folder)): if not os.path.exists(os.path.join(LOCAL_PATH, output_folder)):
os.mkdir(os.path.join(LOCAL_PATH, output_folder)) os.mkdir(os.path.join(LOCAL_PATH, output_folder))
...@@ -67,5 +72,5 @@ if __name__ == "__main__": ...@@ -67,5 +72,5 @@ if __name__ == "__main__":
convert_list(train_path, train_path.replace('txt', 'list'), output_folder) convert_list(train_path, train_path.replace('txt', 'list'), output_folder)
convert_list(val_path, val_path.replace('txt', 'list'), output_folder) convert_list(val_path, val_path.replace('txt', 'list'), output_folder)
convert_list(trainval_path, trainval_path.replace('txt', 'list'), output_folder) convert_list(trainval_path, trainval_path.replace('txt', 'list'),
output_folder)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -28,12 +29,12 @@ from convert_voc2012 import remove_colormap ...@@ -28,12 +29,12 @@ from convert_voc2012 import remove_colormap
from convert_voc2012 import save_annotation from convert_voc2012 import save_annotation
def download_VOC_dataset(savepath, extrapath): def download_VOC_dataset(savepath, extrapath):
url = "https://paddleseg.bj.bcebos.com/dataset/VOCtrainval_11-May-2012.tar" url = "https://paddleseg.bj.bcebos.com/dataset/VOCtrainval_11-May-2012.tar"
download_file_and_uncompress( download_file_and_uncompress(
url=url, savepath=savepath, extrapath=extrapath) url=url, savepath=savepath, extrapath=extrapath)
if __name__ == "__main__": if __name__ == "__main__":
download_VOC_dataset(LOCAL_PATH, LOCAL_PATH) download_VOC_dataset(LOCAL_PATH, LOCAL_PATH)
print("Dataset download finish!") print("Dataset download finish!")
...@@ -45,10 +46,10 @@ if __name__ == "__main__": ...@@ -45,10 +46,10 @@ if __name__ == "__main__":
train_path = os.path.join(txt_folder, "train.txt") train_path = os.path.join(txt_folder, "train.txt")
val_path = os.path.join(txt_folder, "val.txt") val_path = os.path.join(txt_folder, "val.txt")
trainval_path = os.path.join(txt_folder, "trainval.txt") trainval_path = os.path.join(txt_folder, "trainval.txt")
# 标注图转换后存储目录 # 标注图转换后存储目录
output_folder = os.path.join(pascal_root, "SegmentationClassAug") output_folder = os.path.join(pascal_root, "SegmentationClassAug")
print("annotation convert and file list convert") print("annotation convert and file list convert")
if not os.path.exists(output_folder): if not os.path.exists(output_folder):
os.mkdir(output_folder) os.mkdir(output_folder)
...@@ -61,5 +62,5 @@ if __name__ == "__main__": ...@@ -61,5 +62,5 @@ if __name__ == "__main__":
convert_list(train_path, train_path.replace('txt', 'list'), output_folder) convert_list(train_path, train_path.replace('txt', 'list'), output_folder)
convert_list(val_path, val_path.replace('txt', 'list'), output_folder) convert_list(val_path, val_path.replace('txt', 'list'), output_folder)
convert_list(trainval_path, trainval_path.replace('txt', 'list'), output_folder) convert_list(trainval_path, trainval_path.replace('txt', 'list'),
output_folder)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -31,7 +31,8 @@ gflags.DEFINE_string("conf", default="", help="Configuration File Path") ...@@ -31,7 +31,8 @@ gflags.DEFINE_string("conf", default="", help="Configuration File Path")
gflags.DEFINE_string("input_dir", default="", help="Directory of Input Images") gflags.DEFINE_string("input_dir", default="", help="Directory of Input Images")
gflags.DEFINE_boolean("use_pr", default=False, help="Use optimized model") gflags.DEFINE_boolean("use_pr", default=False, help="Use optimized model")
gflags.DEFINE_string("trt_mode", default="", help="Use optimized model") gflags.DEFINE_string("trt_mode", default="", help="Use optimized model")
gflags.DEFINE_string("ext", default=".jpeg|.jpg", help="Input Image File Extensions") gflags.DEFINE_string(
"ext", default=".jpeg|.jpg", help="Input Image File Extensions")
gflags.FLAGS = gflags.FLAGS gflags.FLAGS = gflags.FLAGS
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,4 +14,4 @@ ...@@ -14,4 +14,4 @@
# limitations under the License. # limitations under the License.
import models import models
import utils import utils
from . import tools from . import tools
\ No newline at end of file
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -427,12 +440,17 @@ def max_img_size_statistics(): ...@@ -427,12 +440,17 @@ def max_img_size_statistics():
logger.info("max width and max height of images are ({},{})".format( logger.info("max width and max height of images are ({},{})".format(
max_width, max_height)) max_width, max_height))
def num_classes_loss_matching_check(): def num_classes_loss_matching_check():
loss_type = cfg.SOLVER.LOSS loss_type = cfg.SOLVER.LOSS
num_classes = cfg.DATASET.NUM_CLASSES num_classes = cfg.DATASET.NUM_CLASSES
if num_classes > 2 and (("dice_loss" in loss_type) or ("bce_loss" in loss_type)): if num_classes > 2 and (("dice_loss" in loss_type) or
logger.info(error_print("loss check." ("bce_loss" in loss_type)):
" Dice loss and bce loss is only applicable to binary classfication")) logger.info(
error_print(
"loss check."
" Dice loss and bce loss is only applicable to binary classfication"
))
else: else:
logger.info(correct_print("loss check")) logger.info(correct_print("loss check"))
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -362,7 +362,7 @@ def hsv_color_jitter(crop_img, ...@@ -362,7 +362,7 @@ def hsv_color_jitter(crop_img,
saturation_jitter_ratio > 0 or \ saturation_jitter_ratio > 0 or \
contrast_jitter_ratio > 0: contrast_jitter_ratio > 0:
crop_img = random_jitter(crop_img, saturation_jitter_ratio, crop_img = random_jitter(crop_img, saturation_jitter_ratio,
brightness_jitter_ratio, contrast_jitter_ratio) brightness_jitter_ratio, contrast_jitter_ratio)
return crop_img return crop_img
...@@ -391,7 +391,7 @@ def rand_crop(crop_img, crop_seg, mode=ModelPhase.TRAIN): ...@@ -391,7 +391,7 @@ def rand_crop(crop_img, crop_seg, mode=ModelPhase.TRAIN):
crop_width = cfg.EVAL_CROP_SIZE[0] crop_width = cfg.EVAL_CROP_SIZE[0]
crop_height = cfg.EVAL_CROP_SIZE[1] crop_height = cfg.EVAL_CROP_SIZE[1]
if not ModelPhase.is_train(mode): if not ModelPhase.is_train(mode):
if (crop_height < img_height or crop_width < img_width): if (crop_height < img_height or crop_width < img_width):
raise Exception( raise Exception(
"Crop size({},{}) must large than img size({},{}) when in EvalPhase." "Crop size({},{}) must large than img size({},{}) when in EvalPhase."
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This code is based on https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py This code is based on https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py
""" """
...@@ -14,10 +28,10 @@ except ImportError: ...@@ -14,10 +28,10 @@ except ImportError:
class GeneratorEnqueuer(object): class GeneratorEnqueuer(object):
""" """
Multiple generators Multiple generators
Args: Args:
generators: generators:
wait_time (float): time to sleep in-between calls to `put()`. wait_time (float): time to sleep in-between calls to `put()`.
""" """
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -141,7 +141,7 @@ class ResNet(): ...@@ -141,7 +141,7 @@ class ResNet():
else: else:
conv_name = "res" + str(block + 2) + chr(97 + i) conv_name = "res" + str(block + 2) + chr(97 + i)
dilation_rate = get_dilated_rate(dilation_dict, block) dilation_rate = get_dilated_rate(dilation_dict, block)
conv = self.bottleneck_block( conv = self.bottleneck_block(
input=conv, input=conv,
num_filters=int(num_filters[block] * self.scale), num_filters=int(num_filters[block] * self.scale),
...@@ -215,11 +215,11 @@ class ResNet(): ...@@ -215,11 +215,11 @@ class ResNet():
groups=1, groups=1,
act=None, act=None,
name=None): name=None):
if self.stem == 'pspnet': if self.stem == 'pspnet':
bias_attr=ParamAttr(name=name + "_biases") bias_attr = ParamAttr(name=name + "_biases")
else: else:
bias_attr=False bias_attr = False
conv = fluid.layers.conv2d( conv = fluid.layers.conv2d(
input=input, input=input,
...@@ -238,13 +238,15 @@ class ResNet(): ...@@ -238,13 +238,15 @@ class ResNet():
bn_name = "bn_" + name bn_name = "bn_" + name
else: else:
bn_name = "bn" + name[3:] bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(input=conv, return fluid.layers.batch_norm(
act=act, input=conv,
name=bn_name + '.output.1', act=act,
param_attr=ParamAttr(name=bn_name + '_scale'), name=bn_name + '.output.1',
bias_attr=ParamAttr(bn_name + '_offset'), param_attr=ParamAttr(name=bn_name + '_scale'),
moving_mean_name=bn_name + '_mean', bias_attr=ParamAttr(bn_name + '_offset'),
moving_variance_name=bn_name + '_variance', ) moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',
)
def shortcut(self, input, ch_out, stride, is_first, name): def shortcut(self, input, ch_out, stride, is_first, name):
ch_in = input.shape[1] ch_in = input.shape[1]
...@@ -258,7 +260,7 @@ class ResNet(): ...@@ -258,7 +260,7 @@ class ResNet():
strides = [1, stride] strides = [1, stride]
else: else:
strides = [stride, 1] strides = [stride, 1]
conv0 = self.conv_bn_layer( conv0 = self.conv_bn_layer(
input=input, input=input,
num_filters=num_filters, num_filters=num_filters,
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -55,7 +55,8 @@ class VGGNet(): ...@@ -55,7 +55,8 @@ class VGGNet():
channels = [64, 128, 256, 512, 512] channels = [64, 128, 256, 512, 512]
conv = input conv = input
for i in range(len(nums)): for i in range(len(nums)):
conv = self.conv_block(conv, channels[i], nums[i], name="conv" + str(i + 1) + "_") conv = self.conv_block(
conv, channels[i], nums[i], name="conv" + str(i + 1) + "_")
layers_count += nums[i] layers_count += nums[i]
if check_points(layers_count, decode_points): if check_points(layers_count, decode_points):
short_cuts[layers_count] = conv short_cuts[layers_count] = conv
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -197,4 +197,4 @@ def conv_bn_layer(input, ...@@ -197,4 +197,4 @@ def conv_bn_layer(input,
if if_act: if if_act:
return fluid.layers.relu6(bn) return fluid.layers.relu6(bn)
else: else:
return bn return bn
\ No newline at end of file
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -25,7 +25,14 @@ from paddle.fluid.param_attr import ParamAttr ...@@ -25,7 +25,14 @@ from paddle.fluid.param_attr import ParamAttr
from utils.config import cfg from utils.config import cfg
def conv_bn_layer(input, filter_size, num_filters, stride=1, padding=1, num_groups=1, if_act=True, name=None): def conv_bn_layer(input,
filter_size,
num_filters,
stride=1,
padding=1,
num_groups=1,
if_act=True,
name=None):
conv = fluid.layers.conv2d( conv = fluid.layers.conv2d(
input=input, input=input,
num_filters=num_filters, num_filters=num_filters,
...@@ -37,37 +44,74 @@ def conv_bn_layer(input, filter_size, num_filters, stride=1, padding=1, num_grou ...@@ -37,37 +44,74 @@ def conv_bn_layer(input, filter_size, num_filters, stride=1, padding=1, num_grou
param_attr=ParamAttr(initializer=MSRA(), name=name + '_weights'), param_attr=ParamAttr(initializer=MSRA(), name=name + '_weights'),
bias_attr=False) bias_attr=False)
bn_name = name + '_bn' bn_name = name + '_bn'
bn = fluid.layers.batch_norm(input=conv, bn = fluid.layers.batch_norm(
param_attr=ParamAttr(name=bn_name + "_scale", input=conv,
initializer=fluid.initializer.Constant(1.0)), param_attr=ParamAttr(
bias_attr=ParamAttr(name=bn_name + "_offset", name=bn_name + "_scale",
initializer=fluid.initializer.Constant(0.0)), initializer=fluid.initializer.Constant(1.0)),
moving_mean_name=bn_name + '_mean', bias_attr=ParamAttr(
moving_variance_name=bn_name + '_variance') name=bn_name + "_offset",
initializer=fluid.initializer.Constant(0.0)),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act: if if_act:
bn = fluid.layers.relu(bn) bn = fluid.layers.relu(bn)
return bn return bn
def basic_block(input, num_filters, stride=1, downsample=False, name=None): def basic_block(input, num_filters, stride=1, downsample=False, name=None):
residual = input residual = input
conv = conv_bn_layer(input=input, filter_size=3, num_filters=num_filters, stride=stride, name=name + '_conv1') conv = conv_bn_layer(
conv = conv_bn_layer(input=conv, filter_size=3, num_filters=num_filters, if_act=False, name=name + '_conv2') input=input,
filter_size=3,
num_filters=num_filters,
stride=stride,
name=name + '_conv1')
conv = conv_bn_layer(
input=conv,
filter_size=3,
num_filters=num_filters,
if_act=False,
name=name + '_conv2')
if downsample: if downsample:
residual = conv_bn_layer(input=input, filter_size=1, num_filters=num_filters, if_act=False, residual = conv_bn_layer(
name=name + '_downsample') input=input,
filter_size=1,
num_filters=num_filters,
if_act=False,
name=name + '_downsample')
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu') return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def bottleneck_block(input, num_filters, stride=1, downsample=False, name=None): def bottleneck_block(input, num_filters, stride=1, downsample=False, name=None):
residual = input residual = input
conv = conv_bn_layer(input=input, filter_size=1, num_filters=num_filters, name=name + '_conv1') conv = conv_bn_layer(
conv = conv_bn_layer(input=conv, filter_size=3, num_filters=num_filters, stride=stride, name=name + '_conv2') input=input,
conv = conv_bn_layer(input=conv, filter_size=1, num_filters=num_filters * 4, if_act=False, filter_size=1,
name=name + '_conv3') num_filters=num_filters,
name=name + '_conv1')
conv = conv_bn_layer(
input=conv,
filter_size=3,
num_filters=num_filters,
stride=stride,
name=name + '_conv2')
conv = conv_bn_layer(
input=conv,
filter_size=1,
num_filters=num_filters * 4,
if_act=False,
name=name + '_conv3')
if downsample: if downsample:
residual = conv_bn_layer(input=input, filter_size=1, num_filters=num_filters * 4, if_act=False, residual = conv_bn_layer(
name=name + '_downsample') input=input,
filter_size=1,
num_filters=num_filters * 4,
if_act=False,
name=name + '_downsample')
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu') return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def fuse_layers(x, channels, multi_scale_output=True, name=None): def fuse_layers(x, channels, multi_scale_output=True, name=None):
out = [] out = []
for i in range(len(channels) if multi_scale_output else 1): for i in range(len(channels) if multi_scale_output else 1):
...@@ -77,40 +121,64 @@ def fuse_layers(x, channels, multi_scale_output=True, name=None): ...@@ -77,40 +121,64 @@ def fuse_layers(x, channels, multi_scale_output=True, name=None):
height = shape[-2] height = shape[-2]
for j in range(len(channels)): for j in range(len(channels)):
if j > i: if j > i:
y = conv_bn_layer(x[j], filter_size=1, num_filters=channels[i], if_act=False, y = conv_bn_layer(
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1)) x[j],
y = fluid.layers.resize_bilinear(input=y, out_shape=[height, width]) filter_size=1,
residual = fluid.layers.elementwise_add(x=residual, y=y, act=None) num_filters=channels[i],
if_act=False,
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1))
y = fluid.layers.resize_bilinear(
input=y, out_shape=[height, width])
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
elif j < i: elif j < i:
y = x[j] y = x[j]
for k in range(i - j): for k in range(i - j):
if k == i - j - 1: if k == i - j - 1:
y = conv_bn_layer(y, filter_size=3, num_filters=channels[i], stride=2, if_act=False, y = conv_bn_layer(
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1) + '_' + str(k + 1)) y,
filter_size=3,
num_filters=channels[i],
stride=2,
if_act=False,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1))
else: else:
y = conv_bn_layer(y, filter_size=3, num_filters=channels[j], stride=2, y = conv_bn_layer(
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1) + '_' + str(k + 1)) y,
residual = fluid.layers.elementwise_add(x=residual, y=y, act=None) filter_size=3,
num_filters=channels[j],
stride=2,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1))
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
residual = fluid.layers.relu(residual) residual = fluid.layers.relu(residual)
out.append(residual) out.append(residual)
return out return out
def branches(x, block_num, channels, name=None): def branches(x, block_num, channels, name=None):
out = [] out = []
for i in range(len(channels)): for i in range(len(channels)):
residual = x[i] residual = x[i]
for j in range(block_num): for j in range(block_num):
residual = basic_block(residual, channels[i], residual = basic_block(
name=name + '_branch_layer_' + str(i + 1) + '_' + str(j + 1)) residual,
channels[i],
name=name + '_branch_layer_' + str(i + 1) + '_' + str(j + 1))
out.append(residual) out.append(residual)
return out return out
def high_resolution_module(x, channels, multi_scale_output=True, name=None): def high_resolution_module(x, channels, multi_scale_output=True, name=None):
residual = branches(x, 4, channels, name=name) residual = branches(x, 4, channels, name=name)
out = fuse_layers(residual, channels, multi_scale_output=multi_scale_output, name=name) out = fuse_layers(
residual, channels, multi_scale_output=multi_scale_output, name=name)
return out return out
def transition_layer(x, in_channels, out_channels, name=None): def transition_layer(x, in_channels, out_channels, name=None):
num_in = len(in_channels) num_in = len(in_channels)
num_out = len(out_channels) num_out = len(out_channels)
...@@ -118,46 +186,76 @@ def transition_layer(x, in_channels, out_channels, name=None): ...@@ -118,46 +186,76 @@ def transition_layer(x, in_channels, out_channels, name=None):
for i in range(num_out): for i in range(num_out):
if i < num_in: if i < num_in:
if in_channels[i] != out_channels[i]: if in_channels[i] != out_channels[i]:
residual = conv_bn_layer(x[i], filter_size=3, num_filters=out_channels[i], residual = conv_bn_layer(
name=name + '_layer_' + str(i + 1)) x[i],
filter_size=3,
num_filters=out_channels[i],
name=name + '_layer_' + str(i + 1))
out.append(residual) out.append(residual)
else: else:
out.append(x[i]) out.append(x[i])
else: else:
residual = conv_bn_layer(x[-1], filter_size=3, num_filters=out_channels[i], stride=2, residual = conv_bn_layer(
name=name + '_layer_' + str(i + 1)) x[-1],
filter_size=3,
num_filters=out_channels[i],
stride=2,
name=name + '_layer_' + str(i + 1))
out.append(residual) out.append(residual)
return out return out
def stage(x, num_modules, channels, multi_scale_output=True, name=None): def stage(x, num_modules, channels, multi_scale_output=True, name=None):
out = x out = x
for i in range(num_modules): for i in range(num_modules):
if i == num_modules - 1 and multi_scale_output == False: if i == num_modules - 1 and multi_scale_output == False:
out = high_resolution_module(out, channels, multi_scale_output=False, name=name + '_' + str(i + 1)) out = high_resolution_module(
out,
channels,
multi_scale_output=False,
name=name + '_' + str(i + 1))
else: else:
out = high_resolution_module(out, channels, name=name + '_' + str(i + 1)) out = high_resolution_module(
out, channels, name=name + '_' + str(i + 1))
return out return out
def layer1(input, name=None): def layer1(input, name=None):
conv = input conv = input
for i in range(4): for i in range(4):
conv = bottleneck_block(conv, num_filters=64, downsample=True if i == 0 else False, conv = bottleneck_block(
name=name + '_' + str(i + 1)) conv,
num_filters=64,
downsample=True if i == 0 else False,
name=name + '_' + str(i + 1))
return conv return conv
def high_resolution_net(input, num_classes): def high_resolution_net(input, num_classes):
channels_2 = cfg.MODEL.HRNET.STAGE2.NUM_CHANNELS channels_2 = cfg.MODEL.HRNET.STAGE2.NUM_CHANNELS
channels_3 = cfg.MODEL.HRNET.STAGE3.NUM_CHANNELS channels_3 = cfg.MODEL.HRNET.STAGE3.NUM_CHANNELS
channels_4 = cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS channels_4 = cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS
num_modules_2 = cfg.MODEL.HRNET.STAGE2.NUM_MODULES num_modules_2 = cfg.MODEL.HRNET.STAGE2.NUM_MODULES
num_modules_3 = cfg.MODEL.HRNET.STAGE3.NUM_MODULES num_modules_3 = cfg.MODEL.HRNET.STAGE3.NUM_MODULES
num_modules_4 = cfg.MODEL.HRNET.STAGE4.NUM_MODULES num_modules_4 = cfg.MODEL.HRNET.STAGE4.NUM_MODULES
x = conv_bn_layer(input=input, filter_size=3, num_filters=64, stride=2, if_act=True, name='layer1_1') x = conv_bn_layer(
x = conv_bn_layer(input=x, filter_size=3, num_filters=64, stride=2, if_act=True, name='layer1_2') input=input,
filter_size=3,
num_filters=64,
stride=2,
if_act=True,
name='layer1_1')
x = conv_bn_layer(
input=x,
filter_size=3,
num_filters=64,
stride=2,
if_act=True,
name='layer1_2')
la1 = layer1(x, name='layer2') la1 = layer1(x, name='layer2')
tr1 = transition_layer([la1], [256], channels_2, name='tr1') tr1 = transition_layer([la1], [256], channels_2, name='tr1')
...@@ -170,18 +268,21 @@ def high_resolution_net(input, num_classes): ...@@ -170,18 +268,21 @@ def high_resolution_net(input, num_classes):
# upsample # upsample
shape = st4[0].shape shape = st4[0].shape
height, width = shape[-2], shape[-1] height, width = shape[-2], shape[-1]
st4[1] = fluid.layers.resize_bilinear( st4[1] = fluid.layers.resize_bilinear(st4[1], out_shape=[height, width])
st4[1], out_shape=[height, width]) st4[2] = fluid.layers.resize_bilinear(st4[2], out_shape=[height, width])
st4[2] = fluid.layers.resize_bilinear( st4[3] = fluid.layers.resize_bilinear(st4[3], out_shape=[height, width])
st4[2], out_shape=[height, width])
st4[3] = fluid.layers.resize_bilinear(
st4[3], out_shape=[height, width])
out = fluid.layers.concat(st4, axis=1) out = fluid.layers.concat(st4, axis=1)
last_channels = sum(channels_4) last_channels = sum(channels_4)
out = conv_bn_layer(input=out, filter_size=1, num_filters=last_channels, stride=1, if_act=True, name='conv-2') out = conv_bn_layer(
out= fluid.layers.conv2d( input=out,
filter_size=1,
num_filters=last_channels,
stride=1,
if_act=True,
name='conv-2')
out = fluid.layers.conv2d(
input=out, input=out,
num_filters=num_classes, num_filters=num_classes,
filter_size=1, filter_size=1,
...@@ -193,7 +294,6 @@ def high_resolution_net(input, num_classes): ...@@ -193,7 +294,6 @@ def high_resolution_net(input, num_classes):
out = fluid.layers.resize_bilinear(out, input.shape[2:]) out = fluid.layers.resize_bilinear(out, input.shape[2:])
return out return out
...@@ -201,6 +301,7 @@ def hrnet(input, num_classes): ...@@ -201,6 +301,7 @@ def hrnet(input, num_classes):
logit = high_resolution_net(input, num_classes) logit = high_resolution_net(input, num_classes)
return logit return logit
if __name__ == '__main__': if __name__ == '__main__':
image_shape = [-1, 3, 769, 769] image_shape = [-1, 3, 769, 769]
image = fluid.data(name='image', shape=image_shape, dtype='float32') image = fluid.data(name='image', shape=image_shape, dtype='float32')
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -24,6 +24,7 @@ from models.libs.model_libs import avg_pool, conv, bn ...@@ -24,6 +24,7 @@ from models.libs.model_libs import avg_pool, conv, bn
from models.backbone.resnet import ResNet as resnet_backbone from models.backbone.resnet import ResNet as resnet_backbone
from utils.config import cfg from utils.config import cfg
def get_logit_interp(input, num_classes, out_shape, name="logit"): def get_logit_interp(input, num_classes, out_shape, name="logit"):
# 根据类别数决定最后一层卷积输出, 并插值回原始尺寸 # 根据类别数决定最后一层卷积输出, 并插值回原始尺寸
param_attr = fluid.ParamAttr( param_attr = fluid.ParamAttr(
...@@ -33,16 +34,15 @@ def get_logit_interp(input, num_classes, out_shape, name="logit"): ...@@ -33,16 +34,15 @@ def get_logit_interp(input, num_classes, out_shape, name="logit"):
initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.01)) initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.01))
with scope(name): with scope(name):
logit = conv(input, logit = conv(
num_classes, input,
filter_size=1, num_classes,
param_attr=param_attr, filter_size=1,
bias_attr=True, param_attr=param_attr,
name=name+'_conv') bias_attr=True,
name=name + '_conv')
logit_interp = fluid.layers.resize_bilinear( logit_interp = fluid.layers.resize_bilinear(
logit, logit, out_shape=out_shape, name=name + '_interp')
out_shape=out_shape,
name=name+'_interp')
return logit_interp return logit_interp
...@@ -51,40 +51,44 @@ def psp_module(input, out_features): ...@@ -51,40 +51,44 @@ def psp_module(input, out_features):
# 输入:backbone输出的特征 # 输入:backbone输出的特征
# 输出:对输入进行不同尺度pooling, 卷积操作后插值回原始尺寸,并concat # 输出:对输入进行不同尺度pooling, 卷积操作后插值回原始尺寸,并concat
# 最后进行一个卷积及BN操作 # 最后进行一个卷积及BN操作
cat_layers = [] cat_layers = []
sizes = (1,2,3,6) sizes = (1, 2, 3, 6)
for size in sizes: for size in sizes:
psp_name = "psp" + str(size) psp_name = "psp" + str(size)
with scope(psp_name): with scope(psp_name):
pool = fluid.layers.adaptive_pool2d(input, pool = fluid.layers.adaptive_pool2d(
pool_size=[size, size], input,
pool_type='avg', pool_size=[size, size],
name=psp_name+'_adapool') pool_type='avg',
data = conv(pool, out_features, name=psp_name + '_adapool')
filter_size=1, data = conv(
bias_attr=True, pool,
name= psp_name + '_conv') out_features,
filter_size=1,
bias_attr=True,
name=psp_name + '_conv')
data_bn = bn(data, act='relu') data_bn = bn(data, act='relu')
interp = fluid.layers.resize_bilinear(data_bn, interp = fluid.layers.resize_bilinear(
out_shape=input.shape[2:], data_bn, out_shape=input.shape[2:], name=psp_name + '_interp')
name=psp_name+'_interp')
cat_layers.append(interp) cat_layers.append(interp)
cat_layers = [input] + cat_layers[::-1] cat_layers = [input] + cat_layers[::-1]
cat = fluid.layers.concat(cat_layers, axis=1, name='psp_cat') cat = fluid.layers.concat(cat_layers, axis=1, name='psp_cat')
psp_end_name = "psp_end" psp_end_name = "psp_end"
with scope(psp_end_name): with scope(psp_end_name):
data = conv(cat, data = conv(
out_features, cat,
filter_size=3, out_features,
padding=1, filter_size=3,
bias_attr=True, padding=1,
name=psp_end_name) bias_attr=True,
name=psp_end_name)
out = bn(data, act='relu') out = bn(data, act='relu')
return out return out
def resnet(input): def resnet(input):
# PSPNET backbone: resnet, 默认resnet50 # PSPNET backbone: resnet, 默认resnet50
# end_points: resnet终止层数 # end_points: resnet终止层数
...@@ -92,14 +96,14 @@ def resnet(input): ...@@ -92,14 +96,14 @@ def resnet(input):
scale = cfg.MODEL.PSPNET.DEPTH_MULTIPLIER scale = cfg.MODEL.PSPNET.DEPTH_MULTIPLIER
layers = cfg.MODEL.PSPNET.LAYERS layers = cfg.MODEL.PSPNET.LAYERS
end_points = layers - 1 end_points = layers - 1
dilation_dict = {2:2, 3:4} dilation_dict = {2: 2, 3: 4}
model = resnet_backbone(layers, scale, stem='pspnet') model = resnet_backbone(layers, scale, stem='pspnet')
data, _ = model.net(input, data, _ = model.net(
end_points=end_points, input, end_points=end_points, dilation_dict=dilation_dict)
dilation_dict=dilation_dict)
return data return data
def pspnet(input, num_classes): def pspnet(input, num_classes):
# Backbone: ResNet # Backbone: ResNet
res = resnet(input) res = resnet(input)
...@@ -109,4 +113,3 @@ def pspnet(input, num_classes): ...@@ -109,4 +113,3 @@ def pspnet(input, num_classes):
# 根据类别数决定最后一层卷积输出, 并插值回原始尺寸 # 根据类别数决定最后一层卷积输出, 并插值回原始尺寸
logit = get_logit_interp(dropout, num_classes, input.shape[2:]) logit = get_logit_interp(dropout, num_classes, input.shape[2:])
return logit return logit
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -71,7 +71,8 @@ class SegDataset(object): ...@@ -71,7 +71,8 @@ class SegDataset(object):
if self.shuffle and cfg.NUM_TRAINERS > 1: if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines) np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // cfg.NUM_TRAINERS num_lines = len(self.all_lines) // cfg.NUM_TRAINERS
self.lines = self.all_lines[num_lines * cfg.TRAINER_ID: num_lines * (cfg.TRAINER_ID + 1)] self.lines = self.all_lines[num_lines * cfg.TRAINER_ID:num_lines *
(cfg.TRAINER_ID + 1)]
self.shuffle_seed += 1 self.shuffle_seed += 1
elif self.shuffle: elif self.shuffle:
np.random.shuffle(self.lines) np.random.shuffle(self.lines)
...@@ -99,7 +100,8 @@ class SegDataset(object): ...@@ -99,7 +100,8 @@ class SegDataset(object):
if self.shuffle and cfg.NUM_TRAINERS > 1: if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines) np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // cfg.NUM_TRAINERS num_lines = len(self.all_lines) // cfg.NUM_TRAINERS
self.lines = self.all_lines[num_lines * cfg.TRAINER_ID: num_lines * (cfg.TRAINER_ID + 1)] self.lines = self.all_lines[num_lines * cfg.TRAINER_ID:num_lines *
(cfg.TRAINER_ID + 1)]
self.shuffle_seed += 1 self.shuffle_seed += 1
elif self.shuffle: elif self.shuffle:
np.random.shuffle(self.lines) np.random.shuffle(self.lines)
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,55 +21,48 @@ import warnings ...@@ -21,55 +21,48 @@ import warnings
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='PaddleSeg generate file list on cityscapes or your customized dataset.') description=
parser.add_argument( 'PaddleSeg generate file list on cityscapes or your customized dataset.'
'dataset_root',
help='dataset root directory',
type=str
) )
parser.add_argument('dataset_root', help='dataset root directory', type=str)
parser.add_argument( parser.add_argument(
'--type', '--type',
help='dataset type: \n' help='dataset type: \n'
'- cityscapes \n' '- cityscapes \n'
'- custom(default)', '- custom(default)',
default="custom", default="custom",
type=str type=str)
)
parser.add_argument( parser.add_argument(
'--separator', '--separator',
dest='separator', dest='separator',
help='file list separator', help='file list separator',
default="|", default="|",
type=str type=str)
)
parser.add_argument( parser.add_argument(
'--folder', '--folder',
help='the folder names of images and labels', help='the folder names of images and labels',
type=str, type=str,
nargs=2, nargs=2,
default=['images', 'annotations'] default=['images', 'annotations'])
)
parser.add_argument( parser.add_argument(
'--second_folder', '--second_folder',
help='the second-level folder names of train set, validation set, test set', help=
'the second-level folder names of train set, validation set, test set',
type=str, type=str,
nargs='*', nargs='*',
default=['train', 'val', 'test'] default=['train', 'val', 'test'])
)
parser.add_argument( parser.add_argument(
'--format', '--format',
help='data format of images and labels, e.g. jpg or png.', help='data format of images and labels, e.g. jpg or png.',
type=str, type=str,
nargs=2, nargs=2,
default=['jpg', 'png'] default=['jpg', 'png'])
)
parser.add_argument( parser.add_argument(
'--postfix', '--postfix',
help='postfix of images or labels', help='postfix of images or labels',
type=str, type=str,
nargs=2, nargs=2,
default=['', ''] default=['', ''])
)
return parser.parse_args() return parser.parse_args()
...@@ -120,15 +113,17 @@ def generate_list(args): ...@@ -120,15 +113,17 @@ def generate_list(args):
num_images = len(image_files) num_images = len(image_files)
if not label_files: if not label_files:
label_dir = os.path.join(dataset_root, args.folder[1], dataset_split) label_dir = os.path.join(dataset_root, args.folder[1],
dataset_split)
warnings.warn("No labels in {} !!!".format(label_dir)) warnings.warn("No labels in {} !!!".format(label_dir))
num_label = len(label_files) num_label = len(label_files)
if num_images != num_label and num_label > 0: if num_images != num_label and num_label > 0:
raise Exception("Number of images = {} number of labels = {} \n" raise Exception(
"Either number of images is equal to number of labels, " "Number of images = {} number of labels = {} \n"
"or number of labels is equal to 0.\n" "Either number of images is equal to number of labels, "
"Please check your dataset!".format(num_images, num_label)) "or number of labels is equal to 0.\n"
"Please check your dataset!".format(num_images, num_label))
file_list = os.path.join(dataset_root, dataset_split + '.txt') file_list = os.path.join(dataset_root, dataset_split + '.txt')
with open(file_list, "w") as f: with open(file_list, "w") as f:
......
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function from __future__ import print_function
import argparse import argparse
...@@ -11,16 +25,12 @@ from PIL import Image ...@@ -11,16 +25,12 @@ from PIL import Image
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter)
) parser.add_argument(
parser.add_argument('dir_or_file', 'dir_or_file', help='input gray label directory or file list path')
help='input gray label directory or file list path') parser.add_argument('output_dir', help='output colorful label directory')
parser.add_argument('output_dir', parser.add_argument('--dataset_dir', help='dataset directory')
help='output colorful label directory') parser.add_argument('--file_separator', help='file list separator')
parser.add_argument('--dataset_dir',
help='dataset directory')
parser.add_argument('--file_separator',
help='file list separator')
return parser.parse_args() return parser.parse_args()
......
#!/usr/bin/env python # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function from __future__ import print_function
......
#!/usr/bin/env python # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function from __future__ import print_function
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -98,7 +99,7 @@ class SegConfig(dict): ...@@ -98,7 +99,7 @@ class SegConfig(dict):
'DATASET.IMAGE_TYPE config error, only support `rgb`, `gray` and `rgba`' 'DATASET.IMAGE_TYPE config error, only support `rgb`, `gray` and `rgba`'
) )
if self.MEAN is not None: if self.MEAN is not None:
self.DATASET.PADDING_VALUE = [x*255.0 for x in self.MEAN] self.DATASET.PADDING_VALUE = [x * 255.0 for x in self.MEAN]
if not self.TRAIN_CROP_SIZE: if not self.TRAIN_CROP_SIZE:
raise ValueError( raise ValueError(
...@@ -111,9 +112,12 @@ class SegConfig(dict): ...@@ -111,9 +112,12 @@ class SegConfig(dict):
) )
# Ensure file list is use UTF-8 encoding # Ensure file list is use UTF-8 encoding
train_sets = codecs.open(self.DATASET.TRAIN_FILE_LIST, 'r', 'utf-8').readlines() train_sets = codecs.open(self.DATASET.TRAIN_FILE_LIST, 'r',
val_sets = codecs.open(self.DATASET.VAL_FILE_LIST, 'r', 'utf-8').readlines() 'utf-8').readlines()
test_sets = codecs.open(self.DATASET.TEST_FILE_LIST, 'r', 'utf-8').readlines() val_sets = codecs.open(self.DATASET.VAL_FILE_LIST, 'r',
'utf-8').readlines()
test_sets = codecs.open(self.DATASET.TEST_FILE_LIST, 'r',
'utf-8').readlines()
self.DATASET.TRAIN_TOTAL_IMAGES = len(train_sets) self.DATASET.TRAIN_TOTAL_IMAGES = len(train_sets)
self.DATASET.VAL_TOTAL_IMAGES = len(val_sets) self.DATASET.VAL_TOTAL_IMAGES = len(val_sets)
self.DATASET.TEST_TOTAL_IMAGES = len(test_sets) self.DATASET.TEST_TOTAL_IMAGES = len(test_sets)
...@@ -122,12 +126,13 @@ class SegConfig(dict): ...@@ -122,12 +126,13 @@ class SegConfig(dict):
len(self.MODEL.MULTI_LOSS_WEIGHT) != 3: len(self.MODEL.MULTI_LOSS_WEIGHT) != 3:
self.MODEL.MULTI_LOSS_WEIGHT = [1.0, 0.4, 0.16] self.MODEL.MULTI_LOSS_WEIGHT = [1.0, 0.4, 0.16]
if self.AUG.AUG_METHOD not in ['unpadding', 'stepscaling', 'rangescaling']: if self.AUG.AUG_METHOD not in [
'unpadding', 'stepscaling', 'rangescaling'
]:
raise ValueError( raise ValueError(
'AUG.AUG_METHOD config error, only support `unpadding`, `unpadding` and `rangescaling`' 'AUG.AUG_METHOD config error, only support `unpadding`, `unpadding` and `rangescaling`'
) )
def update_from_list(self, config_list): def update_from_list(self, config_list):
if len(config_list) % 2 != 0: if len(config_list) % 2 != 0:
raise ValueError( raise ValueError(
......
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. #Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); #Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. #you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
from paddle import fluid from paddle import fluid
def load_fp16_vars(executor, dirname, program): def load_fp16_vars(executor, dirname, program):
load_dirname = os.path.normpath(dirname) load_dirname = os.path.normpath(dirname)
...@@ -28,4 +44,4 @@ def load_fp16_vars(executor, dirname, program): ...@@ -28,4 +44,4 @@ def load_fp16_vars(executor, dirname, program):
'load_as_fp16': var.dtype == fluid.core.VarDesc.VarType.FP16 'load_as_fp16': var.dtype == fluid.core.VarDesc.VarType.FP16
}) })
executor.run(load_prog) executor.run(load_prog)
\ No newline at end of file
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -42,7 +43,7 @@ model_urls = { ...@@ -42,7 +43,7 @@ model_urls = {
"hrnet_w30_bn_imagenet": "hrnet_w30_bn_imagenet":
"https://paddleseg.bj.bcebos.com/models/hrnet_w30_imagenet.tar", "https://paddleseg.bj.bcebos.com/models/hrnet_w30_imagenet.tar",
"hrnet_w32_bn_imagenet": "hrnet_w32_bn_imagenet":
"https://paddleseg.bj.bcebos.com/models/hrnet_w32_imagenet.tar" , "https://paddleseg.bj.bcebos.com/models/hrnet_w32_imagenet.tar",
"hrnet_w40_bn_imagenet": "hrnet_w40_bn_imagenet":
"https://paddleseg.bj.bcebos.com/models/hrnet_w40_imagenet.tar", "https://paddleseg.bj.bcebos.com/models/hrnet_w40_imagenet.tar",
"hrnet_w44_bn_imagenet": "hrnet_w44_bn_imagenet":
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -27,6 +27,7 @@ from models.libs.model_libs import separate_conv ...@@ -27,6 +27,7 @@ from models.libs.model_libs import separate_conv
from models.backbone.mobilenet_v2 import MobileNetV2 as mobilenet_backbone from models.backbone.mobilenet_v2 import MobileNetV2 as mobilenet_backbone
from models.backbone.xception import Xception as xception_backbone from models.backbone.xception import Xception as xception_backbone
def encoder(input): def encoder(input):
# 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv # 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv
# ASPP_WITH_SEP_CONV:默认为真,使用depthwise可分离卷积,否则使用普通卷积 # ASPP_WITH_SEP_CONV:默认为真,使用depthwise可分离卷积,否则使用普通卷积
...@@ -47,8 +48,7 @@ def encoder(input): ...@@ -47,8 +48,7 @@ def encoder(input):
with scope('encoder'): with scope('encoder'):
channel = 256 channel = 256
with scope("image_pool"): with scope("image_pool"):
image_avg = fluid.layers.reduce_mean( image_avg = fluid.layers.reduce_mean(input, [2, 3], keep_dim=True)
input, [2, 3], keep_dim=True)
image_avg = bn_relu( image_avg = bn_relu(
conv( conv(
image_avg, image_avg,
...@@ -191,7 +191,10 @@ def nas_backbone(input, arch): ...@@ -191,7 +191,10 @@ def nas_backbone(input, arch):
end_points = 8 end_points = 8
decode_point = 3 decode_point = 3
data, decode_shortcuts = arch( data, decode_shortcuts = arch(
input, end_points=end_points, return_block=decode_point, output_stride=16) input,
end_points=end_points,
return_block=decode_point,
output_stride=16)
decode_shortcut = decode_shortcuts[decode_point] decode_shortcut = decode_shortcuts[decode_point]
return data, decode_shortcut return data, decode_shortcut
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -31,7 +32,7 @@ __all__ = ["MobileNetV2SpaceSeg"] ...@@ -31,7 +32,7 @@ __all__ = ["MobileNetV2SpaceSeg"]
class MobileNetV2SpaceSeg(SearchSpaceBase): class MobileNetV2SpaceSeg(SearchSpaceBase):
def __init__(self, input_size, output_size, block_num, block_mask=None): def __init__(self, input_size, output_size, block_num, block_mask=None):
super(MobileNetV2SpaceSeg, self).__init__(input_size, output_size, super(MobileNetV2SpaceSeg, self).__init__(input_size, output_size,
block_num, block_mask) block_num, block_mask)
# self.head_num means the first convolution channel # self.head_num means the first convolution channel
self.head_num = np.array([3, 4, 8, 12, 16, 24, 32]) #7 self.head_num = np.array([3, 4, 8, 12, 16, 24, 32]) #7
# self.filter_num1 ~ self.filter_num6 means following convlution channel # self.filter_num1 ~ self.filter_num6 means following convlution channel
...@@ -48,7 +49,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase): ...@@ -48,7 +49,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
self.k_size = np.array([3, 5]) #2 self.k_size = np.array([3, 5]) #2
# self.multiply means expansion_factor of each _inverted_residual_unit # self.multiply means expansion_factor of each _inverted_residual_unit
self.multiply = np.array([1, 2, 3, 4, 6]) #5 self.multiply = np.array([1, 2, 3, 4, 6]) #5
# self.repeat means repeat_num _inverted_residual_unit in each _invresi_blocks # self.repeat means repeat_num _inverted_residual_unit in each _invresi_blocks
self.repeat = np.array([1, 2, 3, 4, 5, 6]) #6 self.repeat = np.array([1, 2, 3, 4, 5, 6]) #6
def init_tokens(self): def init_tokens(self):
...@@ -72,7 +73,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase): ...@@ -72,7 +73,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
def range_table(self): def range_table(self):
""" """
Get range table of current search space, constrains the range of tokens. Get range table of current search space, constrains the range of tokens.
""" """
# head_num + 6 * [multiple(expansion_factor), filter_num, repeat, kernel_size] # head_num + 6 * [multiple(expansion_factor), filter_num, repeat, kernel_size]
# yapf: disable # yapf: disable
...@@ -95,8 +96,8 @@ class MobileNetV2SpaceSeg(SearchSpaceBase): ...@@ -95,8 +96,8 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
tokens = self.init_tokens() tokens = self.init_tokens()
self.bottleneck_params_list = [] self.bottleneck_params_list = []
self.bottleneck_params_list.append( self.bottleneck_params_list.append((1, self.head_num[tokens[0]], 1, 1,
(1, self.head_num[tokens[0]], 1, 1, 3)) 3))
self.bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.multiply[tokens[1]], self.filter_num1[tokens[2]], (self.multiply[tokens[1]], self.filter_num1[tokens[2]],
self.repeat[tokens[3]], 2, self.k_size[tokens[4]])) self.repeat[tokens[3]], 2, self.k_size[tokens[4]]))
...@@ -150,7 +151,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase): ...@@ -150,7 +151,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
return (True if count == points else False) return (True if count == points else False)
#conv1 #conv1
# all padding is 'SAME' in the conv2d, can compute the actual padding automatic. # all padding is 'SAME' in the conv2d, can compute the actual padding automatic.
input = conv_bn_layer( input = conv_bn_layer(
input, input,
num_filters=int(32 * self.scale), num_filters=int(32 * self.scale),
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册