未验证 提交 664dc115 编写于 作者: W wuyefeilin 提交者: GitHub

Unifiy copyright format (#261)

* update model save load

* first add

* update model save and load

* update train.py

* update LaneNet model saving and loading

* adapt slim to paddle-1.8

* update distillation save and load

* update nas model save and load

* update model load op

* update utils.py

* update load_model_utils.py

* update model saving and loading

* update copyright

* update palette.py
上级 020d1072
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from utils.util import AttrDict, merge_cfg_from_args, get_arguments
import os
......@@ -19,10 +33,10 @@ cfg.class_num = 20
# 均值, 图像预处理减去的均值
cfg.MEAN = 0.406, 0.456, 0.485
# 标准差,图像预处理除以标准差
cfg.STD = 0.225, 0.224, 0.229
cfg.STD = 0.225, 0.224, 0.229
# 多尺度预测时图像尺寸
cfg.multi_scales = (377,377), (473,473), (567,567)
cfg.multi_scales = (377, 377), (473, 473), (567, 567)
# 多尺度预测时图像是否水平翻转
cfg.flip = True
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import numpy as np
......@@ -12,18 +26,19 @@ config = importlib.import_module('config')
cfg = getattr(config, 'cfg')
# paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启
os.environ['FLAGS_eager_delete_tensor_gb']='0.0'
os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
import paddle.fluid as fluid
# 预测数据集类
class TestDataSet():
def __init__(self):
self.data_dir = cfg.data_dir
self.data_dir = cfg.data_dir
self.data_list_file = cfg.data_list_file
self.data_list = self.get_data_list()
self.data_num = len(self.data_list)
def get_data_list(self):
# 获取预测图像路径列表
data_list = []
......@@ -56,10 +71,10 @@ class TestDataSet():
img_path = self.data_list[index]
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
if img is None:
return img, img,img_path, None
return img, img, img_path, None
img_name = img_path.split(os.sep)[-1]
name_prefix = img_name.replace('.'+img_name.split('.')[-1],'')
name_prefix = img_name.replace('.' + img_name.split('.')[-1], '')
img_shape = img.shape[:2]
img_process = self.preprocess(img)
......@@ -90,39 +105,44 @@ def infer():
if image is None:
print(im_name, 'is None')
continue
# 预测
if cfg.example == 'ACE2P':
# ACE2P模型使用多尺度预测
reader = importlib.import_module('reader')
multi_scale_test = getattr(reader, 'multi_scale_test')
parsing, logits = multi_scale_test(exe, test_prog, feed_name, fetch_list, image, im_shape)
parsing, logits = multi_scale_test(exe, test_prog, feed_name,
fetch_list, image, im_shape)
else:
# HumanSeg,RoadLine模型单尺度预测
result = exe.run(program=test_prog, feed={feed_name[0]: image}, fetch_list=fetch_list)
result = exe.run(
program=test_prog,
feed={feed_name[0]: image},
fetch_list=fetch_list)
parsing = np.argmax(result[0][0], axis=0)
parsing = cv2.resize(parsing.astype(np.uint8), im_shape[::-1])
# 预测结果保存
result_path = os.path.join(cfg.vis_dir, im_name + '.png')
if cfg.example == 'HumanSeg':
logits = result[0][0][1]*255
logits = result[0][0][1] * 255
logits = cv2.resize(logits, im_shape[::-1])
ret, logits = cv2.threshold(logits, thresh, 0, cv2.THRESH_TOZERO)
logits = 255 *(logits - thresh)/(255 - thresh)
logits = 255 * (logits - thresh) / (255 - thresh)
# 将分割结果添加到alpha通道
rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)), axis=2)
rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)),
axis=2)
cv2.imwrite(result_path, rgba)
else:
else:
output_im = PILImage.fromarray(np.asarray(parsing, dtype=np.uint8))
output_im.putpalette(palette)
output_im.save(result_path)
if (idx + 1) % 100 == 0:
print('%d processd' % (idx + 1))
print('%d processd done' % (idx + 1))
print('%d processd done' % (idx + 1))
return 0
......
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
from config import cfg
import cv2
def get_affine_points(src_shape, dst_shape, rot_grad=0):
# 获取图像和仿射后图像的三组对应点坐标
# 三组点为仿射变换后图像的中心点, [w/2,0], [0,0],及对应原始图像的点
......@@ -23,7 +38,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0):
# 原始图像三组点
points = [[0, 0]] * 3
points[0] = (np.array([w, h]) - 1) * 0.5
points[0] = (np.array([w, h]) - 1) * 0.5
points[1] = points[0] + 0.5 * affine_shape[0] * np.array([sin_v, -cos_v])
points[2] = points[1] - 0.5 * affine_shape[1] * np.array([cos_v, sin_v])
......@@ -34,6 +49,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0):
return points, points_trans
def preprocess(im):
# ACE2P模型数据预处理
im_shape = im.shape[:2]
......@@ -42,13 +58,10 @@ def preprocess(im):
# 获取图像和仿射变换后图像的对应点坐标
points, points_trans = get_affine_points(im_shape, scale)
# 根据对应点集获得仿射矩阵
trans = cv2.getAffineTransform(np.float32(points),
np.float32(points_trans))
trans = cv2.getAffineTransform(
np.float32(points), np.float32(points_trans))
# 根据仿射矩阵对图像进行仿射
input = cv2.warpAffine(im,
trans,
scale[::-1],
flags=cv2.INTER_LINEAR)
input = cv2.warpAffine(im, trans, scale[::-1], flags=cv2.INTER_LINEAR)
# 减均值测,除以方差,转换数据格式为NCHW
input = input.astype(np.float32)
......@@ -66,19 +79,20 @@ def preprocess(im):
return input_images
def multi_scale_test(exe, test_prog, feed_name, fetch_list,
input_ims, im_shape):
def multi_scale_test(exe, test_prog, feed_name, fetch_list, input_ims,
im_shape):
# 由于部分类别分左右部位, flipped_idx为其水平翻转后对应的标签
flipped_idx = (15, 14, 17, 16, 19, 18)
ms_outputs = []
# 多尺度预测
for idx, scale in enumerate(cfg.multi_scales):
input_im = input_ims[idx]
parsing_output = exe.run(program=test_prog,
feed={feed_name[0]: input_im},
fetch_list=fetch_list)
parsing_output = exe.run(
program=test_prog,
feed={feed_name[0]: input_im},
fetch_list=fetch_list)
output = parsing_output[0][0]
if cfg.flip:
# 若水平翻转,对部分类别进行翻转,与原始预测结果取均值
......@@ -92,7 +106,8 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list,
# 仿射变换回图像原始尺寸
points, points_trans = get_affine_points(im_shape, scale)
M = cv2.getAffineTransform(np.float32(points_trans), np.float32(points))
logits_result = cv2.warpAffine(output, M, im_shape[::-1], flags=cv2.INTER_LINEAR)
logits_result = cv2.warpAffine(
output, M, im_shape[::-1], flags=cv2.INTER_LINEAR)
ms_outputs.append(logits_result)
# 多尺度预测结果求均值,求预测概率最大的类别
......@@ -100,4 +115,3 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list,
ms_fused_parsing_output = np.mean(ms_fused_parsing_output, axis=0)
parsing = np.argmax(ms_fused_parsing_output, axis=2)
return parsing, ms_fused_parsing_output
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -7,6 +7,7 @@
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import os
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--use_gpu",
action="store_true",
help="Use gpu or cpu to test.")
parser.add_argument('--example',
type=str,
help='RoadLine, HumanSeg or ACE2P')
parser.add_argument(
"--use_gpu", action="store_true", help="Use gpu or cpu to test.")
parser.add_argument(
'--example', type=str, help='RoadLine, HumanSeg or ACE2P')
return parser.parse_args()
......@@ -34,6 +48,7 @@ class AttrDict(dict):
else:
self[name] = value
def merge_cfg_from_args(args, cfg):
"""Merge config keys, values in args into the global config."""
for k, v in vars(args).items():
......@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg):
value = v
if value is not None:
cfg[k] = value
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,9 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# utils for memory management which is allocated on sharedmemory,
# note that these structures may not be thread-safe
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import models
import argparse
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import os.path as osp
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .humanseg import HumanSegMobile
from .humanseg import HumanSegServer
from .humanseg import HumanSegLite
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .backbone import mobilenet_v2
from .backbone import xception
from .deeplabv3p import DeepLabv3p
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .mobilenet_v2 import MobileNetV2
from .xception import Xception
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -10,6 +11,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from datasets.dataset import Dataset
import transforms
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from datasets.dataset import Dataset
from models import HumanSegMobile, HumanSegLite, HumanSegServer
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from datasets.dataset import Dataset
from models import HumanSegMobile, HumanSegLite, HumanSegServer
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import cv2
import os
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from datasets.dataset import Dataset
import transforms
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import os.path as osp
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -21,6 +21,7 @@ from models.model_builder import ModelPhase
from pdseg.data_aug import get_random_scale, randomly_scale_image_and_label, random_rotation, \
rand_scale_aspect, hsv_color_jitter, rand_crop
def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN):
"""
改变图像及标签图像尺寸
......@@ -44,7 +45,8 @@ def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN):
if grt is not None:
grt = cv2.resize(grt, target_size, interpolation=cv2.INTER_NEAREST)
if grt_instance is not None:
grt_instance = cv2.resize(grt_instance, target_size, interpolation=cv2.INTER_NEAREST)
grt_instance = cv2.resize(
grt_instance, target_size, interpolation=cv2.INTER_NEAREST)
elif cfg.AUG.AUG_METHOD == 'stepscaling':
if mode == ModelPhase.TRAIN:
min_scale_factor = cfg.AUG.MIN_SCALE_FACTOR
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,7 +18,6 @@ from __future__ import print_function
import paddle.fluid as fluid
from utils.config import cfg
from pdseg.models.libs.model_libs import scope, name_scope
from pdseg.models.libs.model_libs import bn, bn_relu, relu
......@@ -86,7 +85,12 @@ def bottleneck(inputs,
with scope('down_sample'):
inputs_shape = inputs.shape
with scope('main_max_pool'):
net_main = fluid.layers.conv2d(inputs, inputs_shape[1], filter_size=3, stride=2, padding='SAME')
net_main = fluid.layers.conv2d(
inputs,
inputs_shape[1],
filter_size=3,
stride=2,
padding='SAME')
#First get the difference in depth to pad, then pad with zeros only on the last dimension.
depth_to_pad = abs(inputs_shape[1] - output_depth)
......@@ -95,12 +99,16 @@ def bottleneck(inputs,
net_main = fluid.layers.pad(net_main, paddings=paddings)
with scope('block1'):
net = conv(inputs, reduced_depth, [2, 2], stride=2, padding='same')
net = conv(
inputs, reduced_depth, [2, 2], stride=2, padding='same')
net = bn(net)
net = prelu(net, decoder=decoder)
with scope('block2'):
net = conv(net, reduced_depth, [filter_size, filter_size], padding='same')
net = conv(
net,
reduced_depth, [filter_size, filter_size],
padding='same')
net = bn(net)
net = prelu(net, decoder=decoder)
......@@ -137,13 +145,18 @@ def bottleneck(inputs,
# Second conv block --- apply dilated convolution here
with scope('block2'):
net = conv(net, reduced_depth, filter_size, padding='SAME', dilation=dilation_rate)
net = conv(
net,
reduced_depth,
filter_size,
padding='SAME',
dilation=dilation_rate)
net = bn(net)
net = prelu(net, decoder=decoder)
# Final projection with 1x1 kernel (Expansion)
with scope('block3'):
net = conv(net, output_depth, [1,1])
net = conv(net, output_depth, [1, 1])
net = bn(net)
net = prelu(net, decoder=decoder)
......@@ -172,9 +185,11 @@ def bottleneck(inputs,
# Second conv block --- apply asymmetric conv here
with scope('block2'):
with scope('asymmetric_conv2a'):
net = conv(net, reduced_depth, [filter_size, 1], padding='same')
net = conv(
net, reduced_depth, [filter_size, 1], padding='same')
with scope('asymmetric_conv2b'):
net = conv(net, reduced_depth, [1, filter_size], padding='same')
net = conv(
net, reduced_depth, [1, filter_size], padding='same')
net = bn(net)
net = prelu(net, decoder=decoder)
......@@ -211,7 +226,8 @@ def bottleneck(inputs,
with scope('unpool'):
net_unpool = conv(inputs, output_depth, [1, 1])
net_unpool = bn(net_unpool)
net_unpool = fluid.layers.resize_bilinear(net_unpool, out_shape=output_shape[2:])
net_unpool = fluid.layers.resize_bilinear(
net_unpool, out_shape=output_shape[2:])
# First 1x1 projection to reduce depth
with scope('block1'):
......@@ -220,7 +236,12 @@ def bottleneck(inputs,
net = prelu(net, decoder=decoder)
with scope('block2'):
net = deconv(net, reduced_depth, filter_size=filter_size, stride=2, padding='same')
net = deconv(
net,
reduced_depth,
filter_size=filter_size,
stride=2,
padding='same')
net = bn(net)
net = prelu(net, decoder=decoder)
......@@ -253,7 +274,10 @@ def bottleneck(inputs,
# Second conv block
with scope('block2'):
net = conv(net, reduced_depth, [filter_size, filter_size], padding='same')
net = conv(
net,
reduced_depth, [filter_size, filter_size],
padding='same')
net = bn(net)
net = prelu(net, decoder=decoder)
......@@ -281,17 +305,33 @@ def ENet_stage1(inputs, name_scope='stage1_block'):
= bottleneck(inputs, output_depth=64, filter_size=3, regularizer_prob=0.01, type=DOWNSAMPLING,
name_scope='bottleneck1_0')
with scope('bottleneck1_1'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01,
name_scope='bottleneck1_1')
net = bottleneck(
net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_1')
with scope('bottleneck1_2'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01,
name_scope='bottleneck1_2')
net = bottleneck(
net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_2')
with scope('bottleneck1_3'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01,
name_scope='bottleneck1_3')
net = bottleneck(
net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_3')
with scope('bottleneck1_4'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01,
name_scope='bottleneck1_4')
net = bottleneck(
net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_4')
return net, inputs_shape_1
......@@ -302,17 +342,38 @@ def ENet_stage2(inputs, name_scope='stage2_block'):
name_scope='bottleneck2_0')
for i in range(2):
with scope('bottleneck2_{}'.format(str(4 * i + 1))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1,
name_scope='bottleneck2_{}'.format(str(4 * i + 1)))
net = bottleneck(
net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
name_scope='bottleneck2_{}'.format(str(4 * i + 1)))
with scope('bottleneck2_{}'.format(str(4 * i + 2))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+1)),
name_scope='bottleneck2_{}'.format(str(4 * i + 2)))
net = bottleneck(
net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 1)),
name_scope='bottleneck2_{}'.format(str(4 * i + 2)))
with scope('bottleneck2_{}'.format(str(4 * i + 3))):
net = bottleneck(net, output_depth=128, filter_size=5, regularizer_prob=0.1, type=ASYMMETRIC,
name_scope='bottleneck2_{}'.format(str(4 * i + 3)))
net = bottleneck(
net,
output_depth=128,
filter_size=5,
regularizer_prob=0.1,
type=ASYMMETRIC,
name_scope='bottleneck2_{}'.format(str(4 * i + 3)))
with scope('bottleneck2_{}'.format(str(4 * i + 4))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+2)),
name_scope='bottleneck2_{}'.format(str(4 * i + 4)))
net = bottleneck(
net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 2)),
name_scope='bottleneck2_{}'.format(str(4 * i + 4)))
return net, inputs_shape_2
......@@ -320,52 +381,106 @@ def ENet_stage3(inputs, name_scope='stage3_block'):
with scope(name_scope):
for i in range(2):
with scope('bottleneck3_{}'.format(str(4 * i + 0))):
net = bottleneck(inputs, output_depth=128, filter_size=3, regularizer_prob=0.1,
name_scope='bottleneck3_{}'.format(str(4 * i + 0)))
net = bottleneck(
inputs,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
name_scope='bottleneck3_{}'.format(str(4 * i + 0)))
with scope('bottleneck3_{}'.format(str(4 * i + 1))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+1)),
name_scope='bottleneck3_{}'.format(str(4 * i + 1)))
net = bottleneck(
net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 1)),
name_scope='bottleneck3_{}'.format(str(4 * i + 1)))
with scope('bottleneck3_{}'.format(str(4 * i + 2))):
net = bottleneck(net, output_depth=128, filter_size=5, regularizer_prob=0.1, type=ASYMMETRIC,
name_scope='bottleneck3_{}'.format(str(4 * i + 2)))
net = bottleneck(
net,
output_depth=128,
filter_size=5,
regularizer_prob=0.1,
type=ASYMMETRIC,
name_scope='bottleneck3_{}'.format(str(4 * i + 2)))
with scope('bottleneck3_{}'.format(str(4 * i + 3))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+2)),
name_scope='bottleneck3_{}'.format(str(4 * i + 3)))
net = bottleneck(
net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 2)),
name_scope='bottleneck3_{}'.format(str(4 * i + 3)))
return net
def ENet_stage4(inputs, inputs_shape, connect_tensor,
skip_connections=True, name_scope='stage4_block'):
def ENet_stage4(inputs,
inputs_shape,
connect_tensor,
skip_connections=True,
name_scope='stage4_block'):
with scope(name_scope):
with scope('bottleneck4_0'):
net = bottleneck(inputs, output_depth=64, filter_size=3, regularizer_prob=0.1,
type=UPSAMPLING, decoder=True, output_shape=inputs_shape,
name_scope='bottleneck4_0')
net = bottleneck(
inputs,
output_depth=64,
filter_size=3,
regularizer_prob=0.1,
type=UPSAMPLING,
decoder=True,
output_shape=inputs_shape,
name_scope='bottleneck4_0')
if skip_connections:
net = fluid.layers.elementwise_add(net, connect_tensor)
with scope('bottleneck4_1'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.1, decoder=True,
name_scope='bottleneck4_1')
net = bottleneck(
net,
output_depth=64,
filter_size=3,
regularizer_prob=0.1,
decoder=True,
name_scope='bottleneck4_1')
with scope('bottleneck4_2'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.1, decoder=True,
name_scope='bottleneck4_2')
net = bottleneck(
net,
output_depth=64,
filter_size=3,
regularizer_prob=0.1,
decoder=True,
name_scope='bottleneck4_2')
return net
def ENet_stage5(inputs, inputs_shape, connect_tensor, skip_connections=True,
def ENet_stage5(inputs,
inputs_shape,
connect_tensor,
skip_connections=True,
name_scope='stage5_block'):
with scope(name_scope):
net = bottleneck(inputs, output_depth=16, filter_size=3, regularizer_prob=0.1, type=UPSAMPLING,
decoder=True, output_shape=inputs_shape,
name_scope='bottleneck5_0')
net = bottleneck(
inputs,
output_depth=16,
filter_size=3,
regularizer_prob=0.1,
type=UPSAMPLING,
decoder=True,
output_shape=inputs_shape,
name_scope='bottleneck5_0')
if skip_connections:
net = fluid.layers.elementwise_add(net, connect_tensor)
with scope('bottleneck5_1'):
net = bottleneck(net, output_depth=16, filter_size=3, regularizer_prob=0.1, decoder=True,
name_scope='bottleneck5_1')
net = bottleneck(
net,
output_depth=16,
filter_size=3,
regularizer_prob=0.1,
decoder=True,
name_scope='bottleneck5_1')
return net
......@@ -378,14 +493,16 @@ def decoder(input, num_classes):
segStage3 = ENet_stage3(stage2)
segStage4 = ENet_stage4(segStage3, inputs_shape_2, stage1)
segStage5 = ENet_stage5(segStage4, inputs_shape_1, initial)
segLogits = deconv(segStage5, num_classes, filter_size=2, stride=2, padding='SAME')
segLogits = deconv(
segStage5, num_classes, filter_size=2, stride=2, padding='SAME')
# Embedding branch
with scope('LaneNetEm'):
emStage3 = ENet_stage3(stage2)
emStage4 = ENet_stage4(emStage3, inputs_shape_2, stage1)
emStage5 = ENet_stage5(emStage4, inputs_shape_1, initial)
emLogits = deconv(emStage5, 4, filter_size=2, stride=2, padding='SAME')
emLogits = deconv(
emStage5, 4, filter_size=2, stride=2, padding='SAME')
elif 'vgg' in cfg.MODEL.LANENET.BACKBONE:
encoder_list = ['pool5', 'pool4', 'pool3']
......@@ -396,14 +513,16 @@ def decoder(input, num_classes):
encoder_list = encoder_list[1:]
for i in range(len(encoder_list)):
with scope('deconv_{:d}'.format(i + 1)):
deconv_out = deconv(score, 64, filter_size=4, stride=2, padding='SAME')
deconv_out = deconv(
score, 64, filter_size=4, stride=2, padding='SAME')
input_tensor = input[encoder_list[i]]
with scope('score_{:d}'.format(i + 1)):
score = conv(input_tensor, 64, 1)
score = fluid.layers.elementwise_add(deconv_out, score)
with scope('deconv_final'):
emLogits = deconv(score, 64, filter_size=16, stride=8, padding='SAME')
emLogits = deconv(
score, 64, filter_size=16, stride=8, padding='SAME')
with scope('score_final'):
segLogits = conv(emLogits, num_classes, 1)
emLogits = relu(conv(emLogits, 4, 1))
......@@ -415,7 +534,8 @@ def encoder(input):
model = vgg_backbone(layers=16)
#output = model.net(input)
_, encode_feature_dict = model.net(input, end_points=13, decode_points=[7, 10, 13])
_, encode_feature_dict = model.net(
input, end_points=13, decode_points=[7, 10, 13])
output = {}
output['pool3'] = encode_feature_dict[7]
output['pool4'] = encode_feature_dict[10]
......@@ -427,8 +547,9 @@ def encoder(input):
stage2, inputs_shape_2 = ENet_stage2(stage1)
output = (initial, stage1, stage2, inputs_shape_1, inputs_shape_2)
else:
raise Exception("LaneNet expect enet and vgg backbone, but received {}".
format(cfg.MODEL.LANENET.BACKBONE))
raise Exception(
"LaneNet expect enet and vgg backbone, but received {}".format(
cfg.MODEL.LANENET.BACKBONE))
return output
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -58,7 +58,8 @@ class LaneNetDataset():
if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // cfg.NUM_TRAINERS
self.lines = self.all_lines[num_lines * cfg.TRAINER_ID: num_lines * (cfg.TRAINER_ID + 1)]
self.lines = self.all_lines[num_lines * cfg.TRAINER_ID:num_lines *
(cfg.TRAINER_ID + 1)]
self.shuffle_seed += 1
elif self.shuffle:
np.random.shuffle(self.lines)
......@@ -86,7 +87,8 @@ class LaneNetDataset():
if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // self.num_trainers
self.lines = self.all_lines[num_lines * self.trainer_id: num_lines * (self.trainer_id + 1)]
self.lines = self.all_lines[num_lines * self.trainer_id:num_lines *
(self.trainer_id + 1)]
self.shuffle_seed += 1
elif self.shuffle:
np.random.shuffle(self.lines)
......@@ -118,7 +120,8 @@ class LaneNetDataset():
def batch_reader(is_test=False, drop_last=drop_last):
if is_test:
imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], []
for img, grt, grt_instance, img_name, valid_shape, org_shape in reader():
for img, grt, grt_instance, img_name, valid_shape, org_shape in reader(
):
imgs.append(img)
grts.append(grt)
grts_instance.append(grt_instance)
......@@ -126,14 +129,15 @@ class LaneNetDataset():
valid_shapes.append(valid_shape)
org_shapes.append(org_shape)
if len(imgs) == batch_size:
yield np.array(imgs), np.array(
grts), np.array(grts_instance), img_names, np.array(valid_shapes), np.array(
org_shapes)
yield np.array(imgs), np.array(grts), np.array(
grts_instance), img_names, np.array(
valid_shapes), np.array(org_shapes)
imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], []
if not drop_last and len(imgs) > 0:
yield np.array(imgs), np.array(grts), np.array(grts_instance), img_names, np.array(
valid_shapes), np.array(org_shapes)
yield np.array(imgs), np.array(grts), np.array(
grts_instance), img_names, np.array(
valid_shapes), np.array(org_shapes)
else:
imgs, labs, labs_instance, ignore = [], [], [], []
bs = 0
......@@ -144,12 +148,14 @@ class LaneNetDataset():
ignore.append(ig)
bs += 1
if bs == batch_size:
yield np.array(imgs), np.array(labs), np.array(labs_instance), np.array(ignore)
yield np.array(imgs), np.array(labs), np.array(
labs_instance), np.array(ignore)
bs = 0
imgs, labs, labs_instance, ignore = [], [], [], []
if not drop_last and bs > 0:
yield np.array(imgs), np.array(labs), np.array(labs_instance), np.array(ignore)
yield np.array(imgs), np.array(labs), np.array(
labs_instance), np.array(ignore)
return batch_reader(is_test, drop_last)
......@@ -299,10 +305,12 @@ class LaneNetDataset():
img, grt = aug.rand_crop(img, grt, mode=mode)
elif ModelPhase.is_eval(mode):
img, grt, grt_instance = aug.resize(img, grt, grt_instance, mode=mode)
img, grt, grt_instance = aug.resize(
img, grt, grt_instance, mode=mode)
elif ModelPhase.is_visual(mode):
ori_img = img.copy()
img, grt, grt_instance = aug.resize(img, grt, grt_instance, mode=mode)
img, grt, grt_instance = aug.resize(
img, grt, grt_instance, mode=mode)
valid_shape = [img.shape[0], img.shape[1]]
else:
raise ValueError("Dataset mode={} Error!".format(mode))
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -80,8 +80,8 @@ cfg.DATASET.DATA_DIM = 3
cfg.DATASET.SEPARATOR = ' '
# 忽略的像素标签值, 默认为255,一般无需改动
cfg.DATASET.IGNORE_INDEX = 255
# 数据增强是图像的padding值
cfg.DATASET.PADDING_VALUE = [127.5,127.5,127.5]
# 数据增强是图像的padding值
cfg.DATASET.PADDING_VALUE = [127.5, 127.5, 127.5]
########################### 数据增强配置 ######################################
# 图像镜像左右翻转
......
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
generate tusimple training dataset
"""
......@@ -14,12 +28,16 @@ import numpy as np
def init_args():
parser = argparse.ArgumentParser()
parser.add_argument('--src_dir', type=str, help='The origin path of unzipped tusimple dataset')
parser.add_argument(
'--src_dir',
type=str,
help='The origin path of unzipped tusimple dataset')
return parser.parse_args()
def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, instance_dst_dir):
def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir,
instance_dst_dir):
assert ops.exists(json_file_path), '{:s} not exist'.format(json_file_path)
......@@ -39,11 +57,14 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
h_samples = info_dict['h_samples']
lanes = info_dict['lanes']
image_name_new = '{:s}.png'.format('{:d}'.format(line_index + image_nums).zfill(4))
image_name_new = '{:s}.png'.format(
'{:d}'.format(line_index + image_nums).zfill(4))
src_image = cv2.imread(image_path, cv2.IMREAD_COLOR)
dst_binary_image = np.zeros([src_image.shape[0], src_image.shape[1]], np.uint8)
dst_instance_image = np.zeros([src_image.shape[0], src_image.shape[1]], np.uint8)
dst_binary_image = np.zeros(
[src_image.shape[0], src_image.shape[1]], np.uint8)
dst_instance_image = np.zeros(
[src_image.shape[0], src_image.shape[1]], np.uint8)
for lane_index, lane in enumerate(lanes):
assert len(h_samples) == len(lane)
......@@ -62,13 +83,23 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
lane_pts = np.vstack((lane_x, lane_y)).transpose()
lane_pts = np.array([lane_pts], np.int64)
cv2.polylines(dst_binary_image, lane_pts, isClosed=False,
color=255, thickness=5)
cv2.polylines(dst_instance_image, lane_pts, isClosed=False,
color=lane_index * 50 + 20, thickness=5)
dst_binary_image_path = ops.join(src_dir, binary_dst_dir, image_name_new)
dst_instance_image_path = ops.join(src_dir, instance_dst_dir, image_name_new)
cv2.polylines(
dst_binary_image,
lane_pts,
isClosed=False,
color=255,
thickness=5)
cv2.polylines(
dst_instance_image,
lane_pts,
isClosed=False,
color=lane_index * 50 + 20,
thickness=5)
dst_binary_image_path = ops.join(src_dir, binary_dst_dir,
image_name_new)
dst_instance_image_path = ops.join(src_dir, instance_dst_dir,
image_name_new)
dst_rgb_image_path = ops.join(src_dir, ori_dst_dir, image_name_new)
cv2.imwrite(dst_binary_image_path, dst_binary_image)
......@@ -78,7 +109,12 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
print('Process {:s} success'.format(image_name))
def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train', split=False):
def gen_sample(src_dir,
b_gt_image_dir,
i_gt_image_dir,
image_dir,
phase='train',
split=False):
label_list = []
with open('{:s}/{}ing/{}.txt'.format(src_dir, phase, phase), 'w') as file:
......@@ -92,7 +128,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
image_path = ops.join(image_dir, image_name)
assert ops.exists(image_path), '{:s} not exist'.format(image_path)
assert ops.exists(instance_gt_image_path), '{:s} not exist'.format(instance_gt_image_path)
assert ops.exists(instance_gt_image_path), '{:s} not exist'.format(
instance_gt_image_path)
b_gt_image = cv2.imread(binary_gt_image_path, cv2.IMREAD_COLOR)
i_gt_image = cv2.imread(instance_gt_image_path, cv2.IMREAD_COLOR)
......@@ -102,7 +139,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
print('image: {:s} corrupt'.format(image_name))
continue
else:
info = '{:s} {:s} {:s}'.format(image_path, binary_gt_image_path, instance_gt_image_path)
info = '{:s} {:s} {:s}'.format(image_path, binary_gt_image_path,
instance_gt_image_path)
file.write(info + '\n')
label_list.append(info)
if phase == 'train' and split:
......@@ -110,10 +148,12 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
val_list_len = len(label_list) // 10
val_label_list = label_list[:val_list_len]
train_label_list = label_list[val_list_len:]
with open('{:s}/{}ing/train_part.txt'.format(src_dir, phase, phase), 'w') as file:
with open('{:s}/{}ing/train_part.txt'.format(src_dir, phase, phase),
'w') as file:
for info in train_label_list:
file.write(info + '\n')
with open('{:s}/{}ing/val_part.txt'.format(src_dir, phase, phase), 'w') as file:
with open('{:s}/{}ing/val_part.txt'.format(src_dir, phase, phase),
'w') as file:
for info in val_label_list:
file.write(info + '\n')
return
......@@ -130,12 +170,14 @@ def process_tusimple_dataset(src_dir):
for json_label_path in glob.glob('{:s}/label*.json'.format(src_dir)):
json_label_name = ops.split(json_label_path)[1]
shutil.copyfile(json_label_path, ops.join(traing_folder_path, json_label_name))
shutil.copyfile(json_label_path,
ops.join(traing_folder_path, json_label_name))
for json_label_path in glob.glob('{:s}/test_label.json'.format(src_dir)):
json_label_name = ops.split(json_label_path)[1]
shutil.copyfile(json_label_path, ops.join(testing_folder_path, json_label_name))
shutil.copyfile(json_label_path,
ops.join(testing_folder_path, json_label_name))
train_gt_image_dir = ops.join('training', 'gt_image')
train_gt_binary_dir = ops.join('training', 'gt_binary_image')
......@@ -154,9 +196,11 @@ def process_tusimple_dataset(src_dir):
os.makedirs(os.path.join(src_dir, test_gt_instance_dir), exist_ok=True)
for json_label_path in glob.glob('{:s}/*.json'.format(traing_folder_path)):
process_json_file(json_label_path, src_dir, train_gt_image_dir, train_gt_binary_dir, train_gt_instance_dir)
process_json_file(json_label_path, src_dir, train_gt_image_dir,
train_gt_binary_dir, train_gt_instance_dir)
gen_sample(src_dir, train_gt_binary_dir, train_gt_instance_dir, train_gt_image_dir, 'train', True)
gen_sample(src_dir, train_gt_binary_dir, train_gt_instance_dir,
train_gt_image_dir, 'train', True)
if __name__ == '__main__':
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this code heavily base on https://github.com/MaybeShewill-CV/lanenet-lane-detection/blob/master/lanenet_model/lanenet_postprocess.py
"""
LaneNet model post process
......@@ -22,12 +35,14 @@ def _morphological_process(image, kernel_size=5):
:return:
"""
if len(image.shape) == 3:
raise ValueError('Binary segmentation result image should be a single channel image')
raise ValueError(
'Binary segmentation result image should be a single channel image')
if image.dtype is not np.uint8:
image = np.array(image, np.uint8)
kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(kernel_size, kernel_size))
kernel = cv2.getStructuringElement(
shape=cv2.MORPH_ELLIPSE, ksize=(kernel_size, kernel_size))
# close operation fille hole
closing = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel, iterations=1)
......@@ -46,13 +61,15 @@ def _connect_components_analysis(image):
else:
gray_image = image
return cv2.connectedComponentsWithStats(gray_image, connectivity=8, ltype=cv2.CV_32S)
return cv2.connectedComponentsWithStats(
gray_image, connectivity=8, ltype=cv2.CV_32S)
class _LaneFeat(object):
"""
"""
def __init__(self, feat, coord, class_id=-1):
"""
lane feat object
......@@ -108,18 +125,21 @@ class _LaneNetCluster(object):
"""
Instance segmentation result cluster
"""
def __init__(self):
"""
"""
self._color_map = [np.array([255, 0, 0]),
np.array([0, 255, 0]),
np.array([0, 0, 255]),
np.array([125, 125, 0]),
np.array([0, 125, 125]),
np.array([125, 0, 125]),
np.array([50, 100, 50]),
np.array([100, 50, 100])]
self._color_map = [
np.array([255, 0, 0]),
np.array([0, 255, 0]),
np.array([0, 0, 255]),
np.array([125, 125, 0]),
np.array([0, 125, 125]),
np.array([125, 0, 125]),
np.array([50, 100, 50]),
np.array([100, 50, 100])
]
@staticmethod
def _embedding_feats_dbscan_cluster(embedding_image_feats):
......@@ -186,15 +206,16 @@ class _LaneNetCluster(object):
# get embedding feats and coords
get_lane_embedding_feats_result = self._get_lane_embedding_feats(
binary_seg_ret=binary_seg_result,
instance_seg_ret=instance_seg_result
)
instance_seg_ret=instance_seg_result)
# dbscan cluster
dbscan_cluster_result = self._embedding_feats_dbscan_cluster(
embedding_image_feats=get_lane_embedding_feats_result['lane_embedding_feats']
)
embedding_image_feats=get_lane_embedding_feats_result[
'lane_embedding_feats'])
mask = np.zeros(shape=[binary_seg_result.shape[0], binary_seg_result.shape[1], 3], dtype=np.uint8)
mask = np.zeros(
shape=[binary_seg_result.shape[0], binary_seg_result.shape[1], 3],
dtype=np.uint8)
db_labels = dbscan_cluster_result['db_labels']
unique_labels = dbscan_cluster_result['unique_labels']
coord = get_lane_embedding_feats_result['lane_coordinates']
......@@ -219,11 +240,13 @@ class LaneNetPostProcessor(object):
"""
lanenet post process for lane generation
"""
def __init__(self, ipm_remap_file_path='./utils/tusimple_ipm_remap.yml'):
"""
convert front car view to bird view
"""
assert ops.exists(ipm_remap_file_path), '{:s} not exist'.format(ipm_remap_file_path)
assert ops.exists(ipm_remap_file_path), '{:s} not exist'.format(
ipm_remap_file_path)
self._cluster = _LaneNetCluster()
self._ipm_remap_file_path = ipm_remap_file_path
......@@ -232,14 +255,16 @@ class LaneNetPostProcessor(object):
self._remap_to_ipm_x = remap_file_load_ret['remap_to_ipm_x']
self._remap_to_ipm_y = remap_file_load_ret['remap_to_ipm_y']
self._color_map = [np.array([255, 0, 0]),
np.array([0, 255, 0]),
np.array([0, 0, 255]),
np.array([125, 125, 0]),
np.array([0, 125, 125]),
np.array([125, 0, 125]),
np.array([50, 100, 50]),
np.array([100, 50, 100])]
self._color_map = [
np.array([255, 0, 0]),
np.array([0, 255, 0]),
np.array([0, 0, 255]),
np.array([125, 125, 0]),
np.array([0, 125, 125]),
np.array([125, 0, 125]),
np.array([50, 100, 50]),
np.array([100, 50, 100])
]
def _load_remap_matrix(self):
fs = cv2.FileStorage(self._ipm_remap_file_path, cv2.FILE_STORAGE_READ)
......@@ -256,15 +281,20 @@ class LaneNetPostProcessor(object):
return ret
def postprocess(self, binary_seg_result, instance_seg_result=None,
min_area_threshold=100, source_image=None,
def postprocess(self,
binary_seg_result,
instance_seg_result=None,
min_area_threshold=100,
source_image=None,
data_source='tusimple'):
# convert binary_seg_result
binary_seg_result = np.array(binary_seg_result * 255, dtype=np.uint8)
# apply image morphology operation to fill in the hold and reduce the small area
morphological_ret = _morphological_process(binary_seg_result, kernel_size=5)
connect_components_analysis_ret = _connect_components_analysis(image=morphological_ret)
morphological_ret = _morphological_process(
binary_seg_result, kernel_size=5)
connect_components_analysis_ret = _connect_components_analysis(
image=morphological_ret)
labels = connect_components_analysis_ret[1]
stats = connect_components_analysis_ret[2]
......@@ -276,8 +306,7 @@ class LaneNetPostProcessor(object):
# apply embedding features cluster
mask_image, lane_coords = self._cluster.apply_lane_feats_cluster(
binary_seg_result=morphological_ret,
instance_seg_result=instance_seg_result
)
instance_seg_result=instance_seg_result)
if mask_image is None:
return {
......@@ -292,15 +321,15 @@ class LaneNetPostProcessor(object):
for lane_index, coords in enumerate(lane_coords):
if data_source == 'tusimple':
tmp_mask = np.zeros(shape=(720, 1280), dtype=np.uint8)
tmp_mask[tuple((np.int_(coords[:, 1] * 720 / 256), np.int_(coords[:, 0] * 1280 / 512)))] = 255
tmp_mask[tuple((np.int_(coords[:, 1] * 720 / 256),
np.int_(coords[:, 0] * 1280 / 512)))] = 255
else:
raise ValueError('Wrong data source now only support tusimple')
tmp_ipm_mask = cv2.remap(
tmp_mask,
self._remap_to_ipm_x,
self._remap_to_ipm_y,
interpolation=cv2.INTER_NEAREST
)
interpolation=cv2.INTER_NEAREST)
nonzero_y = np.array(tmp_ipm_mask.nonzero()[0])
nonzero_x = np.array(tmp_ipm_mask.nonzero()[1])
......@@ -309,16 +338,19 @@ class LaneNetPostProcessor(object):
[ipm_image_height, ipm_image_width] = tmp_ipm_mask.shape
plot_y = np.linspace(10, ipm_image_height, ipm_image_height - 10)
fit_x = fit_param[0] * plot_y ** 2 + fit_param[1] * plot_y + fit_param[2]
fit_x = fit_param[0] * plot_y**2 + fit_param[
1] * plot_y + fit_param[2]
lane_pts = []
for index in range(0, plot_y.shape[0], 5):
src_x = self._remap_to_ipm_x[
int(plot_y[index]), int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
int(plot_y[index]),
int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
if src_x <= 0:
continue
src_y = self._remap_to_ipm_y[
int(plot_y[index]), int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
int(plot_y[index]),
int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
src_y = src_y if src_y > 0 else 0
lane_pts.append([src_x, src_y])
......@@ -366,8 +398,10 @@ class LaneNetPostProcessor(object):
continue
lane_color = self._color_map[index].tolist()
cv2.circle(source_image, (int(interpolation_src_pt_x),
int(interpolation_src_pt_y)), 5, lane_color, -1)
cv2.circle(
source_image,
(int(interpolation_src_pt_x), int(interpolation_src_pt_y)),
5, lane_color, -1)
ret = {
'mask_image': mask_image,
'fit_params': fit_params,
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -23,7 +24,8 @@ from test_utils import download_file_and_uncompress
if __name__ == "__main__":
download_file_and_uncompress(
url='https://paddleseg.bj.bcebos.com/models/unet_mechanical_industry_meter.tar',
url=
'https://paddleseg.bj.bcebos.com/models/unet_mechanical_industry_meter.tar',
savepath=LOCAL_PATH,
extrapath=LOCAL_PATH)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .load_model import *
from .unet import *
from .hrnet import *
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import paddle.fluid as fluid
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import numpy as np
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .unet import UNet
from .hrnet import HRNet
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
import sys
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
import argparse
import transforms.transforms as T
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from utils.util import AttrDict, merge_cfg_from_args, get_arguments
import os
......@@ -6,20 +20,20 @@ args = get_arguments()
cfg = AttrDict()
# 待预测图像所在路径
cfg.data_dir = os.path.join(args.example , "data", "test_images")
cfg.data_dir = os.path.join(args.example, "data", "test_images")
# 待预测图像名称列表
cfg.data_list_file = os.path.join(args.example , "data", "test.txt")
cfg.data_list_file = os.path.join(args.example, "data", "test.txt")
# 模型加载路径
cfg.model_path = os.path.join(args.example , "model")
cfg.model_path = os.path.join(args.example, "model")
# 预测结果保存路径
cfg.vis_dir = os.path.join(args.example , "result")
cfg.vis_dir = os.path.join(args.example, "result")
# 预测类别数
cfg.class_num = 2
# 均值, 图像预处理减去的均值
cfg.MEAN = 127.5, 127.5, 127.5
# 标准差,图像预处理除以标准差
cfg.STD = 127.5, 127.5, 127.5
cfg.STD = 127.5, 127.5, 127.5
# 待预测图像输入尺寸
cfg.input_size = 1536, 576
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import numpy as np
......@@ -12,18 +26,19 @@ config = importlib.import_module('config')
cfg = getattr(config, 'cfg')
# paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启
os.environ['FLAGS_eager_delete_tensor_gb']='0.0'
os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
import paddle.fluid as fluid
# 预测数据集类
class TestDataSet():
def __init__(self):
self.data_dir = cfg.data_dir
self.data_dir = cfg.data_dir
self.data_list_file = cfg.data_list_file
self.data_list = self.get_data_list()
self.data_num = len(self.data_list)
def get_data_list(self):
# 获取预测图像路径列表
data_list = []
......@@ -40,7 +55,7 @@ class TestDataSet():
def preprocess(self, img):
# 图像预处理
if cfg.example == 'ACE2P':
reader = importlib.import_module(args.example+'.reader')
reader = importlib.import_module(args.example + '.reader')
ACE2P_preprocess = getattr(reader, 'preprocess')
img = ACE2P_preprocess(img)
else:
......@@ -56,10 +71,10 @@ class TestDataSet():
img_path = self.data_list[index]
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
if img is None:
return img, img,img_path, None
return img, img, img_path, None
img_name = img_path.split(os.sep)[-1]
name_prefix = img_name.replace('.'+img_name.split('.')[-1],'')
name_prefix = img_name.replace('.' + img_name.split('.')[-1], '')
img_shape = img.shape[:2]
img_process = self.preprocess(img)
......@@ -90,39 +105,44 @@ def infer():
if image is None:
print(im_name, 'is None')
continue
# 预测
if cfg.example == 'ACE2P':
# ACE2P模型使用多尺度预测
reader = importlib.import_module(args.example+'.reader')
reader = importlib.import_module(args.example + '.reader')
multi_scale_test = getattr(reader, 'multi_scale_test')
parsing, logits = multi_scale_test(exe, test_prog, feed_name, fetch_list, image, im_shape)
parsing, logits = multi_scale_test(exe, test_prog, feed_name,
fetch_list, image, im_shape)
else:
# HumanSeg,RoadLine模型单尺度预测
result = exe.run(program=test_prog, feed={feed_name[0]: image}, fetch_list=fetch_list)
result = exe.run(
program=test_prog,
feed={feed_name[0]: image},
fetch_list=fetch_list)
parsing = np.argmax(result[0][0], axis=0)
parsing = cv2.resize(parsing.astype(np.uint8), im_shape[::-1])
# 预测结果保存
result_path = os.path.join(cfg.vis_dir, im_name + '.png')
if cfg.example == 'HumanSeg':
logits = result[0][0][1]*255
logits = result[0][0][1] * 255
logits = cv2.resize(logits, im_shape[::-1])
ret, logits = cv2.threshold(logits, thresh, 0, cv2.THRESH_TOZERO)
logits = 255 *(logits - thresh)/(255 - thresh)
logits = 255 * (logits - thresh) / (255 - thresh)
# 将分割结果添加到alpha通道
rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)), axis=2)
rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)),
axis=2)
cv2.imwrite(result_path, rgba)
else:
else:
output_im = PILImage.fromarray(np.asarray(parsing, dtype=np.uint8))
output_im.putpalette(palette)
output_im.save(result_path)
if (idx + 1) % 100 == 0:
print('%d processd' % (idx + 1))
print('%d processd done' % (idx + 1))
print('%d processd done' % (idx + 1))
return 0
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: RainbowSecret
## Microsoft Research
## yuyua@microsoft.com
## Copyright (c) 2018
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import os
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--use_gpu",
action="store_true",
help="Use gpu or cpu to test.")
parser.add_argument('--example',
type=str,
help='RoadLine, HumanSeg or ACE2P')
parser.add_argument(
"--use_gpu", action="store_true", help="Use gpu or cpu to test.")
parser.add_argument(
'--example', type=str, help='RoadLine, HumanSeg or ACE2P')
return parser.parse_args()
......@@ -34,6 +48,7 @@ class AttrDict(dict):
else:
self[name] = value
def merge_cfg_from_args(args, cfg):
"""Merge config keys, values in args into the global config."""
for k, v in vars(args).items():
......@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg):
value = v
if value is not None:
cfg[k] = value
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -20,6 +21,8 @@ from PIL import Image
import glob
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
def remove_colormap(filename):
gray_anno = np.array(Image.open(filename))
return gray_anno
......@@ -30,6 +33,7 @@ def save_annotation(annotation, filename):
annotation = Image.fromarray(annotation)
annotation.save(filename)
def convert_list(origin_file, seg_file, output_folder):
with open(seg_file, 'w') as fid_seg:
with open(origin_file) as fid_ori:
......@@ -43,6 +47,7 @@ def convert_list(origin_file, seg_file, output_folder):
new_line = ' '.join([img_name, anno_name])
fid_seg.write(new_line + "\n")
if __name__ == "__main__":
pascal_root = "./VOCtrainval_11-May-2012/VOC2012"
pascal_root = os.path.join(LOCAL_PATH, pascal_root)
......@@ -54,7 +59,7 @@ if __name__ == "__main__":
# 标注图转换后存储目录
output_folder = os.path.join(pascal_root, "SegmentationClassAug")
print("annotation convert and file list convert")
if not os.path.exists(os.path.join(LOCAL_PATH, output_folder)):
os.mkdir(os.path.join(LOCAL_PATH, output_folder))
......@@ -67,5 +72,5 @@ if __name__ == "__main__":
convert_list(train_path, train_path.replace('txt', 'list'), output_folder)
convert_list(val_path, val_path.replace('txt', 'list'), output_folder)
convert_list(trainval_path, trainval_path.replace('txt', 'list'), output_folder)
convert_list(trainval_path, trainval_path.replace('txt', 'list'),
output_folder)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -28,12 +29,12 @@ from convert_voc2012 import remove_colormap
from convert_voc2012 import save_annotation
def download_VOC_dataset(savepath, extrapath):
url = "https://paddleseg.bj.bcebos.com/dataset/VOCtrainval_11-May-2012.tar"
download_file_and_uncompress(
url=url, savepath=savepath, extrapath=extrapath)
if __name__ == "__main__":
download_VOC_dataset(LOCAL_PATH, LOCAL_PATH)
print("Dataset download finish!")
......@@ -45,10 +46,10 @@ if __name__ == "__main__":
train_path = os.path.join(txt_folder, "train.txt")
val_path = os.path.join(txt_folder, "val.txt")
trainval_path = os.path.join(txt_folder, "trainval.txt")
# 标注图转换后存储目录
output_folder = os.path.join(pascal_root, "SegmentationClassAug")
print("annotation convert and file list convert")
if not os.path.exists(output_folder):
os.mkdir(output_folder)
......@@ -61,5 +62,5 @@ if __name__ == "__main__":
convert_list(train_path, train_path.replace('txt', 'list'), output_folder)
convert_list(val_path, val_path.replace('txt', 'list'), output_folder)
convert_list(trainval_path, trainval_path.replace('txt', 'list'), output_folder)
convert_list(trainval_path, trainval_path.replace('txt', 'list'),
output_folder)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -31,7 +31,8 @@ gflags.DEFINE_string("conf", default="", help="Configuration File Path")
gflags.DEFINE_string("input_dir", default="", help="Directory of Input Images")
gflags.DEFINE_boolean("use_pr", default=False, help="Use optimized model")
gflags.DEFINE_string("trt_mode", default="", help="Use optimized model")
gflags.DEFINE_string("ext", default=".jpeg|.jpg", help="Input Image File Extensions")
gflags.DEFINE_string(
"ext", default=".jpeg|.jpg", help="Input Image File Extensions")
gflags.FLAGS = gflags.FLAGS
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,4 +14,4 @@
# limitations under the License.
import models
import utils
from . import tools
\ No newline at end of file
from . import tools
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
......@@ -427,12 +440,17 @@ def max_img_size_statistics():
logger.info("max width and max height of images are ({},{})".format(
max_width, max_height))
def num_classes_loss_matching_check():
loss_type = cfg.SOLVER.LOSS
num_classes = cfg.DATASET.NUM_CLASSES
if num_classes > 2 and (("dice_loss" in loss_type) or ("bce_loss" in loss_type)):
logger.info(error_print("loss check."
" Dice loss and bce loss is only applicable to binary classfication"))
if num_classes > 2 and (("dice_loss" in loss_type) or
("bce_loss" in loss_type)):
logger.info(
error_print(
"loss check."
" Dice loss and bce loss is only applicable to binary classfication"
))
else:
logger.info(correct_print("loss check"))
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -362,7 +362,7 @@ def hsv_color_jitter(crop_img,
saturation_jitter_ratio > 0 or \
contrast_jitter_ratio > 0:
crop_img = random_jitter(crop_img, saturation_jitter_ratio,
brightness_jitter_ratio, contrast_jitter_ratio)
brightness_jitter_ratio, contrast_jitter_ratio)
return crop_img
......@@ -391,7 +391,7 @@ def rand_crop(crop_img, crop_seg, mode=ModelPhase.TRAIN):
crop_width = cfg.EVAL_CROP_SIZE[0]
crop_height = cfg.EVAL_CROP_SIZE[1]
if not ModelPhase.is_train(mode):
if not ModelPhase.is_train(mode):
if (crop_height < img_height or crop_width < img_width):
raise Exception(
"Crop size({},{}) must large than img size({},{}) when in EvalPhase."
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py
"""
......@@ -14,10 +28,10 @@ except ImportError:
class GeneratorEnqueuer(object):
"""
Multiple generators
Multiple generators
Args:
generators:
generators:
wait_time (float): time to sleep in-between calls to `put()`.
"""
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -141,7 +141,7 @@ class ResNet():
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
dilation_rate = get_dilated_rate(dilation_dict, block)
conv = self.bottleneck_block(
input=conv,
num_filters=int(num_filters[block] * self.scale),
......@@ -215,11 +215,11 @@ class ResNet():
groups=1,
act=None,
name=None):
if self.stem == 'pspnet':
bias_attr=ParamAttr(name=name + "_biases")
bias_attr = ParamAttr(name=name + "_biases")
else:
bias_attr=False
bias_attr = False
conv = fluid.layers.conv2d(
input=input,
......@@ -238,13 +238,15 @@ class ResNet():
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(input=conv,
act=act,
name=bn_name + '.output.1',
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', )
return fluid.layers.batch_norm(
input=conv,
act=act,
name=bn_name + '.output.1',
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',
)
def shortcut(self, input, ch_out, stride, is_first, name):
ch_in = input.shape[1]
......@@ -258,7 +260,7 @@ class ResNet():
strides = [1, stride]
else:
strides = [stride, 1]
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -55,7 +55,8 @@ class VGGNet():
channels = [64, 128, 256, 512, 512]
conv = input
for i in range(len(nums)):
conv = self.conv_block(conv, channels[i], nums[i], name="conv" + str(i + 1) + "_")
conv = self.conv_block(
conv, channels[i], nums[i], name="conv" + str(i + 1) + "_")
layers_count += nums[i]
if check_points(layers_count, decode_points):
short_cuts[layers_count] = conv
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -197,4 +197,4 @@ def conv_bn_layer(input,
if if_act:
return fluid.layers.relu6(bn)
else:
return bn
\ No newline at end of file
return bn
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -25,7 +25,14 @@ from paddle.fluid.param_attr import ParamAttr
from utils.config import cfg
def conv_bn_layer(input, filter_size, num_filters, stride=1, padding=1, num_groups=1, if_act=True, name=None):
def conv_bn_layer(input,
filter_size,
num_filters,
stride=1,
padding=1,
num_groups=1,
if_act=True,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
......@@ -37,37 +44,74 @@ def conv_bn_layer(input, filter_size, num_filters, stride=1, padding=1, num_grou
param_attr=ParamAttr(initializer=MSRA(), name=name + '_weights'),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(input=conv,
param_attr=ParamAttr(name=bn_name + "_scale",
initializer=fluid.initializer.Constant(1.0)),
bias_attr=ParamAttr(name=bn_name + "_offset",
initializer=fluid.initializer.Constant(0.0)),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
bn = fluid.layers.batch_norm(
input=conv,
param_attr=ParamAttr(
name=bn_name + "_scale",
initializer=fluid.initializer.Constant(1.0)),
bias_attr=ParamAttr(
name=bn_name + "_offset",
initializer=fluid.initializer.Constant(0.0)),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act:
bn = fluid.layers.relu(bn)
return bn
def basic_block(input, num_filters, stride=1, downsample=False, name=None):
residual = input
conv = conv_bn_layer(input=input, filter_size=3, num_filters=num_filters, stride=stride, name=name + '_conv1')
conv = conv_bn_layer(input=conv, filter_size=3, num_filters=num_filters, if_act=False, name=name + '_conv2')
conv = conv_bn_layer(
input=input,
filter_size=3,
num_filters=num_filters,
stride=stride,
name=name + '_conv1')
conv = conv_bn_layer(
input=conv,
filter_size=3,
num_filters=num_filters,
if_act=False,
name=name + '_conv2')
if downsample:
residual = conv_bn_layer(input=input, filter_size=1, num_filters=num_filters, if_act=False,
name=name + '_downsample')
residual = conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_filters,
if_act=False,
name=name + '_downsample')
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def bottleneck_block(input, num_filters, stride=1, downsample=False, name=None):
residual = input
conv = conv_bn_layer(input=input, filter_size=1, num_filters=num_filters, name=name + '_conv1')
conv = conv_bn_layer(input=conv, filter_size=3, num_filters=num_filters, stride=stride, name=name + '_conv2')
conv = conv_bn_layer(input=conv, filter_size=1, num_filters=num_filters * 4, if_act=False,
name=name + '_conv3')
conv = conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_filters,
name=name + '_conv1')
conv = conv_bn_layer(
input=conv,
filter_size=3,
num_filters=num_filters,
stride=stride,
name=name + '_conv2')
conv = conv_bn_layer(
input=conv,
filter_size=1,
num_filters=num_filters * 4,
if_act=False,
name=name + '_conv3')
if downsample:
residual = conv_bn_layer(input=input, filter_size=1, num_filters=num_filters * 4, if_act=False,
name=name + '_downsample')
residual = conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_filters * 4,
if_act=False,
name=name + '_downsample')
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def fuse_layers(x, channels, multi_scale_output=True, name=None):
out = []
for i in range(len(channels) if multi_scale_output else 1):
......@@ -77,40 +121,64 @@ def fuse_layers(x, channels, multi_scale_output=True, name=None):
height = shape[-2]
for j in range(len(channels)):
if j > i:
y = conv_bn_layer(x[j], filter_size=1, num_filters=channels[i], if_act=False,
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1))
y = fluid.layers.resize_bilinear(input=y, out_shape=[height, width])
residual = fluid.layers.elementwise_add(x=residual, y=y, act=None)
y = conv_bn_layer(
x[j],
filter_size=1,
num_filters=channels[i],
if_act=False,
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1))
y = fluid.layers.resize_bilinear(
input=y, out_shape=[height, width])
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
elif j < i:
y = x[j]
for k in range(i - j):
if k == i - j - 1:
y = conv_bn_layer(y, filter_size=3, num_filters=channels[i], stride=2, if_act=False,
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1) + '_' + str(k + 1))
y = conv_bn_layer(
y,
filter_size=3,
num_filters=channels[i],
stride=2,
if_act=False,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1))
else:
y = conv_bn_layer(y, filter_size=3, num_filters=channels[j], stride=2,
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1) + '_' + str(k + 1))
residual = fluid.layers.elementwise_add(x=residual, y=y, act=None)
y = conv_bn_layer(
y,
filter_size=3,
num_filters=channels[j],
stride=2,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1))
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
residual = fluid.layers.relu(residual)
out.append(residual)
return out
def branches(x, block_num, channels, name=None):
out = []
for i in range(len(channels)):
residual = x[i]
for j in range(block_num):
residual = basic_block(residual, channels[i],
name=name + '_branch_layer_' + str(i + 1) + '_' + str(j + 1))
residual = basic_block(
residual,
channels[i],
name=name + '_branch_layer_' + str(i + 1) + '_' + str(j + 1))
out.append(residual)
return out
def high_resolution_module(x, channels, multi_scale_output=True, name=None):
residual = branches(x, 4, channels, name=name)
out = fuse_layers(residual, channels, multi_scale_output=multi_scale_output, name=name)
out = fuse_layers(
residual, channels, multi_scale_output=multi_scale_output, name=name)
return out
def transition_layer(x, in_channels, out_channels, name=None):
num_in = len(in_channels)
num_out = len(out_channels)
......@@ -118,46 +186,76 @@ def transition_layer(x, in_channels, out_channels, name=None):
for i in range(num_out):
if i < num_in:
if in_channels[i] != out_channels[i]:
residual = conv_bn_layer(x[i], filter_size=3, num_filters=out_channels[i],
name=name + '_layer_' + str(i + 1))
residual = conv_bn_layer(
x[i],
filter_size=3,
num_filters=out_channels[i],
name=name + '_layer_' + str(i + 1))
out.append(residual)
else:
out.append(x[i])
else:
residual = conv_bn_layer(x[-1], filter_size=3, num_filters=out_channels[i], stride=2,
name=name + '_layer_' + str(i + 1))
residual = conv_bn_layer(
x[-1],
filter_size=3,
num_filters=out_channels[i],
stride=2,
name=name + '_layer_' + str(i + 1))
out.append(residual)
return out
def stage(x, num_modules, channels, multi_scale_output=True, name=None):
out = x
for i in range(num_modules):
if i == num_modules - 1 and multi_scale_output == False:
out = high_resolution_module(out, channels, multi_scale_output=False, name=name + '_' + str(i + 1))
out = high_resolution_module(
out,
channels,
multi_scale_output=False,
name=name + '_' + str(i + 1))
else:
out = high_resolution_module(out, channels, name=name + '_' + str(i + 1))
out = high_resolution_module(
out, channels, name=name + '_' + str(i + 1))
return out
def layer1(input, name=None):
conv = input
for i in range(4):
conv = bottleneck_block(conv, num_filters=64, downsample=True if i == 0 else False,
name=name + '_' + str(i + 1))
conv = bottleneck_block(
conv,
num_filters=64,
downsample=True if i == 0 else False,
name=name + '_' + str(i + 1))
return conv
def high_resolution_net(input, num_classes):
channels_2 = cfg.MODEL.HRNET.STAGE2.NUM_CHANNELS
channels_3 = cfg.MODEL.HRNET.STAGE3.NUM_CHANNELS
channels_4 = cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS
num_modules_2 = cfg.MODEL.HRNET.STAGE2.NUM_MODULES
num_modules_3 = cfg.MODEL.HRNET.STAGE3.NUM_MODULES
num_modules_4 = cfg.MODEL.HRNET.STAGE4.NUM_MODULES
x = conv_bn_layer(input=input, filter_size=3, num_filters=64, stride=2, if_act=True, name='layer1_1')
x = conv_bn_layer(input=x, filter_size=3, num_filters=64, stride=2, if_act=True, name='layer1_2')
x = conv_bn_layer(
input=input,
filter_size=3,
num_filters=64,
stride=2,
if_act=True,
name='layer1_1')
x = conv_bn_layer(
input=x,
filter_size=3,
num_filters=64,
stride=2,
if_act=True,
name='layer1_2')
la1 = layer1(x, name='layer2')
tr1 = transition_layer([la1], [256], channels_2, name='tr1')
......@@ -170,18 +268,21 @@ def high_resolution_net(input, num_classes):
# upsample
shape = st4[0].shape
height, width = shape[-2], shape[-1]
st4[1] = fluid.layers.resize_bilinear(
st4[1], out_shape=[height, width])
st4[2] = fluid.layers.resize_bilinear(
st4[2], out_shape=[height, width])
st4[3] = fluid.layers.resize_bilinear(
st4[3], out_shape=[height, width])
st4[1] = fluid.layers.resize_bilinear(st4[1], out_shape=[height, width])
st4[2] = fluid.layers.resize_bilinear(st4[2], out_shape=[height, width])
st4[3] = fluid.layers.resize_bilinear(st4[3], out_shape=[height, width])
out = fluid.layers.concat(st4, axis=1)
last_channels = sum(channels_4)
out = conv_bn_layer(input=out, filter_size=1, num_filters=last_channels, stride=1, if_act=True, name='conv-2')
out= fluid.layers.conv2d(
out = conv_bn_layer(
input=out,
filter_size=1,
num_filters=last_channels,
stride=1,
if_act=True,
name='conv-2')
out = fluid.layers.conv2d(
input=out,
num_filters=num_classes,
filter_size=1,
......@@ -193,7 +294,6 @@ def high_resolution_net(input, num_classes):
out = fluid.layers.resize_bilinear(out, input.shape[2:])
return out
......@@ -201,6 +301,7 @@ def hrnet(input, num_classes):
logit = high_resolution_net(input, num_classes)
return logit
if __name__ == '__main__':
image_shape = [-1, 3, 769, 769]
image = fluid.data(name='image', shape=image_shape, dtype='float32')
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -24,6 +24,7 @@ from models.libs.model_libs import avg_pool, conv, bn
from models.backbone.resnet import ResNet as resnet_backbone
from utils.config import cfg
def get_logit_interp(input, num_classes, out_shape, name="logit"):
# 根据类别数决定最后一层卷积输出, 并插值回原始尺寸
param_attr = fluid.ParamAttr(
......@@ -33,16 +34,15 @@ def get_logit_interp(input, num_classes, out_shape, name="logit"):
initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.01))
with scope(name):
logit = conv(input,
num_classes,
filter_size=1,
param_attr=param_attr,
bias_attr=True,
name=name+'_conv')
logit = conv(
input,
num_classes,
filter_size=1,
param_attr=param_attr,
bias_attr=True,
name=name + '_conv')
logit_interp = fluid.layers.resize_bilinear(
logit,
out_shape=out_shape,
name=name+'_interp')
logit, out_shape=out_shape, name=name + '_interp')
return logit_interp
......@@ -51,40 +51,44 @@ def psp_module(input, out_features):
# 输入:backbone输出的特征
# 输出:对输入进行不同尺度pooling, 卷积操作后插值回原始尺寸,并concat
# 最后进行一个卷积及BN操作
cat_layers = []
sizes = (1,2,3,6)
sizes = (1, 2, 3, 6)
for size in sizes:
psp_name = "psp" + str(size)
with scope(psp_name):
pool = fluid.layers.adaptive_pool2d(input,
pool_size=[size, size],
pool_type='avg',
name=psp_name+'_adapool')
data = conv(pool, out_features,
filter_size=1,
bias_attr=True,
name= psp_name + '_conv')
pool = fluid.layers.adaptive_pool2d(
input,
pool_size=[size, size],
pool_type='avg',
name=psp_name + '_adapool')
data = conv(
pool,
out_features,
filter_size=1,
bias_attr=True,
name=psp_name + '_conv')
data_bn = bn(data, act='relu')
interp = fluid.layers.resize_bilinear(data_bn,
out_shape=input.shape[2:],
name=psp_name+'_interp')
interp = fluid.layers.resize_bilinear(
data_bn, out_shape=input.shape[2:], name=psp_name + '_interp')
cat_layers.append(interp)
cat_layers = [input] + cat_layers[::-1]
cat = fluid.layers.concat(cat_layers, axis=1, name='psp_cat')
psp_end_name = "psp_end"
with scope(psp_end_name):
data = conv(cat,
out_features,
filter_size=3,
padding=1,
bias_attr=True,
name=psp_end_name)
data = conv(
cat,
out_features,
filter_size=3,
padding=1,
bias_attr=True,
name=psp_end_name)
out = bn(data, act='relu')
return out
def resnet(input):
# PSPNET backbone: resnet, 默认resnet50
# end_points: resnet终止层数
......@@ -92,14 +96,14 @@ def resnet(input):
scale = cfg.MODEL.PSPNET.DEPTH_MULTIPLIER
layers = cfg.MODEL.PSPNET.LAYERS
end_points = layers - 1
dilation_dict = {2:2, 3:4}
dilation_dict = {2: 2, 3: 4}
model = resnet_backbone(layers, scale, stem='pspnet')
data, _ = model.net(input,
end_points=end_points,
dilation_dict=dilation_dict)
data, _ = model.net(
input, end_points=end_points, dilation_dict=dilation_dict)
return data
def pspnet(input, num_classes):
# Backbone: ResNet
res = resnet(input)
......@@ -109,4 +113,3 @@ def pspnet(input, num_classes):
# 根据类别数决定最后一层卷积输出, 并插值回原始尺寸
logit = get_logit_interp(dropout, num_classes, input.shape[2:])
return logit
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -71,7 +71,8 @@ class SegDataset(object):
if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // cfg.NUM_TRAINERS
self.lines = self.all_lines[num_lines * cfg.TRAINER_ID: num_lines * (cfg.TRAINER_ID + 1)]
self.lines = self.all_lines[num_lines * cfg.TRAINER_ID:num_lines *
(cfg.TRAINER_ID + 1)]
self.shuffle_seed += 1
elif self.shuffle:
np.random.shuffle(self.lines)
......@@ -99,7 +100,8 @@ class SegDataset(object):
if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // cfg.NUM_TRAINERS
self.lines = self.all_lines[num_lines * cfg.TRAINER_ID: num_lines * (cfg.TRAINER_ID + 1)]
self.lines = self.all_lines[num_lines * cfg.TRAINER_ID:num_lines *
(cfg.TRAINER_ID + 1)]
self.shuffle_seed += 1
elif self.shuffle:
np.random.shuffle(self.lines)
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -21,55 +21,48 @@ import warnings
def parse_args():
parser = argparse.ArgumentParser(
description='PaddleSeg generate file list on cityscapes or your customized dataset.')
parser.add_argument(
'dataset_root',
help='dataset root directory',
type=str
description=
'PaddleSeg generate file list on cityscapes or your customized dataset.'
)
parser.add_argument('dataset_root', help='dataset root directory', type=str)
parser.add_argument(
'--type',
help='dataset type: \n'
'- cityscapes \n'
'- custom(default)',
'- cityscapes \n'
'- custom(default)',
default="custom",
type=str
)
type=str)
parser.add_argument(
'--separator',
dest='separator',
help='file list separator',
default="|",
type=str
)
type=str)
parser.add_argument(
'--folder',
help='the folder names of images and labels',
type=str,
nargs=2,
default=['images', 'annotations']
)
default=['images', 'annotations'])
parser.add_argument(
'--second_folder',
help='the second-level folder names of train set, validation set, test set',
help=
'the second-level folder names of train set, validation set, test set',
type=str,
nargs='*',
default=['train', 'val', 'test']
)
default=['train', 'val', 'test'])
parser.add_argument(
'--format',
help='data format of images and labels, e.g. jpg or png.',
type=str,
nargs=2,
default=['jpg', 'png']
)
default=['jpg', 'png'])
parser.add_argument(
'--postfix',
help='postfix of images or labels',
type=str,
nargs=2,
default=['', '']
)
default=['', ''])
return parser.parse_args()
......@@ -120,15 +113,17 @@ def generate_list(args):
num_images = len(image_files)
if not label_files:
label_dir = os.path.join(dataset_root, args.folder[1], dataset_split)
label_dir = os.path.join(dataset_root, args.folder[1],
dataset_split)
warnings.warn("No labels in {} !!!".format(label_dir))
num_label = len(label_files)
if num_images != num_label and num_label > 0:
raise Exception("Number of images = {} number of labels = {} \n"
"Either number of images is equal to number of labels, "
"or number of labels is equal to 0.\n"
"Please check your dataset!".format(num_images, num_label))
raise Exception(
"Number of images = {} number of labels = {} \n"
"Either number of images is equal to number of labels, "
"or number of labels is equal to 0.\n"
"Please check your dataset!".format(num_images, num_label))
file_list = os.path.join(dataset_root, dataset_split + '.txt')
with open(file_list, "w") as f:
......
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
......@@ -11,16 +25,12 @@ from PIL import Image
def parse_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument('dir_or_file',
help='input gray label directory or file list path')
parser.add_argument('output_dir',
help='output colorful label directory')
parser.add_argument('--dataset_dir',
help='dataset directory')
parser.add_argument('--file_separator',
help='file list separator')
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'dir_or_file', help='input gray label directory or file list path')
parser.add_argument('output_dir', help='output colorful label directory')
parser.add_argument('--dataset_dir', help='dataset directory')
parser.add_argument('--file_separator', help='file list separator')
return parser.parse_args()
......
#!/usr/bin/env python
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
......
#!/usr/bin/env python
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -98,7 +99,7 @@ class SegConfig(dict):
'DATASET.IMAGE_TYPE config error, only support `rgb`, `gray` and `rgba`'
)
if self.MEAN is not None:
self.DATASET.PADDING_VALUE = [x*255.0 for x in self.MEAN]
self.DATASET.PADDING_VALUE = [x * 255.0 for x in self.MEAN]
if not self.TRAIN_CROP_SIZE:
raise ValueError(
......@@ -111,9 +112,12 @@ class SegConfig(dict):
)
# Ensure file list is use UTF-8 encoding
train_sets = codecs.open(self.DATASET.TRAIN_FILE_LIST, 'r', 'utf-8').readlines()
val_sets = codecs.open(self.DATASET.VAL_FILE_LIST, 'r', 'utf-8').readlines()
test_sets = codecs.open(self.DATASET.TEST_FILE_LIST, 'r', 'utf-8').readlines()
train_sets = codecs.open(self.DATASET.TRAIN_FILE_LIST, 'r',
'utf-8').readlines()
val_sets = codecs.open(self.DATASET.VAL_FILE_LIST, 'r',
'utf-8').readlines()
test_sets = codecs.open(self.DATASET.TEST_FILE_LIST, 'r',
'utf-8').readlines()
self.DATASET.TRAIN_TOTAL_IMAGES = len(train_sets)
self.DATASET.VAL_TOTAL_IMAGES = len(val_sets)
self.DATASET.TEST_TOTAL_IMAGES = len(test_sets)
......@@ -122,12 +126,13 @@ class SegConfig(dict):
len(self.MODEL.MULTI_LOSS_WEIGHT) != 3:
self.MODEL.MULTI_LOSS_WEIGHT = [1.0, 0.4, 0.16]
if self.AUG.AUG_METHOD not in ['unpadding', 'stepscaling', 'rangescaling']:
if self.AUG.AUG_METHOD not in [
'unpadding', 'stepscaling', 'rangescaling'
]:
raise ValueError(
'AUG.AUG_METHOD config error, only support `unpadding`, `unpadding` and `rangescaling`'
)
def update_from_list(self, config_list):
if len(config_list) % 2 != 0:
raise ValueError(
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from paddle import fluid
def load_fp16_vars(executor, dirname, program):
load_dirname = os.path.normpath(dirname)
......@@ -28,4 +44,4 @@ def load_fp16_vars(executor, dirname, program):
'load_as_fp16': var.dtype == fluid.core.VarDesc.VarType.FP16
})
executor.run(load_prog)
\ No newline at end of file
executor.run(load_prog)
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -42,7 +43,7 @@ model_urls = {
"hrnet_w30_bn_imagenet":
"https://paddleseg.bj.bcebos.com/models/hrnet_w30_imagenet.tar",
"hrnet_w32_bn_imagenet":
"https://paddleseg.bj.bcebos.com/models/hrnet_w32_imagenet.tar" ,
"https://paddleseg.bj.bcebos.com/models/hrnet_w32_imagenet.tar",
"hrnet_w40_bn_imagenet":
"https://paddleseg.bj.bcebos.com/models/hrnet_w40_imagenet.tar",
"hrnet_w44_bn_imagenet":
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -27,6 +27,7 @@ from models.libs.model_libs import separate_conv
from models.backbone.mobilenet_v2 import MobileNetV2 as mobilenet_backbone
from models.backbone.xception import Xception as xception_backbone
def encoder(input):
# 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv
# ASPP_WITH_SEP_CONV:默认为真,使用depthwise可分离卷积,否则使用普通卷积
......@@ -47,8 +48,7 @@ def encoder(input):
with scope('encoder'):
channel = 256
with scope("image_pool"):
image_avg = fluid.layers.reduce_mean(
input, [2, 3], keep_dim=True)
image_avg = fluid.layers.reduce_mean(input, [2, 3], keep_dim=True)
image_avg = bn_relu(
conv(
image_avg,
......@@ -191,7 +191,10 @@ def nas_backbone(input, arch):
end_points = 8
decode_point = 3
data, decode_shortcuts = arch(
input, end_points=end_points, return_block=decode_point, output_stride=16)
input,
end_points=end_points,
return_block=decode_point,
output_stride=16)
decode_shortcut = decode_shortcuts[decode_point]
return data, decode_shortcut
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -31,7 +32,7 @@ __all__ = ["MobileNetV2SpaceSeg"]
class MobileNetV2SpaceSeg(SearchSpaceBase):
def __init__(self, input_size, output_size, block_num, block_mask=None):
super(MobileNetV2SpaceSeg, self).__init__(input_size, output_size,
block_num, block_mask)
block_num, block_mask)
# self.head_num means the first convolution channel
self.head_num = np.array([3, 4, 8, 12, 16, 24, 32]) #7
# self.filter_num1 ~ self.filter_num6 means following convlution channel
......@@ -48,7 +49,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
self.k_size = np.array([3, 5]) #2
# self.multiply means expansion_factor of each _inverted_residual_unit
self.multiply = np.array([1, 2, 3, 4, 6]) #5
# self.repeat means repeat_num _inverted_residual_unit in each _invresi_blocks
# self.repeat means repeat_num _inverted_residual_unit in each _invresi_blocks
self.repeat = np.array([1, 2, 3, 4, 5, 6]) #6
def init_tokens(self):
......@@ -72,7 +73,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
def range_table(self):
"""
Get range table of current search space, constrains the range of tokens.
Get range table of current search space, constrains the range of tokens.
"""
# head_num + 6 * [multiple(expansion_factor), filter_num, repeat, kernel_size]
# yapf: disable
......@@ -95,8 +96,8 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
tokens = self.init_tokens()
self.bottleneck_params_list = []
self.bottleneck_params_list.append(
(1, self.head_num[tokens[0]], 1, 1, 3))
self.bottleneck_params_list.append((1, self.head_num[tokens[0]], 1, 1,
3))
self.bottleneck_params_list.append(
(self.multiply[tokens[1]], self.filter_num1[tokens[2]],
self.repeat[tokens[3]], 2, self.k_size[tokens[4]]))
......@@ -150,7 +151,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
return (True if count == points else False)
#conv1
# all padding is 'SAME' in the conv2d, can compute the actual padding automatic.
# all padding is 'SAME' in the conv2d, can compute the actual padding automatic.
input = conv_bn_layer(
input,
num_filters=int(32 * self.scale),
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册