Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
61645b1d
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
289
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
61645b1d
编写于
5月 25, 2020
作者:
C
chenguowei01
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/PaddleSeg
into develop
上级
d3ef4d2e
867d4a5b
变更
156
隐藏空白更改
内联
并排
Showing
156 changed file
with
2009 addition
and
1127 deletion
+2009
-1127
README.md
README.md
+1
-1
contrib/ACE2P/__init__.py
contrib/ACE2P/__init__.py
+14
-0
contrib/ACE2P/config.py
contrib/ACE2P/config.py
+17
-3
contrib/ACE2P/download_ACE2P.py
contrib/ACE2P/download_ACE2P.py
+4
-3
contrib/ACE2P/infer.py
contrib/ACE2P/infer.py
+37
-17
contrib/ACE2P/reader.py
contrib/ACE2P/reader.py
+31
-17
contrib/ACE2P/utils/__init__.py
contrib/ACE2P/utils/__init__.py
+14
-0
contrib/ACE2P/utils/palette.py
contrib/ACE2P/utils/palette.py
+1
-0
contrib/ACE2P/utils/util.py
contrib/ACE2P/utils/util.py
+22
-8
contrib/HumanSeg/datasets/__init__.py
contrib/HumanSeg/datasets/__init__.py
+4
-3
contrib/HumanSeg/datasets/dataset.py
contrib/HumanSeg/datasets/dataset.py
+2
-1
contrib/HumanSeg/datasets/shared_queue/__init__.py
contrib/HumanSeg/datasets/shared_queue/__init__.py
+2
-1
contrib/HumanSeg/datasets/shared_queue/queue.py
contrib/HumanSeg/datasets/shared_queue/queue.py
+2
-1
contrib/HumanSeg/datasets/shared_queue/sharedmemory.py
contrib/HumanSeg/datasets/shared_queue/sharedmemory.py
+2
-4
contrib/HumanSeg/export.py
contrib/HumanSeg/export.py
+15
-0
contrib/HumanSeg/infer.py
contrib/HumanSeg/infer.py
+15
-0
contrib/HumanSeg/models/__init__.py
contrib/HumanSeg/models/__init__.py
+15
-0
contrib/HumanSeg/models/humanseg.py
contrib/HumanSeg/models/humanseg.py
+4
-3
contrib/HumanSeg/models/load_model.py
contrib/HumanSeg/models/load_model.py
+2
-1
contrib/HumanSeg/nets/__init__.py
contrib/HumanSeg/nets/__init__.py
+15
-0
contrib/HumanSeg/nets/backbone/__init__.py
contrib/HumanSeg/nets/backbone/__init__.py
+15
-0
contrib/HumanSeg/nets/backbone/mobilenet_v2.py
contrib/HumanSeg/nets/backbone/mobilenet_v2.py
+3
-1
contrib/HumanSeg/nets/backbone/xception.py
contrib/HumanSeg/nets/backbone/xception.py
+1
-1
contrib/HumanSeg/nets/deeplabv3p.py
contrib/HumanSeg/nets/deeplabv3p.py
+1
-1
contrib/HumanSeg/nets/hrnet.py
contrib/HumanSeg/nets/hrnet.py
+1
-1
contrib/HumanSeg/nets/libs.py
contrib/HumanSeg/nets/libs.py
+1
-1
contrib/HumanSeg/nets/seg_modules.py
contrib/HumanSeg/nets/seg_modules.py
+2
-1
contrib/HumanSeg/nets/shufflenet_slim.py
contrib/HumanSeg/nets/shufflenet_slim.py
+15
-0
contrib/HumanSeg/pretrained_weights/download_pretrained_weights.py
...umanSeg/pretrained_weights/download_pretrained_weights.py
+4
-3
contrib/HumanSeg/quant_offline.py
contrib/HumanSeg/quant_offline.py
+15
-0
contrib/HumanSeg/quant_online.py
contrib/HumanSeg/quant_online.py
+15
-0
contrib/HumanSeg/train.py
contrib/HumanSeg/train.py
+15
-0
contrib/HumanSeg/transforms/__init__.py
contrib/HumanSeg/transforms/__init__.py
+4
-3
contrib/HumanSeg/transforms/functional.py
contrib/HumanSeg/transforms/functional.py
+4
-3
contrib/HumanSeg/transforms/transforms.py
contrib/HumanSeg/transforms/transforms.py
+4
-3
contrib/HumanSeg/utils/__init__.py
contrib/HumanSeg/utils/__init__.py
+4
-3
contrib/HumanSeg/utils/humanseg_postprocess.py
contrib/HumanSeg/utils/humanseg_postprocess.py
+15
-0
contrib/HumanSeg/utils/logging.py
contrib/HumanSeg/utils/logging.py
+2
-1
contrib/HumanSeg/utils/metrics.py
contrib/HumanSeg/utils/metrics.py
+1
-1
contrib/HumanSeg/utils/post_quantization.py
contrib/HumanSeg/utils/post_quantization.py
+2
-1
contrib/HumanSeg/utils/utils.py
contrib/HumanSeg/utils/utils.py
+5
-6
contrib/HumanSeg/val.py
contrib/HumanSeg/val.py
+15
-0
contrib/HumanSeg/video_infer.py
contrib/HumanSeg/video_infer.py
+15
-0
contrib/LaneNet/data_aug.py
contrib/LaneNet/data_aug.py
+4
-2
contrib/LaneNet/dataset/download_tusimple.py
contrib/LaneNet/dataset/download_tusimple.py
+4
-3
contrib/LaneNet/eval.py
contrib/LaneNet/eval.py
+5
-2
contrib/LaneNet/loss.py
contrib/LaneNet/loss.py
+1
-1
contrib/LaneNet/models/__init__.py
contrib/LaneNet/models/__init__.py
+1
-1
contrib/LaneNet/models/model_builder.py
contrib/LaneNet/models/model_builder.py
+1
-1
contrib/LaneNet/models/modeling/lanenet.py
contrib/LaneNet/models/modeling/lanenet.py
+179
-58
contrib/LaneNet/reader.py
contrib/LaneNet/reader.py
+21
-13
contrib/LaneNet/train.py
contrib/LaneNet/train.py
+13
-91
contrib/LaneNet/utils/__init__.py
contrib/LaneNet/utils/__init__.py
+14
-0
contrib/LaneNet/utils/config.py
contrib/LaneNet/utils/config.py
+6
-6
contrib/LaneNet/utils/dist_utils.py
contrib/LaneNet/utils/dist_utils.py
+10
-9
contrib/LaneNet/utils/generate_tusimple_dataset.py
contrib/LaneNet/utils/generate_tusimple_dataset.py
+65
-21
contrib/LaneNet/utils/lanenet_postprocess.py
contrib/LaneNet/utils/lanenet_postprocess.py
+75
-41
contrib/LaneNet/utils/load_model_utils.py
contrib/LaneNet/utils/load_model_utils.py
+126
-0
contrib/LaneNet/vis.py
contrib/LaneNet/vis.py
+21
-13
contrib/MechanicalIndustryMeter/download_mini_mechanical_industry_meter.py
...lIndustryMeter/download_mini_mechanical_industry_meter.py
+4
-3
contrib/MechanicalIndustryMeter/download_unet_mechanical_industry_meter.py
...lIndustryMeter/download_unet_mechanical_industry_meter.py
+6
-4
contrib/RemoteSensing/__init__.py
contrib/RemoteSensing/__init__.py
+4
-3
contrib/RemoteSensing/models/__init__.py
contrib/RemoteSensing/models/__init__.py
+15
-0
contrib/RemoteSensing/models/base.py
contrib/RemoteSensing/models/base.py
+10
-9
contrib/RemoteSensing/models/load_model.py
contrib/RemoteSensing/models/load_model.py
+2
-1
contrib/RemoteSensing/models/unet.py
contrib/RemoteSensing/models/unet.py
+10
-9
contrib/RemoteSensing/nets/__init__.py
contrib/RemoteSensing/nets/__init__.py
+15
-0
contrib/RemoteSensing/nets/libs.py
contrib/RemoteSensing/nets/libs.py
+1
-1
contrib/RemoteSensing/nets/loss.py
contrib/RemoteSensing/nets/loss.py
+2
-1
contrib/RemoteSensing/nets/unet.py
contrib/RemoteSensing/nets/unet.py
+1
-1
contrib/RemoteSensing/predict_demo.py
contrib/RemoteSensing/predict_demo.py
+15
-0
contrib/RemoteSensing/readers/__init__.py
contrib/RemoteSensing/readers/__init__.py
+2
-1
contrib/RemoteSensing/readers/base.py
contrib/RemoteSensing/readers/base.py
+2
-1
contrib/RemoteSensing/readers/reader.py
contrib/RemoteSensing/readers/reader.py
+2
-1
contrib/RemoteSensing/tools/create_dataset_list.py
contrib/RemoteSensing/tools/create_dataset_list.py
+1
-1
contrib/RemoteSensing/tools/split_dataset_list.py
contrib/RemoteSensing/tools/split_dataset_list.py
+1
-1
contrib/RemoteSensing/train_demo.py
contrib/RemoteSensing/train_demo.py
+15
-0
contrib/RemoteSensing/transforms/__init__.py
contrib/RemoteSensing/transforms/__init__.py
+4
-3
contrib/RemoteSensing/transforms/ops.py
contrib/RemoteSensing/transforms/ops.py
+2
-1
contrib/RemoteSensing/transforms/transforms.py
contrib/RemoteSensing/transforms/transforms.py
+1
-1
contrib/RemoteSensing/utils/__init__.py
contrib/RemoteSensing/utils/__init__.py
+4
-3
contrib/RemoteSensing/utils/logging.py
contrib/RemoteSensing/utils/logging.py
+2
-1
contrib/RemoteSensing/utils/metrics.py
contrib/RemoteSensing/utils/metrics.py
+1
-1
contrib/RemoteSensing/utils/pretrain_weights.py
contrib/RemoteSensing/utils/pretrain_weights.py
+15
-0
contrib/RemoteSensing/utils/utils.py
contrib/RemoteSensing/utils/utils.py
+5
-6
contrib/RoadLine/__init__.py
contrib/RoadLine/__init__.py
+14
-0
contrib/RoadLine/config.py
contrib/RoadLine/config.py
+20
-6
contrib/RoadLine/download_RoadLine.py
contrib/RoadLine/download_RoadLine.py
+4
-3
contrib/RoadLine/infer.py
contrib/RoadLine/infer.py
+39
-19
contrib/RoadLine/utils/__init__.py
contrib/RoadLine/utils/__init__.py
+14
-0
contrib/RoadLine/utils/palette.py
contrib/RoadLine/utils/palette.py
+15
-9
contrib/RoadLine/utils/util.py
contrib/RoadLine/utils/util.py
+22
-8
dataset/convert_voc2012.py
dataset/convert_voc2012.py
+11
-6
dataset/download_and_convert_voc2012.py
dataset/download_and_convert_voc2012.py
+9
-8
dataset/download_cityscapes.py
dataset/download_cityscapes.py
+4
-3
dataset/download_mini_deepglobe_road_extraction.py
dataset/download_mini_deepglobe_road_extraction.py
+4
-3
dataset/download_optic.py
dataset/download_optic.py
+4
-3
dataset/download_pet.py
dataset/download_pet.py
+4
-3
deploy/python/infer.py
deploy/python/infer.py
+10
-6
pdseg/__init__.py
pdseg/__init__.py
+2
-2
pdseg/check.py
pdseg/check.py
+21
-3
pdseg/data_aug.py
pdseg/data_aug.py
+3
-3
pdseg/data_utils.py
pdseg/data_utils.py
+16
-2
pdseg/eval.py
pdseg/eval.py
+5
-6
pdseg/export_model.py
pdseg/export_model.py
+11
-5
pdseg/loss.py
pdseg/loss.py
+1
-1
pdseg/lovasz_losses.py
pdseg/lovasz_losses.py
+1
-1
pdseg/metrics.py
pdseg/metrics.py
+1
-1
pdseg/models/__init__.py
pdseg/models/__init__.py
+1
-1
pdseg/models/backbone/__init__.py
pdseg/models/backbone/__init__.py
+14
-0
pdseg/models/backbone/mobilenet_v2.py
pdseg/models/backbone/mobilenet_v2.py
+1
-1
pdseg/models/backbone/resnet.py
pdseg/models/backbone/resnet.py
+15
-13
pdseg/models/backbone/vgg.py
pdseg/models/backbone/vgg.py
+3
-2
pdseg/models/backbone/xception.py
pdseg/models/backbone/xception.py
+1
-1
pdseg/models/libs/__init__.py
pdseg/models/libs/__init__.py
+14
-0
pdseg/models/libs/model_libs.py
pdseg/models/libs/model_libs.py
+2
-2
pdseg/models/model_builder.py
pdseg/models/model_builder.py
+1
-1
pdseg/models/modeling/__init__.py
pdseg/models/modeling/__init__.py
+14
-0
pdseg/models/modeling/deeplab.py
pdseg/models/modeling/deeplab.py
+1
-1
pdseg/models/modeling/fast_scnn.py
pdseg/models/modeling/fast_scnn.py
+1
-1
pdseg/models/modeling/hrnet.py
pdseg/models/modeling/hrnet.py
+153
-52
pdseg/models/modeling/icnet.py
pdseg/models/modeling/icnet.py
+1
-1
pdseg/models/modeling/pspnet.py
pdseg/models/modeling/pspnet.py
+37
-34
pdseg/models/modeling/unet.py
pdseg/models/modeling/unet.py
+1
-1
pdseg/reader.py
pdseg/reader.py
+5
-3
pdseg/solver.py
pdseg/solver.py
+1
-1
pdseg/tools/__init__.py
pdseg/tools/__init__.py
+1
-1
pdseg/tools/create_dataset_list.py
pdseg/tools/create_dataset_list.py
+21
-26
pdseg/tools/gray2pseudo_color.py
pdseg/tools/gray2pseudo_color.py
+21
-11
pdseg/tools/jingling2seg.py
pdseg/tools/jingling2seg.py
+14
-1
pdseg/tools/labelme2seg.py
pdseg/tools/labelme2seg.py
+14
-1
pdseg/train.py
pdseg/train.py
+22
-108
pdseg/utils/__init__.py
pdseg/utils/__init__.py
+14
-0
pdseg/utils/collect.py
pdseg/utils/collect.py
+14
-9
pdseg/utils/config.py
pdseg/utils/config.py
+4
-4
pdseg/utils/dist_utils.py
pdseg/utils/dist_utils.py
+1
-1
pdseg/utils/fp16_utils.py
pdseg/utils/fp16_utils.py
+17
-1
pdseg/utils/load_model_utils.py
pdseg/utils/load_model_utils.py
+128
-0
pdseg/utils/timer.py
pdseg/utils/timer.py
+4
-3
pdseg/vis.py
pdseg/vis.py
+14
-16
pretrained_model/download_model.py
pretrained_model/download_model.py
+5
-4
slim/distillation/model_builder.py
slim/distillation/model_builder.py
+1
-1
slim/distillation/train_distill.py
slim/distillation/train_distill.py
+31
-107
slim/nas/deeplab.py
slim/nas/deeplab.py
+7
-4
slim/nas/eval_nas.py
slim/nas/eval_nas.py
+5
-2
slim/nas/mobilenetv2_search_space.py
slim/nas/mobilenetv2_search_space.py
+10
-9
slim/nas/model_builder.py
slim/nas/model_builder.py
+1
-1
slim/nas/train_nas.py
slim/nas/train_nas.py
+14
-90
slim/prune/eval_prune.py
slim/prune/eval_prune.py
+1
-1
slim/prune/train_prune.py
slim/prune/train_prune.py
+11
-53
slim/quantization/eval_quant.py
slim/quantization/eval_quant.py
+1
-1
slim/quantization/export_model.py
slim/quantization/export_model.py
+1
-1
slim/quantization/train_quant.py
slim/quantization/train_quant.py
+53
-38
test/local_test_cityscapes.py
test/local_test_cityscapes.py
+4
-3
test/local_test_pet.py
test/local_test_pet.py
+4
-3
test/test_utils.py
test/test_utils.py
+4
-3
未找到文件。
README.md
浏览文件 @
61645b1d
...
@@ -35,7 +35,7 @@ PaddleSeg是基于[PaddlePaddle](https://www.paddlepaddle.org.cn)开发的端到
...
@@ -35,7 +35,7 @@ PaddleSeg是基于[PaddlePaddle](https://www.paddlepaddle.org.cn)开发的端到
-
**高性能**
-
**高性能**
PaddleSeg支持多进程I/O、多卡并行
、跨卡Batch Norm同步
等训练加速策略,结合飞桨核心框架的显存优化功能,可大幅度减少分割模型的显存开销,让开发者更低成本、更高效地完成图像分割训练。
PaddleSeg支持多进程I/O、多卡并行等训练加速策略,结合飞桨核心框架的显存优化功能,可大幅度减少分割模型的显存开销,让开发者更低成本、更高效地完成图像分割训练。
-
**工业级部署**
-
**工业级部署**
...
...
contrib/ACE2P/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
contrib/ACE2P/config.py
浏览文件 @
61645b1d
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
utils.util
import
AttrDict
,
merge_cfg_from_args
,
get_arguments
from
utils.util
import
AttrDict
,
merge_cfg_from_args
,
get_arguments
import
os
import
os
...
@@ -19,10 +33,10 @@ cfg.class_num = 20
...
@@ -19,10 +33,10 @@ cfg.class_num = 20
# 均值, 图像预处理减去的均值
# 均值, 图像预处理减去的均值
cfg
.
MEAN
=
0.406
,
0.456
,
0.485
cfg
.
MEAN
=
0.406
,
0.456
,
0.485
# 标准差,图像预处理除以标准差
# 标准差,图像预处理除以标准差
cfg
.
STD
=
0.225
,
0.224
,
0.229
cfg
.
STD
=
0.225
,
0.224
,
0.229
# 多尺度预测时图像尺寸
# 多尺度预测时图像尺寸
cfg
.
multi_scales
=
(
377
,
377
),
(
473
,
473
),
(
567
,
567
)
cfg
.
multi_scales
=
(
377
,
377
),
(
473
,
473
),
(
567
,
567
)
# 多尺度预测时图像是否水平翻转
# 多尺度预测时图像是否水平翻转
cfg
.
flip
=
True
cfg
.
flip
=
True
...
...
contrib/ACE2P/download_ACE2P.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/ACE2P/infer.py
浏览文件 @
61645b1d
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
os
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
...
@@ -12,18 +26,19 @@ config = importlib.import_module('config')
...
@@ -12,18 +26,19 @@ config = importlib.import_module('config')
cfg
=
getattr
(
config
,
'cfg'
)
cfg
=
getattr
(
config
,
'cfg'
)
# paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启
# paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启
os
.
environ
[
'FLAGS_eager_delete_tensor_gb'
]
=
'0.0'
os
.
environ
[
'FLAGS_eager_delete_tensor_gb'
]
=
'0.0'
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
# 预测数据集类
# 预测数据集类
class
TestDataSet
():
class
TestDataSet
():
def
__init__
(
self
):
def
__init__
(
self
):
self
.
data_dir
=
cfg
.
data_dir
self
.
data_dir
=
cfg
.
data_dir
self
.
data_list_file
=
cfg
.
data_list_file
self
.
data_list_file
=
cfg
.
data_list_file
self
.
data_list
=
self
.
get_data_list
()
self
.
data_list
=
self
.
get_data_list
()
self
.
data_num
=
len
(
self
.
data_list
)
self
.
data_num
=
len
(
self
.
data_list
)
def
get_data_list
(
self
):
def
get_data_list
(
self
):
# 获取预测图像路径列表
# 获取预测图像路径列表
data_list
=
[]
data_list
=
[]
...
@@ -56,10 +71,10 @@ class TestDataSet():
...
@@ -56,10 +71,10 @@ class TestDataSet():
img_path
=
self
.
data_list
[
index
]
img_path
=
self
.
data_list
[
index
]
img
=
cv2
.
imread
(
img_path
,
cv2
.
IMREAD_COLOR
)
img
=
cv2
.
imread
(
img_path
,
cv2
.
IMREAD_COLOR
)
if
img
is
None
:
if
img
is
None
:
return
img
,
img
,
img_path
,
None
return
img
,
img
,
img_path
,
None
img_name
=
img_path
.
split
(
os
.
sep
)[
-
1
]
img_name
=
img_path
.
split
(
os
.
sep
)[
-
1
]
name_prefix
=
img_name
.
replace
(
'.'
+
img_name
.
split
(
'.'
)[
-
1
],
''
)
name_prefix
=
img_name
.
replace
(
'.'
+
img_name
.
split
(
'.'
)[
-
1
],
''
)
img_shape
=
img
.
shape
[:
2
]
img_shape
=
img
.
shape
[:
2
]
img_process
=
self
.
preprocess
(
img
)
img_process
=
self
.
preprocess
(
img
)
...
@@ -90,39 +105,44 @@ def infer():
...
@@ -90,39 +105,44 @@ def infer():
if
image
is
None
:
if
image
is
None
:
print
(
im_name
,
'is None'
)
print
(
im_name
,
'is None'
)
continue
continue
# 预测
# 预测
if
cfg
.
example
==
'ACE2P'
:
if
cfg
.
example
==
'ACE2P'
:
# ACE2P模型使用多尺度预测
# ACE2P模型使用多尺度预测
reader
=
importlib
.
import_module
(
'reader'
)
reader
=
importlib
.
import_module
(
'reader'
)
multi_scale_test
=
getattr
(
reader
,
'multi_scale_test'
)
multi_scale_test
=
getattr
(
reader
,
'multi_scale_test'
)
parsing
,
logits
=
multi_scale_test
(
exe
,
test_prog
,
feed_name
,
fetch_list
,
image
,
im_shape
)
parsing
,
logits
=
multi_scale_test
(
exe
,
test_prog
,
feed_name
,
fetch_list
,
image
,
im_shape
)
else
:
else
:
# HumanSeg,RoadLine模型单尺度预测
# HumanSeg,RoadLine模型单尺度预测
result
=
exe
.
run
(
program
=
test_prog
,
feed
=
{
feed_name
[
0
]:
image
},
fetch_list
=
fetch_list
)
result
=
exe
.
run
(
program
=
test_prog
,
feed
=
{
feed_name
[
0
]:
image
},
fetch_list
=
fetch_list
)
parsing
=
np
.
argmax
(
result
[
0
][
0
],
axis
=
0
)
parsing
=
np
.
argmax
(
result
[
0
][
0
],
axis
=
0
)
parsing
=
cv2
.
resize
(
parsing
.
astype
(
np
.
uint8
),
im_shape
[::
-
1
])
parsing
=
cv2
.
resize
(
parsing
.
astype
(
np
.
uint8
),
im_shape
[::
-
1
])
# 预测结果保存
# 预测结果保存
result_path
=
os
.
path
.
join
(
cfg
.
vis_dir
,
im_name
+
'.png'
)
result_path
=
os
.
path
.
join
(
cfg
.
vis_dir
,
im_name
+
'.png'
)
if
cfg
.
example
==
'HumanSeg'
:
if
cfg
.
example
==
'HumanSeg'
:
logits
=
result
[
0
][
0
][
1
]
*
255
logits
=
result
[
0
][
0
][
1
]
*
255
logits
=
cv2
.
resize
(
logits
,
im_shape
[::
-
1
])
logits
=
cv2
.
resize
(
logits
,
im_shape
[::
-
1
])
ret
,
logits
=
cv2
.
threshold
(
logits
,
thresh
,
0
,
cv2
.
THRESH_TOZERO
)
ret
,
logits
=
cv2
.
threshold
(
logits
,
thresh
,
0
,
cv2
.
THRESH_TOZERO
)
logits
=
255
*
(
logits
-
thresh
)
/
(
255
-
thresh
)
logits
=
255
*
(
logits
-
thresh
)
/
(
255
-
thresh
)
# 将分割结果添加到alpha通道
# 将分割结果添加到alpha通道
rgba
=
np
.
concatenate
((
ori_img
,
np
.
expand_dims
(
logits
,
axis
=
2
)),
axis
=
2
)
rgba
=
np
.
concatenate
((
ori_img
,
np
.
expand_dims
(
logits
,
axis
=
2
)),
axis
=
2
)
cv2
.
imwrite
(
result_path
,
rgba
)
cv2
.
imwrite
(
result_path
,
rgba
)
else
:
else
:
output_im
=
PILImage
.
fromarray
(
np
.
asarray
(
parsing
,
dtype
=
np
.
uint8
))
output_im
=
PILImage
.
fromarray
(
np
.
asarray
(
parsing
,
dtype
=
np
.
uint8
))
output_im
.
putpalette
(
palette
)
output_im
.
putpalette
(
palette
)
output_im
.
save
(
result_path
)
output_im
.
save
(
result_path
)
if
(
idx
+
1
)
%
100
==
0
:
if
(
idx
+
1
)
%
100
==
0
:
print
(
'%d processd'
%
(
idx
+
1
))
print
(
'%d processd'
%
(
idx
+
1
))
print
(
'%d processd done'
%
(
idx
+
1
))
print
(
'%d processd done'
%
(
idx
+
1
))
return
0
return
0
...
...
contrib/ACE2P/reader.py
浏览文件 @
61645b1d
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
config
import
cfg
from
config
import
cfg
import
cv2
import
cv2
def
get_affine_points
(
src_shape
,
dst_shape
,
rot_grad
=
0
):
def
get_affine_points
(
src_shape
,
dst_shape
,
rot_grad
=
0
):
# 获取图像和仿射后图像的三组对应点坐标
# 获取图像和仿射后图像的三组对应点坐标
# 三组点为仿射变换后图像的中心点, [w/2,0], [0,0],及对应原始图像的点
# 三组点为仿射变换后图像的中心点, [w/2,0], [0,0],及对应原始图像的点
...
@@ -23,7 +38,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0):
...
@@ -23,7 +38,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0):
# 原始图像三组点
# 原始图像三组点
points
=
[[
0
,
0
]]
*
3
points
=
[[
0
,
0
]]
*
3
points
[
0
]
=
(
np
.
array
([
w
,
h
])
-
1
)
*
0.5
points
[
0
]
=
(
np
.
array
([
w
,
h
])
-
1
)
*
0.5
points
[
1
]
=
points
[
0
]
+
0.5
*
affine_shape
[
0
]
*
np
.
array
([
sin_v
,
-
cos_v
])
points
[
1
]
=
points
[
0
]
+
0.5
*
affine_shape
[
0
]
*
np
.
array
([
sin_v
,
-
cos_v
])
points
[
2
]
=
points
[
1
]
-
0.5
*
affine_shape
[
1
]
*
np
.
array
([
cos_v
,
sin_v
])
points
[
2
]
=
points
[
1
]
-
0.5
*
affine_shape
[
1
]
*
np
.
array
([
cos_v
,
sin_v
])
...
@@ -34,6 +49,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0):
...
@@ -34,6 +49,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0):
return
points
,
points_trans
return
points
,
points_trans
def
preprocess
(
im
):
def
preprocess
(
im
):
# ACE2P模型数据预处理
# ACE2P模型数据预处理
im_shape
=
im
.
shape
[:
2
]
im_shape
=
im
.
shape
[:
2
]
...
@@ -42,13 +58,10 @@ def preprocess(im):
...
@@ -42,13 +58,10 @@ def preprocess(im):
# 获取图像和仿射变换后图像的对应点坐标
# 获取图像和仿射变换后图像的对应点坐标
points
,
points_trans
=
get_affine_points
(
im_shape
,
scale
)
points
,
points_trans
=
get_affine_points
(
im_shape
,
scale
)
# 根据对应点集获得仿射矩阵
# 根据对应点集获得仿射矩阵
trans
=
cv2
.
getAffineTransform
(
np
.
float32
(
points
),
trans
=
cv2
.
getAffineTransform
(
np
.
float32
(
points_trans
))
np
.
float32
(
points
),
np
.
float32
(
points_trans
))
# 根据仿射矩阵对图像进行仿射
# 根据仿射矩阵对图像进行仿射
input
=
cv2
.
warpAffine
(
im
,
input
=
cv2
.
warpAffine
(
im
,
trans
,
scale
[::
-
1
],
flags
=
cv2
.
INTER_LINEAR
)
trans
,
scale
[::
-
1
],
flags
=
cv2
.
INTER_LINEAR
)
# 减均值测,除以方差,转换数据格式为NCHW
# 减均值测,除以方差,转换数据格式为NCHW
input
=
input
.
astype
(
np
.
float32
)
input
=
input
.
astype
(
np
.
float32
)
...
@@ -66,19 +79,20 @@ def preprocess(im):
...
@@ -66,19 +79,20 @@ def preprocess(im):
return
input_images
return
input_images
def
multi_scale_test
(
exe
,
test_prog
,
feed_name
,
fetch_list
,
def
multi_scale_test
(
exe
,
test_prog
,
feed_name
,
fetch_list
,
input_ims
,
input_ims
,
im_shape
):
im_shape
):
# 由于部分类别分左右部位, flipped_idx为其水平翻转后对应的标签
# 由于部分类别分左右部位, flipped_idx为其水平翻转后对应的标签
flipped_idx
=
(
15
,
14
,
17
,
16
,
19
,
18
)
flipped_idx
=
(
15
,
14
,
17
,
16
,
19
,
18
)
ms_outputs
=
[]
ms_outputs
=
[]
# 多尺度预测
# 多尺度预测
for
idx
,
scale
in
enumerate
(
cfg
.
multi_scales
):
for
idx
,
scale
in
enumerate
(
cfg
.
multi_scales
):
input_im
=
input_ims
[
idx
]
input_im
=
input_ims
[
idx
]
parsing_output
=
exe
.
run
(
program
=
test_prog
,
parsing_output
=
exe
.
run
(
feed
=
{
feed_name
[
0
]:
input_im
},
program
=
test_prog
,
fetch_list
=
fetch_list
)
feed
=
{
feed_name
[
0
]:
input_im
},
fetch_list
=
fetch_list
)
output
=
parsing_output
[
0
][
0
]
output
=
parsing_output
[
0
][
0
]
if
cfg
.
flip
:
if
cfg
.
flip
:
# 若水平翻转,对部分类别进行翻转,与原始预测结果取均值
# 若水平翻转,对部分类别进行翻转,与原始预测结果取均值
...
@@ -92,7 +106,8 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list,
...
@@ -92,7 +106,8 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list,
# 仿射变换回图像原始尺寸
# 仿射变换回图像原始尺寸
points
,
points_trans
=
get_affine_points
(
im_shape
,
scale
)
points
,
points_trans
=
get_affine_points
(
im_shape
,
scale
)
M
=
cv2
.
getAffineTransform
(
np
.
float32
(
points_trans
),
np
.
float32
(
points
))
M
=
cv2
.
getAffineTransform
(
np
.
float32
(
points_trans
),
np
.
float32
(
points
))
logits_result
=
cv2
.
warpAffine
(
output
,
M
,
im_shape
[::
-
1
],
flags
=
cv2
.
INTER_LINEAR
)
logits_result
=
cv2
.
warpAffine
(
output
,
M
,
im_shape
[::
-
1
],
flags
=
cv2
.
INTER_LINEAR
)
ms_outputs
.
append
(
logits_result
)
ms_outputs
.
append
(
logits_result
)
# 多尺度预测结果求均值,求预测概率最大的类别
# 多尺度预测结果求均值,求预测概率最大的类别
...
@@ -100,4 +115,3 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list,
...
@@ -100,4 +115,3 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list,
ms_fused_parsing_output
=
np
.
mean
(
ms_fused_parsing_output
,
axis
=
0
)
ms_fused_parsing_output
=
np
.
mean
(
ms_fused_parsing_output
,
axis
=
0
)
parsing
=
np
.
argmax
(
ms_fused_parsing_output
,
axis
=
2
)
parsing
=
np
.
argmax
(
ms_fused_parsing_output
,
axis
=
2
)
return
parsing
,
ms_fused_parsing_output
return
parsing
,
ms_fused_parsing_output
contrib/ACE2P/utils/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
contrib/ACE2P/utils/palette.py
浏览文件 @
61645b1d
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
## This source code is licensed under the MIT-style license found in the
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
...
...
contrib/ACE2P/utils/util.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
unicode_literals
import
argparse
import
argparse
import
os
import
os
def
get_arguments
():
def
get_arguments
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--use_gpu"
,
parser
.
add_argument
(
action
=
"store_true"
,
"--use_gpu"
,
action
=
"store_true"
,
help
=
"Use gpu or cpu to test."
)
help
=
"Use gpu or cpu to test."
)
parser
.
add_argument
(
parser
.
add_argument
(
'--example'
,
'--example'
,
type
=
str
,
help
=
'RoadLine, HumanSeg or ACE2P'
)
type
=
str
,
help
=
'RoadLine, HumanSeg or ACE2P'
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -34,6 +48,7 @@ class AttrDict(dict):
...
@@ -34,6 +48,7 @@ class AttrDict(dict):
else
:
else
:
self
[
name
]
=
value
self
[
name
]
=
value
def
merge_cfg_from_args
(
args
,
cfg
):
def
merge_cfg_from_args
(
args
,
cfg
):
"""Merge config keys, values in args into the global config."""
"""Merge config keys, values in args into the global config."""
for
k
,
v
in
vars
(
args
).
items
():
for
k
,
v
in
vars
(
args
).
items
():
...
@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg):
...
@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg):
value
=
v
value
=
v
if
value
is
not
None
:
if
value
is
not
None
:
cfg
[
k
]
=
value
cfg
[
k
]
=
value
contrib/HumanSeg/datasets/__init__.py
浏览文件 @
61645b1d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/HumanSeg/datasets/dataset.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/datasets/shared_queue/__init__.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/datasets/shared_queue/queue.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/datasets/shared_queue/sharedmemory.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,9 +13,6 @@
...
@@ -12,9 +13,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# utils for memory management which is allocated on sharedmemory,
# note that these structures may not be thread-safe
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
...
...
contrib/HumanSeg/export.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
models
import
models
import
argparse
import
argparse
...
...
contrib/HumanSeg/infer.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
argparse
import
os
import
os
import
os.path
as
osp
import
os.path
as
osp
...
...
contrib/HumanSeg/models/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.humanseg
import
HumanSegMobile
from
.humanseg
import
HumanSegMobile
from
.humanseg
import
HumanSegServer
from
.humanseg
import
HumanSegServer
from
.humanseg
import
HumanSegLite
from
.humanseg
import
HumanSegLite
...
...
contrib/HumanSeg/models/humanseg.py
浏览文件 @
61645b1d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/HumanSeg/models/load_model.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/nets/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.backbone
import
mobilenet_v2
from
.backbone
import
mobilenet_v2
from
.backbone
import
xception
from
.backbone
import
xception
from
.deeplabv3p
import
DeepLabv3p
from
.deeplabv3p
import
DeepLabv3p
...
...
contrib/HumanSeg/nets/backbone/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.mobilenet_v2
import
MobileNetV2
from
.mobilenet_v2
import
MobileNetV2
from
.xception
import
Xception
from
.xception
import
Xception
contrib/HumanSeg/nets/backbone/mobilenet_v2.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -10,6 +11,7 @@
...
@@ -10,6 +11,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
...
...
contrib/HumanSeg/nets/backbone/xception.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/nets/deeplabv3p.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
copyright (c) 2020
PaddlePaddle Authors. All Rights Reserve.
#
Copyright (c) 2019
PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/nets/hrnet.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
copyright (c) 2020
PaddlePaddle Authors. All Rights Reserve.
#
Copyright (c) 2019
PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/nets/libs.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
copyright (c) 2020
PaddlePaddle Authors. All Rights Reserve.
#
Copyright (c) 2019
PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/nets/seg_modules.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/nets/shufflenet_slim.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
...
...
contrib/HumanSeg/pretrained_weights/download_pretrained_weights.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/HumanSeg/quant_offline.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
argparse
from
datasets.dataset
import
Dataset
from
datasets.dataset
import
Dataset
import
transforms
import
transforms
...
...
contrib/HumanSeg/quant_online.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
argparse
from
datasets.dataset
import
Dataset
from
datasets.dataset
import
Dataset
from
models
import
HumanSegMobile
,
HumanSegLite
,
HumanSegServer
from
models
import
HumanSegMobile
,
HumanSegLite
,
HumanSegServer
...
...
contrib/HumanSeg/train.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
argparse
from
datasets.dataset
import
Dataset
from
datasets.dataset
import
Dataset
from
models
import
HumanSegMobile
,
HumanSegLite
,
HumanSegServer
from
models
import
HumanSegMobile
,
HumanSegLite
,
HumanSegServer
...
...
contrib/HumanSeg/transforms/__init__.py
浏览文件 @
61645b1d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/HumanSeg/transforms/functional.py
浏览文件 @
61645b1d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/HumanSeg/transforms/transforms.py
浏览文件 @
61645b1d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/HumanSeg/utils/__init__.py
浏览文件 @
61645b1d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/HumanSeg/utils/humanseg_postprocess.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
numpy
as
np
import
cv2
import
cv2
import
os
import
os
...
...
contrib/HumanSeg/utils/logging.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/utils/metrics.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/utils/post_quantization.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/utils/utils.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -205,11 +206,9 @@ def load_pretrained_weights(exe, main_prog, weights_dir, fuse_bn=False):
...
@@ -205,11 +206,9 @@ def load_pretrained_weights(exe, main_prog, weights_dir, fuse_bn=False):
vars_to_load
.
append
(
var
)
vars_to_load
.
append
(
var
)
logging
.
debug
(
"Weight {} will be load"
.
format
(
var
.
name
))
logging
.
debug
(
"Weight {} will be load"
.
format
(
var
.
name
))
fluid
.
io
.
load_vars
(
params_dict
=
fluid
.
io
.
load_program_state
(
executor
=
exe
,
weights_dir
,
var_list
=
vars_to_load
)
dirname
=
weights_dir
,
fluid
.
io
.
set_program_state
(
main_prog
,
params_dict
)
main_program
=
main_prog
,
vars
=
vars_to_load
)
if
len
(
vars_to_load
)
==
0
:
if
len
(
vars_to_load
)
==
0
:
logging
.
warning
(
logging
.
warning
(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
...
...
contrib/HumanSeg/val.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
argparse
from
datasets.dataset
import
Dataset
from
datasets.dataset
import
Dataset
import
transforms
import
transforms
...
...
contrib/HumanSeg/video_infer.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
argparse
import
os
import
os
import
os.path
as
osp
import
os.path
as
osp
...
...
contrib/LaneNet/data_aug.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -21,6 +21,7 @@ from models.model_builder import ModelPhase
...
@@ -21,6 +21,7 @@ from models.model_builder import ModelPhase
from
pdseg.data_aug
import
get_random_scale
,
randomly_scale_image_and_label
,
random_rotation
,
\
from
pdseg.data_aug
import
get_random_scale
,
randomly_scale_image_and_label
,
random_rotation
,
\
rand_scale_aspect
,
hsv_color_jitter
,
rand_crop
rand_scale_aspect
,
hsv_color_jitter
,
rand_crop
def
resize
(
img
,
grt
=
None
,
grt_instance
=
None
,
mode
=
ModelPhase
.
TRAIN
):
def
resize
(
img
,
grt
=
None
,
grt_instance
=
None
,
mode
=
ModelPhase
.
TRAIN
):
"""
"""
改变图像及标签图像尺寸
改变图像及标签图像尺寸
...
@@ -44,7 +45,8 @@ def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN):
...
@@ -44,7 +45,8 @@ def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN):
if
grt
is
not
None
:
if
grt
is
not
None
:
grt
=
cv2
.
resize
(
grt
,
target_size
,
interpolation
=
cv2
.
INTER_NEAREST
)
grt
=
cv2
.
resize
(
grt
,
target_size
,
interpolation
=
cv2
.
INTER_NEAREST
)
if
grt_instance
is
not
None
:
if
grt_instance
is
not
None
:
grt_instance
=
cv2
.
resize
(
grt_instance
,
target_size
,
interpolation
=
cv2
.
INTER_NEAREST
)
grt_instance
=
cv2
.
resize
(
grt_instance
,
target_size
,
interpolation
=
cv2
.
INTER_NEAREST
)
elif
cfg
.
AUG
.
AUG_METHOD
==
'stepscaling'
:
elif
cfg
.
AUG
.
AUG_METHOD
==
'stepscaling'
:
if
mode
==
ModelPhase
.
TRAIN
:
if
mode
==
ModelPhase
.
TRAIN
:
min_scale_factor
=
cfg
.
AUG
.
MIN_SCALE_FACTOR
min_scale_factor
=
cfg
.
AUG
.
MIN_SCALE_FACTOR
...
...
contrib/LaneNet/dataset/download_tusimple.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/LaneNet/eval.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -122,7 +122,10 @@ def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs):
...
@@ -122,7 +122,10 @@ def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs):
if
ckpt_dir
is
not
None
:
if
ckpt_dir
is
not
None
:
print
(
'load test model:'
,
ckpt_dir
)
print
(
'load test model:'
,
ckpt_dir
)
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
try
:
fluid
.
load
(
test_prog
,
os
.
path
.
join
(
ckpt_dir
,
'model'
),
exe
)
except
:
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
# Use streaming confusion matrix to calculate mean_iou
# Use streaming confusion matrix to calculate mean_iou
np
.
set_printoptions
(
np
.
set_printoptions
(
...
...
contrib/LaneNet/loss.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/LaneNet/models/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/LaneNet/models/model_builder.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/LaneNet/models/modeling/lanenet.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -18,7 +18,6 @@ from __future__ import print_function
...
@@ -18,7 +18,6 @@ from __future__ import print_function
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
utils.config
import
cfg
from
utils.config
import
cfg
from
pdseg.models.libs.model_libs
import
scope
,
name_scope
from
pdseg.models.libs.model_libs
import
scope
,
name_scope
from
pdseg.models.libs.model_libs
import
bn
,
bn_relu
,
relu
from
pdseg.models.libs.model_libs
import
bn
,
bn_relu
,
relu
...
@@ -86,7 +85,12 @@ def bottleneck(inputs,
...
@@ -86,7 +85,12 @@ def bottleneck(inputs,
with
scope
(
'down_sample'
):
with
scope
(
'down_sample'
):
inputs_shape
=
inputs
.
shape
inputs_shape
=
inputs
.
shape
with
scope
(
'main_max_pool'
):
with
scope
(
'main_max_pool'
):
net_main
=
fluid
.
layers
.
conv2d
(
inputs
,
inputs_shape
[
1
],
filter_size
=
3
,
stride
=
2
,
padding
=
'SAME'
)
net_main
=
fluid
.
layers
.
conv2d
(
inputs
,
inputs_shape
[
1
],
filter_size
=
3
,
stride
=
2
,
padding
=
'SAME'
)
#First get the difference in depth to pad, then pad with zeros only on the last dimension.
#First get the difference in depth to pad, then pad with zeros only on the last dimension.
depth_to_pad
=
abs
(
inputs_shape
[
1
]
-
output_depth
)
depth_to_pad
=
abs
(
inputs_shape
[
1
]
-
output_depth
)
...
@@ -95,12 +99,16 @@ def bottleneck(inputs,
...
@@ -95,12 +99,16 @@ def bottleneck(inputs,
net_main
=
fluid
.
layers
.
pad
(
net_main
,
paddings
=
paddings
)
net_main
=
fluid
.
layers
.
pad
(
net_main
,
paddings
=
paddings
)
with
scope
(
'block1'
):
with
scope
(
'block1'
):
net
=
conv
(
inputs
,
reduced_depth
,
[
2
,
2
],
stride
=
2
,
padding
=
'same'
)
net
=
conv
(
inputs
,
reduced_depth
,
[
2
,
2
],
stride
=
2
,
padding
=
'same'
)
net
=
bn
(
net
)
net
=
bn
(
net
)
net
=
prelu
(
net
,
decoder
=
decoder
)
net
=
prelu
(
net
,
decoder
=
decoder
)
with
scope
(
'block2'
):
with
scope
(
'block2'
):
net
=
conv
(
net
,
reduced_depth
,
[
filter_size
,
filter_size
],
padding
=
'same'
)
net
=
conv
(
net
,
reduced_depth
,
[
filter_size
,
filter_size
],
padding
=
'same'
)
net
=
bn
(
net
)
net
=
bn
(
net
)
net
=
prelu
(
net
,
decoder
=
decoder
)
net
=
prelu
(
net
,
decoder
=
decoder
)
...
@@ -137,13 +145,18 @@ def bottleneck(inputs,
...
@@ -137,13 +145,18 @@ def bottleneck(inputs,
# Second conv block --- apply dilated convolution here
# Second conv block --- apply dilated convolution here
with
scope
(
'block2'
):
with
scope
(
'block2'
):
net
=
conv
(
net
,
reduced_depth
,
filter_size
,
padding
=
'SAME'
,
dilation
=
dilation_rate
)
net
=
conv
(
net
,
reduced_depth
,
filter_size
,
padding
=
'SAME'
,
dilation
=
dilation_rate
)
net
=
bn
(
net
)
net
=
bn
(
net
)
net
=
prelu
(
net
,
decoder
=
decoder
)
net
=
prelu
(
net
,
decoder
=
decoder
)
# Final projection with 1x1 kernel (Expansion)
# Final projection with 1x1 kernel (Expansion)
with
scope
(
'block3'
):
with
scope
(
'block3'
):
net
=
conv
(
net
,
output_depth
,
[
1
,
1
])
net
=
conv
(
net
,
output_depth
,
[
1
,
1
])
net
=
bn
(
net
)
net
=
bn
(
net
)
net
=
prelu
(
net
,
decoder
=
decoder
)
net
=
prelu
(
net
,
decoder
=
decoder
)
...
@@ -172,9 +185,11 @@ def bottleneck(inputs,
...
@@ -172,9 +185,11 @@ def bottleneck(inputs,
# Second conv block --- apply asymmetric conv here
# Second conv block --- apply asymmetric conv here
with
scope
(
'block2'
):
with
scope
(
'block2'
):
with
scope
(
'asymmetric_conv2a'
):
with
scope
(
'asymmetric_conv2a'
):
net
=
conv
(
net
,
reduced_depth
,
[
filter_size
,
1
],
padding
=
'same'
)
net
=
conv
(
net
,
reduced_depth
,
[
filter_size
,
1
],
padding
=
'same'
)
with
scope
(
'asymmetric_conv2b'
):
with
scope
(
'asymmetric_conv2b'
):
net
=
conv
(
net
,
reduced_depth
,
[
1
,
filter_size
],
padding
=
'same'
)
net
=
conv
(
net
,
reduced_depth
,
[
1
,
filter_size
],
padding
=
'same'
)
net
=
bn
(
net
)
net
=
bn
(
net
)
net
=
prelu
(
net
,
decoder
=
decoder
)
net
=
prelu
(
net
,
decoder
=
decoder
)
...
@@ -211,7 +226,8 @@ def bottleneck(inputs,
...
@@ -211,7 +226,8 @@ def bottleneck(inputs,
with
scope
(
'unpool'
):
with
scope
(
'unpool'
):
net_unpool
=
conv
(
inputs
,
output_depth
,
[
1
,
1
])
net_unpool
=
conv
(
inputs
,
output_depth
,
[
1
,
1
])
net_unpool
=
bn
(
net_unpool
)
net_unpool
=
bn
(
net_unpool
)
net_unpool
=
fluid
.
layers
.
resize_bilinear
(
net_unpool
,
out_shape
=
output_shape
[
2
:])
net_unpool
=
fluid
.
layers
.
resize_bilinear
(
net_unpool
,
out_shape
=
output_shape
[
2
:])
# First 1x1 projection to reduce depth
# First 1x1 projection to reduce depth
with
scope
(
'block1'
):
with
scope
(
'block1'
):
...
@@ -220,7 +236,12 @@ def bottleneck(inputs,
...
@@ -220,7 +236,12 @@ def bottleneck(inputs,
net
=
prelu
(
net
,
decoder
=
decoder
)
net
=
prelu
(
net
,
decoder
=
decoder
)
with
scope
(
'block2'
):
with
scope
(
'block2'
):
net
=
deconv
(
net
,
reduced_depth
,
filter_size
=
filter_size
,
stride
=
2
,
padding
=
'same'
)
net
=
deconv
(
net
,
reduced_depth
,
filter_size
=
filter_size
,
stride
=
2
,
padding
=
'same'
)
net
=
bn
(
net
)
net
=
bn
(
net
)
net
=
prelu
(
net
,
decoder
=
decoder
)
net
=
prelu
(
net
,
decoder
=
decoder
)
...
@@ -253,7 +274,10 @@ def bottleneck(inputs,
...
@@ -253,7 +274,10 @@ def bottleneck(inputs,
# Second conv block
# Second conv block
with
scope
(
'block2'
):
with
scope
(
'block2'
):
net
=
conv
(
net
,
reduced_depth
,
[
filter_size
,
filter_size
],
padding
=
'same'
)
net
=
conv
(
net
,
reduced_depth
,
[
filter_size
,
filter_size
],
padding
=
'same'
)
net
=
bn
(
net
)
net
=
bn
(
net
)
net
=
prelu
(
net
,
decoder
=
decoder
)
net
=
prelu
(
net
,
decoder
=
decoder
)
...
@@ -281,17 +305,33 @@ def ENet_stage1(inputs, name_scope='stage1_block'):
...
@@ -281,17 +305,33 @@ def ENet_stage1(inputs, name_scope='stage1_block'):
=
bottleneck
(
inputs
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.01
,
type
=
DOWNSAMPLING
,
=
bottleneck
(
inputs
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.01
,
type
=
DOWNSAMPLING
,
name_scope
=
'bottleneck1_0'
)
name_scope
=
'bottleneck1_0'
)
with
scope
(
'bottleneck1_1'
):
with
scope
(
'bottleneck1_1'
):
net
=
bottleneck
(
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.01
,
net
=
bottleneck
(
name_scope
=
'bottleneck1_1'
)
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.01
,
name_scope
=
'bottleneck1_1'
)
with
scope
(
'bottleneck1_2'
):
with
scope
(
'bottleneck1_2'
):
net
=
bottleneck
(
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.01
,
net
=
bottleneck
(
name_scope
=
'bottleneck1_2'
)
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.01
,
name_scope
=
'bottleneck1_2'
)
with
scope
(
'bottleneck1_3'
):
with
scope
(
'bottleneck1_3'
):
net
=
bottleneck
(
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.01
,
net
=
bottleneck
(
name_scope
=
'bottleneck1_3'
)
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.01
,
name_scope
=
'bottleneck1_3'
)
with
scope
(
'bottleneck1_4'
):
with
scope
(
'bottleneck1_4'
):
net
=
bottleneck
(
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.01
,
net
=
bottleneck
(
name_scope
=
'bottleneck1_4'
)
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.01
,
name_scope
=
'bottleneck1_4'
)
return
net
,
inputs_shape_1
return
net
,
inputs_shape_1
...
@@ -302,17 +342,38 @@ def ENet_stage2(inputs, name_scope='stage2_block'):
...
@@ -302,17 +342,38 @@ def ENet_stage2(inputs, name_scope='stage2_block'):
name_scope
=
'bottleneck2_0'
)
name_scope
=
'bottleneck2_0'
)
for
i
in
range
(
2
):
for
i
in
range
(
2
):
with
scope
(
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
1
))):
with
scope
(
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
1
))):
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
net
=
bottleneck
(
name_scope
=
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
1
)))
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
name_scope
=
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
1
)))
with
scope
(
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
2
))):
with
scope
(
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
2
))):
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
DILATED
,
dilation_rate
=
(
2
**
(
2
*
i
+
1
)),
net
=
bottleneck
(
name_scope
=
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
2
)))
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
DILATED
,
dilation_rate
=
(
2
**
(
2
*
i
+
1
)),
name_scope
=
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
2
)))
with
scope
(
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
3
))):
with
scope
(
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
3
))):
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
5
,
regularizer_prob
=
0.1
,
type
=
ASYMMETRIC
,
net
=
bottleneck
(
name_scope
=
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
3
)))
net
,
output_depth
=
128
,
filter_size
=
5
,
regularizer_prob
=
0.1
,
type
=
ASYMMETRIC
,
name_scope
=
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
3
)))
with
scope
(
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
4
))):
with
scope
(
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
4
))):
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
DILATED
,
dilation_rate
=
(
2
**
(
2
*
i
+
2
)),
net
=
bottleneck
(
name_scope
=
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
4
)))
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
DILATED
,
dilation_rate
=
(
2
**
(
2
*
i
+
2
)),
name_scope
=
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
4
)))
return
net
,
inputs_shape_2
return
net
,
inputs_shape_2
...
@@ -320,52 +381,106 @@ def ENet_stage3(inputs, name_scope='stage3_block'):
...
@@ -320,52 +381,106 @@ def ENet_stage3(inputs, name_scope='stage3_block'):
with
scope
(
name_scope
):
with
scope
(
name_scope
):
for
i
in
range
(
2
):
for
i
in
range
(
2
):
with
scope
(
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
0
))):
with
scope
(
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
0
))):
net
=
bottleneck
(
inputs
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
net
=
bottleneck
(
name_scope
=
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
0
)))
inputs
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
name_scope
=
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
0
)))
with
scope
(
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
1
))):
with
scope
(
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
1
))):
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
DILATED
,
dilation_rate
=
(
2
**
(
2
*
i
+
1
)),
net
=
bottleneck
(
name_scope
=
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
1
)))
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
DILATED
,
dilation_rate
=
(
2
**
(
2
*
i
+
1
)),
name_scope
=
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
1
)))
with
scope
(
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
2
))):
with
scope
(
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
2
))):
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
5
,
regularizer_prob
=
0.1
,
type
=
ASYMMETRIC
,
net
=
bottleneck
(
name_scope
=
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
2
)))
net
,
output_depth
=
128
,
filter_size
=
5
,
regularizer_prob
=
0.1
,
type
=
ASYMMETRIC
,
name_scope
=
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
2
)))
with
scope
(
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
3
))):
with
scope
(
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
3
))):
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
DILATED
,
dilation_rate
=
(
2
**
(
2
*
i
+
2
)),
net
=
bottleneck
(
name_scope
=
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
3
)))
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
DILATED
,
dilation_rate
=
(
2
**
(
2
*
i
+
2
)),
name_scope
=
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
3
)))
return
net
return
net
def
ENet_stage4
(
inputs
,
inputs_shape
,
connect_tensor
,
def
ENet_stage4
(
inputs
,
skip_connections
=
True
,
name_scope
=
'stage4_block'
):
inputs_shape
,
connect_tensor
,
skip_connections
=
True
,
name_scope
=
'stage4_block'
):
with
scope
(
name_scope
):
with
scope
(
name_scope
):
with
scope
(
'bottleneck4_0'
):
with
scope
(
'bottleneck4_0'
):
net
=
bottleneck
(
inputs
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
net
=
bottleneck
(
type
=
UPSAMPLING
,
decoder
=
True
,
output_shape
=
inputs_shape
,
inputs
,
name_scope
=
'bottleneck4_0'
)
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
UPSAMPLING
,
decoder
=
True
,
output_shape
=
inputs_shape
,
name_scope
=
'bottleneck4_0'
)
if
skip_connections
:
if
skip_connections
:
net
=
fluid
.
layers
.
elementwise_add
(
net
,
connect_tensor
)
net
=
fluid
.
layers
.
elementwise_add
(
net
,
connect_tensor
)
with
scope
(
'bottleneck4_1'
):
with
scope
(
'bottleneck4_1'
):
net
=
bottleneck
(
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
decoder
=
True
,
net
=
bottleneck
(
name_scope
=
'bottleneck4_1'
)
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
decoder
=
True
,
name_scope
=
'bottleneck4_1'
)
with
scope
(
'bottleneck4_2'
):
with
scope
(
'bottleneck4_2'
):
net
=
bottleneck
(
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
decoder
=
True
,
net
=
bottleneck
(
name_scope
=
'bottleneck4_2'
)
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
decoder
=
True
,
name_scope
=
'bottleneck4_2'
)
return
net
return
net
def
ENet_stage5
(
inputs
,
inputs_shape
,
connect_tensor
,
skip_connections
=
True
,
def
ENet_stage5
(
inputs
,
inputs_shape
,
connect_tensor
,
skip_connections
=
True
,
name_scope
=
'stage5_block'
):
name_scope
=
'stage5_block'
):
with
scope
(
name_scope
):
with
scope
(
name_scope
):
net
=
bottleneck
(
inputs
,
output_depth
=
16
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
UPSAMPLING
,
net
=
bottleneck
(
decoder
=
True
,
output_shape
=
inputs_shape
,
inputs
,
name_scope
=
'bottleneck5_0'
)
output_depth
=
16
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
UPSAMPLING
,
decoder
=
True
,
output_shape
=
inputs_shape
,
name_scope
=
'bottleneck5_0'
)
if
skip_connections
:
if
skip_connections
:
net
=
fluid
.
layers
.
elementwise_add
(
net
,
connect_tensor
)
net
=
fluid
.
layers
.
elementwise_add
(
net
,
connect_tensor
)
with
scope
(
'bottleneck5_1'
):
with
scope
(
'bottleneck5_1'
):
net
=
bottleneck
(
net
,
output_depth
=
16
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
decoder
=
True
,
net
=
bottleneck
(
name_scope
=
'bottleneck5_1'
)
net
,
output_depth
=
16
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
decoder
=
True
,
name_scope
=
'bottleneck5_1'
)
return
net
return
net
...
@@ -378,14 +493,16 @@ def decoder(input, num_classes):
...
@@ -378,14 +493,16 @@ def decoder(input, num_classes):
segStage3
=
ENet_stage3
(
stage2
)
segStage3
=
ENet_stage3
(
stage2
)
segStage4
=
ENet_stage4
(
segStage3
,
inputs_shape_2
,
stage1
)
segStage4
=
ENet_stage4
(
segStage3
,
inputs_shape_2
,
stage1
)
segStage5
=
ENet_stage5
(
segStage4
,
inputs_shape_1
,
initial
)
segStage5
=
ENet_stage5
(
segStage4
,
inputs_shape_1
,
initial
)
segLogits
=
deconv
(
segStage5
,
num_classes
,
filter_size
=
2
,
stride
=
2
,
padding
=
'SAME'
)
segLogits
=
deconv
(
segStage5
,
num_classes
,
filter_size
=
2
,
stride
=
2
,
padding
=
'SAME'
)
# Embedding branch
# Embedding branch
with
scope
(
'LaneNetEm'
):
with
scope
(
'LaneNetEm'
):
emStage3
=
ENet_stage3
(
stage2
)
emStage3
=
ENet_stage3
(
stage2
)
emStage4
=
ENet_stage4
(
emStage3
,
inputs_shape_2
,
stage1
)
emStage4
=
ENet_stage4
(
emStage3
,
inputs_shape_2
,
stage1
)
emStage5
=
ENet_stage5
(
emStage4
,
inputs_shape_1
,
initial
)
emStage5
=
ENet_stage5
(
emStage4
,
inputs_shape_1
,
initial
)
emLogits
=
deconv
(
emStage5
,
4
,
filter_size
=
2
,
stride
=
2
,
padding
=
'SAME'
)
emLogits
=
deconv
(
emStage5
,
4
,
filter_size
=
2
,
stride
=
2
,
padding
=
'SAME'
)
elif
'vgg'
in
cfg
.
MODEL
.
LANENET
.
BACKBONE
:
elif
'vgg'
in
cfg
.
MODEL
.
LANENET
.
BACKBONE
:
encoder_list
=
[
'pool5'
,
'pool4'
,
'pool3'
]
encoder_list
=
[
'pool5'
,
'pool4'
,
'pool3'
]
...
@@ -396,14 +513,16 @@ def decoder(input, num_classes):
...
@@ -396,14 +513,16 @@ def decoder(input, num_classes):
encoder_list
=
encoder_list
[
1
:]
encoder_list
=
encoder_list
[
1
:]
for
i
in
range
(
len
(
encoder_list
)):
for
i
in
range
(
len
(
encoder_list
)):
with
scope
(
'deconv_{:d}'
.
format
(
i
+
1
)):
with
scope
(
'deconv_{:d}'
.
format
(
i
+
1
)):
deconv_out
=
deconv
(
score
,
64
,
filter_size
=
4
,
stride
=
2
,
padding
=
'SAME'
)
deconv_out
=
deconv
(
score
,
64
,
filter_size
=
4
,
stride
=
2
,
padding
=
'SAME'
)
input_tensor
=
input
[
encoder_list
[
i
]]
input_tensor
=
input
[
encoder_list
[
i
]]
with
scope
(
'score_{:d}'
.
format
(
i
+
1
)):
with
scope
(
'score_{:d}'
.
format
(
i
+
1
)):
score
=
conv
(
input_tensor
,
64
,
1
)
score
=
conv
(
input_tensor
,
64
,
1
)
score
=
fluid
.
layers
.
elementwise_add
(
deconv_out
,
score
)
score
=
fluid
.
layers
.
elementwise_add
(
deconv_out
,
score
)
with
scope
(
'deconv_final'
):
with
scope
(
'deconv_final'
):
emLogits
=
deconv
(
score
,
64
,
filter_size
=
16
,
stride
=
8
,
padding
=
'SAME'
)
emLogits
=
deconv
(
score
,
64
,
filter_size
=
16
,
stride
=
8
,
padding
=
'SAME'
)
with
scope
(
'score_final'
):
with
scope
(
'score_final'
):
segLogits
=
conv
(
emLogits
,
num_classes
,
1
)
segLogits
=
conv
(
emLogits
,
num_classes
,
1
)
emLogits
=
relu
(
conv
(
emLogits
,
4
,
1
))
emLogits
=
relu
(
conv
(
emLogits
,
4
,
1
))
...
@@ -415,7 +534,8 @@ def encoder(input):
...
@@ -415,7 +534,8 @@ def encoder(input):
model
=
vgg_backbone
(
layers
=
16
)
model
=
vgg_backbone
(
layers
=
16
)
#output = model.net(input)
#output = model.net(input)
_
,
encode_feature_dict
=
model
.
net
(
input
,
end_points
=
13
,
decode_points
=
[
7
,
10
,
13
])
_
,
encode_feature_dict
=
model
.
net
(
input
,
end_points
=
13
,
decode_points
=
[
7
,
10
,
13
])
output
=
{}
output
=
{}
output
[
'pool3'
]
=
encode_feature_dict
[
7
]
output
[
'pool3'
]
=
encode_feature_dict
[
7
]
output
[
'pool4'
]
=
encode_feature_dict
[
10
]
output
[
'pool4'
]
=
encode_feature_dict
[
10
]
...
@@ -427,8 +547,9 @@ def encoder(input):
...
@@ -427,8 +547,9 @@ def encoder(input):
stage2
,
inputs_shape_2
=
ENet_stage2
(
stage1
)
stage2
,
inputs_shape_2
=
ENet_stage2
(
stage1
)
output
=
(
initial
,
stage1
,
stage2
,
inputs_shape_1
,
inputs_shape_2
)
output
=
(
initial
,
stage1
,
stage2
,
inputs_shape_1
,
inputs_shape_2
)
else
:
else
:
raise
Exception
(
"LaneNet expect enet and vgg backbone, but received {}"
.
raise
Exception
(
format
(
cfg
.
MODEL
.
LANENET
.
BACKBONE
))
"LaneNet expect enet and vgg backbone, but received {}"
.
format
(
cfg
.
MODEL
.
LANENET
.
BACKBONE
))
return
output
return
output
...
...
contrib/LaneNet/reader.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -58,7 +58,8 @@ class LaneNetDataset():
...
@@ -58,7 +58,8 @@ class LaneNetDataset():
if
self
.
shuffle
and
cfg
.
NUM_TRAINERS
>
1
:
if
self
.
shuffle
and
cfg
.
NUM_TRAINERS
>
1
:
np
.
random
.
RandomState
(
self
.
shuffle_seed
).
shuffle
(
self
.
all_lines
)
np
.
random
.
RandomState
(
self
.
shuffle_seed
).
shuffle
(
self
.
all_lines
)
num_lines
=
len
(
self
.
all_lines
)
//
cfg
.
NUM_TRAINERS
num_lines
=
len
(
self
.
all_lines
)
//
cfg
.
NUM_TRAINERS
self
.
lines
=
self
.
all_lines
[
num_lines
*
cfg
.
TRAINER_ID
:
num_lines
*
(
cfg
.
TRAINER_ID
+
1
)]
self
.
lines
=
self
.
all_lines
[
num_lines
*
cfg
.
TRAINER_ID
:
num_lines
*
(
cfg
.
TRAINER_ID
+
1
)]
self
.
shuffle_seed
+=
1
self
.
shuffle_seed
+=
1
elif
self
.
shuffle
:
elif
self
.
shuffle
:
np
.
random
.
shuffle
(
self
.
lines
)
np
.
random
.
shuffle
(
self
.
lines
)
...
@@ -86,7 +87,8 @@ class LaneNetDataset():
...
@@ -86,7 +87,8 @@ class LaneNetDataset():
if
self
.
shuffle
and
cfg
.
NUM_TRAINERS
>
1
:
if
self
.
shuffle
and
cfg
.
NUM_TRAINERS
>
1
:
np
.
random
.
RandomState
(
self
.
shuffle_seed
).
shuffle
(
self
.
all_lines
)
np
.
random
.
RandomState
(
self
.
shuffle_seed
).
shuffle
(
self
.
all_lines
)
num_lines
=
len
(
self
.
all_lines
)
//
self
.
num_trainers
num_lines
=
len
(
self
.
all_lines
)
//
self
.
num_trainers
self
.
lines
=
self
.
all_lines
[
num_lines
*
self
.
trainer_id
:
num_lines
*
(
self
.
trainer_id
+
1
)]
self
.
lines
=
self
.
all_lines
[
num_lines
*
self
.
trainer_id
:
num_lines
*
(
self
.
trainer_id
+
1
)]
self
.
shuffle_seed
+=
1
self
.
shuffle_seed
+=
1
elif
self
.
shuffle
:
elif
self
.
shuffle
:
np
.
random
.
shuffle
(
self
.
lines
)
np
.
random
.
shuffle
(
self
.
lines
)
...
@@ -118,7 +120,8 @@ class LaneNetDataset():
...
@@ -118,7 +120,8 @@ class LaneNetDataset():
def
batch_reader
(
is_test
=
False
,
drop_last
=
drop_last
):
def
batch_reader
(
is_test
=
False
,
drop_last
=
drop_last
):
if
is_test
:
if
is_test
:
imgs
,
grts
,
grts_instance
,
img_names
,
valid_shapes
,
org_shapes
=
[],
[],
[],
[],
[],
[]
imgs
,
grts
,
grts_instance
,
img_names
,
valid_shapes
,
org_shapes
=
[],
[],
[],
[],
[],
[]
for
img
,
grt
,
grt_instance
,
img_name
,
valid_shape
,
org_shape
in
reader
():
for
img
,
grt
,
grt_instance
,
img_name
,
valid_shape
,
org_shape
in
reader
(
):
imgs
.
append
(
img
)
imgs
.
append
(
img
)
grts
.
append
(
grt
)
grts
.
append
(
grt
)
grts_instance
.
append
(
grt_instance
)
grts_instance
.
append
(
grt_instance
)
...
@@ -126,14 +129,15 @@ class LaneNetDataset():
...
@@ -126,14 +129,15 @@ class LaneNetDataset():
valid_shapes
.
append
(
valid_shape
)
valid_shapes
.
append
(
valid_shape
)
org_shapes
.
append
(
org_shape
)
org_shapes
.
append
(
org_shape
)
if
len
(
imgs
)
==
batch_size
:
if
len
(
imgs
)
==
batch_size
:
yield
np
.
array
(
imgs
),
np
.
array
(
yield
np
.
array
(
imgs
),
np
.
array
(
grts
),
np
.
array
(
grts
),
np
.
array
(
grts_instance
),
img_names
,
np
.
array
(
valid_shapes
)
,
np
.
array
(
grts
_instance
),
img_names
,
np
.
array
(
org_shapes
)
valid_shapes
),
np
.
array
(
org_shapes
)
imgs
,
grts
,
grts_instance
,
img_names
,
valid_shapes
,
org_shapes
=
[],
[],
[],
[],
[],
[]
imgs
,
grts
,
grts_instance
,
img_names
,
valid_shapes
,
org_shapes
=
[],
[],
[],
[],
[],
[]
if
not
drop_last
and
len
(
imgs
)
>
0
:
if
not
drop_last
and
len
(
imgs
)
>
0
:
yield
np
.
array
(
imgs
),
np
.
array
(
grts
),
np
.
array
(
grts_instance
),
img_names
,
np
.
array
(
yield
np
.
array
(
imgs
),
np
.
array
(
grts
),
np
.
array
(
valid_shapes
),
np
.
array
(
org_shapes
)
grts_instance
),
img_names
,
np
.
array
(
valid_shapes
),
np
.
array
(
org_shapes
)
else
:
else
:
imgs
,
labs
,
labs_instance
,
ignore
=
[],
[],
[],
[]
imgs
,
labs
,
labs_instance
,
ignore
=
[],
[],
[],
[]
bs
=
0
bs
=
0
...
@@ -144,12 +148,14 @@ class LaneNetDataset():
...
@@ -144,12 +148,14 @@ class LaneNetDataset():
ignore
.
append
(
ig
)
ignore
.
append
(
ig
)
bs
+=
1
bs
+=
1
if
bs
==
batch_size
:
if
bs
==
batch_size
:
yield
np
.
array
(
imgs
),
np
.
array
(
labs
),
np
.
array
(
labs_instance
),
np
.
array
(
ignore
)
yield
np
.
array
(
imgs
),
np
.
array
(
labs
),
np
.
array
(
labs_instance
),
np
.
array
(
ignore
)
bs
=
0
bs
=
0
imgs
,
labs
,
labs_instance
,
ignore
=
[],
[],
[],
[]
imgs
,
labs
,
labs_instance
,
ignore
=
[],
[],
[],
[]
if
not
drop_last
and
bs
>
0
:
if
not
drop_last
and
bs
>
0
:
yield
np
.
array
(
imgs
),
np
.
array
(
labs
),
np
.
array
(
labs_instance
),
np
.
array
(
ignore
)
yield
np
.
array
(
imgs
),
np
.
array
(
labs
),
np
.
array
(
labs_instance
),
np
.
array
(
ignore
)
return
batch_reader
(
is_test
,
drop_last
)
return
batch_reader
(
is_test
,
drop_last
)
...
@@ -299,10 +305,12 @@ class LaneNetDataset():
...
@@ -299,10 +305,12 @@ class LaneNetDataset():
img
,
grt
=
aug
.
rand_crop
(
img
,
grt
,
mode
=
mode
)
img
,
grt
=
aug
.
rand_crop
(
img
,
grt
,
mode
=
mode
)
elif
ModelPhase
.
is_eval
(
mode
):
elif
ModelPhase
.
is_eval
(
mode
):
img
,
grt
,
grt_instance
=
aug
.
resize
(
img
,
grt
,
grt_instance
,
mode
=
mode
)
img
,
grt
,
grt_instance
=
aug
.
resize
(
img
,
grt
,
grt_instance
,
mode
=
mode
)
elif
ModelPhase
.
is_visual
(
mode
):
elif
ModelPhase
.
is_visual
(
mode
):
ori_img
=
img
.
copy
()
ori_img
=
img
.
copy
()
img
,
grt
,
grt_instance
=
aug
.
resize
(
img
,
grt
,
grt_instance
,
mode
=
mode
)
img
,
grt
,
grt_instance
=
aug
.
resize
(
img
,
grt
,
grt_instance
,
mode
=
mode
)
valid_shape
=
[
img
.
shape
[
0
],
img
.
shape
[
1
]]
valid_shape
=
[
img
.
shape
[
0
],
img
.
shape
[
1
]]
else
:
else
:
raise
ValueError
(
"Dataset mode={} Error!"
.
format
(
mode
))
raise
ValueError
(
"Dataset mode={} Error!"
.
format
(
mode
))
...
...
contrib/LaneNet/train.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -40,10 +40,10 @@ from pdseg.utils.timer import Timer, calculate_eta
...
@@ -40,10 +40,10 @@ from pdseg.utils.timer import Timer, calculate_eta
from
reader
import
LaneNetDataset
from
reader
import
LaneNetDataset
from
models.model_builder
import
build_model
from
models.model_builder
import
build_model
from
models.model_builder
import
ModelPhase
from
models.model_builder
import
ModelPhase
from
models.model_builder
import
parse_shape_from_file
from
eval
import
evaluate
from
eval
import
evaluate
from
vis
import
visualize
from
vis
import
visualize
from
utils
import
dist_utils
from
utils
import
dist_utils
from
utils.load_model_utils
import
load_pretrained_weights
def
parse_args
():
def
parse_args
():
...
@@ -101,37 +101,6 @@ def parse_args():
...
@@ -101,37 +101,6 @@ def parse_args():
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
save_vars
(
executor
,
dirname
,
program
=
None
,
vars
=
None
):
"""
Temporary resolution for Win save variables compatability.
Will fix in PaddlePaddle v1.5.2
"""
save_program
=
fluid
.
Program
()
save_block
=
save_program
.
global_block
()
for
each_var
in
vars
:
# NOTE: don't save the variable which type is RAW
if
each_var
.
type
==
fluid
.
core
.
VarDesc
.
VarType
.
RAW
:
continue
new_var
=
save_block
.
create_var
(
name
=
each_var
.
name
,
shape
=
each_var
.
shape
,
dtype
=
each_var
.
dtype
,
type
=
each_var
.
type
,
lod_level
=
each_var
.
lod_level
,
persistable
=
True
)
file_path
=
os
.
path
.
join
(
dirname
,
new_var
.
name
)
file_path
=
os
.
path
.
normpath
(
file_path
)
save_block
.
append_op
(
type
=
'save'
,
inputs
=
{
'X'
:
[
new_var
]},
outputs
=
{},
attrs
=
{
'file_path'
:
file_path
})
executor
.
run
(
save_program
)
def
save_checkpoint
(
exe
,
program
,
ckpt_name
):
def
save_checkpoint
(
exe
,
program
,
ckpt_name
):
"""
"""
Save checkpoint for evaluation or resume training
Save checkpoint for evaluation or resume training
...
@@ -141,29 +110,22 @@ def save_checkpoint(exe, program, ckpt_name):
...
@@ -141,29 +110,22 @@ def save_checkpoint(exe, program, ckpt_name):
if
not
os
.
path
.
isdir
(
ckpt_dir
):
if
not
os
.
path
.
isdir
(
ckpt_dir
):
os
.
makedirs
(
ckpt_dir
)
os
.
makedirs
(
ckpt_dir
)
save_vars
(
fluid
.
save
(
program
,
os
.
path
.
join
(
ckpt_dir
,
'model'
))
exe
,
ckpt_dir
,
program
,
vars
=
list
(
filter
(
fluid
.
io
.
is_persistable
,
program
.
list_vars
())))
return
ckpt_dir
return
ckpt_dir
def
load_checkpoint
(
exe
,
program
):
def
load_checkpoint
(
exe
,
program
):
"""
"""
Load checkpoiont f
rom pretrained model directory for resume
training
Load checkpoiont f
or resuming
training
"""
"""
print
(
'Resume model training from:'
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
)
if
not
os
.
path
.
exists
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
):
raise
ValueError
(
"TRAIN.PRETRAIN_MODEL {} not exist!"
.
format
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
))
fluid
.
io
.
load_persistables
(
exe
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
,
main_program
=
program
)
model_path
=
cfg
.
TRAIN
.
RESUME_MODEL_DIR
model_path
=
cfg
.
TRAIN
.
RESUME_MODEL_DIR
print
(
'Resume model training from:'
,
model_path
)
if
not
os
.
path
.
exists
(
model_path
):
raise
ValueError
(
"TRAIN.PRETRAIN_MODEL {} not exist!"
.
format
(
model_path
))
fluid
.
load
(
program
,
os
.
path
.
join
(
model_path
,
'model'
),
exe
)
# Check is path ended by path spearator
# Check is path ended by path spearator
if
model_path
[
-
1
]
==
os
.
sep
:
if
model_path
[
-
1
]
==
os
.
sep
:
model_path
=
model_path
[
0
:
-
1
]
model_path
=
model_path
[
0
:
-
1
]
...
@@ -178,7 +140,6 @@ def load_checkpoint(exe, program):
...
@@ -178,7 +140,6 @@ def load_checkpoint(exe, program):
else
:
else
:
raise
ValueError
(
"Resume model path is not valid!"
)
raise
ValueError
(
"Resume model path is not valid!"
)
print
(
"Model checkpoint loaded successfully!"
)
print
(
"Model checkpoint loaded successfully!"
)
return
begin_epoch
return
begin_epoch
...
@@ -271,44 +232,7 @@ def train(cfg):
...
@@ -271,44 +232,7 @@ def train(cfg):
begin_epoch
=
load_checkpoint
(
exe
,
train_prog
)
begin_epoch
=
load_checkpoint
(
exe
,
train_prog
)
# Load pretrained model
# Load pretrained model
elif
os
.
path
.
exists
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
):
elif
os
.
path
.
exists
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
):
print_info
(
'Pretrained model dir: '
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
load_pretrained_weights
(
exe
,
train_prog
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
load_vars
=
[]
load_fail_vars
=
[]
def
var_shape_matched
(
var
,
shape
):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist
=
os
.
path
.
exists
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
if
var_exist
:
var_shape
=
parse_shape_from_file
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
if
var_shape
!=
shape
:
print
(
var
.
name
,
var_shape
,
shape
)
return
var_shape
==
shape
return
False
for
x
in
train_prog
.
list_vars
():
if
isinstance
(
x
,
fluid
.
framework
.
Parameter
):
shape
=
tuple
(
fluid
.
global_scope
().
find_var
(
x
.
name
).
get_tensor
().
shape
())
if
var_shape_matched
(
x
,
shape
):
load_vars
.
append
(
x
)
else
:
load_fail_vars
.
append
(
x
)
fluid
.
io
.
load_vars
(
exe
,
dirname
=
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
vars
=
load_vars
)
for
var
in
load_vars
:
print_info
(
"Parameter[{}] loaded sucessfully!"
.
format
(
var
.
name
))
for
var
in
load_fail_vars
:
print_info
(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it."
.
format
(
var
.
name
))
print_info
(
"{}/{} pretrained parameters loaded successfully!"
.
format
(
len
(
load_vars
),
len
(
load_vars
)
+
len
(
load_fail_vars
)))
else
:
else
:
print_info
(
print_info
(
'Pretrained model dir {} not exists, training from scratch...'
.
'Pretrained model dir {} not exists, training from scratch...'
.
...
@@ -393,8 +317,7 @@ def train(cfg):
...
@@ -393,8 +317,7 @@ def train(cfg):
avg_emb_loss
,
avg_acc
,
avg_fp
,
avg_fn
,
speed
,
avg_emb_loss
,
avg_acc
,
avg_fp
,
avg_fn
,
speed
,
calculate_eta
(
all_step
-
step
,
speed
)))
calculate_eta
(
all_step
-
step
,
speed
)))
if
args
.
use_vdl
:
if
args
.
use_vdl
:
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/speed'
,
speed
,
step
)
log_writer
.
add_scalar
(
'Train/speed'
,
speed
,
step
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
...
@@ -423,8 +346,7 @@ def train(cfg):
...
@@ -423,8 +346,7 @@ def train(cfg):
use_gpu
=
args
.
use_gpu
,
use_gpu
=
args
.
use_gpu
,
use_mpio
=
args
.
use_mpio
)
use_mpio
=
args
.
use_mpio
)
if
args
.
use_vdl
:
if
args
.
use_vdl
:
log_writer
.
add_scalar
(
'Evaluate/accuracy'
,
accuracy
,
log_writer
.
add_scalar
(
'Evaluate/accuracy'
,
accuracy
,
step
)
step
)
log_writer
.
add_scalar
(
'Evaluate/fp'
,
fp
,
step
)
log_writer
.
add_scalar
(
'Evaluate/fp'
,
fp
,
step
)
log_writer
.
add_scalar
(
'Evaluate/fn'
,
fn
,
step
)
log_writer
.
add_scalar
(
'Evaluate/fn'
,
fn
,
step
)
...
...
contrib/LaneNet/utils/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
contrib/LaneNet/utils/config.py
浏览文件 @
61645b1d
#
-*- coding: utf-8 -*-
#
coding: utf8
#
Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
.
#
Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve
.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -80,8 +80,8 @@ cfg.DATASET.DATA_DIM = 3
...
@@ -80,8 +80,8 @@ cfg.DATASET.DATA_DIM = 3
cfg
.
DATASET
.
SEPARATOR
=
' '
cfg
.
DATASET
.
SEPARATOR
=
' '
# 忽略的像素标签值, 默认为255,一般无需改动
# 忽略的像素标签值, 默认为255,一般无需改动
cfg
.
DATASET
.
IGNORE_INDEX
=
255
cfg
.
DATASET
.
IGNORE_INDEX
=
255
# 数据增强是图像的padding值
# 数据增强是图像的padding值
cfg
.
DATASET
.
PADDING_VALUE
=
[
127.5
,
127.5
,
127.5
]
cfg
.
DATASET
.
PADDING_VALUE
=
[
127.5
,
127.5
,
127.5
]
########################### 数据增强配置 ######################################
########################### 数据增强配置 ######################################
# 图像镜像左右翻转
# 图像镜像左右翻转
...
...
contrib/LaneNet/utils/dist_utils.py
浏览文件 @
61645b1d
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
#Licensed under the Apache License, Version 2.0 (the "License");
#
Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#
you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
#Unless required by applicable law or agreed to in writing, software
#
Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#
distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#
See the License for the specific language governing permissions and
#limitations under the License.
#
limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
...
...
contrib/LaneNet/utils/generate_tusimple_dataset.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
"""
generate tusimple training dataset
generate tusimple training dataset
"""
"""
...
@@ -14,12 +28,16 @@ import numpy as np
...
@@ -14,12 +28,16 @@ import numpy as np
def
init_args
():
def
init_args
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--src_dir'
,
type
=
str
,
help
=
'The origin path of unzipped tusimple dataset'
)
parser
.
add_argument
(
'--src_dir'
,
type
=
str
,
help
=
'The origin path of unzipped tusimple dataset'
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
process_json_file
(
json_file_path
,
src_dir
,
ori_dst_dir
,
binary_dst_dir
,
instance_dst_dir
):
def
process_json_file
(
json_file_path
,
src_dir
,
ori_dst_dir
,
binary_dst_dir
,
instance_dst_dir
):
assert
ops
.
exists
(
json_file_path
),
'{:s} not exist'
.
format
(
json_file_path
)
assert
ops
.
exists
(
json_file_path
),
'{:s} not exist'
.
format
(
json_file_path
)
...
@@ -39,11 +57,14 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
...
@@ -39,11 +57,14 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
h_samples
=
info_dict
[
'h_samples'
]
h_samples
=
info_dict
[
'h_samples'
]
lanes
=
info_dict
[
'lanes'
]
lanes
=
info_dict
[
'lanes'
]
image_name_new
=
'{:s}.png'
.
format
(
'{:d}'
.
format
(
line_index
+
image_nums
).
zfill
(
4
))
image_name_new
=
'{:s}.png'
.
format
(
'{:d}'
.
format
(
line_index
+
image_nums
).
zfill
(
4
))
src_image
=
cv2
.
imread
(
image_path
,
cv2
.
IMREAD_COLOR
)
src_image
=
cv2
.
imread
(
image_path
,
cv2
.
IMREAD_COLOR
)
dst_binary_image
=
np
.
zeros
([
src_image
.
shape
[
0
],
src_image
.
shape
[
1
]],
np
.
uint8
)
dst_binary_image
=
np
.
zeros
(
dst_instance_image
=
np
.
zeros
([
src_image
.
shape
[
0
],
src_image
.
shape
[
1
]],
np
.
uint8
)
[
src_image
.
shape
[
0
],
src_image
.
shape
[
1
]],
np
.
uint8
)
dst_instance_image
=
np
.
zeros
(
[
src_image
.
shape
[
0
],
src_image
.
shape
[
1
]],
np
.
uint8
)
for
lane_index
,
lane
in
enumerate
(
lanes
):
for
lane_index
,
lane
in
enumerate
(
lanes
):
assert
len
(
h_samples
)
==
len
(
lane
)
assert
len
(
h_samples
)
==
len
(
lane
)
...
@@ -62,13 +83,23 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
...
@@ -62,13 +83,23 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
lane_pts
=
np
.
vstack
((
lane_x
,
lane_y
)).
transpose
()
lane_pts
=
np
.
vstack
((
lane_x
,
lane_y
)).
transpose
()
lane_pts
=
np
.
array
([
lane_pts
],
np
.
int64
)
lane_pts
=
np
.
array
([
lane_pts
],
np
.
int64
)
cv2
.
polylines
(
dst_binary_image
,
lane_pts
,
isClosed
=
False
,
cv2
.
polylines
(
color
=
255
,
thickness
=
5
)
dst_binary_image
,
cv2
.
polylines
(
dst_instance_image
,
lane_pts
,
isClosed
=
False
,
lane_pts
,
color
=
lane_index
*
50
+
20
,
thickness
=
5
)
isClosed
=
False
,
color
=
255
,
dst_binary_image_path
=
ops
.
join
(
src_dir
,
binary_dst_dir
,
image_name_new
)
thickness
=
5
)
dst_instance_image_path
=
ops
.
join
(
src_dir
,
instance_dst_dir
,
image_name_new
)
cv2
.
polylines
(
dst_instance_image
,
lane_pts
,
isClosed
=
False
,
color
=
lane_index
*
50
+
20
,
thickness
=
5
)
dst_binary_image_path
=
ops
.
join
(
src_dir
,
binary_dst_dir
,
image_name_new
)
dst_instance_image_path
=
ops
.
join
(
src_dir
,
instance_dst_dir
,
image_name_new
)
dst_rgb_image_path
=
ops
.
join
(
src_dir
,
ori_dst_dir
,
image_name_new
)
dst_rgb_image_path
=
ops
.
join
(
src_dir
,
ori_dst_dir
,
image_name_new
)
cv2
.
imwrite
(
dst_binary_image_path
,
dst_binary_image
)
cv2
.
imwrite
(
dst_binary_image_path
,
dst_binary_image
)
...
@@ -78,7 +109,12 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
...
@@ -78,7 +109,12 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
print
(
'Process {:s} success'
.
format
(
image_name
))
print
(
'Process {:s} success'
.
format
(
image_name
))
def
gen_sample
(
src_dir
,
b_gt_image_dir
,
i_gt_image_dir
,
image_dir
,
phase
=
'train'
,
split
=
False
):
def
gen_sample
(
src_dir
,
b_gt_image_dir
,
i_gt_image_dir
,
image_dir
,
phase
=
'train'
,
split
=
False
):
label_list
=
[]
label_list
=
[]
with
open
(
'{:s}/{}ing/{}.txt'
.
format
(
src_dir
,
phase
,
phase
),
'w'
)
as
file
:
with
open
(
'{:s}/{}ing/{}.txt'
.
format
(
src_dir
,
phase
,
phase
),
'w'
)
as
file
:
...
@@ -92,7 +128,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
...
@@ -92,7 +128,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
image_path
=
ops
.
join
(
image_dir
,
image_name
)
image_path
=
ops
.
join
(
image_dir
,
image_name
)
assert
ops
.
exists
(
image_path
),
'{:s} not exist'
.
format
(
image_path
)
assert
ops
.
exists
(
image_path
),
'{:s} not exist'
.
format
(
image_path
)
assert
ops
.
exists
(
instance_gt_image_path
),
'{:s} not exist'
.
format
(
instance_gt_image_path
)
assert
ops
.
exists
(
instance_gt_image_path
),
'{:s} not exist'
.
format
(
instance_gt_image_path
)
b_gt_image
=
cv2
.
imread
(
binary_gt_image_path
,
cv2
.
IMREAD_COLOR
)
b_gt_image
=
cv2
.
imread
(
binary_gt_image_path
,
cv2
.
IMREAD_COLOR
)
i_gt_image
=
cv2
.
imread
(
instance_gt_image_path
,
cv2
.
IMREAD_COLOR
)
i_gt_image
=
cv2
.
imread
(
instance_gt_image_path
,
cv2
.
IMREAD_COLOR
)
...
@@ -102,7 +139,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
...
@@ -102,7 +139,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
print
(
'image: {:s} corrupt'
.
format
(
image_name
))
print
(
'image: {:s} corrupt'
.
format
(
image_name
))
continue
continue
else
:
else
:
info
=
'{:s} {:s} {:s}'
.
format
(
image_path
,
binary_gt_image_path
,
instance_gt_image_path
)
info
=
'{:s} {:s} {:s}'
.
format
(
image_path
,
binary_gt_image_path
,
instance_gt_image_path
)
file
.
write
(
info
+
'
\n
'
)
file
.
write
(
info
+
'
\n
'
)
label_list
.
append
(
info
)
label_list
.
append
(
info
)
if
phase
==
'train'
and
split
:
if
phase
==
'train'
and
split
:
...
@@ -110,10 +148,12 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
...
@@ -110,10 +148,12 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
val_list_len
=
len
(
label_list
)
//
10
val_list_len
=
len
(
label_list
)
//
10
val_label_list
=
label_list
[:
val_list_len
]
val_label_list
=
label_list
[:
val_list_len
]
train_label_list
=
label_list
[
val_list_len
:]
train_label_list
=
label_list
[
val_list_len
:]
with
open
(
'{:s}/{}ing/train_part.txt'
.
format
(
src_dir
,
phase
,
phase
),
'w'
)
as
file
:
with
open
(
'{:s}/{}ing/train_part.txt'
.
format
(
src_dir
,
phase
,
phase
),
'w'
)
as
file
:
for
info
in
train_label_list
:
for
info
in
train_label_list
:
file
.
write
(
info
+
'
\n
'
)
file
.
write
(
info
+
'
\n
'
)
with
open
(
'{:s}/{}ing/val_part.txt'
.
format
(
src_dir
,
phase
,
phase
),
'w'
)
as
file
:
with
open
(
'{:s}/{}ing/val_part.txt'
.
format
(
src_dir
,
phase
,
phase
),
'w'
)
as
file
:
for
info
in
val_label_list
:
for
info
in
val_label_list
:
file
.
write
(
info
+
'
\n
'
)
file
.
write
(
info
+
'
\n
'
)
return
return
...
@@ -130,12 +170,14 @@ def process_tusimple_dataset(src_dir):
...
@@ -130,12 +170,14 @@ def process_tusimple_dataset(src_dir):
for
json_label_path
in
glob
.
glob
(
'{:s}/label*.json'
.
format
(
src_dir
)):
for
json_label_path
in
glob
.
glob
(
'{:s}/label*.json'
.
format
(
src_dir
)):
json_label_name
=
ops
.
split
(
json_label_path
)[
1
]
json_label_name
=
ops
.
split
(
json_label_path
)[
1
]
shutil
.
copyfile
(
json_label_path
,
ops
.
join
(
traing_folder_path
,
json_label_name
))
shutil
.
copyfile
(
json_label_path
,
ops
.
join
(
traing_folder_path
,
json_label_name
))
for
json_label_path
in
glob
.
glob
(
'{:s}/test_label.json'
.
format
(
src_dir
)):
for
json_label_path
in
glob
.
glob
(
'{:s}/test_label.json'
.
format
(
src_dir
)):
json_label_name
=
ops
.
split
(
json_label_path
)[
1
]
json_label_name
=
ops
.
split
(
json_label_path
)[
1
]
shutil
.
copyfile
(
json_label_path
,
ops
.
join
(
testing_folder_path
,
json_label_name
))
shutil
.
copyfile
(
json_label_path
,
ops
.
join
(
testing_folder_path
,
json_label_name
))
train_gt_image_dir
=
ops
.
join
(
'training'
,
'gt_image'
)
train_gt_image_dir
=
ops
.
join
(
'training'
,
'gt_image'
)
train_gt_binary_dir
=
ops
.
join
(
'training'
,
'gt_binary_image'
)
train_gt_binary_dir
=
ops
.
join
(
'training'
,
'gt_binary_image'
)
...
@@ -154,9 +196,11 @@ def process_tusimple_dataset(src_dir):
...
@@ -154,9 +196,11 @@ def process_tusimple_dataset(src_dir):
os
.
makedirs
(
os
.
path
.
join
(
src_dir
,
test_gt_instance_dir
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
join
(
src_dir
,
test_gt_instance_dir
),
exist_ok
=
True
)
for
json_label_path
in
glob
.
glob
(
'{:s}/*.json'
.
format
(
traing_folder_path
)):
for
json_label_path
in
glob
.
glob
(
'{:s}/*.json'
.
format
(
traing_folder_path
)):
process_json_file
(
json_label_path
,
src_dir
,
train_gt_image_dir
,
train_gt_binary_dir
,
train_gt_instance_dir
)
process_json_file
(
json_label_path
,
src_dir
,
train_gt_image_dir
,
train_gt_binary_dir
,
train_gt_instance_dir
)
gen_sample
(
src_dir
,
train_gt_binary_dir
,
train_gt_instance_dir
,
train_gt_image_dir
,
'train'
,
True
)
gen_sample
(
src_dir
,
train_gt_binary_dir
,
train_gt_instance_dir
,
train_gt_image_dir
,
'train'
,
True
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
contrib/LaneNet/utils/lanenet_postprocess.py
浏览文件 @
61645b1d
#!/usr/bin/env python3
# coding: utf8
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this code heavily base on https://github.com/MaybeShewill-CV/lanenet-lane-detection/blob/master/lanenet_model/lanenet_postprocess.py
# this code heavily base on https://github.com/MaybeShewill-CV/lanenet-lane-detection/blob/master/lanenet_model/lanenet_postprocess.py
"""
"""
LaneNet model post process
LaneNet model post process
...
@@ -22,12 +35,14 @@ def _morphological_process(image, kernel_size=5):
...
@@ -22,12 +35,14 @@ def _morphological_process(image, kernel_size=5):
:return:
:return:
"""
"""
if
len
(
image
.
shape
)
==
3
:
if
len
(
image
.
shape
)
==
3
:
raise
ValueError
(
'Binary segmentation result image should be a single channel image'
)
raise
ValueError
(
'Binary segmentation result image should be a single channel image'
)
if
image
.
dtype
is
not
np
.
uint8
:
if
image
.
dtype
is
not
np
.
uint8
:
image
=
np
.
array
(
image
,
np
.
uint8
)
image
=
np
.
array
(
image
,
np
.
uint8
)
kernel
=
cv2
.
getStructuringElement
(
shape
=
cv2
.
MORPH_ELLIPSE
,
ksize
=
(
kernel_size
,
kernel_size
))
kernel
=
cv2
.
getStructuringElement
(
shape
=
cv2
.
MORPH_ELLIPSE
,
ksize
=
(
kernel_size
,
kernel_size
))
# close operation fille hole
# close operation fille hole
closing
=
cv2
.
morphologyEx
(
image
,
cv2
.
MORPH_CLOSE
,
kernel
,
iterations
=
1
)
closing
=
cv2
.
morphologyEx
(
image
,
cv2
.
MORPH_CLOSE
,
kernel
,
iterations
=
1
)
...
@@ -46,13 +61,15 @@ def _connect_components_analysis(image):
...
@@ -46,13 +61,15 @@ def _connect_components_analysis(image):
else
:
else
:
gray_image
=
image
gray_image
=
image
return
cv2
.
connectedComponentsWithStats
(
gray_image
,
connectivity
=
8
,
ltype
=
cv2
.
CV_32S
)
return
cv2
.
connectedComponentsWithStats
(
gray_image
,
connectivity
=
8
,
ltype
=
cv2
.
CV_32S
)
class
_LaneFeat
(
object
):
class
_LaneFeat
(
object
):
"""
"""
"""
"""
def
__init__
(
self
,
feat
,
coord
,
class_id
=-
1
):
def
__init__
(
self
,
feat
,
coord
,
class_id
=-
1
):
"""
"""
lane feat object
lane feat object
...
@@ -108,18 +125,21 @@ class _LaneNetCluster(object):
...
@@ -108,18 +125,21 @@ class _LaneNetCluster(object):
"""
"""
Instance segmentation result cluster
Instance segmentation result cluster
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
"""
"""
"""
"""
self
.
_color_map
=
[
np
.
array
([
255
,
0
,
0
]),
self
.
_color_map
=
[
np
.
array
([
0
,
255
,
0
]),
np
.
array
([
255
,
0
,
0
]),
np
.
array
([
0
,
0
,
255
]),
np
.
array
([
0
,
255
,
0
]),
np
.
array
([
125
,
125
,
0
]),
np
.
array
([
0
,
0
,
255
]),
np
.
array
([
0
,
125
,
125
]),
np
.
array
([
125
,
125
,
0
]),
np
.
array
([
125
,
0
,
125
]),
np
.
array
([
0
,
125
,
125
]),
np
.
array
([
50
,
100
,
50
]),
np
.
array
([
125
,
0
,
125
]),
np
.
array
([
100
,
50
,
100
])]
np
.
array
([
50
,
100
,
50
]),
np
.
array
([
100
,
50
,
100
])
]
@
staticmethod
@
staticmethod
def
_embedding_feats_dbscan_cluster
(
embedding_image_feats
):
def
_embedding_feats_dbscan_cluster
(
embedding_image_feats
):
...
@@ -186,15 +206,16 @@ class _LaneNetCluster(object):
...
@@ -186,15 +206,16 @@ class _LaneNetCluster(object):
# get embedding feats and coords
# get embedding feats and coords
get_lane_embedding_feats_result
=
self
.
_get_lane_embedding_feats
(
get_lane_embedding_feats_result
=
self
.
_get_lane_embedding_feats
(
binary_seg_ret
=
binary_seg_result
,
binary_seg_ret
=
binary_seg_result
,
instance_seg_ret
=
instance_seg_result
instance_seg_ret
=
instance_seg_result
)
)
# dbscan cluster
# dbscan cluster
dbscan_cluster_result
=
self
.
_embedding_feats_dbscan_cluster
(
dbscan_cluster_result
=
self
.
_embedding_feats_dbscan_cluster
(
embedding_image_feats
=
get_lane_embedding_feats_result
[
'lane_embedding_feats'
]
embedding_image_feats
=
get_lane_embedding_feats_result
[
)
'lane_embedding_feats'
]
)
mask
=
np
.
zeros
(
shape
=
[
binary_seg_result
.
shape
[
0
],
binary_seg_result
.
shape
[
1
],
3
],
dtype
=
np
.
uint8
)
mask
=
np
.
zeros
(
shape
=
[
binary_seg_result
.
shape
[
0
],
binary_seg_result
.
shape
[
1
],
3
],
dtype
=
np
.
uint8
)
db_labels
=
dbscan_cluster_result
[
'db_labels'
]
db_labels
=
dbscan_cluster_result
[
'db_labels'
]
unique_labels
=
dbscan_cluster_result
[
'unique_labels'
]
unique_labels
=
dbscan_cluster_result
[
'unique_labels'
]
coord
=
get_lane_embedding_feats_result
[
'lane_coordinates'
]
coord
=
get_lane_embedding_feats_result
[
'lane_coordinates'
]
...
@@ -219,11 +240,13 @@ class LaneNetPostProcessor(object):
...
@@ -219,11 +240,13 @@ class LaneNetPostProcessor(object):
"""
"""
lanenet post process for lane generation
lanenet post process for lane generation
"""
"""
def
__init__
(
self
,
ipm_remap_file_path
=
'./utils/tusimple_ipm_remap.yml'
):
def
__init__
(
self
,
ipm_remap_file_path
=
'./utils/tusimple_ipm_remap.yml'
):
"""
"""
convert front car view to bird view
convert front car view to bird view
"""
"""
assert
ops
.
exists
(
ipm_remap_file_path
),
'{:s} not exist'
.
format
(
ipm_remap_file_path
)
assert
ops
.
exists
(
ipm_remap_file_path
),
'{:s} not exist'
.
format
(
ipm_remap_file_path
)
self
.
_cluster
=
_LaneNetCluster
()
self
.
_cluster
=
_LaneNetCluster
()
self
.
_ipm_remap_file_path
=
ipm_remap_file_path
self
.
_ipm_remap_file_path
=
ipm_remap_file_path
...
@@ -232,14 +255,16 @@ class LaneNetPostProcessor(object):
...
@@ -232,14 +255,16 @@ class LaneNetPostProcessor(object):
self
.
_remap_to_ipm_x
=
remap_file_load_ret
[
'remap_to_ipm_x'
]
self
.
_remap_to_ipm_x
=
remap_file_load_ret
[
'remap_to_ipm_x'
]
self
.
_remap_to_ipm_y
=
remap_file_load_ret
[
'remap_to_ipm_y'
]
self
.
_remap_to_ipm_y
=
remap_file_load_ret
[
'remap_to_ipm_y'
]
self
.
_color_map
=
[
np
.
array
([
255
,
0
,
0
]),
self
.
_color_map
=
[
np
.
array
([
0
,
255
,
0
]),
np
.
array
([
255
,
0
,
0
]),
np
.
array
([
0
,
0
,
255
]),
np
.
array
([
0
,
255
,
0
]),
np
.
array
([
125
,
125
,
0
]),
np
.
array
([
0
,
0
,
255
]),
np
.
array
([
0
,
125
,
125
]),
np
.
array
([
125
,
125
,
0
]),
np
.
array
([
125
,
0
,
125
]),
np
.
array
([
0
,
125
,
125
]),
np
.
array
([
50
,
100
,
50
]),
np
.
array
([
125
,
0
,
125
]),
np
.
array
([
100
,
50
,
100
])]
np
.
array
([
50
,
100
,
50
]),
np
.
array
([
100
,
50
,
100
])
]
def
_load_remap_matrix
(
self
):
def
_load_remap_matrix
(
self
):
fs
=
cv2
.
FileStorage
(
self
.
_ipm_remap_file_path
,
cv2
.
FILE_STORAGE_READ
)
fs
=
cv2
.
FileStorage
(
self
.
_ipm_remap_file_path
,
cv2
.
FILE_STORAGE_READ
)
...
@@ -256,15 +281,20 @@ class LaneNetPostProcessor(object):
...
@@ -256,15 +281,20 @@ class LaneNetPostProcessor(object):
return
ret
return
ret
def
postprocess
(
self
,
binary_seg_result
,
instance_seg_result
=
None
,
def
postprocess
(
self
,
min_area_threshold
=
100
,
source_image
=
None
,
binary_seg_result
,
instance_seg_result
=
None
,
min_area_threshold
=
100
,
source_image
=
None
,
data_source
=
'tusimple'
):
data_source
=
'tusimple'
):
# convert binary_seg_result
# convert binary_seg_result
binary_seg_result
=
np
.
array
(
binary_seg_result
*
255
,
dtype
=
np
.
uint8
)
binary_seg_result
=
np
.
array
(
binary_seg_result
*
255
,
dtype
=
np
.
uint8
)
# apply image morphology operation to fill in the hold and reduce the small area
# apply image morphology operation to fill in the hold and reduce the small area
morphological_ret
=
_morphological_process
(
binary_seg_result
,
kernel_size
=
5
)
morphological_ret
=
_morphological_process
(
connect_components_analysis_ret
=
_connect_components_analysis
(
image
=
morphological_ret
)
binary_seg_result
,
kernel_size
=
5
)
connect_components_analysis_ret
=
_connect_components_analysis
(
image
=
morphological_ret
)
labels
=
connect_components_analysis_ret
[
1
]
labels
=
connect_components_analysis_ret
[
1
]
stats
=
connect_components_analysis_ret
[
2
]
stats
=
connect_components_analysis_ret
[
2
]
...
@@ -276,8 +306,7 @@ class LaneNetPostProcessor(object):
...
@@ -276,8 +306,7 @@ class LaneNetPostProcessor(object):
# apply embedding features cluster
# apply embedding features cluster
mask_image
,
lane_coords
=
self
.
_cluster
.
apply_lane_feats_cluster
(
mask_image
,
lane_coords
=
self
.
_cluster
.
apply_lane_feats_cluster
(
binary_seg_result
=
morphological_ret
,
binary_seg_result
=
morphological_ret
,
instance_seg_result
=
instance_seg_result
instance_seg_result
=
instance_seg_result
)
)
if
mask_image
is
None
:
if
mask_image
is
None
:
return
{
return
{
...
@@ -292,15 +321,15 @@ class LaneNetPostProcessor(object):
...
@@ -292,15 +321,15 @@ class LaneNetPostProcessor(object):
for
lane_index
,
coords
in
enumerate
(
lane_coords
):
for
lane_index
,
coords
in
enumerate
(
lane_coords
):
if
data_source
==
'tusimple'
:
if
data_source
==
'tusimple'
:
tmp_mask
=
np
.
zeros
(
shape
=
(
720
,
1280
),
dtype
=
np
.
uint8
)
tmp_mask
=
np
.
zeros
(
shape
=
(
720
,
1280
),
dtype
=
np
.
uint8
)
tmp_mask
[
tuple
((
np
.
int_
(
coords
[:,
1
]
*
720
/
256
),
np
.
int_
(
coords
[:,
0
]
*
1280
/
512
)))]
=
255
tmp_mask
[
tuple
((
np
.
int_
(
coords
[:,
1
]
*
720
/
256
),
np
.
int_
(
coords
[:,
0
]
*
1280
/
512
)))]
=
255
else
:
else
:
raise
ValueError
(
'Wrong data source now only support tusimple'
)
raise
ValueError
(
'Wrong data source now only support tusimple'
)
tmp_ipm_mask
=
cv2
.
remap
(
tmp_ipm_mask
=
cv2
.
remap
(
tmp_mask
,
tmp_mask
,
self
.
_remap_to_ipm_x
,
self
.
_remap_to_ipm_x
,
self
.
_remap_to_ipm_y
,
self
.
_remap_to_ipm_y
,
interpolation
=
cv2
.
INTER_NEAREST
interpolation
=
cv2
.
INTER_NEAREST
)
)
nonzero_y
=
np
.
array
(
tmp_ipm_mask
.
nonzero
()[
0
])
nonzero_y
=
np
.
array
(
tmp_ipm_mask
.
nonzero
()[
0
])
nonzero_x
=
np
.
array
(
tmp_ipm_mask
.
nonzero
()[
1
])
nonzero_x
=
np
.
array
(
tmp_ipm_mask
.
nonzero
()[
1
])
...
@@ -309,16 +338,19 @@ class LaneNetPostProcessor(object):
...
@@ -309,16 +338,19 @@ class LaneNetPostProcessor(object):
[
ipm_image_height
,
ipm_image_width
]
=
tmp_ipm_mask
.
shape
[
ipm_image_height
,
ipm_image_width
]
=
tmp_ipm_mask
.
shape
plot_y
=
np
.
linspace
(
10
,
ipm_image_height
,
ipm_image_height
-
10
)
plot_y
=
np
.
linspace
(
10
,
ipm_image_height
,
ipm_image_height
-
10
)
fit_x
=
fit_param
[
0
]
*
plot_y
**
2
+
fit_param
[
1
]
*
plot_y
+
fit_param
[
2
]
fit_x
=
fit_param
[
0
]
*
plot_y
**
2
+
fit_param
[
1
]
*
plot_y
+
fit_param
[
2
]
lane_pts
=
[]
lane_pts
=
[]
for
index
in
range
(
0
,
plot_y
.
shape
[
0
],
5
):
for
index
in
range
(
0
,
plot_y
.
shape
[
0
],
5
):
src_x
=
self
.
_remap_to_ipm_x
[
src_x
=
self
.
_remap_to_ipm_x
[
int
(
plot_y
[
index
]),
int
(
np
.
clip
(
fit_x
[
index
],
0
,
ipm_image_width
-
1
))]
int
(
plot_y
[
index
]),
int
(
np
.
clip
(
fit_x
[
index
],
0
,
ipm_image_width
-
1
))]
if
src_x
<=
0
:
if
src_x
<=
0
:
continue
continue
src_y
=
self
.
_remap_to_ipm_y
[
src_y
=
self
.
_remap_to_ipm_y
[
int
(
plot_y
[
index
]),
int
(
np
.
clip
(
fit_x
[
index
],
0
,
ipm_image_width
-
1
))]
int
(
plot_y
[
index
]),
int
(
np
.
clip
(
fit_x
[
index
],
0
,
ipm_image_width
-
1
))]
src_y
=
src_y
if
src_y
>
0
else
0
src_y
=
src_y
if
src_y
>
0
else
0
lane_pts
.
append
([
src_x
,
src_y
])
lane_pts
.
append
([
src_x
,
src_y
])
...
@@ -366,8 +398,10 @@ class LaneNetPostProcessor(object):
...
@@ -366,8 +398,10 @@ class LaneNetPostProcessor(object):
continue
continue
lane_color
=
self
.
_color_map
[
index
].
tolist
()
lane_color
=
self
.
_color_map
[
index
].
tolist
()
cv2
.
circle
(
source_image
,
(
int
(
interpolation_src_pt_x
),
cv2
.
circle
(
int
(
interpolation_src_pt_y
)),
5
,
lane_color
,
-
1
)
source_image
,
(
int
(
interpolation_src_pt_x
),
int
(
interpolation_src_pt_y
)),
5
,
lane_color
,
-
1
)
ret
=
{
ret
=
{
'mask_image'
:
mask_image
,
'mask_image'
:
mask_image
,
'fit_params'
:
fit_params
,
'fit_params'
:
fit_params
,
...
...
contrib/LaneNet/utils/load_model_utils.py
0 → 100644
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
os.path
as
osp
import
six
import
numpy
as
np
def
parse_param_file
(
param_file
,
return_shape
=
True
):
from
paddle.fluid.proto.framework_pb2
import
VarType
f
=
open
(
param_file
,
'rb'
)
version
=
np
.
fromstring
(
f
.
read
(
4
),
dtype
=
'int32'
)
lod_level
=
np
.
fromstring
(
f
.
read
(
8
),
dtype
=
'int64'
)
for
i
in
range
(
int
(
lod_level
)):
_size
=
np
.
fromstring
(
f
.
read
(
8
),
dtype
=
'int64'
)
_
=
f
.
read
(
_size
)
version
=
np
.
fromstring
(
f
.
read
(
4
),
dtype
=
'int32'
)
tensor_desc
=
VarType
.
TensorDesc
()
tensor_desc_size
=
np
.
fromstring
(
f
.
read
(
4
),
dtype
=
'int32'
)
tensor_desc
.
ParseFromString
(
f
.
read
(
int
(
tensor_desc_size
)))
tensor_shape
=
tuple
(
tensor_desc
.
dims
)
if
return_shape
:
f
.
close
()
return
tuple
(
tensor_desc
.
dims
)
if
tensor_desc
.
data_type
!=
5
:
raise
Exception
(
"Unexpected data type while parse {}"
.
format
(
param_file
))
data_size
=
4
for
i
in
range
(
len
(
tensor_shape
)):
data_size
*=
tensor_shape
[
i
]
weight
=
np
.
fromstring
(
f
.
read
(
data_size
),
dtype
=
'float32'
)
f
.
close
()
return
np
.
reshape
(
weight
,
tensor_shape
)
def
load_pdparams
(
exe
,
main_prog
,
model_dir
):
import
paddle.fluid
as
fluid
from
paddle.fluid.proto.framework_pb2
import
VarType
from
paddle.fluid.framework
import
Program
vars_to_load
=
list
()
vars_not_load
=
list
()
import
pickle
with
open
(
osp
.
join
(
model_dir
,
'model.pdparams'
),
'rb'
)
as
f
:
params_dict
=
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
f
,
encoding
=
'latin1'
)
unused_vars
=
list
()
for
var
in
main_prog
.
list_vars
():
if
not
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
continue
if
var
.
name
not
in
params_dict
:
print
(
"{} is not in saved model"
.
format
(
var
.
name
))
vars_not_load
.
append
(
var
.
name
)
continue
if
var
.
shape
!=
params_dict
[
var
.
name
].
shape
:
unused_vars
.
append
(
var
.
name
)
vars_not_load
.
append
(
var
.
name
)
print
(
"[SKIP] Shape of pretrained weight {} doesn't match.(Pretrained: {}, Actual: {})"
.
format
(
var
.
name
,
params_dict
[
var
.
name
].
shape
,
var
.
shape
))
continue
vars_to_load
.
append
(
var
)
for
var_name
in
unused_vars
:
del
params_dict
[
var_name
]
fluid
.
io
.
set_program_state
(
main_prog
,
params_dict
)
if
len
(
vars_to_load
)
==
0
:
print
(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
)
else
:
print
(
"There are {}/{} varaibles in {} are loaded."
.
format
(
len
(
vars_to_load
),
len
(
vars_to_load
)
+
len
(
vars_not_load
),
model_dir
))
def
load_pretrained_weights
(
exe
,
main_prog
,
weights_dir
):
if
not
osp
.
exists
(
weights_dir
):
raise
Exception
(
"Path {} not exists."
.
format
(
weights_dir
))
if
osp
.
exists
(
osp
.
join
(
weights_dir
,
"model.pdparams"
)):
return
load_pdparams
(
exe
,
main_prog
,
weights_dir
)
import
paddle.fluid
as
fluid
vars_to_load
=
list
()
vars_not_load
=
list
()
for
var
in
main_prog
.
list_vars
():
if
not
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
continue
if
not
osp
.
exists
(
osp
.
join
(
weights_dir
,
var
.
name
)):
print
(
"[SKIP] Pretrained weight {}/{} doesn't exist"
.
format
(
weights_dir
,
var
.
name
))
vars_not_load
.
append
(
var
)
continue
pretrained_shape
=
parse_param_file
(
osp
.
join
(
weights_dir
,
var
.
name
))
actual_shape
=
tuple
(
var
.
shape
)
if
pretrained_shape
!=
actual_shape
:
print
(
"[SKIP] Shape of pretrained weight {}/{} doesn't match.(Pretrained: {}, Actual: {})"
.
format
(
weights_dir
,
var
.
name
,
pretrained_shape
,
actual_shape
))
vars_not_load
.
append
(
var
)
continue
vars_to_load
.
append
(
var
)
params_dict
=
fluid
.
io
.
load_program_state
(
weights_dir
,
var_list
=
vars_to_load
)
fluid
.
io
.
set_program_state
(
main_prog
,
params_dict
)
if
len
(
vars_to_load
)
==
0
:
print
(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
)
else
:
print
(
"There are {}/{} varaibles in {} are loaded."
.
format
(
len
(
vars_to_load
),
len
(
vars_to_load
)
+
len
(
vars_not_load
),
weights_dir
))
contrib/LaneNet/vis.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -45,6 +45,7 @@ from models.model_builder import ModelPhase
...
@@ -45,6 +45,7 @@ from models.model_builder import ModelPhase
from
utils
import
lanenet_postprocess
from
utils
import
lanenet_postprocess
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'PaddeSeg visualization tools'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'PaddeSeg visualization tools'
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -106,7 +107,6 @@ def minmax_scale(input_arr):
...
@@ -106,7 +107,6 @@ def minmax_scale(input_arr):
return
output_arr
return
output_arr
def
visualize
(
cfg
,
def
visualize
(
cfg
,
vis_file_list
=
None
,
vis_file_list
=
None
,
use_gpu
=
False
,
use_gpu
=
False
,
...
@@ -119,7 +119,6 @@ def visualize(cfg,
...
@@ -119,7 +119,6 @@ def visualize(cfg,
if
vis_file_list
is
None
:
if
vis_file_list
is
None
:
vis_file_list
=
cfg
.
DATASET
.
TEST_FILE_LIST
vis_file_list
=
cfg
.
DATASET
.
TEST_FILE_LIST
dataset
=
LaneNetDataset
(
dataset
=
LaneNetDataset
(
file_list
=
vis_file_list
,
file_list
=
vis_file_list
,
mode
=
ModelPhase
.
VISUAL
,
mode
=
ModelPhase
.
VISUAL
,
...
@@ -139,7 +138,12 @@ def visualize(cfg,
...
@@ -139,7 +138,12 @@ def visualize(cfg,
ckpt_dir
=
cfg
.
TEST
.
TEST_MODEL
if
not
ckpt_dir
else
ckpt_dir
ckpt_dir
=
cfg
.
TEST
.
TEST_MODEL
if
not
ckpt_dir
else
ckpt_dir
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
if
ckpt_dir
is
not
None
:
print
(
'load test model:'
,
ckpt_dir
)
try
:
fluid
.
load
(
test_prog
,
os
.
path
.
join
(
ckpt_dir
,
'model'
),
exe
)
except
:
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
save_dir
=
os
.
path
.
join
(
vis_dir
,
'visual_results'
)
save_dir
=
os
.
path
.
join
(
vis_dir
,
'visual_results'
)
makedirs
(
save_dir
)
makedirs
(
save_dir
)
...
@@ -161,22 +165,26 @@ def visualize(cfg,
...
@@ -161,22 +165,26 @@ def visualize(cfg,
for
i
in
range
(
num_imgs
):
for
i
in
range
(
num_imgs
):
gt_image
=
org_imgs
[
i
]
gt_image
=
org_imgs
[
i
]
binary_seg_image
,
instance_seg_image
=
segLogits
[
i
].
squeeze
(
-
1
),
emLogits
[
i
].
transpose
((
1
,
2
,
0
))
binary_seg_image
,
instance_seg_image
=
segLogits
[
i
].
squeeze
(
-
1
),
emLogits
[
i
].
transpose
((
1
,
2
,
0
))
postprocess_result
=
postprocessor
.
postprocess
(
postprocess_result
=
postprocessor
.
postprocess
(
binary_seg_result
=
binary_seg_image
,
binary_seg_result
=
binary_seg_image
,
instance_seg_result
=
instance_seg_image
,
instance_seg_result
=
instance_seg_image
,
source_image
=
gt_image
source_image
=
gt_image
)
)
pred_binary_fn
=
os
.
path
.
join
(
pred_binary_fn
=
os
.
path
.
join
(
save_dir
,
to_png_fn
(
img_names
[
i
],
name
=
'_pred_binary'
))
save_dir
,
to_png_fn
(
img_names
[
i
],
name
=
'_pred_binary'
))
pred_lane_fn
=
os
.
path
.
join
(
save_dir
,
to_png_fn
(
img_names
[
i
],
name
=
'_pred_lane'
))
pred_lane_fn
=
os
.
path
.
join
(
pred_instance_fn
=
os
.
path
.
join
(
save_dir
,
to_png_fn
(
img_names
[
i
],
name
=
'_pred_instance'
))
save_dir
,
to_png_fn
(
img_names
[
i
],
name
=
'_pred_lane'
))
pred_instance_fn
=
os
.
path
.
join
(
save_dir
,
to_png_fn
(
img_names
[
i
],
name
=
'_pred_instance'
))
dirname
=
os
.
path
.
dirname
(
pred_binary_fn
)
dirname
=
os
.
path
.
dirname
(
pred_binary_fn
)
makedirs
(
dirname
)
makedirs
(
dirname
)
mask_image
=
postprocess_result
[
'mask_image'
]
mask_image
=
postprocess_result
[
'mask_image'
]
for
i
in
range
(
4
):
for
i
in
range
(
4
):
instance_seg_image
[:,
:,
i
]
=
minmax_scale
(
instance_seg_image
[:,
:,
i
])
instance_seg_image
[:,
:,
i
]
=
minmax_scale
(
instance_seg_image
[:,
:,
i
])
embedding_image
=
np
.
array
(
instance_seg_image
).
astype
(
np
.
uint8
)
embedding_image
=
np
.
array
(
instance_seg_image
).
astype
(
np
.
uint8
)
plt
.
figure
(
'mask_image'
)
plt
.
figure
(
'mask_image'
)
...
@@ -189,13 +197,13 @@ def visualize(cfg,
...
@@ -189,13 +197,13 @@ def visualize(cfg,
plt
.
imshow
(
binary_seg_image
*
255
,
cmap
=
'gray'
)
plt
.
imshow
(
binary_seg_image
*
255
,
cmap
=
'gray'
)
plt
.
show
()
plt
.
show
()
cv2
.
imwrite
(
pred_binary_fn
,
np
.
array
(
binary_seg_image
*
255
).
astype
(
np
.
uint8
))
cv2
.
imwrite
(
pred_binary_fn
,
np
.
array
(
binary_seg_image
*
255
).
astype
(
np
.
uint8
))
cv2
.
imwrite
(
pred_lane_fn
,
postprocess_result
[
'source_image'
])
cv2
.
imwrite
(
pred_lane_fn
,
postprocess_result
[
'source_image'
])
cv2
.
imwrite
(
pred_instance_fn
,
mask_image
)
cv2
.
imwrite
(
pred_instance_fn
,
mask_image
)
print
(
pred_lane_fn
,
'saved!'
)
print
(
pred_lane_fn
,
'saved!'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
args
=
parse_args
()
args
=
parse_args
()
if
args
.
cfg_file
is
not
None
:
if
args
.
cfg_file
is
not
None
:
...
...
contrib/MechanicalIndustryMeter/download_mini_mechanical_industry_meter.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/MechanicalIndustryMeter/download_unet_mechanical_industry_meter.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -23,7 +24,8 @@ from test_utils import download_file_and_uncompress
...
@@ -23,7 +24,8 @@ from test_utils import download_file_and_uncompress
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
download_file_and_uncompress
(
download_file_and_uncompress
(
url
=
'https://paddleseg.bj.bcebos.com/models/unet_mechanical_industry_meter.tar'
,
url
=
'https://paddleseg.bj.bcebos.com/models/unet_mechanical_industry_meter.tar'
,
savepath
=
LOCAL_PATH
,
savepath
=
LOCAL_PATH
,
extrapath
=
LOCAL_PATH
)
extrapath
=
LOCAL_PATH
)
...
...
contrib/RemoteSensing/__init__.py
浏览文件 @
61645b1d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/RemoteSensing/models/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.load_model
import
*
from
.load_model
import
*
from
.unet
import
*
from
.unet
import
*
from
.hrnet
import
*
from
.hrnet
import
*
contrib/RemoteSensing/models/base.py
浏览文件 @
61645b1d
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
#Licensed under the Apache License, Version 2.0 (the "License");
#
Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#
you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
#Unless required by applicable law or agreed to in writing, software
#
Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#
distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#
See the License for the specific language governing permissions and
#limitations under the License.
#
limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
absolute_import
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
...
...
contrib/RemoteSensing/models/load_model.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/models/unet.py
浏览文件 @
61645b1d
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
#Licensed under the Apache License, Version 2.0 (the "License");
#
Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#
you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
#Unless required by applicable law or agreed to in writing, software
#
Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#
distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#
See the License for the specific language governing permissions and
#limitations under the License.
#
limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
absolute_import
import
numpy
as
np
import
numpy
as
np
...
...
contrib/RemoteSensing/nets/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.unet
import
UNet
from
.unet
import
UNet
from
.hrnet
import
HRNet
from
.hrnet
import
HRNet
contrib/RemoteSensing/nets/libs.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
copyright (c) 2020
PaddlePaddle Authors. All Rights Reserve.
#
Copyright (c) 2019
PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/nets/loss.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/nets/unet.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
copyright (c) 2020
PaddlePaddle Authors. All Rights Reserve.
#
Copyright (c) 2019
PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/predict_demo.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
os
import
os.path
as
osp
import
os.path
as
osp
import
sys
import
sys
...
...
contrib/RemoteSensing/readers/__init__.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/readers/base.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/readers/reader.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/tools/create_dataset_list.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/tools/split_dataset_list.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/train_demo.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os.path
as
osp
import
os.path
as
osp
import
argparse
import
argparse
import
transforms.transforms
as
T
import
transforms.transforms
as
T
...
...
contrib/RemoteSensing/transforms/__init__.py
浏览文件 @
61645b1d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/RemoteSensing/transforms/ops.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/transforms/transforms.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
copyright (c) 2020
PaddlePaddle Authors. All Rights Reserve.
#
Copyright (c) 2019
PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/utils/__init__.py
浏览文件 @
61645b1d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/RemoteSensing/utils/logging.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/utils/metrics.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/utils/pretrain_weights.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os.path
as
osp
import
os.path
as
osp
...
...
contrib/RemoteSensing/utils/utils.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -201,11 +202,9 @@ def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False):
...
@@ -201,11 +202,9 @@ def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False):
vars_to_load
.
append
(
var
)
vars_to_load
.
append
(
var
)
logging
.
debug
(
"Weight {} will be load"
.
format
(
var
.
name
))
logging
.
debug
(
"Weight {} will be load"
.
format
(
var
.
name
))
fluid
.
io
.
load_vars
(
params_dict
=
fluid
.
io
.
load_program_state
(
executor
=
exe
,
weights_dir
,
var_list
=
vars_to_load
)
dirname
=
weights_dir
,
fluid
.
io
.
set_program_state
(
main_prog
,
params_dict
)
main_program
=
main_prog
,
vars
=
vars_to_load
)
if
len
(
vars_to_load
)
==
0
:
if
len
(
vars_to_load
)
==
0
:
logging
.
warning
(
logging
.
warning
(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
...
...
contrib/RoadLine/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
contrib/RoadLine/config.py
浏览文件 @
61645b1d
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
utils.util
import
AttrDict
,
merge_cfg_from_args
,
get_arguments
from
utils.util
import
AttrDict
,
merge_cfg_from_args
,
get_arguments
import
os
import
os
...
@@ -6,20 +20,20 @@ args = get_arguments()
...
@@ -6,20 +20,20 @@ args = get_arguments()
cfg
=
AttrDict
()
cfg
=
AttrDict
()
# 待预测图像所在路径
# 待预测图像所在路径
cfg
.
data_dir
=
os
.
path
.
join
(
args
.
example
,
"data"
,
"test_images"
)
cfg
.
data_dir
=
os
.
path
.
join
(
args
.
example
,
"data"
,
"test_images"
)
# 待预测图像名称列表
# 待预测图像名称列表
cfg
.
data_list_file
=
os
.
path
.
join
(
args
.
example
,
"data"
,
"test.txt"
)
cfg
.
data_list_file
=
os
.
path
.
join
(
args
.
example
,
"data"
,
"test.txt"
)
# 模型加载路径
# 模型加载路径
cfg
.
model_path
=
os
.
path
.
join
(
args
.
example
,
"model"
)
cfg
.
model_path
=
os
.
path
.
join
(
args
.
example
,
"model"
)
# 预测结果保存路径
# 预测结果保存路径
cfg
.
vis_dir
=
os
.
path
.
join
(
args
.
example
,
"result"
)
cfg
.
vis_dir
=
os
.
path
.
join
(
args
.
example
,
"result"
)
# 预测类别数
# 预测类别数
cfg
.
class_num
=
2
cfg
.
class_num
=
2
# 均值, 图像预处理减去的均值
# 均值, 图像预处理减去的均值
cfg
.
MEAN
=
127.5
,
127.5
,
127.5
cfg
.
MEAN
=
127.5
,
127.5
,
127.5
# 标准差,图像预处理除以标准差
# 标准差,图像预处理除以标准差
cfg
.
STD
=
127.5
,
127.5
,
127.5
cfg
.
STD
=
127.5
,
127.5
,
127.5
# 待预测图像输入尺寸
# 待预测图像输入尺寸
cfg
.
input_size
=
1536
,
576
cfg
.
input_size
=
1536
,
576
...
...
contrib/RoadLine/download_RoadLine.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/RoadLine/infer.py
浏览文件 @
61645b1d
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
os
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
...
@@ -12,18 +26,19 @@ config = importlib.import_module('config')
...
@@ -12,18 +26,19 @@ config = importlib.import_module('config')
cfg
=
getattr
(
config
,
'cfg'
)
cfg
=
getattr
(
config
,
'cfg'
)
# paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启
# paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启
os
.
environ
[
'FLAGS_eager_delete_tensor_gb'
]
=
'0.0'
os
.
environ
[
'FLAGS_eager_delete_tensor_gb'
]
=
'0.0'
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
# 预测数据集类
# 预测数据集类
class
TestDataSet
():
class
TestDataSet
():
def
__init__
(
self
):
def
__init__
(
self
):
self
.
data_dir
=
cfg
.
data_dir
self
.
data_dir
=
cfg
.
data_dir
self
.
data_list_file
=
cfg
.
data_list_file
self
.
data_list_file
=
cfg
.
data_list_file
self
.
data_list
=
self
.
get_data_list
()
self
.
data_list
=
self
.
get_data_list
()
self
.
data_num
=
len
(
self
.
data_list
)
self
.
data_num
=
len
(
self
.
data_list
)
def
get_data_list
(
self
):
def
get_data_list
(
self
):
# 获取预测图像路径列表
# 获取预测图像路径列表
data_list
=
[]
data_list
=
[]
...
@@ -40,7 +55,7 @@ class TestDataSet():
...
@@ -40,7 +55,7 @@ class TestDataSet():
def
preprocess
(
self
,
img
):
def
preprocess
(
self
,
img
):
# 图像预处理
# 图像预处理
if
cfg
.
example
==
'ACE2P'
:
if
cfg
.
example
==
'ACE2P'
:
reader
=
importlib
.
import_module
(
args
.
example
+
'.reader'
)
reader
=
importlib
.
import_module
(
args
.
example
+
'.reader'
)
ACE2P_preprocess
=
getattr
(
reader
,
'preprocess'
)
ACE2P_preprocess
=
getattr
(
reader
,
'preprocess'
)
img
=
ACE2P_preprocess
(
img
)
img
=
ACE2P_preprocess
(
img
)
else
:
else
:
...
@@ -56,10 +71,10 @@ class TestDataSet():
...
@@ -56,10 +71,10 @@ class TestDataSet():
img_path
=
self
.
data_list
[
index
]
img_path
=
self
.
data_list
[
index
]
img
=
cv2
.
imread
(
img_path
,
cv2
.
IMREAD_COLOR
)
img
=
cv2
.
imread
(
img_path
,
cv2
.
IMREAD_COLOR
)
if
img
is
None
:
if
img
is
None
:
return
img
,
img
,
img_path
,
None
return
img
,
img
,
img_path
,
None
img_name
=
img_path
.
split
(
os
.
sep
)[
-
1
]
img_name
=
img_path
.
split
(
os
.
sep
)[
-
1
]
name_prefix
=
img_name
.
replace
(
'.'
+
img_name
.
split
(
'.'
)[
-
1
],
''
)
name_prefix
=
img_name
.
replace
(
'.'
+
img_name
.
split
(
'.'
)[
-
1
],
''
)
img_shape
=
img
.
shape
[:
2
]
img_shape
=
img
.
shape
[:
2
]
img_process
=
self
.
preprocess
(
img
)
img_process
=
self
.
preprocess
(
img
)
...
@@ -90,39 +105,44 @@ def infer():
...
@@ -90,39 +105,44 @@ def infer():
if
image
is
None
:
if
image
is
None
:
print
(
im_name
,
'is None'
)
print
(
im_name
,
'is None'
)
continue
continue
# 预测
# 预测
if
cfg
.
example
==
'ACE2P'
:
if
cfg
.
example
==
'ACE2P'
:
# ACE2P模型使用多尺度预测
# ACE2P模型使用多尺度预测
reader
=
importlib
.
import_module
(
args
.
example
+
'.reader'
)
reader
=
importlib
.
import_module
(
args
.
example
+
'.reader'
)
multi_scale_test
=
getattr
(
reader
,
'multi_scale_test'
)
multi_scale_test
=
getattr
(
reader
,
'multi_scale_test'
)
parsing
,
logits
=
multi_scale_test
(
exe
,
test_prog
,
feed_name
,
fetch_list
,
image
,
im_shape
)
parsing
,
logits
=
multi_scale_test
(
exe
,
test_prog
,
feed_name
,
fetch_list
,
image
,
im_shape
)
else
:
else
:
# HumanSeg,RoadLine模型单尺度预测
# HumanSeg,RoadLine模型单尺度预测
result
=
exe
.
run
(
program
=
test_prog
,
feed
=
{
feed_name
[
0
]:
image
},
fetch_list
=
fetch_list
)
result
=
exe
.
run
(
program
=
test_prog
,
feed
=
{
feed_name
[
0
]:
image
},
fetch_list
=
fetch_list
)
parsing
=
np
.
argmax
(
result
[
0
][
0
],
axis
=
0
)
parsing
=
np
.
argmax
(
result
[
0
][
0
],
axis
=
0
)
parsing
=
cv2
.
resize
(
parsing
.
astype
(
np
.
uint8
),
im_shape
[::
-
1
])
parsing
=
cv2
.
resize
(
parsing
.
astype
(
np
.
uint8
),
im_shape
[::
-
1
])
# 预测结果保存
# 预测结果保存
result_path
=
os
.
path
.
join
(
cfg
.
vis_dir
,
im_name
+
'.png'
)
result_path
=
os
.
path
.
join
(
cfg
.
vis_dir
,
im_name
+
'.png'
)
if
cfg
.
example
==
'HumanSeg'
:
if
cfg
.
example
==
'HumanSeg'
:
logits
=
result
[
0
][
0
][
1
]
*
255
logits
=
result
[
0
][
0
][
1
]
*
255
logits
=
cv2
.
resize
(
logits
,
im_shape
[::
-
1
])
logits
=
cv2
.
resize
(
logits
,
im_shape
[::
-
1
])
ret
,
logits
=
cv2
.
threshold
(
logits
,
thresh
,
0
,
cv2
.
THRESH_TOZERO
)
ret
,
logits
=
cv2
.
threshold
(
logits
,
thresh
,
0
,
cv2
.
THRESH_TOZERO
)
logits
=
255
*
(
logits
-
thresh
)
/
(
255
-
thresh
)
logits
=
255
*
(
logits
-
thresh
)
/
(
255
-
thresh
)
# 将分割结果添加到alpha通道
# 将分割结果添加到alpha通道
rgba
=
np
.
concatenate
((
ori_img
,
np
.
expand_dims
(
logits
,
axis
=
2
)),
axis
=
2
)
rgba
=
np
.
concatenate
((
ori_img
,
np
.
expand_dims
(
logits
,
axis
=
2
)),
axis
=
2
)
cv2
.
imwrite
(
result_path
,
rgba
)
cv2
.
imwrite
(
result_path
,
rgba
)
else
:
else
:
output_im
=
PILImage
.
fromarray
(
np
.
asarray
(
parsing
,
dtype
=
np
.
uint8
))
output_im
=
PILImage
.
fromarray
(
np
.
asarray
(
parsing
,
dtype
=
np
.
uint8
))
output_im
.
putpalette
(
palette
)
output_im
.
putpalette
(
palette
)
output_im
.
save
(
result_path
)
output_im
.
save
(
result_path
)
if
(
idx
+
1
)
%
100
==
0
:
if
(
idx
+
1
)
%
100
==
0
:
print
(
'%d processd'
%
(
idx
+
1
))
print
(
'%d processd'
%
(
idx
+
1
))
print
(
'%d processd done'
%
(
idx
+
1
))
print
(
'%d processd done'
%
(
idx
+
1
))
return
0
return
0
...
...
contrib/RoadLine/utils/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
contrib/RoadLine/utils/palette.py
浏览文件 @
61645b1d
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# coding: utf8
## Created by: RainbowSecret
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
## Microsoft Research
#
## yuyua@microsoft.com
# Licensed under the Apache License, Version 2.0 (the "License");
## Copyright (c) 2018
# you may not use this file except in compliance with the License.
##
# You may obtain a copy of the License at
## This source code is licensed under the MIT-style license found in the
#
## LICENSE file in the root directory of this source tree
# http://www.apache.org/licenses/LICENSE-2.0
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
...
...
contrib/RoadLine/utils/util.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
unicode_literals
import
argparse
import
argparse
import
os
import
os
def
get_arguments
():
def
get_arguments
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--use_gpu"
,
parser
.
add_argument
(
action
=
"store_true"
,
"--use_gpu"
,
action
=
"store_true"
,
help
=
"Use gpu or cpu to test."
)
help
=
"Use gpu or cpu to test."
)
parser
.
add_argument
(
parser
.
add_argument
(
'--example'
,
'--example'
,
type
=
str
,
help
=
'RoadLine, HumanSeg or ACE2P'
)
type
=
str
,
help
=
'RoadLine, HumanSeg or ACE2P'
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -34,6 +48,7 @@ class AttrDict(dict):
...
@@ -34,6 +48,7 @@ class AttrDict(dict):
else
:
else
:
self
[
name
]
=
value
self
[
name
]
=
value
def
merge_cfg_from_args
(
args
,
cfg
):
def
merge_cfg_from_args
(
args
,
cfg
):
"""Merge config keys, values in args into the global config."""
"""Merge config keys, values in args into the global config."""
for
k
,
v
in
vars
(
args
).
items
():
for
k
,
v
in
vars
(
args
).
items
():
...
@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg):
...
@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg):
value
=
v
value
=
v
if
value
is
not
None
:
if
value
is
not
None
:
cfg
[
k
]
=
value
cfg
[
k
]
=
value
dataset/convert_voc2012.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -20,6 +21,8 @@ from PIL import Image
...
@@ -20,6 +21,8 @@ from PIL import Image
import
glob
import
glob
LOCAL_PATH
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
LOCAL_PATH
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
def
remove_colormap
(
filename
):
def
remove_colormap
(
filename
):
gray_anno
=
np
.
array
(
Image
.
open
(
filename
))
gray_anno
=
np
.
array
(
Image
.
open
(
filename
))
return
gray_anno
return
gray_anno
...
@@ -30,6 +33,7 @@ def save_annotation(annotation, filename):
...
@@ -30,6 +33,7 @@ def save_annotation(annotation, filename):
annotation
=
Image
.
fromarray
(
annotation
)
annotation
=
Image
.
fromarray
(
annotation
)
annotation
.
save
(
filename
)
annotation
.
save
(
filename
)
def
convert_list
(
origin_file
,
seg_file
,
output_folder
):
def
convert_list
(
origin_file
,
seg_file
,
output_folder
):
with
open
(
seg_file
,
'w'
)
as
fid_seg
:
with
open
(
seg_file
,
'w'
)
as
fid_seg
:
with
open
(
origin_file
)
as
fid_ori
:
with
open
(
origin_file
)
as
fid_ori
:
...
@@ -43,6 +47,7 @@ def convert_list(origin_file, seg_file, output_folder):
...
@@ -43,6 +47,7 @@ def convert_list(origin_file, seg_file, output_folder):
new_line
=
' '
.
join
([
img_name
,
anno_name
])
new_line
=
' '
.
join
([
img_name
,
anno_name
])
fid_seg
.
write
(
new_line
+
"
\n
"
)
fid_seg
.
write
(
new_line
+
"
\n
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pascal_root
=
"./VOCtrainval_11-May-2012/VOC2012"
pascal_root
=
"./VOCtrainval_11-May-2012/VOC2012"
pascal_root
=
os
.
path
.
join
(
LOCAL_PATH
,
pascal_root
)
pascal_root
=
os
.
path
.
join
(
LOCAL_PATH
,
pascal_root
)
...
@@ -54,7 +59,7 @@ if __name__ == "__main__":
...
@@ -54,7 +59,7 @@ if __name__ == "__main__":
# 标注图转换后存储目录
# 标注图转换后存储目录
output_folder
=
os
.
path
.
join
(
pascal_root
,
"SegmentationClassAug"
)
output_folder
=
os
.
path
.
join
(
pascal_root
,
"SegmentationClassAug"
)
print
(
"annotation convert and file list convert"
)
print
(
"annotation convert and file list convert"
)
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
LOCAL_PATH
,
output_folder
)):
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
LOCAL_PATH
,
output_folder
)):
os
.
mkdir
(
os
.
path
.
join
(
LOCAL_PATH
,
output_folder
))
os
.
mkdir
(
os
.
path
.
join
(
LOCAL_PATH
,
output_folder
))
...
@@ -67,5 +72,5 @@ if __name__ == "__main__":
...
@@ -67,5 +72,5 @@ if __name__ == "__main__":
convert_list
(
train_path
,
train_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
convert_list
(
train_path
,
train_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
convert_list
(
val_path
,
val_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
convert_list
(
val_path
,
val_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
convert_list
(
trainval_path
,
trainval_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
convert_list
(
trainval_path
,
trainval_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
dataset/download_and_convert_voc2012.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -28,12 +29,12 @@ from convert_voc2012 import remove_colormap
...
@@ -28,12 +29,12 @@ from convert_voc2012 import remove_colormap
from
convert_voc2012
import
save_annotation
from
convert_voc2012
import
save_annotation
def
download_VOC_dataset
(
savepath
,
extrapath
):
def
download_VOC_dataset
(
savepath
,
extrapath
):
url
=
"https://paddleseg.bj.bcebos.com/dataset/VOCtrainval_11-May-2012.tar"
url
=
"https://paddleseg.bj.bcebos.com/dataset/VOCtrainval_11-May-2012.tar"
download_file_and_uncompress
(
download_file_and_uncompress
(
url
=
url
,
savepath
=
savepath
,
extrapath
=
extrapath
)
url
=
url
,
savepath
=
savepath
,
extrapath
=
extrapath
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
download_VOC_dataset
(
LOCAL_PATH
,
LOCAL_PATH
)
download_VOC_dataset
(
LOCAL_PATH
,
LOCAL_PATH
)
print
(
"Dataset download finish!"
)
print
(
"Dataset download finish!"
)
...
@@ -45,10 +46,10 @@ if __name__ == "__main__":
...
@@ -45,10 +46,10 @@ if __name__ == "__main__":
train_path
=
os
.
path
.
join
(
txt_folder
,
"train.txt"
)
train_path
=
os
.
path
.
join
(
txt_folder
,
"train.txt"
)
val_path
=
os
.
path
.
join
(
txt_folder
,
"val.txt"
)
val_path
=
os
.
path
.
join
(
txt_folder
,
"val.txt"
)
trainval_path
=
os
.
path
.
join
(
txt_folder
,
"trainval.txt"
)
trainval_path
=
os
.
path
.
join
(
txt_folder
,
"trainval.txt"
)
# 标注图转换后存储目录
# 标注图转换后存储目录
output_folder
=
os
.
path
.
join
(
pascal_root
,
"SegmentationClassAug"
)
output_folder
=
os
.
path
.
join
(
pascal_root
,
"SegmentationClassAug"
)
print
(
"annotation convert and file list convert"
)
print
(
"annotation convert and file list convert"
)
if
not
os
.
path
.
exists
(
output_folder
):
if
not
os
.
path
.
exists
(
output_folder
):
os
.
mkdir
(
output_folder
)
os
.
mkdir
(
output_folder
)
...
@@ -61,5 +62,5 @@ if __name__ == "__main__":
...
@@ -61,5 +62,5 @@ if __name__ == "__main__":
convert_list
(
train_path
,
train_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
convert_list
(
train_path
,
train_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
convert_list
(
val_path
,
val_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
convert_list
(
val_path
,
val_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
convert_list
(
trainval_path
,
trainval_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
convert_list
(
trainval_path
,
trainval_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
dataset/download_cityscapes.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
dataset/download_mini_deepglobe_road_extraction.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
dataset/download_optic.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
dataset/download_pet.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
deploy/python/infer.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -29,9 +29,9 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
...
@@ -29,9 +29,9 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
gflags
.
DEFINE_string
(
"conf"
,
default
=
""
,
help
=
"Configuration File Path"
)
gflags
.
DEFINE_string
(
"conf"
,
default
=
""
,
help
=
"Configuration File Path"
)
gflags
.
DEFINE_string
(
"input_dir"
,
default
=
""
,
help
=
"Directory of Input Images"
)
gflags
.
DEFINE_string
(
"input_dir"
,
default
=
""
,
help
=
"Directory of Input Images"
)
gflags
.
DEFINE_boolean
(
"use_pr"
,
default
=
False
,
help
=
"Use optimized model"
)
gflags
.
DEFINE_string
(
"trt_mode"
,
default
=
""
,
help
=
"Use optimized model"
)
gflags
.
DEFINE_string
(
"trt_mode"
,
default
=
""
,
help
=
"Use optimized model"
)
gflags
.
DEFINE_string
(
"ext"
,
default
=
".jpeg|.jpg"
,
help
=
"Input Image File Extensions"
)
gflags
.
DEFINE_string
(
"ext"
,
default
=
".jpeg|.jpg"
,
help
=
"Input Image File Extensions"
)
gflags
.
FLAGS
=
gflags
.
FLAGS
gflags
.
FLAGS
=
gflags
.
FLAGS
...
@@ -103,6 +103,9 @@ class DeployConfig:
...
@@ -103,6 +103,9 @@ class DeployConfig:
self
.
batch_size
=
deploy_conf
[
"BATCH_SIZE"
]
self
.
batch_size
=
deploy_conf
[
"BATCH_SIZE"
]
# 9. channels
# 9. channels
self
.
channels
=
deploy_conf
[
"CHANNELS"
]
self
.
channels
=
deploy_conf
[
"CHANNELS"
]
# 10. use_pr
self
.
use_pr
=
deploy_conf
[
"USE_PR"
]
class
ImageReader
:
class
ImageReader
:
...
@@ -257,23 +260,24 @@ class Predictor:
...
@@ -257,23 +260,24 @@ class Predictor:
# record starting time point
# record starting time point
total_start
=
time
.
time
()
total_start
=
time
.
time
()
batch_size
=
self
.
config
.
batch_size
batch_size
=
self
.
config
.
batch_size
use_pr
=
self
.
config
.
use_pr
for
i
in
range
(
0
,
len
(
images
),
batch_size
):
for
i
in
range
(
0
,
len
(
images
),
batch_size
):
real_batch_size
=
batch_size
real_batch_size
=
batch_size
if
i
+
batch_size
>=
len
(
images
):
if
i
+
batch_size
>=
len
(
images
):
real_batch_size
=
len
(
images
)
-
i
real_batch_size
=
len
(
images
)
-
i
reader_start
=
time
.
time
()
reader_start
=
time
.
time
()
img_datas
=
self
.
image_reader
.
process
(
images
[
i
:
i
+
real_batch_size
],
img_datas
=
self
.
image_reader
.
process
(
images
[
i
:
i
+
real_batch_size
],
gflags
.
FLAGS
.
use_pr
)
use_pr
)
input_data
=
np
.
concatenate
([
item
[
1
]
for
item
in
img_datas
])
input_data
=
np
.
concatenate
([
item
[
1
]
for
item
in
img_datas
])
input_data
=
self
.
create_tensor
(
input_data
=
self
.
create_tensor
(
input_data
,
real_batch_size
,
use_pr
=
gflags
.
FLAGS
.
use_pr
)
input_data
,
real_batch_size
,
use_pr
=
use_pr
)
reader_end
=
time
.
time
()
reader_end
=
time
.
time
()
infer_start
=
time
.
time
()
infer_start
=
time
.
time
()
output_data
=
self
.
predictor
.
run
(
input_data
)[
0
]
output_data
=
self
.
predictor
.
run
(
input_data
)[
0
]
infer_end
=
time
.
time
()
infer_end
=
time
.
time
()
output_data
=
output_data
.
as_ndarray
()
output_data
=
output_data
.
as_ndarray
()
post_start
=
time
.
time
()
post_start
=
time
.
time
()
self
.
output_result
(
img_datas
,
output_data
,
gflags
.
FLAGS
.
use_pr
)
self
.
output_result
(
img_datas
,
output_data
,
use_pr
)
post_end
=
time
.
time
()
post_end
=
time
.
time
()
reader_time
+=
(
reader_end
-
reader_start
)
reader_time
+=
(
reader_end
-
reader_start
)
infer_time
+=
(
infer_end
-
infer_start
)
infer_time
+=
(
infer_end
-
infer_start
)
...
...
pdseg/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -14,4 +14,4 @@
...
@@ -14,4 +14,4 @@
# limitations under the License.
# limitations under the License.
import
models
import
models
import
utils
import
utils
from
.
import
tools
from
.
import
tools
\ No newline at end of file
pdseg/check.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
...
@@ -427,12 +440,17 @@ def max_img_size_statistics():
...
@@ -427,12 +440,17 @@ def max_img_size_statistics():
logger
.
info
(
"max width and max height of images are ({},{})"
.
format
(
logger
.
info
(
"max width and max height of images are ({},{})"
.
format
(
max_width
,
max_height
))
max_width
,
max_height
))
def
num_classes_loss_matching_check
():
def
num_classes_loss_matching_check
():
loss_type
=
cfg
.
SOLVER
.
LOSS
loss_type
=
cfg
.
SOLVER
.
LOSS
num_classes
=
cfg
.
DATASET
.
NUM_CLASSES
num_classes
=
cfg
.
DATASET
.
NUM_CLASSES
if
num_classes
>
2
and
((
"dice_loss"
in
loss_type
)
or
(
"bce_loss"
in
loss_type
)):
if
num_classes
>
2
and
((
"dice_loss"
in
loss_type
)
or
logger
.
info
(
error_print
(
"loss check."
(
"bce_loss"
in
loss_type
)):
" Dice loss and bce loss is only applicable to binary classfication"
))
logger
.
info
(
error_print
(
"loss check."
" Dice loss and bce loss is only applicable to binary classfication"
))
else
:
else
:
logger
.
info
(
correct_print
(
"loss check"
))
logger
.
info
(
correct_print
(
"loss check"
))
...
...
pdseg/data_aug.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -362,7 +362,7 @@ def hsv_color_jitter(crop_img,
...
@@ -362,7 +362,7 @@ def hsv_color_jitter(crop_img,
saturation_jitter_ratio
>
0
or
\
saturation_jitter_ratio
>
0
or
\
contrast_jitter_ratio
>
0
:
contrast_jitter_ratio
>
0
:
crop_img
=
random_jitter
(
crop_img
,
saturation_jitter_ratio
,
crop_img
=
random_jitter
(
crop_img
,
saturation_jitter_ratio
,
brightness_jitter_ratio
,
contrast_jitter_ratio
)
brightness_jitter_ratio
,
contrast_jitter_ratio
)
return
crop_img
return
crop_img
...
@@ -391,7 +391,7 @@ def rand_crop(crop_img, crop_seg, mode=ModelPhase.TRAIN):
...
@@ -391,7 +391,7 @@ def rand_crop(crop_img, crop_seg, mode=ModelPhase.TRAIN):
crop_width
=
cfg
.
EVAL_CROP_SIZE
[
0
]
crop_width
=
cfg
.
EVAL_CROP_SIZE
[
0
]
crop_height
=
cfg
.
EVAL_CROP_SIZE
[
1
]
crop_height
=
cfg
.
EVAL_CROP_SIZE
[
1
]
if
not
ModelPhase
.
is_train
(
mode
):
if
not
ModelPhase
.
is_train
(
mode
):
if
(
crop_height
<
img_height
or
crop_width
<
img_width
):
if
(
crop_height
<
img_height
or
crop_width
<
img_width
):
raise
Exception
(
raise
Exception
(
"Crop size({},{}) must large than img size({},{}) when in EvalPhase."
"Crop size({},{}) must large than img size({},{}) when in EvalPhase."
...
...
pdseg/data_utils.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
"""
This code is based on https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py
This code is based on https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py
"""
"""
...
@@ -14,10 +28,10 @@ except ImportError:
...
@@ -14,10 +28,10 @@ except ImportError:
class
GeneratorEnqueuer
(
object
):
class
GeneratorEnqueuer
(
object
):
"""
"""
Multiple generators
Multiple generators
Args:
Args:
generators:
generators:
wait_time (float): time to sleep in-between calls to `put()`.
wait_time (float): time to sleep in-between calls to `put()`.
"""
"""
...
...
pdseg/eval.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -22,13 +22,9 @@ import os
...
@@ -22,13 +22,9 @@ import os
os
.
environ
[
'FLAGS_eager_delete_tensor_gb'
]
=
"0.0"
os
.
environ
[
'FLAGS_eager_delete_tensor_gb'
]
=
"0.0"
import
sys
import
sys
import
time
import
argparse
import
argparse
import
functools
import
pprint
import
pprint
import
cv2
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
utils.config
import
cfg
from
utils.config
import
cfg
...
@@ -116,7 +112,10 @@ def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs):
...
@@ -116,7 +112,10 @@ def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs):
if
ckpt_dir
is
not
None
:
if
ckpt_dir
is
not
None
:
print
(
'load test model:'
,
ckpt_dir
)
print
(
'load test model:'
,
ckpt_dir
)
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
try
:
fluid
.
load
(
test_prog
,
os
.
path
.
join
(
ckpt_dir
,
'model'
),
exe
)
except
:
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
# Use streaming confusion matrix to calculate mean_iou
# Use streaming confusion matrix to calculate mean_iou
np
.
set_printoptions
(
np
.
set_printoptions
(
...
...
pdseg/export_model.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -49,6 +49,7 @@ def parse_args():
...
@@ -49,6 +49,7 @@ def parse_args():
sys
.
exit
(
1
)
sys
.
exit
(
1
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
export_inference_config
():
def
export_inference_config
():
deploy_cfg
=
'''DEPLOY:
deploy_cfg
=
'''DEPLOY:
USE_GPU : 1
USE_GPU : 1
...
@@ -66,9 +67,8 @@ def export_inference_config():
...
@@ -66,9 +67,8 @@ def export_inference_config():
PREDICTOR_MODE : "ANALYSIS"
PREDICTOR_MODE : "ANALYSIS"
BATCH_SIZE : 1
BATCH_SIZE : 1
'''
%
(
cfg
.
FREEZE
.
SAVE_DIR
,
cfg
.
FREEZE
.
MODEL_FILENAME
,
'''
%
(
cfg
.
FREEZE
.
SAVE_DIR
,
cfg
.
FREEZE
.
MODEL_FILENAME
,
cfg
.
FREEZE
.
PARAMS_FILENAME
,
cfg
.
EVAL_CROP_SIZE
,
cfg
.
FREEZE
.
PARAMS_FILENAME
,
cfg
.
EVAL_CROP_SIZE
,
cfg
.
MEAN
,
cfg
.
STD
,
cfg
.
MEAN
,
cfg
.
STD
,
cfg
.
DATASET
.
IMAGE_TYPE
,
cfg
.
DATASET
.
IMAGE_TYPE
,
cfg
.
DATASET
.
NUM_CLASSES
,
len
(
cfg
.
STD
))
cfg
.
DATASET
.
NUM_CLASSES
,
len
(
cfg
.
STD
))
if
not
os
.
path
.
exists
(
cfg
.
FREEZE
.
SAVE_DIR
):
if
not
os
.
path
.
exists
(
cfg
.
FREEZE
.
SAVE_DIR
):
os
.
mkdir
(
cfg
.
FREEZE
.
SAVE_DIR
)
os
.
mkdir
(
cfg
.
FREEZE
.
SAVE_DIR
)
yaml_path
=
os
.
path
.
join
(
cfg
.
FREEZE
.
SAVE_DIR
,
'deploy.yaml'
)
yaml_path
=
os
.
path
.
join
(
cfg
.
FREEZE
.
SAVE_DIR
,
'deploy.yaml'
)
...
@@ -94,7 +94,13 @@ def export_inference_model(args):
...
@@ -94,7 +94,13 @@ def export_inference_model(args):
infer_prog
=
infer_prog
.
clone
(
for_test
=
True
)
infer_prog
=
infer_prog
.
clone
(
for_test
=
True
)
if
os
.
path
.
exists
(
cfg
.
TEST
.
TEST_MODEL
):
if
os
.
path
.
exists
(
cfg
.
TEST
.
TEST_MODEL
):
fluid
.
io
.
load_params
(
exe
,
cfg
.
TEST
.
TEST_MODEL
,
main_program
=
infer_prog
)
print
(
'load test model:'
,
cfg
.
TEST
.
TEST_MODEL
)
try
:
fluid
.
load
(
infer_prog
,
os
.
path
.
join
(
cfg
.
TEST
.
TEST_MODEL
,
'model'
),
exe
)
except
:
fluid
.
io
.
load_params
(
exe
,
cfg
.
TEST
.
TEST_MODEL
,
main_program
=
infer_prog
)
else
:
else
:
print
(
"TEST.TEST_MODEL diretory is empty!"
)
print
(
"TEST.TEST_MODEL diretory is empty!"
)
exit
(
-
1
)
exit
(
-
1
)
...
...
pdseg/loss.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
pdseg/lovasz_losses.py
浏览文件 @
61645b1d
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
pdseg/metrics.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
pdseg/models/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
pdseg/models/backbone/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pdseg/models/backbone/mobilenet_v2.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
pdseg/models/backbone/resnet.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -141,7 +141,7 @@ class ResNet():
...
@@ -141,7 +141,7 @@ class ResNet():
else
:
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
dilation_rate
=
get_dilated_rate
(
dilation_dict
,
block
)
dilation_rate
=
get_dilated_rate
(
dilation_dict
,
block
)
conv
=
self
.
bottleneck_block
(
conv
=
self
.
bottleneck_block
(
input
=
conv
,
input
=
conv
,
num_filters
=
int
(
num_filters
[
block
]
*
self
.
scale
),
num_filters
=
int
(
num_filters
[
block
]
*
self
.
scale
),
...
@@ -215,11 +215,11 @@ class ResNet():
...
@@ -215,11 +215,11 @@ class ResNet():
groups
=
1
,
groups
=
1
,
act
=
None
,
act
=
None
,
name
=
None
):
name
=
None
):
if
self
.
stem
==
'pspnet'
:
if
self
.
stem
==
'pspnet'
:
bias_attr
=
ParamAttr
(
name
=
name
+
"_biases"
)
bias_attr
=
ParamAttr
(
name
=
name
+
"_biases"
)
else
:
else
:
bias_attr
=
False
bias_attr
=
False
conv
=
fluid
.
layers
.
conv2d
(
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
input
=
input
,
...
@@ -238,13 +238,15 @@ class ResNet():
...
@@ -238,13 +238,15 @@ class ResNet():
bn_name
=
"bn_"
+
name
bn_name
=
"bn_"
+
name
else
:
else
:
bn_name
=
"bn"
+
name
[
3
:]
bn_name
=
"bn"
+
name
[
3
:]
return
fluid
.
layers
.
batch_norm
(
input
=
conv
,
return
fluid
.
layers
.
batch_norm
(
act
=
act
,
input
=
conv
,
name
=
bn_name
+
'.output.1'
,
act
=
act
,
param_attr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
),
name
=
bn_name
+
'.output.1'
,
bias_attr
=
ParamAttr
(
bn_name
+
'_offset'
),
param_attr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
),
moving_mean_name
=
bn_name
+
'_mean'
,
bias_attr
=
ParamAttr
(
bn_name
+
'_offset'
),
moving_variance_name
=
bn_name
+
'_variance'
,
)
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
,
)
def
shortcut
(
self
,
input
,
ch_out
,
stride
,
is_first
,
name
):
def
shortcut
(
self
,
input
,
ch_out
,
stride
,
is_first
,
name
):
ch_in
=
input
.
shape
[
1
]
ch_in
=
input
.
shape
[
1
]
...
@@ -258,7 +260,7 @@ class ResNet():
...
@@ -258,7 +260,7 @@ class ResNet():
strides
=
[
1
,
stride
]
strides
=
[
1
,
stride
]
else
:
else
:
strides
=
[
stride
,
1
]
strides
=
[
stride
,
1
]
conv0
=
self
.
conv_bn_layer
(
conv0
=
self
.
conv_bn_layer
(
input
=
input
,
input
=
input
,
num_filters
=
num_filters
,
num_filters
=
num_filters
,
...
...
pdseg/models/backbone/vgg.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -55,7 +55,8 @@ class VGGNet():
...
@@ -55,7 +55,8 @@ class VGGNet():
channels
=
[
64
,
128
,
256
,
512
,
512
]
channels
=
[
64
,
128
,
256
,
512
,
512
]
conv
=
input
conv
=
input
for
i
in
range
(
len
(
nums
)):
for
i
in
range
(
len
(
nums
)):
conv
=
self
.
conv_block
(
conv
,
channels
[
i
],
nums
[
i
],
name
=
"conv"
+
str
(
i
+
1
)
+
"_"
)
conv
=
self
.
conv_block
(
conv
,
channels
[
i
],
nums
[
i
],
name
=
"conv"
+
str
(
i
+
1
)
+
"_"
)
layers_count
+=
nums
[
i
]
layers_count
+=
nums
[
i
]
if
check_points
(
layers_count
,
decode_points
):
if
check_points
(
layers_count
,
decode_points
):
short_cuts
[
layers_count
]
=
conv
short_cuts
[
layers_count
]
=
conv
...
...
pdseg/models/backbone/xception.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
pdseg/models/libs/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pdseg/models/libs/model_libs.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -197,4 +197,4 @@ def conv_bn_layer(input,
...
@@ -197,4 +197,4 @@ def conv_bn_layer(input,
if
if_act
:
if
if_act
:
return
fluid
.
layers
.
relu6
(
bn
)
return
fluid
.
layers
.
relu6
(
bn
)
else
:
else
:
return
bn
return
bn
\ No newline at end of file
pdseg/models/model_builder.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
pdseg/models/modeling/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pdseg/models/modeling/deeplab.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
pdseg/models/modeling/fast_scnn.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
pdseg/models/modeling/hrnet.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -25,7 +25,14 @@ from paddle.fluid.param_attr import ParamAttr
...
@@ -25,7 +25,14 @@ from paddle.fluid.param_attr import ParamAttr
from
utils.config
import
cfg
from
utils.config
import
cfg
def
conv_bn_layer
(
input
,
filter_size
,
num_filters
,
stride
=
1
,
padding
=
1
,
num_groups
=
1
,
if_act
=
True
,
name
=
None
):
def
conv_bn_layer
(
input
,
filter_size
,
num_filters
,
stride
=
1
,
padding
=
1
,
num_groups
=
1
,
if_act
=
True
,
name
=
None
):
conv
=
fluid
.
layers
.
conv2d
(
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
input
=
input
,
num_filters
=
num_filters
,
num_filters
=
num_filters
,
...
@@ -37,37 +44,74 @@ def conv_bn_layer(input, filter_size, num_filters, stride=1, padding=1, num_grou
...
@@ -37,37 +44,74 @@ def conv_bn_layer(input, filter_size, num_filters, stride=1, padding=1, num_grou
param_attr
=
ParamAttr
(
initializer
=
MSRA
(),
name
=
name
+
'_weights'
),
param_attr
=
ParamAttr
(
initializer
=
MSRA
(),
name
=
name
+
'_weights'
),
bias_attr
=
False
)
bias_attr
=
False
)
bn_name
=
name
+
'_bn'
bn_name
=
name
+
'_bn'
bn
=
fluid
.
layers
.
batch_norm
(
input
=
conv
,
bn
=
fluid
.
layers
.
batch_norm
(
param_attr
=
ParamAttr
(
name
=
bn_name
+
"_scale"
,
input
=
conv
,
initializer
=
fluid
.
initializer
.
Constant
(
1.0
)),
param_attr
=
ParamAttr
(
bias_attr
=
ParamAttr
(
name
=
bn_name
+
"_offset"
,
name
=
bn_name
+
"_scale"
,
initializer
=
fluid
.
initializer
.
Constant
(
0.0
)),
initializer
=
fluid
.
initializer
.
Constant
(
1.0
)),
moving_mean_name
=
bn_name
+
'_mean'
,
bias_attr
=
ParamAttr
(
moving_variance_name
=
bn_name
+
'_variance'
)
name
=
bn_name
+
"_offset"
,
initializer
=
fluid
.
initializer
.
Constant
(
0.0
)),
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
)
if
if_act
:
if
if_act
:
bn
=
fluid
.
layers
.
relu
(
bn
)
bn
=
fluid
.
layers
.
relu
(
bn
)
return
bn
return
bn
def
basic_block
(
input
,
num_filters
,
stride
=
1
,
downsample
=
False
,
name
=
None
):
def
basic_block
(
input
,
num_filters
,
stride
=
1
,
downsample
=
False
,
name
=
None
):
residual
=
input
residual
=
input
conv
=
conv_bn_layer
(
input
=
input
,
filter_size
=
3
,
num_filters
=
num_filters
,
stride
=
stride
,
name
=
name
+
'_conv1'
)
conv
=
conv_bn_layer
(
conv
=
conv_bn_layer
(
input
=
conv
,
filter_size
=
3
,
num_filters
=
num_filters
,
if_act
=
False
,
name
=
name
+
'_conv2'
)
input
=
input
,
filter_size
=
3
,
num_filters
=
num_filters
,
stride
=
stride
,
name
=
name
+
'_conv1'
)
conv
=
conv_bn_layer
(
input
=
conv
,
filter_size
=
3
,
num_filters
=
num_filters
,
if_act
=
False
,
name
=
name
+
'_conv2'
)
if
downsample
:
if
downsample
:
residual
=
conv_bn_layer
(
input
=
input
,
filter_size
=
1
,
num_filters
=
num_filters
,
if_act
=
False
,
residual
=
conv_bn_layer
(
name
=
name
+
'_downsample'
)
input
=
input
,
filter_size
=
1
,
num_filters
=
num_filters
,
if_act
=
False
,
name
=
name
+
'_downsample'
)
return
fluid
.
layers
.
elementwise_add
(
x
=
residual
,
y
=
conv
,
act
=
'relu'
)
return
fluid
.
layers
.
elementwise_add
(
x
=
residual
,
y
=
conv
,
act
=
'relu'
)
def
bottleneck_block
(
input
,
num_filters
,
stride
=
1
,
downsample
=
False
,
name
=
None
):
def
bottleneck_block
(
input
,
num_filters
,
stride
=
1
,
downsample
=
False
,
name
=
None
):
residual
=
input
residual
=
input
conv
=
conv_bn_layer
(
input
=
input
,
filter_size
=
1
,
num_filters
=
num_filters
,
name
=
name
+
'_conv1'
)
conv
=
conv_bn_layer
(
conv
=
conv_bn_layer
(
input
=
conv
,
filter_size
=
3
,
num_filters
=
num_filters
,
stride
=
stride
,
name
=
name
+
'_conv2'
)
input
=
input
,
conv
=
conv_bn_layer
(
input
=
conv
,
filter_size
=
1
,
num_filters
=
num_filters
*
4
,
if_act
=
False
,
filter_size
=
1
,
name
=
name
+
'_conv3'
)
num_filters
=
num_filters
,
name
=
name
+
'_conv1'
)
conv
=
conv_bn_layer
(
input
=
conv
,
filter_size
=
3
,
num_filters
=
num_filters
,
stride
=
stride
,
name
=
name
+
'_conv2'
)
conv
=
conv_bn_layer
(
input
=
conv
,
filter_size
=
1
,
num_filters
=
num_filters
*
4
,
if_act
=
False
,
name
=
name
+
'_conv3'
)
if
downsample
:
if
downsample
:
residual
=
conv_bn_layer
(
input
=
input
,
filter_size
=
1
,
num_filters
=
num_filters
*
4
,
if_act
=
False
,
residual
=
conv_bn_layer
(
name
=
name
+
'_downsample'
)
input
=
input
,
filter_size
=
1
,
num_filters
=
num_filters
*
4
,
if_act
=
False
,
name
=
name
+
'_downsample'
)
return
fluid
.
layers
.
elementwise_add
(
x
=
residual
,
y
=
conv
,
act
=
'relu'
)
return
fluid
.
layers
.
elementwise_add
(
x
=
residual
,
y
=
conv
,
act
=
'relu'
)
def
fuse_layers
(
x
,
channels
,
multi_scale_output
=
True
,
name
=
None
):
def
fuse_layers
(
x
,
channels
,
multi_scale_output
=
True
,
name
=
None
):
out
=
[]
out
=
[]
for
i
in
range
(
len
(
channels
)
if
multi_scale_output
else
1
):
for
i
in
range
(
len
(
channels
)
if
multi_scale_output
else
1
):
...
@@ -77,40 +121,64 @@ def fuse_layers(x, channels, multi_scale_output=True, name=None):
...
@@ -77,40 +121,64 @@ def fuse_layers(x, channels, multi_scale_output=True, name=None):
height
=
shape
[
-
2
]
height
=
shape
[
-
2
]
for
j
in
range
(
len
(
channels
)):
for
j
in
range
(
len
(
channels
)):
if
j
>
i
:
if
j
>
i
:
y
=
conv_bn_layer
(
x
[
j
],
filter_size
=
1
,
num_filters
=
channels
[
i
],
if_act
=
False
,
y
=
conv_bn_layer
(
name
=
name
+
'_layer_'
+
str
(
i
+
1
)
+
'_'
+
str
(
j
+
1
))
x
[
j
],
y
=
fluid
.
layers
.
resize_bilinear
(
input
=
y
,
out_shape
=
[
height
,
width
])
filter_size
=
1
,
residual
=
fluid
.
layers
.
elementwise_add
(
x
=
residual
,
y
=
y
,
act
=
None
)
num_filters
=
channels
[
i
],
if_act
=
False
,
name
=
name
+
'_layer_'
+
str
(
i
+
1
)
+
'_'
+
str
(
j
+
1
))
y
=
fluid
.
layers
.
resize_bilinear
(
input
=
y
,
out_shape
=
[
height
,
width
])
residual
=
fluid
.
layers
.
elementwise_add
(
x
=
residual
,
y
=
y
,
act
=
None
)
elif
j
<
i
:
elif
j
<
i
:
y
=
x
[
j
]
y
=
x
[
j
]
for
k
in
range
(
i
-
j
):
for
k
in
range
(
i
-
j
):
if
k
==
i
-
j
-
1
:
if
k
==
i
-
j
-
1
:
y
=
conv_bn_layer
(
y
,
filter_size
=
3
,
num_filters
=
channels
[
i
],
stride
=
2
,
if_act
=
False
,
y
=
conv_bn_layer
(
name
=
name
+
'_layer_'
+
str
(
i
+
1
)
+
'_'
+
str
(
j
+
1
)
+
'_'
+
str
(
k
+
1
))
y
,
filter_size
=
3
,
num_filters
=
channels
[
i
],
stride
=
2
,
if_act
=
False
,
name
=
name
+
'_layer_'
+
str
(
i
+
1
)
+
'_'
+
str
(
j
+
1
)
+
'_'
+
str
(
k
+
1
))
else
:
else
:
y
=
conv_bn_layer
(
y
,
filter_size
=
3
,
num_filters
=
channels
[
j
],
stride
=
2
,
y
=
conv_bn_layer
(
name
=
name
+
'_layer_'
+
str
(
i
+
1
)
+
'_'
+
str
(
j
+
1
)
+
'_'
+
str
(
k
+
1
))
y
,
residual
=
fluid
.
layers
.
elementwise_add
(
x
=
residual
,
y
=
y
,
act
=
None
)
filter_size
=
3
,
num_filters
=
channels
[
j
],
stride
=
2
,
name
=
name
+
'_layer_'
+
str
(
i
+
1
)
+
'_'
+
str
(
j
+
1
)
+
'_'
+
str
(
k
+
1
))
residual
=
fluid
.
layers
.
elementwise_add
(
x
=
residual
,
y
=
y
,
act
=
None
)
residual
=
fluid
.
layers
.
relu
(
residual
)
residual
=
fluid
.
layers
.
relu
(
residual
)
out
.
append
(
residual
)
out
.
append
(
residual
)
return
out
return
out
def
branches
(
x
,
block_num
,
channels
,
name
=
None
):
def
branches
(
x
,
block_num
,
channels
,
name
=
None
):
out
=
[]
out
=
[]
for
i
in
range
(
len
(
channels
)):
for
i
in
range
(
len
(
channels
)):
residual
=
x
[
i
]
residual
=
x
[
i
]
for
j
in
range
(
block_num
):
for
j
in
range
(
block_num
):
residual
=
basic_block
(
residual
,
channels
[
i
],
residual
=
basic_block
(
name
=
name
+
'_branch_layer_'
+
str
(
i
+
1
)
+
'_'
+
str
(
j
+
1
))
residual
,
channels
[
i
],
name
=
name
+
'_branch_layer_'
+
str
(
i
+
1
)
+
'_'
+
str
(
j
+
1
))
out
.
append
(
residual
)
out
.
append
(
residual
)
return
out
return
out
def
high_resolution_module
(
x
,
channels
,
multi_scale_output
=
True
,
name
=
None
):
def
high_resolution_module
(
x
,
channels
,
multi_scale_output
=
True
,
name
=
None
):
residual
=
branches
(
x
,
4
,
channels
,
name
=
name
)
residual
=
branches
(
x
,
4
,
channels
,
name
=
name
)
out
=
fuse_layers
(
residual
,
channels
,
multi_scale_output
=
multi_scale_output
,
name
=
name
)
out
=
fuse_layers
(
residual
,
channels
,
multi_scale_output
=
multi_scale_output
,
name
=
name
)
return
out
return
out
def
transition_layer
(
x
,
in_channels
,
out_channels
,
name
=
None
):
def
transition_layer
(
x
,
in_channels
,
out_channels
,
name
=
None
):
num_in
=
len
(
in_channels
)
num_in
=
len
(
in_channels
)
num_out
=
len
(
out_channels
)
num_out
=
len
(
out_channels
)
...
@@ -118,46 +186,76 @@ def transition_layer(x, in_channels, out_channels, name=None):
...
@@ -118,46 +186,76 @@ def transition_layer(x, in_channels, out_channels, name=None):
for
i
in
range
(
num_out
):
for
i
in
range
(
num_out
):
if
i
<
num_in
:
if
i
<
num_in
:
if
in_channels
[
i
]
!=
out_channels
[
i
]:
if
in_channels
[
i
]
!=
out_channels
[
i
]:
residual
=
conv_bn_layer
(
x
[
i
],
filter_size
=
3
,
num_filters
=
out_channels
[
i
],
residual
=
conv_bn_layer
(
name
=
name
+
'_layer_'
+
str
(
i
+
1
))
x
[
i
],
filter_size
=
3
,
num_filters
=
out_channels
[
i
],
name
=
name
+
'_layer_'
+
str
(
i
+
1
))
out
.
append
(
residual
)
out
.
append
(
residual
)
else
:
else
:
out
.
append
(
x
[
i
])
out
.
append
(
x
[
i
])
else
:
else
:
residual
=
conv_bn_layer
(
x
[
-
1
],
filter_size
=
3
,
num_filters
=
out_channels
[
i
],
stride
=
2
,
residual
=
conv_bn_layer
(
name
=
name
+
'_layer_'
+
str
(
i
+
1
))
x
[
-
1
],
filter_size
=
3
,
num_filters
=
out_channels
[
i
],
stride
=
2
,
name
=
name
+
'_layer_'
+
str
(
i
+
1
))
out
.
append
(
residual
)
out
.
append
(
residual
)
return
out
return
out
def
stage
(
x
,
num_modules
,
channels
,
multi_scale_output
=
True
,
name
=
None
):
def
stage
(
x
,
num_modules
,
channels
,
multi_scale_output
=
True
,
name
=
None
):
out
=
x
out
=
x
for
i
in
range
(
num_modules
):
for
i
in
range
(
num_modules
):
if
i
==
num_modules
-
1
and
multi_scale_output
==
False
:
if
i
==
num_modules
-
1
and
multi_scale_output
==
False
:
out
=
high_resolution_module
(
out
,
channels
,
multi_scale_output
=
False
,
name
=
name
+
'_'
+
str
(
i
+
1
))
out
=
high_resolution_module
(
out
,
channels
,
multi_scale_output
=
False
,
name
=
name
+
'_'
+
str
(
i
+
1
))
else
:
else
:
out
=
high_resolution_module
(
out
,
channels
,
name
=
name
+
'_'
+
str
(
i
+
1
))
out
=
high_resolution_module
(
out
,
channels
,
name
=
name
+
'_'
+
str
(
i
+
1
))
return
out
return
out
def
layer1
(
input
,
name
=
None
):
def
layer1
(
input
,
name
=
None
):
conv
=
input
conv
=
input
for
i
in
range
(
4
):
for
i
in
range
(
4
):
conv
=
bottleneck_block
(
conv
,
num_filters
=
64
,
downsample
=
True
if
i
==
0
else
False
,
conv
=
bottleneck_block
(
name
=
name
+
'_'
+
str
(
i
+
1
))
conv
,
num_filters
=
64
,
downsample
=
True
if
i
==
0
else
False
,
name
=
name
+
'_'
+
str
(
i
+
1
))
return
conv
return
conv
def
high_resolution_net
(
input
,
num_classes
):
def
high_resolution_net
(
input
,
num_classes
):
channels_2
=
cfg
.
MODEL
.
HRNET
.
STAGE2
.
NUM_CHANNELS
channels_2
=
cfg
.
MODEL
.
HRNET
.
STAGE2
.
NUM_CHANNELS
channels_3
=
cfg
.
MODEL
.
HRNET
.
STAGE3
.
NUM_CHANNELS
channels_3
=
cfg
.
MODEL
.
HRNET
.
STAGE3
.
NUM_CHANNELS
channels_4
=
cfg
.
MODEL
.
HRNET
.
STAGE4
.
NUM_CHANNELS
channels_4
=
cfg
.
MODEL
.
HRNET
.
STAGE4
.
NUM_CHANNELS
num_modules_2
=
cfg
.
MODEL
.
HRNET
.
STAGE2
.
NUM_MODULES
num_modules_2
=
cfg
.
MODEL
.
HRNET
.
STAGE2
.
NUM_MODULES
num_modules_3
=
cfg
.
MODEL
.
HRNET
.
STAGE3
.
NUM_MODULES
num_modules_3
=
cfg
.
MODEL
.
HRNET
.
STAGE3
.
NUM_MODULES
num_modules_4
=
cfg
.
MODEL
.
HRNET
.
STAGE4
.
NUM_MODULES
num_modules_4
=
cfg
.
MODEL
.
HRNET
.
STAGE4
.
NUM_MODULES
x
=
conv_bn_layer
(
input
=
input
,
filter_size
=
3
,
num_filters
=
64
,
stride
=
2
,
if_act
=
True
,
name
=
'layer1_1'
)
x
=
conv_bn_layer
(
x
=
conv_bn_layer
(
input
=
x
,
filter_size
=
3
,
num_filters
=
64
,
stride
=
2
,
if_act
=
True
,
name
=
'layer1_2'
)
input
=
input
,
filter_size
=
3
,
num_filters
=
64
,
stride
=
2
,
if_act
=
True
,
name
=
'layer1_1'
)
x
=
conv_bn_layer
(
input
=
x
,
filter_size
=
3
,
num_filters
=
64
,
stride
=
2
,
if_act
=
True
,
name
=
'layer1_2'
)
la1
=
layer1
(
x
,
name
=
'layer2'
)
la1
=
layer1
(
x
,
name
=
'layer2'
)
tr1
=
transition_layer
([
la1
],
[
256
],
channels_2
,
name
=
'tr1'
)
tr1
=
transition_layer
([
la1
],
[
256
],
channels_2
,
name
=
'tr1'
)
...
@@ -170,18 +268,21 @@ def high_resolution_net(input, num_classes):
...
@@ -170,18 +268,21 @@ def high_resolution_net(input, num_classes):
# upsample
# upsample
shape
=
st4
[
0
].
shape
shape
=
st4
[
0
].
shape
height
,
width
=
shape
[
-
2
],
shape
[
-
1
]
height
,
width
=
shape
[
-
2
],
shape
[
-
1
]
st4
[
1
]
=
fluid
.
layers
.
resize_bilinear
(
st4
[
1
]
=
fluid
.
layers
.
resize_bilinear
(
st4
[
1
],
out_shape
=
[
height
,
width
])
st4
[
1
],
out_shape
=
[
height
,
width
])
st4
[
2
]
=
fluid
.
layers
.
resize_bilinear
(
st4
[
2
],
out_shape
=
[
height
,
width
])
st4
[
2
]
=
fluid
.
layers
.
resize_bilinear
(
st4
[
3
]
=
fluid
.
layers
.
resize_bilinear
(
st4
[
3
],
out_shape
=
[
height
,
width
])
st4
[
2
],
out_shape
=
[
height
,
width
])
st4
[
3
]
=
fluid
.
layers
.
resize_bilinear
(
st4
[
3
],
out_shape
=
[
height
,
width
])
out
=
fluid
.
layers
.
concat
(
st4
,
axis
=
1
)
out
=
fluid
.
layers
.
concat
(
st4
,
axis
=
1
)
last_channels
=
sum
(
channels_4
)
last_channels
=
sum
(
channels_4
)
out
=
conv_bn_layer
(
input
=
out
,
filter_size
=
1
,
num_filters
=
last_channels
,
stride
=
1
,
if_act
=
True
,
name
=
'conv-2'
)
out
=
conv_bn_layer
(
out
=
fluid
.
layers
.
conv2d
(
input
=
out
,
filter_size
=
1
,
num_filters
=
last_channels
,
stride
=
1
,
if_act
=
True
,
name
=
'conv-2'
)
out
=
fluid
.
layers
.
conv2d
(
input
=
out
,
input
=
out
,
num_filters
=
num_classes
,
num_filters
=
num_classes
,
filter_size
=
1
,
filter_size
=
1
,
...
@@ -193,7 +294,6 @@ def high_resolution_net(input, num_classes):
...
@@ -193,7 +294,6 @@ def high_resolution_net(input, num_classes):
out
=
fluid
.
layers
.
resize_bilinear
(
out
,
input
.
shape
[
2
:])
out
=
fluid
.
layers
.
resize_bilinear
(
out
,
input
.
shape
[
2
:])
return
out
return
out
...
@@ -201,6 +301,7 @@ def hrnet(input, num_classes):
...
@@ -201,6 +301,7 @@ def hrnet(input, num_classes):
logit
=
high_resolution_net
(
input
,
num_classes
)
logit
=
high_resolution_net
(
input
,
num_classes
)
return
logit
return
logit
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
image_shape
=
[
-
1
,
3
,
769
,
769
]
image_shape
=
[
-
1
,
3
,
769
,
769
]
image
=
fluid
.
data
(
name
=
'image'
,
shape
=
image_shape
,
dtype
=
'float32'
)
image
=
fluid
.
data
(
name
=
'image'
,
shape
=
image_shape
,
dtype
=
'float32'
)
...
...
pdseg/models/modeling/icnet.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
pdseg/models/modeling/pspnet.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -24,6 +24,7 @@ from models.libs.model_libs import avg_pool, conv, bn
...
@@ -24,6 +24,7 @@ from models.libs.model_libs import avg_pool, conv, bn
from
models.backbone.resnet
import
ResNet
as
resnet_backbone
from
models.backbone.resnet
import
ResNet
as
resnet_backbone
from
utils.config
import
cfg
from
utils.config
import
cfg
def
get_logit_interp
(
input
,
num_classes
,
out_shape
,
name
=
"logit"
):
def
get_logit_interp
(
input
,
num_classes
,
out_shape
,
name
=
"logit"
):
# 根据类别数决定最后一层卷积输出, 并插值回原始尺寸
# 根据类别数决定最后一层卷积输出, 并插值回原始尺寸
param_attr
=
fluid
.
ParamAttr
(
param_attr
=
fluid
.
ParamAttr
(
...
@@ -33,16 +34,15 @@ def get_logit_interp(input, num_classes, out_shape, name="logit"):
...
@@ -33,16 +34,15 @@ def get_logit_interp(input, num_classes, out_shape, name="logit"):
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.01
))
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.01
))
with
scope
(
name
):
with
scope
(
name
):
logit
=
conv
(
input
,
logit
=
conv
(
num_classes
,
input
,
filter_size
=
1
,
num_classes
,
param_attr
=
param_attr
,
filter_size
=
1
,
bias_attr
=
True
,
param_attr
=
param_attr
,
name
=
name
+
'_conv'
)
bias_attr
=
True
,
name
=
name
+
'_conv'
)
logit_interp
=
fluid
.
layers
.
resize_bilinear
(
logit_interp
=
fluid
.
layers
.
resize_bilinear
(
logit
,
logit
,
out_shape
=
out_shape
,
name
=
name
+
'_interp'
)
out_shape
=
out_shape
,
name
=
name
+
'_interp'
)
return
logit_interp
return
logit_interp
...
@@ -51,40 +51,44 @@ def psp_module(input, out_features):
...
@@ -51,40 +51,44 @@ def psp_module(input, out_features):
# 输入:backbone输出的特征
# 输入:backbone输出的特征
# 输出:对输入进行不同尺度pooling, 卷积操作后插值回原始尺寸,并concat
# 输出:对输入进行不同尺度pooling, 卷积操作后插值回原始尺寸,并concat
# 最后进行一个卷积及BN操作
# 最后进行一个卷积及BN操作
cat_layers
=
[]
cat_layers
=
[]
sizes
=
(
1
,
2
,
3
,
6
)
sizes
=
(
1
,
2
,
3
,
6
)
for
size
in
sizes
:
for
size
in
sizes
:
psp_name
=
"psp"
+
str
(
size
)
psp_name
=
"psp"
+
str
(
size
)
with
scope
(
psp_name
):
with
scope
(
psp_name
):
pool
=
fluid
.
layers
.
adaptive_pool2d
(
input
,
pool
=
fluid
.
layers
.
adaptive_pool2d
(
pool_size
=
[
size
,
size
],
input
,
pool_type
=
'avg'
,
pool_size
=
[
size
,
size
],
name
=
psp_name
+
'_adapool'
)
pool_type
=
'avg'
,
data
=
conv
(
pool
,
out_features
,
name
=
psp_name
+
'_adapool'
)
filter_size
=
1
,
data
=
conv
(
bias_attr
=
True
,
pool
,
name
=
psp_name
+
'_conv'
)
out_features
,
filter_size
=
1
,
bias_attr
=
True
,
name
=
psp_name
+
'_conv'
)
data_bn
=
bn
(
data
,
act
=
'relu'
)
data_bn
=
bn
(
data
,
act
=
'relu'
)
interp
=
fluid
.
layers
.
resize_bilinear
(
data_bn
,
interp
=
fluid
.
layers
.
resize_bilinear
(
out_shape
=
input
.
shape
[
2
:],
data_bn
,
out_shape
=
input
.
shape
[
2
:],
name
=
psp_name
+
'_interp'
)
name
=
psp_name
+
'_interp'
)
cat_layers
.
append
(
interp
)
cat_layers
.
append
(
interp
)
cat_layers
=
[
input
]
+
cat_layers
[::
-
1
]
cat_layers
=
[
input
]
+
cat_layers
[::
-
1
]
cat
=
fluid
.
layers
.
concat
(
cat_layers
,
axis
=
1
,
name
=
'psp_cat'
)
cat
=
fluid
.
layers
.
concat
(
cat_layers
,
axis
=
1
,
name
=
'psp_cat'
)
psp_end_name
=
"psp_end"
psp_end_name
=
"psp_end"
with
scope
(
psp_end_name
):
with
scope
(
psp_end_name
):
data
=
conv
(
cat
,
data
=
conv
(
out_features
,
cat
,
filter_size
=
3
,
out_features
,
padding
=
1
,
filter_size
=
3
,
bias_attr
=
True
,
padding
=
1
,
name
=
psp_end_name
)
bias_attr
=
True
,
name
=
psp_end_name
)
out
=
bn
(
data
,
act
=
'relu'
)
out
=
bn
(
data
,
act
=
'relu'
)
return
out
return
out
def
resnet
(
input
):
def
resnet
(
input
):
# PSPNET backbone: resnet, 默认resnet50
# PSPNET backbone: resnet, 默认resnet50
# end_points: resnet终止层数
# end_points: resnet终止层数
...
@@ -92,14 +96,14 @@ def resnet(input):
...
@@ -92,14 +96,14 @@ def resnet(input):
scale
=
cfg
.
MODEL
.
PSPNET
.
DEPTH_MULTIPLIER
scale
=
cfg
.
MODEL
.
PSPNET
.
DEPTH_MULTIPLIER
layers
=
cfg
.
MODEL
.
PSPNET
.
LAYERS
layers
=
cfg
.
MODEL
.
PSPNET
.
LAYERS
end_points
=
layers
-
1
end_points
=
layers
-
1
dilation_dict
=
{
2
:
2
,
3
:
4
}
dilation_dict
=
{
2
:
2
,
3
:
4
}
model
=
resnet_backbone
(
layers
,
scale
,
stem
=
'pspnet'
)
model
=
resnet_backbone
(
layers
,
scale
,
stem
=
'pspnet'
)
data
,
_
=
model
.
net
(
input
,
data
,
_
=
model
.
net
(
end_points
=
end_points
,
input
,
end_points
=
end_points
,
dilation_dict
=
dilation_dict
)
dilation_dict
=
dilation_dict
)
return
data
return
data
def
pspnet
(
input
,
num_classes
):
def
pspnet
(
input
,
num_classes
):
# Backbone: ResNet
# Backbone: ResNet
res
=
resnet
(
input
)
res
=
resnet
(
input
)
...
@@ -109,4 +113,3 @@ def pspnet(input, num_classes):
...
@@ -109,4 +113,3 @@ def pspnet(input, num_classes):
# 根据类别数决定最后一层卷积输出, 并插值回原始尺寸
# 根据类别数决定最后一层卷积输出, 并插值回原始尺寸
logit
=
get_logit_interp
(
dropout
,
num_classes
,
input
.
shape
[
2
:])
logit
=
get_logit_interp
(
dropout
,
num_classes
,
input
.
shape
[
2
:])
return
logit
return
logit
pdseg/models/modeling/unet.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
pdseg/reader.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -71,7 +71,8 @@ class SegDataset(object):
...
@@ -71,7 +71,8 @@ class SegDataset(object):
if
self
.
shuffle
and
cfg
.
NUM_TRAINERS
>
1
:
if
self
.
shuffle
and
cfg
.
NUM_TRAINERS
>
1
:
np
.
random
.
RandomState
(
self
.
shuffle_seed
).
shuffle
(
self
.
all_lines
)
np
.
random
.
RandomState
(
self
.
shuffle_seed
).
shuffle
(
self
.
all_lines
)
num_lines
=
len
(
self
.
all_lines
)
//
cfg
.
NUM_TRAINERS
num_lines
=
len
(
self
.
all_lines
)
//
cfg
.
NUM_TRAINERS
self
.
lines
=
self
.
all_lines
[
num_lines
*
cfg
.
TRAINER_ID
:
num_lines
*
(
cfg
.
TRAINER_ID
+
1
)]
self
.
lines
=
self
.
all_lines
[
num_lines
*
cfg
.
TRAINER_ID
:
num_lines
*
(
cfg
.
TRAINER_ID
+
1
)]
self
.
shuffle_seed
+=
1
self
.
shuffle_seed
+=
1
elif
self
.
shuffle
:
elif
self
.
shuffle
:
np
.
random
.
shuffle
(
self
.
lines
)
np
.
random
.
shuffle
(
self
.
lines
)
...
@@ -99,7 +100,8 @@ class SegDataset(object):
...
@@ -99,7 +100,8 @@ class SegDataset(object):
if
self
.
shuffle
and
cfg
.
NUM_TRAINERS
>
1
:
if
self
.
shuffle
and
cfg
.
NUM_TRAINERS
>
1
:
np
.
random
.
RandomState
(
self
.
shuffle_seed
).
shuffle
(
self
.
all_lines
)
np
.
random
.
RandomState
(
self
.
shuffle_seed
).
shuffle
(
self
.
all_lines
)
num_lines
=
len
(
self
.
all_lines
)
//
cfg
.
NUM_TRAINERS
num_lines
=
len
(
self
.
all_lines
)
//
cfg
.
NUM_TRAINERS
self
.
lines
=
self
.
all_lines
[
num_lines
*
cfg
.
TRAINER_ID
:
num_lines
*
(
cfg
.
TRAINER_ID
+
1
)]
self
.
lines
=
self
.
all_lines
[
num_lines
*
cfg
.
TRAINER_ID
:
num_lines
*
(
cfg
.
TRAINER_ID
+
1
)]
self
.
shuffle_seed
+=
1
self
.
shuffle_seed
+=
1
elif
self
.
shuffle
:
elif
self
.
shuffle
:
np
.
random
.
shuffle
(
self
.
lines
)
np
.
random
.
shuffle
(
self
.
lines
)
...
...
pdseg/solver.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
pdseg/tools/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
pdseg/tools/create_dataset_list.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -21,55 +21,48 @@ import warnings
...
@@ -21,55 +21,48 @@ import warnings
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
'PaddleSeg generate file list on cityscapes or your customized dataset.'
)
description
=
parser
.
add_argument
(
'PaddleSeg generate file list on cityscapes or your customized dataset.'
'dataset_root'
,
help
=
'dataset root directory'
,
type
=
str
)
)
parser
.
add_argument
(
'dataset_root'
,
help
=
'dataset root directory'
,
type
=
str
)
parser
.
add_argument
(
parser
.
add_argument
(
'--type'
,
'--type'
,
help
=
'dataset type:
\n
'
help
=
'dataset type:
\n
'
'- cityscapes
\n
'
'- cityscapes
\n
'
'- custom(default)'
,
'- custom(default)'
,
default
=
"custom"
,
default
=
"custom"
,
type
=
str
type
=
str
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--separator'
,
'--separator'
,
dest
=
'separator'
,
dest
=
'separator'
,
help
=
'file list separator'
,
help
=
'file list separator'
,
default
=
"|"
,
default
=
"|"
,
type
=
str
type
=
str
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--folder'
,
'--folder'
,
help
=
'the folder names of images and labels'
,
help
=
'the folder names of images and labels'
,
type
=
str
,
type
=
str
,
nargs
=
2
,
nargs
=
2
,
default
=
[
'images'
,
'annotations'
]
default
=
[
'images'
,
'annotations'
])
)
parser
.
add_argument
(
parser
.
add_argument
(
'--second_folder'
,
'--second_folder'
,
help
=
'the second-level folder names of train set, validation set, test set'
,
help
=
'the second-level folder names of train set, validation set, test set'
,
type
=
str
,
type
=
str
,
nargs
=
'*'
,
nargs
=
'*'
,
default
=
[
'train'
,
'val'
,
'test'
]
default
=
[
'train'
,
'val'
,
'test'
])
)
parser
.
add_argument
(
parser
.
add_argument
(
'--format'
,
'--format'
,
help
=
'data format of images and labels, e.g. jpg or png.'
,
help
=
'data format of images and labels, e.g. jpg or png.'
,
type
=
str
,
type
=
str
,
nargs
=
2
,
nargs
=
2
,
default
=
[
'jpg'
,
'png'
]
default
=
[
'jpg'
,
'png'
])
)
parser
.
add_argument
(
parser
.
add_argument
(
'--postfix'
,
'--postfix'
,
help
=
'postfix of images or labels'
,
help
=
'postfix of images or labels'
,
type
=
str
,
type
=
str
,
nargs
=
2
,
nargs
=
2
,
default
=
[
''
,
''
]
default
=
[
''
,
''
])
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -120,15 +113,17 @@ def generate_list(args):
...
@@ -120,15 +113,17 @@ def generate_list(args):
num_images
=
len
(
image_files
)
num_images
=
len
(
image_files
)
if
not
label_files
:
if
not
label_files
:
label_dir
=
os
.
path
.
join
(
dataset_root
,
args
.
folder
[
1
],
dataset_split
)
label_dir
=
os
.
path
.
join
(
dataset_root
,
args
.
folder
[
1
],
dataset_split
)
warnings
.
warn
(
"No labels in {} !!!"
.
format
(
label_dir
))
warnings
.
warn
(
"No labels in {} !!!"
.
format
(
label_dir
))
num_label
=
len
(
label_files
)
num_label
=
len
(
label_files
)
if
num_images
!=
num_label
and
num_label
>
0
:
if
num_images
!=
num_label
and
num_label
>
0
:
raise
Exception
(
"Number of images = {} number of labels = {}
\n
"
raise
Exception
(
"Either number of images is equal to number of labels, "
"Number of images = {} number of labels = {}
\n
"
"or number of labels is equal to 0.
\n
"
"Either number of images is equal to number of labels, "
"Please check your dataset!"
.
format
(
num_images
,
num_label
))
"or number of labels is equal to 0.
\n
"
"Please check your dataset!"
.
format
(
num_images
,
num_label
))
file_list
=
os
.
path
.
join
(
dataset_root
,
dataset_split
+
'.txt'
)
file_list
=
os
.
path
.
join
(
dataset_root
,
dataset_split
+
'.txt'
)
with
open
(
file_list
,
"w"
)
as
f
:
with
open
(
file_list
,
"w"
)
as
f
:
...
...
pdseg/tools/gray2pseudo_color.py
浏览文件 @
61645b1d
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
__future__
import
print_function
import
argparse
import
argparse
...
@@ -11,16 +25,12 @@ from PIL import Image
...
@@ -11,16 +25,12 @@ from PIL import Image
def
parse_args
():
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'dir_or_file'
,
'dir_or_file'
,
help
=
'input gray label directory or file list path'
)
help
=
'input gray label directory or file list path'
)
parser
.
add_argument
(
'output_dir'
,
help
=
'output colorful label directory'
)
parser
.
add_argument
(
'output_dir'
,
parser
.
add_argument
(
'--dataset_dir'
,
help
=
'dataset directory'
)
help
=
'output colorful label directory'
)
parser
.
add_argument
(
'--file_separator'
,
help
=
'file list separator'
)
parser
.
add_argument
(
'--dataset_dir'
,
help
=
'dataset directory'
)
parser
.
add_argument
(
'--file_separator'
,
help
=
'file list separator'
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
...
pdseg/tools/jingling2seg.py
浏览文件 @
61645b1d
#!/usr/bin/env python
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
__future__
import
print_function
...
...
pdseg/tools/labelme2seg.py
浏览文件 @
61645b1d
#!/usr/bin/env python
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
__future__
import
print_function
...
...
pdseg/train.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -26,9 +26,7 @@ import argparse
...
@@ -26,9 +26,7 @@ import argparse
import
pprint
import
pprint
import
random
import
random
import
shutil
import
shutil
import
functools
import
paddle
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid
import
profiler
from
paddle.fluid
import
profiler
...
@@ -39,10 +37,10 @@ from metrics import ConfusionMatrix
...
@@ -39,10 +37,10 @@ from metrics import ConfusionMatrix
from
reader
import
SegDataset
from
reader
import
SegDataset
from
models.model_builder
import
build_model
from
models.model_builder
import
build_model
from
models.model_builder
import
ModelPhase
from
models.model_builder
import
ModelPhase
from
models.model_builder
import
parse_shape_from_file
from
eval
import
evaluate
from
eval
import
evaluate
from
vis
import
visualize
from
vis
import
visualize
from
utils
import
dist_utils
from
utils
import
dist_utils
from
utils.load_model_utils
import
load_pretrained_weights
def
parse_args
():
def
parse_args
():
...
@@ -118,38 +116,7 @@ def parse_args():
...
@@ -118,38 +116,7 @@ def parse_args():
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
save_vars
(
executor
,
dirname
,
program
=
None
,
vars
=
None
):
def
save_checkpoint
(
program
,
ckpt_name
):
"""
Temporary resolution for Win save variables compatability.
Will fix in PaddlePaddle v1.5.2
"""
save_program
=
fluid
.
Program
()
save_block
=
save_program
.
global_block
()
for
each_var
in
vars
:
# NOTE: don't save the variable which type is RAW
if
each_var
.
type
==
fluid
.
core
.
VarDesc
.
VarType
.
RAW
:
continue
new_var
=
save_block
.
create_var
(
name
=
each_var
.
name
,
shape
=
each_var
.
shape
,
dtype
=
each_var
.
dtype
,
type
=
each_var
.
type
,
lod_level
=
each_var
.
lod_level
,
persistable
=
True
)
file_path
=
os
.
path
.
join
(
dirname
,
new_var
.
name
)
file_path
=
os
.
path
.
normpath
(
file_path
)
save_block
.
append_op
(
type
=
'save'
,
inputs
=
{
'X'
:
[
new_var
]},
outputs
=
{},
attrs
=
{
'file_path'
:
file_path
})
executor
.
run
(
save_program
)
def
save_checkpoint
(
exe
,
program
,
ckpt_name
):
"""
"""
Save checkpoint for evaluation or resume training
Save checkpoint for evaluation or resume training
"""
"""
...
@@ -158,29 +125,22 @@ def save_checkpoint(exe, program, ckpt_name):
...
@@ -158,29 +125,22 @@ def save_checkpoint(exe, program, ckpt_name):
if
not
os
.
path
.
isdir
(
ckpt_dir
):
if
not
os
.
path
.
isdir
(
ckpt_dir
):
os
.
makedirs
(
ckpt_dir
)
os
.
makedirs
(
ckpt_dir
)
save_vars
(
fluid
.
save
(
program
,
os
.
path
.
join
(
ckpt_dir
,
'model'
))
exe
,
ckpt_dir
,
program
,
vars
=
list
(
filter
(
fluid
.
io
.
is_persistable
,
program
.
list_vars
())))
return
ckpt_dir
return
ckpt_dir
def
load_checkpoint
(
exe
,
program
):
def
load_checkpoint
(
exe
,
program
):
"""
"""
Load checkpoiont f
rom pretrained model directory for resume
training
Load checkpoiont f
or resuming
training
"""
"""
print
(
'Resume model training from:'
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
)
if
not
os
.
path
.
exists
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
):
raise
ValueError
(
"TRAIN.PRETRAIN_MODEL {} not exist!"
.
format
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
))
fluid
.
io
.
load_persistables
(
exe
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
,
main_program
=
program
)
model_path
=
cfg
.
TRAIN
.
RESUME_MODEL_DIR
model_path
=
cfg
.
TRAIN
.
RESUME_MODEL_DIR
print
(
'Resume model training from:'
,
model_path
)
if
not
os
.
path
.
exists
(
model_path
):
raise
ValueError
(
"TRAIN.PRETRAIN_MODEL {} not exist!"
.
format
(
model_path
))
fluid
.
load
(
program
,
os
.
path
.
join
(
model_path
,
'model'
),
exe
)
# Check is path ended by path spearator
# Check is path ended by path spearator
if
model_path
[
-
1
]
==
os
.
sep
:
if
model_path
[
-
1
]
==
os
.
sep
:
model_path
=
model_path
[
0
:
-
1
]
model_path
=
model_path
[
0
:
-
1
]
...
@@ -195,7 +155,6 @@ def load_checkpoint(exe, program):
...
@@ -195,7 +155,6 @@ def load_checkpoint(exe, program):
else
:
else
:
raise
ValueError
(
"Resume model path is not valid!"
)
raise
ValueError
(
"Resume model path is not valid!"
)
print
(
"Model checkpoint loaded successfully!"
)
print
(
"Model checkpoint loaded successfully!"
)
return
begin_epoch
return
begin_epoch
...
@@ -247,8 +206,6 @@ def train(cfg):
...
@@ -247,8 +206,6 @@ def train(cfg):
yield
item
[
0
],
item
[
1
],
item
[
2
]
yield
item
[
0
],
item
[
1
],
item
[
2
]
# Get device environment
# Get device environment
# places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
# place = places[0]
gpu_id
=
int
(
os
.
environ
.
get
(
'FLAGS_selected_gpus'
,
0
))
gpu_id
=
int
(
os
.
environ
.
get
(
'FLAGS_selected_gpus'
,
0
))
place
=
fluid
.
CUDAPlace
(
gpu_id
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
place
=
fluid
.
CUDAPlace
(
gpu_id
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
places
=
fluid
.
cuda_places
()
if
args
.
use_gpu
else
fluid
.
cpu_places
()
places
=
fluid
.
cuda_places
()
if
args
.
use_gpu
else
fluid
.
cpu_places
()
...
@@ -304,42 +261,7 @@ def train(cfg):
...
@@ -304,42 +261,7 @@ def train(cfg):
begin_epoch
=
load_checkpoint
(
exe
,
train_prog
)
begin_epoch
=
load_checkpoint
(
exe
,
train_prog
)
# Load pretrained model
# Load pretrained model
elif
os
.
path
.
exists
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
):
elif
os
.
path
.
exists
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
):
print_info
(
'Pretrained model dir: '
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
load_pretrained_weights
(
exe
,
train_prog
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
load_vars
=
[]
load_fail_vars
=
[]
def
var_shape_matched
(
var
,
shape
):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist
=
os
.
path
.
exists
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
if
var_exist
:
var_shape
=
parse_shape_from_file
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
return
var_shape
==
shape
return
False
for
x
in
train_prog
.
list_vars
():
if
isinstance
(
x
,
fluid
.
framework
.
Parameter
):
shape
=
tuple
(
fluid
.
global_scope
().
find_var
(
x
.
name
).
get_tensor
().
shape
())
if
var_shape_matched
(
x
,
shape
):
load_vars
.
append
(
x
)
else
:
load_fail_vars
.
append
(
x
)
fluid
.
io
.
load_vars
(
exe
,
dirname
=
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
vars
=
load_vars
)
for
var
in
load_vars
:
print_info
(
"Parameter[{}] loaded sucessfully!"
.
format
(
var
.
name
))
for
var
in
load_fail_vars
:
print_info
(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it."
.
format
(
var
.
name
))
print_info
(
"{}/{} pretrained parameters loaded successfully!"
.
format
(
len
(
load_vars
),
len
(
load_vars
)
+
len
(
load_fail_vars
)))
else
:
else
:
print_info
(
print_info
(
'Pretrained model dir {} not exists, training from scratch...'
.
'Pretrained model dir {} not exists, training from scratch...'
.
...
@@ -418,12 +340,9 @@ def train(cfg):
...
@@ -418,12 +340,9 @@ def train(cfg):
step
)
step
)
log_writer
.
add_scalar
(
'Train/mean_acc'
,
mean_acc
,
log_writer
.
add_scalar
(
'Train/mean_acc'
,
mean_acc
,
step
)
step
)
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
log_writer
.
add_scalar
(
'Train/step/sec'
,
speed
,
step
)
step
)
log_writer
.
add_scalar
(
'Train/step/sec'
,
speed
,
step
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
avg_loss
=
0.0
avg_loss
=
0.0
cm
.
zero_matrix
()
cm
.
zero_matrix
()
...
@@ -445,12 +364,9 @@ def train(cfg):
...
@@ -445,12 +364,9 @@ def train(cfg):
).
format
(
epoch
,
step
,
lr
[
0
],
avg_loss
,
speed
,
).
format
(
epoch
,
step
,
lr
[
0
],
avg_loss
,
speed
,
calculate_eta
(
all_step
-
step
,
speed
)))
calculate_eta
(
all_step
-
step
,
speed
)))
if
args
.
use_vdl
:
if
args
.
use_vdl
:
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
log_writer
.
add_scalar
(
'Train/speed'
,
speed
,
step
)
step
)
log_writer
.
add_scalar
(
'Train/speed'
,
speed
,
step
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
avg_loss
=
0.0
avg_loss
=
0.0
timer
.
restart
()
timer
.
restart
()
...
@@ -470,7 +386,7 @@ def train(cfg):
...
@@ -470,7 +386,7 @@ def train(cfg):
if
(
epoch
%
cfg
.
TRAIN
.
SNAPSHOT_EPOCH
==
0
if
(
epoch
%
cfg
.
TRAIN
.
SNAPSHOT_EPOCH
==
0
or
epoch
==
cfg
.
SOLVER
.
NUM_EPOCHS
)
and
cfg
.
TRAINER_ID
==
0
:
or
epoch
==
cfg
.
SOLVER
.
NUM_EPOCHS
)
and
cfg
.
TRAINER_ID
==
0
:
ckpt_dir
=
save_checkpoint
(
exe
,
train_prog
,
epoch
)
ckpt_dir
=
save_checkpoint
(
train_prog
,
epoch
)
if
args
.
do_eval
:
if
args
.
do_eval
:
print
(
"Evaluation start"
)
print
(
"Evaluation start"
)
...
@@ -480,10 +396,8 @@ def train(cfg):
...
@@ -480,10 +396,8 @@ def train(cfg):
use_gpu
=
args
.
use_gpu
,
use_gpu
=
args
.
use_gpu
,
use_mpio
=
args
.
use_mpio
)
use_mpio
=
args
.
use_mpio
)
if
args
.
use_vdl
:
if
args
.
use_vdl
:
log_writer
.
add_scalar
(
'Evaluate/mean_iou'
,
mean_iou
,
log_writer
.
add_scalar
(
'Evaluate/mean_iou'
,
mean_iou
,
step
)
step
)
log_writer
.
add_scalar
(
'Evaluate/mean_acc'
,
mean_acc
,
step
)
log_writer
.
add_scalar
(
'Evaluate/mean_acc'
,
mean_acc
,
step
)
if
mean_iou
>
best_mIoU
:
if
mean_iou
>
best_mIoU
:
best_mIoU
=
mean_iou
best_mIoU
=
mean_iou
...
@@ -505,7 +419,7 @@ def train(cfg):
...
@@ -505,7 +419,7 @@ def train(cfg):
# save final model
# save final model
if
cfg
.
TRAINER_ID
==
0
:
if
cfg
.
TRAINER_ID
==
0
:
save_checkpoint
(
exe
,
train_prog
,
'final'
)
save_checkpoint
(
train_prog
,
'final'
)
def
main
(
args
):
def
main
(
args
):
...
...
pdseg/utils/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pdseg/utils/collect.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -98,7 +99,7 @@ class SegConfig(dict):
...
@@ -98,7 +99,7 @@ class SegConfig(dict):
'DATASET.IMAGE_TYPE config error, only support `rgb`, `gray` and `rgba`'
'DATASET.IMAGE_TYPE config error, only support `rgb`, `gray` and `rgba`'
)
)
if
self
.
MEAN
is
not
None
:
if
self
.
MEAN
is
not
None
:
self
.
DATASET
.
PADDING_VALUE
=
[
x
*
255.0
for
x
in
self
.
MEAN
]
self
.
DATASET
.
PADDING_VALUE
=
[
x
*
255.0
for
x
in
self
.
MEAN
]
if
not
self
.
TRAIN_CROP_SIZE
:
if
not
self
.
TRAIN_CROP_SIZE
:
raise
ValueError
(
raise
ValueError
(
...
@@ -111,9 +112,12 @@ class SegConfig(dict):
...
@@ -111,9 +112,12 @@ class SegConfig(dict):
)
)
# Ensure file list is use UTF-8 encoding
# Ensure file list is use UTF-8 encoding
train_sets
=
codecs
.
open
(
self
.
DATASET
.
TRAIN_FILE_LIST
,
'r'
,
'utf-8'
).
readlines
()
train_sets
=
codecs
.
open
(
self
.
DATASET
.
TRAIN_FILE_LIST
,
'r'
,
val_sets
=
codecs
.
open
(
self
.
DATASET
.
VAL_FILE_LIST
,
'r'
,
'utf-8'
).
readlines
()
'utf-8'
).
readlines
()
test_sets
=
codecs
.
open
(
self
.
DATASET
.
TEST_FILE_LIST
,
'r'
,
'utf-8'
).
readlines
()
val_sets
=
codecs
.
open
(
self
.
DATASET
.
VAL_FILE_LIST
,
'r'
,
'utf-8'
).
readlines
()
test_sets
=
codecs
.
open
(
self
.
DATASET
.
TEST_FILE_LIST
,
'r'
,
'utf-8'
).
readlines
()
self
.
DATASET
.
TRAIN_TOTAL_IMAGES
=
len
(
train_sets
)
self
.
DATASET
.
TRAIN_TOTAL_IMAGES
=
len
(
train_sets
)
self
.
DATASET
.
VAL_TOTAL_IMAGES
=
len
(
val_sets
)
self
.
DATASET
.
VAL_TOTAL_IMAGES
=
len
(
val_sets
)
self
.
DATASET
.
TEST_TOTAL_IMAGES
=
len
(
test_sets
)
self
.
DATASET
.
TEST_TOTAL_IMAGES
=
len
(
test_sets
)
...
@@ -122,12 +126,13 @@ class SegConfig(dict):
...
@@ -122,12 +126,13 @@ class SegConfig(dict):
len
(
self
.
MODEL
.
MULTI_LOSS_WEIGHT
)
!=
3
:
len
(
self
.
MODEL
.
MULTI_LOSS_WEIGHT
)
!=
3
:
self
.
MODEL
.
MULTI_LOSS_WEIGHT
=
[
1.0
,
0.4
,
0.16
]
self
.
MODEL
.
MULTI_LOSS_WEIGHT
=
[
1.0
,
0.4
,
0.16
]
if
self
.
AUG
.
AUG_METHOD
not
in
[
'unpadding'
,
'stepscaling'
,
'rangescaling'
]:
if
self
.
AUG
.
AUG_METHOD
not
in
[
'unpadding'
,
'stepscaling'
,
'rangescaling'
]:
raise
ValueError
(
raise
ValueError
(
'AUG.AUG_METHOD config error, only support `unpadding`, `unpadding` and `rangescaling`'
'AUG.AUG_METHOD config error, only support `unpadding`, `unpadding` and `rangescaling`'
)
)
def
update_from_list
(
self
,
config_list
):
def
update_from_list
(
self
,
config_list
):
if
len
(
config_list
)
%
2
!=
0
:
if
len
(
config_list
)
%
2
!=
0
:
raise
ValueError
(
raise
ValueError
(
...
...
pdseg/utils/config.py
浏览文件 @
61645b1d
#
-*- coding: utf-8 -*-
#
coding: utf8
#
Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
.
#
Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve
.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
pdseg/utils/dist_utils.py
浏览文件 @
61645b1d
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
#Licensed under the Apache License, Version 2.0 (the "License");
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#you may not use this file except in compliance with the License.
...
...
pdseg/utils/fp16_utils.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
os
from
paddle
import
fluid
from
paddle
import
fluid
def
load_fp16_vars
(
executor
,
dirname
,
program
):
def
load_fp16_vars
(
executor
,
dirname
,
program
):
load_dirname
=
os
.
path
.
normpath
(
dirname
)
load_dirname
=
os
.
path
.
normpath
(
dirname
)
...
@@ -28,4 +44,4 @@ def load_fp16_vars(executor, dirname, program):
...
@@ -28,4 +44,4 @@ def load_fp16_vars(executor, dirname, program):
'load_as_fp16'
:
var
.
dtype
==
fluid
.
core
.
VarDesc
.
VarType
.
FP16
'load_as_fp16'
:
var
.
dtype
==
fluid
.
core
.
VarDesc
.
VarType
.
FP16
})
})
executor
.
run
(
load_prog
)
executor
.
run
(
load_prog
)
\ No newline at end of file
pdseg/utils/load_model_utils.py
0 → 100644
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
os.path
as
osp
import
six
import
numpy
as
np
def
parse_param_file
(
param_file
,
return_shape
=
True
):
from
paddle.fluid.proto.framework_pb2
import
VarType
f
=
open
(
param_file
,
'rb'
)
version
=
np
.
fromstring
(
f
.
read
(
4
),
dtype
=
'int32'
)
lod_level
=
np
.
fromstring
(
f
.
read
(
8
),
dtype
=
'int64'
)
for
i
in
range
(
int
(
lod_level
)):
_size
=
np
.
fromstring
(
f
.
read
(
8
),
dtype
=
'int64'
)
_
=
f
.
read
(
_size
)
version
=
np
.
fromstring
(
f
.
read
(
4
),
dtype
=
'int32'
)
tensor_desc
=
VarType
.
TensorDesc
()
tensor_desc_size
=
np
.
fromstring
(
f
.
read
(
4
),
dtype
=
'int32'
)
tensor_desc
.
ParseFromString
(
f
.
read
(
int
(
tensor_desc_size
)))
tensor_shape
=
tuple
(
tensor_desc
.
dims
)
if
return_shape
:
f
.
close
()
return
tuple
(
tensor_desc
.
dims
)
if
tensor_desc
.
data_type
!=
5
:
raise
Exception
(
"Unexpected data type while parse {}"
.
format
(
param_file
))
data_size
=
4
for
i
in
range
(
len
(
tensor_shape
)):
data_size
*=
tensor_shape
[
i
]
weight
=
np
.
fromstring
(
f
.
read
(
data_size
),
dtype
=
'float32'
)
f
.
close
()
return
np
.
reshape
(
weight
,
tensor_shape
)
def
load_pdparams
(
exe
,
main_prog
,
model_dir
):
import
paddle.fluid
as
fluid
from
paddle.fluid.proto.framework_pb2
import
VarType
from
paddle.fluid.framework
import
Program
vars_to_load
=
list
()
vars_not_load
=
list
()
import
pickle
with
open
(
osp
.
join
(
model_dir
,
'model.pdparams'
),
'rb'
)
as
f
:
params_dict
=
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
f
,
encoding
=
'latin1'
)
unused_vars
=
list
()
for
var
in
main_prog
.
list_vars
():
if
not
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
continue
if
var
.
name
not
in
params_dict
:
print
(
"{} is not in saved model"
.
format
(
var
.
name
))
vars_not_load
.
append
(
var
.
name
)
continue
if
var
.
shape
!=
params_dict
[
var
.
name
].
shape
:
unused_vars
.
append
(
var
.
name
)
vars_not_load
.
append
(
var
.
name
)
print
(
"[SKIP] Shape of pretrained weight {} doesn't match.(Pretrained: {}, Actual: {})"
.
format
(
var
.
name
,
params_dict
[
var
.
name
].
shape
,
var
.
shape
))
continue
vars_to_load
.
append
(
var
)
for
var_name
in
unused_vars
:
del
params_dict
[
var_name
]
fluid
.
io
.
set_program_state
(
main_prog
,
params_dict
)
if
len
(
vars_to_load
)
==
0
:
print
(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
)
else
:
print
(
"There are {}/{} varaibles in {} are loaded."
.
format
(
len
(
vars_to_load
),
len
(
vars_to_load
)
+
len
(
vars_not_load
),
model_dir
))
def
load_pretrained_weights
(
exe
,
main_prog
,
weights_dir
):
if
not
osp
.
exists
(
weights_dir
):
raise
Exception
(
"Path {} not exists."
.
format
(
weights_dir
))
if
osp
.
exists
(
osp
.
join
(
weights_dir
,
"model.pdparams"
)):
return
load_pdparams
(
exe
,
main_prog
,
weights_dir
)
import
paddle.fluid
as
fluid
vars_to_load
=
list
()
vars_not_load
=
list
()
for
var
in
main_prog
.
list_vars
():
if
not
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
continue
if
not
osp
.
exists
(
osp
.
join
(
weights_dir
,
var
.
name
)):
print
(
"[SKIP] Pretrained weight {}/{} doesn't exist"
.
format
(
weights_dir
,
var
.
name
))
vars_not_load
.
append
(
var
)
continue
pretrained_shape
=
parse_param_file
(
osp
.
join
(
weights_dir
,
var
.
name
))
actual_shape
=
tuple
(
var
.
shape
)
if
pretrained_shape
!=
actual_shape
:
print
(
"[SKIP] Shape of pretrained weight {}/{} doesn't match.(Pretrained: {}, Actual: {})"
.
format
(
weights_dir
,
var
.
name
,
pretrained_shape
,
actual_shape
))
vars_not_load
.
append
(
var
)
continue
vars_to_load
.
append
(
var
)
params_dict
=
fluid
.
io
.
load_program_state
(
weights_dir
,
var_list
=
vars_to_load
)
fluid
.
io
.
set_program_state
(
main_prog
,
params_dict
)
if
len
(
vars_to_load
)
==
0
:
print
(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
)
else
:
print
(
"There are {}/{} varaibles in {} are loaded."
.
format
(
len
(
vars_to_load
),
len
(
vars_to_load
)
+
len
(
vars_not_load
),
weights_dir
))
pdseg/utils/timer.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
pdseg/vis.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -115,7 +115,12 @@ def visualize(cfg,
...
@@ -115,7 +115,12 @@ def visualize(cfg,
ckpt_dir
=
cfg
.
TEST
.
TEST_MODEL
if
not
ckpt_dir
else
ckpt_dir
ckpt_dir
=
cfg
.
TEST
.
TEST_MODEL
if
not
ckpt_dir
else
ckpt_dir
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
if
ckpt_dir
is
not
None
:
print
(
'load test model:'
,
ckpt_dir
)
try
:
fluid
.
load
(
test_prog
,
os
.
path
.
join
(
ckpt_dir
,
'model'
),
exe
)
except
:
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
save_dir
=
vis_dir
save_dir
=
vis_dir
makedirs
(
save_dir
)
makedirs
(
save_dir
)
...
@@ -169,18 +174,13 @@ def visualize(cfg,
...
@@ -169,18 +174,13 @@ def visualize(cfg,
print
(
"VisualDL visualization epoch"
,
epoch
)
print
(
"VisualDL visualization epoch"
,
epoch
)
pred_mask_np
=
np
.
array
(
pred_mask
.
convert
(
"RGB"
))
pred_mask_np
=
np
.
array
(
pred_mask
.
convert
(
"RGB"
))
log_writer
.
add_image
(
log_writer
.
add_image
(
"Predict/{}"
.
format
(
img_name
),
"Predict/{}"
.
format
(
img_name
),
pred_mask_np
,
epoch
)
pred_mask_np
,
epoch
)
# Original image
# Original image
# BGR->RGB
# BGR->RGB
img
=
cv2
.
imread
(
img
=
cv2
.
imread
(
os
.
path
.
join
(
cfg
.
DATASET
.
DATA_DIR
,
os
.
path
.
join
(
cfg
.
DATASET
.
DATA_DIR
,
img_name
))[...,
::
-
1
]
img_name
))[...,
::
-
1
]
log_writer
.
add_image
(
log_writer
.
add_image
(
"Images/{}"
.
format
(
img_name
),
img
,
epoch
)
"Images/{}"
.
format
(
img_name
),
img
,
epoch
)
# add ground truth (label) images
# add ground truth (label) images
grt
=
grts
[
i
]
grt
=
grts
[
i
]
if
grt
is
not
None
:
if
grt
is
not
None
:
...
@@ -189,10 +189,8 @@ def visualize(cfg,
...
@@ -189,10 +189,8 @@ def visualize(cfg,
grt_pil
.
putpalette
(
color_map
)
grt_pil
.
putpalette
(
color_map
)
grt_pil
=
grt_pil
.
resize
((
org_shape
[
1
],
org_shape
[
0
]))
grt_pil
=
grt_pil
.
resize
((
org_shape
[
1
],
org_shape
[
0
]))
grt
=
np
.
array
(
grt_pil
.
convert
(
"RGB"
))
grt
=
np
.
array
(
grt_pil
.
convert
(
"RGB"
))
log_writer
.
add_image
(
log_writer
.
add_image
(
"Label/{}"
.
format
(
img_name
),
grt
,
"Label/{}"
.
format
(
img_name
),
epoch
)
grt
,
epoch
)
# If in local_test mode, only visualize 5 images just for testing
# If in local_test mode, only visualize 5 images just for testing
# procedure
# procedure
...
...
pretrained_model/download_model.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -42,7 +43,7 @@ model_urls = {
...
@@ -42,7 +43,7 @@ model_urls = {
"hrnet_w30_bn_imagenet"
:
"hrnet_w30_bn_imagenet"
:
"https://paddleseg.bj.bcebos.com/models/hrnet_w30_imagenet.tar"
,
"https://paddleseg.bj.bcebos.com/models/hrnet_w30_imagenet.tar"
,
"hrnet_w32_bn_imagenet"
:
"hrnet_w32_bn_imagenet"
:
"https://paddleseg.bj.bcebos.com/models/hrnet_w32_imagenet.tar"
,
"https://paddleseg.bj.bcebos.com/models/hrnet_w32_imagenet.tar"
,
"hrnet_w40_bn_imagenet"
:
"hrnet_w40_bn_imagenet"
:
"https://paddleseg.bj.bcebos.com/models/hrnet_w40_imagenet.tar"
,
"https://paddleseg.bj.bcebos.com/models/hrnet_w40_imagenet.tar"
,
"hrnet_w44_bn_imagenet"
:
"hrnet_w44_bn_imagenet"
:
...
...
slim/distillation/model_builder.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
slim/distillation/train_distill.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -44,6 +44,7 @@ from model_builder import parse_shape_from_file
...
@@ -44,6 +44,7 @@ from model_builder import parse_shape_from_file
from
eval
import
evaluate
from
eval
import
evaluate
from
vis
import
visualize
from
vis
import
visualize
from
utils
import
dist_utils
from
utils
import
dist_utils
from
utils.load_model_utils
import
load_pretrained_weights
import
solver
import
solver
from
paddleslim.dist.single_distiller
import
merge
,
l2_loss
from
paddleslim.dist.single_distiller
import
merge
,
l2_loss
...
@@ -116,38 +117,7 @@ def parse_args():
...
@@ -116,38 +117,7 @@ def parse_args():
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
save_vars
(
executor
,
dirname
,
program
=
None
,
vars
=
None
):
def
save_checkpoint
(
program
,
ckpt_name
):
"""
Temporary resolution for Win save variables compatability.
Will fix in PaddlePaddle v1.5.2
"""
save_program
=
fluid
.
Program
()
save_block
=
save_program
.
global_block
()
for
each_var
in
vars
:
# NOTE: don't save the variable which type is RAW
if
each_var
.
type
==
fluid
.
core
.
VarDesc
.
VarType
.
RAW
:
continue
new_var
=
save_block
.
create_var
(
name
=
each_var
.
name
,
shape
=
each_var
.
shape
,
dtype
=
each_var
.
dtype
,
type
=
each_var
.
type
,
lod_level
=
each_var
.
lod_level
,
persistable
=
True
)
file_path
=
os
.
path
.
join
(
dirname
,
new_var
.
name
)
file_path
=
os
.
path
.
normpath
(
file_path
)
save_block
.
append_op
(
type
=
'save'
,
inputs
=
{
'X'
:
[
new_var
]},
outputs
=
{},
attrs
=
{
'file_path'
:
file_path
})
executor
.
run
(
save_program
)
def
save_checkpoint
(
exe
,
program
,
ckpt_name
):
"""
"""
Save checkpoint for evaluation or resume training
Save checkpoint for evaluation or resume training
"""
"""
...
@@ -156,29 +126,22 @@ def save_checkpoint(exe, program, ckpt_name):
...
@@ -156,29 +126,22 @@ def save_checkpoint(exe, program, ckpt_name):
if
not
os
.
path
.
isdir
(
ckpt_dir
):
if
not
os
.
path
.
isdir
(
ckpt_dir
):
os
.
makedirs
(
ckpt_dir
)
os
.
makedirs
(
ckpt_dir
)
save_vars
(
fluid
.
save
(
program
,
os
.
path
.
join
(
ckpt_dir
,
'model'
))
exe
,
ckpt_dir
,
program
,
vars
=
list
(
filter
(
fluid
.
io
.
is_persistable
,
program
.
list_vars
())))
return
ckpt_dir
return
ckpt_dir
def
load_checkpoint
(
exe
,
program
):
def
load_checkpoint
(
exe
,
program
):
"""
"""
Load checkpoiont f
rom pretrained model directory for resume
training
Load checkpoiont f
or resuming
training
"""
"""
print
(
'Resume model training from:'
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
)
if
not
os
.
path
.
exists
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
):
raise
ValueError
(
"TRAIN.PRETRAIN_MODEL {} not exist!"
.
format
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
))
fluid
.
io
.
load_persistables
(
exe
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
,
main_program
=
program
)
model_path
=
cfg
.
TRAIN
.
RESUME_MODEL_DIR
model_path
=
cfg
.
TRAIN
.
RESUME_MODEL_DIR
print
(
'Resume model training from:'
,
model_path
)
if
not
os
.
path
.
exists
(
model_path
):
raise
ValueError
(
"TRAIN.PRETRAIN_MODEL {} not exist!"
.
format
(
model_path
))
fluid
.
load
(
program
,
os
.
path
.
join
(
model_path
,
'model'
),
exe
)
# Check is path ended by path spearator
# Check is path ended by path spearator
if
model_path
[
-
1
]
==
os
.
sep
:
if
model_path
[
-
1
]
==
os
.
sep
:
model_path
=
model_path
[
0
:
-
1
]
model_path
=
model_path
[
0
:
-
1
]
...
@@ -193,7 +156,6 @@ def load_checkpoint(exe, program):
...
@@ -193,7 +156,6 @@ def load_checkpoint(exe, program):
else
:
else
:
raise
ValueError
(
"Resume model path is not valid!"
)
raise
ValueError
(
"Resume model path is not valid!"
)
print
(
"Model checkpoint loaded successfully!"
)
print
(
"Model checkpoint loaded successfully!"
)
return
begin_epoch
return
begin_epoch
...
@@ -289,7 +251,11 @@ def train(cfg):
...
@@ -289,7 +251,11 @@ def train(cfg):
ckpt_dir
=
cfg
.
SLIM
.
KNOWLEDGE_DISTILL_TEACHER_MODEL_DIR
ckpt_dir
=
cfg
.
SLIM
.
KNOWLEDGE_DISTILL_TEACHER_MODEL_DIR
assert
ckpt_dir
is
not
None
assert
ckpt_dir
is
not
None
print
(
'load teacher model:'
,
ckpt_dir
)
print
(
'load teacher model:'
,
ckpt_dir
)
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
teacher_program
)
if
os
.
path
.
exists
(
ckpt_dir
):
try
:
fluid
.
load
(
teacher_program
,
os
.
path
.
join
(
ckpt_dir
,
'model'
),
exe
)
except
:
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
teacher_program
)
# cfg = load_config(FLAGS.config)
# cfg = load_config(FLAGS.config)
cfg
.
update_from_file
(
args
.
cfg_file
)
cfg
.
update_from_file
(
args
.
cfg_file
)
...
@@ -355,42 +321,8 @@ def train(cfg):
...
@@ -355,42 +321,8 @@ def train(cfg):
begin_epoch
=
load_checkpoint
(
exe
,
fluid
.
default_main_program
())
begin_epoch
=
load_checkpoint
(
exe
,
fluid
.
default_main_program
())
# Load pretrained model
# Load pretrained model
elif
os
.
path
.
exists
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
):
elif
os
.
path
.
exists
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
):
print_info
(
'Pretrained model dir: '
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
load_pretrained_weights
(
exe
,
fluid
.
default_main_program
(),
load_vars
=
[]
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
load_fail_vars
=
[]
def
var_shape_matched
(
var
,
shape
):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist
=
os
.
path
.
exists
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
if
var_exist
:
var_shape
=
parse_shape_from_file
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
return
var_shape
==
shape
return
False
for
x
in
fluid
.
default_main_program
().
list_vars
():
if
isinstance
(
x
,
fluid
.
framework
.
Parameter
):
shape
=
tuple
(
fluid
.
global_scope
().
find_var
(
x
.
name
).
get_tensor
().
shape
())
if
var_shape_matched
(
x
,
shape
):
load_vars
.
append
(
x
)
else
:
load_fail_vars
.
append
(
x
)
fluid
.
io
.
load_vars
(
exe
,
dirname
=
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
vars
=
load_vars
)
for
var
in
load_vars
:
print_info
(
"Parameter[{}] loaded sucessfully!"
.
format
(
var
.
name
))
for
var
in
load_fail_vars
:
print_info
(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it."
.
format
(
var
.
name
))
print_info
(
"{}/{} pretrained parameters loaded successfully!"
.
format
(
len
(
load_vars
),
len
(
load_vars
)
+
len
(
load_fail_vars
)))
else
:
else
:
print_info
(
print_info
(
'Pretrained model dir {} not exists, training from scratch...'
.
'Pretrained model dir {} not exists, training from scratch...'
.
...
@@ -475,12 +407,9 @@ def train(cfg):
...
@@ -475,12 +407,9 @@ def train(cfg):
step
)
step
)
log_writer
.
add_scalar
(
'Train/mean_acc'
,
mean_acc
,
log_writer
.
add_scalar
(
'Train/mean_acc'
,
mean_acc
,
step
)
step
)
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
log_writer
.
add_scalar
(
'Train/step/sec'
,
speed
,
step
)
step
)
log_writer
.
add_scalar
(
'Train/step/sec'
,
speed
,
step
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
avg_loss
=
0.0
avg_loss
=
0.0
cm
.
zero_matrix
()
cm
.
zero_matrix
()
...
@@ -503,16 +432,13 @@ def train(cfg):
...
@@ -503,16 +432,13 @@ def train(cfg):
speed
=
args
.
log_steps
/
timer
.
elapsed_time
()
speed
=
args
.
log_steps
/
timer
.
elapsed_time
()
print
((
print
((
"epoch={} step={} lr={:.5f} loss={:.4f} teacher loss={:.4f} distill loss={:.4f} step/sec={:.3f} | ETA {}"
"epoch={} step={} lr={:.5f} loss={:.4f} teacher loss={:.4f} distill loss={:.4f} step/sec={:.3f} | ETA {}"
).
format
(
epoch
,
step
,
lr
[
0
],
avg_loss
,
).
format
(
epoch
,
step
,
lr
[
0
],
avg_loss
,
avg_t_loss
,
avg_
t_loss
,
avg_
d_loss
,
speed
,
avg_d_loss
,
speed
,
calculate_eta
(
all_step
-
step
,
speed
)))
calculate_eta
(
all_step
-
step
,
speed
)))
if
args
.
use_vdl
:
if
args
.
use_vdl
:
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
log_writer
.
add_scalar
(
'Train/speed'
,
speed
,
step
)
step
)
log_writer
.
add_scalar
(
'Train/speed'
,
speed
,
step
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
avg_loss
=
0.0
avg_loss
=
0.0
avg_t_loss
=
0.0
avg_t_loss
=
0.0
...
@@ -527,7 +453,7 @@ def train(cfg):
...
@@ -527,7 +453,7 @@ def train(cfg):
if
(
epoch
%
cfg
.
TRAIN
.
SNAPSHOT_EPOCH
==
0
if
(
epoch
%
cfg
.
TRAIN
.
SNAPSHOT_EPOCH
==
0
or
epoch
==
cfg
.
SOLVER
.
NUM_EPOCHS
)
and
cfg
.
TRAINER_ID
==
0
:
or
epoch
==
cfg
.
SOLVER
.
NUM_EPOCHS
)
and
cfg
.
TRAINER_ID
==
0
:
ckpt_dir
=
save_checkpoint
(
exe
,
fluid
.
default_main_program
(),
epoch
)
ckpt_dir
=
save_checkpoint
(
fluid
.
default_main_program
(),
epoch
)
if
args
.
do_eval
:
if
args
.
do_eval
:
print
(
"Evaluation start"
)
print
(
"Evaluation start"
)
...
@@ -537,10 +463,8 @@ def train(cfg):
...
@@ -537,10 +463,8 @@ def train(cfg):
use_gpu
=
args
.
use_gpu
,
use_gpu
=
args
.
use_gpu
,
use_mpio
=
args
.
use_mpio
)
use_mpio
=
args
.
use_mpio
)
if
args
.
use_vdl
:
if
args
.
use_vdl
:
log_writer
.
add_scalar
(
'Evaluate/mean_iou'
,
mean_iou
,
log_writer
.
add_scalar
(
'Evaluate/mean_iou'
,
mean_iou
,
step
)
step
)
log_writer
.
add_scalar
(
'Evaluate/mean_acc'
,
mean_acc
,
step
)
log_writer
.
add_scalar
(
'Evaluate/mean_acc'
,
mean_acc
,
step
)
if
mean_iou
>
best_mIoU
:
if
mean_iou
>
best_mIoU
:
best_mIoU
=
mean_iou
best_mIoU
=
mean_iou
...
@@ -560,11 +484,11 @@ def train(cfg):
...
@@ -560,11 +484,11 @@ def train(cfg):
ckpt_dir
=
ckpt_dir
,
ckpt_dir
=
ckpt_dir
,
log_writer
=
log_writer
)
log_writer
=
log_writer
)
if
cfg
.
TRAINER_ID
==
0
:
if
cfg
.
TRAINER_ID
==
0
:
ckpt_dir
=
save_checkpoint
(
exe
,
fluid
.
default_main_program
(),
epoch
)
ckpt_dir
=
save_checkpoint
(
fluid
.
default_main_program
(),
epoch
)
# save final model
# save final model
if
cfg
.
TRAINER_ID
==
0
:
if
cfg
.
TRAINER_ID
==
0
:
save_checkpoint
(
exe
,
fluid
.
default_main_program
(),
'final'
)
save_checkpoint
(
fluid
.
default_main_program
(),
'final'
)
def
main
(
args
):
def
main
(
args
):
...
...
slim/nas/deeplab.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -27,6 +27,7 @@ from models.libs.model_libs import separate_conv
...
@@ -27,6 +27,7 @@ from models.libs.model_libs import separate_conv
from
models.backbone.mobilenet_v2
import
MobileNetV2
as
mobilenet_backbone
from
models.backbone.mobilenet_v2
import
MobileNetV2
as
mobilenet_backbone
from
models.backbone.xception
import
Xception
as
xception_backbone
from
models.backbone.xception
import
Xception
as
xception_backbone
def
encoder
(
input
):
def
encoder
(
input
):
# 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv
# 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv
# ASPP_WITH_SEP_CONV:默认为真,使用depthwise可分离卷积,否则使用普通卷积
# ASPP_WITH_SEP_CONV:默认为真,使用depthwise可分离卷积,否则使用普通卷积
...
@@ -47,8 +48,7 @@ def encoder(input):
...
@@ -47,8 +48,7 @@ def encoder(input):
with
scope
(
'encoder'
):
with
scope
(
'encoder'
):
channel
=
256
channel
=
256
with
scope
(
"image_pool"
):
with
scope
(
"image_pool"
):
image_avg
=
fluid
.
layers
.
reduce_mean
(
image_avg
=
fluid
.
layers
.
reduce_mean
(
input
,
[
2
,
3
],
keep_dim
=
True
)
input
,
[
2
,
3
],
keep_dim
=
True
)
image_avg
=
bn_relu
(
image_avg
=
bn_relu
(
conv
(
conv
(
image_avg
,
image_avg
,
...
@@ -191,7 +191,10 @@ def nas_backbone(input, arch):
...
@@ -191,7 +191,10 @@ def nas_backbone(input, arch):
end_points
=
8
end_points
=
8
decode_point
=
3
decode_point
=
3
data
,
decode_shortcuts
=
arch
(
data
,
decode_shortcuts
=
arch
(
input
,
end_points
=
end_points
,
return_block
=
decode_point
,
output_stride
=
16
)
input
,
end_points
=
end_points
,
return_block
=
decode_point
,
output_stride
=
16
)
decode_shortcut
=
decode_shortcuts
[
decode_point
]
decode_shortcut
=
decode_shortcuts
[
decode_point
]
return
data
,
decode_shortcut
return
data
,
decode_shortcut
...
...
slim/nas/eval_nas.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -123,7 +123,10 @@ def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs):
...
@@ -123,7 +123,10 @@ def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs):
if
ckpt_dir
is
not
None
:
if
ckpt_dir
is
not
None
:
print
(
'load test model:'
,
ckpt_dir
)
print
(
'load test model:'
,
ckpt_dir
)
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
try
:
fluid
.
load
(
test_prog
,
os
.
path
.
join
(
ckpt_dir
,
'model'
),
exe
)
except
:
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
# Use streaming confusion matrix to calculate mean_iou
# Use streaming confusion matrix to calculate mean_iou
np
.
set_printoptions
(
np
.
set_printoptions
(
...
...
slim/nas/mobilenetv2_search_space.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -31,7 +32,7 @@ __all__ = ["MobileNetV2SpaceSeg"]
...
@@ -31,7 +32,7 @@ __all__ = ["MobileNetV2SpaceSeg"]
class
MobileNetV2SpaceSeg
(
SearchSpaceBase
):
class
MobileNetV2SpaceSeg
(
SearchSpaceBase
):
def
__init__
(
self
,
input_size
,
output_size
,
block_num
,
block_mask
=
None
):
def
__init__
(
self
,
input_size
,
output_size
,
block_num
,
block_mask
=
None
):
super
(
MobileNetV2SpaceSeg
,
self
).
__init__
(
input_size
,
output_size
,
super
(
MobileNetV2SpaceSeg
,
self
).
__init__
(
input_size
,
output_size
,
block_num
,
block_mask
)
block_num
,
block_mask
)
# self.head_num means the first convolution channel
# self.head_num means the first convolution channel
self
.
head_num
=
np
.
array
([
3
,
4
,
8
,
12
,
16
,
24
,
32
])
#7
self
.
head_num
=
np
.
array
([
3
,
4
,
8
,
12
,
16
,
24
,
32
])
#7
# self.filter_num1 ~ self.filter_num6 means following convlution channel
# self.filter_num1 ~ self.filter_num6 means following convlution channel
...
@@ -48,7 +49,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
...
@@ -48,7 +49,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
self
.
k_size
=
np
.
array
([
3
,
5
])
#2
self
.
k_size
=
np
.
array
([
3
,
5
])
#2
# self.multiply means expansion_factor of each _inverted_residual_unit
# self.multiply means expansion_factor of each _inverted_residual_unit
self
.
multiply
=
np
.
array
([
1
,
2
,
3
,
4
,
6
])
#5
self
.
multiply
=
np
.
array
([
1
,
2
,
3
,
4
,
6
])
#5
# self.repeat means repeat_num _inverted_residual_unit in each _invresi_blocks
# self.repeat means repeat_num _inverted_residual_unit in each _invresi_blocks
self
.
repeat
=
np
.
array
([
1
,
2
,
3
,
4
,
5
,
6
])
#6
self
.
repeat
=
np
.
array
([
1
,
2
,
3
,
4
,
5
,
6
])
#6
def
init_tokens
(
self
):
def
init_tokens
(
self
):
...
@@ -72,7 +73,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
...
@@ -72,7 +73,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
def
range_table
(
self
):
def
range_table
(
self
):
"""
"""
Get range table of current search space, constrains the range of tokens.
Get range table of current search space, constrains the range of tokens.
"""
"""
# head_num + 6 * [multiple(expansion_factor), filter_num, repeat, kernel_size]
# head_num + 6 * [multiple(expansion_factor), filter_num, repeat, kernel_size]
# yapf: disable
# yapf: disable
...
@@ -95,8 +96,8 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
...
@@ -95,8 +96,8 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
tokens
=
self
.
init_tokens
()
tokens
=
self
.
init_tokens
()
self
.
bottleneck_params_list
=
[]
self
.
bottleneck_params_list
=
[]
self
.
bottleneck_params_list
.
append
(
self
.
bottleneck_params_list
.
append
(
(
1
,
self
.
head_num
[
tokens
[
0
]],
1
,
1
,
(
1
,
self
.
head_num
[
tokens
[
0
]],
1
,
1
,
3
))
3
))
self
.
bottleneck_params_list
.
append
(
self
.
bottleneck_params_list
.
append
(
(
self
.
multiply
[
tokens
[
1
]],
self
.
filter_num1
[
tokens
[
2
]],
(
self
.
multiply
[
tokens
[
1
]],
self
.
filter_num1
[
tokens
[
2
]],
self
.
repeat
[
tokens
[
3
]],
2
,
self
.
k_size
[
tokens
[
4
]]))
self
.
repeat
[
tokens
[
3
]],
2
,
self
.
k_size
[
tokens
[
4
]]))
...
@@ -150,7 +151,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
...
@@ -150,7 +151,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
return
(
True
if
count
==
points
else
False
)
return
(
True
if
count
==
points
else
False
)
#conv1
#conv1
# all padding is 'SAME' in the conv2d, can compute the actual padding automatic.
# all padding is 'SAME' in the conv2d, can compute the actual padding automatic.
input
=
conv_bn_layer
(
input
=
conv_bn_layer
(
input
,
input
,
num_filters
=
int
(
32
*
self
.
scale
),
num_filters
=
int
(
32
*
self
.
scale
),
...
...
slim/nas/model_builder.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
slim/nas/train_nas.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -47,6 +47,7 @@ from model_builder import parse_shape_from_file
...
@@ -47,6 +47,7 @@ from model_builder import parse_shape_from_file
from
eval_nas
import
evaluate
from
eval_nas
import
evaluate
from
vis
import
visualize
from
vis
import
visualize
from
utils
import
dist_utils
from
utils
import
dist_utils
from
utils.load_model_utils
import
load_pretrained_weights
from
mobilenetv2_search_space
import
MobileNetV2SpaceSeg
from
mobilenetv2_search_space
import
MobileNetV2SpaceSeg
from
paddleslim.nas.search_space.search_space_factory
import
SearchSpaceFactory
from
paddleslim.nas.search_space.search_space_factory
import
SearchSpaceFactory
...
@@ -116,38 +117,7 @@ def parse_args():
...
@@ -116,38 +117,7 @@ def parse_args():
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
save_vars
(
executor
,
dirname
,
program
=
None
,
vars
=
None
):
def
save_checkpoint
(
program
,
ckpt_name
):
"""
Temporary resolution for Win save variables compatability.
Will fix in PaddlePaddle v1.5.2
"""
save_program
=
fluid
.
Program
()
save_block
=
save_program
.
global_block
()
for
each_var
in
vars
:
# NOTE: don't save the variable which type is RAW
if
each_var
.
type
==
fluid
.
core
.
VarDesc
.
VarType
.
RAW
:
continue
new_var
=
save_block
.
create_var
(
name
=
each_var
.
name
,
shape
=
each_var
.
shape
,
dtype
=
each_var
.
dtype
,
type
=
each_var
.
type
,
lod_level
=
each_var
.
lod_level
,
persistable
=
True
)
file_path
=
os
.
path
.
join
(
dirname
,
new_var
.
name
)
file_path
=
os
.
path
.
normpath
(
file_path
)
save_block
.
append_op
(
type
=
'save'
,
inputs
=
{
'X'
:
[
new_var
]},
outputs
=
{},
attrs
=
{
'file_path'
:
file_path
})
executor
.
run
(
save_program
)
def
save_checkpoint
(
exe
,
program
,
ckpt_name
):
"""
"""
Save checkpoint for evaluation or resume training
Save checkpoint for evaluation or resume training
"""
"""
...
@@ -156,29 +126,22 @@ def save_checkpoint(exe, program, ckpt_name):
...
@@ -156,29 +126,22 @@ def save_checkpoint(exe, program, ckpt_name):
if
not
os
.
path
.
isdir
(
ckpt_dir
):
if
not
os
.
path
.
isdir
(
ckpt_dir
):
os
.
makedirs
(
ckpt_dir
)
os
.
makedirs
(
ckpt_dir
)
save_vars
(
fluid
.
save
(
program
,
os
.
path
.
join
(
ckpt_dir
,
'model'
))
exe
,
ckpt_dir
,
program
,
vars
=
list
(
filter
(
fluid
.
io
.
is_persistable
,
program
.
list_vars
())))
return
ckpt_dir
return
ckpt_dir
def
load_checkpoint
(
exe
,
program
):
def
load_checkpoint
(
exe
,
program
):
"""
"""
Load checkpoiont f
rom pretrained model directory for resume
training
Load checkpoiont f
or resuming
training
"""
"""
print
(
'Resume model training from:'
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
)
if
not
os
.
path
.
exists
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
):
raise
ValueError
(
"TRAIN.PRETRAIN_MODEL {} not exist!"
.
format
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
))
fluid
.
io
.
load_persistables
(
exe
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
,
main_program
=
program
)
model_path
=
cfg
.
TRAIN
.
RESUME_MODEL_DIR
model_path
=
cfg
.
TRAIN
.
RESUME_MODEL_DIR
print
(
'Resume model training from:'
,
model_path
)
if
not
os
.
path
.
exists
(
model_path
):
raise
ValueError
(
"TRAIN.PRETRAIN_MODEL {} not exist!"
.
format
(
model_path
))
fluid
.
load
(
program
,
os
.
path
.
join
(
model_path
,
'model'
),
exe
)
# Check is path ended by path spearator
# Check is path ended by path spearator
if
model_path
[
-
1
]
==
os
.
sep
:
if
model_path
[
-
1
]
==
os
.
sep
:
model_path
=
model_path
[
0
:
-
1
]
model_path
=
model_path
[
0
:
-
1
]
...
@@ -193,7 +156,6 @@ def load_checkpoint(exe, program):
...
@@ -193,7 +156,6 @@ def load_checkpoint(exe, program):
else
:
else
:
raise
ValueError
(
"Resume model path is not valid!"
)
raise
ValueError
(
"Resume model path is not valid!"
)
print
(
"Model checkpoint loaded successfully!"
)
print
(
"Model checkpoint loaded successfully!"
)
return
begin_epoch
return
begin_epoch
...
@@ -245,8 +207,6 @@ def train(cfg):
...
@@ -245,8 +207,6 @@ def train(cfg):
yield
item
[
0
],
item
[
1
],
item
[
2
]
yield
item
[
0
],
item
[
1
],
item
[
2
]
# Get device environment
# Get device environment
# places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
# place = places[0]
gpu_id
=
int
(
os
.
environ
.
get
(
'FLAGS_selected_gpus'
,
0
))
gpu_id
=
int
(
os
.
environ
.
get
(
'FLAGS_selected_gpus'
,
0
))
place
=
fluid
.
CUDAPlace
(
gpu_id
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
place
=
fluid
.
CUDAPlace
(
gpu_id
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
places
=
fluid
.
cuda_places
()
if
args
.
use_gpu
else
fluid
.
cpu_places
()
places
=
fluid
.
cuda_places
()
if
args
.
use_gpu
else
fluid
.
cpu_places
()
...
@@ -326,43 +286,8 @@ def train(cfg):
...
@@ -326,43 +286,8 @@ def train(cfg):
begin_epoch
=
load_checkpoint
(
exe
,
train_prog
)
begin_epoch
=
load_checkpoint
(
exe
,
train_prog
)
# Load pretrained model
# Load pretrained model
elif
os
.
path
.
exists
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
):
elif
os
.
path
.
exists
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
):
print_info
(
'Pretrained model dir: '
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
load_pretrained_weights
(
exe
,
train_prog
,
load_vars
=
[]
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
load_fail_vars
=
[]
def
var_shape_matched
(
var
,
shape
):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist
=
os
.
path
.
exists
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
if
var_exist
:
var_shape
=
parse_shape_from_file
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
return
var_shape
==
shape
return
False
for
x
in
train_prog
.
list_vars
():
if
isinstance
(
x
,
fluid
.
framework
.
Parameter
):
shape
=
tuple
(
fluid
.
global_scope
().
find_var
(
x
.
name
).
get_tensor
().
shape
())
if
var_shape_matched
(
x
,
shape
):
load_vars
.
append
(
x
)
else
:
load_fail_vars
.
append
(
x
)
fluid
.
io
.
load_vars
(
exe
,
dirname
=
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
vars
=
load_vars
)
for
var
in
load_vars
:
print_info
(
"Parameter[{}] loaded sucessfully!"
.
format
(
var
.
name
))
for
var
in
load_fail_vars
:
print_info
(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it."
.
format
(
var
.
name
))
print_info
(
"{}/{} pretrained parameters loaded successfully!"
.
format
(
len
(
load_vars
),
len
(
load_vars
)
+
len
(
load_fail_vars
)))
else
:
else
:
print_info
(
print_info
(
'Pretrained model dir {} not exists, training from scratch...'
.
'Pretrained model dir {} not exists, training from scratch...'
.
...
@@ -419,8 +344,7 @@ def train(cfg):
...
@@ -419,8 +344,7 @@ def train(cfg):
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
print
(
e
)
if
epoch
>
cfg
.
SLIM
.
NAS_START_EVAL_EPOCH
:
if
epoch
>
cfg
.
SLIM
.
NAS_START_EVAL_EPOCH
:
ckpt_dir
=
save_checkpoint
(
exe
,
train_prog
,
ckpt_dir
=
save_checkpoint
(
train_prog
,
'{}_tmp'
.
format
(
port
))
'{}_tmp'
.
format
(
port
))
_
,
mean_iou
,
_
,
mean_acc
=
evaluate
(
_
,
mean_iou
,
_
,
mean_acc
=
evaluate
(
cfg
=
cfg
,
cfg
=
cfg
,
arch
=
arch
,
arch
=
arch
,
...
...
slim/prune/eval_prune.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
slim/prune/train_prune.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -46,6 +46,7 @@ from models.model_builder import parse_shape_from_file
...
@@ -46,6 +46,7 @@ from models.model_builder import parse_shape_from_file
from
eval_prune
import
evaluate
from
eval_prune
import
evaluate
from
vis
import
visualize
from
vis
import
visualize
from
utils
import
dist_utils
from
utils
import
dist_utils
from
utils.load_model_utils
import
load_pretrained_weights
from
paddleslim.prune
import
Pruner
,
save_model
from
paddleslim.prune
import
Pruner
,
save_model
from
paddleslim.analysis
import
flops
from
paddleslim.analysis
import
flops
...
@@ -285,42 +286,7 @@ def train(cfg):
...
@@ -285,42 +286,7 @@ def train(cfg):
begin_epoch
=
load_checkpoint
(
exe
,
train_prog
)
begin_epoch
=
load_checkpoint
(
exe
,
train_prog
)
# Load pretrained model
# Load pretrained model
elif
os
.
path
.
exists
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
):
elif
os
.
path
.
exists
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
):
print_info
(
'Pretrained model dir: '
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
load_pretrained_weights
(
exe
,
train_prog
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
load_vars
=
[]
load_fail_vars
=
[]
def
var_shape_matched
(
var
,
shape
):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist
=
os
.
path
.
exists
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
if
var_exist
:
var_shape
=
parse_shape_from_file
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
return
var_shape
==
shape
return
False
for
x
in
train_prog
.
list_vars
():
if
isinstance
(
x
,
fluid
.
framework
.
Parameter
):
shape
=
tuple
(
fluid
.
global_scope
().
find_var
(
x
.
name
).
get_tensor
().
shape
())
if
var_shape_matched
(
x
,
shape
):
load_vars
.
append
(
x
)
else
:
load_fail_vars
.
append
(
x
)
fluid
.
io
.
load_vars
(
exe
,
dirname
=
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
vars
=
load_vars
)
for
var
in
load_vars
:
print_info
(
"Parameter[{}] loaded sucessfully!"
.
format
(
var
.
name
))
for
var
in
load_fail_vars
:
print_info
(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it."
.
format
(
var
.
name
))
print_info
(
"{}/{} pretrained parameters loaded successfully!"
.
format
(
len
(
load_vars
),
len
(
load_vars
)
+
len
(
load_fail_vars
)))
else
:
else
:
print_info
(
print_info
(
'Pretrained model dir {} not exists, training from scratch...'
.
'Pretrained model dir {} not exists, training from scratch...'
.
...
@@ -409,12 +375,9 @@ def train(cfg):
...
@@ -409,12 +375,9 @@ def train(cfg):
step
)
step
)
log_writer
.
add_scalar
(
'Train/mean_acc'
,
mean_acc
,
log_writer
.
add_scalar
(
'Train/mean_acc'
,
mean_acc
,
step
)
step
)
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
log_writer
.
add_scalar
(
'Train/step/sec'
,
speed
,
step
)
step
)
log_writer
.
add_scalar
(
'Train/step/sec'
,
speed
,
step
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
avg_loss
=
0.0
avg_loss
=
0.0
cm
.
zero_matrix
()
cm
.
zero_matrix
()
...
@@ -436,12 +399,9 @@ def train(cfg):
...
@@ -436,12 +399,9 @@ def train(cfg):
).
format
(
epoch
,
step
,
lr
[
0
],
avg_loss
,
speed
,
).
format
(
epoch
,
step
,
lr
[
0
],
avg_loss
,
speed
,
calculate_eta
(
all_step
-
step
,
speed
)))
calculate_eta
(
all_step
-
step
,
speed
)))
if
args
.
use_vdl
:
if
args
.
use_vdl
:
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
log_writer
.
add_scalar
(
'Train/speed'
,
speed
,
step
)
step
)
log_writer
.
add_scalar
(
'Train/speed'
,
speed
,
step
)
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
avg_loss
=
0.0
avg_loss
=
0.0
timer
.
restart
()
timer
.
restart
()
...
@@ -464,10 +424,8 @@ def train(cfg):
...
@@ -464,10 +424,8 @@ def train(cfg):
use_gpu
=
args
.
use_gpu
,
use_gpu
=
args
.
use_gpu
,
use_mpio
=
args
.
use_mpio
)
use_mpio
=
args
.
use_mpio
)
if
args
.
use_vdl
:
if
args
.
use_vdl
:
log_writer
.
add_scalar
(
'Evaluate/mean_iou'
,
mean_iou
,
log_writer
.
add_scalar
(
'Evaluate/mean_iou'
,
mean_iou
,
step
)
step
)
log_writer
.
add_scalar
(
'Evaluate/mean_acc'
,
mean_acc
,
step
)
log_writer
.
add_scalar
(
'Evaluate/mean_acc'
,
mean_acc
,
step
)
# Use VisualDL to visualize results
# Use VisualDL to visualize results
if
args
.
use_vdl
and
cfg
.
DATASET
.
VIS_FILE_LIST
is
not
None
:
if
args
.
use_vdl
and
cfg
.
DATASET
.
VIS_FILE_LIST
is
not
None
:
...
...
slim/quantization/eval_quant.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
slim/quantization/export_model.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
slim/quantization/train_quant.py
浏览文件 @
61645b1d
# coding: utf8
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -40,7 +40,8 @@ from models.model_builder import parse_shape_from_file
...
@@ -40,7 +40,8 @@ from models.model_builder import parse_shape_from_file
from
eval_quant
import
evaluate
from
eval_quant
import
evaluate
from
vis
import
visualize
from
vis
import
visualize
from
utils
import
dist_utils
from
utils
import
dist_utils
from
train
import
save_vars
,
save_checkpoint
,
load_checkpoint
,
update_best_model
,
print_info
from
utils.load_model_utils
import
load_pretrained_weights
from
train
import
update_best_model
,
print_info
from
paddleslim.quant
import
quant_aware
from
paddleslim.quant
import
quant_aware
...
@@ -103,6 +104,55 @@ def parse_args():
...
@@ -103,6 +104,55 @@ def parse_args():
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
save_checkpoint
(
exe
,
program
,
ckpt_name
):
"""
Save checkpoint for evaluation or resume training
"""
ckpt_dir
=
os
.
path
.
join
(
cfg
.
TRAIN
.
MODEL_SAVE_DIR
,
str
(
ckpt_name
))
print
(
"Save model checkpoint to {}"
.
format
(
ckpt_dir
))
if
not
os
.
path
.
isdir
(
ckpt_dir
):
os
.
makedirs
(
ckpt_dir
)
fluid
.
io
.
save_vars
(
exe
,
ckpt_dir
,
program
,
vars
=
list
(
filter
(
fluid
.
io
.
is_persistable
,
program
.
list_vars
())))
return
ckpt_dir
def
load_checkpoint
(
exe
,
program
):
"""
Load checkpoiont from pretrained model directory for resume training
"""
print
(
'Resume model training from:'
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
)
if
not
os
.
path
.
exists
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
):
raise
ValueError
(
"TRAIN.PRETRAIN_MODEL {} not exist!"
.
format
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
))
fluid
.
io
.
load_persistables
(
exe
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
,
main_program
=
program
)
model_path
=
cfg
.
TRAIN
.
RESUME_MODEL_DIR
# Check is path ended by path spearator
if
model_path
[
-
1
]
==
os
.
sep
:
model_path
=
model_path
[
0
:
-
1
]
epoch_name
=
os
.
path
.
basename
(
model_path
)
# If resume model is final model
if
epoch_name
==
'final'
:
begin_epoch
=
cfg
.
SOLVER
.
NUM_EPOCHS
# If resume model path is end of digit, restore epoch status
elif
epoch_name
.
isdigit
():
epoch
=
int
(
epoch_name
)
begin_epoch
=
epoch
+
1
else
:
raise
ValueError
(
"Resume model path is not valid!"
)
print
(
"Model checkpoint loaded successfully!"
)
return
begin_epoch
def
train_quant
(
cfg
):
def
train_quant
(
cfg
):
startup_prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
train_prog
=
fluid
.
Program
()
train_prog
=
fluid
.
Program
()
...
@@ -182,42 +232,7 @@ def train_quant(cfg):
...
@@ -182,42 +232,7 @@ def train_quant(cfg):
begin_epoch
=
load_checkpoint
(
exe
,
train_prog
)
begin_epoch
=
load_checkpoint
(
exe
,
train_prog
)
# Load pretrained model
# Load pretrained model
elif
os
.
path
.
exists
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
):
elif
os
.
path
.
exists
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
):
print_info
(
'Pretrained model dir: '
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
load_pretrained_weights
(
exe
,
train_prog
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
load_vars
=
[]
load_fail_vars
=
[]
def
var_shape_matched
(
var
,
shape
):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist
=
os
.
path
.
exists
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
if
var_exist
:
var_shape
=
parse_shape_from_file
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
return
var_shape
==
shape
return
False
for
x
in
train_prog
.
list_vars
():
if
isinstance
(
x
,
fluid
.
framework
.
Parameter
):
shape
=
tuple
(
fluid
.
global_scope
().
find_var
(
x
.
name
).
get_tensor
().
shape
())
if
var_shape_matched
(
x
,
shape
):
load_vars
.
append
(
x
)
else
:
load_fail_vars
.
append
(
x
)
fluid
.
io
.
load_vars
(
exe
,
dirname
=
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
vars
=
load_vars
)
for
var
in
load_vars
:
print_info
(
"Parameter[{}] loaded sucessfully!"
.
format
(
var
.
name
))
for
var
in
load_fail_vars
:
print_info
(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it."
.
format
(
var
.
name
))
print_info
(
"{}/{} pretrained parameters loaded successfully!"
.
format
(
len
(
load_vars
),
len
(
load_vars
)
+
len
(
load_fail_vars
)))
else
:
else
:
print_info
(
print_info
(
'Pretrained model dir {} not exists, training from scratch...'
.
'Pretrained model dir {} not exists, training from scratch...'
.
...
...
test/local_test_cityscapes.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
test/local_test_pet.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
test/test_utils.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录