提交 7835c042 编写于 作者: C chenguowei01

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleSeg into humanseg

...@@ -35,7 +35,7 @@ PaddleSeg是基于[PaddlePaddle](https://www.paddlepaddle.org.cn)开发的端到 ...@@ -35,7 +35,7 @@ PaddleSeg是基于[PaddlePaddle](https://www.paddlepaddle.org.cn)开发的端到
- **高性能** - **高性能**
PaddleSeg支持多进程I/O、多卡并行、跨卡Batch Norm同步等训练加速策略,结合飞桨核心框架的显存优化功能,可大幅度减少分割模型的显存开销,让开发者更低成本、更高效地完成图像分割训练。 PaddleSeg支持多进程I/O、多卡并行等训练加速策略,结合飞桨核心框架的显存优化功能,可大幅度减少分割模型的显存开销,让开发者更低成本、更高效地完成图像分割训练。
- **工业级部署** - **工业级部署**
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from utils.util import AttrDict, merge_cfg_from_args, get_arguments from utils.util import AttrDict, merge_cfg_from_args, get_arguments
import os import os
...@@ -19,10 +33,10 @@ cfg.class_num = 20 ...@@ -19,10 +33,10 @@ cfg.class_num = 20
# 均值, 图像预处理减去的均值 # 均值, 图像预处理减去的均值
cfg.MEAN = 0.406, 0.456, 0.485 cfg.MEAN = 0.406, 0.456, 0.485
# 标准差,图像预处理除以标准差 # 标准差,图像预处理除以标准差
cfg.STD = 0.225, 0.224, 0.229 cfg.STD = 0.225, 0.224, 0.229
# 多尺度预测时图像尺寸 # 多尺度预测时图像尺寸
cfg.multi_scales = (377,377), (473,473), (567,567) cfg.multi_scales = (377, 377), (473, 473), (567, 567)
# 多尺度预测时图像是否水平翻转 # 多尺度预测时图像是否水平翻转
cfg.flip = True cfg.flip = True
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import cv2 import cv2
import numpy as np import numpy as np
...@@ -12,18 +26,19 @@ config = importlib.import_module('config') ...@@ -12,18 +26,19 @@ config = importlib.import_module('config')
cfg = getattr(config, 'cfg') cfg = getattr(config, 'cfg')
# paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启 # paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启
os.environ['FLAGS_eager_delete_tensor_gb']='0.0' os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
import paddle.fluid as fluid import paddle.fluid as fluid
# 预测数据集类 # 预测数据集类
class TestDataSet(): class TestDataSet():
def __init__(self): def __init__(self):
self.data_dir = cfg.data_dir self.data_dir = cfg.data_dir
self.data_list_file = cfg.data_list_file self.data_list_file = cfg.data_list_file
self.data_list = self.get_data_list() self.data_list = self.get_data_list()
self.data_num = len(self.data_list) self.data_num = len(self.data_list)
def get_data_list(self): def get_data_list(self):
# 获取预测图像路径列表 # 获取预测图像路径列表
data_list = [] data_list = []
...@@ -56,10 +71,10 @@ class TestDataSet(): ...@@ -56,10 +71,10 @@ class TestDataSet():
img_path = self.data_list[index] img_path = self.data_list[index]
img = cv2.imread(img_path, cv2.IMREAD_COLOR) img = cv2.imread(img_path, cv2.IMREAD_COLOR)
if img is None: if img is None:
return img, img,img_path, None return img, img, img_path, None
img_name = img_path.split(os.sep)[-1] img_name = img_path.split(os.sep)[-1]
name_prefix = img_name.replace('.'+img_name.split('.')[-1],'') name_prefix = img_name.replace('.' + img_name.split('.')[-1], '')
img_shape = img.shape[:2] img_shape = img.shape[:2]
img_process = self.preprocess(img) img_process = self.preprocess(img)
...@@ -90,39 +105,44 @@ def infer(): ...@@ -90,39 +105,44 @@ def infer():
if image is None: if image is None:
print(im_name, 'is None') print(im_name, 'is None')
continue continue
# 预测 # 预测
if cfg.example == 'ACE2P': if cfg.example == 'ACE2P':
# ACE2P模型使用多尺度预测 # ACE2P模型使用多尺度预测
reader = importlib.import_module('reader') reader = importlib.import_module('reader')
multi_scale_test = getattr(reader, 'multi_scale_test') multi_scale_test = getattr(reader, 'multi_scale_test')
parsing, logits = multi_scale_test(exe, test_prog, feed_name, fetch_list, image, im_shape) parsing, logits = multi_scale_test(exe, test_prog, feed_name,
fetch_list, image, im_shape)
else: else:
# HumanSeg,RoadLine模型单尺度预测 # HumanSeg,RoadLine模型单尺度预测
result = exe.run(program=test_prog, feed={feed_name[0]: image}, fetch_list=fetch_list) result = exe.run(
program=test_prog,
feed={feed_name[0]: image},
fetch_list=fetch_list)
parsing = np.argmax(result[0][0], axis=0) parsing = np.argmax(result[0][0], axis=0)
parsing = cv2.resize(parsing.astype(np.uint8), im_shape[::-1]) parsing = cv2.resize(parsing.astype(np.uint8), im_shape[::-1])
# 预测结果保存 # 预测结果保存
result_path = os.path.join(cfg.vis_dir, im_name + '.png') result_path = os.path.join(cfg.vis_dir, im_name + '.png')
if cfg.example == 'HumanSeg': if cfg.example == 'HumanSeg':
logits = result[0][0][1]*255 logits = result[0][0][1] * 255
logits = cv2.resize(logits, im_shape[::-1]) logits = cv2.resize(logits, im_shape[::-1])
ret, logits = cv2.threshold(logits, thresh, 0, cv2.THRESH_TOZERO) ret, logits = cv2.threshold(logits, thresh, 0, cv2.THRESH_TOZERO)
logits = 255 *(logits - thresh)/(255 - thresh) logits = 255 * (logits - thresh) / (255 - thresh)
# 将分割结果添加到alpha通道 # 将分割结果添加到alpha通道
rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)), axis=2) rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)),
axis=2)
cv2.imwrite(result_path, rgba) cv2.imwrite(result_path, rgba)
else: else:
output_im = PILImage.fromarray(np.asarray(parsing, dtype=np.uint8)) output_im = PILImage.fromarray(np.asarray(parsing, dtype=np.uint8))
output_im.putpalette(palette) output_im.putpalette(palette)
output_im.save(result_path) output_im.save(result_path)
if (idx + 1) % 100 == 0: if (idx + 1) % 100 == 0:
print('%d processd' % (idx + 1)) print('%d processd' % (idx + 1))
print('%d processd done' % (idx + 1)) print('%d processd done' % (idx + 1))
return 0 return 0
......
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from config import cfg from config import cfg
import cv2 import cv2
def get_affine_points(src_shape, dst_shape, rot_grad=0): def get_affine_points(src_shape, dst_shape, rot_grad=0):
# 获取图像和仿射后图像的三组对应点坐标 # 获取图像和仿射后图像的三组对应点坐标
# 三组点为仿射变换后图像的中心点, [w/2,0], [0,0],及对应原始图像的点 # 三组点为仿射变换后图像的中心点, [w/2,0], [0,0],及对应原始图像的点
...@@ -23,7 +38,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0): ...@@ -23,7 +38,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0):
# 原始图像三组点 # 原始图像三组点
points = [[0, 0]] * 3 points = [[0, 0]] * 3
points[0] = (np.array([w, h]) - 1) * 0.5 points[0] = (np.array([w, h]) - 1) * 0.5
points[1] = points[0] + 0.5 * affine_shape[0] * np.array([sin_v, -cos_v]) points[1] = points[0] + 0.5 * affine_shape[0] * np.array([sin_v, -cos_v])
points[2] = points[1] - 0.5 * affine_shape[1] * np.array([cos_v, sin_v]) points[2] = points[1] - 0.5 * affine_shape[1] * np.array([cos_v, sin_v])
...@@ -34,6 +49,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0): ...@@ -34,6 +49,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0):
return points, points_trans return points, points_trans
def preprocess(im): def preprocess(im):
# ACE2P模型数据预处理 # ACE2P模型数据预处理
im_shape = im.shape[:2] im_shape = im.shape[:2]
...@@ -42,13 +58,10 @@ def preprocess(im): ...@@ -42,13 +58,10 @@ def preprocess(im):
# 获取图像和仿射变换后图像的对应点坐标 # 获取图像和仿射变换后图像的对应点坐标
points, points_trans = get_affine_points(im_shape, scale) points, points_trans = get_affine_points(im_shape, scale)
# 根据对应点集获得仿射矩阵 # 根据对应点集获得仿射矩阵
trans = cv2.getAffineTransform(np.float32(points), trans = cv2.getAffineTransform(
np.float32(points_trans)) np.float32(points), np.float32(points_trans))
# 根据仿射矩阵对图像进行仿射 # 根据仿射矩阵对图像进行仿射
input = cv2.warpAffine(im, input = cv2.warpAffine(im, trans, scale[::-1], flags=cv2.INTER_LINEAR)
trans,
scale[::-1],
flags=cv2.INTER_LINEAR)
# 减均值测,除以方差,转换数据格式为NCHW # 减均值测,除以方差,转换数据格式为NCHW
input = input.astype(np.float32) input = input.astype(np.float32)
...@@ -66,19 +79,20 @@ def preprocess(im): ...@@ -66,19 +79,20 @@ def preprocess(im):
return input_images return input_images
def multi_scale_test(exe, test_prog, feed_name, fetch_list, def multi_scale_test(exe, test_prog, feed_name, fetch_list, input_ims,
input_ims, im_shape): im_shape):
# 由于部分类别分左右部位, flipped_idx为其水平翻转后对应的标签 # 由于部分类别分左右部位, flipped_idx为其水平翻转后对应的标签
flipped_idx = (15, 14, 17, 16, 19, 18) flipped_idx = (15, 14, 17, 16, 19, 18)
ms_outputs = [] ms_outputs = []
# 多尺度预测 # 多尺度预测
for idx, scale in enumerate(cfg.multi_scales): for idx, scale in enumerate(cfg.multi_scales):
input_im = input_ims[idx] input_im = input_ims[idx]
parsing_output = exe.run(program=test_prog, parsing_output = exe.run(
feed={feed_name[0]: input_im}, program=test_prog,
fetch_list=fetch_list) feed={feed_name[0]: input_im},
fetch_list=fetch_list)
output = parsing_output[0][0] output = parsing_output[0][0]
if cfg.flip: if cfg.flip:
# 若水平翻转,对部分类别进行翻转,与原始预测结果取均值 # 若水平翻转,对部分类别进行翻转,与原始预测结果取均值
...@@ -92,7 +106,8 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list, ...@@ -92,7 +106,8 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list,
# 仿射变换回图像原始尺寸 # 仿射变换回图像原始尺寸
points, points_trans = get_affine_points(im_shape, scale) points, points_trans = get_affine_points(im_shape, scale)
M = cv2.getAffineTransform(np.float32(points_trans), np.float32(points)) M = cv2.getAffineTransform(np.float32(points_trans), np.float32(points))
logits_result = cv2.warpAffine(output, M, im_shape[::-1], flags=cv2.INTER_LINEAR) logits_result = cv2.warpAffine(
output, M, im_shape[::-1], flags=cv2.INTER_LINEAR)
ms_outputs.append(logits_result) ms_outputs.append(logits_result)
# 多尺度预测结果求均值,求预测概率最大的类别 # 多尺度预测结果求均值,求预测概率最大的类别
...@@ -100,4 +115,3 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list, ...@@ -100,4 +115,3 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list,
ms_fused_parsing_output = np.mean(ms_fused_parsing_output, axis=0) ms_fused_parsing_output = np.mean(ms_fused_parsing_output, axis=0)
parsing = np.argmax(ms_fused_parsing_output, axis=2) parsing = np.argmax(ms_fused_parsing_output, axis=2)
return parsing, ms_fused_parsing_output return parsing, ms_fused_parsing_output
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
## This source code is licensed under the MIT-style license found in the ## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree ## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import argparse import argparse
import os import os
def get_arguments(): def get_arguments():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--use_gpu", parser.add_argument(
action="store_true", "--use_gpu", action="store_true", help="Use gpu or cpu to test.")
help="Use gpu or cpu to test.") parser.add_argument(
parser.add_argument('--example', '--example', type=str, help='RoadLine, HumanSeg or ACE2P')
type=str,
help='RoadLine, HumanSeg or ACE2P')
return parser.parse_args() return parser.parse_args()
...@@ -34,6 +48,7 @@ class AttrDict(dict): ...@@ -34,6 +48,7 @@ class AttrDict(dict):
else: else:
self[name] = value self[name] = value
def merge_cfg_from_args(args, cfg): def merge_cfg_from_args(args, cfg):
"""Merge config keys, values in args into the global config.""" """Merge config keys, values in args into the global config."""
for k, v in vars(args).items(): for k, v in vars(args).items():
...@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg): ...@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg):
value = v value = v
if value is not None: if value is not None:
cfg[k] = value cfg[k] = value
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,9 +13,6 @@ ...@@ -12,9 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# utils for memory management which is allocated on sharedmemory,
# note that these structures may not be thread-safe
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import models import models
import argparse import argparse
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import os import os
import os.path as osp import os.path as osp
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .humanseg import HumanSegMobile from .humanseg import HumanSegMobile
from .humanseg import HumanSegServer from .humanseg import HumanSegServer
from .humanseg import HumanSegLite from .humanseg import HumanSegLite
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -24,6 +25,7 @@ import time ...@@ -24,6 +25,7 @@ import time
import tqdm import tqdm
import cv2 import cv2
import yaml import yaml
import shutil
import paddleslim as slim import paddleslim as slim
import utils import utils
...@@ -102,7 +104,7 @@ class SegModel(object): ...@@ -102,7 +104,7 @@ class SegModel(object):
# 当前模型状态 # 当前模型状态
self.status = 'Normal' self.status = 'Normal'
def _get_single_car_bs(self, batch_size): def _get_single_card_bs(self, batch_size):
if batch_size % len(self.places) == 0: if batch_size % len(self.places) == 0:
return int(batch_size // len(self.places)) return int(batch_size // len(self.places))
else: else:
...@@ -144,7 +146,7 @@ class SegModel(object): ...@@ -144,7 +146,7 @@ class SegModel(object):
capacity=64, capacity=64,
use_double_buffer=True, use_double_buffer=True,
iterable=True) iterable=True)
batch_size_each_gpu = self._get_single_car_bs(batch_size) batch_size_each_gpu = self._get_single_card_bs(batch_size)
self.train_data_loader.set_sample_list_generator( self.train_data_loader.set_sample_list_generator(
dataset.generator(batch_size=batch_size_each_gpu), dataset.generator(batch_size=batch_size_each_gpu),
places=self.places) places=self.places)
...@@ -242,30 +244,11 @@ class SegModel(object): ...@@ -242,30 +244,11 @@ class SegModel(object):
if self.status == 'Normal': if self.status == 'Normal':
fluid.save(self.train_prog, osp.join(save_dir, 'model')) fluid.save(self.train_prog, osp.join(save_dir, 'model'))
model_info['status'] = 'Normal'
elif self.status == 'Quant': elif self.status == 'Quant':
float_prog, _ = slim.quant.convert( fluid.save(self.test_prog, osp.join(save_dir, 'model'))
self.test_prog, self.exe.place, save_int8=True) model_info['status'] = 'QuantOnline'
test_input_names = [
var.name for var in list(self.test_inputs.values())
]
test_outputs = list(self.test_outputs.values())
fluid.io.save_inference_model(
dirname=save_dir,
executor=self.exe,
params_filename='__params__',
feeded_var_names=test_input_names,
target_vars=test_outputs,
main_program=float_prog)
model_info['_ModelInputsOutputs'] = dict()
model_info['_ModelInputsOutputs']['test_inputs'] = [
[k, v.name] for k, v in self.test_inputs.items()
]
model_info['_ModelInputsOutputs']['test_outputs'] = [
[k, v.name] for k, v in self.test_outputs.items()
]
model_info['status'] = self.status
with open( with open(
osp.join(save_dir, 'model.yml'), encoding='utf-8', osp.join(save_dir, 'model.yml'), encoding='utf-8',
mode='w') as f: mode='w') as f:
...@@ -307,40 +290,57 @@ class SegModel(object): ...@@ -307,40 +290,57 @@ class SegModel(object):
logging.info("Model for inference deploy saved in {}.".format(save_dir)) logging.info("Model for inference deploy saved in {}.".format(save_dir))
def export_quant_model(self, def export_quant_model(self,
dataset, dataset=None,
save_dir, save_dir=None,
batch_size=1, batch_size=1,
batch_nums=10, batch_nums=10,
cache_dir="./.temp"): cache_dir=".temp",
self.arrange_transform(transforms=dataset.transforms, mode='quant') quant_type="offline"):
dataset.num_samples = batch_size * batch_nums if quant_type == "offline":
try: self.arrange_transform(transforms=dataset.transforms, mode='quant')
from utils import HumanSegPostTrainingQuantization dataset.num_samples = batch_size * batch_nums
except: try:
raise Exception( from utils import HumanSegPostTrainingQuantization
"Model Quantization is not available, try to upgrade your paddlepaddle>=1.7.0" except:
) raise Exception(
is_use_cache_file = True "Model Quantization is not available, try to upgrade your paddlepaddle>=1.8.1"
if cache_dir is None: )
is_use_cache_file = False is_use_cache_file = True
post_training_quantization = HumanSegPostTrainingQuantization( if cache_dir is None:
executor=self.exe, is_use_cache_file = False
dataset=dataset, post_training_quantization = HumanSegPostTrainingQuantization(
program=self.test_prog, executor=self.exe,
inputs=self.test_inputs, dataset=dataset,
outputs=self.test_outputs, program=self.test_prog,
batch_size=batch_size, inputs=self.test_inputs,
batch_nums=batch_nums, outputs=self.test_outputs,
scope=None, batch_size=batch_size,
algo='KL', batch_nums=batch_nums,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], scope=None,
is_full_quantize=False, algo='KL',
is_use_cache_file=is_use_cache_file, quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
cache_dir=cache_dir) is_full_quantize=False,
post_training_quantization.quantize() is_use_cache_file=is_use_cache_file,
post_training_quantization.save_quantized_model(save_dir) cache_dir=cache_dir)
if cache_dir is not None: post_training_quantization.quantize()
os.system('rm -r' + cache_dir) post_training_quantization.save_quantized_model(save_dir)
if cache_dir is not None:
shutil.rmtree(cache_dir)
else:
float_prog, _ = slim.quant.convert(
self.test_prog, self.exe.place, save_int8=True)
test_input_names = [
var.name for var in list(self.test_inputs.values())
]
test_outputs = list(self.test_outputs.values())
fluid.io.save_inference_model(
dirname=save_dir,
executor=self.exe,
params_filename='__params__',
feeded_var_names=test_input_names,
target_vars=test_outputs,
main_program=float_prog)
model_info = self.get_model_info() model_info = self.get_model_info()
model_info['status'] = 'Quant' model_info['status'] = 'Quant'
...@@ -592,6 +592,16 @@ class SegModel(object): ...@@ -592,6 +592,16 @@ class SegModel(object):
'Current evaluated best model in eval_dataset is epoch_{}, miou={}' 'Current evaluated best model in eval_dataset is epoch_{}, miou={}'
.format(best_model_epoch, best_miou)) .format(best_model_epoch, best_miou))
if quant:
if osp.exists(osp.join(save_dir, "best_model")):
fluid.load(
program=self.test_prog,
model_path=osp.join(save_dir, "best_model"),
executor=self.exe)
self.export_quant_model(
save_dir=osp.join(save_dir, "best_model_export"),
quant_type="online")
def evaluate(self, eval_dataset, batch_size=1, epoch_id=None): def evaluate(self, eval_dataset, batch_size=1, epoch_id=None):
"""评估。 """评估。
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -24,7 +25,7 @@ import models ...@@ -24,7 +25,7 @@ import models
def load_model(model_dir): def load_model(model_dir):
if not osp.exists(osp.join(model_dir, "model.yml")): if not osp.exists(osp.join(model_dir, "model.yml")):
raise Exception("There's not model.yml in {}".format(model_dir)) raise Exception("There's no model.yml in {}".format(model_dir))
with open(osp.join(model_dir, "model.yml")) as f: with open(osp.join(model_dir, "model.yml")) as f:
info = yaml.load(f.read(), Loader=yaml.Loader) info = yaml.load(f.read(), Loader=yaml.Loader)
status = info['status'] status = info['status']
...@@ -33,7 +34,7 @@ def load_model(model_dir): ...@@ -33,7 +34,7 @@ def load_model(model_dir):
raise Exception("There's no attribute {} in models".format( raise Exception("There's no attribute {} in models".format(
info['Model'])) info['Model']))
model = getattr(models, info['Model'])(**info['_init_params']) model = getattr(models, info['Model'])(**info['_init_params'])
if status == "Normal": if status in ["Normal", "QuantOnline"]:
startup_prog = fluid.Program() startup_prog = fluid.Program()
model.test_prog = fluid.Program() model.test_prog = fluid.Program()
with fluid.program_guard(model.test_prog, startup_prog): with fluid.program_guard(model.test_prog, startup_prog):
...@@ -41,11 +42,16 @@ def load_model(model_dir): ...@@ -41,11 +42,16 @@ def load_model(model_dir):
model.test_inputs, model.test_outputs = model.build_net( model.test_inputs, model.test_outputs = model.build_net(
mode='test') mode='test')
model.test_prog = model.test_prog.clone(for_test=True) model.test_prog = model.test_prog.clone(for_test=True)
if status == "QuantOnline":
print('test quant online')
import paddleslim as slim
model.test_prog = slim.quant.quant_aware(
model.test_prog, model.exe.place, for_test=True)
model.exe.run(startup_prog) model.exe.run(startup_prog)
import pickle fluid.load(model.test_prog, osp.join(model_dir, 'model'))
with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f: if status == "QuantOnline":
load_dict = pickle.load(f) model.test_prog = slim.quant.convert(model.test_prog,
fluid.io.set_program_state(model.test_prog, load_dict) model.exe.place)
elif status in ['Infer', 'Quant']: elif status in ['Infer', 'Quant']:
[prog, input_names, outputs] = fluid.io.load_inference_model( [prog, input_names, outputs] = fluid.io.load_inference_model(
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .backbone import mobilenet_v2 from .backbone import mobilenet_v2
from .backbone import xception from .backbone import xception
from .deeplabv3p import DeepLabv3p from .deeplabv3p import DeepLabv3p
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .mobilenet_v2 import MobileNetV2 from .mobilenet_v2 import MobileNetV2
from .xception import Xception from .xception import Xception
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -10,6 +11,7 @@ ...@@ -10,6 +11,7 @@
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
from datasets.dataset import Dataset from datasets.dataset import Dataset
import transforms import transforms
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
from datasets.dataset import Dataset from datasets.dataset import Dataset
from models import HumanSegMobile, HumanSegLite, HumanSegServer from models import HumanSegMobile, HumanSegLite, HumanSegServer
......
visualdl == 2.0.0b1 visualdl >= 2.0.0b1
paddleslim paddleslim
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
from datasets.dataset import Dataset from datasets.dataset import Dataset
from models import HumanSegMobile, HumanSegLite, HumanSegServer from models import HumanSegMobile, HumanSegLite, HumanSegServer
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
import cv2 import cv2
import os import os
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -205,11 +206,9 @@ def load_pretrained_weights(exe, main_prog, weights_dir, fuse_bn=False): ...@@ -205,11 +206,9 @@ def load_pretrained_weights(exe, main_prog, weights_dir, fuse_bn=False):
vars_to_load.append(var) vars_to_load.append(var)
logging.debug("Weight {} will be load".format(var.name)) logging.debug("Weight {} will be load".format(var.name))
fluid.io.load_vars( params_dict = fluid.io.load_program_state(
executor=exe, weights_dir, var_list=vars_to_load)
dirname=weights_dir, fluid.io.set_program_state(main_prog, params_dict)
main_program=main_prog,
vars=vars_to_load)
if len(vars_to_load) == 0: if len(vars_to_load) == 0:
logging.warning( logging.warning(
"There is no pretrain weights loaded, maybe you should check you pretrain model!" "There is no pretrain weights loaded, maybe you should check you pretrain model!"
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
from datasets.dataset import Dataset from datasets.dataset import Dataset
import transforms import transforms
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import os import os
import os.path as osp import os.path as osp
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,6 +21,7 @@ from models.model_builder import ModelPhase ...@@ -21,6 +21,7 @@ from models.model_builder import ModelPhase
from pdseg.data_aug import get_random_scale, randomly_scale_image_and_label, random_rotation, \ from pdseg.data_aug import get_random_scale, randomly_scale_image_and_label, random_rotation, \
rand_scale_aspect, hsv_color_jitter, rand_crop rand_scale_aspect, hsv_color_jitter, rand_crop
def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN): def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN):
""" """
改变图像及标签图像尺寸 改变图像及标签图像尺寸
...@@ -44,7 +45,8 @@ def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN): ...@@ -44,7 +45,8 @@ def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN):
if grt is not None: if grt is not None:
grt = cv2.resize(grt, target_size, interpolation=cv2.INTER_NEAREST) grt = cv2.resize(grt, target_size, interpolation=cv2.INTER_NEAREST)
if grt_instance is not None: if grt_instance is not None:
grt_instance = cv2.resize(grt_instance, target_size, interpolation=cv2.INTER_NEAREST) grt_instance = cv2.resize(
grt_instance, target_size, interpolation=cv2.INTER_NEAREST)
elif cfg.AUG.AUG_METHOD == 'stepscaling': elif cfg.AUG.AUG_METHOD == 'stepscaling':
if mode == ModelPhase.TRAIN: if mode == ModelPhase.TRAIN:
min_scale_factor = cfg.AUG.MIN_SCALE_FACTOR min_scale_factor = cfg.AUG.MIN_SCALE_FACTOR
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -122,7 +122,10 @@ def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs): ...@@ -122,7 +122,10 @@ def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs):
if ckpt_dir is not None: if ckpt_dir is not None:
print('load test model:', ckpt_dir) print('load test model:', ckpt_dir)
fluid.io.load_params(exe, ckpt_dir, main_program=test_prog) try:
fluid.load(test_prog, os.path.join(ckpt_dir, 'model'), exe)
except:
fluid.io.load_params(exe, ckpt_dir, main_program=test_prog)
# Use streaming confusion matrix to calculate mean_iou # Use streaming confusion matrix to calculate mean_iou
np.set_printoptions( np.set_printoptions(
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,8 +19,9 @@ from utils.config import cfg ...@@ -19,8 +19,9 @@ from utils.config import cfg
def unsorted_segment_sum(data, segment_ids, unique_labels, feature_dims): def unsorted_segment_sum(data, segment_ids, unique_labels, feature_dims):
zeros = fluid.layers.fill_constant_batch_size_like(unique_labels, shape=[1, feature_dims], unique_labels_shape = fluid.layers.shape(unique_labels)
dtype='float32', value=0) zeros = fluid.layers.fill_constant(
shape=[unique_labels_shape[0], feature_dims], dtype='float32', value=0)
segment_ids = fluid.layers.unsqueeze(segment_ids, axes=[1]) segment_ids = fluid.layers.unsqueeze(segment_ids, axes=[1])
segment_ids.stop_gradient = True segment_ids.stop_gradient = True
segment_sum = fluid.layers.scatter_nd_add(zeros, segment_ids, data) segment_sum = fluid.layers.scatter_nd_add(zeros, segment_ids, data)
...@@ -30,29 +31,23 @@ def unsorted_segment_sum(data, segment_ids, unique_labels, feature_dims): ...@@ -30,29 +31,23 @@ def unsorted_segment_sum(data, segment_ids, unique_labels, feature_dims):
def norm(x, axis=-1): def norm(x, axis=-1):
distance = fluid.layers.reduce_sum(fluid.layers.abs(x), dim=axis, keep_dim=True) distance = fluid.layers.reduce_sum(
fluid.layers.abs(x), dim=axis, keep_dim=True)
return distance return distance
def discriminative_loss_single(
prediction, def discriminative_loss_single(prediction, correct_label, feature_dim,
correct_label, label_shape, delta_v, delta_d, param_var,
feature_dim, param_dist, param_reg):
label_shape,
delta_v, correct_label = fluid.layers.reshape(correct_label,
delta_d, [label_shape[1] * label_shape[0]])
param_var,
param_dist,
param_reg):
correct_label = fluid.layers.reshape(
correct_label, [
label_shape[1] * label_shape[0]])
prediction = fluid.layers.transpose(prediction, [1, 2, 0]) prediction = fluid.layers.transpose(prediction, [1, 2, 0])
reshaped_pred = fluid.layers.reshape( reshaped_pred = fluid.layers.reshape(
prediction, [ prediction, [label_shape[1] * label_shape[0], feature_dim])
label_shape[1] * label_shape[0], feature_dim])
unique_labels, unique_id, counts = fluid.layers.unique_with_counts(correct_label) unique_labels, unique_id, counts = fluid.layers.unique_with_counts(
correct_label)
correct_label.stop_gradient = True correct_label.stop_gradient = True
counts = fluid.layers.cast(counts, 'float32') counts = fluid.layers.cast(counts, 'float32')
num_instances = fluid.layers.shape(unique_labels) num_instances = fluid.layers.shape(unique_labels)
...@@ -69,24 +64,29 @@ def discriminative_loss_single( ...@@ -69,24 +64,29 @@ def discriminative_loss_single(
distance = norm(tmp) distance = norm(tmp)
distance = distance - delta_v distance = distance - delta_v
distance_pos = fluid.layers.greater_equal(distance, fluid.layers.zeros_like(distance)) distance_pos = fluid.layers.greater_equal(distance,
fluid.layers.zeros_like(distance))
distance_pos = fluid.layers.cast(distance_pos, 'float32') distance_pos = fluid.layers.cast(distance_pos, 'float32')
distance = distance * distance_pos distance = distance * distance_pos
distance = fluid.layers.square(distance) distance = fluid.layers.square(distance)
l_var = unsorted_segment_sum(distance, unique_id, unique_labels, feature_dims=1) l_var = unsorted_segment_sum(
distance, unique_id, unique_labels, feature_dims=1)
l_var = fluid.layers.elementwise_div(l_var, counts_rsp) l_var = fluid.layers.elementwise_div(l_var, counts_rsp)
l_var = fluid.layers.reduce_sum(l_var) l_var = fluid.layers.reduce_sum(l_var)
l_var = l_var / fluid.layers.cast(num_instances * (num_instances - 1), 'float32') l_var = l_var / fluid.layers.cast(num_instances * (num_instances - 1),
'float32')
mu_interleaved_rep = fluid.layers.expand(mu, [num_instances, 1]) mu_interleaved_rep = fluid.layers.expand(mu, [num_instances, 1])
mu_band_rep = fluid.layers.expand(mu, [1, num_instances]) mu_band_rep = fluid.layers.expand(mu, [1, num_instances])
mu_band_rep = fluid.layers.reshape(mu_band_rep, (num_instances * num_instances, feature_dim)) mu_band_rep = fluid.layers.reshape(
mu_band_rep, (num_instances * num_instances, feature_dim))
mu_diff = fluid.layers.elementwise_sub(mu_band_rep, mu_interleaved_rep) mu_diff = fluid.layers.elementwise_sub(mu_band_rep, mu_interleaved_rep)
intermediate_tensor = fluid.layers.reduce_sum(fluid.layers.abs(mu_diff), dim=1) intermediate_tensor = fluid.layers.reduce_sum(
fluid.layers.abs(mu_diff), dim=1)
intermediate_tensor.stop_gradient = True intermediate_tensor.stop_gradient = True
zero_vector = fluid.layers.zeros([1], 'float32') zero_vector = fluid.layers.zeros([1], 'float32')
bool_mask = fluid.layers.not_equal(intermediate_tensor, zero_vector) bool_mask = fluid.layers.not_equal(intermediate_tensor, zero_vector)
...@@ -95,7 +95,8 @@ def discriminative_loss_single( ...@@ -95,7 +95,8 @@ def discriminative_loss_single(
mu_norm = norm(mu_diff_bool) mu_norm = norm(mu_diff_bool)
mu_norm = 2. * delta_d - mu_norm mu_norm = 2. * delta_d - mu_norm
mu_norm_pos = fluid.layers.greater_equal(mu_norm, fluid.layers.zeros_like(mu_norm)) mu_norm_pos = fluid.layers.greater_equal(mu_norm,
fluid.layers.zeros_like(mu_norm))
mu_norm_pos = fluid.layers.cast(mu_norm_pos, 'float32') mu_norm_pos = fluid.layers.cast(mu_norm_pos, 'float32')
mu_norm = mu_norm * mu_norm_pos mu_norm = mu_norm * mu_norm_pos
mu_norm_pos.stop_gradient = True mu_norm_pos.stop_gradient = True
...@@ -122,8 +123,8 @@ def discriminative_loss(prediction, correct_label, feature_dim, image_shape, ...@@ -122,8 +123,8 @@ def discriminative_loss(prediction, correct_label, feature_dim, image_shape,
output_ta_reg = 0. output_ta_reg = 0.
for i in range(batch_size): for i in range(batch_size):
disc_loss_single, l_var_single, l_dist_single, l_reg_single = discriminative_loss_single( disc_loss_single, l_var_single, l_dist_single, l_reg_single = discriminative_loss_single(
prediction[i], correct_label[i], feature_dim, image_shape, delta_v, delta_d, param_var, param_dist, prediction[i], correct_label[i], feature_dim, image_shape, delta_v,
param_reg) delta_d, param_var, param_dist, param_reg)
output_ta_loss += disc_loss_single output_ta_loss += disc_loss_single
output_ta_var += l_var_single output_ta_var += l_var_single
output_ta_dist += l_dist_single output_ta_dist += l_dist_single
...@@ -134,5 +135,3 @@ def discriminative_loss(prediction, correct_label, feature_dim, image_shape, ...@@ -134,5 +135,3 @@ def discriminative_loss(prediction, correct_label, feature_dim, image_shape,
l_dist = output_ta_dist / batch_size l_dist = output_ta_dist / batch_size
l_reg = output_ta_reg / batch_size l_reg = output_ta_reg / batch_size
return disc_loss, l_var, l_dist, l_reg return disc_loss, l_var, l_dist, l_reg
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,7 +18,6 @@ from __future__ import print_function ...@@ -18,7 +18,6 @@ from __future__ import print_function
import paddle.fluid as fluid import paddle.fluid as fluid
from utils.config import cfg from utils.config import cfg
from pdseg.models.libs.model_libs import scope, name_scope from pdseg.models.libs.model_libs import scope, name_scope
from pdseg.models.libs.model_libs import bn, bn_relu, relu from pdseg.models.libs.model_libs import bn, bn_relu, relu
...@@ -86,7 +85,12 @@ def bottleneck(inputs, ...@@ -86,7 +85,12 @@ def bottleneck(inputs,
with scope('down_sample'): with scope('down_sample'):
inputs_shape = inputs.shape inputs_shape = inputs.shape
with scope('main_max_pool'): with scope('main_max_pool'):
net_main = fluid.layers.conv2d(inputs, inputs_shape[1], filter_size=3, stride=2, padding='SAME') net_main = fluid.layers.conv2d(
inputs,
inputs_shape[1],
filter_size=3,
stride=2,
padding='SAME')
#First get the difference in depth to pad, then pad with zeros only on the last dimension. #First get the difference in depth to pad, then pad with zeros only on the last dimension.
depth_to_pad = abs(inputs_shape[1] - output_depth) depth_to_pad = abs(inputs_shape[1] - output_depth)
...@@ -95,12 +99,16 @@ def bottleneck(inputs, ...@@ -95,12 +99,16 @@ def bottleneck(inputs,
net_main = fluid.layers.pad(net_main, paddings=paddings) net_main = fluid.layers.pad(net_main, paddings=paddings)
with scope('block1'): with scope('block1'):
net = conv(inputs, reduced_depth, [2, 2], stride=2, padding='same') net = conv(
inputs, reduced_depth, [2, 2], stride=2, padding='same')
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
with scope('block2'): with scope('block2'):
net = conv(net, reduced_depth, [filter_size, filter_size], padding='same') net = conv(
net,
reduced_depth, [filter_size, filter_size],
padding='same')
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
...@@ -137,13 +145,18 @@ def bottleneck(inputs, ...@@ -137,13 +145,18 @@ def bottleneck(inputs,
# Second conv block --- apply dilated convolution here # Second conv block --- apply dilated convolution here
with scope('block2'): with scope('block2'):
net = conv(net, reduced_depth, filter_size, padding='SAME', dilation=dilation_rate) net = conv(
net,
reduced_depth,
filter_size,
padding='SAME',
dilation=dilation_rate)
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
# Final projection with 1x1 kernel (Expansion) # Final projection with 1x1 kernel (Expansion)
with scope('block3'): with scope('block3'):
net = conv(net, output_depth, [1,1]) net = conv(net, output_depth, [1, 1])
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
...@@ -172,9 +185,11 @@ def bottleneck(inputs, ...@@ -172,9 +185,11 @@ def bottleneck(inputs,
# Second conv block --- apply asymmetric conv here # Second conv block --- apply asymmetric conv here
with scope('block2'): with scope('block2'):
with scope('asymmetric_conv2a'): with scope('asymmetric_conv2a'):
net = conv(net, reduced_depth, [filter_size, 1], padding='same') net = conv(
net, reduced_depth, [filter_size, 1], padding='same')
with scope('asymmetric_conv2b'): with scope('asymmetric_conv2b'):
net = conv(net, reduced_depth, [1, filter_size], padding='same') net = conv(
net, reduced_depth, [1, filter_size], padding='same')
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
...@@ -211,7 +226,8 @@ def bottleneck(inputs, ...@@ -211,7 +226,8 @@ def bottleneck(inputs,
with scope('unpool'): with scope('unpool'):
net_unpool = conv(inputs, output_depth, [1, 1]) net_unpool = conv(inputs, output_depth, [1, 1])
net_unpool = bn(net_unpool) net_unpool = bn(net_unpool)
net_unpool = fluid.layers.resize_bilinear(net_unpool, out_shape=output_shape[2:]) net_unpool = fluid.layers.resize_bilinear(
net_unpool, out_shape=output_shape[2:])
# First 1x1 projection to reduce depth # First 1x1 projection to reduce depth
with scope('block1'): with scope('block1'):
...@@ -220,7 +236,12 @@ def bottleneck(inputs, ...@@ -220,7 +236,12 @@ def bottleneck(inputs,
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
with scope('block2'): with scope('block2'):
net = deconv(net, reduced_depth, filter_size=filter_size, stride=2, padding='same') net = deconv(
net,
reduced_depth,
filter_size=filter_size,
stride=2,
padding='same')
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
...@@ -253,7 +274,10 @@ def bottleneck(inputs, ...@@ -253,7 +274,10 @@ def bottleneck(inputs,
# Second conv block # Second conv block
with scope('block2'): with scope('block2'):
net = conv(net, reduced_depth, [filter_size, filter_size], padding='same') net = conv(
net,
reduced_depth, [filter_size, filter_size],
padding='same')
net = bn(net) net = bn(net)
net = prelu(net, decoder=decoder) net = prelu(net, decoder=decoder)
...@@ -281,17 +305,33 @@ def ENet_stage1(inputs, name_scope='stage1_block'): ...@@ -281,17 +305,33 @@ def ENet_stage1(inputs, name_scope='stage1_block'):
= bottleneck(inputs, output_depth=64, filter_size=3, regularizer_prob=0.01, type=DOWNSAMPLING, = bottleneck(inputs, output_depth=64, filter_size=3, regularizer_prob=0.01, type=DOWNSAMPLING,
name_scope='bottleneck1_0') name_scope='bottleneck1_0')
with scope('bottleneck1_1'): with scope('bottleneck1_1'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01, net = bottleneck(
name_scope='bottleneck1_1') net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_1')
with scope('bottleneck1_2'): with scope('bottleneck1_2'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01, net = bottleneck(
name_scope='bottleneck1_2') net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_2')
with scope('bottleneck1_3'): with scope('bottleneck1_3'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01, net = bottleneck(
name_scope='bottleneck1_3') net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_3')
with scope('bottleneck1_4'): with scope('bottleneck1_4'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01, net = bottleneck(
name_scope='bottleneck1_4') net,
output_depth=64,
filter_size=3,
regularizer_prob=0.01,
name_scope='bottleneck1_4')
return net, inputs_shape_1 return net, inputs_shape_1
...@@ -302,17 +342,38 @@ def ENet_stage2(inputs, name_scope='stage2_block'): ...@@ -302,17 +342,38 @@ def ENet_stage2(inputs, name_scope='stage2_block'):
name_scope='bottleneck2_0') name_scope='bottleneck2_0')
for i in range(2): for i in range(2):
with scope('bottleneck2_{}'.format(str(4 * i + 1))): with scope('bottleneck2_{}'.format(str(4 * i + 1))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, net = bottleneck(
name_scope='bottleneck2_{}'.format(str(4 * i + 1))) net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
name_scope='bottleneck2_{}'.format(str(4 * i + 1)))
with scope('bottleneck2_{}'.format(str(4 * i + 2))): with scope('bottleneck2_{}'.format(str(4 * i + 2))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+1)), net = bottleneck(
name_scope='bottleneck2_{}'.format(str(4 * i + 2))) net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 1)),
name_scope='bottleneck2_{}'.format(str(4 * i + 2)))
with scope('bottleneck2_{}'.format(str(4 * i + 3))): with scope('bottleneck2_{}'.format(str(4 * i + 3))):
net = bottleneck(net, output_depth=128, filter_size=5, regularizer_prob=0.1, type=ASYMMETRIC, net = bottleneck(
name_scope='bottleneck2_{}'.format(str(4 * i + 3))) net,
output_depth=128,
filter_size=5,
regularizer_prob=0.1,
type=ASYMMETRIC,
name_scope='bottleneck2_{}'.format(str(4 * i + 3)))
with scope('bottleneck2_{}'.format(str(4 * i + 4))): with scope('bottleneck2_{}'.format(str(4 * i + 4))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+2)), net = bottleneck(
name_scope='bottleneck2_{}'.format(str(4 * i + 4))) net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 2)),
name_scope='bottleneck2_{}'.format(str(4 * i + 4)))
return net, inputs_shape_2 return net, inputs_shape_2
...@@ -320,52 +381,106 @@ def ENet_stage3(inputs, name_scope='stage3_block'): ...@@ -320,52 +381,106 @@ def ENet_stage3(inputs, name_scope='stage3_block'):
with scope(name_scope): with scope(name_scope):
for i in range(2): for i in range(2):
with scope('bottleneck3_{}'.format(str(4 * i + 0))): with scope('bottleneck3_{}'.format(str(4 * i + 0))):
net = bottleneck(inputs, output_depth=128, filter_size=3, regularizer_prob=0.1, net = bottleneck(
name_scope='bottleneck3_{}'.format(str(4 * i + 0))) inputs,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
name_scope='bottleneck3_{}'.format(str(4 * i + 0)))
with scope('bottleneck3_{}'.format(str(4 * i + 1))): with scope('bottleneck3_{}'.format(str(4 * i + 1))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+1)), net = bottleneck(
name_scope='bottleneck3_{}'.format(str(4 * i + 1))) net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 1)),
name_scope='bottleneck3_{}'.format(str(4 * i + 1)))
with scope('bottleneck3_{}'.format(str(4 * i + 2))): with scope('bottleneck3_{}'.format(str(4 * i + 2))):
net = bottleneck(net, output_depth=128, filter_size=5, regularizer_prob=0.1, type=ASYMMETRIC, net = bottleneck(
name_scope='bottleneck3_{}'.format(str(4 * i + 2))) net,
output_depth=128,
filter_size=5,
regularizer_prob=0.1,
type=ASYMMETRIC,
name_scope='bottleneck3_{}'.format(str(4 * i + 2)))
with scope('bottleneck3_{}'.format(str(4 * i + 3))): with scope('bottleneck3_{}'.format(str(4 * i + 3))):
net = bottleneck(net, output_depth=128, filter_size=3, regularizer_prob=0.1, type=DILATED, dilation_rate=(2 ** (2*i+2)), net = bottleneck(
name_scope='bottleneck3_{}'.format(str(4 * i + 3))) net,
output_depth=128,
filter_size=3,
regularizer_prob=0.1,
type=DILATED,
dilation_rate=(2**(2 * i + 2)),
name_scope='bottleneck3_{}'.format(str(4 * i + 3)))
return net return net
def ENet_stage4(inputs, inputs_shape, connect_tensor, def ENet_stage4(inputs,
skip_connections=True, name_scope='stage4_block'): inputs_shape,
connect_tensor,
skip_connections=True,
name_scope='stage4_block'):
with scope(name_scope): with scope(name_scope):
with scope('bottleneck4_0'): with scope('bottleneck4_0'):
net = bottleneck(inputs, output_depth=64, filter_size=3, regularizer_prob=0.1, net = bottleneck(
type=UPSAMPLING, decoder=True, output_shape=inputs_shape, inputs,
name_scope='bottleneck4_0') output_depth=64,
filter_size=3,
regularizer_prob=0.1,
type=UPSAMPLING,
decoder=True,
output_shape=inputs_shape,
name_scope='bottleneck4_0')
if skip_connections: if skip_connections:
net = fluid.layers.elementwise_add(net, connect_tensor) net = fluid.layers.elementwise_add(net, connect_tensor)
with scope('bottleneck4_1'): with scope('bottleneck4_1'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.1, decoder=True, net = bottleneck(
name_scope='bottleneck4_1') net,
output_depth=64,
filter_size=3,
regularizer_prob=0.1,
decoder=True,
name_scope='bottleneck4_1')
with scope('bottleneck4_2'): with scope('bottleneck4_2'):
net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.1, decoder=True, net = bottleneck(
name_scope='bottleneck4_2') net,
output_depth=64,
filter_size=3,
regularizer_prob=0.1,
decoder=True,
name_scope='bottleneck4_2')
return net return net
def ENet_stage5(inputs, inputs_shape, connect_tensor, skip_connections=True, def ENet_stage5(inputs,
inputs_shape,
connect_tensor,
skip_connections=True,
name_scope='stage5_block'): name_scope='stage5_block'):
with scope(name_scope): with scope(name_scope):
net = bottleneck(inputs, output_depth=16, filter_size=3, regularizer_prob=0.1, type=UPSAMPLING, net = bottleneck(
decoder=True, output_shape=inputs_shape, inputs,
name_scope='bottleneck5_0') output_depth=16,
filter_size=3,
regularizer_prob=0.1,
type=UPSAMPLING,
decoder=True,
output_shape=inputs_shape,
name_scope='bottleneck5_0')
if skip_connections: if skip_connections:
net = fluid.layers.elementwise_add(net, connect_tensor) net = fluid.layers.elementwise_add(net, connect_tensor)
with scope('bottleneck5_1'): with scope('bottleneck5_1'):
net = bottleneck(net, output_depth=16, filter_size=3, regularizer_prob=0.1, decoder=True, net = bottleneck(
name_scope='bottleneck5_1') net,
output_depth=16,
filter_size=3,
regularizer_prob=0.1,
decoder=True,
name_scope='bottleneck5_1')
return net return net
...@@ -378,14 +493,16 @@ def decoder(input, num_classes): ...@@ -378,14 +493,16 @@ def decoder(input, num_classes):
segStage3 = ENet_stage3(stage2) segStage3 = ENet_stage3(stage2)
segStage4 = ENet_stage4(segStage3, inputs_shape_2, stage1) segStage4 = ENet_stage4(segStage3, inputs_shape_2, stage1)
segStage5 = ENet_stage5(segStage4, inputs_shape_1, initial) segStage5 = ENet_stage5(segStage4, inputs_shape_1, initial)
segLogits = deconv(segStage5, num_classes, filter_size=2, stride=2, padding='SAME') segLogits = deconv(
segStage5, num_classes, filter_size=2, stride=2, padding='SAME')
# Embedding branch # Embedding branch
with scope('LaneNetEm'): with scope('LaneNetEm'):
emStage3 = ENet_stage3(stage2) emStage3 = ENet_stage3(stage2)
emStage4 = ENet_stage4(emStage3, inputs_shape_2, stage1) emStage4 = ENet_stage4(emStage3, inputs_shape_2, stage1)
emStage5 = ENet_stage5(emStage4, inputs_shape_1, initial) emStage5 = ENet_stage5(emStage4, inputs_shape_1, initial)
emLogits = deconv(emStage5, 4, filter_size=2, stride=2, padding='SAME') emLogits = deconv(
emStage5, 4, filter_size=2, stride=2, padding='SAME')
elif 'vgg' in cfg.MODEL.LANENET.BACKBONE: elif 'vgg' in cfg.MODEL.LANENET.BACKBONE:
encoder_list = ['pool5', 'pool4', 'pool3'] encoder_list = ['pool5', 'pool4', 'pool3']
...@@ -396,14 +513,16 @@ def decoder(input, num_classes): ...@@ -396,14 +513,16 @@ def decoder(input, num_classes):
encoder_list = encoder_list[1:] encoder_list = encoder_list[1:]
for i in range(len(encoder_list)): for i in range(len(encoder_list)):
with scope('deconv_{:d}'.format(i + 1)): with scope('deconv_{:d}'.format(i + 1)):
deconv_out = deconv(score, 64, filter_size=4, stride=2, padding='SAME') deconv_out = deconv(
score, 64, filter_size=4, stride=2, padding='SAME')
input_tensor = input[encoder_list[i]] input_tensor = input[encoder_list[i]]
with scope('score_{:d}'.format(i + 1)): with scope('score_{:d}'.format(i + 1)):
score = conv(input_tensor, 64, 1) score = conv(input_tensor, 64, 1)
score = fluid.layers.elementwise_add(deconv_out, score) score = fluid.layers.elementwise_add(deconv_out, score)
with scope('deconv_final'): with scope('deconv_final'):
emLogits = deconv(score, 64, filter_size=16, stride=8, padding='SAME') emLogits = deconv(
score, 64, filter_size=16, stride=8, padding='SAME')
with scope('score_final'): with scope('score_final'):
segLogits = conv(emLogits, num_classes, 1) segLogits = conv(emLogits, num_classes, 1)
emLogits = relu(conv(emLogits, 4, 1)) emLogits = relu(conv(emLogits, 4, 1))
...@@ -415,7 +534,8 @@ def encoder(input): ...@@ -415,7 +534,8 @@ def encoder(input):
model = vgg_backbone(layers=16) model = vgg_backbone(layers=16)
#output = model.net(input) #output = model.net(input)
_, encode_feature_dict = model.net(input, end_points=13, decode_points=[7, 10, 13]) _, encode_feature_dict = model.net(
input, end_points=13, decode_points=[7, 10, 13])
output = {} output = {}
output['pool3'] = encode_feature_dict[7] output['pool3'] = encode_feature_dict[7]
output['pool4'] = encode_feature_dict[10] output['pool4'] = encode_feature_dict[10]
...@@ -427,8 +547,9 @@ def encoder(input): ...@@ -427,8 +547,9 @@ def encoder(input):
stage2, inputs_shape_2 = ENet_stage2(stage1) stage2, inputs_shape_2 = ENet_stage2(stage1)
output = (initial, stage1, stage2, inputs_shape_1, inputs_shape_2) output = (initial, stage1, stage2, inputs_shape_1, inputs_shape_2)
else: else:
raise Exception("LaneNet expect enet and vgg backbone, but received {}". raise Exception(
format(cfg.MODEL.LANENET.BACKBONE)) "LaneNet expect enet and vgg backbone, but received {}".format(
cfg.MODEL.LANENET.BACKBONE))
return output return output
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -58,7 +58,8 @@ class LaneNetDataset(): ...@@ -58,7 +58,8 @@ class LaneNetDataset():
if self.shuffle and cfg.NUM_TRAINERS > 1: if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines) np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // cfg.NUM_TRAINERS num_lines = len(self.all_lines) // cfg.NUM_TRAINERS
self.lines = self.all_lines[num_lines * cfg.TRAINER_ID: num_lines * (cfg.TRAINER_ID + 1)] self.lines = self.all_lines[num_lines * cfg.TRAINER_ID:num_lines *
(cfg.TRAINER_ID + 1)]
self.shuffle_seed += 1 self.shuffle_seed += 1
elif self.shuffle: elif self.shuffle:
np.random.shuffle(self.lines) np.random.shuffle(self.lines)
...@@ -86,7 +87,8 @@ class LaneNetDataset(): ...@@ -86,7 +87,8 @@ class LaneNetDataset():
if self.shuffle and cfg.NUM_TRAINERS > 1: if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines) np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // self.num_trainers num_lines = len(self.all_lines) // self.num_trainers
self.lines = self.all_lines[num_lines * self.trainer_id: num_lines * (self.trainer_id + 1)] self.lines = self.all_lines[num_lines * self.trainer_id:num_lines *
(self.trainer_id + 1)]
self.shuffle_seed += 1 self.shuffle_seed += 1
elif self.shuffle: elif self.shuffle:
np.random.shuffle(self.lines) np.random.shuffle(self.lines)
...@@ -118,7 +120,8 @@ class LaneNetDataset(): ...@@ -118,7 +120,8 @@ class LaneNetDataset():
def batch_reader(is_test=False, drop_last=drop_last): def batch_reader(is_test=False, drop_last=drop_last):
if is_test: if is_test:
imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], [] imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], []
for img, grt, grt_instance, img_name, valid_shape, org_shape in reader(): for img, grt, grt_instance, img_name, valid_shape, org_shape in reader(
):
imgs.append(img) imgs.append(img)
grts.append(grt) grts.append(grt)
grts_instance.append(grt_instance) grts_instance.append(grt_instance)
...@@ -126,14 +129,15 @@ class LaneNetDataset(): ...@@ -126,14 +129,15 @@ class LaneNetDataset():
valid_shapes.append(valid_shape) valid_shapes.append(valid_shape)
org_shapes.append(org_shape) org_shapes.append(org_shape)
if len(imgs) == batch_size: if len(imgs) == batch_size:
yield np.array(imgs), np.array( yield np.array(imgs), np.array(grts), np.array(
grts), np.array(grts_instance), img_names, np.array(valid_shapes), np.array( grts_instance), img_names, np.array(
org_shapes) valid_shapes), np.array(org_shapes)
imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], [] imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], []
if not drop_last and len(imgs) > 0: if not drop_last and len(imgs) > 0:
yield np.array(imgs), np.array(grts), np.array(grts_instance), img_names, np.array( yield np.array(imgs), np.array(grts), np.array(
valid_shapes), np.array(org_shapes) grts_instance), img_names, np.array(
valid_shapes), np.array(org_shapes)
else: else:
imgs, labs, labs_instance, ignore = [], [], [], [] imgs, labs, labs_instance, ignore = [], [], [], []
bs = 0 bs = 0
...@@ -144,12 +148,14 @@ class LaneNetDataset(): ...@@ -144,12 +148,14 @@ class LaneNetDataset():
ignore.append(ig) ignore.append(ig)
bs += 1 bs += 1
if bs == batch_size: if bs == batch_size:
yield np.array(imgs), np.array(labs), np.array(labs_instance), np.array(ignore) yield np.array(imgs), np.array(labs), np.array(
labs_instance), np.array(ignore)
bs = 0 bs = 0
imgs, labs, labs_instance, ignore = [], [], [], [] imgs, labs, labs_instance, ignore = [], [], [], []
if not drop_last and bs > 0: if not drop_last and bs > 0:
yield np.array(imgs), np.array(labs), np.array(labs_instance), np.array(ignore) yield np.array(imgs), np.array(labs), np.array(
labs_instance), np.array(ignore)
return batch_reader(is_test, drop_last) return batch_reader(is_test, drop_last)
...@@ -299,10 +305,12 @@ class LaneNetDataset(): ...@@ -299,10 +305,12 @@ class LaneNetDataset():
img, grt = aug.rand_crop(img, grt, mode=mode) img, grt = aug.rand_crop(img, grt, mode=mode)
elif ModelPhase.is_eval(mode): elif ModelPhase.is_eval(mode):
img, grt, grt_instance = aug.resize(img, grt, grt_instance, mode=mode) img, grt, grt_instance = aug.resize(
img, grt, grt_instance, mode=mode)
elif ModelPhase.is_visual(mode): elif ModelPhase.is_visual(mode):
ori_img = img.copy() ori_img = img.copy()
img, grt, grt_instance = aug.resize(img, grt, grt_instance, mode=mode) img, grt, grt_instance = aug.resize(
img, grt, grt_instance, mode=mode)
valid_shape = [img.shape[0], img.shape[1]] valid_shape = [img.shape[0], img.shape[1]]
else: else:
raise ValueError("Dataset mode={} Error!".format(mode)) raise ValueError("Dataset mode={} Error!".format(mode))
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -40,10 +40,10 @@ from pdseg.utils.timer import Timer, calculate_eta ...@@ -40,10 +40,10 @@ from pdseg.utils.timer import Timer, calculate_eta
from reader import LaneNetDataset from reader import LaneNetDataset
from models.model_builder import build_model from models.model_builder import build_model
from models.model_builder import ModelPhase from models.model_builder import ModelPhase
from models.model_builder import parse_shape_from_file
from eval import evaluate from eval import evaluate
from vis import visualize from vis import visualize
from utils import dist_utils from utils import dist_utils
from utils.load_model_utils import load_pretrained_weights
def parse_args(): def parse_args():
...@@ -101,37 +101,6 @@ def parse_args(): ...@@ -101,37 +101,6 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def save_vars(executor, dirname, program=None, vars=None):
"""
Temporary resolution for Win save variables compatability.
Will fix in PaddlePaddle v1.5.2
"""
save_program = fluid.Program()
save_block = save_program.global_block()
for each_var in vars:
# NOTE: don't save the variable which type is RAW
if each_var.type == fluid.core.VarDesc.VarType.RAW:
continue
new_var = save_block.create_var(
name=each_var.name,
shape=each_var.shape,
dtype=each_var.dtype,
type=each_var.type,
lod_level=each_var.lod_level,
persistable=True)
file_path = os.path.join(dirname, new_var.name)
file_path = os.path.normpath(file_path)
save_block.append_op(
type='save',
inputs={'X': [new_var]},
outputs={},
attrs={'file_path': file_path})
executor.run(save_program)
def save_checkpoint(exe, program, ckpt_name): def save_checkpoint(exe, program, ckpt_name):
""" """
Save checkpoint for evaluation or resume training Save checkpoint for evaluation or resume training
...@@ -141,29 +110,22 @@ def save_checkpoint(exe, program, ckpt_name): ...@@ -141,29 +110,22 @@ def save_checkpoint(exe, program, ckpt_name):
if not os.path.isdir(ckpt_dir): if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir) os.makedirs(ckpt_dir)
save_vars( fluid.save(program, os.path.join(ckpt_dir, 'model'))
exe,
ckpt_dir,
program,
vars=list(filter(fluid.io.is_persistable, program.list_vars())))
return ckpt_dir return ckpt_dir
def load_checkpoint(exe, program): def load_checkpoint(exe, program):
""" """
Load checkpoiont from pretrained model directory for resume training Load checkpoiont for resuming training
""" """
print('Resume model training from:', cfg.TRAIN.RESUME_MODEL_DIR)
if not os.path.exists(cfg.TRAIN.RESUME_MODEL_DIR):
raise ValueError("TRAIN.PRETRAIN_MODEL {} not exist!".format(
cfg.TRAIN.RESUME_MODEL_DIR))
fluid.io.load_persistables(
exe, cfg.TRAIN.RESUME_MODEL_DIR, main_program=program)
model_path = cfg.TRAIN.RESUME_MODEL_DIR model_path = cfg.TRAIN.RESUME_MODEL_DIR
print('Resume model training from:', model_path)
if not os.path.exists(model_path):
raise ValueError(
"TRAIN.PRETRAIN_MODEL {} not exist!".format(model_path))
fluid.load(program, os.path.join(model_path, 'model'), exe)
# Check is path ended by path spearator # Check is path ended by path spearator
if model_path[-1] == os.sep: if model_path[-1] == os.sep:
model_path = model_path[0:-1] model_path = model_path[0:-1]
...@@ -178,7 +140,6 @@ def load_checkpoint(exe, program): ...@@ -178,7 +140,6 @@ def load_checkpoint(exe, program):
else: else:
raise ValueError("Resume model path is not valid!") raise ValueError("Resume model path is not valid!")
print("Model checkpoint loaded successfully!") print("Model checkpoint loaded successfully!")
return begin_epoch return begin_epoch
...@@ -271,44 +232,7 @@ def train(cfg): ...@@ -271,44 +232,7 @@ def train(cfg):
begin_epoch = load_checkpoint(exe, train_prog) begin_epoch = load_checkpoint(exe, train_prog)
# Load pretrained model # Load pretrained model
elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL_DIR): elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL_DIR):
print_info('Pretrained model dir: ', cfg.TRAIN.PRETRAINED_MODEL_DIR) load_pretrained_weights(exe, train_prog, cfg.TRAIN.PRETRAINED_MODEL_DIR)
load_vars = []
load_fail_vars = []
def var_shape_matched(var, shape):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist = os.path.exists(
os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
if var_exist:
var_shape = parse_shape_from_file(
os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
if var_shape != shape:
print(var.name, var_shape, shape)
return var_shape == shape
return False
for x in train_prog.list_vars():
if isinstance(x, fluid.framework.Parameter):
shape = tuple(fluid.global_scope().find_var(
x.name).get_tensor().shape())
if var_shape_matched(x, shape):
load_vars.append(x)
else:
load_fail_vars.append(x)
fluid.io.load_vars(
exe, dirname=cfg.TRAIN.PRETRAINED_MODEL_DIR, vars=load_vars)
for var in load_vars:
print_info("Parameter[{}] loaded sucessfully!".format(var.name))
for var in load_fail_vars:
print_info(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it.".format(var.name))
print_info("{}/{} pretrained parameters loaded successfully!".format(
len(load_vars),
len(load_vars) + len(load_fail_vars)))
else: else:
print_info( print_info(
'Pretrained model dir {} not exists, training from scratch...'. 'Pretrained model dir {} not exists, training from scratch...'.
...@@ -393,8 +317,7 @@ def train(cfg): ...@@ -393,8 +317,7 @@ def train(cfg):
avg_emb_loss, avg_acc, avg_fp, avg_fn, speed, avg_emb_loss, avg_acc, avg_fp, avg_fn, speed,
calculate_eta(all_step - step, speed))) calculate_eta(all_step - step, speed)))
if args.use_vdl: if args.use_vdl:
log_writer.add_scalar('Train/loss', avg_loss, log_writer.add_scalar('Train/loss', avg_loss, step)
step)
log_writer.add_scalar('Train/lr', lr[0], step) log_writer.add_scalar('Train/lr', lr[0], step)
log_writer.add_scalar('Train/speed', speed, step) log_writer.add_scalar('Train/speed', speed, step)
sys.stdout.flush() sys.stdout.flush()
...@@ -423,8 +346,7 @@ def train(cfg): ...@@ -423,8 +346,7 @@ def train(cfg):
use_gpu=args.use_gpu, use_gpu=args.use_gpu,
use_mpio=args.use_mpio) use_mpio=args.use_mpio)
if args.use_vdl: if args.use_vdl:
log_writer.add_scalar('Evaluate/accuracy', accuracy, log_writer.add_scalar('Evaluate/accuracy', accuracy, step)
step)
log_writer.add_scalar('Evaluate/fp', fp, step) log_writer.add_scalar('Evaluate/fp', fp, step)
log_writer.add_scalar('Evaluate/fn', fn, step) log_writer.add_scalar('Evaluate/fn', fn, step)
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -80,8 +80,8 @@ cfg.DATASET.DATA_DIM = 3 ...@@ -80,8 +80,8 @@ cfg.DATASET.DATA_DIM = 3
cfg.DATASET.SEPARATOR = ' ' cfg.DATASET.SEPARATOR = ' '
# 忽略的像素标签值, 默认为255,一般无需改动 # 忽略的像素标签值, 默认为255,一般无需改动
cfg.DATASET.IGNORE_INDEX = 255 cfg.DATASET.IGNORE_INDEX = 255
# 数据增强是图像的padding值 # 数据增强是图像的padding值
cfg.DATASET.PADDING_VALUE = [127.5,127.5,127.5] cfg.DATASET.PADDING_VALUE = [127.5, 127.5, 127.5]
########################### 数据增强配置 ###################################### ########################### 数据增强配置 ######################################
# 图像镜像左右翻转 # 图像镜像左右翻转
......
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
generate tusimple training dataset generate tusimple training dataset
""" """
...@@ -14,12 +28,16 @@ import numpy as np ...@@ -14,12 +28,16 @@ import numpy as np
def init_args(): def init_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--src_dir', type=str, help='The origin path of unzipped tusimple dataset') parser.add_argument(
'--src_dir',
type=str,
help='The origin path of unzipped tusimple dataset')
return parser.parse_args() return parser.parse_args()
def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, instance_dst_dir): def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir,
instance_dst_dir):
assert ops.exists(json_file_path), '{:s} not exist'.format(json_file_path) assert ops.exists(json_file_path), '{:s} not exist'.format(json_file_path)
...@@ -39,11 +57,14 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst ...@@ -39,11 +57,14 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
h_samples = info_dict['h_samples'] h_samples = info_dict['h_samples']
lanes = info_dict['lanes'] lanes = info_dict['lanes']
image_name_new = '{:s}.png'.format('{:d}'.format(line_index + image_nums).zfill(4)) image_name_new = '{:s}.png'.format(
'{:d}'.format(line_index + image_nums).zfill(4))
src_image = cv2.imread(image_path, cv2.IMREAD_COLOR) src_image = cv2.imread(image_path, cv2.IMREAD_COLOR)
dst_binary_image = np.zeros([src_image.shape[0], src_image.shape[1]], np.uint8) dst_binary_image = np.zeros(
dst_instance_image = np.zeros([src_image.shape[0], src_image.shape[1]], np.uint8) [src_image.shape[0], src_image.shape[1]], np.uint8)
dst_instance_image = np.zeros(
[src_image.shape[0], src_image.shape[1]], np.uint8)
for lane_index, lane in enumerate(lanes): for lane_index, lane in enumerate(lanes):
assert len(h_samples) == len(lane) assert len(h_samples) == len(lane)
...@@ -62,13 +83,23 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst ...@@ -62,13 +83,23 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
lane_pts = np.vstack((lane_x, lane_y)).transpose() lane_pts = np.vstack((lane_x, lane_y)).transpose()
lane_pts = np.array([lane_pts], np.int64) lane_pts = np.array([lane_pts], np.int64)
cv2.polylines(dst_binary_image, lane_pts, isClosed=False, cv2.polylines(
color=255, thickness=5) dst_binary_image,
cv2.polylines(dst_instance_image, lane_pts, isClosed=False, lane_pts,
color=lane_index * 50 + 20, thickness=5) isClosed=False,
color=255,
dst_binary_image_path = ops.join(src_dir, binary_dst_dir, image_name_new) thickness=5)
dst_instance_image_path = ops.join(src_dir, instance_dst_dir, image_name_new) cv2.polylines(
dst_instance_image,
lane_pts,
isClosed=False,
color=lane_index * 50 + 20,
thickness=5)
dst_binary_image_path = ops.join(src_dir, binary_dst_dir,
image_name_new)
dst_instance_image_path = ops.join(src_dir, instance_dst_dir,
image_name_new)
dst_rgb_image_path = ops.join(src_dir, ori_dst_dir, image_name_new) dst_rgb_image_path = ops.join(src_dir, ori_dst_dir, image_name_new)
cv2.imwrite(dst_binary_image_path, dst_binary_image) cv2.imwrite(dst_binary_image_path, dst_binary_image)
...@@ -78,7 +109,12 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst ...@@ -78,7 +109,12 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
print('Process {:s} success'.format(image_name)) print('Process {:s} success'.format(image_name))
def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train', split=False): def gen_sample(src_dir,
b_gt_image_dir,
i_gt_image_dir,
image_dir,
phase='train',
split=False):
label_list = [] label_list = []
with open('{:s}/{}ing/{}.txt'.format(src_dir, phase, phase), 'w') as file: with open('{:s}/{}ing/{}.txt'.format(src_dir, phase, phase), 'w') as file:
...@@ -92,7 +128,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train' ...@@ -92,7 +128,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
image_path = ops.join(image_dir, image_name) image_path = ops.join(image_dir, image_name)
assert ops.exists(image_path), '{:s} not exist'.format(image_path) assert ops.exists(image_path), '{:s} not exist'.format(image_path)
assert ops.exists(instance_gt_image_path), '{:s} not exist'.format(instance_gt_image_path) assert ops.exists(instance_gt_image_path), '{:s} not exist'.format(
instance_gt_image_path)
b_gt_image = cv2.imread(binary_gt_image_path, cv2.IMREAD_COLOR) b_gt_image = cv2.imread(binary_gt_image_path, cv2.IMREAD_COLOR)
i_gt_image = cv2.imread(instance_gt_image_path, cv2.IMREAD_COLOR) i_gt_image = cv2.imread(instance_gt_image_path, cv2.IMREAD_COLOR)
...@@ -102,7 +139,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train' ...@@ -102,7 +139,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
print('image: {:s} corrupt'.format(image_name)) print('image: {:s} corrupt'.format(image_name))
continue continue
else: else:
info = '{:s} {:s} {:s}'.format(image_path, binary_gt_image_path, instance_gt_image_path) info = '{:s} {:s} {:s}'.format(image_path, binary_gt_image_path,
instance_gt_image_path)
file.write(info + '\n') file.write(info + '\n')
label_list.append(info) label_list.append(info)
if phase == 'train' and split: if phase == 'train' and split:
...@@ -110,10 +148,12 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train' ...@@ -110,10 +148,12 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
val_list_len = len(label_list) // 10 val_list_len = len(label_list) // 10
val_label_list = label_list[:val_list_len] val_label_list = label_list[:val_list_len]
train_label_list = label_list[val_list_len:] train_label_list = label_list[val_list_len:]
with open('{:s}/{}ing/train_part.txt'.format(src_dir, phase, phase), 'w') as file: with open('{:s}/{}ing/train_part.txt'.format(src_dir, phase, phase),
'w') as file:
for info in train_label_list: for info in train_label_list:
file.write(info + '\n') file.write(info + '\n')
with open('{:s}/{}ing/val_part.txt'.format(src_dir, phase, phase), 'w') as file: with open('{:s}/{}ing/val_part.txt'.format(src_dir, phase, phase),
'w') as file:
for info in val_label_list: for info in val_label_list:
file.write(info + '\n') file.write(info + '\n')
return return
...@@ -130,12 +170,14 @@ def process_tusimple_dataset(src_dir): ...@@ -130,12 +170,14 @@ def process_tusimple_dataset(src_dir):
for json_label_path in glob.glob('{:s}/label*.json'.format(src_dir)): for json_label_path in glob.glob('{:s}/label*.json'.format(src_dir)):
json_label_name = ops.split(json_label_path)[1] json_label_name = ops.split(json_label_path)[1]
shutil.copyfile(json_label_path, ops.join(traing_folder_path, json_label_name)) shutil.copyfile(json_label_path,
ops.join(traing_folder_path, json_label_name))
for json_label_path in glob.glob('{:s}/test_label.json'.format(src_dir)): for json_label_path in glob.glob('{:s}/test_label.json'.format(src_dir)):
json_label_name = ops.split(json_label_path)[1] json_label_name = ops.split(json_label_path)[1]
shutil.copyfile(json_label_path, ops.join(testing_folder_path, json_label_name)) shutil.copyfile(json_label_path,
ops.join(testing_folder_path, json_label_name))
train_gt_image_dir = ops.join('training', 'gt_image') train_gt_image_dir = ops.join('training', 'gt_image')
train_gt_binary_dir = ops.join('training', 'gt_binary_image') train_gt_binary_dir = ops.join('training', 'gt_binary_image')
...@@ -154,9 +196,11 @@ def process_tusimple_dataset(src_dir): ...@@ -154,9 +196,11 @@ def process_tusimple_dataset(src_dir):
os.makedirs(os.path.join(src_dir, test_gt_instance_dir), exist_ok=True) os.makedirs(os.path.join(src_dir, test_gt_instance_dir), exist_ok=True)
for json_label_path in glob.glob('{:s}/*.json'.format(traing_folder_path)): for json_label_path in glob.glob('{:s}/*.json'.format(traing_folder_path)):
process_json_file(json_label_path, src_dir, train_gt_image_dir, train_gt_binary_dir, train_gt_instance_dir) process_json_file(json_label_path, src_dir, train_gt_image_dir,
train_gt_binary_dir, train_gt_instance_dir)
gen_sample(src_dir, train_gt_binary_dir, train_gt_instance_dir, train_gt_image_dir, 'train', True) gen_sample(src_dir, train_gt_binary_dir, train_gt_instance_dir,
train_gt_image_dir, 'train', True)
if __name__ == '__main__': if __name__ == '__main__':
......
#!/usr/bin/env python3 # coding: utf8
# -*- coding: utf-8 -*- # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this code heavily base on https://github.com/MaybeShewill-CV/lanenet-lane-detection/blob/master/lanenet_model/lanenet_postprocess.py # this code heavily base on https://github.com/MaybeShewill-CV/lanenet-lane-detection/blob/master/lanenet_model/lanenet_postprocess.py
""" """
LaneNet model post process LaneNet model post process
...@@ -22,12 +35,14 @@ def _morphological_process(image, kernel_size=5): ...@@ -22,12 +35,14 @@ def _morphological_process(image, kernel_size=5):
:return: :return:
""" """
if len(image.shape) == 3: if len(image.shape) == 3:
raise ValueError('Binary segmentation result image should be a single channel image') raise ValueError(
'Binary segmentation result image should be a single channel image')
if image.dtype is not np.uint8: if image.dtype is not np.uint8:
image = np.array(image, np.uint8) image = np.array(image, np.uint8)
kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(kernel_size, kernel_size)) kernel = cv2.getStructuringElement(
shape=cv2.MORPH_ELLIPSE, ksize=(kernel_size, kernel_size))
# close operation fille hole # close operation fille hole
closing = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel, iterations=1) closing = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel, iterations=1)
...@@ -46,13 +61,15 @@ def _connect_components_analysis(image): ...@@ -46,13 +61,15 @@ def _connect_components_analysis(image):
else: else:
gray_image = image gray_image = image
return cv2.connectedComponentsWithStats(gray_image, connectivity=8, ltype=cv2.CV_32S) return cv2.connectedComponentsWithStats(
gray_image, connectivity=8, ltype=cv2.CV_32S)
class _LaneFeat(object): class _LaneFeat(object):
""" """
""" """
def __init__(self, feat, coord, class_id=-1): def __init__(self, feat, coord, class_id=-1):
""" """
lane feat object lane feat object
...@@ -108,18 +125,21 @@ class _LaneNetCluster(object): ...@@ -108,18 +125,21 @@ class _LaneNetCluster(object):
""" """
Instance segmentation result cluster Instance segmentation result cluster
""" """
def __init__(self): def __init__(self):
""" """
""" """
self._color_map = [np.array([255, 0, 0]), self._color_map = [
np.array([0, 255, 0]), np.array([255, 0, 0]),
np.array([0, 0, 255]), np.array([0, 255, 0]),
np.array([125, 125, 0]), np.array([0, 0, 255]),
np.array([0, 125, 125]), np.array([125, 125, 0]),
np.array([125, 0, 125]), np.array([0, 125, 125]),
np.array([50, 100, 50]), np.array([125, 0, 125]),
np.array([100, 50, 100])] np.array([50, 100, 50]),
np.array([100, 50, 100])
]
@staticmethod @staticmethod
def _embedding_feats_dbscan_cluster(embedding_image_feats): def _embedding_feats_dbscan_cluster(embedding_image_feats):
...@@ -186,15 +206,16 @@ class _LaneNetCluster(object): ...@@ -186,15 +206,16 @@ class _LaneNetCluster(object):
# get embedding feats and coords # get embedding feats and coords
get_lane_embedding_feats_result = self._get_lane_embedding_feats( get_lane_embedding_feats_result = self._get_lane_embedding_feats(
binary_seg_ret=binary_seg_result, binary_seg_ret=binary_seg_result,
instance_seg_ret=instance_seg_result instance_seg_ret=instance_seg_result)
)
# dbscan cluster # dbscan cluster
dbscan_cluster_result = self._embedding_feats_dbscan_cluster( dbscan_cluster_result = self._embedding_feats_dbscan_cluster(
embedding_image_feats=get_lane_embedding_feats_result['lane_embedding_feats'] embedding_image_feats=get_lane_embedding_feats_result[
) 'lane_embedding_feats'])
mask = np.zeros(shape=[binary_seg_result.shape[0], binary_seg_result.shape[1], 3], dtype=np.uint8) mask = np.zeros(
shape=[binary_seg_result.shape[0], binary_seg_result.shape[1], 3],
dtype=np.uint8)
db_labels = dbscan_cluster_result['db_labels'] db_labels = dbscan_cluster_result['db_labels']
unique_labels = dbscan_cluster_result['unique_labels'] unique_labels = dbscan_cluster_result['unique_labels']
coord = get_lane_embedding_feats_result['lane_coordinates'] coord = get_lane_embedding_feats_result['lane_coordinates']
...@@ -219,11 +240,13 @@ class LaneNetPostProcessor(object): ...@@ -219,11 +240,13 @@ class LaneNetPostProcessor(object):
""" """
lanenet post process for lane generation lanenet post process for lane generation
""" """
def __init__(self, ipm_remap_file_path='./utils/tusimple_ipm_remap.yml'): def __init__(self, ipm_remap_file_path='./utils/tusimple_ipm_remap.yml'):
""" """
convert front car view to bird view convert front car view to bird view
""" """
assert ops.exists(ipm_remap_file_path), '{:s} not exist'.format(ipm_remap_file_path) assert ops.exists(ipm_remap_file_path), '{:s} not exist'.format(
ipm_remap_file_path)
self._cluster = _LaneNetCluster() self._cluster = _LaneNetCluster()
self._ipm_remap_file_path = ipm_remap_file_path self._ipm_remap_file_path = ipm_remap_file_path
...@@ -232,14 +255,16 @@ class LaneNetPostProcessor(object): ...@@ -232,14 +255,16 @@ class LaneNetPostProcessor(object):
self._remap_to_ipm_x = remap_file_load_ret['remap_to_ipm_x'] self._remap_to_ipm_x = remap_file_load_ret['remap_to_ipm_x']
self._remap_to_ipm_y = remap_file_load_ret['remap_to_ipm_y'] self._remap_to_ipm_y = remap_file_load_ret['remap_to_ipm_y']
self._color_map = [np.array([255, 0, 0]), self._color_map = [
np.array([0, 255, 0]), np.array([255, 0, 0]),
np.array([0, 0, 255]), np.array([0, 255, 0]),
np.array([125, 125, 0]), np.array([0, 0, 255]),
np.array([0, 125, 125]), np.array([125, 125, 0]),
np.array([125, 0, 125]), np.array([0, 125, 125]),
np.array([50, 100, 50]), np.array([125, 0, 125]),
np.array([100, 50, 100])] np.array([50, 100, 50]),
np.array([100, 50, 100])
]
def _load_remap_matrix(self): def _load_remap_matrix(self):
fs = cv2.FileStorage(self._ipm_remap_file_path, cv2.FILE_STORAGE_READ) fs = cv2.FileStorage(self._ipm_remap_file_path, cv2.FILE_STORAGE_READ)
...@@ -256,15 +281,20 @@ class LaneNetPostProcessor(object): ...@@ -256,15 +281,20 @@ class LaneNetPostProcessor(object):
return ret return ret
def postprocess(self, binary_seg_result, instance_seg_result=None, def postprocess(self,
min_area_threshold=100, source_image=None, binary_seg_result,
instance_seg_result=None,
min_area_threshold=100,
source_image=None,
data_source='tusimple'): data_source='tusimple'):
# convert binary_seg_result # convert binary_seg_result
binary_seg_result = np.array(binary_seg_result * 255, dtype=np.uint8) binary_seg_result = np.array(binary_seg_result * 255, dtype=np.uint8)
# apply image morphology operation to fill in the hold and reduce the small area # apply image morphology operation to fill in the hold and reduce the small area
morphological_ret = _morphological_process(binary_seg_result, kernel_size=5) morphological_ret = _morphological_process(
connect_components_analysis_ret = _connect_components_analysis(image=morphological_ret) binary_seg_result, kernel_size=5)
connect_components_analysis_ret = _connect_components_analysis(
image=morphological_ret)
labels = connect_components_analysis_ret[1] labels = connect_components_analysis_ret[1]
stats = connect_components_analysis_ret[2] stats = connect_components_analysis_ret[2]
...@@ -276,8 +306,7 @@ class LaneNetPostProcessor(object): ...@@ -276,8 +306,7 @@ class LaneNetPostProcessor(object):
# apply embedding features cluster # apply embedding features cluster
mask_image, lane_coords = self._cluster.apply_lane_feats_cluster( mask_image, lane_coords = self._cluster.apply_lane_feats_cluster(
binary_seg_result=morphological_ret, binary_seg_result=morphological_ret,
instance_seg_result=instance_seg_result instance_seg_result=instance_seg_result)
)
if mask_image is None: if mask_image is None:
return { return {
...@@ -292,15 +321,15 @@ class LaneNetPostProcessor(object): ...@@ -292,15 +321,15 @@ class LaneNetPostProcessor(object):
for lane_index, coords in enumerate(lane_coords): for lane_index, coords in enumerate(lane_coords):
if data_source == 'tusimple': if data_source == 'tusimple':
tmp_mask = np.zeros(shape=(720, 1280), dtype=np.uint8) tmp_mask = np.zeros(shape=(720, 1280), dtype=np.uint8)
tmp_mask[tuple((np.int_(coords[:, 1] * 720 / 256), np.int_(coords[:, 0] * 1280 / 512)))] = 255 tmp_mask[tuple((np.int_(coords[:, 1] * 720 / 256),
np.int_(coords[:, 0] * 1280 / 512)))] = 255
else: else:
raise ValueError('Wrong data source now only support tusimple') raise ValueError('Wrong data source now only support tusimple')
tmp_ipm_mask = cv2.remap( tmp_ipm_mask = cv2.remap(
tmp_mask, tmp_mask,
self._remap_to_ipm_x, self._remap_to_ipm_x,
self._remap_to_ipm_y, self._remap_to_ipm_y,
interpolation=cv2.INTER_NEAREST interpolation=cv2.INTER_NEAREST)
)
nonzero_y = np.array(tmp_ipm_mask.nonzero()[0]) nonzero_y = np.array(tmp_ipm_mask.nonzero()[0])
nonzero_x = np.array(tmp_ipm_mask.nonzero()[1]) nonzero_x = np.array(tmp_ipm_mask.nonzero()[1])
...@@ -309,16 +338,19 @@ class LaneNetPostProcessor(object): ...@@ -309,16 +338,19 @@ class LaneNetPostProcessor(object):
[ipm_image_height, ipm_image_width] = tmp_ipm_mask.shape [ipm_image_height, ipm_image_width] = tmp_ipm_mask.shape
plot_y = np.linspace(10, ipm_image_height, ipm_image_height - 10) plot_y = np.linspace(10, ipm_image_height, ipm_image_height - 10)
fit_x = fit_param[0] * plot_y ** 2 + fit_param[1] * plot_y + fit_param[2] fit_x = fit_param[0] * plot_y**2 + fit_param[
1] * plot_y + fit_param[2]
lane_pts = [] lane_pts = []
for index in range(0, plot_y.shape[0], 5): for index in range(0, plot_y.shape[0], 5):
src_x = self._remap_to_ipm_x[ src_x = self._remap_to_ipm_x[
int(plot_y[index]), int(np.clip(fit_x[index], 0, ipm_image_width - 1))] int(plot_y[index]),
int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
if src_x <= 0: if src_x <= 0:
continue continue
src_y = self._remap_to_ipm_y[ src_y = self._remap_to_ipm_y[
int(plot_y[index]), int(np.clip(fit_x[index], 0, ipm_image_width - 1))] int(plot_y[index]),
int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
src_y = src_y if src_y > 0 else 0 src_y = src_y if src_y > 0 else 0
lane_pts.append([src_x, src_y]) lane_pts.append([src_x, src_y])
...@@ -366,8 +398,10 @@ class LaneNetPostProcessor(object): ...@@ -366,8 +398,10 @@ class LaneNetPostProcessor(object):
continue continue
lane_color = self._color_map[index].tolist() lane_color = self._color_map[index].tolist()
cv2.circle(source_image, (int(interpolation_src_pt_x), cv2.circle(
int(interpolation_src_pt_y)), 5, lane_color, -1) source_image,
(int(interpolation_src_pt_x), int(interpolation_src_pt_y)),
5, lane_color, -1)
ret = { ret = {
'mask_image': mask_image, 'mask_image': mask_image,
'fit_params': fit_params, 'fit_params': fit_params,
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
import six
import numpy as np
def parse_param_file(param_file, return_shape=True):
from paddle.fluid.proto.framework_pb2 import VarType
f = open(param_file, 'rb')
version = np.fromstring(f.read(4), dtype='int32')
lod_level = np.fromstring(f.read(8), dtype='int64')
for i in range(int(lod_level)):
_size = np.fromstring(f.read(8), dtype='int64')
_ = f.read(_size)
version = np.fromstring(f.read(4), dtype='int32')
tensor_desc = VarType.TensorDesc()
tensor_desc_size = np.fromstring(f.read(4), dtype='int32')
tensor_desc.ParseFromString(f.read(int(tensor_desc_size)))
tensor_shape = tuple(tensor_desc.dims)
if return_shape:
f.close()
return tuple(tensor_desc.dims)
if tensor_desc.data_type != 5:
raise Exception(
"Unexpected data type while parse {}".format(param_file))
data_size = 4
for i in range(len(tensor_shape)):
data_size *= tensor_shape[i]
weight = np.fromstring(f.read(data_size), dtype='float32')
f.close()
return np.reshape(weight, tensor_shape)
def load_pdparams(exe, main_prog, model_dir):
import paddle.fluid as fluid
from paddle.fluid.proto.framework_pb2 import VarType
from paddle.fluid.framework import Program
vars_to_load = list()
vars_not_load = list()
import pickle
with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f:
params_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
unused_vars = list()
for var in main_prog.list_vars():
if not isinstance(var, fluid.framework.Parameter):
continue
if var.name not in params_dict:
print("{} is not in saved model".format(var.name))
vars_not_load.append(var.name)
continue
if var.shape != params_dict[var.name].shape:
unused_vars.append(var.name)
vars_not_load.append(var.name)
print(
"[SKIP] Shape of pretrained weight {} doesn't match.(Pretrained: {}, Actual: {})"
.format(var.name, params_dict[var.name].shape, var.shape))
continue
vars_to_load.append(var)
for var_name in unused_vars:
del params_dict[var_name]
fluid.io.set_program_state(main_prog, params_dict)
if len(vars_to_load) == 0:
print(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
)
else:
print("There are {}/{} varaibles in {} are loaded.".format(
len(vars_to_load),
len(vars_to_load) + len(vars_not_load), model_dir))
def load_pretrained_weights(exe, main_prog, weights_dir):
if not osp.exists(weights_dir):
raise Exception("Path {} not exists.".format(weights_dir))
if osp.exists(osp.join(weights_dir, "model.pdparams")):
return load_pdparams(exe, main_prog, weights_dir)
import paddle.fluid as fluid
vars_to_load = list()
vars_not_load = list()
for var in main_prog.list_vars():
if not isinstance(var, fluid.framework.Parameter):
continue
if not osp.exists(osp.join(weights_dir, var.name)):
print("[SKIP] Pretrained weight {}/{} doesn't exist".format(
weights_dir, var.name))
vars_not_load.append(var)
continue
pretrained_shape = parse_param_file(osp.join(weights_dir, var.name))
actual_shape = tuple(var.shape)
if pretrained_shape != actual_shape:
print(
"[SKIP] Shape of pretrained weight {}/{} doesn't match.(Pretrained: {}, Actual: {})"
.format(weights_dir, var.name, pretrained_shape, actual_shape))
vars_not_load.append(var)
continue
vars_to_load.append(var)
params_dict = fluid.io.load_program_state(
weights_dir, var_list=vars_to_load)
fluid.io.set_program_state(main_prog, params_dict)
if len(vars_to_load) == 0:
print(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
)
else:
print("There are {}/{} varaibles in {} are loaded.".format(
len(vars_to_load),
len(vars_to_load) + len(vars_not_load), weights_dir))
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -45,6 +45,7 @@ from models.model_builder import ModelPhase ...@@ -45,6 +45,7 @@ from models.model_builder import ModelPhase
from utils import lanenet_postprocess from utils import lanenet_postprocess
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='PaddeSeg visualization tools') parser = argparse.ArgumentParser(description='PaddeSeg visualization tools')
parser.add_argument( parser.add_argument(
...@@ -106,7 +107,6 @@ def minmax_scale(input_arr): ...@@ -106,7 +107,6 @@ def minmax_scale(input_arr):
return output_arr return output_arr
def visualize(cfg, def visualize(cfg,
vis_file_list=None, vis_file_list=None,
use_gpu=False, use_gpu=False,
...@@ -119,7 +119,6 @@ def visualize(cfg, ...@@ -119,7 +119,6 @@ def visualize(cfg,
if vis_file_list is None: if vis_file_list is None:
vis_file_list = cfg.DATASET.TEST_FILE_LIST vis_file_list = cfg.DATASET.TEST_FILE_LIST
dataset = LaneNetDataset( dataset = LaneNetDataset(
file_list=vis_file_list, file_list=vis_file_list,
mode=ModelPhase.VISUAL, mode=ModelPhase.VISUAL,
...@@ -139,7 +138,12 @@ def visualize(cfg, ...@@ -139,7 +138,12 @@ def visualize(cfg,
ckpt_dir = cfg.TEST.TEST_MODEL if not ckpt_dir else ckpt_dir ckpt_dir = cfg.TEST.TEST_MODEL if not ckpt_dir else ckpt_dir
fluid.io.load_params(exe, ckpt_dir, main_program=test_prog) if ckpt_dir is not None:
print('load test model:', ckpt_dir)
try:
fluid.load(test_prog, os.path.join(ckpt_dir, 'model'), exe)
except:
fluid.io.load_params(exe, ckpt_dir, main_program=test_prog)
save_dir = os.path.join(vis_dir, 'visual_results') save_dir = os.path.join(vis_dir, 'visual_results')
makedirs(save_dir) makedirs(save_dir)
...@@ -161,22 +165,26 @@ def visualize(cfg, ...@@ -161,22 +165,26 @@ def visualize(cfg,
for i in range(num_imgs): for i in range(num_imgs):
gt_image = org_imgs[i] gt_image = org_imgs[i]
binary_seg_image, instance_seg_image = segLogits[i].squeeze(-1), emLogits[i].transpose((1,2,0)) binary_seg_image, instance_seg_image = segLogits[i].squeeze(
-1), emLogits[i].transpose((1, 2, 0))
postprocess_result = postprocessor.postprocess( postprocess_result = postprocessor.postprocess(
binary_seg_result=binary_seg_image, binary_seg_result=binary_seg_image,
instance_seg_result=instance_seg_image, instance_seg_result=instance_seg_image,
source_image=gt_image source_image=gt_image)
) pred_binary_fn = os.path.join(
pred_binary_fn = os.path.join(save_dir, to_png_fn(img_names[i], name='_pred_binary')) save_dir, to_png_fn(img_names[i], name='_pred_binary'))
pred_lane_fn = os.path.join(save_dir, to_png_fn(img_names[i], name='_pred_lane')) pred_lane_fn = os.path.join(
pred_instance_fn = os.path.join(save_dir, to_png_fn(img_names[i], name='_pred_instance')) save_dir, to_png_fn(img_names[i], name='_pred_lane'))
pred_instance_fn = os.path.join(
save_dir, to_png_fn(img_names[i], name='_pred_instance'))
dirname = os.path.dirname(pred_binary_fn) dirname = os.path.dirname(pred_binary_fn)
makedirs(dirname) makedirs(dirname)
mask_image = postprocess_result['mask_image'] mask_image = postprocess_result['mask_image']
for i in range(4): for i in range(4):
instance_seg_image[:, :, i] = minmax_scale(instance_seg_image[:, :, i]) instance_seg_image[:, :, i] = minmax_scale(
instance_seg_image[:, :, i])
embedding_image = np.array(instance_seg_image).astype(np.uint8) embedding_image = np.array(instance_seg_image).astype(np.uint8)
plt.figure('mask_image') plt.figure('mask_image')
...@@ -189,13 +197,13 @@ def visualize(cfg, ...@@ -189,13 +197,13 @@ def visualize(cfg,
plt.imshow(binary_seg_image * 255, cmap='gray') plt.imshow(binary_seg_image * 255, cmap='gray')
plt.show() plt.show()
cv2.imwrite(pred_binary_fn, np.array(binary_seg_image * 255).astype(np.uint8)) cv2.imwrite(pred_binary_fn,
np.array(binary_seg_image * 255).astype(np.uint8))
cv2.imwrite(pred_lane_fn, postprocess_result['source_image']) cv2.imwrite(pred_lane_fn, postprocess_result['source_image'])
cv2.imwrite(pred_instance_fn, mask_image) cv2.imwrite(pred_instance_fn, mask_image)
print(pred_lane_fn, 'saved!') print(pred_lane_fn, 'saved!')
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
if args.cfg_file is not None: if args.cfg_file is not None:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -23,7 +24,8 @@ from test_utils import download_file_and_uncompress ...@@ -23,7 +24,8 @@ from test_utils import download_file_and_uncompress
if __name__ == "__main__": if __name__ == "__main__":
download_file_and_uncompress( download_file_and_uncompress(
url='https://paddleseg.bj.bcebos.com/models/unet_mechanical_industry_meter.tar', url=
'https://paddleseg.bj.bcebos.com/models/unet_mechanical_industry_meter.tar',
savepath=LOCAL_PATH, savepath=LOCAL_PATH,
extrapath=LOCAL_PATH) extrapath=LOCAL_PATH)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
提供基于PaddlSeg最新的分割特色模型: 提供基于PaddlSeg最新的分割特色模型:
- [人像分割](./HumanSeg) - [人像分割](./HumanSeg)
- [遥感分割](./RemoteSensing)
- [人体解析](./ACE2P) - [人体解析](./ACE2P)
- [车道线分割](./LaneNet) - [车道线分割](./LaneNet)
- [工业表盘分割](#工业表盘分割) - [工业表盘分割](#工业表盘分割)
...@@ -12,6 +13,14 @@ ...@@ -12,6 +13,14 @@
HumanSeg系列全新升级,提供三个适用于不同场景,包含适用于移动端实时分割场景的模型`HumanSeg-lite`,提供了包含光流的后处理的优化,使人像分割在视频场景中更加顺畅,更多详情请参考[HumanSeg](./HumanSeg) HumanSeg系列全新升级,提供三个适用于不同场景,包含适用于移动端实时分割场景的模型`HumanSeg-lite`,提供了包含光流的后处理的优化,使人像分割在视频场景中更加顺畅,更多详情请参考[HumanSeg](./HumanSeg)
## 遥感分割 Remote Sensing Segmentation
PaddleSeg遥感影像分割涵盖图像预处理、数据增强、模型训练、预测流程。
针对遥感数据多通道、分布范围大、分布不均的特点,我们支持多通道训练预测,内置10+多通道预处理和数据增强的策略,可结合实际业务场景进行定制组合,提升模型泛化能力和鲁棒性。
内置U-Net, HRNet两种主流分割网络,可选择不同的损失函数如Dice Loss, BCE Loss等方式强化小目标和不均衡样本场景下的分割精度。更多详情请参考[RemoteSensing](./RemoteSensing)
以下是遥感云检测的示例效果:
![](./RemoteSensing/docs/imgs/rs.png)
## 人体解析 Human Parsing ## 人体解析 Human Parsing
......
# 遥感分割(RemoteSensing) # PaddleSeg遥感影像分割
遥感影像分割是图像分割领域中的重要应用场景,广泛应用于土地测绘、环境监测、城市建设等领域。遥感影像分割的目标多种多样,有诸如积雪、农作物、道路、建筑、水源等地物目标,也有例如云层的空中目标。 遥感影像分割是图像分割领域中的重要应用场景,广泛应用于土地测绘、环境监测、城市建设等领域。遥感影像分割的目标多种多样,有诸如积雪、农作物、道路、建筑、水源等地物目标,也有例如云层的空中目标。
PaddleSeg提供了针对遥感专题的语义分割库RemoteSensing,涵盖图像预处理、数据增强、模型训练、预测流程,帮助用户利用深度学习技术解决遥感影像分割问题。 PaddleSeg遥感影像分割涵盖图像预处理、数据增强、模型训练、预测流程,帮助用户利用深度学习技术解决遥感影像分割问题。
## 特点 ## 特点
针对遥感数据多通道、分布范围大、分布不均的特点,我们支持多通道训练预测,内置一系列多通道预处理和数据增强的策略,可结合实际业务场景进行定制组合,提升模型泛化能力和鲁棒性。 - 针对遥感数据多通道、分布范围大、分布不均的特点,我们支持多通道训练预测,内置10+多通道预处理和数据增强的策略,可结合实际业务场景进行定制组合,提升模型泛化能力和鲁棒性。
**Note:** 所有命令需要在`PaddleSeg/contrib/RemoteSensing/`目录下执行。 - 内置U-Net, HRNet两种主流分割网络,可选择不同的损失函数如Dice Loss, BCE Loss等方式强化小目标和不均衡样本场景下的分割精度。
以下是遥感云检测的示例效果:
![](./docs/imgs/rs.png)
## 前置依赖 ## 前置依赖
**Note:** 若没有特殊说明,以下所有命令需要在`PaddleSeg/contrib/RemoteSensing/`目录下执行。
- Paddle 1.7.1+ - Paddle 1.7.1+
由于图像分割模型计算开销大,推荐在GPU版本的PaddlePaddle下使用。 由于图像分割模型计算开销大,推荐在GPU版本的PaddlePaddle下使用。
PaddlePaddle的安装, 请按照[官网指引](https://paddlepaddle.org.cn/install/quick)安装合适自己的版本。 PaddlePaddle的安装, 请按照[官网指引](https://paddlepaddle.org.cn/install/quick)安装合适自己的版本。
...@@ -18,7 +24,6 @@ PaddlePaddle的安装, 请按照[官网指引](https://paddlepaddle.org.cn/insta ...@@ -18,7 +24,6 @@ PaddlePaddle的安装, 请按照[官网指引](https://paddlepaddle.org.cn/insta
- 其他依赖安装 - 其他依赖安装
通过以下命令安装python包依赖,请确保至少执行过一次以下命令: 通过以下命令安装python包依赖,请确保至少执行过一次以下命令:
``` ```
cd RemoteSensing
pip install -r requirements.txt pip install -r requirements.txt
``` ```
...@@ -63,9 +68,9 @@ RemoteSensing # 根目录 ...@@ -63,9 +68,9 @@ RemoteSensing # 根目录
``` ```
其中,相应的文件名可根据需要自行定义。 其中,相应的文件名可根据需要自行定义。
遥感领域图像格式多种多样,不同传感器产生的数据格式可能不同。为方便数据加载,本分割库统一采用numpy存储格式`npy`作为原图格式,采用`png`无损压缩格式作为标注图片格式。 遥感影像的格式多种多样,不同传感器产生的数据格式也可能不同。PaddleSeg以numpy.ndarray数据类型进行图像预处理。为统一接口并方便数据加载,我们采用numpy存储格式`npy`作为原图格式,采用`png`无损压缩格式作为标注图片格式。
原图的前两维是图像的尺寸,第3维是图像的通道数。 原图的尺寸应为(h, w, channel),其中h, w为图像的高和宽,channel为图像的通道数。
标注图像为单通道图像,像素值即为对应的类别,像素标注类别需要从0开始递增 标注图像为单通道图像,像素值即为对应的类别,像素标注类别需要从0开始递增
例如0,1,2,3表示有4种类别,标注类别最多为256类。其中可以指定特定的像素值用于表示该值的像素不参与训练和评估(默认为255)。 例如0,1,2,3表示有4种类别,标注类别最多为256类。其中可以指定特定的像素值用于表示该值的像素不参与训练和评估(默认为255)。
`train_list.txt``val_list.txt`文本以空格为分割符分为两列,第一列为图像文件相对于dataset的相对路径,第二列为标注图像文件相对于dataset的相对路径。如下所示: `train_list.txt``val_list.txt`文本以空格为分割符分为两列,第一列为图像文件相对于dataset的相对路径,第二列为标注图像文件相对于dataset的相对路径。如下所示:
...@@ -93,154 +98,38 @@ labelB ...@@ -93,154 +98,38 @@ labelB
### 1. 准备数据集 ### 1. 准备数据集
为了快速体验,我们准备了一个小型demo数据集,已位于`RemoteSensing/dataset/demo/`目录下. 为了快速体验,我们准备了一个小型demo数据集,已位于`RemoteSensing/dataset/demo/`目录下.
对于您自己的数据集,您需要按照上述的数据协议进行格式转换,可分别使用numpy和pil库保存遥感数据和标注图片。其中numpy api示例如下: 对于您自己的数据集,您需要按照上述的数据协议进行格式转换,可分别使用numpy和Pillow库保存遥感数据和标注图片。其中numpy API示例如下:
```python ```python
import numpy as np import numpy as np
# 保存遥感数据 # 将遥感数据保存到以 .npy 为扩展名的文件中
# img类型:numpy.ndarray # img类型:numpy.ndarray
np.save(save_path, img) np.save(save_path, img)
``` ```
### 2. 训练代码开发 ### 2. 模型训练
通过如下`train_demo.py`代码进行训练。 #### (1) 设置GPU卡号
> 导入RemoteSensing api
```python
import transforms.transforms as T
from readers.reader import Reader
from models import UNet
```
> 定义训练和验证时的数据处理和增强流程, 在`train_transforms`中加入了`RandomVerticalFlip`,`RandomHorizontalFlip`等数据增强方式。
```python
train_transforms = T.Compose([
T.RandomVerticalFlip(0.5),
T.RandomHorizontalFlip(0.5),
T.ResizeStepScaling(0.5, 2.0, 0.25),
T.RandomPaddingCrop(256),
T.Normalize(mean=[0.5] * channel, std=[0.5] * channel),
])
eval_transforms = T.Compose([
T.Normalize(mean=[0.5] * channel, std=[0.5] * channel),
])
```
> 定义数据读取器
```python
import os
import os.path as osp
train_list = osp.join(data_dir, 'train.txt')
val_list = osp.join(data_dir, 'val.txt')
label_list = osp.join(data_dir, 'labels.txt')
train_reader = Reader(
data_dir=data_dir,
file_list=train_list,
label_list=label_list,
transforms=train_transforms,
num_workers=8,
buffer_size=16,
shuffle=True,
parallel_method='thread')
eval_reader = Reader(
data_dir=data_dir,
file_list=val_list,
label_list=label_list,
transforms=eval_transforms,
num_workers=8,
buffer_size=16,
shuffle=False,
parallel_method='thread')
```
> 模型构建
```python
model = UNet(
num_classes=2, input_channel=channel, use_bce_loss=True, use_dice_loss=True)
```
> 模型训练,并开启边训边评估
```python
model.train(
num_epochs=num_epochs,
train_reader=train_reader,
train_batch_size=train_batch_size,
eval_reader=eval_reader,
save_interval_epochs=5,
log_interval_steps=10,
save_dir=save_dir,
pretrain_weights=None,
optimizer=None,
learning_rate=lr,
use_vdl=True
)
```
### 3. 模型训练
> 设置GPU卡号
```shell script ```shell script
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
``` ```
> 在RemoteSensing目录下运行`train_demo.py`即可开始训练。 #### (2) 以U-Net为例,在RemoteSensing目录下运行`train_demo.py`即可开始训练。
```shell script ```shell script
python train_demo.py --data_dir dataset/demo/ --save_dir saved_model/unet/ --channel 3 --num_epochs 20 python train_demo.py --model_type unet --data_dir dataset/demo/ --save_dir saved_model/unet/ --channel 3 --num_epochs 20
```
### 4. 模型预测代码开发
通过如下`predict_demo.py`代码进行预测。
> 导入RemoteSensing api
```python
from models import load_model
```
> 加载训练过程中最好的模型,设置预测结果保存路径。
```python
import os
import os.path as osp
model = load_model(osp.join(save_dir, 'best_model'))
pred_dir = osp.join(save_dir, 'pred')
if not osp.exists(pred_dir):
os.mkdir(pred_dir)
```
> 使用模型对验证集进行测试,并保存预测结果。
```python
import numpy as np
from PIL import Image as Image
val_list = osp.join(data_dir, 'val.txt')
color_map = [0, 0, 0, 255, 255, 255]
with open(val_list) as f:
lines = f.readlines()
for line in lines:
img_path = line.split(' ')[0]
print('Predicting {}'.format(img_path))
img_path_ = osp.join(data_dir, img_path)
pred = model.predict(img_path_)
# 以伪彩色png图片保存预测结果
pred_name = osp.basename(img_path).rstrip('npy') + 'png'
pred_path = osp.join(pred_dir, pred_name)
pred_mask = Image.fromarray(pred.astype(np.uint8), mode='P')
pred_mask.putpalette(color_map)
pred_mask.save(pred_path)
``` ```
### 5. 模型预测 ### 3. 模型预测
> 设置GPU卡号 #### (1) 设置GPU卡号
```shell script ```shell script
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
``` ```
> 在RemoteSensing目录下运行`predict_demo.py`即可开始训练。 #### (2) 以刚训练好的U-Net最优模型为例,在RemoteSensing目录下运行`predict_demo.py`即可开始训练。
```shell script ```shell script
python predict_demo.py --data_dir dataset/demo/ --load_model_dir saved_model/unet/best_model/ python predict_demo.py --data_dir dataset/demo/ --file_list val.txt --load_model_dir saved_model/unet/best_model
``` ```
## Api说明 ## API说明
您可以使用`RemoteSensing`目录下提供的api构建自己的分割代码。 您可以使用`RemoteSensing`目录下提供的API构建自己的分割代码。
- [数据处理-transforms](docs/transforms.md) - [数据处理-transforms](docs/transforms.md)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .load_model import * from .load_model import *
from .unet import * from .unet import *
from .hrnet import *
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -19,15 +20,16 @@ import numpy as np ...@@ -19,15 +20,16 @@ import numpy as np
import time import time
import math import math
import yaml import yaml
import tqdm
import cv2
import copy import copy
import json
import utils.logging as logging import utils.logging as logging
from collections import OrderedDict from collections import OrderedDict
from os import path as osp from os import path as osp
from utils.pretrain_weights import get_pretrain_weights from utils.utils import seconds_to_hms, get_environ_info
from utils.metrics import ConfusionMatrix
import transforms.transforms as T import transforms.transforms as T
import utils import utils
import __init__
def dict2str(dict_input): def dict2str(dict_input):
...@@ -41,12 +43,45 @@ def dict2str(dict_input): ...@@ -41,12 +43,45 @@ def dict2str(dict_input):
return out.strip(', ') return out.strip(', ')
class BaseAPI: class BaseModel(object):
def __init__(self): def __init__(self,
# 现有的CV模型都有这个属性,而这个属且也需要在eval时用到 num_classes=2,
self.num_classes = None use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
ignore_index=255,
sync_bn=True):
self.init_params = locals()
if num_classes > 2 and (use_bce_loss or use_dice_loss):
raise ValueError(
"dice loss and bce loss is only applicable to binary classfication"
)
if class_weight is not None:
if isinstance(class_weight, list):
if len(class_weight) != num_classes:
raise ValueError(
"Length of class_weight should be equal to number of classes"
)
elif isinstance(class_weight, str):
if class_weight.lower() != 'dynamic':
raise ValueError(
"if class_weight is string, must be dynamic!")
else:
raise TypeError(
'Expect class_weight is a list or string but receive {}'.
format(type(class_weight)))
self.num_classes = num_classes
self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss
self.class_weight = class_weight
self.ignore_index = ignore_index
self.sync_bn = sync_bn
self.labels = None self.labels = None
if __init__.env_info['place'] == 'cpu': self.env_info = get_environ_info()
if self.env_info['place'] == 'cpu':
self.places = fluid.cpu_places() self.places = fluid.cpu_places()
else: else:
self.places = fluid.cuda_places() self.places = fluid.cuda_places()
...@@ -60,10 +95,6 @@ class BaseAPI: ...@@ -60,10 +95,6 @@ class BaseAPI:
self.test_outputs = None self.test_outputs = None
self.train_data_loader = None self.train_data_loader = None
self.eval_metrics = None self.eval_metrics = None
# 若模型是从inference model加载进来的,无法调用训练接口进行训练
self.trainable = True
# 是否使用多卡间同步BatchNorm均值和方差
self.sync_bn = False
# 当前模型状态 # 当前模型状态
self.status = 'Normal' self.status = 'Normal'
...@@ -73,16 +104,20 @@ class BaseAPI: ...@@ -73,16 +104,20 @@ class BaseAPI:
else: else:
raise Exception("Please support correct batch_size, \ raise Exception("Please support correct batch_size, \
which can be divided by available cards({}) in {}". which can be divided by available cards({}) in {}".
format(__init__.env_info['num'], format(self.env_info['num'],
__init__.env_info['place'])) self.env_info['place']))
def build_net(self, mode='train'):
"""应根据不同的情况进行构建"""
pass
def build_program(self): def build_program(self):
# 构建训练网络 # build training network
self.train_inputs, self.train_outputs = self.build_net(mode='train') self.train_inputs, self.train_outputs = self.build_net(mode='train')
self.train_prog = fluid.default_main_program() self.train_prog = fluid.default_main_program()
startup_prog = fluid.default_startup_program() startup_prog = fluid.default_startup_program()
# 构建预测网络 # build prediction network
self.test_prog = fluid.Program() self.test_prog = fluid.Program()
with fluid.program_guard(self.test_prog, startup_prog): with fluid.program_guard(self.test_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
...@@ -90,15 +125,15 @@ class BaseAPI: ...@@ -90,15 +125,15 @@ class BaseAPI:
mode='test') mode='test')
self.test_prog = self.test_prog.clone(for_test=True) self.test_prog = self.test_prog.clone(for_test=True)
def arrange_transforms(self, transforms, mode='train'): def arrange_transform(self, transforms, mode='train'):
# 给transforms添加arrange操作 arrange_transform = T.ArrangeSegmenter
if transforms.transforms[-1].__class__.__name__.startswith('Arrange'): if type(transforms.transforms[-1]).__name__.startswith('Arrange'):
transforms.transforms[-1] = T.ArrangeSegmenter(mode=mode) transforms.transforms[-1] = arrange_transform(mode=mode)
else: else:
transforms.transforms.append(T.ArrangeSegmenter(mode=mode)) transforms.transforms.append(arrange_transform(mode=mode))
def build_train_data_loader(self, reader, batch_size): def build_train_data_loader(self, dataset, batch_size):
# 初始化data_loader # init data_loader
if self.train_data_loader is None: if self.train_data_loader is None:
self.train_data_loader = fluid.io.DataLoader.from_generator( self.train_data_loader = fluid.io.DataLoader.from_generator(
feed_list=list(self.train_inputs.values()), feed_list=list(self.train_inputs.values()),
...@@ -106,72 +141,92 @@ class BaseAPI: ...@@ -106,72 +141,92 @@ class BaseAPI:
use_double_buffer=True, use_double_buffer=True,
iterable=True) iterable=True)
batch_size_each_gpu = self._get_single_card_bs(batch_size) batch_size_each_gpu = self._get_single_card_bs(batch_size)
generator = reader.generator(
batch_size=batch_size_each_gpu, drop_last=True)
self.train_data_loader.set_sample_list_generator( self.train_data_loader.set_sample_list_generator(
reader.generator(batch_size=batch_size_each_gpu), dataset.generator(batch_size=batch_size_each_gpu),
places=self.places) places=self.places)
def net_initialize(self, def net_initialize(self,
startup_prog=None, startup_prog=None,
pretrain_weights=None, pretrain_weights=None,
fuse_bn=False, resume_weights=None):
save_dir='.',
sensitivities_file=None,
eval_metric_loss=0.05):
if hasattr(self, 'backbone'):
backbone = self.backbone
else:
backbone = self.__class__.__name__
pretrain_weights = get_pretrain_weights(pretrain_weights, backbone,
save_dir)
if startup_prog is None: if startup_prog is None:
startup_prog = fluid.default_startup_program() startup_prog = fluid.default_startup_program()
self.exe.run(startup_prog) self.exe.run(startup_prog)
if pretrain_weights is not None: if resume_weights is not None:
logging.info("Resume weights from {}".format(resume_weights))
if not osp.exists(resume_weights):
raise Exception("Path {} not exists.".format(resume_weights))
fluid.load(self.train_prog, osp.join(resume_weights, 'model'),
self.exe)
# Check is path ended by path spearator
if resume_weights[-1] == os.sep:
resume_weights = resume_weights[0:-1]
epoch_name = osp.basename(resume_weights)
# If resume weights is end of digit, restore epoch status
epoch = epoch_name.split('_')[-1]
if epoch.isdigit():
self.begin_epoch = int(epoch)
else:
raise ValueError("Resume model path is not valid!")
logging.info("Model checkpoint loaded successfully!")
elif pretrain_weights is not None:
logging.info( logging.info(
"Load pretrain weights from {}.".format(pretrain_weights)) "Load pretrain weights from {}.".format(pretrain_weights))
utils.utils.load_pretrain_weights(self.exe, self.train_prog, utils.load_pretrained_weights(self.exe, self.train_prog,
pretrain_weights, fuse_bn) pretrain_weights)
# 进行裁剪
if sensitivities_file is not None:
from .slim.prune_config import get_sensitivities
sensitivities_file = get_sensitivities(sensitivities_file, self,
save_dir)
from .slim.prune import get_params_ratios, prune_program
prune_params_ratios = get_params_ratios(
sensitivities_file, eval_metric_loss=eval_metric_loss)
prune_program(self, prune_params_ratios)
self.status = 'Prune'
def get_model_info(self): def get_model_info(self):
# 存储相应的信息到yml文件
info = dict() info = dict()
info['Model'] = self.__class__.__name__ info['Model'] = self.__class__.__name__
info['_Attributes'] = {}
if 'self' in self.init_params: if 'self' in self.init_params:
del self.init_params['self'] del self.init_params['self']
if '__class__' in self.init_params: if '__class__' in self.init_params:
del self.init_params['__class__'] del self.init_params['__class__']
info['_init_params'] = self.init_params info['_init_params'] = self.init_params
info['_Attributes'] = dict()
info['_Attributes']['num_classes'] = self.num_classes info['_Attributes']['num_classes'] = self.num_classes
info['_Attributes']['labels'] = self.labels info['_Attributes']['labels'] = self.labels
try: try:
primary_metric_key = list(self.eval_metrics.keys())[0] info['_Attributes']['eval_metric'] = dict()
primary_metric_value = float(self.eval_metrics[primary_metric_key]) for k, v in self.eval_metrics.items():
info['_Attributes']['eval_metrics'] = { if isinstance(v, np.ndarray):
primary_metric_key: primary_metric_value if v.size > 1:
} v = [float(i) for i in v]
else:
v = float(v)
info['_Attributes']['eval_metric'][k] = v
except: except:
pass pass
if hasattr(self, 'test_transforms'): if hasattr(self, 'test_transforms'):
if self.test_transforms is not None: if self.test_transforms is not None:
info['Transforms'] = list() info['test_transforms'] = list()
for op in self.test_transforms.transforms: for op in self.test_transforms.transforms:
name = op.__class__.__name__ name = op.__class__.__name__
attr = op.__dict__ attr = op.__dict__
info['Transforms'].append({name: attr}) info['test_transforms'].append({name: attr})
if hasattr(self, 'train_transforms'):
if self.train_transforms is not None:
info['train_transforms'] = list()
for op in self.train_transforms.transforms:
name = op.__class__.__name__
attr = op.__dict__
info['train_transforms'].append({name: attr})
if hasattr(self, 'train_init'):
if 'self' in self.train_init:
del self.train_init['self']
if 'train_reader' in self.train_init:
del self.train_init['train_reader']
if 'eval_reader' in self.train_init:
del self.train_init['eval_reader']
if 'optimizer' in self.train_init:
del self.train_init['optimizer']
info['train_init'] = self.train_init
return info return info
def save_model(self, save_dir): def save_model(self, save_dir):
...@@ -179,76 +234,139 @@ class BaseAPI: ...@@ -179,76 +234,139 @@ class BaseAPI:
if osp.exists(save_dir): if osp.exists(save_dir):
os.remove(save_dir) os.remove(save_dir)
os.makedirs(save_dir) os.makedirs(save_dir)
fluid.save(self.train_prog, osp.join(save_dir, 'model'))
model_info = self.get_model_info() model_info = self.get_model_info()
if self.status == 'Normal':
fluid.save(self.train_prog, osp.join(save_dir, 'model'))
model_info['status'] = self.status model_info['status'] = self.status
with open( with open(
osp.join(save_dir, 'model.yml'), encoding='utf-8', osp.join(save_dir, 'model.yml'), encoding='utf-8',
mode='w') as f: mode='w') as f:
yaml.dump(model_info, f) yaml.dump(model_info, f)
# 评估结果保存
if hasattr(self, 'eval_details'): # The flag of model for saving successfully
with open(osp.join(save_dir, 'eval_details.json'), 'w') as f:
json.dump(self.eval_details, f)
if self.status == 'Prune':
# 保存裁剪的shape
shapes = {}
for block in self.train_prog.blocks:
for param in block.all_parameters():
pd_var = fluid.global_scope().find_var(param.name)
pd_param = pd_var.get_tensor()
shapes[param.name] = np.array(pd_param).shape
with open(
osp.join(save_dir, 'prune.yml'), encoding='utf-8',
mode='w') as f:
yaml.dump(shapes, f)
# 模型保存成功的标志
open(osp.join(save_dir, '.success'), 'w').close() open(osp.join(save_dir, '.success'), 'w').close()
logging.info("Model saved in {}.".format(save_dir)) logging.info("Model saved in {}.".format(save_dir))
def train_loop(self, def export_inference_model(self, save_dir):
num_epochs, test_input_names = [var.name for var in list(self.test_inputs.values())]
train_reader, test_outputs = list(self.test_outputs.values())
train_batch_size, fluid.io.save_inference_model(
eval_reader=None, dirname=save_dir,
eval_best_metric=None, executor=self.exe,
save_interval_epochs=1, params_filename='__params__',
log_interval_steps=10, feeded_var_names=test_input_names,
save_dir='output', target_vars=test_outputs,
use_vdl=False): main_program=self.test_prog)
model_info = self.get_model_info()
model_info['status'] = 'Infer'
# Save input and output descrition of model
model_info['_ModelInputsOutputs'] = dict()
model_info['_ModelInputsOutputs']['test_inputs'] = [
[k, v.name] for k, v in self.test_inputs.items()
]
model_info['_ModelInputsOutputs']['test_outputs'] = [
[k, v.name] for k, v in self.test_outputs.items()
]
with open(
osp.join(save_dir, 'model.yml'), encoding='utf-8',
mode='w') as f:
yaml.dump(model_info, f)
# The flag of model for saving successfully
open(osp.join(save_dir, '.success'), 'w').close()
logging.info("Model for inference deploy saved in {}.".format(save_dir))
def default_optimizer(self,
learning_rate,
num_epochs,
num_steps_each_epoch,
lr_decay_power=0.9,
regularization_coeff=4e-5):
decay_step = num_epochs * num_steps_each_epoch
lr_decay = fluid.layers.polynomial_decay(
learning_rate,
decay_step,
end_learning_rate=0,
power=lr_decay_power)
optimizer = fluid.optimizer.Momentum(
lr_decay,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(
regularization_coeff=regularization_coeff))
return optimizer
def train(self,
num_epochs,
train_reader,
train_batch_size=2,
eval_reader=None,
eval_best_metric=None,
save_interval_epochs=1,
log_interval_steps=2,
save_dir='output',
pretrain_weights=None,
resume_weights=None,
optimizer=None,
learning_rate=0.01,
lr_decay_power=0.9,
regularization_coeff=4e-5,
use_vdl=False):
self.labels = train_reader.labels
self.train_transforms = train_reader.transforms
self.train_init = locals()
self.begin_epoch = 0
if optimizer is None:
num_steps_each_epoch = train_reader.num_samples // train_batch_size
optimizer = self.default_optimizer(
learning_rate=learning_rate,
num_epochs=num_epochs,
num_steps_each_epoch=num_steps_each_epoch,
lr_decay_power=lr_decay_power,
regularization_coeff=regularization_coeff)
self.optimizer = optimizer
self.build_program()
self.net_initialize(
startup_prog=fluid.default_startup_program(),
pretrain_weights=pretrain_weights,
resume_weights=resume_weights)
if self.begin_epoch >= num_epochs:
raise ValueError(
("begin epoch[{}] is larger than num_epochs[{}]").format(
self.begin_epoch, num_epochs))
if not osp.isdir(save_dir): if not osp.isdir(save_dir):
if osp.exists(save_dir): if osp.exists(save_dir):
os.remove(save_dir) os.remove(save_dir)
os.makedirs(save_dir) os.makedirs(save_dir)
if use_vdl:
from visualdl import LogWriter # add arrange op tor transforms
vdl_logdir = osp.join(save_dir, 'vdl_log') self.arrange_transform(transforms=train_reader.transforms, mode='train')
# 给transform添加arrange操作
self.arrange_transforms(
transforms=train_reader.transforms, mode='train')
# 构建train_data_loader
self.build_train_data_loader( self.build_train_data_loader(
reader=train_reader, batch_size=train_batch_size) dataset=train_reader, batch_size=train_batch_size)
if eval_reader is not None: if eval_reader is not None:
self.eval_transforms = eval_reader.transforms self.eval_transforms = eval_reader.transforms
self.test_transforms = copy.deepcopy(eval_reader.transforms) self.test_transforms = copy.deepcopy(eval_reader.transforms)
# 获取实时变化的learning rate
lr = self.optimizer._learning_rate lr = self.optimizer._learning_rate
lr.persistable = True
if isinstance(lr, fluid.framework.Variable): if isinstance(lr, fluid.framework.Variable):
self.train_outputs['lr'] = lr self.train_outputs['lr'] = lr
# 在多卡上跑训练 # 多卡训练
if self.parallel_train_prog is None: if self.parallel_train_prog is None:
build_strategy = fluid.compiler.BuildStrategy() build_strategy = fluid.compiler.BuildStrategy()
build_strategy.fuse_all_optimizer_ops = False if self.env_info['place'] != 'cpu' and len(self.places) > 1:
if __init__.env_info['place'] != 'cpu' and len(self.places) > 1:
build_strategy.sync_batch_norm = self.sync_bn build_strategy.sync_batch_norm = self.sync_bn
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_iteration_per_drop_scope = 1 exec_strategy.num_iteration_per_drop_scope = 1
self.parallel_train_prog = fluid.CompiledProgram( self.parallel_train_prog = fluid.CompiledProgram(
self.train_prog).with_data_parallel( self.train_prog).with_data_parallel(
loss_name=self.train_outputs['loss'].name, loss_name=self.train_outputs['loss'].name,
...@@ -259,16 +377,27 @@ class BaseAPI: ...@@ -259,16 +377,27 @@ class BaseAPI:
train_reader.num_samples / train_batch_size) train_reader.num_samples / train_batch_size)
num_steps = 0 num_steps = 0
time_stat = list() time_stat = list()
time_train_one_epoch = None
time_eval_one_epoch = None
total_num_steps_eval = 0
# eval times
total_eval_times = math.ceil(num_epochs / save_interval_epochs)
eval_batch_size = train_batch_size
if eval_reader is not None:
total_num_steps_eval = math.ceil(
eval_reader.num_samples / eval_batch_size)
if use_vdl: if use_vdl:
# VisualDL component from visualdl import LogWriter
vdl_logdir = osp.join(save_dir, 'vdl_log')
log_writer = LogWriter(vdl_logdir) log_writer = LogWriter(vdl_logdir)
best_metric = -1.0
best_accuracy = -1.0
best_model_epoch = 1 best_model_epoch = 1
for i in range(num_epochs): for i in range(self.begin_epoch, num_epochs):
records = list() records = list()
step_start_time = time.time() step_start_time = time.time()
epoch_start_time = time.time()
for step, data in enumerate(self.train_data_loader()): for step, data in enumerate(self.train_data_loader()):
outputs = self.exe.run( outputs = self.exe.run(
self.parallel_train_prog, self.parallel_train_prog,
...@@ -277,22 +406,15 @@ class BaseAPI: ...@@ -277,22 +406,15 @@ class BaseAPI:
outputs_avg = np.mean(np.array(outputs), axis=1) outputs_avg = np.mean(np.array(outputs), axis=1)
records.append(outputs_avg) records.append(outputs_avg)
# 训练完成剩余时间预估 # time estimated to complete the training
current_time = time.time() currend_time = time.time()
step_cost_time = current_time - step_start_time step_cost_time = currend_time - step_start_time
step_start_time = current_time step_start_time = currend_time
if len(time_stat) < 20: if len(time_stat) < 20:
time_stat.append(step_cost_time) time_stat.append(step_cost_time)
else: else:
time_stat[num_steps % 20] = step_cost_time time_stat[num_steps % 20] = step_cost_time
eta = ((num_epochs - i) * total_num_steps - step -
1) * np.mean(time_stat)
eta_h = math.floor(eta / 3600)
eta_m = math.floor((eta - eta_h * 3600) / 60)
eta_s = int(eta - eta_h * 3600 - eta_m * 60)
eta_str = "{}:{}:{}".format(eta_h, eta_m, eta_s)
# 每间隔log_interval_steps,输出loss信息
num_steps += 1 num_steps += 1
if num_steps % log_interval_steps == 0: if num_steps % log_interval_steps == 0:
step_metrics = OrderedDict( step_metrics = OrderedDict(
...@@ -301,38 +423,52 @@ class BaseAPI: ...@@ -301,38 +423,52 @@ class BaseAPI:
if use_vdl: if use_vdl:
for k, v in step_metrics.items(): for k, v in step_metrics.items():
log_writer.add_scalar( log_writer.add_scalar(
tag="Training: {}".format(k), step=num_steps,
value=v, tag='train/{}'.format(k),
step=num_steps) value=v)
# 计算剩余时间
avg_step_time = np.mean(time_stat)
if time_train_one_epoch is not None:
eta = (num_epochs - i - 1) * time_train_one_epoch + (
total_num_steps - step - 1) * avg_step_time
else:
eta = ((num_epochs - i) * total_num_steps - step -
1) * avg_step_time
if time_eval_one_epoch is not None:
eval_eta = (total_eval_times - i // save_interval_epochs
) * time_eval_one_epoch
else:
eval_eta = (total_eval_times - i // save_interval_epochs
) * total_num_steps_eval * avg_step_time
eta_str = seconds_to_hms(eta + eval_eta)
logging.info( logging.info(
"[TRAIN] Epoch={}/{}, Step={}/{}, {}, eta={}".format( "[TRAIN] Epoch={}/{}, Step={}/{}, {}, time_each_step={}s, eta={}"
i + 1, num_epochs, step + 1, total_num_steps, .format(i + 1, num_epochs, step + 1, total_num_steps,
dict2str(step_metrics), eta_str)) dict2str(step_metrics), round(avg_step_time, 2),
eta_str))
train_metrics = OrderedDict( train_metrics = OrderedDict(
zip(list(self.train_outputs.keys()), np.mean(records, axis=0))) zip(list(self.train_outputs.keys()), np.mean(records, axis=0)))
logging.info('[TRAIN] Epoch {} finished, {} .'.format( logging.info('[TRAIN] Epoch {} finished, {} .'.format(
i + 1, dict2str(train_metrics))) i + 1, dict2str(train_metrics)))
time_train_one_epoch = time.time() - epoch_start_time
# 每间隔save_interval_epochs, 在验证集上评估和对模型进行保存 eval_epoch_start_time = time.time()
if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1: if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1:
current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1)) current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
if not osp.isdir(current_save_dir): if not osp.isdir(current_save_dir):
os.makedirs(current_save_dir) os.makedirs(current_save_dir)
if eval_reader is not None: if eval_reader is not None:
# 检测目前仅支持单卡评估,训练数据batch大小与显卡数量之商为验证数据batch大小。 self.eval_metrics = self.evaluate(
eval_batch_size = train_batch_size
self.eval_metrics, self.eval_details = self.evaluate(
eval_reader=eval_reader, eval_reader=eval_reader,
batch_size=eval_batch_size, batch_size=eval_batch_size,
verbose=True, epoch_id=i + 1)
epoch_id=i + 1,
return_details=True)
logging.info('[EVAL] Finished, Epoch={}, {} .'.format(
i + 1, dict2str(self.eval_metrics)))
# 保存最优模型 # 保存最优模型
current_metric = self.eval_metrics[eval_best_metric] current_metric = self.eval_metrics[eval_best_metric]
if current_metric > best_accuracy: if current_metric > best_metric:
best_accuracy = current_metric best_metric = current_metric
best_model_epoch = i + 1 best_model_epoch = i + 1
best_model_dir = osp.join(save_dir, "best_model") best_model_dir = osp.join(save_dir, "best_model")
self.save_model(save_dir=best_model_dir) self.save_model(save_dir=best_model_dir)
...@@ -344,10 +480,131 @@ class BaseAPI: ...@@ -344,10 +480,131 @@ class BaseAPI:
if v.size > 1: if v.size > 1:
continue continue
log_writer.add_scalar( log_writer.add_scalar(
tag="Evaluation: {}".format(k), step=num_steps,
step=i + 1, tag='evaluate/{}'.format(k),
value=v) value=v)
self.save_model(save_dir=current_save_dir) self.save_model(save_dir=current_save_dir)
logging.info( time_eval_one_epoch = time.time() - eval_epoch_start_time
'Current evaluated best model in eval_reader is epoch_{}, {}={}' if eval_reader is not None:
.format(best_model_epoch, eval_best_metric, best_accuracy)) logging.info(
'Current evaluated best model in validation dataset is epoch_{}, {}={}'
.format(best_model_epoch, eval_best_metric,
best_metric))
def evaluate(self, eval_reader, batch_size=1, epoch_id=None):
"""评估。
Args:
eval_reader (reader): 评估数据读取器。
batch_size (int): 评估时的batch大小。默认1。
epoch_id (int): 当前评估模型所在的训练轮数。
return_details (bool): 是否返回详细信息。默认False。
Returns:
dict: 当return_details为False时,返回dict。包含关键字:'miou'、'category_iou'、'macc'、
'category_acc'和'kappa',分别表示平均iou、各类别iou、平均准确率、各类别准确率和kappa系数。
tuple (metrics, eval_details):当return_details为True时,增加返回dict (eval_details),
包含关键字:'confusion_matrix',表示评估的混淆矩阵。
"""
self.arrange_transform(transforms=eval_reader.transforms, mode='train')
total_steps = math.ceil(eval_reader.num_samples * 1.0 / batch_size)
conf_mat = ConfusionMatrix(self.num_classes, streaming=True)
data_generator = eval_reader.generator(
batch_size=batch_size, drop_last=False)
if not hasattr(self, 'parallel_test_prog'):
self.parallel_test_prog = fluid.CompiledProgram(
self.test_prog).with_data_parallel(
share_vars_from=self.parallel_train_prog)
logging.info(
"Start to evaluating(total_samples={}, total_steps={})...".format(
eval_reader.num_samples, total_steps))
for step, data in tqdm.tqdm(
enumerate(data_generator()), total=total_steps):
images = np.array([d[0] for d in data])
images = images.astype(np.float32)
labels = np.array([d[1] for d in data])
num_samples = images.shape[0]
if num_samples < batch_size:
num_pad_samples = batch_size - num_samples
pad_images = np.tile(images[0:1], (num_pad_samples, 1, 1, 1))
images = np.concatenate([images, pad_images])
feed_data = {'image': images}
outputs = self.exe.run(
self.parallel_test_prog,
feed=feed_data,
fetch_list=list(self.test_outputs.values()),
return_numpy=True)
pred = outputs[0]
if num_samples < batch_size:
pred = pred[0:num_samples]
mask = labels != self.ignore_index
conf_mat.calculate(pred=pred, label=labels, ignore=mask)
_, iou = conf_mat.mean_iou()
logging.debug("[EVAL] Epoch={}, Step={}/{}, iou={}".format(
epoch_id, step + 1, total_steps, iou))
category_iou, miou = conf_mat.mean_iou()
category_acc, macc = conf_mat.accuracy()
precision, recall = conf_mat.precision_recall()
metrics = OrderedDict(
zip([
'miou', 'category_iou', 'macc', 'category_acc', 'kappa',
'precision', 'recall'
], [
miou, category_iou, macc, category_acc,
conf_mat.kappa(), precision, recall
]))
logging.info('[EVAL] Finished, Epoch={}, {} .'.format(
epoch_id, dict2str(metrics)))
return metrics
def predict(self, im_file, transforms=None):
"""预测。
Args:
img_file(str|np.ndarray): 预测图像。
transforms(transforms.transforms): 数据预处理操作。
Returns:
dict: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,
像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)
"""
if isinstance(im_file, str):
if not osp.exists(im_file):
raise ValueError(
'The Image file does not exist: {}'.format(im_file))
if transforms is None and not hasattr(self, 'test_transforms'):
raise Exception("transforms need to be defined, now is None.")
if transforms is not None:
self.arrange_transform(transforms=transforms, mode='test')
im, im_info = transforms(im_file)
else:
self.arrange_transform(transforms=self.test_transforms, mode='test')
im, im_info = self.test_transforms(im_file)
im = im.astype(np.float32)
im = np.expand_dims(im, axis=0)
result = self.exe.run(
self.test_prog,
feed={'image': im},
fetch_list=list(self.test_outputs.values()))
pred = result[0]
logit = result[1]
logit = np.squeeze(logit)
logit = np.transpose(logit, (1, 2, 0))
pred = np.squeeze(pred).astype('uint8')
keys = list(im_info.keys())
for k in keys[::-1]:
if k == 'shape_before_resize':
h, w = im_info[k][0], im_info[k][1]
pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
logit = cv2.resize(logit, (w, h), cv2.INTER_LINEAR)
elif k == 'shape_before_padding':
h, w = im_info[k][0], im_info[k][1]
pred = pred[0:h, 0:w]
logit = logit[0:h, 0:w, :]
return {'label_map': pred, 'score_map': logit}
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import paddle.fluid as fluid
import os
from os import path as osp
import numpy as np
from collections import OrderedDict
import copy
import math
import time
import tqdm
import cv2
import yaml
import utils
import utils.logging as logging
from utils.utils import seconds_to_hms, get_environ_info
from utils.metrics import ConfusionMatrix
import nets
import transforms.transforms as T
from .base import BaseModel
def dict2str(dict_input):
out = ''
for k, v in dict_input.items():
try:
v = round(float(v), 6)
except:
pass
out = out + '{}={}, '.format(k, v)
return out.strip(', ')
class HRNet(BaseModel):
def __init__(self,
num_classes=2,
input_channel=3,
stage1_num_modules=1,
stage1_num_blocks=[4],
stage1_num_channels=[64],
stage2_num_modules=1,
stage2_num_blocks=[4, 4],
stage2_num_channels=[18, 36],
stage3_num_modules=4,
stage3_num_blocks=[4, 4, 4],
stage3_num_channels=[18, 36, 72],
stage4_num_modules=3,
stage4_num_blocks=[4, 4, 4, 4],
stage4_num_channels=[18, 36, 72, 144],
use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
ignore_index=255,
sync_bn=True):
super().__init__(
num_classes=num_classes,
use_bce_loss=use_bce_loss,
use_dice_loss=use_dice_loss,
class_weight=class_weight,
ignore_index=ignore_index,
sync_bn=sync_bn)
self.init_params = locals()
self.input_channel = input_channel
self.stage1_num_modules = stage1_num_modules
self.stage1_num_blocks = stage1_num_blocks
self.stage1_num_channels = stage1_num_channels
self.stage2_num_modules = stage2_num_modules
self.stage2_num_blocks = stage2_num_blocks
self.stage2_num_channels = stage2_num_channels
self.stage3_num_modules = stage3_num_modules
self.stage3_num_blocks = stage3_num_blocks
self.stage3_num_channels = stage3_num_channels
self.stage4_num_modules = stage4_num_modules
self.stage4_num_blocks = stage4_num_blocks
self.stage4_num_channels = stage4_num_channels
def build_net(self, mode='train'):
"""应根据不同的情况进行构建"""
model = nets.HRNet(
self.num_classes,
self.input_channel,
mode=mode,
stage1_num_modules=self.stage1_num_modules,
stage1_num_blocks=self.stage1_num_blocks,
stage1_num_channels=self.stage1_num_channels,
stage2_num_modules=self.stage2_num_modules,
stage2_num_blocks=self.stage2_num_blocks,
stage2_num_channels=self.stage2_num_channels,
stage3_num_modules=self.stage3_num_modules,
stage3_num_blocks=self.stage3_num_blocks,
stage3_num_channels=self.stage3_num_channels,
stage4_num_modules=self.stage4_num_modules,
stage4_num_blocks=self.stage4_num_blocks,
stage4_num_channels=self.stage4_num_channels,
use_bce_loss=self.use_bce_loss,
use_dice_loss=self.use_dice_loss,
class_weight=self.class_weight,
ignore_index=self.ignore_index)
inputs = model.generate_inputs()
model_out = model.build_net(inputs)
outputs = OrderedDict()
if mode == 'train':
self.optimizer.minimize(model_out)
outputs['loss'] = model_out
else:
outputs['pred'] = model_out[0]
outputs['logit'] = model_out[1]
return inputs, outputs
def train(self,
num_epochs,
train_reader,
train_batch_size=2,
eval_reader=None,
eval_best_metric='kappa',
save_interval_epochs=1,
log_interval_steps=2,
save_dir='output',
pretrain_weights=None,
resume_weights=None,
optimizer=None,
learning_rate=0.01,
lr_decay_power=0.9,
regularization_coeff=5e-4,
use_vdl=False):
super().train(
num_epochs=num_epochs,
train_reader=train_reader,
train_batch_size=train_batch_size,
eval_reader=eval_reader,
eval_best_metric=eval_best_metric,
save_interval_epochs=save_interval_epochs,
log_interval_steps=log_interval_steps,
save_dir=save_dir,
pretrain_weights=pretrain_weights,
resume_weights=resume_weights,
optimizer=optimizer,
learning_rate=learning_rate,
lr_decay_power=lr_decay_power,
regularization_coeff=regularization_coeff,
use_vdl=use_vdl)
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -25,7 +26,7 @@ import models ...@@ -25,7 +26,7 @@ import models
def load_model(model_dir): def load_model(model_dir):
if not osp.exists(osp.join(model_dir, "model.yml")): if not osp.exists(osp.join(model_dir, "model.yml")):
raise Exception("There's not model.yml in {}".format(model_dir)) raise Exception("There's no model.yml in {}".format(model_dir))
with open(osp.join(model_dir, "model.yml")) as f: with open(osp.join(model_dir, "model.yml")) as f:
info = yaml.load(f.read(), Loader=yaml.Loader) info = yaml.load(f.read(), Loader=yaml.Loader)
status = info['status'] status = info['status']
...@@ -35,8 +36,7 @@ def load_model(model_dir): ...@@ -35,8 +36,7 @@ def load_model(model_dir):
info['Model'])) info['Model']))
model = getattr(models, info['Model'])(**info['_init_params']) model = getattr(models, info['Model'])(**info['_init_params'])
if status == "Normal" or \ if status == "Normal":
status == "Prune":
startup_prog = fluid.Program() startup_prog = fluid.Program()
model.test_prog = fluid.Program() model.test_prog = fluid.Program()
with fluid.program_guard(model.test_prog, startup_prog): with fluid.program_guard(model.test_prog, startup_prog):
...@@ -45,17 +45,12 @@ def load_model(model_dir): ...@@ -45,17 +45,12 @@ def load_model(model_dir):
mode='test') mode='test')
model.test_prog = model.test_prog.clone(for_test=True) model.test_prog = model.test_prog.clone(for_test=True)
model.exe.run(startup_prog) model.exe.run(startup_prog)
if status == "Prune":
from .slim.prune import update_program
model.test_prog = update_program(model.test_prog, model_dir,
model.places[0])
import pickle import pickle
with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f: with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f:
load_dict = pickle.load(f) load_dict = pickle.load(f)
fluid.io.set_program_state(model.test_prog, load_dict) fluid.io.set_program_state(model.test_prog, load_dict)
elif status == "Infer" or \ elif status == "Infer":
status == "Quant":
[prog, input_names, outputs] = fluid.io.load_inference_model( [prog, input_names, outputs] = fluid.io.load_inference_model(
model_dir, model.exe, params_filename='__params__') model_dir, model.exe, params_filename='__params__')
model.test_prog = prog model.test_prog = prog
...@@ -67,8 +62,8 @@ def load_model(model_dir): ...@@ -67,8 +62,8 @@ def load_model(model_dir):
for i, out in enumerate(outputs): for i, out in enumerate(outputs):
var_desc = test_outputs_info[i] var_desc = test_outputs_info[i]
model.test_outputs[var_desc[0]] = out model.test_outputs[var_desc[0]] = out
if 'Transforms' in info: if 'test_transforms' in info:
model.test_transforms = build_transforms(info['Transforms']) model.test_transforms = build_transforms(info['test_transforms'])
model.eval_transforms = copy.deepcopy(model.test_transforms) model.eval_transforms = copy.deepcopy(model.test_transforms)
if '_Attributes' in info: if '_Attributes' in info:
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
#You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
#Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
#limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
import os.path as osp
import numpy as np import numpy as np
import math import math
import cv2 import cv2
import paddle.fluid as fluid import paddle.fluid as fluid
import utils.logging as logging import utils.logging as logging
from collections import OrderedDict from collections import OrderedDict
from .base import BaseAPI from .base import BaseModel
from utils.metrics import ConfusionMatrix from utils.metrics import ConfusionMatrix
import nets import nets
class UNet(BaseAPI): class UNet(BaseModel):
"""实现UNet网络的构建并进行训练、评估、预测和模型导出。 """实现UNet网络的构建并进行训练、评估、预测和模型导出。
Args: Args:
...@@ -55,9 +55,16 @@ class UNet(BaseAPI): ...@@ -55,9 +55,16 @@ class UNet(BaseAPI):
use_bce_loss=False, use_bce_loss=False,
use_dice_loss=False, use_dice_loss=False,
class_weight=None, class_weight=None,
ignore_index=255): ignore_index=255,
sync_bn=True):
super().__init__(
num_classes=num_classes,
use_bce_loss=use_bce_loss,
use_dice_loss=use_dice_loss,
class_weight=class_weight,
ignore_index=ignore_index,
sync_bn=sync_bn)
self.init_params = locals() self.init_params = locals()
super(UNet, self).__init__()
# dice_loss或bce_loss只适用两类分割中 # dice_loss或bce_loss只适用两类分割中
if num_classes > 2 and (use_bce_loss or use_dice_loss): if num_classes > 2 and (use_bce_loss or use_dice_loss):
raise ValueError( raise ValueError(
...@@ -115,24 +122,6 @@ class UNet(BaseAPI): ...@@ -115,24 +122,6 @@ class UNet(BaseAPI):
outputs['logit'] = model_out[1] outputs['logit'] = model_out[1]
return inputs, outputs return inputs, outputs
def default_optimizer(self,
learning_rate,
num_epochs,
num_steps_each_epoch,
lr_decay_power=0.9):
decay_step = num_epochs * num_steps_each_epoch
lr_decay = fluid.layers.polynomial_decay(
learning_rate,
decay_step,
end_learning_rate=0,
power=lr_decay_power)
optimizer = fluid.optimizer.Momentum(
lr_decay,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(
regularization_coeff=4e-05))
return optimizer
def train(self, def train(self,
num_epochs, num_epochs,
train_reader, train_reader,
...@@ -142,13 +131,13 @@ class UNet(BaseAPI): ...@@ -142,13 +131,13 @@ class UNet(BaseAPI):
save_interval_epochs=1, save_interval_epochs=1,
log_interval_steps=2, log_interval_steps=2,
save_dir='output', save_dir='output',
pretrain_weights='COCO', pretrain_weights=None,
resume_weights=None,
optimizer=None, optimizer=None,
learning_rate=0.01, learning_rate=0.01,
lr_decay_power=0.9, lr_decay_power=0.9,
use_vdl=False, regularization_coeff=5e-4,
sensitivities_file=None, use_vdl=False):
eval_metric_loss=0.05):
"""训练。 """训练。
Args: Args:
...@@ -160,46 +149,17 @@ class UNet(BaseAPI): ...@@ -160,46 +149,17 @@ class UNet(BaseAPI):
save_interval_epochs (int): 模型保存间隔(单位:迭代轮数)。默认为1。 save_interval_epochs (int): 模型保存间隔(单位:迭代轮数)。默认为1。
log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为2。 log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为2。
save_dir (str): 模型保存路径。默认'output'。 save_dir (str): 模型保存路径。默认'output'。
pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'COCO', pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为None,则不使用预训练模型。
则自动下载在COCO图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认为'COCO'。
optimizer (paddle.fluid.optimizer): 优化器。当改参数为None时,使用默认的优化器:使用 optimizer (paddle.fluid.optimizer): 优化器。当改参数为None时,使用默认的优化器:使用
fluid.optimizer.Momentum优化方法,polynomial的学习率衰减策略。 fluid.optimizer.Momentum优化方法,polynomial的学习率衰减策略。
learning_rate (float): 默认优化器的初始学习率。默认0.01。 learning_rate (float): 默认优化器的初始学习率。默认0.01。
lr_decay_power (float): 默认优化器学习率多项式衰减系数。默认0.9。 lr_decay_power (float): 默认优化器学习率多项式衰减系数。默认0.9。
use_vdl (bool): 是否使用VisualDL进行可视化。默认False。 use_vdl (bool): 是否使用VisualDL进行可视化。默认False。
sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
Raises: Raises:
ValueError: 模型从inference model进行加载。 ValueError: 模型从inference model进行加载。
""" """
if not self.trainable: super().train(
raise ValueError(
"Model is not trainable since it was loaded from a inference model."
)
self.labels = train_reader.labels
if optimizer is None:
num_steps_each_epoch = train_reader.num_samples // train_batch_size
optimizer = self.default_optimizer(
learning_rate=learning_rate,
num_epochs=num_epochs,
num_steps_each_epoch=num_steps_each_epoch,
lr_decay_power=lr_decay_power)
self.optimizer = optimizer
# 构建训练、验证、预测网络
self.build_program()
# 初始化网络权重
self.net_initialize(
startup_prog=fluid.default_startup_program(),
pretrain_weights=pretrain_weights,
save_dir=save_dir,
sensitivities_file=sensitivities_file,
eval_metric_loss=eval_metric_loss)
# 训练
self.train_loop(
num_epochs=num_epochs, num_epochs=num_epochs,
train_reader=train_reader, train_reader=train_reader,
train_batch_size=train_batch_size, train_batch_size=train_batch_size,
...@@ -208,6 +168,12 @@ class UNet(BaseAPI): ...@@ -208,6 +168,12 @@ class UNet(BaseAPI):
save_interval_epochs=save_interval_epochs, save_interval_epochs=save_interval_epochs,
log_interval_steps=log_interval_steps, log_interval_steps=log_interval_steps,
save_dir=save_dir, save_dir=save_dir,
pretrain_weights=pretrain_weights,
resume_weights=resume_weights,
optimizer=optimizer,
learning_rate=learning_rate,
lr_decay_power=lr_decay_power,
regularization_coeff=regularization_coeff,
use_vdl=use_vdl) use_vdl=use_vdl)
def evaluate(self, def evaluate(self,
...@@ -231,7 +197,7 @@ class UNet(BaseAPI): ...@@ -231,7 +197,7 @@ class UNet(BaseAPI):
tuple (metrics, eval_details):当return_details为True时,增加返回dict (eval_details), tuple (metrics, eval_details):当return_details为True时,增加返回dict (eval_details),
包含关键字:'confusion_matrix',表示评估的混淆矩阵。 包含关键字:'confusion_matrix',表示评估的混淆矩阵。
""" """
self.arrange_transforms(transforms=eval_reader.transforms, mode='eval') self.arrange_transform(transforms=eval_reader.transforms, mode='eval')
total_steps = math.ceil(eval_reader.num_samples * 1.0 / batch_size) total_steps = math.ceil(eval_reader.num_samples * 1.0 / batch_size)
conf_mat = ConfusionMatrix(self.num_classes, streaming=True) conf_mat = ConfusionMatrix(self.num_classes, streaming=True)
data_generator = eval_reader.generator( data_generator = eval_reader.generator(
...@@ -272,11 +238,16 @@ class UNet(BaseAPI): ...@@ -272,11 +238,16 @@ class UNet(BaseAPI):
category_iou, miou = conf_mat.mean_iou() category_iou, miou = conf_mat.mean_iou()
category_acc, macc = conf_mat.accuracy() category_acc, macc = conf_mat.accuracy()
precision, recall = conf_mat.precision_recall()
metrics = OrderedDict( metrics = OrderedDict(
zip(['miou', 'category_iou', 'macc', 'category_acc', 'kappa'], zip([
[miou, category_iou, macc, category_acc, 'miou', 'category_iou', 'macc', 'category_acc', 'kappa',
conf_mat.kappa()])) 'precision', 'recall'
], [
miou, category_iou, macc, category_acc,
conf_mat.kappa(), precision, recall
]))
if return_details: if return_details:
eval_details = { eval_details = {
'confusion_matrix': conf_mat.confusion_matrix.tolist() 'confusion_matrix': conf_mat.confusion_matrix.tolist()
...@@ -296,11 +267,10 @@ class UNet(BaseAPI): ...@@ -296,11 +267,10 @@ class UNet(BaseAPI):
if transforms is None and not hasattr(self, 'test_transforms'): if transforms is None and not hasattr(self, 'test_transforms'):
raise Exception("transforms need to be defined, now is None.") raise Exception("transforms need to be defined, now is None.")
if transforms is not None: if transforms is not None:
self.arrange_transforms(transforms=transforms, mode='test') self.arrange_transform(transforms=transforms, mode='test')
im, im_info = transforms(im_file) im, im_info = transforms(im_file)
else: else:
self.arrange_transforms( self.arrange_transform(transforms=self.test_transforms, mode='test')
transforms=self.test_transforms, mode='test')
im, im_info = self.test_transforms(im_file) im, im_info = self.test_transforms(im_file)
im = im.astype(np.float32) im = im.astype(np.float32)
im = np.expand_dims(im, axis=0) im = np.expand_dims(im, axis=0)
...@@ -319,4 +289,4 @@ class UNet(BaseAPI): ...@@ -319,4 +289,4 @@ class UNet(BaseAPI):
h, w = im_info[k][0], im_info[k][1] h, w = im_info[k][0], im_info[k][1]
pred = pred[0:h, 0:w] pred = pred[0:h, 0:w]
return pred return {'label_map': pred}
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .unet import UNet from .unet import UNet
from .hrnet import HRNet
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
from .loss import softmax_with_loss
from .loss import dice_loss
from .loss import bce_loss
from .libs import sigmoid_to_softmax
class HRNet(object):
def __init__(self,
num_classes,
input_channel=3,
mode='train',
stage1_num_modules=1,
stage1_num_blocks=[4],
stage1_num_channels=[64],
stage2_num_modules=1,
stage2_num_blocks=[4, 4],
stage2_num_channels=[18, 36],
stage3_num_modules=4,
stage3_num_blocks=[4, 4, 4],
stage3_num_channels=[18, 36, 72],
stage4_num_modules=3,
stage4_num_blocks=[4, 4, 4, 4],
stage4_num_channels=[18, 36, 72, 144],
use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
ignore_index=255):
# dice_loss或bce_loss只适用两类分割中
if num_classes > 2 and (use_bce_loss or use_dice_loss):
raise ValueError(
"dice loss and bce loss is only applicable to binary classfication"
)
if class_weight is not None:
if isinstance(class_weight, list):
if len(class_weight) != num_classes:
raise ValueError(
"Length of class_weight should be equal to number of classes"
)
elif isinstance(class_weight, str):
if class_weight.lower() != 'dynamic':
raise ValueError(
"if class_weight is string, must be dynamic!")
else:
raise TypeError(
'Expect class_weight is a list or string but receive {}'.
format(type(class_weight)))
self.num_classes = num_classes
self.input_channel = input_channel
self.mode = mode
self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss
self.class_weight = class_weight
self.ignore_index = ignore_index
self.stage1_num_modules = stage1_num_modules
self.stage1_num_blocks = stage1_num_blocks
self.stage1_num_channels = stage1_num_channels
self.stage2_num_modules = stage2_num_modules
self.stage2_num_blocks = stage2_num_blocks
self.stage2_num_channels = stage2_num_channels
self.stage3_num_modules = stage3_num_modules
self.stage3_num_blocks = stage3_num_blocks
self.stage3_num_channels = stage3_num_channels
self.stage4_num_modules = stage4_num_modules
self.stage4_num_blocks = stage4_num_blocks
self.stage4_num_channels = stage4_num_channels
def build_net(self, inputs):
if self.use_dice_loss or self.use_bce_loss:
self.num_classes = 1
image = inputs['image']
logit = self._high_resolution_net(image, self.num_classes)
if self.num_classes == 1:
out = sigmoid_to_softmax(logit)
out = fluid.layers.transpose(out, [0, 2, 3, 1])
else:
out = fluid.layers.transpose(logit, [0, 2, 3, 1])
pred = fluid.layers.argmax(out, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
if self.mode == 'train':
label = inputs['label']
mask = label != self.ignore_index
return self._get_loss(logit, label, mask)
else:
if self.num_classes == 1:
logit = sigmoid_to_softmax(logit)
else:
logit = fluid.layers.softmax(logit, axis=1)
return pred, logit
return logit
def generate_inputs(self):
inputs = OrderedDict()
inputs['image'] = fluid.data(
dtype='float32',
shape=[None, self.input_channel, None, None],
name='image')
if self.mode == 'train':
inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label')
elif self.mode == 'eval':
inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label')
return inputs
def _get_loss(self, logit, label, mask):
avg_loss = 0
if not (self.use_dice_loss or self.use_bce_loss):
avg_loss += softmax_with_loss(
logit,
label,
mask,
num_classes=self.num_classes,
weight=self.class_weight,
ignore_index=self.ignore_index)
else:
if self.use_dice_loss:
avg_loss += dice_loss(logit, label, mask)
if self.use_bce_loss:
avg_loss += bce_loss(
logit, label, mask, ignore_index=self.ignore_index)
return avg_loss
def _conv_bn_layer(self,
input,
filter_size,
num_filters,
stride=1,
padding=1,
num_groups=1,
if_act=True,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=num_groups,
act=None,
param_attr=ParamAttr(initializer=MSRA(), name=name + '_weights'),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(
input=conv,
param_attr=ParamAttr(
name=bn_name + "_scale",
initializer=fluid.initializer.Constant(1.0)),
bias_attr=ParamAttr(
name=bn_name + "_offset",
initializer=fluid.initializer.Constant(0.0)),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act:
bn = fluid.layers.relu(bn)
return bn
def _basic_block(self,
input,
num_filters,
stride=1,
downsample=False,
name=None):
residual = input
conv = self._conv_bn_layer(
input=input,
filter_size=3,
num_filters=num_filters,
stride=stride,
name=name + '_conv1')
conv = self._conv_bn_layer(
input=conv,
filter_size=3,
num_filters=num_filters,
if_act=False,
name=name + '_conv2')
if downsample:
residual = self._conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_filters,
if_act=False,
name=name + '_downsample')
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def _bottleneck_block(self,
input,
num_filters,
stride=1,
downsample=False,
name=None):
residual = input
conv = self._conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_filters,
name=name + '_conv1')
conv = self._conv_bn_layer(
input=conv,
filter_size=3,
num_filters=num_filters,
stride=stride,
name=name + '_conv2')
conv = self._conv_bn_layer(
input=conv,
filter_size=1,
num_filters=num_filters * 4,
if_act=False,
name=name + '_conv3')
if downsample:
residual = self._conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_filters * 4,
if_act=False,
name=name + '_downsample')
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def _fuse_layers(self, x, channels, multi_scale_output=True, name=None):
out = []
for i in range(len(channels) if multi_scale_output else 1):
residual = x[i]
shape = fluid.layers.shape(residual)[-2:]
for j in range(len(channels)):
if j > i:
y = self._conv_bn_layer(
x[j],
filter_size=1,
num_filters=channels[i],
if_act=False,
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1))
y = fluid.layers.resize_bilinear(input=y, out_shape=shape)
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
elif j < i:
y = x[j]
for k in range(i - j):
if k == i - j - 1:
y = self._conv_bn_layer(
y,
filter_size=3,
num_filters=channels[i],
stride=2,
if_act=False,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1))
else:
y = self._conv_bn_layer(
y,
filter_size=3,
num_filters=channels[j],
stride=2,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1))
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
residual = fluid.layers.relu(residual)
out.append(residual)
return out
def _branches(self, x, block_num, channels, name=None):
out = []
for i in range(len(channels)):
residual = x[i]
for j in range(block_num[i]):
residual = self._basic_block(
residual,
channels[i],
name=name + '_branch_layer_' + str(i + 1) + '_' +
str(j + 1))
out.append(residual)
return out
def _high_resolution_module(self,
x,
blocks,
channels,
multi_scale_output=True,
name=None):
residual = self._branches(x, blocks, channels, name=name)
out = self._fuse_layers(
residual,
channels,
multi_scale_output=multi_scale_output,
name=name)
return out
def _transition_layer(self, x, in_channels, out_channels, name=None):
num_in = len(in_channels)
num_out = len(out_channels)
out = []
for i in range(num_out):
if i < num_in:
if in_channels[i] != out_channels[i]:
residual = self._conv_bn_layer(
x[i],
filter_size=3,
num_filters=out_channels[i],
name=name + '_layer_' + str(i + 1))
out.append(residual)
else:
out.append(x[i])
else:
residual = self._conv_bn_layer(
x[-1],
filter_size=3,
num_filters=out_channels[i],
stride=2,
name=name + '_layer_' + str(i + 1))
out.append(residual)
return out
def _stage(self,
x,
num_modules,
num_blocks,
num_channels,
multi_scale_output=True,
name=None):
out = x
for i in range(num_modules):
if i == num_modules - 1 and multi_scale_output == False:
out = self._high_resolution_module(
out,
num_blocks,
num_channels,
multi_scale_output=False,
name=name + '_' + str(i + 1))
else:
out = self._high_resolution_module(
out, num_blocks, num_channels, name=name + '_' + str(i + 1))
return out
def _layer1(self, input, num_modules, num_blocks, num_channels, name=None):
# num_modules 默认为1,是否增加处理,官网实现为[1],是否对齐。
conv = input
for i in range(num_blocks[0]):
conv = self._bottleneck_block(
conv,
num_filters=num_channels[0],
downsample=True if i == 0 else False,
name=name + '_' + str(i + 1))
return conv
def _high_resolution_net(self, input, num_classes):
x = self._conv_bn_layer(
input=input,
filter_size=3,
num_filters=self.stage1_num_channels[0],
stride=2,
if_act=True,
name='layer1_1')
x = self._conv_bn_layer(
input=x,
filter_size=3,
num_filters=self.stage1_num_channels[0],
stride=2,
if_act=True,
name='layer1_2')
la1 = self._layer1(
x,
self.stage1_num_modules,
self.stage1_num_blocks,
self.stage1_num_channels,
name='layer2')
tr1 = self._transition_layer([la1],
self.stage1_num_channels,
self.stage2_num_channels,
name='tr1')
st2 = self._stage(
tr1,
self.stage2_num_modules,
self.stage2_num_blocks,
self.stage2_num_channels,
name='st2')
tr2 = self._transition_layer(
st2, self.stage2_num_channels, self.stage3_num_channels, name='tr2')
st3 = self._stage(
tr2,
self.stage3_num_modules,
self.stage3_num_blocks,
self.stage3_num_channels,
name='st3')
tr3 = self._transition_layer(
st3, self.stage3_num_channels, self.stage4_num_channels, name='tr3')
st4 = self._stage(
tr3,
self.stage4_num_modules,
self.stage4_num_blocks,
self.stage4_num_channels,
name='st4')
# upsample
shape = fluid.layers.shape(st4[0])[-2:]
st4[1] = fluid.layers.resize_bilinear(st4[1], out_shape=shape)
st4[2] = fluid.layers.resize_bilinear(st4[2], out_shape=shape)
st4[3] = fluid.layers.resize_bilinear(st4[3], out_shape=shape)
out = fluid.layers.concat(st4, axis=1)
last_channels = sum(self.stage4_num_channels)
out = self._conv_bn_layer(
input=out,
filter_size=1,
num_filters=last_channels,
stride=1,
if_act=True,
name='conv-2')
out = fluid.layers.conv2d(
input=out,
num_filters=num_classes,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(initializer=MSRA(), name='conv-1_weights'),
bias_attr=False)
input_shape = fluid.layers.shape(input)[-2:]
out = fluid.layers.resize_bilinear(out, input_shape)
return out
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import os.path as osp import os.path as osp
import sys
import numpy as np import numpy as np
from PIL import Image as Image from PIL import Image as Image
import argparse import argparse
...@@ -8,46 +24,81 @@ from models import load_model ...@@ -8,46 +24,81 @@ from models import load_model
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='RemoteSensing predict') parser = argparse.ArgumentParser(description='RemoteSensing predict')
parser.add_argument(
'--single_img',
dest='single_img',
help='single image path to predict',
default=None,
type=str)
parser.add_argument( parser.add_argument(
'--data_dir', '--data_dir',
dest='data_dir', dest='data_dir',
help='dataset directory', help='dataset directory',
default=None, default=None,
type=str) type=str)
parser.add_argument(
'--file_list',
dest='file_list',
help='file name of predict file list',
default=None,
type=str)
parser.add_argument( parser.add_argument(
'--load_model_dir', '--load_model_dir',
dest='load_model_dir', dest='load_model_dir',
help='model load directory', help='model load directory',
default=None, default=None,
type=str) type=str)
parser.add_argument(
'--save_img_dir',
dest='save_img_dir',
help='save directory name of predict results',
default='predict_results',
type=str)
if len(sys.argv) < 2:
parser.print_help()
sys.exit(1)
return parser.parse_args() return parser.parse_args()
args = parse_args() args = parse_args()
data_dir = args.data_dir data_dir = args.data_dir
file_list = args.file_list
single_img = args.single_img
load_model_dir = args.load_model_dir load_model_dir = args.load_model_dir
save_img_dir = args.save_img_dir
if not osp.exists(save_img_dir):
os.makedirs(save_img_dir)
# predict # predict
model = load_model(load_model_dir) model = load_model(load_model_dir)
pred_dir = osp.join(load_model_dir, 'predict')
if not osp.exists(pred_dir): color_map = [0, 0, 0, 0, 255, 0]
os.mkdir(pred_dir) if single_img is not None:
pred = model.predict(single_img)
val_list = osp.join(data_dir, 'val.txt') # 以伪彩色png图片保存预测结果
color_map = [0, 0, 0, 255, 255, 255] pred_name = osp.basename(single_img).rstrip('npy') + 'png'
with open(val_list) as f: pred_path = osp.join(save_img_dir, pred_name)
lines = f.readlines() pred_mask = Image.fromarray(pred['label_map'].astype(np.uint8), mode='P')
for line in lines: pred_mask.putpalette(color_map)
img_path = line.split(' ')[0] pred_mask.save(pred_path)
print('Predicting {}'.format(img_path)) elif (file_list is not None) and (data_dir is not None):
img_path_ = osp.join(data_dir, img_path) with open(osp.join(data_dir, file_list)) as f:
lines = f.readlines()
pred = model.predict(img_path_) for line in lines:
img_path = line.split(' ')[0]
# 以伪彩色png图片保存预测结果 print('Predicting {}'.format(img_path))
pred_name = osp.basename(img_path).rstrip('npy') + 'png' img_path_ = osp.join(data_dir, img_path)
pred_path = osp.join(pred_dir, pred_name)
pred_mask = Image.fromarray(pred.astype(np.uint8), mode='P') pred = model.predict(img_path_)
pred_mask.putpalette(color_map)
pred_mask.save(pred_path) # 以伪彩色png图片保存预测结果
pred_name = osp.basename(img_path).rstrip('npy') + 'png'
pred_path = osp.join(save_img_dir, pred_name)
pred_mask = Image.fromarray(
pred['label_map'].astype(np.uint8), mode='P')
pred_mask.putpalette(color_map)
pred_mask.save(pred_path)
else:
raise Exception(
'You should either set the parameter single_img, or set the parameters data_dir, file_list.'
)
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp import os.path as osp
import argparse import argparse
import transforms.transforms as T import transforms.transforms as T
from readers.reader import Reader from readers.reader import Reader
from models import UNet from models import UNet, HRNet
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='RemoteSensing training') parser = argparse.ArgumentParser(description='RemoteSensing training')
parser.add_argument(
'--model_type',
dest='model_type',
help="Model type for traing, which is one of ('unet', 'hrnet')",
type=str,
default='hrnet')
parser.add_argument( parser.add_argument(
'--data_dir', '--data_dir',
dest='data_dir', dest='data_dir',
...@@ -43,7 +64,6 @@ def parse_args(): ...@@ -43,7 +64,6 @@ def parse_args():
args = parse_args() args = parse_args()
data_dir = args.data_dir data_dir = args.data_dir
save_dir = args.save_dir save_dir = args.save_dir
channel = args.channel channel = args.channel
...@@ -52,17 +72,9 @@ train_batch_size = args.train_batch_size ...@@ -52,17 +72,9 @@ train_batch_size = args.train_batch_size
lr = args.lr lr = args.lr
# 定义训练和验证时的transforms # 定义训练和验证时的transforms
train_transforms = T.Compose([ train_transforms = T.Compose([T.RandomHorizontalFlip(0.5), T.Normalize()])
T.RandomVerticalFlip(0.5),
T.RandomHorizontalFlip(0.5),
T.ResizeStepScaling(0.5, 2.0, 0.25),
T.RandomPaddingCrop(256),
T.Normalize(mean=[0.5] * channel, std=[0.5] * channel),
])
eval_transforms = T.Compose([ eval_transforms = T.Compose([T.Normalize()])
T.Normalize(mean=[0.5] * channel, std=[0.5] * channel),
])
train_list = osp.join(data_dir, 'train.txt') train_list = osp.join(data_dir, 'train.txt')
val_list = osp.join(data_dir, 'val.txt') val_list = osp.join(data_dir, 'val.txt')
...@@ -74,23 +86,30 @@ train_reader = Reader( ...@@ -74,23 +86,30 @@ train_reader = Reader(
file_list=train_list, file_list=train_list,
label_list=label_list, label_list=label_list,
transforms=train_transforms, transforms=train_transforms,
num_workers=8, shuffle=True)
buffer_size=16,
shuffle=True,
parallel_method='thread')
eval_reader = Reader( eval_reader = Reader(
data_dir=data_dir, data_dir=data_dir,
file_list=val_list, file_list=val_list,
label_list=label_list, label_list=label_list,
transforms=eval_transforms, transforms=eval_transforms)
num_workers=8,
buffer_size=16,
shuffle=False,
parallel_method='thread')
model = UNet( if args.model_type == 'unet':
num_classes=2, input_channel=channel, use_bce_loss=True, use_dice_loss=True) model = UNet(
num_classes=2,
input_channel=channel,
use_bce_loss=True,
use_dice_loss=True)
elif args.model_type == 'hrnet':
model = HRNet(
num_classes=2,
input_channel=channel,
use_bce_loss=True,
use_dice_loss=True)
else:
raise ValueError(
"--model_type: {} is set wrong, it shold be one of ('unet', "
"'hrnet')".format(args.model_type))
model.train( model.train(
num_epochs=num_epochs, num_epochs=num_epochs,
...@@ -100,7 +119,5 @@ model.train( ...@@ -100,7 +119,5 @@ model.train(
save_interval_epochs=5, save_interval_epochs=5,
log_interval_steps=10, log_interval_steps=10,
save_dir=save_dir, save_dir=save_dir,
pretrain_weights=None,
optimizer=None,
learning_rate=lr, learning_rate=lr,
use_vdl=True) use_vdl=True)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -143,3 +143,14 @@ class ConfusionMatrix(object): ...@@ -143,3 +143,14 @@ class ConfusionMatrix(object):
kappa = (po - pe) / (1 - pe) kappa = (po - pe) / (1 - pe)
return kappa return kappa
def precision_recall(self):
'''
precision, recall of foreground(value=1) for 2 categories
'''
TP = self.confusion_matrix[1, 1]
FN = self.confusion_matrix[1, 0]
FP = self.confusion_matrix[0, 1]
recall = TP / (TP + FN)
precision = TP / (TP + FP)
return precision, recall
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path as osp import os.path as osp
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,13 +13,10 @@ ...@@ -12,13 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
import time
import os import os
import os.path as osp import os.path as osp
import numpy as np import numpy as np
import six import six
import yaml
import math import math
from . import logging from . import logging
...@@ -204,11 +202,9 @@ def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False): ...@@ -204,11 +202,9 @@ def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False):
vars_to_load.append(var) vars_to_load.append(var)
logging.debug("Weight {} will be load".format(var.name)) logging.debug("Weight {} will be load".format(var.name))
fluid.io.load_vars( params_dict = fluid.io.load_program_state(
executor=exe, weights_dir, var_list=vars_to_load)
dirname=weights_dir, fluid.io.set_program_state(main_prog, params_dict)
main_program=main_prog,
vars=vars_to_load)
if len(vars_to_load) == 0: if len(vars_to_load) == 0:
logging.warning( logging.warning(
"There is no pretrain weights loaded, maybe you should check you pretrain model!" "There is no pretrain weights loaded, maybe you should check you pretrain model!"
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from utils.util import AttrDict, merge_cfg_from_args, get_arguments from utils.util import AttrDict, merge_cfg_from_args, get_arguments
import os import os
...@@ -6,20 +20,20 @@ args = get_arguments() ...@@ -6,20 +20,20 @@ args = get_arguments()
cfg = AttrDict() cfg = AttrDict()
# 待预测图像所在路径 # 待预测图像所在路径
cfg.data_dir = os.path.join(args.example , "data", "test_images") cfg.data_dir = os.path.join(args.example, "data", "test_images")
# 待预测图像名称列表 # 待预测图像名称列表
cfg.data_list_file = os.path.join(args.example , "data", "test.txt") cfg.data_list_file = os.path.join(args.example, "data", "test.txt")
# 模型加载路径 # 模型加载路径
cfg.model_path = os.path.join(args.example , "model") cfg.model_path = os.path.join(args.example, "model")
# 预测结果保存路径 # 预测结果保存路径
cfg.vis_dir = os.path.join(args.example , "result") cfg.vis_dir = os.path.join(args.example, "result")
# 预测类别数 # 预测类别数
cfg.class_num = 2 cfg.class_num = 2
# 均值, 图像预处理减去的均值 # 均值, 图像预处理减去的均值
cfg.MEAN = 127.5, 127.5, 127.5 cfg.MEAN = 127.5, 127.5, 127.5
# 标准差,图像预处理除以标准差 # 标准差,图像预处理除以标准差
cfg.STD = 127.5, 127.5, 127.5 cfg.STD = 127.5, 127.5, 127.5
# 待预测图像输入尺寸 # 待预测图像输入尺寸
cfg.input_size = 1536, 576 cfg.input_size = 1536, 576
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import cv2 import cv2
import numpy as np import numpy as np
...@@ -12,18 +26,19 @@ config = importlib.import_module('config') ...@@ -12,18 +26,19 @@ config = importlib.import_module('config')
cfg = getattr(config, 'cfg') cfg = getattr(config, 'cfg')
# paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启 # paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启
os.environ['FLAGS_eager_delete_tensor_gb']='0.0' os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
import paddle.fluid as fluid import paddle.fluid as fluid
# 预测数据集类 # 预测数据集类
class TestDataSet(): class TestDataSet():
def __init__(self): def __init__(self):
self.data_dir = cfg.data_dir self.data_dir = cfg.data_dir
self.data_list_file = cfg.data_list_file self.data_list_file = cfg.data_list_file
self.data_list = self.get_data_list() self.data_list = self.get_data_list()
self.data_num = len(self.data_list) self.data_num = len(self.data_list)
def get_data_list(self): def get_data_list(self):
# 获取预测图像路径列表 # 获取预测图像路径列表
data_list = [] data_list = []
...@@ -40,7 +55,7 @@ class TestDataSet(): ...@@ -40,7 +55,7 @@ class TestDataSet():
def preprocess(self, img): def preprocess(self, img):
# 图像预处理 # 图像预处理
if cfg.example == 'ACE2P': if cfg.example == 'ACE2P':
reader = importlib.import_module(args.example+'.reader') reader = importlib.import_module(args.example + '.reader')
ACE2P_preprocess = getattr(reader, 'preprocess') ACE2P_preprocess = getattr(reader, 'preprocess')
img = ACE2P_preprocess(img) img = ACE2P_preprocess(img)
else: else:
...@@ -56,10 +71,10 @@ class TestDataSet(): ...@@ -56,10 +71,10 @@ class TestDataSet():
img_path = self.data_list[index] img_path = self.data_list[index]
img = cv2.imread(img_path, cv2.IMREAD_COLOR) img = cv2.imread(img_path, cv2.IMREAD_COLOR)
if img is None: if img is None:
return img, img,img_path, None return img, img, img_path, None
img_name = img_path.split(os.sep)[-1] img_name = img_path.split(os.sep)[-1]
name_prefix = img_name.replace('.'+img_name.split('.')[-1],'') name_prefix = img_name.replace('.' + img_name.split('.')[-1], '')
img_shape = img.shape[:2] img_shape = img.shape[:2]
img_process = self.preprocess(img) img_process = self.preprocess(img)
...@@ -90,39 +105,44 @@ def infer(): ...@@ -90,39 +105,44 @@ def infer():
if image is None: if image is None:
print(im_name, 'is None') print(im_name, 'is None')
continue continue
# 预测 # 预测
if cfg.example == 'ACE2P': if cfg.example == 'ACE2P':
# ACE2P模型使用多尺度预测 # ACE2P模型使用多尺度预测
reader = importlib.import_module(args.example+'.reader') reader = importlib.import_module(args.example + '.reader')
multi_scale_test = getattr(reader, 'multi_scale_test') multi_scale_test = getattr(reader, 'multi_scale_test')
parsing, logits = multi_scale_test(exe, test_prog, feed_name, fetch_list, image, im_shape) parsing, logits = multi_scale_test(exe, test_prog, feed_name,
fetch_list, image, im_shape)
else: else:
# HumanSeg,RoadLine模型单尺度预测 # HumanSeg,RoadLine模型单尺度预测
result = exe.run(program=test_prog, feed={feed_name[0]: image}, fetch_list=fetch_list) result = exe.run(
program=test_prog,
feed={feed_name[0]: image},
fetch_list=fetch_list)
parsing = np.argmax(result[0][0], axis=0) parsing = np.argmax(result[0][0], axis=0)
parsing = cv2.resize(parsing.astype(np.uint8), im_shape[::-1]) parsing = cv2.resize(parsing.astype(np.uint8), im_shape[::-1])
# 预测结果保存 # 预测结果保存
result_path = os.path.join(cfg.vis_dir, im_name + '.png') result_path = os.path.join(cfg.vis_dir, im_name + '.png')
if cfg.example == 'HumanSeg': if cfg.example == 'HumanSeg':
logits = result[0][0][1]*255 logits = result[0][0][1] * 255
logits = cv2.resize(logits, im_shape[::-1]) logits = cv2.resize(logits, im_shape[::-1])
ret, logits = cv2.threshold(logits, thresh, 0, cv2.THRESH_TOZERO) ret, logits = cv2.threshold(logits, thresh, 0, cv2.THRESH_TOZERO)
logits = 255 *(logits - thresh)/(255 - thresh) logits = 255 * (logits - thresh) / (255 - thresh)
# 将分割结果添加到alpha通道 # 将分割结果添加到alpha通道
rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)), axis=2) rgba = np.concatenate((ori_img, np.expand_dims(logits, axis=2)),
axis=2)
cv2.imwrite(result_path, rgba) cv2.imwrite(result_path, rgba)
else: else:
output_im = PILImage.fromarray(np.asarray(parsing, dtype=np.uint8)) output_im = PILImage.fromarray(np.asarray(parsing, dtype=np.uint8))
output_im.putpalette(palette) output_im.putpalette(palette)
output_im.save(result_path) output_im.save(result_path)
if (idx + 1) % 100 == 0: if (idx + 1) % 100 == 0:
print('%d processd' % (idx + 1)) print('%d processd' % (idx + 1))
print('%d processd done' % (idx + 1)) print('%d processd done' % (idx + 1))
return 0 return 0
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # coding: utf8
## Created by: RainbowSecret # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
## Microsoft Research #
## yuyua@microsoft.com # Licensed under the Apache License, Version 2.0 (the "License");
## Copyright (c) 2018 # you may not use this file except in compliance with the License.
## # You may obtain a copy of the License at
## This source code is licensed under the MIT-style license found in the #
## LICENSE file in the root directory of this source tree # http://www.apache.org/licenses/LICENSE-2.0
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import argparse import argparse
import os import os
def get_arguments(): def get_arguments():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--use_gpu", parser.add_argument(
action="store_true", "--use_gpu", action="store_true", help="Use gpu or cpu to test.")
help="Use gpu or cpu to test.") parser.add_argument(
parser.add_argument('--example', '--example', type=str, help='RoadLine, HumanSeg or ACE2P')
type=str,
help='RoadLine, HumanSeg or ACE2P')
return parser.parse_args() return parser.parse_args()
...@@ -34,6 +48,7 @@ class AttrDict(dict): ...@@ -34,6 +48,7 @@ class AttrDict(dict):
else: else:
self[name] = value self[name] = value
def merge_cfg_from_args(args, cfg): def merge_cfg_from_args(args, cfg):
"""Merge config keys, values in args into the global config.""" """Merge config keys, values in args into the global config."""
for k, v in vars(args).items(): for k, v in vars(args).items():
...@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg): ...@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg):
value = v value = v
if value is not None: if value is not None:
cfg[k] = value cfg[k] = value
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -20,6 +21,8 @@ from PIL import Image ...@@ -20,6 +21,8 @@ from PIL import Image
import glob import glob
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__)) LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
def remove_colormap(filename): def remove_colormap(filename):
gray_anno = np.array(Image.open(filename)) gray_anno = np.array(Image.open(filename))
return gray_anno return gray_anno
...@@ -30,6 +33,7 @@ def save_annotation(annotation, filename): ...@@ -30,6 +33,7 @@ def save_annotation(annotation, filename):
annotation = Image.fromarray(annotation) annotation = Image.fromarray(annotation)
annotation.save(filename) annotation.save(filename)
def convert_list(origin_file, seg_file, output_folder): def convert_list(origin_file, seg_file, output_folder):
with open(seg_file, 'w') as fid_seg: with open(seg_file, 'w') as fid_seg:
with open(origin_file) as fid_ori: with open(origin_file) as fid_ori:
...@@ -43,6 +47,7 @@ def convert_list(origin_file, seg_file, output_folder): ...@@ -43,6 +47,7 @@ def convert_list(origin_file, seg_file, output_folder):
new_line = ' '.join([img_name, anno_name]) new_line = ' '.join([img_name, anno_name])
fid_seg.write(new_line + "\n") fid_seg.write(new_line + "\n")
if __name__ == "__main__": if __name__ == "__main__":
pascal_root = "./VOCtrainval_11-May-2012/VOC2012" pascal_root = "./VOCtrainval_11-May-2012/VOC2012"
pascal_root = os.path.join(LOCAL_PATH, pascal_root) pascal_root = os.path.join(LOCAL_PATH, pascal_root)
...@@ -54,7 +59,7 @@ if __name__ == "__main__": ...@@ -54,7 +59,7 @@ if __name__ == "__main__":
# 标注图转换后存储目录 # 标注图转换后存储目录
output_folder = os.path.join(pascal_root, "SegmentationClassAug") output_folder = os.path.join(pascal_root, "SegmentationClassAug")
print("annotation convert and file list convert") print("annotation convert and file list convert")
if not os.path.exists(os.path.join(LOCAL_PATH, output_folder)): if not os.path.exists(os.path.join(LOCAL_PATH, output_folder)):
os.mkdir(os.path.join(LOCAL_PATH, output_folder)) os.mkdir(os.path.join(LOCAL_PATH, output_folder))
...@@ -67,5 +72,5 @@ if __name__ == "__main__": ...@@ -67,5 +72,5 @@ if __name__ == "__main__":
convert_list(train_path, train_path.replace('txt', 'list'), output_folder) convert_list(train_path, train_path.replace('txt', 'list'), output_folder)
convert_list(val_path, val_path.replace('txt', 'list'), output_folder) convert_list(val_path, val_path.replace('txt', 'list'), output_folder)
convert_list(trainval_path, trainval_path.replace('txt', 'list'), output_folder) convert_list(trainval_path, trainval_path.replace('txt', 'list'),
output_folder)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -28,12 +29,12 @@ from convert_voc2012 import remove_colormap ...@@ -28,12 +29,12 @@ from convert_voc2012 import remove_colormap
from convert_voc2012 import save_annotation from convert_voc2012 import save_annotation
def download_VOC_dataset(savepath, extrapath): def download_VOC_dataset(savepath, extrapath):
url = "https://paddleseg.bj.bcebos.com/dataset/VOCtrainval_11-May-2012.tar" url = "https://paddleseg.bj.bcebos.com/dataset/VOCtrainval_11-May-2012.tar"
download_file_and_uncompress( download_file_and_uncompress(
url=url, savepath=savepath, extrapath=extrapath) url=url, savepath=savepath, extrapath=extrapath)
if __name__ == "__main__": if __name__ == "__main__":
download_VOC_dataset(LOCAL_PATH, LOCAL_PATH) download_VOC_dataset(LOCAL_PATH, LOCAL_PATH)
print("Dataset download finish!") print("Dataset download finish!")
...@@ -45,10 +46,10 @@ if __name__ == "__main__": ...@@ -45,10 +46,10 @@ if __name__ == "__main__":
train_path = os.path.join(txt_folder, "train.txt") train_path = os.path.join(txt_folder, "train.txt")
val_path = os.path.join(txt_folder, "val.txt") val_path = os.path.join(txt_folder, "val.txt")
trainval_path = os.path.join(txt_folder, "trainval.txt") trainval_path = os.path.join(txt_folder, "trainval.txt")
# 标注图转换后存储目录 # 标注图转换后存储目录
output_folder = os.path.join(pascal_root, "SegmentationClassAug") output_folder = os.path.join(pascal_root, "SegmentationClassAug")
print("annotation convert and file list convert") print("annotation convert and file list convert")
if not os.path.exists(output_folder): if not os.path.exists(output_folder):
os.mkdir(output_folder) os.mkdir(output_folder)
...@@ -61,5 +62,5 @@ if __name__ == "__main__": ...@@ -61,5 +62,5 @@ if __name__ == "__main__":
convert_list(train_path, train_path.replace('txt', 'list'), output_folder) convert_list(train_path, train_path.replace('txt', 'list'), output_folder)
convert_list(val_path, val_path.replace('txt', 'list'), output_folder) convert_list(val_path, val_path.replace('txt', 'list'), output_folder)
convert_list(trainval_path, trainval_path.replace('txt', 'list'), output_folder) convert_list(trainval_path, trainval_path.replace('txt', 'list'),
output_folder)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
...@@ -44,7 +44,7 @@ yum install -y libXext libSM libXrender ...@@ -44,7 +44,7 @@ yum install -y libXext libSM libXrender
### 5.1 准备模型 ### 5.1 准备模型
请使用[模型导出工具](../../docs/model_export.md) 导出您的模型, 或点击下载我们的[人像分割样例模型](https://bj.bcebos.com/paddleseg/inference/human_freeze_model.zip)用于测试。 请使用[模型导出工具](../../docs/model_export.md) 导出您的模型, 或点击下载我们的[人像分割样例模型](https://bj.bcebos.com/paddleseg/inference/human_freeze_model.zip)用于测试。
模型导出的目录通常包括三个文件: 模型导出的目录通常包括三个文件:
``` ```
├── model # 模型文件 ├── model # 模型文件
├── params # 参数文件 ├── params # 参数文件
...@@ -79,7 +79,7 @@ DEPLOY: ...@@ -79,7 +79,7 @@ DEPLOY:
### 5.2 执行预测程序 ### 5.2 执行预测程序
在终端输入以下命令进行预测: 在终端输入以下命令进行预测:
```bash ```bash
python infer.py --conf=/path/to/deploy.yaml --input_dir/path/to/images_directory --use_pr=False python infer.py --conf=/path/to/deploy.yaml --input_dir/path/to/images_directory
``` ```
参数说明如下: 参数说明如下:
...@@ -87,9 +87,6 @@ python infer.py --conf=/path/to/deploy.yaml --input_dir/path/to/images_directory ...@@ -87,9 +87,6 @@ python infer.py --conf=/path/to/deploy.yaml --input_dir/path/to/images_directory
|-------|-------|----------| |-------|-------|----------|
| conf | Yes|模型配置的Yaml文件路径 | | conf | Yes|模型配置的Yaml文件路径 |
| input_dir |Yes| 需要预测的图片目录 | | input_dir |Yes| 需要预测的图片目录 |
| use_pr |No|是否使用优化模型,默认为False|
* 优化模型:使用`PaddleSeg 0.3.0`版导出的为优化模型, 此前版本导出的模型即为未优化版本。优化模型把图像的预处理以及后处理部分融入到模型网络中使用`GPU` 完成,相比原来`CPU` 中的处理提升了计算性能。
**注意**: 如果硬件支持且安装的是从源码编译集成`TensorRT``PaddlePaddle`, 则可以使用参数`--trt_mode=fp16` 表示开启`FP16` 精度优化, 使用`trt_mode=fp32` 表示使用`FP32` 精度。 **注意**: 如果硬件支持且安装的是从源码编译集成`TensorRT``PaddlePaddle`, 则可以使用参数`--trt_mode=fp16` 表示开启`FP16` 精度优化, 使用`trt_mode=fp32` 表示使用`FP32` 精度。
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -29,9 +29,9 @@ from concurrent.futures import ThreadPoolExecutor, as_completed ...@@ -29,9 +29,9 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
gflags.DEFINE_string("conf", default="", help="Configuration File Path") gflags.DEFINE_string("conf", default="", help="Configuration File Path")
gflags.DEFINE_string("input_dir", default="", help="Directory of Input Images") gflags.DEFINE_string("input_dir", default="", help="Directory of Input Images")
gflags.DEFINE_boolean("use_pr", default=False, help="Use optimized model")
gflags.DEFINE_string("trt_mode", default="", help="Use optimized model") gflags.DEFINE_string("trt_mode", default="", help="Use optimized model")
gflags.DEFINE_string("ext", default=".jpeg|.jpg", help="Input Image File Extensions") gflags.DEFINE_string(
"ext", default=".jpeg|.jpg", help="Input Image File Extensions")
gflags.FLAGS = gflags.FLAGS gflags.FLAGS = gflags.FLAGS
...@@ -103,6 +103,9 @@ class DeployConfig: ...@@ -103,6 +103,9 @@ class DeployConfig:
self.batch_size = deploy_conf["BATCH_SIZE"] self.batch_size = deploy_conf["BATCH_SIZE"]
# 9. channels # 9. channels
self.channels = deploy_conf["CHANNELS"] self.channels = deploy_conf["CHANNELS"]
# 10. use_pr
self.use_pr = deploy_conf["USE_PR"]
class ImageReader: class ImageReader:
...@@ -257,23 +260,24 @@ class Predictor: ...@@ -257,23 +260,24 @@ class Predictor:
# record starting time point # record starting time point
total_start = time.time() total_start = time.time()
batch_size = self.config.batch_size batch_size = self.config.batch_size
use_pr = self.config.use_pr
for i in range(0, len(images), batch_size): for i in range(0, len(images), batch_size):
real_batch_size = batch_size real_batch_size = batch_size
if i + batch_size >= len(images): if i + batch_size >= len(images):
real_batch_size = len(images) - i real_batch_size = len(images) - i
reader_start = time.time() reader_start = time.time()
img_datas = self.image_reader.process(images[i:i + real_batch_size], img_datas = self.image_reader.process(images[i:i + real_batch_size],
gflags.FLAGS.use_pr) use_pr)
input_data = np.concatenate([item[1] for item in img_datas]) input_data = np.concatenate([item[1] for item in img_datas])
input_data = self.create_tensor( input_data = self.create_tensor(
input_data, real_batch_size, use_pr=gflags.FLAGS.use_pr) input_data, real_batch_size, use_pr=use_pr)
reader_end = time.time() reader_end = time.time()
infer_start = time.time() infer_start = time.time()
output_data = self.predictor.run(input_data)[0] output_data = self.predictor.run(input_data)[0]
infer_end = time.time() infer_end = time.time()
output_data = output_data.as_ndarray() output_data = output_data.as_ndarray()
post_start = time.time() post_start = time.time()
self.output_result(img_datas, output_data, gflags.FLAGS.use_pr) self.output_result(img_datas, output_data, use_pr)
post_end = time.time() post_end = time.time()
reader_time += (reader_end - reader_start) reader_time += (reader_end - reader_start)
infer_time += (infer_end - infer_start) infer_time += (infer_end - infer_start)
......
...@@ -59,16 +59,16 @@ ...@@ -59,16 +59,16 @@
* 经过数据格式转换后的数据集目录结构如下: * 经过数据格式转换后的数据集目录结构如下:
``` ```
my_dataset # 根目录 my_dataset # 根目录
|-- outputs # 标注工具导出目录 |-- outputs # 标注工具导出目录
| |-- annotations # 数据集真值 | |-- annotations # 数据集真值
| |-- xxx.png # 像素级别的真值信息 | |-- xxx.png # 像素级别的真值信息
| |... | |...
| |-- class_names.txt # 数据集的类别名称 | |-- class_names.txt # 数据集的类别名称
| |-- xxx.json # 标注json文件 | |-- xxx.json # 标注json文件
|-- xxx.jpg(png or other) # 数据集原图 |-- xxx.jpg(png or other) # 数据集原图
|-- ... |-- ...
``` ```
<div align="center"> <div align="center">
...@@ -76,16 +76,10 @@ ...@@ -76,16 +76,10 @@
<p>图5 格式转换后的数据集目录的结构示意图</p> <p>图5 格式转换后的数据集目录的结构示意图</p>
</div> </div>
* 运行转换脚本需要依赖labelme和pillow,如未安装,请先安装。Labelme的具体安装流程请参见[官方安装指南](https://github.com/wkentaro/labelme)。Pillow的安装:
```shell
pip install pillow
```
* 运行以下代码,将标注后的数据转换成满足以上格式的数据集: * 运行以下代码,将标注后的数据转换成满足以上格式的数据集:
``` ```
python pdseg/tools/jingling2seg.py <PATH/TO/LABEL_JSON_FILE> python pdseg/tools/jingling2seg.py <PATH/TO/LABEL_JSON_FILE>
``` ```
其中,`<PATH/TO/LABEL_JSON_FILE>`为精灵标注产出的json文件所在文件夹的目录,一般为精灵工具使用(3)中`保存位置`下的`outputs`目录。 其中,`<PATH/TO/LABEL_JSON_FILE>`为精灵标注产出的json文件所在文件夹的目录,一般为精灵工具使用(3)中`保存位置`下的`outputs`目录。
...@@ -101,4 +95,4 @@ python pdseg/tools/jingling2seg.py docs/annotation/jingling_demo/outputs/ ...@@ -101,4 +95,4 @@ python pdseg/tools/jingling2seg.py docs/annotation/jingling_demo/outputs/
<div align="center"> <div align="center">
<img src="../imgs/annotation/jingling-5.png" width="600px"/> <img src="../imgs/annotation/jingling-5.png" width="600px"/>
<p>图6 格式转换后的数据集各目录的内容示意图</p> <p>图6 格式转换后的数据集各目录的内容示意图</p>
</div> </div>
...@@ -53,11 +53,11 @@ LableMe产出的真值文件可参考我们给出的文件夹[docs/annotation/la ...@@ -53,11 +53,11 @@ LableMe产出的真值文件可参考我们给出的文件夹[docs/annotation/la
<img src="../imgs/annotation/image-5.png" width="600px"/> <img src="../imgs/annotation/image-5.png" width="600px"/>
<p>图5 LableMe产出的真值文件的示意图</p> <p>图5 LableMe产出的真值文件的示意图</p>
</div> </div>
**Note:** **Note:**
对于中间有空洞的目标的标注方法:在标注完目标轮廓后,再沿空洞区域边缘画多边形,并将其指定为其他类别,如果是背景则指定为`_background_`。如下: 对于中间有空洞的目标的标注方法:在标注完目标轮廓后,再沿空洞区域边缘画多边形,并将其指定为其他类别,如果是背景则指定为`_background_`。如下:
<div align="center"> <div align="center">
<img src="../imgs/annotation/image-10.jpg" width="600px"/> <img src="../imgs/annotation/image-10.jpg" width="600px"/>
<p>图6 带空洞目标的标注示意图</p> <p>图6 带空洞目标的标注示意图</p>
...@@ -69,16 +69,16 @@ LableMe产出的真值文件可参考我们给出的文件夹[docs/annotation/la ...@@ -69,16 +69,16 @@ LableMe产出的真值文件可参考我们给出的文件夹[docs/annotation/la
* 经过数据格式转换后的数据集目录结构如下: * 经过数据格式转换后的数据集目录结构如下:
``` ```
my_dataset # 根目录 my_dataset # 根目录
|-- annotations # 数据集真值 |-- annotations # 数据集真值
| |-- xxx.png # 像素级别的真值信息 | |-- xxx.png # 像素级别的真值信息
| |... | |...
|-- class_names.txt # 数据集的类别名称 |-- class_names.txt # 数据集的类别名称
|-- xxx.jpg(png or other) # 数据集原图 |-- xxx.jpg(png or other) # 数据集原图
|-- ... |-- ...
|-- xxx.json # 标注json文件 |-- xxx.json # 标注json文件
|-- ... |-- ...
``` ```
<div align="center"> <div align="center">
...@@ -86,16 +86,10 @@ LableMe产出的真值文件可参考我们给出的文件夹[docs/annotation/la ...@@ -86,16 +86,10 @@ LableMe产出的真值文件可参考我们给出的文件夹[docs/annotation/la
<p>图7 格式转换后的数据集目录的结构示意图</p> <p>图7 格式转换后的数据集目录的结构示意图</p>
</div> </div>
* 运行转换脚本需要依赖labelme和pillow,如未安装,请先安装。Labelme的具体安装流程请参见[官方安装指南](https://github.com/wkentaro/labelme)。Pillow的安装:
```shell
pip install pillow
```
* 运行以下代码,将标注后的数据转换成满足以上格式的数据集: * 运行以下代码,将标注后的数据转换成满足以上格式的数据集:
``` ```
python pdseg/tools/labelme2seg.py <PATH/TO/LABEL_JSON_FILE> python pdseg/tools/labelme2seg.py <PATH/TO/LABEL_JSON_FILE>
``` ```
其中,`<PATH/TO/LABEL_JSON_FILE>`为图片以及LabelMe产出的json文件所在文件夹的目录,同时也是转换后的标注集所在文件夹的目录。 其中,`<PATH/TO/LABEL_JSON_FILE>`为图片以及LabelMe产出的json文件所在文件夹的目录,同时也是转换后的标注集所在文件夹的目录。
...@@ -111,4 +105,4 @@ python pdseg/tools/labelme2seg.py docs/annotation/labelme_demo/ ...@@ -111,4 +105,4 @@ python pdseg/tools/labelme2seg.py docs/annotation/labelme_demo/
<div align="center"> <div align="center">
<img src="../imgs/annotation/image-7.png" width="600px"/> <img src="../imgs/annotation/image-7.png" width="600px"/>
<p>图8 格式转换后的数据集各目录的内容示意图</p> <p>图8 格式转换后的数据集各目录的内容示意图</p>
</div> </div>
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,4 +14,4 @@ ...@@ -14,4 +14,4 @@
# limitations under the License. # limitations under the License.
import models import models
import utils import utils
from . import tools from . import tools
\ No newline at end of file
# coding: utf8 # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -427,12 +440,17 @@ def max_img_size_statistics(): ...@@ -427,12 +440,17 @@ def max_img_size_statistics():
logger.info("max width and max height of images are ({},{})".format( logger.info("max width and max height of images are ({},{})".format(
max_width, max_height)) max_width, max_height))
def num_classes_loss_matching_check(): def num_classes_loss_matching_check():
loss_type = cfg.SOLVER.LOSS loss_type = cfg.SOLVER.LOSS
num_classes = cfg.DATASET.NUM_CLASSES num_classes = cfg.DATASET.NUM_CLASSES
if num_classes > 2 and (("dice_loss" in loss_type) or ("bce_loss" in loss_type)): if num_classes > 2 and (("dice_loss" in loss_type) or
logger.info(error_print("loss check." ("bce_loss" in loss_type)):
" Dice loss and bce loss is only applicable to binary classfication")) logger.info(
error_print(
"loss check."
" Dice loss and bce loss is only applicable to binary classfication"
))
else: else:
logger.info(correct_print("loss check")) logger.info(correct_print("loss check"))
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -362,7 +362,7 @@ def hsv_color_jitter(crop_img, ...@@ -362,7 +362,7 @@ def hsv_color_jitter(crop_img,
saturation_jitter_ratio > 0 or \ saturation_jitter_ratio > 0 or \
contrast_jitter_ratio > 0: contrast_jitter_ratio > 0:
crop_img = random_jitter(crop_img, saturation_jitter_ratio, crop_img = random_jitter(crop_img, saturation_jitter_ratio,
brightness_jitter_ratio, contrast_jitter_ratio) brightness_jitter_ratio, contrast_jitter_ratio)
return crop_img return crop_img
...@@ -391,7 +391,7 @@ def rand_crop(crop_img, crop_seg, mode=ModelPhase.TRAIN): ...@@ -391,7 +391,7 @@ def rand_crop(crop_img, crop_seg, mode=ModelPhase.TRAIN):
crop_width = cfg.EVAL_CROP_SIZE[0] crop_width = cfg.EVAL_CROP_SIZE[0]
crop_height = cfg.EVAL_CROP_SIZE[1] crop_height = cfg.EVAL_CROP_SIZE[1]
if not ModelPhase.is_train(mode): if not ModelPhase.is_train(mode):
if (crop_height < img_height or crop_width < img_width): if (crop_height < img_height or crop_width < img_width):
raise Exception( raise Exception(
"Crop size({},{}) must large than img size({},{}) when in EvalPhase." "Crop size({},{}) must large than img size({},{}) when in EvalPhase."
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This code is based on https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py This code is based on https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py
""" """
...@@ -14,10 +28,10 @@ except ImportError: ...@@ -14,10 +28,10 @@ except ImportError:
class GeneratorEnqueuer(object): class GeneratorEnqueuer(object):
""" """
Multiple generators Multiple generators
Args: Args:
generators: generators:
wait_time (float): time to sleep in-between calls to `put()`. wait_time (float): time to sleep in-between calls to `put()`.
""" """
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -22,13 +22,9 @@ import os ...@@ -22,13 +22,9 @@ import os
os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0" os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0"
import sys import sys
import time
import argparse import argparse
import functools
import pprint import pprint
import cv2
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from utils.config import cfg from utils.config import cfg
...@@ -116,7 +112,10 @@ def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs): ...@@ -116,7 +112,10 @@ def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs):
if ckpt_dir is not None: if ckpt_dir is not None:
print('load test model:', ckpt_dir) print('load test model:', ckpt_dir)
fluid.io.load_params(exe, ckpt_dir, main_program=test_prog) try:
fluid.load(test_prog, os.path.join(ckpt_dir, 'model'), exe)
except:
fluid.io.load_params(exe, ckpt_dir, main_program=test_prog)
# Use streaming confusion matrix to calculate mean_iou # Use streaming confusion matrix to calculate mean_iou
np.set_printoptions( np.set_printoptions(
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -49,10 +49,11 @@ def parse_args(): ...@@ -49,10 +49,11 @@ def parse_args():
sys.exit(1) sys.exit(1)
return parser.parse_args() return parser.parse_args()
def export_inference_config(): def export_inference_config():
deploy_cfg = '''DEPLOY: deploy_cfg = '''DEPLOY:
USE_GPU : 1 USE_GPU : 1
USE_PR : 1 USE_PR : 0
MODEL_PATH : "%s" MODEL_PATH : "%s"
MODEL_FILENAME : "%s" MODEL_FILENAME : "%s"
PARAMS_FILENAME : "%s" PARAMS_FILENAME : "%s"
...@@ -66,9 +67,8 @@ def export_inference_config(): ...@@ -66,9 +67,8 @@ def export_inference_config():
PREDICTOR_MODE : "ANALYSIS" PREDICTOR_MODE : "ANALYSIS"
BATCH_SIZE : 1 BATCH_SIZE : 1
''' % (cfg.FREEZE.SAVE_DIR, cfg.FREEZE.MODEL_FILENAME, ''' % (cfg.FREEZE.SAVE_DIR, cfg.FREEZE.MODEL_FILENAME,
cfg.FREEZE.PARAMS_FILENAME, cfg.EVAL_CROP_SIZE, cfg.FREEZE.PARAMS_FILENAME, cfg.EVAL_CROP_SIZE, cfg.MEAN, cfg.STD,
cfg.MEAN, cfg.STD, cfg.DATASET.IMAGE_TYPE, cfg.DATASET.IMAGE_TYPE, cfg.DATASET.NUM_CLASSES, len(cfg.STD))
cfg.DATASET.NUM_CLASSES, len(cfg.STD))
if not os.path.exists(cfg.FREEZE.SAVE_DIR): if not os.path.exists(cfg.FREEZE.SAVE_DIR):
os.mkdir(cfg.FREEZE.SAVE_DIR) os.mkdir(cfg.FREEZE.SAVE_DIR)
yaml_path = os.path.join(cfg.FREEZE.SAVE_DIR, 'deploy.yaml') yaml_path = os.path.join(cfg.FREEZE.SAVE_DIR, 'deploy.yaml')
...@@ -94,7 +94,13 @@ def export_inference_model(args): ...@@ -94,7 +94,13 @@ def export_inference_model(args):
infer_prog = infer_prog.clone(for_test=True) infer_prog = infer_prog.clone(for_test=True)
if os.path.exists(cfg.TEST.TEST_MODEL): if os.path.exists(cfg.TEST.TEST_MODEL):
fluid.io.load_params(exe, cfg.TEST.TEST_MODEL, main_program=infer_prog) print('load test model:', cfg.TEST.TEST_MODEL)
try:
fluid.load(infer_prog, os.path.join(cfg.TEST.TEST_MODEL, 'model'),
exe)
except:
fluid.io.load_params(
exe, cfg.TEST.TEST_MODEL, main_program=infer_prog)
else: else:
print("TEST.TEST_MODEL diretory is empty!") print("TEST.TEST_MODEL diretory is empty!")
exit(-1) exit(-1)
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -141,7 +141,7 @@ class ResNet(): ...@@ -141,7 +141,7 @@ class ResNet():
else: else:
conv_name = "res" + str(block + 2) + chr(97 + i) conv_name = "res" + str(block + 2) + chr(97 + i)
dilation_rate = get_dilated_rate(dilation_dict, block) dilation_rate = get_dilated_rate(dilation_dict, block)
conv = self.bottleneck_block( conv = self.bottleneck_block(
input=conv, input=conv,
num_filters=int(num_filters[block] * self.scale), num_filters=int(num_filters[block] * self.scale),
...@@ -215,11 +215,11 @@ class ResNet(): ...@@ -215,11 +215,11 @@ class ResNet():
groups=1, groups=1,
act=None, act=None,
name=None): name=None):
if self.stem == 'pspnet': if self.stem == 'pspnet':
bias_attr=ParamAttr(name=name + "_biases") bias_attr = ParamAttr(name=name + "_biases")
else: else:
bias_attr=False bias_attr = False
conv = fluid.layers.conv2d( conv = fluid.layers.conv2d(
input=input, input=input,
...@@ -238,13 +238,15 @@ class ResNet(): ...@@ -238,13 +238,15 @@ class ResNet():
bn_name = "bn_" + name bn_name = "bn_" + name
else: else:
bn_name = "bn" + name[3:] bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(input=conv, return fluid.layers.batch_norm(
act=act, input=conv,
name=bn_name + '.output.1', act=act,
param_attr=ParamAttr(name=bn_name + '_scale'), name=bn_name + '.output.1',
bias_attr=ParamAttr(bn_name + '_offset'), param_attr=ParamAttr(name=bn_name + '_scale'),
moving_mean_name=bn_name + '_mean', bias_attr=ParamAttr(bn_name + '_offset'),
moving_variance_name=bn_name + '_variance', ) moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',
)
def shortcut(self, input, ch_out, stride, is_first, name): def shortcut(self, input, ch_out, stride, is_first, name):
ch_in = input.shape[1] ch_in = input.shape[1]
...@@ -258,7 +260,7 @@ class ResNet(): ...@@ -258,7 +260,7 @@ class ResNet():
strides = [1, stride] strides = [1, stride]
else: else:
strides = [stride, 1] strides = [stride, 1]
conv0 = self.conv_bn_layer( conv0 = self.conv_bn_layer(
input=input, input=input,
num_filters=num_filters, num_filters=num_filters,
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -55,7 +55,8 @@ class VGGNet(): ...@@ -55,7 +55,8 @@ class VGGNet():
channels = [64, 128, 256, 512, 512] channels = [64, 128, 256, 512, 512]
conv = input conv = input
for i in range(len(nums)): for i in range(len(nums)):
conv = self.conv_block(conv, channels[i], nums[i], name="conv" + str(i + 1) + "_") conv = self.conv_block(
conv, channels[i], nums[i], name="conv" + str(i + 1) + "_")
layers_count += nums[i] layers_count += nums[i]
if check_points(layers_count, decode_points): if check_points(layers_count, decode_points):
short_cuts[layers_count] = conv short_cuts[layers_count] = conv
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -197,4 +197,4 @@ def conv_bn_layer(input, ...@@ -197,4 +197,4 @@ def conv_bn_layer(input,
if if_act: if if_act:
return fluid.layers.relu6(bn) return fluid.layers.relu6(bn)
else: else:
return bn return bn
\ No newline at end of file
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -111,53 +111,6 @@ def sigmoid_to_softmax(logit): ...@@ -111,53 +111,6 @@ def sigmoid_to_softmax(logit):
return logit return logit
def export_preprocess(image):
"""导出模型的预处理流程"""
image = fluid.layers.transpose(image, [0, 3, 1, 2])
origin_shape = fluid.layers.shape(image)[-2:]
# 不同AUG_METHOD方法的resize
if cfg.AUG.AUG_METHOD == 'unpadding':
h_fix = cfg.AUG.FIX_RESIZE_SIZE[1]
w_fix = cfg.AUG.FIX_RESIZE_SIZE[0]
image = fluid.layers.resize_bilinear(
image, out_shape=[h_fix, w_fix], align_corners=False, align_mode=0)
elif cfg.AUG.AUG_METHOD == 'rangescaling':
size = cfg.AUG.INF_RESIZE_VALUE
value = fluid.layers.reduce_max(origin_shape)
scale = float(size) / value.astype('float32')
image = fluid.layers.resize_bilinear(
image, scale=scale, align_corners=False, align_mode=0)
# 存储resize后图像shape
valid_shape = fluid.layers.shape(image)[-2:]
# padding到eval_crop_size大小
width = cfg.EVAL_CROP_SIZE[0]
height = cfg.EVAL_CROP_SIZE[1]
pad_target = fluid.layers.assign(
np.array([height, width]).astype('float32'))
up = fluid.layers.assign(np.array([0]).astype('float32'))
down = pad_target[0] - valid_shape[0]
left = up
right = pad_target[1] - valid_shape[1]
paddings = fluid.layers.concat([up, down, left, right])
paddings = fluid.layers.cast(paddings, 'int32')
image = fluid.layers.pad2d(image, paddings=paddings, pad_value=127.5)
# normalize
mean = np.array(cfg.MEAN).reshape(1, len(cfg.MEAN), 1, 1)
mean = fluid.layers.assign(mean.astype('float32'))
std = np.array(cfg.STD).reshape(1, len(cfg.STD), 1, 1)
std = fluid.layers.assign(std.astype('float32'))
image = (image / 255 - mean) / std
# 使后面的网络能通过类似image.shape获取特征图的shape
image = fluid.layers.reshape(
image, shape=[-1, cfg.DATASET.DATA_DIM, height, width])
return image, valid_shape, origin_shape
def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN): def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
if not ModelPhase.is_valid_phase(phase): if not ModelPhase.is_valid_phase(phase):
raise ValueError("ModelPhase {} is not valid!".format(phase)) raise ValueError("ModelPhase {} is not valid!".format(phase))
...@@ -176,21 +129,7 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN): ...@@ -176,21 +129,7 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
# 在导出模型的时候,增加图像标准化预处理,减小预测部署时图像的处理流程 # 在导出模型的时候,增加图像标准化预处理,减小预测部署时图像的处理流程
# 预测部署时只须对输入图像增加batch_size维度即可 # 预测部署时只须对输入图像增加batch_size维度即可
if ModelPhase.is_predict(phase): image = fluid.data(name='image', shape=image_shape, dtype='float32')
if cfg.SLIM.PREPROCESS:
image = fluid.data(
name='image', shape=image_shape, dtype='float32')
else:
origin_image = fluid.data(
name='image',
shape=[-1, -1, -1, cfg.DATASET.DATA_DIM],
dtype='float32')
image, valid_shape, origin_shape = export_preprocess(
origin_image)
else:
image = fluid.data(
name='image', shape=image_shape, dtype='float32')
label = fluid.data(name='label', shape=grt_shape, dtype='int32') label = fluid.data(name='label', shape=grt_shape, dtype='int32')
mask = fluid.data(name='mask', shape=grt_shape, dtype='int32') mask = fluid.data(name='mask', shape=grt_shape, dtype='int32')
...@@ -223,6 +162,7 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN): ...@@ -223,6 +162,7 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
raise Exception( raise Exception(
"softmax loss or lovasz softmax loss can not combine with bce loss or dice loss or lovasz hinge loss." "softmax loss or lovasz softmax loss can not combine with bce loss or dice loss or lovasz hinge loss."
) )
cfg.PHASE = phase
logits = seg_model(image, class_num) logits = seg_model(image, class_num)
# 根据选择的loss函数计算相应的损失函数 # 根据选择的loss函数计算相应的损失函数
...@@ -292,21 +232,7 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN): ...@@ -292,21 +232,7 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
else: else:
logit = softmax(logit) logit = softmax(logit)
# 获取有效部分 return image, logit
if cfg.SLIM.PREPROCESS:
return image, logit
else:
logit = fluid.layers.slice(
logit, axes=[2, 3], starts=[0, 0], ends=valid_shape)
logit = fluid.layers.resize_bilinear(
logit,
out_shape=origin_shape,
align_corners=False,
align_mode=0)
logit = fluid.layers.argmax(logit, axis=1)
return origin_image, logit
if class_num == 1: if class_num == 1:
out = sigmoid_to_softmax(logit) out = sigmoid_to_softmax(logit)
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -25,12 +25,15 @@ from models.libs.model_libs import separate_conv ...@@ -25,12 +25,15 @@ from models.libs.model_libs import separate_conv
from utils.config import cfg from utils.config import cfg
def learning_to_downsample(x, dw_channels1=32, dw_channels2=48, out_channels=64): def learning_to_downsample(x, dw_channels1=32, dw_channels2=48,
out_channels=64):
x = relu(bn(conv(x, dw_channels1, 3, 2))) x = relu(bn(conv(x, dw_channels1, 3, 2)))
with scope('dsconv1'): with scope('dsconv1'):
x = separate_conv(x, dw_channels2, stride=2, filter=3, act=fluid.layers.relu) x = separate_conv(
x, dw_channels2, stride=2, filter=3, act=fluid.layers.relu)
with scope('dsconv2'): with scope('dsconv2'):
x = separate_conv(x, out_channels, stride=2, filter=3, act=fluid.layers.relu) x = separate_conv(
x, out_channels, stride=2, filter=3, act=fluid.layers.relu)
return x return x
...@@ -43,7 +46,9 @@ def dropout2d(input, prob, is_train=False): ...@@ -43,7 +46,9 @@ def dropout2d(input, prob, is_train=False):
return input return input
channels = input.shape[1] channels = input.shape[1]
keep_prob = 1.0 - prob keep_prob = 1.0 - prob
random_tensor = keep_prob + fluid.layers.uniform_random_batch_size_like(input, [-1, channels, 1, 1], min=0., max=1.) shape = fluid.layers.shape(input)
random_tensor = keep_prob + fluid.layers.uniform_random(
[shape[0], channels, 1, 1], min=0., max=1.)
binary_tensor = fluid.layers.floor(random_tensor) binary_tensor = fluid.layers.floor(random_tensor)
output = input / keep_prob * binary_tensor output = input / keep_prob * binary_tensor
return output return output
...@@ -136,18 +141,23 @@ def psp_module(input, out_features): ...@@ -136,18 +141,23 @@ def psp_module(input, out_features):
for size in sizes: for size in sizes:
psp_name = "psp" + str(size) psp_name = "psp" + str(size)
with scope(psp_name): with scope(psp_name):
pool = fluid.layers.adaptive_pool2d(input, pool = fluid.layers.adaptive_pool2d(
pool_size=[size, size], input,
pool_type='avg', pool_size=[size, size],
name=psp_name + '_adapool') pool_type='avg',
data = conv(pool, out_features, name=psp_name + '_adapool')
filter_size=1, data = conv(
bias_attr=False, pool,
name=psp_name + '_conv') out_features,
filter_size=1,
bias_attr=False,
name=psp_name + '_conv')
data_bn = bn(data, act='relu') data_bn = bn(data, act='relu')
interp = fluid.layers.resize_bilinear(data_bn, interp = fluid.layers.resize_bilinear(
out_shape=input.shape[2:], data_bn,
name=psp_name + '_interp', align_mode=0) out_shape=input.shape[2:],
name=psp_name + '_interp',
align_mode=0)
cat_layers.append(interp) cat_layers.append(interp)
cat_layers = [input] + cat_layers cat_layers = [input] + cat_layers
out = fluid.layers.concat(cat_layers, axis=1, name='psp_cat') out = fluid.layers.concat(cat_layers, axis=1, name='psp_cat')
...@@ -158,7 +168,11 @@ def psp_module(input, out_features): ...@@ -158,7 +168,11 @@ def psp_module(input, out_features):
class FeatureFusionModule: class FeatureFusionModule:
"""Feature fusion module""" """Feature fusion module"""
def __init__(self, higher_in_channels, lower_in_channels, out_channels, scale_factor=4): def __init__(self,
higher_in_channels,
lower_in_channels,
out_channels,
scale_factor=4):
self.higher_in_channels = higher_in_channels self.higher_in_channels = higher_in_channels
self.lower_in_channels = lower_in_channels self.lower_in_channels = lower_in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -166,14 +180,19 @@ class FeatureFusionModule: ...@@ -166,14 +180,19 @@ class FeatureFusionModule:
def net(self, higher_res_feature, lower_res_feature): def net(self, higher_res_feature, lower_res_feature):
h, w = higher_res_feature.shape[2:] h, w = higher_res_feature.shape[2:]
lower_res_feature = fluid.layers.resize_bilinear(lower_res_feature, [h, w], align_mode=0) lower_res_feature = fluid.layers.resize_bilinear(
lower_res_feature, [h, w], align_mode=0)
with scope('dwconv'): with scope('dwconv'):
lower_res_feature = relu(bn(conv(lower_res_feature, self.out_channels, 1)))#(lower_res_feature) lower_res_feature = relu(
bn(conv(lower_res_feature, self.out_channels,
1))) #(lower_res_feature)
with scope('conv_lower_res'): with scope('conv_lower_res'):
lower_res_feature = bn(conv(lower_res_feature, self.out_channels, 1, bias_attr=True)) lower_res_feature = bn(
conv(lower_res_feature, self.out_channels, 1, bias_attr=True))
with scope('conv_higher_res'): with scope('conv_higher_res'):
higher_res_feature = bn(conv(higher_res_feature, self.out_channels, 1, bias_attr=True)) higher_res_feature = bn(
conv(higher_res_feature, self.out_channels, 1, bias_attr=True))
out = higher_res_feature + lower_res_feature out = higher_res_feature + lower_res_feature
return relu(out) return relu(out)
...@@ -182,8 +201,12 @@ class FeatureFusionModule: ...@@ -182,8 +201,12 @@ class FeatureFusionModule:
class GlobalFeatureExtractor(): class GlobalFeatureExtractor():
"""Global feature extractor module""" """Global feature extractor module"""
def __init__(self, in_channels=64, block_channels=(64, 96, 128), out_channels=128, def __init__(self,
t=6, num_blocks=(3, 3, 3)): in_channels=64,
block_channels=(64, 96, 128),
out_channels=128,
t=6,
num_blocks=(3, 3, 3)):
self.in_channels = in_channels self.in_channels = in_channels
self.block_channels = block_channels self.block_channels = block_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -191,12 +214,15 @@ class GlobalFeatureExtractor(): ...@@ -191,12 +214,15 @@ class GlobalFeatureExtractor():
self.num_blocks = num_blocks self.num_blocks = num_blocks
def net(self, x): def net(self, x):
x, _ = inverted_blocks(x, self.in_channels, self.t, self.block_channels[0], x, _ = inverted_blocks(x, self.in_channels, self.t,
self.num_blocks[0], 2, 'inverted_block_1') self.block_channels[0], self.num_blocks[0], 2,
x, _ = inverted_blocks(x, self.block_channels[0], self.t, self.block_channels[1], 'inverted_block_1')
self.num_blocks[1], 2, 'inverted_block_2') x, _ = inverted_blocks(x, self.block_channels[0], self.t,
x, _ = inverted_blocks(x, self.block_channels[1], self.t, self.block_channels[2], self.block_channels[1], self.num_blocks[1], 2,
self.num_blocks[2], 1, 'inverted_block_3') 'inverted_block_2')
x, _ = inverted_blocks(x, self.block_channels[1], self.t,
self.block_channels[2], self.num_blocks[2], 1,
'inverted_block_3')
x = psp_module(x, self.block_channels[2] // 4) x = psp_module(x, self.block_channels[2] // 4)
with scope('out'): with scope('out'):
x = relu(bn(conv(x, self.out_channels, 1))) x = relu(bn(conv(x, self.out_channels, 1)))
...@@ -213,10 +239,21 @@ class Classifier: ...@@ -213,10 +239,21 @@ class Classifier:
def net(self, x): def net(self, x):
with scope('dsconv1'): with scope('dsconv1'):
x = separate_conv(x, self.dw_channels, stride=self.stride, filter=3, act=fluid.layers.relu) x = separate_conv(
x,
self.dw_channels,
stride=self.stride,
filter=3,
act=fluid.layers.relu)
with scope('dsconv2'): with scope('dsconv2'):
x = separate_conv(x, self.dw_channels, stride=self.stride, filter=3, act=fluid.layers.relu) x = separate_conv(
x = dropout2d(x, 0.1, is_train=cfg.PHASE=='train') x,
self.dw_channels,
stride=self.stride,
filter=3,
act=fluid.layers.relu)
x = dropout2d(x, 0.1, is_train=cfg.PHASE == 'train')
x = conv(x, self.num_classes, 1, bias_attr=True) x = conv(x, self.num_classes, 1, bias_attr=True)
return x return x
...@@ -233,7 +270,8 @@ def fast_scnn(img, num_classes): ...@@ -233,7 +270,8 @@ def fast_scnn(img, num_classes):
size = img.shape[2:] size = img.shape[2:]
classifier = Classifier(128, num_classes) classifier = Classifier(128, num_classes)
global_feature_extractor = GlobalFeatureExtractor(64, [64, 96, 128], 128, 6, [3, 3, 3]) global_feature_extractor = GlobalFeatureExtractor(64, [64, 96, 128], 128, 6,
[3, 3, 3])
feature_fusion = FeatureFusionModule(64, 128, 128) feature_fusion = FeatureFusionModule(64, 128, 128)
with scope('learning_to_downsample'): with scope('learning_to_downsample'):
...@@ -249,15 +287,18 @@ def fast_scnn(img, num_classes): ...@@ -249,15 +287,18 @@ def fast_scnn(img, num_classes):
if len(cfg.MODEL.MULTI_LOSS_WEIGHT) == 3: if len(cfg.MODEL.MULTI_LOSS_WEIGHT) == 3:
with scope('aux_layer_higher'): with scope('aux_layer_higher'):
higher_logit = aux_layer(higher_res_features, num_classes) higher_logit = aux_layer(higher_res_features, num_classes)
higher_logit = fluid.layers.resize_bilinear(higher_logit, size, align_mode=0) higher_logit = fluid.layers.resize_bilinear(
higher_logit, size, align_mode=0)
with scope('aux_layer_lower'): with scope('aux_layer_lower'):
lower_logit = aux_layer(lower_res_feature, num_classes) lower_logit = aux_layer(lower_res_feature, num_classes)
lower_logit = fluid.layers.resize_bilinear(lower_logit, size, align_mode=0) lower_logit = fluid.layers.resize_bilinear(
lower_logit, size, align_mode=0)
return logit, higher_logit, lower_logit return logit, higher_logit, lower_logit
elif len(cfg.MODEL.MULTI_LOSS_WEIGHT) == 2: elif len(cfg.MODEL.MULTI_LOSS_WEIGHT) == 2:
with scope('aux_layer_higher'): with scope('aux_layer_higher'):
higher_logit = aux_layer(higher_res_features, num_classes) higher_logit = aux_layer(higher_res_features, num_classes)
higher_logit = fluid.layers.resize_bilinear(higher_logit, size, align_mode=0) higher_logit = fluid.layers.resize_bilinear(
higher_logit, size, align_mode=0)
return logit, higher_logit return logit, higher_logit
return logit return logit
\ No newline at end of file
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -25,7 +25,14 @@ from paddle.fluid.param_attr import ParamAttr ...@@ -25,7 +25,14 @@ from paddle.fluid.param_attr import ParamAttr
from utils.config import cfg from utils.config import cfg
def conv_bn_layer(input, filter_size, num_filters, stride=1, padding=1, num_groups=1, if_act=True, name=None): def conv_bn_layer(input,
filter_size,
num_filters,
stride=1,
padding=1,
num_groups=1,
if_act=True,
name=None):
conv = fluid.layers.conv2d( conv = fluid.layers.conv2d(
input=input, input=input,
num_filters=num_filters, num_filters=num_filters,
...@@ -37,37 +44,74 @@ def conv_bn_layer(input, filter_size, num_filters, stride=1, padding=1, num_grou ...@@ -37,37 +44,74 @@ def conv_bn_layer(input, filter_size, num_filters, stride=1, padding=1, num_grou
param_attr=ParamAttr(initializer=MSRA(), name=name + '_weights'), param_attr=ParamAttr(initializer=MSRA(), name=name + '_weights'),
bias_attr=False) bias_attr=False)
bn_name = name + '_bn' bn_name = name + '_bn'
bn = fluid.layers.batch_norm(input=conv, bn = fluid.layers.batch_norm(
param_attr=ParamAttr(name=bn_name + "_scale", input=conv,
initializer=fluid.initializer.Constant(1.0)), param_attr=ParamAttr(
bias_attr=ParamAttr(name=bn_name + "_offset", name=bn_name + "_scale",
initializer=fluid.initializer.Constant(0.0)), initializer=fluid.initializer.Constant(1.0)),
moving_mean_name=bn_name + '_mean', bias_attr=ParamAttr(
moving_variance_name=bn_name + '_variance') name=bn_name + "_offset",
initializer=fluid.initializer.Constant(0.0)),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act: if if_act:
bn = fluid.layers.relu(bn) bn = fluid.layers.relu(bn)
return bn return bn
def basic_block(input, num_filters, stride=1, downsample=False, name=None): def basic_block(input, num_filters, stride=1, downsample=False, name=None):
residual = input residual = input
conv = conv_bn_layer(input=input, filter_size=3, num_filters=num_filters, stride=stride, name=name + '_conv1') conv = conv_bn_layer(
conv = conv_bn_layer(input=conv, filter_size=3, num_filters=num_filters, if_act=False, name=name + '_conv2') input=input,
filter_size=3,
num_filters=num_filters,
stride=stride,
name=name + '_conv1')
conv = conv_bn_layer(
input=conv,
filter_size=3,
num_filters=num_filters,
if_act=False,
name=name + '_conv2')
if downsample: if downsample:
residual = conv_bn_layer(input=input, filter_size=1, num_filters=num_filters, if_act=False, residual = conv_bn_layer(
name=name + '_downsample') input=input,
filter_size=1,
num_filters=num_filters,
if_act=False,
name=name + '_downsample')
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu') return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def bottleneck_block(input, num_filters, stride=1, downsample=False, name=None): def bottleneck_block(input, num_filters, stride=1, downsample=False, name=None):
residual = input residual = input
conv = conv_bn_layer(input=input, filter_size=1, num_filters=num_filters, name=name + '_conv1') conv = conv_bn_layer(
conv = conv_bn_layer(input=conv, filter_size=3, num_filters=num_filters, stride=stride, name=name + '_conv2') input=input,
conv = conv_bn_layer(input=conv, filter_size=1, num_filters=num_filters * 4, if_act=False, filter_size=1,
name=name + '_conv3') num_filters=num_filters,
name=name + '_conv1')
conv = conv_bn_layer(
input=conv,
filter_size=3,
num_filters=num_filters,
stride=stride,
name=name + '_conv2')
conv = conv_bn_layer(
input=conv,
filter_size=1,
num_filters=num_filters * 4,
if_act=False,
name=name + '_conv3')
if downsample: if downsample:
residual = conv_bn_layer(input=input, filter_size=1, num_filters=num_filters * 4, if_act=False, residual = conv_bn_layer(
name=name + '_downsample') input=input,
filter_size=1,
num_filters=num_filters * 4,
if_act=False,
name=name + '_downsample')
return fluid.layers.elementwise_add(x=residual, y=conv, act='relu') return fluid.layers.elementwise_add(x=residual, y=conv, act='relu')
def fuse_layers(x, channels, multi_scale_output=True, name=None): def fuse_layers(x, channels, multi_scale_output=True, name=None):
out = [] out = []
for i in range(len(channels) if multi_scale_output else 1): for i in range(len(channels) if multi_scale_output else 1):
...@@ -77,40 +121,64 @@ def fuse_layers(x, channels, multi_scale_output=True, name=None): ...@@ -77,40 +121,64 @@ def fuse_layers(x, channels, multi_scale_output=True, name=None):
height = shape[-2] height = shape[-2]
for j in range(len(channels)): for j in range(len(channels)):
if j > i: if j > i:
y = conv_bn_layer(x[j], filter_size=1, num_filters=channels[i], if_act=False, y = conv_bn_layer(
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1)) x[j],
y = fluid.layers.resize_bilinear(input=y, out_shape=[height, width]) filter_size=1,
residual = fluid.layers.elementwise_add(x=residual, y=y, act=None) num_filters=channels[i],
if_act=False,
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1))
y = fluid.layers.resize_bilinear(
input=y, out_shape=[height, width])
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
elif j < i: elif j < i:
y = x[j] y = x[j]
for k in range(i - j): for k in range(i - j):
if k == i - j - 1: if k == i - j - 1:
y = conv_bn_layer(y, filter_size=3, num_filters=channels[i], stride=2, if_act=False, y = conv_bn_layer(
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1) + '_' + str(k + 1)) y,
filter_size=3,
num_filters=channels[i],
stride=2,
if_act=False,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1))
else: else:
y = conv_bn_layer(y, filter_size=3, num_filters=channels[j], stride=2, y = conv_bn_layer(
name=name + '_layer_' + str(i + 1) + '_' + str(j + 1) + '_' + str(k + 1)) y,
residual = fluid.layers.elementwise_add(x=residual, y=y, act=None) filter_size=3,
num_filters=channels[j],
stride=2,
name=name + '_layer_' + str(i + 1) + '_' +
str(j + 1) + '_' + str(k + 1))
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
residual = fluid.layers.relu(residual) residual = fluid.layers.relu(residual)
out.append(residual) out.append(residual)
return out return out
def branches(x, block_num, channels, name=None): def branches(x, block_num, channels, name=None):
out = [] out = []
for i in range(len(channels)): for i in range(len(channels)):
residual = x[i] residual = x[i]
for j in range(block_num): for j in range(block_num):
residual = basic_block(residual, channels[i], residual = basic_block(
name=name + '_branch_layer_' + str(i + 1) + '_' + str(j + 1)) residual,
channels[i],
name=name + '_branch_layer_' + str(i + 1) + '_' + str(j + 1))
out.append(residual) out.append(residual)
return out return out
def high_resolution_module(x, channels, multi_scale_output=True, name=None): def high_resolution_module(x, channels, multi_scale_output=True, name=None):
residual = branches(x, 4, channels, name=name) residual = branches(x, 4, channels, name=name)
out = fuse_layers(residual, channels, multi_scale_output=multi_scale_output, name=name) out = fuse_layers(
residual, channels, multi_scale_output=multi_scale_output, name=name)
return out return out
def transition_layer(x, in_channels, out_channels, name=None): def transition_layer(x, in_channels, out_channels, name=None):
num_in = len(in_channels) num_in = len(in_channels)
num_out = len(out_channels) num_out = len(out_channels)
...@@ -118,46 +186,76 @@ def transition_layer(x, in_channels, out_channels, name=None): ...@@ -118,46 +186,76 @@ def transition_layer(x, in_channels, out_channels, name=None):
for i in range(num_out): for i in range(num_out):
if i < num_in: if i < num_in:
if in_channels[i] != out_channels[i]: if in_channels[i] != out_channels[i]:
residual = conv_bn_layer(x[i], filter_size=3, num_filters=out_channels[i], residual = conv_bn_layer(
name=name + '_layer_' + str(i + 1)) x[i],
filter_size=3,
num_filters=out_channels[i],
name=name + '_layer_' + str(i + 1))
out.append(residual) out.append(residual)
else: else:
out.append(x[i]) out.append(x[i])
else: else:
residual = conv_bn_layer(x[-1], filter_size=3, num_filters=out_channels[i], stride=2, residual = conv_bn_layer(
name=name + '_layer_' + str(i + 1)) x[-1],
filter_size=3,
num_filters=out_channels[i],
stride=2,
name=name + '_layer_' + str(i + 1))
out.append(residual) out.append(residual)
return out return out
def stage(x, num_modules, channels, multi_scale_output=True, name=None): def stage(x, num_modules, channels, multi_scale_output=True, name=None):
out = x out = x
for i in range(num_modules): for i in range(num_modules):
if i == num_modules - 1 and multi_scale_output == False: if i == num_modules - 1 and multi_scale_output == False:
out = high_resolution_module(out, channels, multi_scale_output=False, name=name + '_' + str(i + 1)) out = high_resolution_module(
out,
channels,
multi_scale_output=False,
name=name + '_' + str(i + 1))
else: else:
out = high_resolution_module(out, channels, name=name + '_' + str(i + 1)) out = high_resolution_module(
out, channels, name=name + '_' + str(i + 1))
return out return out
def layer1(input, name=None): def layer1(input, name=None):
conv = input conv = input
for i in range(4): for i in range(4):
conv = bottleneck_block(conv, num_filters=64, downsample=True if i == 0 else False, conv = bottleneck_block(
name=name + '_' + str(i + 1)) conv,
num_filters=64,
downsample=True if i == 0 else False,
name=name + '_' + str(i + 1))
return conv return conv
def high_resolution_net(input, num_classes): def high_resolution_net(input, num_classes):
channels_2 = cfg.MODEL.HRNET.STAGE2.NUM_CHANNELS channels_2 = cfg.MODEL.HRNET.STAGE2.NUM_CHANNELS
channels_3 = cfg.MODEL.HRNET.STAGE3.NUM_CHANNELS channels_3 = cfg.MODEL.HRNET.STAGE3.NUM_CHANNELS
channels_4 = cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS channels_4 = cfg.MODEL.HRNET.STAGE4.NUM_CHANNELS
num_modules_2 = cfg.MODEL.HRNET.STAGE2.NUM_MODULES num_modules_2 = cfg.MODEL.HRNET.STAGE2.NUM_MODULES
num_modules_3 = cfg.MODEL.HRNET.STAGE3.NUM_MODULES num_modules_3 = cfg.MODEL.HRNET.STAGE3.NUM_MODULES
num_modules_4 = cfg.MODEL.HRNET.STAGE4.NUM_MODULES num_modules_4 = cfg.MODEL.HRNET.STAGE4.NUM_MODULES
x = conv_bn_layer(input=input, filter_size=3, num_filters=64, stride=2, if_act=True, name='layer1_1') x = conv_bn_layer(
x = conv_bn_layer(input=x, filter_size=3, num_filters=64, stride=2, if_act=True, name='layer1_2') input=input,
filter_size=3,
num_filters=64,
stride=2,
if_act=True,
name='layer1_1')
x = conv_bn_layer(
input=x,
filter_size=3,
num_filters=64,
stride=2,
if_act=True,
name='layer1_2')
la1 = layer1(x, name='layer2') la1 = layer1(x, name='layer2')
tr1 = transition_layer([la1], [256], channels_2, name='tr1') tr1 = transition_layer([la1], [256], channels_2, name='tr1')
...@@ -170,18 +268,21 @@ def high_resolution_net(input, num_classes): ...@@ -170,18 +268,21 @@ def high_resolution_net(input, num_classes):
# upsample # upsample
shape = st4[0].shape shape = st4[0].shape
height, width = shape[-2], shape[-1] height, width = shape[-2], shape[-1]
st4[1] = fluid.layers.resize_bilinear( st4[1] = fluid.layers.resize_bilinear(st4[1], out_shape=[height, width])
st4[1], out_shape=[height, width]) st4[2] = fluid.layers.resize_bilinear(st4[2], out_shape=[height, width])
st4[2] = fluid.layers.resize_bilinear( st4[3] = fluid.layers.resize_bilinear(st4[3], out_shape=[height, width])
st4[2], out_shape=[height, width])
st4[3] = fluid.layers.resize_bilinear(
st4[3], out_shape=[height, width])
out = fluid.layers.concat(st4, axis=1) out = fluid.layers.concat(st4, axis=1)
last_channels = sum(channels_4) last_channels = sum(channels_4)
out = conv_bn_layer(input=out, filter_size=1, num_filters=last_channels, stride=1, if_act=True, name='conv-2') out = conv_bn_layer(
out= fluid.layers.conv2d( input=out,
filter_size=1,
num_filters=last_channels,
stride=1,
if_act=True,
name='conv-2')
out = fluid.layers.conv2d(
input=out, input=out,
num_filters=num_classes, num_filters=num_classes,
filter_size=1, filter_size=1,
...@@ -193,7 +294,6 @@ def high_resolution_net(input, num_classes): ...@@ -193,7 +294,6 @@ def high_resolution_net(input, num_classes):
out = fluid.layers.resize_bilinear(out, input.shape[2:]) out = fluid.layers.resize_bilinear(out, input.shape[2:])
return out return out
...@@ -201,6 +301,7 @@ def hrnet(input, num_classes): ...@@ -201,6 +301,7 @@ def hrnet(input, num_classes):
logit = high_resolution_net(input, num_classes) logit = high_resolution_net(input, num_classes)
return logit return logit
if __name__ == '__main__': if __name__ == '__main__':
image_shape = [-1, 3, 769, 769] image_shape = [-1, 3, 769, 769]
image = fluid.data(name='image', shape=image_shape, dtype='float32') image = fluid.data(name='image', shape=image_shape, dtype='float32')
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -24,6 +24,7 @@ from models.libs.model_libs import avg_pool, conv, bn ...@@ -24,6 +24,7 @@ from models.libs.model_libs import avg_pool, conv, bn
from models.backbone.resnet import ResNet as resnet_backbone from models.backbone.resnet import ResNet as resnet_backbone
from utils.config import cfg from utils.config import cfg
def get_logit_interp(input, num_classes, out_shape, name="logit"): def get_logit_interp(input, num_classes, out_shape, name="logit"):
# 根据类别数决定最后一层卷积输出, 并插值回原始尺寸 # 根据类别数决定最后一层卷积输出, 并插值回原始尺寸
param_attr = fluid.ParamAttr( param_attr = fluid.ParamAttr(
...@@ -33,16 +34,15 @@ def get_logit_interp(input, num_classes, out_shape, name="logit"): ...@@ -33,16 +34,15 @@ def get_logit_interp(input, num_classes, out_shape, name="logit"):
initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.01)) initializer=fluid.initializer.TruncatedNormal(loc=0.0, scale=0.01))
with scope(name): with scope(name):
logit = conv(input, logit = conv(
num_classes, input,
filter_size=1, num_classes,
param_attr=param_attr, filter_size=1,
bias_attr=True, param_attr=param_attr,
name=name+'_conv') bias_attr=True,
name=name + '_conv')
logit_interp = fluid.layers.resize_bilinear( logit_interp = fluid.layers.resize_bilinear(
logit, logit, out_shape=out_shape, name=name + '_interp')
out_shape=out_shape,
name=name+'_interp')
return logit_interp return logit_interp
...@@ -51,40 +51,44 @@ def psp_module(input, out_features): ...@@ -51,40 +51,44 @@ def psp_module(input, out_features):
# 输入:backbone输出的特征 # 输入:backbone输出的特征
# 输出:对输入进行不同尺度pooling, 卷积操作后插值回原始尺寸,并concat # 输出:对输入进行不同尺度pooling, 卷积操作后插值回原始尺寸,并concat
# 最后进行一个卷积及BN操作 # 最后进行一个卷积及BN操作
cat_layers = [] cat_layers = []
sizes = (1,2,3,6) sizes = (1, 2, 3, 6)
for size in sizes: for size in sizes:
psp_name = "psp" + str(size) psp_name = "psp" + str(size)
with scope(psp_name): with scope(psp_name):
pool = fluid.layers.adaptive_pool2d(input, pool = fluid.layers.adaptive_pool2d(
pool_size=[size, size], input,
pool_type='avg', pool_size=[size, size],
name=psp_name+'_adapool') pool_type='avg',
data = conv(pool, out_features, name=psp_name + '_adapool')
filter_size=1, data = conv(
bias_attr=True, pool,
name= psp_name + '_conv') out_features,
filter_size=1,
bias_attr=True,
name=psp_name + '_conv')
data_bn = bn(data, act='relu') data_bn = bn(data, act='relu')
interp = fluid.layers.resize_bilinear(data_bn, interp = fluid.layers.resize_bilinear(
out_shape=input.shape[2:], data_bn, out_shape=input.shape[2:], name=psp_name + '_interp')
name=psp_name+'_interp')
cat_layers.append(interp) cat_layers.append(interp)
cat_layers = [input] + cat_layers[::-1] cat_layers = [input] + cat_layers[::-1]
cat = fluid.layers.concat(cat_layers, axis=1, name='psp_cat') cat = fluid.layers.concat(cat_layers, axis=1, name='psp_cat')
psp_end_name = "psp_end" psp_end_name = "psp_end"
with scope(psp_end_name): with scope(psp_end_name):
data = conv(cat, data = conv(
out_features, cat,
filter_size=3, out_features,
padding=1, filter_size=3,
bias_attr=True, padding=1,
name=psp_end_name) bias_attr=True,
name=psp_end_name)
out = bn(data, act='relu') out = bn(data, act='relu')
return out return out
def resnet(input): def resnet(input):
# PSPNET backbone: resnet, 默认resnet50 # PSPNET backbone: resnet, 默认resnet50
# end_points: resnet终止层数 # end_points: resnet终止层数
...@@ -92,14 +96,14 @@ def resnet(input): ...@@ -92,14 +96,14 @@ def resnet(input):
scale = cfg.MODEL.PSPNET.DEPTH_MULTIPLIER scale = cfg.MODEL.PSPNET.DEPTH_MULTIPLIER
layers = cfg.MODEL.PSPNET.LAYERS layers = cfg.MODEL.PSPNET.LAYERS
end_points = layers - 1 end_points = layers - 1
dilation_dict = {2:2, 3:4} dilation_dict = {2: 2, 3: 4}
model = resnet_backbone(layers, scale, stem='pspnet') model = resnet_backbone(layers, scale, stem='pspnet')
data, _ = model.net(input, data, _ = model.net(
end_points=end_points, input, end_points=end_points, dilation_dict=dilation_dict)
dilation_dict=dilation_dict)
return data return data
def pspnet(input, num_classes): def pspnet(input, num_classes):
# Backbone: ResNet # Backbone: ResNet
res = resnet(input) res = resnet(input)
...@@ -109,4 +113,3 @@ def pspnet(input, num_classes): ...@@ -109,4 +113,3 @@ def pspnet(input, num_classes):
# 根据类别数决定最后一层卷积输出, 并插值回原始尺寸 # 根据类别数决定最后一层卷积输出, 并插值回原始尺寸
logit = get_logit_interp(dropout, num_classes, input.shape[2:]) logit = get_logit_interp(dropout, num_classes, input.shape[2:])
return logit return logit
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -71,7 +71,8 @@ class SegDataset(object): ...@@ -71,7 +71,8 @@ class SegDataset(object):
if self.shuffle and cfg.NUM_TRAINERS > 1: if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines) np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // cfg.NUM_TRAINERS num_lines = len(self.all_lines) // cfg.NUM_TRAINERS
self.lines = self.all_lines[num_lines * cfg.TRAINER_ID: num_lines * (cfg.TRAINER_ID + 1)] self.lines = self.all_lines[num_lines * cfg.TRAINER_ID:num_lines *
(cfg.TRAINER_ID + 1)]
self.shuffle_seed += 1 self.shuffle_seed += 1
elif self.shuffle: elif self.shuffle:
np.random.shuffle(self.lines) np.random.shuffle(self.lines)
...@@ -99,7 +100,8 @@ class SegDataset(object): ...@@ -99,7 +100,8 @@ class SegDataset(object):
if self.shuffle and cfg.NUM_TRAINERS > 1: if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines) np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // cfg.NUM_TRAINERS num_lines = len(self.all_lines) // cfg.NUM_TRAINERS
self.lines = self.all_lines[num_lines * cfg.TRAINER_ID: num_lines * (cfg.TRAINER_ID + 1)] self.lines = self.all_lines[num_lines * cfg.TRAINER_ID:num_lines *
(cfg.TRAINER_ID + 1)]
self.shuffle_seed += 1 self.shuffle_seed += 1
elif self.shuffle: elif self.shuffle:
np.random.shuffle(self.lines) np.random.shuffle(self.lines)
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,55 +21,48 @@ import warnings ...@@ -21,55 +21,48 @@ import warnings
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='PaddleSeg generate file list on cityscapes or your customized dataset.') description=
parser.add_argument( 'PaddleSeg generate file list on cityscapes or your customized dataset.'
'dataset_root',
help='dataset root directory',
type=str
) )
parser.add_argument('dataset_root', help='dataset root directory', type=str)
parser.add_argument( parser.add_argument(
'--type', '--type',
help='dataset type: \n' help='dataset type: \n'
'- cityscapes \n' '- cityscapes \n'
'- custom(default)', '- custom(default)',
default="custom", default="custom",
type=str type=str)
)
parser.add_argument( parser.add_argument(
'--separator', '--separator',
dest='separator', dest='separator',
help='file list separator', help='file list separator',
default="|", default="|",
type=str type=str)
)
parser.add_argument( parser.add_argument(
'--folder', '--folder',
help='the folder names of images and labels', help='the folder names of images and labels',
type=str, type=str,
nargs=2, nargs=2,
default=['images', 'annotations'] default=['images', 'annotations'])
)
parser.add_argument( parser.add_argument(
'--second_folder', '--second_folder',
help='the second-level folder names of train set, validation set, test set', help=
'the second-level folder names of train set, validation set, test set',
type=str, type=str,
nargs='*', nargs='*',
default=['train', 'val', 'test'] default=['train', 'val', 'test'])
)
parser.add_argument( parser.add_argument(
'--format', '--format',
help='data format of images and labels, e.g. jpg or png.', help='data format of images and labels, e.g. jpg or png.',
type=str, type=str,
nargs=2, nargs=2,
default=['jpg', 'png'] default=['jpg', 'png'])
)
parser.add_argument( parser.add_argument(
'--postfix', '--postfix',
help='postfix of images or labels', help='postfix of images or labels',
type=str, type=str,
nargs=2, nargs=2,
default=['', ''] default=['', ''])
)
return parser.parse_args() return parser.parse_args()
...@@ -120,15 +113,17 @@ def generate_list(args): ...@@ -120,15 +113,17 @@ def generate_list(args):
num_images = len(image_files) num_images = len(image_files)
if not label_files: if not label_files:
label_dir = os.path.join(dataset_root, args.folder[1], dataset_split) label_dir = os.path.join(dataset_root, args.folder[1],
dataset_split)
warnings.warn("No labels in {} !!!".format(label_dir)) warnings.warn("No labels in {} !!!".format(label_dir))
num_label = len(label_files) num_label = len(label_files)
if num_images != num_label and num_label > 0: if num_images != num_label and num_label > 0:
raise Exception("Number of images = {} number of labels = {} \n" raise Exception(
"Either number of images is equal to number of labels, " "Number of images = {} number of labels = {} \n"
"or number of labels is equal to 0.\n" "Either number of images is equal to number of labels, "
"Please check your dataset!".format(num_images, num_label)) "or number of labels is equal to 0.\n"
"Please check your dataset!".format(num_images, num_label))
file_list = os.path.join(dataset_root, dataset_split + '.txt') file_list = os.path.join(dataset_root, dataset_split + '.txt')
with open(file_list, "w") as f: with open(file_list, "w") as f:
......
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function from __future__ import print_function
import argparse import argparse
...@@ -11,16 +25,12 @@ from PIL import Image ...@@ -11,16 +25,12 @@ from PIL import Image
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter)
) parser.add_argument(
parser.add_argument('dir_or_file', 'dir_or_file', help='input gray label directory or file list path')
help='input gray label directory or file list path') parser.add_argument('output_dir', help='output colorful label directory')
parser.add_argument('output_dir', parser.add_argument('--dataset_dir', help='dataset directory')
help='output colorful label directory') parser.add_argument('--file_separator', help='file list separator')
parser.add_argument('--dataset_dir',
help='dataset directory')
parser.add_argument('--file_separator',
help='file list separator')
return parser.parse_args() return parser.parse_args()
......
#!/usr/bin/env python # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function from __future__ import print_function
...@@ -7,12 +20,11 @@ import glob ...@@ -7,12 +20,11 @@ import glob
import json import json
import os import os
import os.path as osp import os.path as osp
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import labelme
from gray2pseudo_color import get_color_map_list from gray2pseudo_color import get_color_map_list
from labelme2seg import shape2label
def parse_args(): def parse_args():
...@@ -89,10 +101,10 @@ def main(args): ...@@ -89,10 +101,10 @@ def main(args):
img_shape = (data_size['height'], data_size['width'], img_shape = (data_size['height'], data_size['width'],
data_size['depth']) data_size['depth'])
lbl = labelme.utils.shapes_to_label( lbl = shape2label(
img_shape=img_shape, img_size=img_shape,
shapes=data_shapes, shapes=data_shapes,
label_name_to_value=class_name_to_id, class_name_mapping=class_name_to_id,
) )
if osp.splitext(out_png_file)[1] != '.png': if osp.splitext(out_png_file)[1] != '.png':
......
#!/usr/bin/env python # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import glob import glob
import math
import json import json
import os import os
import os.path as osp import os.path as osp
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import labelme import PIL.ImageDraw
import cv2
from gray2pseudo_color import get_color_map_list from gray2pseudo_color import get_color_map_list
...@@ -64,12 +78,12 @@ def main(args): ...@@ -64,12 +78,12 @@ def main(args):
data = json.load(f) data = json.load(f)
img_file = osp.join(osp.dirname(label_file), data['imagePath']) img_file = osp.join(osp.dirname(label_file), data['imagePath'])
img = np.asarray(PIL.Image.open(img_file)) img = np.asarray(cv2.imread(img_file))
lbl = labelme.utils.shapes_to_label( lbl = shape2label(
img_shape=img.shape, img_size=img.shape,
shapes=data['shapes'], shapes=data['shapes'],
label_name_to_value=class_name_to_id, class_name_mapping=class_name_to_id,
) )
if osp.splitext(out_png_file)[1] != '.png': if osp.splitext(out_png_file)[1] != '.png':
...@@ -85,6 +99,27 @@ def main(args): ...@@ -85,6 +99,27 @@ def main(args):
'Please consider using the .npy format.' % out_png_file) 'Please consider using the .npy format.' % out_png_file)
def shape2mask(img_size, points):
label_mask = PIL.Image.fromarray(np.zeros(img_size[:2], dtype=np.uint8))
image_draw = PIL.ImageDraw.Draw(label_mask)
points_list = [tuple(point) for point in points]
assert len(points_list) > 2, 'Polygon must have points more than 2'
image_draw.polygon(xy=points_list, outline=1, fill=1)
return np.array(label_mask, dtype=bool)
def shape2label(img_size, shapes, class_name_mapping):
label = np.zeros(img_size[:2], dtype=np.int32)
for shape in shapes:
points = shape['points']
class_name = shape['label']
shape_type = shape.get('shape_type', None)
class_id = class_name_mapping[class_name]
label_mask = shape2mask(img_size[:2], points)
label[label_mask] = class_id
return label
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
main(args) main(args)
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -26,9 +26,7 @@ import argparse ...@@ -26,9 +26,7 @@ import argparse
import pprint import pprint
import random import random
import shutil import shutil
import functools
import paddle
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import profiler from paddle.fluid import profiler
...@@ -39,10 +37,10 @@ from metrics import ConfusionMatrix ...@@ -39,10 +37,10 @@ from metrics import ConfusionMatrix
from reader import SegDataset from reader import SegDataset
from models.model_builder import build_model from models.model_builder import build_model
from models.model_builder import ModelPhase from models.model_builder import ModelPhase
from models.model_builder import parse_shape_from_file
from eval import evaluate from eval import evaluate
from vis import visualize from vis import visualize
from utils import dist_utils from utils import dist_utils
from utils.load_model_utils import load_pretrained_weights
def parse_args(): def parse_args():
...@@ -118,38 +116,7 @@ def parse_args(): ...@@ -118,38 +116,7 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def save_vars(executor, dirname, program=None, vars=None): def save_checkpoint(program, ckpt_name):
"""
Temporary resolution for Win save variables compatability.
Will fix in PaddlePaddle v1.5.2
"""
save_program = fluid.Program()
save_block = save_program.global_block()
for each_var in vars:
# NOTE: don't save the variable which type is RAW
if each_var.type == fluid.core.VarDesc.VarType.RAW:
continue
new_var = save_block.create_var(
name=each_var.name,
shape=each_var.shape,
dtype=each_var.dtype,
type=each_var.type,
lod_level=each_var.lod_level,
persistable=True)
file_path = os.path.join(dirname, new_var.name)
file_path = os.path.normpath(file_path)
save_block.append_op(
type='save',
inputs={'X': [new_var]},
outputs={},
attrs={'file_path': file_path})
executor.run(save_program)
def save_checkpoint(exe, program, ckpt_name):
""" """
Save checkpoint for evaluation or resume training Save checkpoint for evaluation or resume training
""" """
...@@ -158,29 +125,22 @@ def save_checkpoint(exe, program, ckpt_name): ...@@ -158,29 +125,22 @@ def save_checkpoint(exe, program, ckpt_name):
if not os.path.isdir(ckpt_dir): if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir) os.makedirs(ckpt_dir)
save_vars( fluid.save(program, os.path.join(ckpt_dir, 'model'))
exe,
ckpt_dir,
program,
vars=list(filter(fluid.io.is_persistable, program.list_vars())))
return ckpt_dir return ckpt_dir
def load_checkpoint(exe, program): def load_checkpoint(exe, program):
""" """
Load checkpoiont from pretrained model directory for resume training Load checkpoiont for resuming training
""" """
print('Resume model training from:', cfg.TRAIN.RESUME_MODEL_DIR)
if not os.path.exists(cfg.TRAIN.RESUME_MODEL_DIR):
raise ValueError("TRAIN.PRETRAIN_MODEL {} not exist!".format(
cfg.TRAIN.RESUME_MODEL_DIR))
fluid.io.load_persistables(
exe, cfg.TRAIN.RESUME_MODEL_DIR, main_program=program)
model_path = cfg.TRAIN.RESUME_MODEL_DIR model_path = cfg.TRAIN.RESUME_MODEL_DIR
print('Resume model training from:', model_path)
if not os.path.exists(model_path):
raise ValueError(
"TRAIN.PRETRAIN_MODEL {} not exist!".format(model_path))
fluid.load(program, os.path.join(model_path, 'model'), exe)
# Check is path ended by path spearator # Check is path ended by path spearator
if model_path[-1] == os.sep: if model_path[-1] == os.sep:
model_path = model_path[0:-1] model_path = model_path[0:-1]
...@@ -195,7 +155,6 @@ def load_checkpoint(exe, program): ...@@ -195,7 +155,6 @@ def load_checkpoint(exe, program):
else: else:
raise ValueError("Resume model path is not valid!") raise ValueError("Resume model path is not valid!")
print("Model checkpoint loaded successfully!") print("Model checkpoint loaded successfully!")
return begin_epoch return begin_epoch
...@@ -247,8 +206,6 @@ def train(cfg): ...@@ -247,8 +206,6 @@ def train(cfg):
yield item[0], item[1], item[2] yield item[0], item[1], item[2]
# Get device environment # Get device environment
# places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
# place = places[0]
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace()
places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places() places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
...@@ -304,42 +261,7 @@ def train(cfg): ...@@ -304,42 +261,7 @@ def train(cfg):
begin_epoch = load_checkpoint(exe, train_prog) begin_epoch = load_checkpoint(exe, train_prog)
# Load pretrained model # Load pretrained model
elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL_DIR): elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL_DIR):
print_info('Pretrained model dir: ', cfg.TRAIN.PRETRAINED_MODEL_DIR) load_pretrained_weights(exe, train_prog, cfg.TRAIN.PRETRAINED_MODEL_DIR)
load_vars = []
load_fail_vars = []
def var_shape_matched(var, shape):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist = os.path.exists(
os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
if var_exist:
var_shape = parse_shape_from_file(
os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
return var_shape == shape
return False
for x in train_prog.list_vars():
if isinstance(x, fluid.framework.Parameter):
shape = tuple(fluid.global_scope().find_var(
x.name).get_tensor().shape())
if var_shape_matched(x, shape):
load_vars.append(x)
else:
load_fail_vars.append(x)
fluid.io.load_vars(
exe, dirname=cfg.TRAIN.PRETRAINED_MODEL_DIR, vars=load_vars)
for var in load_vars:
print_info("Parameter[{}] loaded sucessfully!".format(var.name))
for var in load_fail_vars:
print_info(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it.".format(var.name))
print_info("{}/{} pretrained parameters loaded successfully!".format(
len(load_vars),
len(load_vars) + len(load_fail_vars)))
else: else:
print_info( print_info(
'Pretrained model dir {} not exists, training from scratch...'. 'Pretrained model dir {} not exists, training from scratch...'.
...@@ -418,12 +340,9 @@ def train(cfg): ...@@ -418,12 +340,9 @@ def train(cfg):
step) step)
log_writer.add_scalar('Train/mean_acc', mean_acc, log_writer.add_scalar('Train/mean_acc', mean_acc,
step) step)
log_writer.add_scalar('Train/loss', avg_loss, log_writer.add_scalar('Train/loss', avg_loss, step)
step) log_writer.add_scalar('Train/lr', lr[0], step)
log_writer.add_scalar('Train/lr', lr[0], log_writer.add_scalar('Train/step/sec', speed, step)
step)
log_writer.add_scalar('Train/step/sec', speed,
step)
sys.stdout.flush() sys.stdout.flush()
avg_loss = 0.0 avg_loss = 0.0
cm.zero_matrix() cm.zero_matrix()
...@@ -445,12 +364,9 @@ def train(cfg): ...@@ -445,12 +364,9 @@ def train(cfg):
).format(epoch, step, lr[0], avg_loss, speed, ).format(epoch, step, lr[0], avg_loss, speed,
calculate_eta(all_step - step, speed))) calculate_eta(all_step - step, speed)))
if args.use_vdl: if args.use_vdl:
log_writer.add_scalar('Train/loss', avg_loss, log_writer.add_scalar('Train/loss', avg_loss, step)
step) log_writer.add_scalar('Train/lr', lr[0], step)
log_writer.add_scalar('Train/lr', lr[0], log_writer.add_scalar('Train/speed', speed, step)
step)
log_writer.add_scalar('Train/speed', speed,
step)
sys.stdout.flush() sys.stdout.flush()
avg_loss = 0.0 avg_loss = 0.0
timer.restart() timer.restart()
...@@ -470,7 +386,7 @@ def train(cfg): ...@@ -470,7 +386,7 @@ def train(cfg):
if (epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0 if (epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0
or epoch == cfg.SOLVER.NUM_EPOCHS) and cfg.TRAINER_ID == 0: or epoch == cfg.SOLVER.NUM_EPOCHS) and cfg.TRAINER_ID == 0:
ckpt_dir = save_checkpoint(exe, train_prog, epoch) ckpt_dir = save_checkpoint(train_prog, epoch)
if args.do_eval: if args.do_eval:
print("Evaluation start") print("Evaluation start")
...@@ -480,10 +396,8 @@ def train(cfg): ...@@ -480,10 +396,8 @@ def train(cfg):
use_gpu=args.use_gpu, use_gpu=args.use_gpu,
use_mpio=args.use_mpio) use_mpio=args.use_mpio)
if args.use_vdl: if args.use_vdl:
log_writer.add_scalar('Evaluate/mean_iou', mean_iou, log_writer.add_scalar('Evaluate/mean_iou', mean_iou, step)
step) log_writer.add_scalar('Evaluate/mean_acc', mean_acc, step)
log_writer.add_scalar('Evaluate/mean_acc', mean_acc,
step)
if mean_iou > best_mIoU: if mean_iou > best_mIoU:
best_mIoU = mean_iou best_mIoU = mean_iou
...@@ -505,7 +419,7 @@ def train(cfg): ...@@ -505,7 +419,7 @@ def train(cfg):
# save final model # save final model
if cfg.TRAINER_ID == 0: if cfg.TRAINER_ID == 0:
save_checkpoint(exe, train_prog, 'final') save_checkpoint(train_prog, 'final')
def main(args): def main(args):
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -98,7 +99,7 @@ class SegConfig(dict): ...@@ -98,7 +99,7 @@ class SegConfig(dict):
'DATASET.IMAGE_TYPE config error, only support `rgb`, `gray` and `rgba`' 'DATASET.IMAGE_TYPE config error, only support `rgb`, `gray` and `rgba`'
) )
if self.MEAN is not None: if self.MEAN is not None:
self.DATASET.PADDING_VALUE = [x*255.0 for x in self.MEAN] self.DATASET.PADDING_VALUE = [x * 255.0 for x in self.MEAN]
if not self.TRAIN_CROP_SIZE: if not self.TRAIN_CROP_SIZE:
raise ValueError( raise ValueError(
...@@ -111,9 +112,12 @@ class SegConfig(dict): ...@@ -111,9 +112,12 @@ class SegConfig(dict):
) )
# Ensure file list is use UTF-8 encoding # Ensure file list is use UTF-8 encoding
train_sets = codecs.open(self.DATASET.TRAIN_FILE_LIST, 'r', 'utf-8').readlines() train_sets = codecs.open(self.DATASET.TRAIN_FILE_LIST, 'r',
val_sets = codecs.open(self.DATASET.VAL_FILE_LIST, 'r', 'utf-8').readlines() 'utf-8').readlines()
test_sets = codecs.open(self.DATASET.TEST_FILE_LIST, 'r', 'utf-8').readlines() val_sets = codecs.open(self.DATASET.VAL_FILE_LIST, 'r',
'utf-8').readlines()
test_sets = codecs.open(self.DATASET.TEST_FILE_LIST, 'r',
'utf-8').readlines()
self.DATASET.TRAIN_TOTAL_IMAGES = len(train_sets) self.DATASET.TRAIN_TOTAL_IMAGES = len(train_sets)
self.DATASET.VAL_TOTAL_IMAGES = len(val_sets) self.DATASET.VAL_TOTAL_IMAGES = len(val_sets)
self.DATASET.TEST_TOTAL_IMAGES = len(test_sets) self.DATASET.TEST_TOTAL_IMAGES = len(test_sets)
...@@ -122,12 +126,13 @@ class SegConfig(dict): ...@@ -122,12 +126,13 @@ class SegConfig(dict):
len(self.MODEL.MULTI_LOSS_WEIGHT) != 3: len(self.MODEL.MULTI_LOSS_WEIGHT) != 3:
self.MODEL.MULTI_LOSS_WEIGHT = [1.0, 0.4, 0.16] self.MODEL.MULTI_LOSS_WEIGHT = [1.0, 0.4, 0.16]
if self.AUG.AUG_METHOD not in ['unpadding', 'stepscaling', 'rangescaling']: if self.AUG.AUG_METHOD not in [
'unpadding', 'stepscaling', 'rangescaling'
]:
raise ValueError( raise ValueError(
'AUG.AUG_METHOD config error, only support `unpadding`, `unpadding` and `rangescaling`' 'AUG.AUG_METHOD config error, only support `unpadding`, `unpadding` and `rangescaling`'
) )
def update_from_list(self, config_list): def update_from_list(self, config_list):
if len(config_list) % 2 != 0: if len(config_list) % 2 != 0:
raise ValueError( raise ValueError(
......
# -*- coding: utf-8 -*- # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. #Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); #Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. #you may not use this file except in compliance with the License.
......
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
from paddle import fluid from paddle import fluid
def load_fp16_vars(executor, dirname, program): def load_fp16_vars(executor, dirname, program):
load_dirname = os.path.normpath(dirname) load_dirname = os.path.normpath(dirname)
...@@ -28,4 +44,4 @@ def load_fp16_vars(executor, dirname, program): ...@@ -28,4 +44,4 @@ def load_fp16_vars(executor, dirname, program):
'load_as_fp16': var.dtype == fluid.core.VarDesc.VarType.FP16 'load_as_fp16': var.dtype == fluid.core.VarDesc.VarType.FP16
}) })
executor.run(load_prog) executor.run(load_prog)
\ No newline at end of file
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
import six
import numpy as np
def parse_param_file(param_file, return_shape=True):
from paddle.fluid.proto.framework_pb2 import VarType
f = open(param_file, 'rb')
version = np.fromstring(f.read(4), dtype='int32')
lod_level = np.fromstring(f.read(8), dtype='int64')
for i in range(int(lod_level)):
_size = np.fromstring(f.read(8), dtype='int64')
_ = f.read(_size)
version = np.fromstring(f.read(4), dtype='int32')
tensor_desc = VarType.TensorDesc()
tensor_desc_size = np.fromstring(f.read(4), dtype='int32')
tensor_desc.ParseFromString(f.read(int(tensor_desc_size)))
tensor_shape = tuple(tensor_desc.dims)
if return_shape:
f.close()
return tuple(tensor_desc.dims)
if tensor_desc.data_type != 5:
raise Exception(
"Unexpected data type while parse {}".format(param_file))
data_size = 4
for i in range(len(tensor_shape)):
data_size *= tensor_shape[i]
weight = np.fromstring(f.read(data_size), dtype='float32')
f.close()
return np.reshape(weight, tensor_shape)
def load_pdparams(exe, main_prog, model_dir):
import paddle.fluid as fluid
from paddle.fluid.proto.framework_pb2 import VarType
from paddle.fluid.framework import Program
vars_to_load = list()
vars_not_load = list()
import pickle
with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f:
params_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
unused_vars = list()
for var in main_prog.list_vars():
if not isinstance(var, fluid.framework.Parameter):
continue
if var.name not in params_dict:
print("{} is not in saved model".format(var.name))
vars_not_load.append(var.name)
continue
if var.shape != params_dict[var.name].shape:
unused_vars.append(var.name)
vars_not_load.append(var.name)
print(
"[SKIP] Shape of pretrained weight {} doesn't match.(Pretrained: {}, Actual: {})"
.format(var.name, params_dict[var.name].shape, var.shape))
continue
vars_to_load.append(var)
for var_name in unused_vars:
del params_dict[var_name]
fluid.io.set_program_state(main_prog, params_dict)
if len(vars_to_load) == 0:
print(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
)
else:
print("There are {}/{} varaibles in {} are loaded.".format(
len(vars_to_load),
len(vars_to_load) + len(vars_not_load), model_dir))
def load_pretrained_weights(exe, main_prog, weights_dir):
if not osp.exists(weights_dir):
raise Exception("Path {} not exists.".format(weights_dir))
if osp.exists(osp.join(weights_dir, "model.pdparams")):
return load_pdparams(exe, main_prog, weights_dir)
import paddle.fluid as fluid
vars_to_load = list()
vars_not_load = list()
for var in main_prog.list_vars():
if not isinstance(var, fluid.framework.Parameter):
continue
if not osp.exists(osp.join(weights_dir, var.name)):
print("[SKIP] Pretrained weight {}/{} doesn't exist".format(
weights_dir, var.name))
vars_not_load.append(var)
continue
pretrained_shape = parse_param_file(osp.join(weights_dir, var.name))
actual_shape = tuple(var.shape)
if pretrained_shape != actual_shape:
print(
"[SKIP] Shape of pretrained weight {}/{} doesn't match.(Pretrained: {}, Actual: {})"
.format(weights_dir, var.name, pretrained_shape, actual_shape))
vars_not_load.append(var)
continue
vars_to_load.append(var)
params_dict = fluid.io.load_program_state(
weights_dir, var_list=vars_to_load)
fluid.io.set_program_state(main_prog, params_dict)
if len(vars_to_load) == 0:
print(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
)
else:
print("There are {}/{} varaibles in {} are loaded.".format(
len(vars_to_load),
len(vars_to_load) + len(vars_not_load), weights_dir))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -115,7 +115,12 @@ def visualize(cfg, ...@@ -115,7 +115,12 @@ def visualize(cfg,
ckpt_dir = cfg.TEST.TEST_MODEL if not ckpt_dir else ckpt_dir ckpt_dir = cfg.TEST.TEST_MODEL if not ckpt_dir else ckpt_dir
fluid.io.load_params(exe, ckpt_dir, main_program=test_prog) if ckpt_dir is not None:
print('load test model:', ckpt_dir)
try:
fluid.load(test_prog, os.path.join(ckpt_dir, 'model'), exe)
except:
fluid.io.load_params(exe, ckpt_dir, main_program=test_prog)
save_dir = vis_dir save_dir = vis_dir
makedirs(save_dir) makedirs(save_dir)
...@@ -169,18 +174,13 @@ def visualize(cfg, ...@@ -169,18 +174,13 @@ def visualize(cfg,
print("VisualDL visualization epoch", epoch) print("VisualDL visualization epoch", epoch)
pred_mask_np = np.array(pred_mask.convert("RGB")) pred_mask_np = np.array(pred_mask.convert("RGB"))
log_writer.add_image( log_writer.add_image("Predict/{}".format(img_name),
"Predict/{}".format(img_name), pred_mask_np, epoch)
pred_mask_np,
epoch)
# Original image # Original image
# BGR->RGB # BGR->RGB
img = cv2.imread( img = cv2.imread(os.path.join(cfg.DATASET.DATA_DIR,
os.path.join(cfg.DATASET.DATA_DIR, img_name))[..., ::-1] img_name))[..., ::-1]
log_writer.add_image( log_writer.add_image("Images/{}".format(img_name), img, epoch)
"Images/{}".format(img_name),
img,
epoch)
# add ground truth (label) images # add ground truth (label) images
grt = grts[i] grt = grts[i]
if grt is not None: if grt is not None:
...@@ -189,10 +189,8 @@ def visualize(cfg, ...@@ -189,10 +189,8 @@ def visualize(cfg,
grt_pil.putpalette(color_map) grt_pil.putpalette(color_map)
grt_pil = grt_pil.resize((org_shape[1], org_shape[0])) grt_pil = grt_pil.resize((org_shape[1], org_shape[0]))
grt = np.array(grt_pil.convert("RGB")) grt = np.array(grt_pil.convert("RGB"))
log_writer.add_image( log_writer.add_image("Label/{}".format(img_name), grt,
"Label/{}".format(img_name), epoch)
grt,
epoch)
# If in local_test mode, only visualize 5 images just for testing # If in local_test mode, only visualize 5 images just for testing
# procedure # procedure
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -42,7 +43,7 @@ model_urls = { ...@@ -42,7 +43,7 @@ model_urls = {
"hrnet_w30_bn_imagenet": "hrnet_w30_bn_imagenet":
"https://paddleseg.bj.bcebos.com/models/hrnet_w30_imagenet.tar", "https://paddleseg.bj.bcebos.com/models/hrnet_w30_imagenet.tar",
"hrnet_w32_bn_imagenet": "hrnet_w32_bn_imagenet":
"https://paddleseg.bj.bcebos.com/models/hrnet_w32_imagenet.tar" , "https://paddleseg.bj.bcebos.com/models/hrnet_w32_imagenet.tar",
"hrnet_w40_bn_imagenet": "hrnet_w40_bn_imagenet":
"https://paddleseg.bj.bcebos.com/models/hrnet_w40_imagenet.tar", "https://paddleseg.bj.bcebos.com/models/hrnet_w40_imagenet.tar",
"hrnet_w44_bn_imagenet": "hrnet_w44_bn_imagenet":
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -44,6 +44,7 @@ from model_builder import parse_shape_from_file ...@@ -44,6 +44,7 @@ from model_builder import parse_shape_from_file
from eval import evaluate from eval import evaluate
from vis import visualize from vis import visualize
from utils import dist_utils from utils import dist_utils
from utils.load_model_utils import load_pretrained_weights
import solver import solver
from paddleslim.dist.single_distiller import merge, l2_loss from paddleslim.dist.single_distiller import merge, l2_loss
...@@ -116,38 +117,7 @@ def parse_args(): ...@@ -116,38 +117,7 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def save_vars(executor, dirname, program=None, vars=None): def save_checkpoint(program, ckpt_name):
"""
Temporary resolution for Win save variables compatability.
Will fix in PaddlePaddle v1.5.2
"""
save_program = fluid.Program()
save_block = save_program.global_block()
for each_var in vars:
# NOTE: don't save the variable which type is RAW
if each_var.type == fluid.core.VarDesc.VarType.RAW:
continue
new_var = save_block.create_var(
name=each_var.name,
shape=each_var.shape,
dtype=each_var.dtype,
type=each_var.type,
lod_level=each_var.lod_level,
persistable=True)
file_path = os.path.join(dirname, new_var.name)
file_path = os.path.normpath(file_path)
save_block.append_op(
type='save',
inputs={'X': [new_var]},
outputs={},
attrs={'file_path': file_path})
executor.run(save_program)
def save_checkpoint(exe, program, ckpt_name):
""" """
Save checkpoint for evaluation or resume training Save checkpoint for evaluation or resume training
""" """
...@@ -156,29 +126,22 @@ def save_checkpoint(exe, program, ckpt_name): ...@@ -156,29 +126,22 @@ def save_checkpoint(exe, program, ckpt_name):
if not os.path.isdir(ckpt_dir): if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir) os.makedirs(ckpt_dir)
save_vars( fluid.save(program, os.path.join(ckpt_dir, 'model'))
exe,
ckpt_dir,
program,
vars=list(filter(fluid.io.is_persistable, program.list_vars())))
return ckpt_dir return ckpt_dir
def load_checkpoint(exe, program): def load_checkpoint(exe, program):
""" """
Load checkpoiont from pretrained model directory for resume training Load checkpoiont for resuming training
""" """
print('Resume model training from:', cfg.TRAIN.RESUME_MODEL_DIR)
if not os.path.exists(cfg.TRAIN.RESUME_MODEL_DIR):
raise ValueError("TRAIN.PRETRAIN_MODEL {} not exist!".format(
cfg.TRAIN.RESUME_MODEL_DIR))
fluid.io.load_persistables(
exe, cfg.TRAIN.RESUME_MODEL_DIR, main_program=program)
model_path = cfg.TRAIN.RESUME_MODEL_DIR model_path = cfg.TRAIN.RESUME_MODEL_DIR
print('Resume model training from:', model_path)
if not os.path.exists(model_path):
raise ValueError(
"TRAIN.PRETRAIN_MODEL {} not exist!".format(model_path))
fluid.load(program, os.path.join(model_path, 'model'), exe)
# Check is path ended by path spearator # Check is path ended by path spearator
if model_path[-1] == os.sep: if model_path[-1] == os.sep:
model_path = model_path[0:-1] model_path = model_path[0:-1]
...@@ -193,7 +156,6 @@ def load_checkpoint(exe, program): ...@@ -193,7 +156,6 @@ def load_checkpoint(exe, program):
else: else:
raise ValueError("Resume model path is not valid!") raise ValueError("Resume model path is not valid!")
print("Model checkpoint loaded successfully!") print("Model checkpoint loaded successfully!")
return begin_epoch return begin_epoch
...@@ -289,7 +251,11 @@ def train(cfg): ...@@ -289,7 +251,11 @@ def train(cfg):
ckpt_dir = cfg.SLIM.KNOWLEDGE_DISTILL_TEACHER_MODEL_DIR ckpt_dir = cfg.SLIM.KNOWLEDGE_DISTILL_TEACHER_MODEL_DIR
assert ckpt_dir is not None assert ckpt_dir is not None
print('load teacher model:', ckpt_dir) print('load teacher model:', ckpt_dir)
fluid.io.load_params(exe, ckpt_dir, main_program=teacher_program) if os.path.exists(ckpt_dir):
try:
fluid.load(teacher_program, os.path.join(ckpt_dir, 'model'), exe)
except:
fluid.io.load_params(exe, ckpt_dir, main_program=teacher_program)
# cfg = load_config(FLAGS.config) # cfg = load_config(FLAGS.config)
cfg.update_from_file(args.cfg_file) cfg.update_from_file(args.cfg_file)
...@@ -355,42 +321,8 @@ def train(cfg): ...@@ -355,42 +321,8 @@ def train(cfg):
begin_epoch = load_checkpoint(exe, fluid.default_main_program()) begin_epoch = load_checkpoint(exe, fluid.default_main_program())
# Load pretrained model # Load pretrained model
elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL_DIR): elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL_DIR):
print_info('Pretrained model dir: ', cfg.TRAIN.PRETRAINED_MODEL_DIR) load_pretrained_weights(exe, fluid.default_main_program(),
load_vars = [] cfg.TRAIN.PRETRAINED_MODEL_DIR)
load_fail_vars = []
def var_shape_matched(var, shape):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist = os.path.exists(
os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
if var_exist:
var_shape = parse_shape_from_file(
os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
return var_shape == shape
return False
for x in fluid.default_main_program().list_vars():
if isinstance(x, fluid.framework.Parameter):
shape = tuple(fluid.global_scope().find_var(
x.name).get_tensor().shape())
if var_shape_matched(x, shape):
load_vars.append(x)
else:
load_fail_vars.append(x)
fluid.io.load_vars(
exe, dirname=cfg.TRAIN.PRETRAINED_MODEL_DIR, vars=load_vars)
for var in load_vars:
print_info("Parameter[{}] loaded sucessfully!".format(var.name))
for var in load_fail_vars:
print_info(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it.".format(var.name))
print_info("{}/{} pretrained parameters loaded successfully!".format(
len(load_vars),
len(load_vars) + len(load_fail_vars)))
else: else:
print_info( print_info(
'Pretrained model dir {} not exists, training from scratch...'. 'Pretrained model dir {} not exists, training from scratch...'.
...@@ -475,12 +407,9 @@ def train(cfg): ...@@ -475,12 +407,9 @@ def train(cfg):
step) step)
log_writer.add_scalar('Train/mean_acc', mean_acc, log_writer.add_scalar('Train/mean_acc', mean_acc,
step) step)
log_writer.add_scalar('Train/loss', avg_loss, log_writer.add_scalar('Train/loss', avg_loss, step)
step) log_writer.add_scalar('Train/lr', lr[0], step)
log_writer.add_scalar('Train/lr', lr[0], log_writer.add_scalar('Train/step/sec', speed, step)
step)
log_writer.add_scalar('Train/step/sec', speed,
step)
sys.stdout.flush() sys.stdout.flush()
avg_loss = 0.0 avg_loss = 0.0
cm.zero_matrix() cm.zero_matrix()
...@@ -503,16 +432,13 @@ def train(cfg): ...@@ -503,16 +432,13 @@ def train(cfg):
speed = args.log_steps / timer.elapsed_time() speed = args.log_steps / timer.elapsed_time()
print(( print((
"epoch={} step={} lr={:.5f} loss={:.4f} teacher loss={:.4f} distill loss={:.4f} step/sec={:.3f} | ETA {}" "epoch={} step={} lr={:.5f} loss={:.4f} teacher loss={:.4f} distill loss={:.4f} step/sec={:.3f} | ETA {}"
).format(epoch, step, lr[0], avg_loss, ).format(epoch, step, lr[0], avg_loss, avg_t_loss,
avg_t_loss, avg_d_loss, speed, avg_d_loss, speed,
calculate_eta(all_step - step, speed))) calculate_eta(all_step - step, speed)))
if args.use_vdl: if args.use_vdl:
log_writer.add_scalar('Train/loss', avg_loss, log_writer.add_scalar('Train/loss', avg_loss, step)
step) log_writer.add_scalar('Train/lr', lr[0], step)
log_writer.add_scalar('Train/lr', lr[0], log_writer.add_scalar('Train/speed', speed, step)
step)
log_writer.add_scalar('Train/speed', speed,
step)
sys.stdout.flush() sys.stdout.flush()
avg_loss = 0.0 avg_loss = 0.0
avg_t_loss = 0.0 avg_t_loss = 0.0
...@@ -527,7 +453,7 @@ def train(cfg): ...@@ -527,7 +453,7 @@ def train(cfg):
if (epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0 if (epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0
or epoch == cfg.SOLVER.NUM_EPOCHS) and cfg.TRAINER_ID == 0: or epoch == cfg.SOLVER.NUM_EPOCHS) and cfg.TRAINER_ID == 0:
ckpt_dir = save_checkpoint(exe, fluid.default_main_program(), epoch) ckpt_dir = save_checkpoint(fluid.default_main_program(), epoch)
if args.do_eval: if args.do_eval:
print("Evaluation start") print("Evaluation start")
...@@ -537,10 +463,8 @@ def train(cfg): ...@@ -537,10 +463,8 @@ def train(cfg):
use_gpu=args.use_gpu, use_gpu=args.use_gpu,
use_mpio=args.use_mpio) use_mpio=args.use_mpio)
if args.use_vdl: if args.use_vdl:
log_writer.add_scalar('Evaluate/mean_iou', mean_iou, log_writer.add_scalar('Evaluate/mean_iou', mean_iou, step)
step) log_writer.add_scalar('Evaluate/mean_acc', mean_acc, step)
log_writer.add_scalar('Evaluate/mean_acc', mean_acc,
step)
if mean_iou > best_mIoU: if mean_iou > best_mIoU:
best_mIoU = mean_iou best_mIoU = mean_iou
...@@ -560,11 +484,11 @@ def train(cfg): ...@@ -560,11 +484,11 @@ def train(cfg):
ckpt_dir=ckpt_dir, ckpt_dir=ckpt_dir,
log_writer=log_writer) log_writer=log_writer)
if cfg.TRAINER_ID == 0: if cfg.TRAINER_ID == 0:
ckpt_dir = save_checkpoint(exe, fluid.default_main_program(), epoch) ckpt_dir = save_checkpoint(fluid.default_main_program(), epoch)
# save final model # save final model
if cfg.TRAINER_ID == 0: if cfg.TRAINER_ID == 0:
save_checkpoint(exe, fluid.default_main_program(), 'final') save_checkpoint(fluid.default_main_program(), 'final')
def main(args): def main(args):
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -27,6 +27,7 @@ from models.libs.model_libs import separate_conv ...@@ -27,6 +27,7 @@ from models.libs.model_libs import separate_conv
from models.backbone.mobilenet_v2 import MobileNetV2 as mobilenet_backbone from models.backbone.mobilenet_v2 import MobileNetV2 as mobilenet_backbone
from models.backbone.xception import Xception as xception_backbone from models.backbone.xception import Xception as xception_backbone
def encoder(input): def encoder(input):
# 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv # 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv
# ASPP_WITH_SEP_CONV:默认为真,使用depthwise可分离卷积,否则使用普通卷积 # ASPP_WITH_SEP_CONV:默认为真,使用depthwise可分离卷积,否则使用普通卷积
...@@ -47,8 +48,7 @@ def encoder(input): ...@@ -47,8 +48,7 @@ def encoder(input):
with scope('encoder'): with scope('encoder'):
channel = 256 channel = 256
with scope("image_pool"): with scope("image_pool"):
image_avg = fluid.layers.reduce_mean( image_avg = fluid.layers.reduce_mean(input, [2, 3], keep_dim=True)
input, [2, 3], keep_dim=True)
image_avg = bn_relu( image_avg = bn_relu(
conv( conv(
image_avg, image_avg,
...@@ -191,7 +191,10 @@ def nas_backbone(input, arch): ...@@ -191,7 +191,10 @@ def nas_backbone(input, arch):
end_points = 8 end_points = 8
decode_point = 3 decode_point = 3
data, decode_shortcuts = arch( data, decode_shortcuts = arch(
input, end_points=end_points, return_block=decode_point, output_stride=16) input,
end_points=end_points,
return_block=decode_point,
output_stride=16)
decode_shortcut = decode_shortcuts[decode_point] decode_shortcut = decode_shortcuts[decode_point]
return data, decode_shortcut return data, decode_shortcut
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -123,7 +123,10 @@ def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs): ...@@ -123,7 +123,10 @@ def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs):
if ckpt_dir is not None: if ckpt_dir is not None:
print('load test model:', ckpt_dir) print('load test model:', ckpt_dir)
fluid.io.load_params(exe, ckpt_dir, main_program=test_prog) try:
fluid.load(test_prog, os.path.join(ckpt_dir, 'model'), exe)
except:
fluid.io.load_params(exe, ckpt_dir, main_program=test_prog)
# Use streaming confusion matrix to calculate mean_iou # Use streaming confusion matrix to calculate mean_iou
np.set_printoptions( np.set_printoptions(
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -31,7 +32,7 @@ __all__ = ["MobileNetV2SpaceSeg"] ...@@ -31,7 +32,7 @@ __all__ = ["MobileNetV2SpaceSeg"]
class MobileNetV2SpaceSeg(SearchSpaceBase): class MobileNetV2SpaceSeg(SearchSpaceBase):
def __init__(self, input_size, output_size, block_num, block_mask=None): def __init__(self, input_size, output_size, block_num, block_mask=None):
super(MobileNetV2SpaceSeg, self).__init__(input_size, output_size, super(MobileNetV2SpaceSeg, self).__init__(input_size, output_size,
block_num, block_mask) block_num, block_mask)
# self.head_num means the first convolution channel # self.head_num means the first convolution channel
self.head_num = np.array([3, 4, 8, 12, 16, 24, 32]) #7 self.head_num = np.array([3, 4, 8, 12, 16, 24, 32]) #7
# self.filter_num1 ~ self.filter_num6 means following convlution channel # self.filter_num1 ~ self.filter_num6 means following convlution channel
...@@ -48,7 +49,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase): ...@@ -48,7 +49,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
self.k_size = np.array([3, 5]) #2 self.k_size = np.array([3, 5]) #2
# self.multiply means expansion_factor of each _inverted_residual_unit # self.multiply means expansion_factor of each _inverted_residual_unit
self.multiply = np.array([1, 2, 3, 4, 6]) #5 self.multiply = np.array([1, 2, 3, 4, 6]) #5
# self.repeat means repeat_num _inverted_residual_unit in each _invresi_blocks # self.repeat means repeat_num _inverted_residual_unit in each _invresi_blocks
self.repeat = np.array([1, 2, 3, 4, 5, 6]) #6 self.repeat = np.array([1, 2, 3, 4, 5, 6]) #6
def init_tokens(self): def init_tokens(self):
...@@ -72,7 +73,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase): ...@@ -72,7 +73,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
def range_table(self): def range_table(self):
""" """
Get range table of current search space, constrains the range of tokens. Get range table of current search space, constrains the range of tokens.
""" """
# head_num + 6 * [multiple(expansion_factor), filter_num, repeat, kernel_size] # head_num + 6 * [multiple(expansion_factor), filter_num, repeat, kernel_size]
# yapf: disable # yapf: disable
...@@ -95,8 +96,8 @@ class MobileNetV2SpaceSeg(SearchSpaceBase): ...@@ -95,8 +96,8 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
tokens = self.init_tokens() tokens = self.init_tokens()
self.bottleneck_params_list = [] self.bottleneck_params_list = []
self.bottleneck_params_list.append( self.bottleneck_params_list.append((1, self.head_num[tokens[0]], 1, 1,
(1, self.head_num[tokens[0]], 1, 1, 3)) 3))
self.bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.multiply[tokens[1]], self.filter_num1[tokens[2]], (self.multiply[tokens[1]], self.filter_num1[tokens[2]],
self.repeat[tokens[3]], 2, self.k_size[tokens[4]])) self.repeat[tokens[3]], 2, self.k_size[tokens[4]]))
...@@ -150,7 +151,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase): ...@@ -150,7 +151,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
return (True if count == points else False) return (True if count == points else False)
#conv1 #conv1
# all padding is 'SAME' in the conv2d, can compute the actual padding automatic. # all padding is 'SAME' in the conv2d, can compute the actual padding automatic.
input = conv_bn_layer( input = conv_bn_layer(
input, input,
num_filters=int(32 * self.scale), num_filters=int(32 * self.scale),
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -47,6 +47,7 @@ from model_builder import parse_shape_from_file ...@@ -47,6 +47,7 @@ from model_builder import parse_shape_from_file
from eval_nas import evaluate from eval_nas import evaluate
from vis import visualize from vis import visualize
from utils import dist_utils from utils import dist_utils
from utils.load_model_utils import load_pretrained_weights
from mobilenetv2_search_space import MobileNetV2SpaceSeg from mobilenetv2_search_space import MobileNetV2SpaceSeg
from paddleslim.nas.search_space.search_space_factory import SearchSpaceFactory from paddleslim.nas.search_space.search_space_factory import SearchSpaceFactory
...@@ -116,38 +117,7 @@ def parse_args(): ...@@ -116,38 +117,7 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def save_vars(executor, dirname, program=None, vars=None): def save_checkpoint(program, ckpt_name):
"""
Temporary resolution for Win save variables compatability.
Will fix in PaddlePaddle v1.5.2
"""
save_program = fluid.Program()
save_block = save_program.global_block()
for each_var in vars:
# NOTE: don't save the variable which type is RAW
if each_var.type == fluid.core.VarDesc.VarType.RAW:
continue
new_var = save_block.create_var(
name=each_var.name,
shape=each_var.shape,
dtype=each_var.dtype,
type=each_var.type,
lod_level=each_var.lod_level,
persistable=True)
file_path = os.path.join(dirname, new_var.name)
file_path = os.path.normpath(file_path)
save_block.append_op(
type='save',
inputs={'X': [new_var]},
outputs={},
attrs={'file_path': file_path})
executor.run(save_program)
def save_checkpoint(exe, program, ckpt_name):
""" """
Save checkpoint for evaluation or resume training Save checkpoint for evaluation or resume training
""" """
...@@ -156,29 +126,22 @@ def save_checkpoint(exe, program, ckpt_name): ...@@ -156,29 +126,22 @@ def save_checkpoint(exe, program, ckpt_name):
if not os.path.isdir(ckpt_dir): if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir) os.makedirs(ckpt_dir)
save_vars( fluid.save(program, os.path.join(ckpt_dir, 'model'))
exe,
ckpt_dir,
program,
vars=list(filter(fluid.io.is_persistable, program.list_vars())))
return ckpt_dir return ckpt_dir
def load_checkpoint(exe, program): def load_checkpoint(exe, program):
""" """
Load checkpoiont from pretrained model directory for resume training Load checkpoiont for resuming training
""" """
print('Resume model training from:', cfg.TRAIN.RESUME_MODEL_DIR)
if not os.path.exists(cfg.TRAIN.RESUME_MODEL_DIR):
raise ValueError("TRAIN.PRETRAIN_MODEL {} not exist!".format(
cfg.TRAIN.RESUME_MODEL_DIR))
fluid.io.load_persistables(
exe, cfg.TRAIN.RESUME_MODEL_DIR, main_program=program)
model_path = cfg.TRAIN.RESUME_MODEL_DIR model_path = cfg.TRAIN.RESUME_MODEL_DIR
print('Resume model training from:', model_path)
if not os.path.exists(model_path):
raise ValueError(
"TRAIN.PRETRAIN_MODEL {} not exist!".format(model_path))
fluid.load(program, os.path.join(model_path, 'model'), exe)
# Check is path ended by path spearator # Check is path ended by path spearator
if model_path[-1] == os.sep: if model_path[-1] == os.sep:
model_path = model_path[0:-1] model_path = model_path[0:-1]
...@@ -193,7 +156,6 @@ def load_checkpoint(exe, program): ...@@ -193,7 +156,6 @@ def load_checkpoint(exe, program):
else: else:
raise ValueError("Resume model path is not valid!") raise ValueError("Resume model path is not valid!")
print("Model checkpoint loaded successfully!") print("Model checkpoint loaded successfully!")
return begin_epoch return begin_epoch
...@@ -245,8 +207,6 @@ def train(cfg): ...@@ -245,8 +207,6 @@ def train(cfg):
yield item[0], item[1], item[2] yield item[0], item[1], item[2]
# Get device environment # Get device environment
# places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
# place = places[0]
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace()
places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places() places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
...@@ -326,43 +286,8 @@ def train(cfg): ...@@ -326,43 +286,8 @@ def train(cfg):
begin_epoch = load_checkpoint(exe, train_prog) begin_epoch = load_checkpoint(exe, train_prog)
# Load pretrained model # Load pretrained model
elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL_DIR): elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL_DIR):
print_info('Pretrained model dir: ', cfg.TRAIN.PRETRAINED_MODEL_DIR) load_pretrained_weights(exe, train_prog,
load_vars = [] cfg.TRAIN.PRETRAINED_MODEL_DIR)
load_fail_vars = []
def var_shape_matched(var, shape):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist = os.path.exists(
os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
if var_exist:
var_shape = parse_shape_from_file(
os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
return var_shape == shape
return False
for x in train_prog.list_vars():
if isinstance(x, fluid.framework.Parameter):
shape = tuple(fluid.global_scope().find_var(
x.name).get_tensor().shape())
if var_shape_matched(x, shape):
load_vars.append(x)
else:
load_fail_vars.append(x)
fluid.io.load_vars(
exe, dirname=cfg.TRAIN.PRETRAINED_MODEL_DIR, vars=load_vars)
for var in load_vars:
print_info("Parameter[{}] loaded sucessfully!".format(var.name))
for var in load_fail_vars:
print_info(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it.".format(var.name))
print_info(
"{}/{} pretrained parameters loaded successfully!".format(
len(load_vars),
len(load_vars) + len(load_fail_vars)))
else: else:
print_info( print_info(
'Pretrained model dir {} not exists, training from scratch...'. 'Pretrained model dir {} not exists, training from scratch...'.
...@@ -419,8 +344,7 @@ def train(cfg): ...@@ -419,8 +344,7 @@ def train(cfg):
except Exception as e: except Exception as e:
print(e) print(e)
if epoch > cfg.SLIM.NAS_START_EVAL_EPOCH: if epoch > cfg.SLIM.NAS_START_EVAL_EPOCH:
ckpt_dir = save_checkpoint(exe, train_prog, ckpt_dir = save_checkpoint(train_prog, '{}_tmp'.format(port))
'{}_tmp'.format(port))
_, mean_iou, _, mean_acc = evaluate( _, mean_iou, _, mean_acc = evaluate(
cfg=cfg, cfg=cfg,
arch=arch, arch=arch,
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -46,6 +46,7 @@ from models.model_builder import parse_shape_from_file ...@@ -46,6 +46,7 @@ from models.model_builder import parse_shape_from_file
from eval_prune import evaluate from eval_prune import evaluate
from vis import visualize from vis import visualize
from utils import dist_utils from utils import dist_utils
from utils.load_model_utils import load_pretrained_weights
from paddleslim.prune import Pruner, save_model from paddleslim.prune import Pruner, save_model
from paddleslim.analysis import flops from paddleslim.analysis import flops
...@@ -285,42 +286,7 @@ def train(cfg): ...@@ -285,42 +286,7 @@ def train(cfg):
begin_epoch = load_checkpoint(exe, train_prog) begin_epoch = load_checkpoint(exe, train_prog)
# Load pretrained model # Load pretrained model
elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL_DIR): elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL_DIR):
print_info('Pretrained model dir: ', cfg.TRAIN.PRETRAINED_MODEL_DIR) load_pretrained_weights(exe, train_prog, cfg.TRAIN.PRETRAINED_MODEL_DIR)
load_vars = []
load_fail_vars = []
def var_shape_matched(var, shape):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist = os.path.exists(
os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
if var_exist:
var_shape = parse_shape_from_file(
os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
return var_shape == shape
return False
for x in train_prog.list_vars():
if isinstance(x, fluid.framework.Parameter):
shape = tuple(fluid.global_scope().find_var(
x.name).get_tensor().shape())
if var_shape_matched(x, shape):
load_vars.append(x)
else:
load_fail_vars.append(x)
fluid.io.load_vars(
exe, dirname=cfg.TRAIN.PRETRAINED_MODEL_DIR, vars=load_vars)
for var in load_vars:
print_info("Parameter[{}] loaded sucessfully!".format(var.name))
for var in load_fail_vars:
print_info(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it.".format(var.name))
print_info("{}/{} pretrained parameters loaded successfully!".format(
len(load_vars),
len(load_vars) + len(load_fail_vars)))
else: else:
print_info( print_info(
'Pretrained model dir {} not exists, training from scratch...'. 'Pretrained model dir {} not exists, training from scratch...'.
...@@ -409,12 +375,9 @@ def train(cfg): ...@@ -409,12 +375,9 @@ def train(cfg):
step) step)
log_writer.add_scalar('Train/mean_acc', mean_acc, log_writer.add_scalar('Train/mean_acc', mean_acc,
step) step)
log_writer.add_scalar('Train/loss', avg_loss, log_writer.add_scalar('Train/loss', avg_loss, step)
step) log_writer.add_scalar('Train/lr', lr[0], step)
log_writer.add_scalar('Train/lr', lr[0], log_writer.add_scalar('Train/step/sec', speed, step)
step)
log_writer.add_scalar('Train/step/sec', speed,
step)
sys.stdout.flush() sys.stdout.flush()
avg_loss = 0.0 avg_loss = 0.0
cm.zero_matrix() cm.zero_matrix()
...@@ -436,12 +399,9 @@ def train(cfg): ...@@ -436,12 +399,9 @@ def train(cfg):
).format(epoch, step, lr[0], avg_loss, speed, ).format(epoch, step, lr[0], avg_loss, speed,
calculate_eta(all_step - step, speed))) calculate_eta(all_step - step, speed)))
if args.use_vdl: if args.use_vdl:
log_writer.add_scalar('Train/loss', avg_loss, log_writer.add_scalar('Train/loss', avg_loss, step)
step) log_writer.add_scalar('Train/lr', lr[0], step)
log_writer.add_scalar('Train/lr', lr[0], log_writer.add_scalar('Train/speed', speed, step)
step)
log_writer.add_scalar('Train/speed', speed,
step)
sys.stdout.flush() sys.stdout.flush()
avg_loss = 0.0 avg_loss = 0.0
timer.restart() timer.restart()
...@@ -464,10 +424,8 @@ def train(cfg): ...@@ -464,10 +424,8 @@ def train(cfg):
use_gpu=args.use_gpu, use_gpu=args.use_gpu,
use_mpio=args.use_mpio) use_mpio=args.use_mpio)
if args.use_vdl: if args.use_vdl:
log_writer.add_scalar('Evaluate/mean_iou', mean_iou, log_writer.add_scalar('Evaluate/mean_iou', mean_iou, step)
step) log_writer.add_scalar('Evaluate/mean_acc', mean_acc, step)
log_writer.add_scalar('Evaluate/mean_acc', mean_acc,
step)
# Use VisualDL to visualize results # Use VisualDL to visualize results
if args.use_vdl and cfg.DATASET.VIS_FILE_LIST is not None: if args.use_vdl and cfg.DATASET.VIS_FILE_LIST is not None:
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding: utf8 # coding: utf8
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -40,7 +40,8 @@ from models.model_builder import parse_shape_from_file ...@@ -40,7 +40,8 @@ from models.model_builder import parse_shape_from_file
from eval_quant import evaluate from eval_quant import evaluate
from vis import visualize from vis import visualize
from utils import dist_utils from utils import dist_utils
from train import save_vars, save_checkpoint, load_checkpoint, update_best_model, print_info from utils.load_model_utils import load_pretrained_weights
from train import update_best_model, print_info
from paddleslim.quant import quant_aware from paddleslim.quant import quant_aware
...@@ -103,6 +104,55 @@ def parse_args(): ...@@ -103,6 +104,55 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def save_checkpoint(exe, program, ckpt_name):
"""
Save checkpoint for evaluation or resume training
"""
ckpt_dir = os.path.join(cfg.TRAIN.MODEL_SAVE_DIR, str(ckpt_name))
print("Save model checkpoint to {}".format(ckpt_dir))
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
fluid.io.save_vars(
exe,
ckpt_dir,
program,
vars=list(filter(fluid.io.is_persistable, program.list_vars())))
return ckpt_dir
def load_checkpoint(exe, program):
"""
Load checkpoiont from pretrained model directory for resume training
"""
print('Resume model training from:', cfg.TRAIN.RESUME_MODEL_DIR)
if not os.path.exists(cfg.TRAIN.RESUME_MODEL_DIR):
raise ValueError("TRAIN.PRETRAIN_MODEL {} not exist!".format(
cfg.TRAIN.RESUME_MODEL_DIR))
fluid.io.load_persistables(
exe, cfg.TRAIN.RESUME_MODEL_DIR, main_program=program)
model_path = cfg.TRAIN.RESUME_MODEL_DIR
# Check is path ended by path spearator
if model_path[-1] == os.sep:
model_path = model_path[0:-1]
epoch_name = os.path.basename(model_path)
# If resume model is final model
if epoch_name == 'final':
begin_epoch = cfg.SOLVER.NUM_EPOCHS
# If resume model path is end of digit, restore epoch status
elif epoch_name.isdigit():
epoch = int(epoch_name)
begin_epoch = epoch + 1
else:
raise ValueError("Resume model path is not valid!")
print("Model checkpoint loaded successfully!")
return begin_epoch
def train_quant(cfg): def train_quant(cfg):
startup_prog = fluid.Program() startup_prog = fluid.Program()
train_prog = fluid.Program() train_prog = fluid.Program()
...@@ -182,42 +232,7 @@ def train_quant(cfg): ...@@ -182,42 +232,7 @@ def train_quant(cfg):
begin_epoch = load_checkpoint(exe, train_prog) begin_epoch = load_checkpoint(exe, train_prog)
# Load pretrained model # Load pretrained model
elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL_DIR): elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL_DIR):
print_info('Pretrained model dir: ', cfg.TRAIN.PRETRAINED_MODEL_DIR) load_pretrained_weights(exe, train_prog, cfg.TRAIN.PRETRAINED_MODEL_DIR)
load_vars = []
load_fail_vars = []
def var_shape_matched(var, shape):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist = os.path.exists(
os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
if var_exist:
var_shape = parse_shape_from_file(
os.path.join(cfg.TRAIN.PRETRAINED_MODEL_DIR, var.name))
return var_shape == shape
return False
for x in train_prog.list_vars():
if isinstance(x, fluid.framework.Parameter):
shape = tuple(fluid.global_scope().find_var(
x.name).get_tensor().shape())
if var_shape_matched(x, shape):
load_vars.append(x)
else:
load_fail_vars.append(x)
fluid.io.load_vars(
exe, dirname=cfg.TRAIN.PRETRAINED_MODEL_DIR, vars=load_vars)
for var in load_vars:
print_info("Parameter[{}] loaded sucessfully!".format(var.name))
for var in load_fail_vars:
print_info(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it.".format(var.name))
print_info("{}/{} pretrained parameters loaded successfully!".format(
len(load_vars),
len(load_vars) + len(load_fail_vars)))
else: else:
print_info( print_info(
'Pretrained model dir {} not exists, training from scratch...'. 'Pretrained model dir {} not exists, training from scratch...'.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册