Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
61645b1d
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
289
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
61645b1d
编写于
5月 25, 2020
作者:
C
chenguowei01
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/PaddleSeg
into develop
上级
d3ef4d2e
867d4a5b
变更
156
隐藏空白更改
内联
并排
Showing
156 changed file
with
2009 addition
and
1127 deletion
+2009
-1127
README.md
README.md
+1
-1
contrib/ACE2P/__init__.py
contrib/ACE2P/__init__.py
+14
-0
contrib/ACE2P/config.py
contrib/ACE2P/config.py
+17
-3
contrib/ACE2P/download_ACE2P.py
contrib/ACE2P/download_ACE2P.py
+4
-3
contrib/ACE2P/infer.py
contrib/ACE2P/infer.py
+37
-17
contrib/ACE2P/reader.py
contrib/ACE2P/reader.py
+31
-17
contrib/ACE2P/utils/__init__.py
contrib/ACE2P/utils/__init__.py
+14
-0
contrib/ACE2P/utils/palette.py
contrib/ACE2P/utils/palette.py
+1
-0
contrib/ACE2P/utils/util.py
contrib/ACE2P/utils/util.py
+22
-8
contrib/HumanSeg/datasets/__init__.py
contrib/HumanSeg/datasets/__init__.py
+4
-3
contrib/HumanSeg/datasets/dataset.py
contrib/HumanSeg/datasets/dataset.py
+2
-1
contrib/HumanSeg/datasets/shared_queue/__init__.py
contrib/HumanSeg/datasets/shared_queue/__init__.py
+2
-1
contrib/HumanSeg/datasets/shared_queue/queue.py
contrib/HumanSeg/datasets/shared_queue/queue.py
+2
-1
contrib/HumanSeg/datasets/shared_queue/sharedmemory.py
contrib/HumanSeg/datasets/shared_queue/sharedmemory.py
+2
-4
contrib/HumanSeg/export.py
contrib/HumanSeg/export.py
+15
-0
contrib/HumanSeg/infer.py
contrib/HumanSeg/infer.py
+15
-0
contrib/HumanSeg/models/__init__.py
contrib/HumanSeg/models/__init__.py
+15
-0
contrib/HumanSeg/models/humanseg.py
contrib/HumanSeg/models/humanseg.py
+4
-3
contrib/HumanSeg/models/load_model.py
contrib/HumanSeg/models/load_model.py
+2
-1
contrib/HumanSeg/nets/__init__.py
contrib/HumanSeg/nets/__init__.py
+15
-0
contrib/HumanSeg/nets/backbone/__init__.py
contrib/HumanSeg/nets/backbone/__init__.py
+15
-0
contrib/HumanSeg/nets/backbone/mobilenet_v2.py
contrib/HumanSeg/nets/backbone/mobilenet_v2.py
+3
-1
contrib/HumanSeg/nets/backbone/xception.py
contrib/HumanSeg/nets/backbone/xception.py
+1
-1
contrib/HumanSeg/nets/deeplabv3p.py
contrib/HumanSeg/nets/deeplabv3p.py
+1
-1
contrib/HumanSeg/nets/hrnet.py
contrib/HumanSeg/nets/hrnet.py
+1
-1
contrib/HumanSeg/nets/libs.py
contrib/HumanSeg/nets/libs.py
+1
-1
contrib/HumanSeg/nets/seg_modules.py
contrib/HumanSeg/nets/seg_modules.py
+2
-1
contrib/HumanSeg/nets/shufflenet_slim.py
contrib/HumanSeg/nets/shufflenet_slim.py
+15
-0
contrib/HumanSeg/pretrained_weights/download_pretrained_weights.py
...umanSeg/pretrained_weights/download_pretrained_weights.py
+4
-3
contrib/HumanSeg/quant_offline.py
contrib/HumanSeg/quant_offline.py
+15
-0
contrib/HumanSeg/quant_online.py
contrib/HumanSeg/quant_online.py
+15
-0
contrib/HumanSeg/train.py
contrib/HumanSeg/train.py
+15
-0
contrib/HumanSeg/transforms/__init__.py
contrib/HumanSeg/transforms/__init__.py
+4
-3
contrib/HumanSeg/transforms/functional.py
contrib/HumanSeg/transforms/functional.py
+4
-3
contrib/HumanSeg/transforms/transforms.py
contrib/HumanSeg/transforms/transforms.py
+4
-3
contrib/HumanSeg/utils/__init__.py
contrib/HumanSeg/utils/__init__.py
+4
-3
contrib/HumanSeg/utils/humanseg_postprocess.py
contrib/HumanSeg/utils/humanseg_postprocess.py
+15
-0
contrib/HumanSeg/utils/logging.py
contrib/HumanSeg/utils/logging.py
+2
-1
contrib/HumanSeg/utils/metrics.py
contrib/HumanSeg/utils/metrics.py
+1
-1
contrib/HumanSeg/utils/post_quantization.py
contrib/HumanSeg/utils/post_quantization.py
+2
-1
contrib/HumanSeg/utils/utils.py
contrib/HumanSeg/utils/utils.py
+5
-6
contrib/HumanSeg/val.py
contrib/HumanSeg/val.py
+15
-0
contrib/HumanSeg/video_infer.py
contrib/HumanSeg/video_infer.py
+15
-0
contrib/LaneNet/data_aug.py
contrib/LaneNet/data_aug.py
+4
-2
contrib/LaneNet/dataset/download_tusimple.py
contrib/LaneNet/dataset/download_tusimple.py
+4
-3
contrib/LaneNet/eval.py
contrib/LaneNet/eval.py
+5
-2
contrib/LaneNet/loss.py
contrib/LaneNet/loss.py
+1
-1
contrib/LaneNet/models/__init__.py
contrib/LaneNet/models/__init__.py
+1
-1
contrib/LaneNet/models/model_builder.py
contrib/LaneNet/models/model_builder.py
+1
-1
contrib/LaneNet/models/modeling/lanenet.py
contrib/LaneNet/models/modeling/lanenet.py
+179
-58
contrib/LaneNet/reader.py
contrib/LaneNet/reader.py
+21
-13
contrib/LaneNet/train.py
contrib/LaneNet/train.py
+13
-91
contrib/LaneNet/utils/__init__.py
contrib/LaneNet/utils/__init__.py
+14
-0
contrib/LaneNet/utils/config.py
contrib/LaneNet/utils/config.py
+6
-6
contrib/LaneNet/utils/dist_utils.py
contrib/LaneNet/utils/dist_utils.py
+10
-9
contrib/LaneNet/utils/generate_tusimple_dataset.py
contrib/LaneNet/utils/generate_tusimple_dataset.py
+65
-21
contrib/LaneNet/utils/lanenet_postprocess.py
contrib/LaneNet/utils/lanenet_postprocess.py
+75
-41
contrib/LaneNet/utils/load_model_utils.py
contrib/LaneNet/utils/load_model_utils.py
+126
-0
contrib/LaneNet/vis.py
contrib/LaneNet/vis.py
+21
-13
contrib/MechanicalIndustryMeter/download_mini_mechanical_industry_meter.py
...lIndustryMeter/download_mini_mechanical_industry_meter.py
+4
-3
contrib/MechanicalIndustryMeter/download_unet_mechanical_industry_meter.py
...lIndustryMeter/download_unet_mechanical_industry_meter.py
+6
-4
contrib/RemoteSensing/__init__.py
contrib/RemoteSensing/__init__.py
+4
-3
contrib/RemoteSensing/models/__init__.py
contrib/RemoteSensing/models/__init__.py
+15
-0
contrib/RemoteSensing/models/base.py
contrib/RemoteSensing/models/base.py
+10
-9
contrib/RemoteSensing/models/load_model.py
contrib/RemoteSensing/models/load_model.py
+2
-1
contrib/RemoteSensing/models/unet.py
contrib/RemoteSensing/models/unet.py
+10
-9
contrib/RemoteSensing/nets/__init__.py
contrib/RemoteSensing/nets/__init__.py
+15
-0
contrib/RemoteSensing/nets/libs.py
contrib/RemoteSensing/nets/libs.py
+1
-1
contrib/RemoteSensing/nets/loss.py
contrib/RemoteSensing/nets/loss.py
+2
-1
contrib/RemoteSensing/nets/unet.py
contrib/RemoteSensing/nets/unet.py
+1
-1
contrib/RemoteSensing/predict_demo.py
contrib/RemoteSensing/predict_demo.py
+15
-0
contrib/RemoteSensing/readers/__init__.py
contrib/RemoteSensing/readers/__init__.py
+2
-1
contrib/RemoteSensing/readers/base.py
contrib/RemoteSensing/readers/base.py
+2
-1
contrib/RemoteSensing/readers/reader.py
contrib/RemoteSensing/readers/reader.py
+2
-1
contrib/RemoteSensing/tools/create_dataset_list.py
contrib/RemoteSensing/tools/create_dataset_list.py
+1
-1
contrib/RemoteSensing/tools/split_dataset_list.py
contrib/RemoteSensing/tools/split_dataset_list.py
+1
-1
contrib/RemoteSensing/train_demo.py
contrib/RemoteSensing/train_demo.py
+15
-0
contrib/RemoteSensing/transforms/__init__.py
contrib/RemoteSensing/transforms/__init__.py
+4
-3
contrib/RemoteSensing/transforms/ops.py
contrib/RemoteSensing/transforms/ops.py
+2
-1
contrib/RemoteSensing/transforms/transforms.py
contrib/RemoteSensing/transforms/transforms.py
+1
-1
contrib/RemoteSensing/utils/__init__.py
contrib/RemoteSensing/utils/__init__.py
+4
-3
contrib/RemoteSensing/utils/logging.py
contrib/RemoteSensing/utils/logging.py
+2
-1
contrib/RemoteSensing/utils/metrics.py
contrib/RemoteSensing/utils/metrics.py
+1
-1
contrib/RemoteSensing/utils/pretrain_weights.py
contrib/RemoteSensing/utils/pretrain_weights.py
+15
-0
contrib/RemoteSensing/utils/utils.py
contrib/RemoteSensing/utils/utils.py
+5
-6
contrib/RoadLine/__init__.py
contrib/RoadLine/__init__.py
+14
-0
contrib/RoadLine/config.py
contrib/RoadLine/config.py
+20
-6
contrib/RoadLine/download_RoadLine.py
contrib/RoadLine/download_RoadLine.py
+4
-3
contrib/RoadLine/infer.py
contrib/RoadLine/infer.py
+39
-19
contrib/RoadLine/utils/__init__.py
contrib/RoadLine/utils/__init__.py
+14
-0
contrib/RoadLine/utils/palette.py
contrib/RoadLine/utils/palette.py
+15
-9
contrib/RoadLine/utils/util.py
contrib/RoadLine/utils/util.py
+22
-8
dataset/convert_voc2012.py
dataset/convert_voc2012.py
+11
-6
dataset/download_and_convert_voc2012.py
dataset/download_and_convert_voc2012.py
+9
-8
dataset/download_cityscapes.py
dataset/download_cityscapes.py
+4
-3
dataset/download_mini_deepglobe_road_extraction.py
dataset/download_mini_deepglobe_road_extraction.py
+4
-3
dataset/download_optic.py
dataset/download_optic.py
+4
-3
dataset/download_pet.py
dataset/download_pet.py
+4
-3
deploy/python/infer.py
deploy/python/infer.py
+10
-6
pdseg/__init__.py
pdseg/__init__.py
+2
-2
pdseg/check.py
pdseg/check.py
+21
-3
pdseg/data_aug.py
pdseg/data_aug.py
+3
-3
pdseg/data_utils.py
pdseg/data_utils.py
+16
-2
pdseg/eval.py
pdseg/eval.py
+5
-6
pdseg/export_model.py
pdseg/export_model.py
+11
-5
pdseg/loss.py
pdseg/loss.py
+1
-1
pdseg/lovasz_losses.py
pdseg/lovasz_losses.py
+1
-1
pdseg/metrics.py
pdseg/metrics.py
+1
-1
pdseg/models/__init__.py
pdseg/models/__init__.py
+1
-1
pdseg/models/backbone/__init__.py
pdseg/models/backbone/__init__.py
+14
-0
pdseg/models/backbone/mobilenet_v2.py
pdseg/models/backbone/mobilenet_v2.py
+1
-1
pdseg/models/backbone/resnet.py
pdseg/models/backbone/resnet.py
+15
-13
pdseg/models/backbone/vgg.py
pdseg/models/backbone/vgg.py
+3
-2
pdseg/models/backbone/xception.py
pdseg/models/backbone/xception.py
+1
-1
pdseg/models/libs/__init__.py
pdseg/models/libs/__init__.py
+14
-0
pdseg/models/libs/model_libs.py
pdseg/models/libs/model_libs.py
+2
-2
pdseg/models/model_builder.py
pdseg/models/model_builder.py
+1
-1
pdseg/models/modeling/__init__.py
pdseg/models/modeling/__init__.py
+14
-0
pdseg/models/modeling/deeplab.py
pdseg/models/modeling/deeplab.py
+1
-1
pdseg/models/modeling/fast_scnn.py
pdseg/models/modeling/fast_scnn.py
+1
-1
pdseg/models/modeling/hrnet.py
pdseg/models/modeling/hrnet.py
+153
-52
pdseg/models/modeling/icnet.py
pdseg/models/modeling/icnet.py
+1
-1
pdseg/models/modeling/pspnet.py
pdseg/models/modeling/pspnet.py
+37
-34
pdseg/models/modeling/unet.py
pdseg/models/modeling/unet.py
+1
-1
pdseg/reader.py
pdseg/reader.py
+5
-3
pdseg/solver.py
pdseg/solver.py
+1
-1
pdseg/tools/__init__.py
pdseg/tools/__init__.py
+1
-1
pdseg/tools/create_dataset_list.py
pdseg/tools/create_dataset_list.py
+21
-26
pdseg/tools/gray2pseudo_color.py
pdseg/tools/gray2pseudo_color.py
+21
-11
pdseg/tools/jingling2seg.py
pdseg/tools/jingling2seg.py
+14
-1
pdseg/tools/labelme2seg.py
pdseg/tools/labelme2seg.py
+14
-1
pdseg/train.py
pdseg/train.py
+22
-108
pdseg/utils/__init__.py
pdseg/utils/__init__.py
+14
-0
pdseg/utils/collect.py
pdseg/utils/collect.py
+14
-9
pdseg/utils/config.py
pdseg/utils/config.py
+4
-4
pdseg/utils/dist_utils.py
pdseg/utils/dist_utils.py
+1
-1
pdseg/utils/fp16_utils.py
pdseg/utils/fp16_utils.py
+17
-1
pdseg/utils/load_model_utils.py
pdseg/utils/load_model_utils.py
+128
-0
pdseg/utils/timer.py
pdseg/utils/timer.py
+4
-3
pdseg/vis.py
pdseg/vis.py
+14
-16
pretrained_model/download_model.py
pretrained_model/download_model.py
+5
-4
slim/distillation/model_builder.py
slim/distillation/model_builder.py
+1
-1
slim/distillation/train_distill.py
slim/distillation/train_distill.py
+31
-107
slim/nas/deeplab.py
slim/nas/deeplab.py
+7
-4
slim/nas/eval_nas.py
slim/nas/eval_nas.py
+5
-2
slim/nas/mobilenetv2_search_space.py
slim/nas/mobilenetv2_search_space.py
+10
-9
slim/nas/model_builder.py
slim/nas/model_builder.py
+1
-1
slim/nas/train_nas.py
slim/nas/train_nas.py
+14
-90
slim/prune/eval_prune.py
slim/prune/eval_prune.py
+1
-1
slim/prune/train_prune.py
slim/prune/train_prune.py
+11
-53
slim/quantization/eval_quant.py
slim/quantization/eval_quant.py
+1
-1
slim/quantization/export_model.py
slim/quantization/export_model.py
+1
-1
slim/quantization/train_quant.py
slim/quantization/train_quant.py
+53
-38
test/local_test_cityscapes.py
test/local_test_cityscapes.py
+4
-3
test/local_test_pet.py
test/local_test_pet.py
+4
-3
test/test_utils.py
test/test_utils.py
+4
-3
未找到文件。
README.md
浏览文件 @
61645b1d
...
...
@@ -35,7 +35,7 @@ PaddleSeg是基于[PaddlePaddle](https://www.paddlepaddle.org.cn)开发的端到
-
**高性能**
PaddleSeg支持多进程I/O、多卡并行
、跨卡Batch Norm同步
等训练加速策略,结合飞桨核心框架的显存优化功能,可大幅度减少分割模型的显存开销,让开发者更低成本、更高效地完成图像分割训练。
PaddleSeg支持多进程I/O、多卡并行等训练加速策略,结合飞桨核心框架的显存优化功能,可大幅度减少分割模型的显存开销,让开发者更低成本、更高效地完成图像分割训练。
-
**工业级部署**
...
...
contrib/ACE2P/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
contrib/ACE2P/config.py
浏览文件 @
61645b1d
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
utils.util
import
AttrDict
,
merge_cfg_from_args
,
get_arguments
import
os
...
...
@@ -19,10 +33,10 @@ cfg.class_num = 20
# 均值, 图像预处理减去的均值
cfg
.
MEAN
=
0.406
,
0.456
,
0.485
# 标准差,图像预处理除以标准差
cfg
.
STD
=
0.225
,
0.224
,
0.229
cfg
.
STD
=
0.225
,
0.224
,
0.229
# 多尺度预测时图像尺寸
cfg
.
multi_scales
=
(
377
,
377
),
(
473
,
473
),
(
567
,
567
)
cfg
.
multi_scales
=
(
377
,
377
),
(
473
,
473
),
(
567
,
567
)
# 多尺度预测时图像是否水平翻转
cfg
.
flip
=
True
...
...
contrib/ACE2P/download_ACE2P.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/ACE2P/infer.py
浏览文件 @
61645b1d
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
cv2
import
numpy
as
np
...
...
@@ -12,18 +26,19 @@ config = importlib.import_module('config')
cfg
=
getattr
(
config
,
'cfg'
)
# paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启
os
.
environ
[
'FLAGS_eager_delete_tensor_gb'
]
=
'0.0'
os
.
environ
[
'FLAGS_eager_delete_tensor_gb'
]
=
'0.0'
import
paddle.fluid
as
fluid
# 预测数据集类
class
TestDataSet
():
def
__init__
(
self
):
self
.
data_dir
=
cfg
.
data_dir
self
.
data_dir
=
cfg
.
data_dir
self
.
data_list_file
=
cfg
.
data_list_file
self
.
data_list
=
self
.
get_data_list
()
self
.
data_num
=
len
(
self
.
data_list
)
def
get_data_list
(
self
):
# 获取预测图像路径列表
data_list
=
[]
...
...
@@ -56,10 +71,10 @@ class TestDataSet():
img_path
=
self
.
data_list
[
index
]
img
=
cv2
.
imread
(
img_path
,
cv2
.
IMREAD_COLOR
)
if
img
is
None
:
return
img
,
img
,
img_path
,
None
return
img
,
img
,
img_path
,
None
img_name
=
img_path
.
split
(
os
.
sep
)[
-
1
]
name_prefix
=
img_name
.
replace
(
'.'
+
img_name
.
split
(
'.'
)[
-
1
],
''
)
name_prefix
=
img_name
.
replace
(
'.'
+
img_name
.
split
(
'.'
)[
-
1
],
''
)
img_shape
=
img
.
shape
[:
2
]
img_process
=
self
.
preprocess
(
img
)
...
...
@@ -90,39 +105,44 @@ def infer():
if
image
is
None
:
print
(
im_name
,
'is None'
)
continue
# 预测
if
cfg
.
example
==
'ACE2P'
:
# ACE2P模型使用多尺度预测
reader
=
importlib
.
import_module
(
'reader'
)
multi_scale_test
=
getattr
(
reader
,
'multi_scale_test'
)
parsing
,
logits
=
multi_scale_test
(
exe
,
test_prog
,
feed_name
,
fetch_list
,
image
,
im_shape
)
parsing
,
logits
=
multi_scale_test
(
exe
,
test_prog
,
feed_name
,
fetch_list
,
image
,
im_shape
)
else
:
# HumanSeg,RoadLine模型单尺度预测
result
=
exe
.
run
(
program
=
test_prog
,
feed
=
{
feed_name
[
0
]:
image
},
fetch_list
=
fetch_list
)
result
=
exe
.
run
(
program
=
test_prog
,
feed
=
{
feed_name
[
0
]:
image
},
fetch_list
=
fetch_list
)
parsing
=
np
.
argmax
(
result
[
0
][
0
],
axis
=
0
)
parsing
=
cv2
.
resize
(
parsing
.
astype
(
np
.
uint8
),
im_shape
[::
-
1
])
# 预测结果保存
result_path
=
os
.
path
.
join
(
cfg
.
vis_dir
,
im_name
+
'.png'
)
if
cfg
.
example
==
'HumanSeg'
:
logits
=
result
[
0
][
0
][
1
]
*
255
logits
=
result
[
0
][
0
][
1
]
*
255
logits
=
cv2
.
resize
(
logits
,
im_shape
[::
-
1
])
ret
,
logits
=
cv2
.
threshold
(
logits
,
thresh
,
0
,
cv2
.
THRESH_TOZERO
)
logits
=
255
*
(
logits
-
thresh
)
/
(
255
-
thresh
)
logits
=
255
*
(
logits
-
thresh
)
/
(
255
-
thresh
)
# 将分割结果添加到alpha通道
rgba
=
np
.
concatenate
((
ori_img
,
np
.
expand_dims
(
logits
,
axis
=
2
)),
axis
=
2
)
rgba
=
np
.
concatenate
((
ori_img
,
np
.
expand_dims
(
logits
,
axis
=
2
)),
axis
=
2
)
cv2
.
imwrite
(
result_path
,
rgba
)
else
:
else
:
output_im
=
PILImage
.
fromarray
(
np
.
asarray
(
parsing
,
dtype
=
np
.
uint8
))
output_im
.
putpalette
(
palette
)
output_im
.
save
(
result_path
)
if
(
idx
+
1
)
%
100
==
0
:
print
(
'%d processd'
%
(
idx
+
1
))
print
(
'%d processd done'
%
(
idx
+
1
))
print
(
'%d processd done'
%
(
idx
+
1
))
return
0
...
...
contrib/ACE2P/reader.py
浏览文件 @
61645b1d
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle.fluid
as
fluid
from
config
import
cfg
import
cv2
def
get_affine_points
(
src_shape
,
dst_shape
,
rot_grad
=
0
):
# 获取图像和仿射后图像的三组对应点坐标
# 三组点为仿射变换后图像的中心点, [w/2,0], [0,0],及对应原始图像的点
...
...
@@ -23,7 +38,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0):
# 原始图像三组点
points
=
[[
0
,
0
]]
*
3
points
[
0
]
=
(
np
.
array
([
w
,
h
])
-
1
)
*
0.5
points
[
0
]
=
(
np
.
array
([
w
,
h
])
-
1
)
*
0.5
points
[
1
]
=
points
[
0
]
+
0.5
*
affine_shape
[
0
]
*
np
.
array
([
sin_v
,
-
cos_v
])
points
[
2
]
=
points
[
1
]
-
0.5
*
affine_shape
[
1
]
*
np
.
array
([
cos_v
,
sin_v
])
...
...
@@ -34,6 +49,7 @@ def get_affine_points(src_shape, dst_shape, rot_grad=0):
return
points
,
points_trans
def
preprocess
(
im
):
# ACE2P模型数据预处理
im_shape
=
im
.
shape
[:
2
]
...
...
@@ -42,13 +58,10 @@ def preprocess(im):
# 获取图像和仿射变换后图像的对应点坐标
points
,
points_trans
=
get_affine_points
(
im_shape
,
scale
)
# 根据对应点集获得仿射矩阵
trans
=
cv2
.
getAffineTransform
(
np
.
float32
(
points
),
np
.
float32
(
points_trans
))
trans
=
cv2
.
getAffineTransform
(
np
.
float32
(
points
),
np
.
float32
(
points_trans
))
# 根据仿射矩阵对图像进行仿射
input
=
cv2
.
warpAffine
(
im
,
trans
,
scale
[::
-
1
],
flags
=
cv2
.
INTER_LINEAR
)
input
=
cv2
.
warpAffine
(
im
,
trans
,
scale
[::
-
1
],
flags
=
cv2
.
INTER_LINEAR
)
# 减均值测,除以方差,转换数据格式为NCHW
input
=
input
.
astype
(
np
.
float32
)
...
...
@@ -66,19 +79,20 @@ def preprocess(im):
return
input_images
def
multi_scale_test
(
exe
,
test_prog
,
feed_name
,
fetch_list
,
input_ims
,
im_shape
):
def
multi_scale_test
(
exe
,
test_prog
,
feed_name
,
fetch_list
,
input_ims
,
im_shape
):
# 由于部分类别分左右部位, flipped_idx为其水平翻转后对应的标签
flipped_idx
=
(
15
,
14
,
17
,
16
,
19
,
18
)
ms_outputs
=
[]
# 多尺度预测
for
idx
,
scale
in
enumerate
(
cfg
.
multi_scales
):
input_im
=
input_ims
[
idx
]
parsing_output
=
exe
.
run
(
program
=
test_prog
,
feed
=
{
feed_name
[
0
]:
input_im
},
fetch_list
=
fetch_list
)
parsing_output
=
exe
.
run
(
program
=
test_prog
,
feed
=
{
feed_name
[
0
]:
input_im
},
fetch_list
=
fetch_list
)
output
=
parsing_output
[
0
][
0
]
if
cfg
.
flip
:
# 若水平翻转,对部分类别进行翻转,与原始预测结果取均值
...
...
@@ -92,7 +106,8 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list,
# 仿射变换回图像原始尺寸
points
,
points_trans
=
get_affine_points
(
im_shape
,
scale
)
M
=
cv2
.
getAffineTransform
(
np
.
float32
(
points_trans
),
np
.
float32
(
points
))
logits_result
=
cv2
.
warpAffine
(
output
,
M
,
im_shape
[::
-
1
],
flags
=
cv2
.
INTER_LINEAR
)
logits_result
=
cv2
.
warpAffine
(
output
,
M
,
im_shape
[::
-
1
],
flags
=
cv2
.
INTER_LINEAR
)
ms_outputs
.
append
(
logits_result
)
# 多尺度预测结果求均值,求预测概率最大的类别
...
...
@@ -100,4 +115,3 @@ def multi_scale_test(exe, test_prog, feed_name, fetch_list,
ms_fused_parsing_output
=
np
.
mean
(
ms_fused_parsing_output
,
axis
=
0
)
parsing
=
np
.
argmax
(
ms_fused_parsing_output
,
axis
=
2
)
return
parsing
,
ms_fused_parsing_output
contrib/ACE2P/utils/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
contrib/ACE2P/utils/palette.py
浏览文件 @
61645b1d
...
...
@@ -7,6 +7,7 @@
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
contrib/ACE2P/utils/util.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
import
argparse
import
os
def
get_arguments
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--use_gpu"
,
action
=
"store_true"
,
help
=
"Use gpu or cpu to test."
)
parser
.
add_argument
(
'--example'
,
type
=
str
,
help
=
'RoadLine, HumanSeg or ACE2P'
)
parser
.
add_argument
(
"--use_gpu"
,
action
=
"store_true"
,
help
=
"Use gpu or cpu to test."
)
parser
.
add_argument
(
'--example'
,
type
=
str
,
help
=
'RoadLine, HumanSeg or ACE2P'
)
return
parser
.
parse_args
()
...
...
@@ -34,6 +48,7 @@ class AttrDict(dict):
else
:
self
[
name
]
=
value
def
merge_cfg_from_args
(
args
,
cfg
):
"""Merge config keys, values in args into the global config."""
for
k
,
v
in
vars
(
args
).
items
():
...
...
@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg):
value
=
v
if
value
is
not
None
:
cfg
[
k
]
=
value
contrib/HumanSeg/datasets/__init__.py
浏览文件 @
61645b1d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/HumanSeg/datasets/dataset.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/datasets/shared_queue/__init__.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/datasets/shared_queue/queue.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/datasets/shared_queue/sharedmemory.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -12,9 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# utils for memory management which is allocated on sharedmemory,
# note that these structures may not be thread-safe
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
contrib/HumanSeg/export.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
models
import
argparse
...
...
contrib/HumanSeg/infer.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
os.path
as
osp
...
...
contrib/HumanSeg/models/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.humanseg
import
HumanSegMobile
from
.humanseg
import
HumanSegServer
from
.humanseg
import
HumanSegLite
...
...
contrib/HumanSeg/models/humanseg.py
浏览文件 @
61645b1d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/HumanSeg/models/load_model.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/nets/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.backbone
import
mobilenet_v2
from
.backbone
import
xception
from
.deeplabv3p
import
DeepLabv3p
...
...
contrib/HumanSeg/nets/backbone/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.mobilenet_v2
import
MobileNetV2
from
.xception
import
Xception
contrib/HumanSeg/nets/backbone/mobilenet_v2.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -10,6 +11,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
...
...
contrib/HumanSeg/nets/backbone/xception.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/nets/deeplabv3p.py
浏览文件 @
61645b1d
# coding: utf8
#
copyright (c) 2020
PaddlePaddle Authors. All Rights Reserve.
#
Copyright (c) 2019
PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/nets/hrnet.py
浏览文件 @
61645b1d
# coding: utf8
#
copyright (c) 2020
PaddlePaddle Authors. All Rights Reserve.
#
Copyright (c) 2019
PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/nets/libs.py
浏览文件 @
61645b1d
# coding: utf8
#
copyright (c) 2020
PaddlePaddle Authors. All Rights Reserve.
#
Copyright (c) 2019
PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/nets/seg_modules.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/nets/shufflenet_slim.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
contrib/HumanSeg/pretrained_weights/download_pretrained_weights.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/HumanSeg/quant_offline.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
from
datasets.dataset
import
Dataset
import
transforms
...
...
contrib/HumanSeg/quant_online.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
from
datasets.dataset
import
Dataset
from
models
import
HumanSegMobile
,
HumanSegLite
,
HumanSegServer
...
...
contrib/HumanSeg/train.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
from
datasets.dataset
import
Dataset
from
models
import
HumanSegMobile
,
HumanSegLite
,
HumanSegServer
...
...
contrib/HumanSeg/transforms/__init__.py
浏览文件 @
61645b1d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/HumanSeg/transforms/functional.py
浏览文件 @
61645b1d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/HumanSeg/transforms/transforms.py
浏览文件 @
61645b1d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/HumanSeg/utils/__init__.py
浏览文件 @
61645b1d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/HumanSeg/utils/humanseg_postprocess.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
cv2
import
os
...
...
contrib/HumanSeg/utils/logging.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/utils/metrics.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/utils/post_quantization.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/HumanSeg/utils/utils.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -205,11 +206,9 @@ def load_pretrained_weights(exe, main_prog, weights_dir, fuse_bn=False):
vars_to_load
.
append
(
var
)
logging
.
debug
(
"Weight {} will be load"
.
format
(
var
.
name
))
fluid
.
io
.
load_vars
(
executor
=
exe
,
dirname
=
weights_dir
,
main_program
=
main_prog
,
vars
=
vars_to_load
)
params_dict
=
fluid
.
io
.
load_program_state
(
weights_dir
,
var_list
=
vars_to_load
)
fluid
.
io
.
set_program_state
(
main_prog
,
params_dict
)
if
len
(
vars_to_load
)
==
0
:
logging
.
warning
(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
...
...
contrib/HumanSeg/val.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
from
datasets.dataset
import
Dataset
import
transforms
...
...
contrib/HumanSeg/video_infer.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
os.path
as
osp
...
...
contrib/LaneNet/data_aug.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -21,6 +21,7 @@ from models.model_builder import ModelPhase
from
pdseg.data_aug
import
get_random_scale
,
randomly_scale_image_and_label
,
random_rotation
,
\
rand_scale_aspect
,
hsv_color_jitter
,
rand_crop
def
resize
(
img
,
grt
=
None
,
grt_instance
=
None
,
mode
=
ModelPhase
.
TRAIN
):
"""
改变图像及标签图像尺寸
...
...
@@ -44,7 +45,8 @@ def resize(img, grt=None, grt_instance=None, mode=ModelPhase.TRAIN):
if
grt
is
not
None
:
grt
=
cv2
.
resize
(
grt
,
target_size
,
interpolation
=
cv2
.
INTER_NEAREST
)
if
grt_instance
is
not
None
:
grt_instance
=
cv2
.
resize
(
grt_instance
,
target_size
,
interpolation
=
cv2
.
INTER_NEAREST
)
grt_instance
=
cv2
.
resize
(
grt_instance
,
target_size
,
interpolation
=
cv2
.
INTER_NEAREST
)
elif
cfg
.
AUG
.
AUG_METHOD
==
'stepscaling'
:
if
mode
==
ModelPhase
.
TRAIN
:
min_scale_factor
=
cfg
.
AUG
.
MIN_SCALE_FACTOR
...
...
contrib/LaneNet/dataset/download_tusimple.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/LaneNet/eval.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -122,7 +122,10 @@ def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs):
if
ckpt_dir
is
not
None
:
print
(
'load test model:'
,
ckpt_dir
)
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
try
:
fluid
.
load
(
test_prog
,
os
.
path
.
join
(
ckpt_dir
,
'model'
),
exe
)
except
:
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
# Use streaming confusion matrix to calculate mean_iou
np
.
set_printoptions
(
...
...
contrib/LaneNet/loss.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/LaneNet/models/__init__.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/LaneNet/models/model_builder.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/LaneNet/models/modeling/lanenet.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -18,7 +18,6 @@ from __future__ import print_function
import
paddle.fluid
as
fluid
from
utils.config
import
cfg
from
pdseg.models.libs.model_libs
import
scope
,
name_scope
from
pdseg.models.libs.model_libs
import
bn
,
bn_relu
,
relu
...
...
@@ -86,7 +85,12 @@ def bottleneck(inputs,
with
scope
(
'down_sample'
):
inputs_shape
=
inputs
.
shape
with
scope
(
'main_max_pool'
):
net_main
=
fluid
.
layers
.
conv2d
(
inputs
,
inputs_shape
[
1
],
filter_size
=
3
,
stride
=
2
,
padding
=
'SAME'
)
net_main
=
fluid
.
layers
.
conv2d
(
inputs
,
inputs_shape
[
1
],
filter_size
=
3
,
stride
=
2
,
padding
=
'SAME'
)
#First get the difference in depth to pad, then pad with zeros only on the last dimension.
depth_to_pad
=
abs
(
inputs_shape
[
1
]
-
output_depth
)
...
...
@@ -95,12 +99,16 @@ def bottleneck(inputs,
net_main
=
fluid
.
layers
.
pad
(
net_main
,
paddings
=
paddings
)
with
scope
(
'block1'
):
net
=
conv
(
inputs
,
reduced_depth
,
[
2
,
2
],
stride
=
2
,
padding
=
'same'
)
net
=
conv
(
inputs
,
reduced_depth
,
[
2
,
2
],
stride
=
2
,
padding
=
'same'
)
net
=
bn
(
net
)
net
=
prelu
(
net
,
decoder
=
decoder
)
with
scope
(
'block2'
):
net
=
conv
(
net
,
reduced_depth
,
[
filter_size
,
filter_size
],
padding
=
'same'
)
net
=
conv
(
net
,
reduced_depth
,
[
filter_size
,
filter_size
],
padding
=
'same'
)
net
=
bn
(
net
)
net
=
prelu
(
net
,
decoder
=
decoder
)
...
...
@@ -137,13 +145,18 @@ def bottleneck(inputs,
# Second conv block --- apply dilated convolution here
with
scope
(
'block2'
):
net
=
conv
(
net
,
reduced_depth
,
filter_size
,
padding
=
'SAME'
,
dilation
=
dilation_rate
)
net
=
conv
(
net
,
reduced_depth
,
filter_size
,
padding
=
'SAME'
,
dilation
=
dilation_rate
)
net
=
bn
(
net
)
net
=
prelu
(
net
,
decoder
=
decoder
)
# Final projection with 1x1 kernel (Expansion)
with
scope
(
'block3'
):
net
=
conv
(
net
,
output_depth
,
[
1
,
1
])
net
=
conv
(
net
,
output_depth
,
[
1
,
1
])
net
=
bn
(
net
)
net
=
prelu
(
net
,
decoder
=
decoder
)
...
...
@@ -172,9 +185,11 @@ def bottleneck(inputs,
# Second conv block --- apply asymmetric conv here
with
scope
(
'block2'
):
with
scope
(
'asymmetric_conv2a'
):
net
=
conv
(
net
,
reduced_depth
,
[
filter_size
,
1
],
padding
=
'same'
)
net
=
conv
(
net
,
reduced_depth
,
[
filter_size
,
1
],
padding
=
'same'
)
with
scope
(
'asymmetric_conv2b'
):
net
=
conv
(
net
,
reduced_depth
,
[
1
,
filter_size
],
padding
=
'same'
)
net
=
conv
(
net
,
reduced_depth
,
[
1
,
filter_size
],
padding
=
'same'
)
net
=
bn
(
net
)
net
=
prelu
(
net
,
decoder
=
decoder
)
...
...
@@ -211,7 +226,8 @@ def bottleneck(inputs,
with
scope
(
'unpool'
):
net_unpool
=
conv
(
inputs
,
output_depth
,
[
1
,
1
])
net_unpool
=
bn
(
net_unpool
)
net_unpool
=
fluid
.
layers
.
resize_bilinear
(
net_unpool
,
out_shape
=
output_shape
[
2
:])
net_unpool
=
fluid
.
layers
.
resize_bilinear
(
net_unpool
,
out_shape
=
output_shape
[
2
:])
# First 1x1 projection to reduce depth
with
scope
(
'block1'
):
...
...
@@ -220,7 +236,12 @@ def bottleneck(inputs,
net
=
prelu
(
net
,
decoder
=
decoder
)
with
scope
(
'block2'
):
net
=
deconv
(
net
,
reduced_depth
,
filter_size
=
filter_size
,
stride
=
2
,
padding
=
'same'
)
net
=
deconv
(
net
,
reduced_depth
,
filter_size
=
filter_size
,
stride
=
2
,
padding
=
'same'
)
net
=
bn
(
net
)
net
=
prelu
(
net
,
decoder
=
decoder
)
...
...
@@ -253,7 +274,10 @@ def bottleneck(inputs,
# Second conv block
with
scope
(
'block2'
):
net
=
conv
(
net
,
reduced_depth
,
[
filter_size
,
filter_size
],
padding
=
'same'
)
net
=
conv
(
net
,
reduced_depth
,
[
filter_size
,
filter_size
],
padding
=
'same'
)
net
=
bn
(
net
)
net
=
prelu
(
net
,
decoder
=
decoder
)
...
...
@@ -281,17 +305,33 @@ def ENet_stage1(inputs, name_scope='stage1_block'):
=
bottleneck
(
inputs
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.01
,
type
=
DOWNSAMPLING
,
name_scope
=
'bottleneck1_0'
)
with
scope
(
'bottleneck1_1'
):
net
=
bottleneck
(
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.01
,
name_scope
=
'bottleneck1_1'
)
net
=
bottleneck
(
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.01
,
name_scope
=
'bottleneck1_1'
)
with
scope
(
'bottleneck1_2'
):
net
=
bottleneck
(
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.01
,
name_scope
=
'bottleneck1_2'
)
net
=
bottleneck
(
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.01
,
name_scope
=
'bottleneck1_2'
)
with
scope
(
'bottleneck1_3'
):
net
=
bottleneck
(
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.01
,
name_scope
=
'bottleneck1_3'
)
net
=
bottleneck
(
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.01
,
name_scope
=
'bottleneck1_3'
)
with
scope
(
'bottleneck1_4'
):
net
=
bottleneck
(
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.01
,
name_scope
=
'bottleneck1_4'
)
net
=
bottleneck
(
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.01
,
name_scope
=
'bottleneck1_4'
)
return
net
,
inputs_shape_1
...
...
@@ -302,17 +342,38 @@ def ENet_stage2(inputs, name_scope='stage2_block'):
name_scope
=
'bottleneck2_0'
)
for
i
in
range
(
2
):
with
scope
(
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
1
))):
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
name_scope
=
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
1
)))
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
name_scope
=
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
1
)))
with
scope
(
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
2
))):
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
DILATED
,
dilation_rate
=
(
2
**
(
2
*
i
+
1
)),
name_scope
=
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
2
)))
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
DILATED
,
dilation_rate
=
(
2
**
(
2
*
i
+
1
)),
name_scope
=
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
2
)))
with
scope
(
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
3
))):
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
5
,
regularizer_prob
=
0.1
,
type
=
ASYMMETRIC
,
name_scope
=
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
3
)))
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
5
,
regularizer_prob
=
0.1
,
type
=
ASYMMETRIC
,
name_scope
=
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
3
)))
with
scope
(
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
4
))):
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
DILATED
,
dilation_rate
=
(
2
**
(
2
*
i
+
2
)),
name_scope
=
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
4
)))
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
DILATED
,
dilation_rate
=
(
2
**
(
2
*
i
+
2
)),
name_scope
=
'bottleneck2_{}'
.
format
(
str
(
4
*
i
+
4
)))
return
net
,
inputs_shape_2
...
...
@@ -320,52 +381,106 @@ def ENet_stage3(inputs, name_scope='stage3_block'):
with
scope
(
name_scope
):
for
i
in
range
(
2
):
with
scope
(
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
0
))):
net
=
bottleneck
(
inputs
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
name_scope
=
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
0
)))
net
=
bottleneck
(
inputs
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
name_scope
=
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
0
)))
with
scope
(
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
1
))):
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
DILATED
,
dilation_rate
=
(
2
**
(
2
*
i
+
1
)),
name_scope
=
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
1
)))
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
DILATED
,
dilation_rate
=
(
2
**
(
2
*
i
+
1
)),
name_scope
=
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
1
)))
with
scope
(
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
2
))):
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
5
,
regularizer_prob
=
0.1
,
type
=
ASYMMETRIC
,
name_scope
=
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
2
)))
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
5
,
regularizer_prob
=
0.1
,
type
=
ASYMMETRIC
,
name_scope
=
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
2
)))
with
scope
(
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
3
))):
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
DILATED
,
dilation_rate
=
(
2
**
(
2
*
i
+
2
)),
name_scope
=
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
3
)))
net
=
bottleneck
(
net
,
output_depth
=
128
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
DILATED
,
dilation_rate
=
(
2
**
(
2
*
i
+
2
)),
name_scope
=
'bottleneck3_{}'
.
format
(
str
(
4
*
i
+
3
)))
return
net
def
ENet_stage4
(
inputs
,
inputs_shape
,
connect_tensor
,
skip_connections
=
True
,
name_scope
=
'stage4_block'
):
def
ENet_stage4
(
inputs
,
inputs_shape
,
connect_tensor
,
skip_connections
=
True
,
name_scope
=
'stage4_block'
):
with
scope
(
name_scope
):
with
scope
(
'bottleneck4_0'
):
net
=
bottleneck
(
inputs
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
UPSAMPLING
,
decoder
=
True
,
output_shape
=
inputs_shape
,
name_scope
=
'bottleneck4_0'
)
net
=
bottleneck
(
inputs
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
UPSAMPLING
,
decoder
=
True
,
output_shape
=
inputs_shape
,
name_scope
=
'bottleneck4_0'
)
if
skip_connections
:
net
=
fluid
.
layers
.
elementwise_add
(
net
,
connect_tensor
)
with
scope
(
'bottleneck4_1'
):
net
=
bottleneck
(
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
decoder
=
True
,
name_scope
=
'bottleneck4_1'
)
net
=
bottleneck
(
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
decoder
=
True
,
name_scope
=
'bottleneck4_1'
)
with
scope
(
'bottleneck4_2'
):
net
=
bottleneck
(
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
decoder
=
True
,
name_scope
=
'bottleneck4_2'
)
net
=
bottleneck
(
net
,
output_depth
=
64
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
decoder
=
True
,
name_scope
=
'bottleneck4_2'
)
return
net
def
ENet_stage5
(
inputs
,
inputs_shape
,
connect_tensor
,
skip_connections
=
True
,
def
ENet_stage5
(
inputs
,
inputs_shape
,
connect_tensor
,
skip_connections
=
True
,
name_scope
=
'stage5_block'
):
with
scope
(
name_scope
):
net
=
bottleneck
(
inputs
,
output_depth
=
16
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
UPSAMPLING
,
decoder
=
True
,
output_shape
=
inputs_shape
,
name_scope
=
'bottleneck5_0'
)
net
=
bottleneck
(
inputs
,
output_depth
=
16
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
type
=
UPSAMPLING
,
decoder
=
True
,
output_shape
=
inputs_shape
,
name_scope
=
'bottleneck5_0'
)
if
skip_connections
:
net
=
fluid
.
layers
.
elementwise_add
(
net
,
connect_tensor
)
with
scope
(
'bottleneck5_1'
):
net
=
bottleneck
(
net
,
output_depth
=
16
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
decoder
=
True
,
name_scope
=
'bottleneck5_1'
)
net
=
bottleneck
(
net
,
output_depth
=
16
,
filter_size
=
3
,
regularizer_prob
=
0.1
,
decoder
=
True
,
name_scope
=
'bottleneck5_1'
)
return
net
...
...
@@ -378,14 +493,16 @@ def decoder(input, num_classes):
segStage3
=
ENet_stage3
(
stage2
)
segStage4
=
ENet_stage4
(
segStage3
,
inputs_shape_2
,
stage1
)
segStage5
=
ENet_stage5
(
segStage4
,
inputs_shape_1
,
initial
)
segLogits
=
deconv
(
segStage5
,
num_classes
,
filter_size
=
2
,
stride
=
2
,
padding
=
'SAME'
)
segLogits
=
deconv
(
segStage5
,
num_classes
,
filter_size
=
2
,
stride
=
2
,
padding
=
'SAME'
)
# Embedding branch
with
scope
(
'LaneNetEm'
):
emStage3
=
ENet_stage3
(
stage2
)
emStage4
=
ENet_stage4
(
emStage3
,
inputs_shape_2
,
stage1
)
emStage5
=
ENet_stage5
(
emStage4
,
inputs_shape_1
,
initial
)
emLogits
=
deconv
(
emStage5
,
4
,
filter_size
=
2
,
stride
=
2
,
padding
=
'SAME'
)
emLogits
=
deconv
(
emStage5
,
4
,
filter_size
=
2
,
stride
=
2
,
padding
=
'SAME'
)
elif
'vgg'
in
cfg
.
MODEL
.
LANENET
.
BACKBONE
:
encoder_list
=
[
'pool5'
,
'pool4'
,
'pool3'
]
...
...
@@ -396,14 +513,16 @@ def decoder(input, num_classes):
encoder_list
=
encoder_list
[
1
:]
for
i
in
range
(
len
(
encoder_list
)):
with
scope
(
'deconv_{:d}'
.
format
(
i
+
1
)):
deconv_out
=
deconv
(
score
,
64
,
filter_size
=
4
,
stride
=
2
,
padding
=
'SAME'
)
deconv_out
=
deconv
(
score
,
64
,
filter_size
=
4
,
stride
=
2
,
padding
=
'SAME'
)
input_tensor
=
input
[
encoder_list
[
i
]]
with
scope
(
'score_{:d}'
.
format
(
i
+
1
)):
score
=
conv
(
input_tensor
,
64
,
1
)
score
=
fluid
.
layers
.
elementwise_add
(
deconv_out
,
score
)
with
scope
(
'deconv_final'
):
emLogits
=
deconv
(
score
,
64
,
filter_size
=
16
,
stride
=
8
,
padding
=
'SAME'
)
emLogits
=
deconv
(
score
,
64
,
filter_size
=
16
,
stride
=
8
,
padding
=
'SAME'
)
with
scope
(
'score_final'
):
segLogits
=
conv
(
emLogits
,
num_classes
,
1
)
emLogits
=
relu
(
conv
(
emLogits
,
4
,
1
))
...
...
@@ -415,7 +534,8 @@ def encoder(input):
model
=
vgg_backbone
(
layers
=
16
)
#output = model.net(input)
_
,
encode_feature_dict
=
model
.
net
(
input
,
end_points
=
13
,
decode_points
=
[
7
,
10
,
13
])
_
,
encode_feature_dict
=
model
.
net
(
input
,
end_points
=
13
,
decode_points
=
[
7
,
10
,
13
])
output
=
{}
output
[
'pool3'
]
=
encode_feature_dict
[
7
]
output
[
'pool4'
]
=
encode_feature_dict
[
10
]
...
...
@@ -427,8 +547,9 @@ def encoder(input):
stage2
,
inputs_shape_2
=
ENet_stage2
(
stage1
)
output
=
(
initial
,
stage1
,
stage2
,
inputs_shape_1
,
inputs_shape_2
)
else
:
raise
Exception
(
"LaneNet expect enet and vgg backbone, but received {}"
.
format
(
cfg
.
MODEL
.
LANENET
.
BACKBONE
))
raise
Exception
(
"LaneNet expect enet and vgg backbone, but received {}"
.
format
(
cfg
.
MODEL
.
LANENET
.
BACKBONE
))
return
output
...
...
contrib/LaneNet/reader.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -58,7 +58,8 @@ class LaneNetDataset():
if
self
.
shuffle
and
cfg
.
NUM_TRAINERS
>
1
:
np
.
random
.
RandomState
(
self
.
shuffle_seed
).
shuffle
(
self
.
all_lines
)
num_lines
=
len
(
self
.
all_lines
)
//
cfg
.
NUM_TRAINERS
self
.
lines
=
self
.
all_lines
[
num_lines
*
cfg
.
TRAINER_ID
:
num_lines
*
(
cfg
.
TRAINER_ID
+
1
)]
self
.
lines
=
self
.
all_lines
[
num_lines
*
cfg
.
TRAINER_ID
:
num_lines
*
(
cfg
.
TRAINER_ID
+
1
)]
self
.
shuffle_seed
+=
1
elif
self
.
shuffle
:
np
.
random
.
shuffle
(
self
.
lines
)
...
...
@@ -86,7 +87,8 @@ class LaneNetDataset():
if
self
.
shuffle
and
cfg
.
NUM_TRAINERS
>
1
:
np
.
random
.
RandomState
(
self
.
shuffle_seed
).
shuffle
(
self
.
all_lines
)
num_lines
=
len
(
self
.
all_lines
)
//
self
.
num_trainers
self
.
lines
=
self
.
all_lines
[
num_lines
*
self
.
trainer_id
:
num_lines
*
(
self
.
trainer_id
+
1
)]
self
.
lines
=
self
.
all_lines
[
num_lines
*
self
.
trainer_id
:
num_lines
*
(
self
.
trainer_id
+
1
)]
self
.
shuffle_seed
+=
1
elif
self
.
shuffle
:
np
.
random
.
shuffle
(
self
.
lines
)
...
...
@@ -118,7 +120,8 @@ class LaneNetDataset():
def
batch_reader
(
is_test
=
False
,
drop_last
=
drop_last
):
if
is_test
:
imgs
,
grts
,
grts_instance
,
img_names
,
valid_shapes
,
org_shapes
=
[],
[],
[],
[],
[],
[]
for
img
,
grt
,
grt_instance
,
img_name
,
valid_shape
,
org_shape
in
reader
():
for
img
,
grt
,
grt_instance
,
img_name
,
valid_shape
,
org_shape
in
reader
(
):
imgs
.
append
(
img
)
grts
.
append
(
grt
)
grts_instance
.
append
(
grt_instance
)
...
...
@@ -126,14 +129,15 @@ class LaneNetDataset():
valid_shapes
.
append
(
valid_shape
)
org_shapes
.
append
(
org_shape
)
if
len
(
imgs
)
==
batch_size
:
yield
np
.
array
(
imgs
),
np
.
array
(
grts
),
np
.
array
(
grts_instance
),
img_names
,
np
.
array
(
valid_shapes
)
,
np
.
array
(
org_shapes
)
yield
np
.
array
(
imgs
),
np
.
array
(
grts
),
np
.
array
(
grts
_instance
),
img_names
,
np
.
array
(
valid_shapes
),
np
.
array
(
org_shapes
)
imgs
,
grts
,
grts_instance
,
img_names
,
valid_shapes
,
org_shapes
=
[],
[],
[],
[],
[],
[]
if
not
drop_last
and
len
(
imgs
)
>
0
:
yield
np
.
array
(
imgs
),
np
.
array
(
grts
),
np
.
array
(
grts_instance
),
img_names
,
np
.
array
(
valid_shapes
),
np
.
array
(
org_shapes
)
yield
np
.
array
(
imgs
),
np
.
array
(
grts
),
np
.
array
(
grts_instance
),
img_names
,
np
.
array
(
valid_shapes
),
np
.
array
(
org_shapes
)
else
:
imgs
,
labs
,
labs_instance
,
ignore
=
[],
[],
[],
[]
bs
=
0
...
...
@@ -144,12 +148,14 @@ class LaneNetDataset():
ignore
.
append
(
ig
)
bs
+=
1
if
bs
==
batch_size
:
yield
np
.
array
(
imgs
),
np
.
array
(
labs
),
np
.
array
(
labs_instance
),
np
.
array
(
ignore
)
yield
np
.
array
(
imgs
),
np
.
array
(
labs
),
np
.
array
(
labs_instance
),
np
.
array
(
ignore
)
bs
=
0
imgs
,
labs
,
labs_instance
,
ignore
=
[],
[],
[],
[]
if
not
drop_last
and
bs
>
0
:
yield
np
.
array
(
imgs
),
np
.
array
(
labs
),
np
.
array
(
labs_instance
),
np
.
array
(
ignore
)
yield
np
.
array
(
imgs
),
np
.
array
(
labs
),
np
.
array
(
labs_instance
),
np
.
array
(
ignore
)
return
batch_reader
(
is_test
,
drop_last
)
...
...
@@ -299,10 +305,12 @@ class LaneNetDataset():
img
,
grt
=
aug
.
rand_crop
(
img
,
grt
,
mode
=
mode
)
elif
ModelPhase
.
is_eval
(
mode
):
img
,
grt
,
grt_instance
=
aug
.
resize
(
img
,
grt
,
grt_instance
,
mode
=
mode
)
img
,
grt
,
grt_instance
=
aug
.
resize
(
img
,
grt
,
grt_instance
,
mode
=
mode
)
elif
ModelPhase
.
is_visual
(
mode
):
ori_img
=
img
.
copy
()
img
,
grt
,
grt_instance
=
aug
.
resize
(
img
,
grt
,
grt_instance
,
mode
=
mode
)
img
,
grt
,
grt_instance
=
aug
.
resize
(
img
,
grt
,
grt_instance
,
mode
=
mode
)
valid_shape
=
[
img
.
shape
[
0
],
img
.
shape
[
1
]]
else
:
raise
ValueError
(
"Dataset mode={} Error!"
.
format
(
mode
))
...
...
contrib/LaneNet/train.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -40,10 +40,10 @@ from pdseg.utils.timer import Timer, calculate_eta
from
reader
import
LaneNetDataset
from
models.model_builder
import
build_model
from
models.model_builder
import
ModelPhase
from
models.model_builder
import
parse_shape_from_file
from
eval
import
evaluate
from
vis
import
visualize
from
utils
import
dist_utils
from
utils.load_model_utils
import
load_pretrained_weights
def
parse_args
():
...
...
@@ -101,37 +101,6 @@ def parse_args():
return
parser
.
parse_args
()
def
save_vars
(
executor
,
dirname
,
program
=
None
,
vars
=
None
):
"""
Temporary resolution for Win save variables compatability.
Will fix in PaddlePaddle v1.5.2
"""
save_program
=
fluid
.
Program
()
save_block
=
save_program
.
global_block
()
for
each_var
in
vars
:
# NOTE: don't save the variable which type is RAW
if
each_var
.
type
==
fluid
.
core
.
VarDesc
.
VarType
.
RAW
:
continue
new_var
=
save_block
.
create_var
(
name
=
each_var
.
name
,
shape
=
each_var
.
shape
,
dtype
=
each_var
.
dtype
,
type
=
each_var
.
type
,
lod_level
=
each_var
.
lod_level
,
persistable
=
True
)
file_path
=
os
.
path
.
join
(
dirname
,
new_var
.
name
)
file_path
=
os
.
path
.
normpath
(
file_path
)
save_block
.
append_op
(
type
=
'save'
,
inputs
=
{
'X'
:
[
new_var
]},
outputs
=
{},
attrs
=
{
'file_path'
:
file_path
})
executor
.
run
(
save_program
)
def
save_checkpoint
(
exe
,
program
,
ckpt_name
):
"""
Save checkpoint for evaluation or resume training
...
...
@@ -141,29 +110,22 @@ def save_checkpoint(exe, program, ckpt_name):
if
not
os
.
path
.
isdir
(
ckpt_dir
):
os
.
makedirs
(
ckpt_dir
)
save_vars
(
exe
,
ckpt_dir
,
program
,
vars
=
list
(
filter
(
fluid
.
io
.
is_persistable
,
program
.
list_vars
())))
fluid
.
save
(
program
,
os
.
path
.
join
(
ckpt_dir
,
'model'
))
return
ckpt_dir
def
load_checkpoint
(
exe
,
program
):
"""
Load checkpoiont f
rom pretrained model directory for resume
training
Load checkpoiont f
or resuming
training
"""
print
(
'Resume model training from:'
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
)
if
not
os
.
path
.
exists
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
):
raise
ValueError
(
"TRAIN.PRETRAIN_MODEL {} not exist!"
.
format
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
))
fluid
.
io
.
load_persistables
(
exe
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
,
main_program
=
program
)
model_path
=
cfg
.
TRAIN
.
RESUME_MODEL_DIR
print
(
'Resume model training from:'
,
model_path
)
if
not
os
.
path
.
exists
(
model_path
):
raise
ValueError
(
"TRAIN.PRETRAIN_MODEL {} not exist!"
.
format
(
model_path
))
fluid
.
load
(
program
,
os
.
path
.
join
(
model_path
,
'model'
),
exe
)
# Check is path ended by path spearator
if
model_path
[
-
1
]
==
os
.
sep
:
model_path
=
model_path
[
0
:
-
1
]
...
...
@@ -178,7 +140,6 @@ def load_checkpoint(exe, program):
else
:
raise
ValueError
(
"Resume model path is not valid!"
)
print
(
"Model checkpoint loaded successfully!"
)
return
begin_epoch
...
...
@@ -271,44 +232,7 @@ def train(cfg):
begin_epoch
=
load_checkpoint
(
exe
,
train_prog
)
# Load pretrained model
elif
os
.
path
.
exists
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
):
print_info
(
'Pretrained model dir: '
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
load_vars
=
[]
load_fail_vars
=
[]
def
var_shape_matched
(
var
,
shape
):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist
=
os
.
path
.
exists
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
if
var_exist
:
var_shape
=
parse_shape_from_file
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
if
var_shape
!=
shape
:
print
(
var
.
name
,
var_shape
,
shape
)
return
var_shape
==
shape
return
False
for
x
in
train_prog
.
list_vars
():
if
isinstance
(
x
,
fluid
.
framework
.
Parameter
):
shape
=
tuple
(
fluid
.
global_scope
().
find_var
(
x
.
name
).
get_tensor
().
shape
())
if
var_shape_matched
(
x
,
shape
):
load_vars
.
append
(
x
)
else
:
load_fail_vars
.
append
(
x
)
fluid
.
io
.
load_vars
(
exe
,
dirname
=
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
vars
=
load_vars
)
for
var
in
load_vars
:
print_info
(
"Parameter[{}] loaded sucessfully!"
.
format
(
var
.
name
))
for
var
in
load_fail_vars
:
print_info
(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it."
.
format
(
var
.
name
))
print_info
(
"{}/{} pretrained parameters loaded successfully!"
.
format
(
len
(
load_vars
),
len
(
load_vars
)
+
len
(
load_fail_vars
)))
load_pretrained_weights
(
exe
,
train_prog
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
else
:
print_info
(
'Pretrained model dir {} not exists, training from scratch...'
.
...
...
@@ -393,8 +317,7 @@ def train(cfg):
avg_emb_loss
,
avg_acc
,
avg_fp
,
avg_fn
,
speed
,
calculate_eta
(
all_step
-
step
,
speed
)))
if
args
.
use_vdl
:
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/speed'
,
speed
,
step
)
sys
.
stdout
.
flush
()
...
...
@@ -423,8 +346,7 @@ def train(cfg):
use_gpu
=
args
.
use_gpu
,
use_mpio
=
args
.
use_mpio
)
if
args
.
use_vdl
:
log_writer
.
add_scalar
(
'Evaluate/accuracy'
,
accuracy
,
step
)
log_writer
.
add_scalar
(
'Evaluate/accuracy'
,
accuracy
,
step
)
log_writer
.
add_scalar
(
'Evaluate/fp'
,
fp
,
step
)
log_writer
.
add_scalar
(
'Evaluate/fn'
,
fn
,
step
)
...
...
contrib/LaneNet/utils/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
contrib/LaneNet/utils/config.py
浏览文件 @
61645b1d
#
-*- coding: utf-8 -*-
#
Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
.
#
coding: utf8
#
Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve
.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
@@ -80,8 +80,8 @@ cfg.DATASET.DATA_DIM = 3
cfg
.
DATASET
.
SEPARATOR
=
' '
# 忽略的像素标签值, 默认为255,一般无需改动
cfg
.
DATASET
.
IGNORE_INDEX
=
255
# 数据增强是图像的padding值
cfg
.
DATASET
.
PADDING_VALUE
=
[
127.5
,
127.5
,
127.5
]
# 数据增强是图像的padding值
cfg
.
DATASET
.
PADDING_VALUE
=
[
127.5
,
127.5
,
127.5
]
########################### 数据增强配置 ######################################
# 图像镜像左右翻转
...
...
contrib/LaneNet/utils/dist_utils.py
浏览文件 @
61645b1d
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
Licensed under the Apache License, Version 2.0 (the "License");
#
you may not use this file except in compliance with the License.
#
You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
#
Unless required by applicable law or agreed to in writing, software
#
distributed under the License is distributed on an "AS IS" BASIS,
#
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
See the License for the specific language governing permissions and
#
limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
...
...
contrib/LaneNet/utils/generate_tusimple_dataset.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
generate tusimple training dataset
"""
...
...
@@ -14,12 +28,16 @@ import numpy as np
def
init_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--src_dir'
,
type
=
str
,
help
=
'The origin path of unzipped tusimple dataset'
)
parser
.
add_argument
(
'--src_dir'
,
type
=
str
,
help
=
'The origin path of unzipped tusimple dataset'
)
return
parser
.
parse_args
()
def
process_json_file
(
json_file_path
,
src_dir
,
ori_dst_dir
,
binary_dst_dir
,
instance_dst_dir
):
def
process_json_file
(
json_file_path
,
src_dir
,
ori_dst_dir
,
binary_dst_dir
,
instance_dst_dir
):
assert
ops
.
exists
(
json_file_path
),
'{:s} not exist'
.
format
(
json_file_path
)
...
...
@@ -39,11 +57,14 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
h_samples
=
info_dict
[
'h_samples'
]
lanes
=
info_dict
[
'lanes'
]
image_name_new
=
'{:s}.png'
.
format
(
'{:d}'
.
format
(
line_index
+
image_nums
).
zfill
(
4
))
image_name_new
=
'{:s}.png'
.
format
(
'{:d}'
.
format
(
line_index
+
image_nums
).
zfill
(
4
))
src_image
=
cv2
.
imread
(
image_path
,
cv2
.
IMREAD_COLOR
)
dst_binary_image
=
np
.
zeros
([
src_image
.
shape
[
0
],
src_image
.
shape
[
1
]],
np
.
uint8
)
dst_instance_image
=
np
.
zeros
([
src_image
.
shape
[
0
],
src_image
.
shape
[
1
]],
np
.
uint8
)
dst_binary_image
=
np
.
zeros
(
[
src_image
.
shape
[
0
],
src_image
.
shape
[
1
]],
np
.
uint8
)
dst_instance_image
=
np
.
zeros
(
[
src_image
.
shape
[
0
],
src_image
.
shape
[
1
]],
np
.
uint8
)
for
lane_index
,
lane
in
enumerate
(
lanes
):
assert
len
(
h_samples
)
==
len
(
lane
)
...
...
@@ -62,13 +83,23 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
lane_pts
=
np
.
vstack
((
lane_x
,
lane_y
)).
transpose
()
lane_pts
=
np
.
array
([
lane_pts
],
np
.
int64
)
cv2
.
polylines
(
dst_binary_image
,
lane_pts
,
isClosed
=
False
,
color
=
255
,
thickness
=
5
)
cv2
.
polylines
(
dst_instance_image
,
lane_pts
,
isClosed
=
False
,
color
=
lane_index
*
50
+
20
,
thickness
=
5
)
dst_binary_image_path
=
ops
.
join
(
src_dir
,
binary_dst_dir
,
image_name_new
)
dst_instance_image_path
=
ops
.
join
(
src_dir
,
instance_dst_dir
,
image_name_new
)
cv2
.
polylines
(
dst_binary_image
,
lane_pts
,
isClosed
=
False
,
color
=
255
,
thickness
=
5
)
cv2
.
polylines
(
dst_instance_image
,
lane_pts
,
isClosed
=
False
,
color
=
lane_index
*
50
+
20
,
thickness
=
5
)
dst_binary_image_path
=
ops
.
join
(
src_dir
,
binary_dst_dir
,
image_name_new
)
dst_instance_image_path
=
ops
.
join
(
src_dir
,
instance_dst_dir
,
image_name_new
)
dst_rgb_image_path
=
ops
.
join
(
src_dir
,
ori_dst_dir
,
image_name_new
)
cv2
.
imwrite
(
dst_binary_image_path
,
dst_binary_image
)
...
...
@@ -78,7 +109,12 @@ def process_json_file(json_file_path, src_dir, ori_dst_dir, binary_dst_dir, inst
print
(
'Process {:s} success'
.
format
(
image_name
))
def
gen_sample
(
src_dir
,
b_gt_image_dir
,
i_gt_image_dir
,
image_dir
,
phase
=
'train'
,
split
=
False
):
def
gen_sample
(
src_dir
,
b_gt_image_dir
,
i_gt_image_dir
,
image_dir
,
phase
=
'train'
,
split
=
False
):
label_list
=
[]
with
open
(
'{:s}/{}ing/{}.txt'
.
format
(
src_dir
,
phase
,
phase
),
'w'
)
as
file
:
...
...
@@ -92,7 +128,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
image_path
=
ops
.
join
(
image_dir
,
image_name
)
assert
ops
.
exists
(
image_path
),
'{:s} not exist'
.
format
(
image_path
)
assert
ops
.
exists
(
instance_gt_image_path
),
'{:s} not exist'
.
format
(
instance_gt_image_path
)
assert
ops
.
exists
(
instance_gt_image_path
),
'{:s} not exist'
.
format
(
instance_gt_image_path
)
b_gt_image
=
cv2
.
imread
(
binary_gt_image_path
,
cv2
.
IMREAD_COLOR
)
i_gt_image
=
cv2
.
imread
(
instance_gt_image_path
,
cv2
.
IMREAD_COLOR
)
...
...
@@ -102,7 +139,8 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
print
(
'image: {:s} corrupt'
.
format
(
image_name
))
continue
else
:
info
=
'{:s} {:s} {:s}'
.
format
(
image_path
,
binary_gt_image_path
,
instance_gt_image_path
)
info
=
'{:s} {:s} {:s}'
.
format
(
image_path
,
binary_gt_image_path
,
instance_gt_image_path
)
file
.
write
(
info
+
'
\n
'
)
label_list
.
append
(
info
)
if
phase
==
'train'
and
split
:
...
...
@@ -110,10 +148,12 @@ def gen_sample(src_dir, b_gt_image_dir, i_gt_image_dir, image_dir, phase='train'
val_list_len
=
len
(
label_list
)
//
10
val_label_list
=
label_list
[:
val_list_len
]
train_label_list
=
label_list
[
val_list_len
:]
with
open
(
'{:s}/{}ing/train_part.txt'
.
format
(
src_dir
,
phase
,
phase
),
'w'
)
as
file
:
with
open
(
'{:s}/{}ing/train_part.txt'
.
format
(
src_dir
,
phase
,
phase
),
'w'
)
as
file
:
for
info
in
train_label_list
:
file
.
write
(
info
+
'
\n
'
)
with
open
(
'{:s}/{}ing/val_part.txt'
.
format
(
src_dir
,
phase
,
phase
),
'w'
)
as
file
:
with
open
(
'{:s}/{}ing/val_part.txt'
.
format
(
src_dir
,
phase
,
phase
),
'w'
)
as
file
:
for
info
in
val_label_list
:
file
.
write
(
info
+
'
\n
'
)
return
...
...
@@ -130,12 +170,14 @@ def process_tusimple_dataset(src_dir):
for
json_label_path
in
glob
.
glob
(
'{:s}/label*.json'
.
format
(
src_dir
)):
json_label_name
=
ops
.
split
(
json_label_path
)[
1
]
shutil
.
copyfile
(
json_label_path
,
ops
.
join
(
traing_folder_path
,
json_label_name
))
shutil
.
copyfile
(
json_label_path
,
ops
.
join
(
traing_folder_path
,
json_label_name
))
for
json_label_path
in
glob
.
glob
(
'{:s}/test_label.json'
.
format
(
src_dir
)):
json_label_name
=
ops
.
split
(
json_label_path
)[
1
]
shutil
.
copyfile
(
json_label_path
,
ops
.
join
(
testing_folder_path
,
json_label_name
))
shutil
.
copyfile
(
json_label_path
,
ops
.
join
(
testing_folder_path
,
json_label_name
))
train_gt_image_dir
=
ops
.
join
(
'training'
,
'gt_image'
)
train_gt_binary_dir
=
ops
.
join
(
'training'
,
'gt_binary_image'
)
...
...
@@ -154,9 +196,11 @@ def process_tusimple_dataset(src_dir):
os
.
makedirs
(
os
.
path
.
join
(
src_dir
,
test_gt_instance_dir
),
exist_ok
=
True
)
for
json_label_path
in
glob
.
glob
(
'{:s}/*.json'
.
format
(
traing_folder_path
)):
process_json_file
(
json_label_path
,
src_dir
,
train_gt_image_dir
,
train_gt_binary_dir
,
train_gt_instance_dir
)
process_json_file
(
json_label_path
,
src_dir
,
train_gt_image_dir
,
train_gt_binary_dir
,
train_gt_instance_dir
)
gen_sample
(
src_dir
,
train_gt_binary_dir
,
train_gt_instance_dir
,
train_gt_image_dir
,
'train'
,
True
)
gen_sample
(
src_dir
,
train_gt_binary_dir
,
train_gt_instance_dir
,
train_gt_image_dir
,
'train'
,
True
)
if
__name__
==
'__main__'
:
...
...
contrib/LaneNet/utils/lanenet_postprocess.py
浏览文件 @
61645b1d
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this code heavily base on https://github.com/MaybeShewill-CV/lanenet-lane-detection/blob/master/lanenet_model/lanenet_postprocess.py
"""
LaneNet model post process
...
...
@@ -22,12 +35,14 @@ def _morphological_process(image, kernel_size=5):
:return:
"""
if
len
(
image
.
shape
)
==
3
:
raise
ValueError
(
'Binary segmentation result image should be a single channel image'
)
raise
ValueError
(
'Binary segmentation result image should be a single channel image'
)
if
image
.
dtype
is
not
np
.
uint8
:
image
=
np
.
array
(
image
,
np
.
uint8
)
kernel
=
cv2
.
getStructuringElement
(
shape
=
cv2
.
MORPH_ELLIPSE
,
ksize
=
(
kernel_size
,
kernel_size
))
kernel
=
cv2
.
getStructuringElement
(
shape
=
cv2
.
MORPH_ELLIPSE
,
ksize
=
(
kernel_size
,
kernel_size
))
# close operation fille hole
closing
=
cv2
.
morphologyEx
(
image
,
cv2
.
MORPH_CLOSE
,
kernel
,
iterations
=
1
)
...
...
@@ -46,13 +61,15 @@ def _connect_components_analysis(image):
else
:
gray_image
=
image
return
cv2
.
connectedComponentsWithStats
(
gray_image
,
connectivity
=
8
,
ltype
=
cv2
.
CV_32S
)
return
cv2
.
connectedComponentsWithStats
(
gray_image
,
connectivity
=
8
,
ltype
=
cv2
.
CV_32S
)
class
_LaneFeat
(
object
):
"""
"""
def
__init__
(
self
,
feat
,
coord
,
class_id
=-
1
):
"""
lane feat object
...
...
@@ -108,18 +125,21 @@ class _LaneNetCluster(object):
"""
Instance segmentation result cluster
"""
def
__init__
(
self
):
"""
"""
self
.
_color_map
=
[
np
.
array
([
255
,
0
,
0
]),
np
.
array
([
0
,
255
,
0
]),
np
.
array
([
0
,
0
,
255
]),
np
.
array
([
125
,
125
,
0
]),
np
.
array
([
0
,
125
,
125
]),
np
.
array
([
125
,
0
,
125
]),
np
.
array
([
50
,
100
,
50
]),
np
.
array
([
100
,
50
,
100
])]
self
.
_color_map
=
[
np
.
array
([
255
,
0
,
0
]),
np
.
array
([
0
,
255
,
0
]),
np
.
array
([
0
,
0
,
255
]),
np
.
array
([
125
,
125
,
0
]),
np
.
array
([
0
,
125
,
125
]),
np
.
array
([
125
,
0
,
125
]),
np
.
array
([
50
,
100
,
50
]),
np
.
array
([
100
,
50
,
100
])
]
@
staticmethod
def
_embedding_feats_dbscan_cluster
(
embedding_image_feats
):
...
...
@@ -186,15 +206,16 @@ class _LaneNetCluster(object):
# get embedding feats and coords
get_lane_embedding_feats_result
=
self
.
_get_lane_embedding_feats
(
binary_seg_ret
=
binary_seg_result
,
instance_seg_ret
=
instance_seg_result
)
instance_seg_ret
=
instance_seg_result
)
# dbscan cluster
dbscan_cluster_result
=
self
.
_embedding_feats_dbscan_cluster
(
embedding_image_feats
=
get_lane_embedding_feats_result
[
'lane_embedding_feats'
]
)
embedding_image_feats
=
get_lane_embedding_feats_result
[
'lane_embedding_feats'
]
)
mask
=
np
.
zeros
(
shape
=
[
binary_seg_result
.
shape
[
0
],
binary_seg_result
.
shape
[
1
],
3
],
dtype
=
np
.
uint8
)
mask
=
np
.
zeros
(
shape
=
[
binary_seg_result
.
shape
[
0
],
binary_seg_result
.
shape
[
1
],
3
],
dtype
=
np
.
uint8
)
db_labels
=
dbscan_cluster_result
[
'db_labels'
]
unique_labels
=
dbscan_cluster_result
[
'unique_labels'
]
coord
=
get_lane_embedding_feats_result
[
'lane_coordinates'
]
...
...
@@ -219,11 +240,13 @@ class LaneNetPostProcessor(object):
"""
lanenet post process for lane generation
"""
def
__init__
(
self
,
ipm_remap_file_path
=
'./utils/tusimple_ipm_remap.yml'
):
"""
convert front car view to bird view
"""
assert
ops
.
exists
(
ipm_remap_file_path
),
'{:s} not exist'
.
format
(
ipm_remap_file_path
)
assert
ops
.
exists
(
ipm_remap_file_path
),
'{:s} not exist'
.
format
(
ipm_remap_file_path
)
self
.
_cluster
=
_LaneNetCluster
()
self
.
_ipm_remap_file_path
=
ipm_remap_file_path
...
...
@@ -232,14 +255,16 @@ class LaneNetPostProcessor(object):
self
.
_remap_to_ipm_x
=
remap_file_load_ret
[
'remap_to_ipm_x'
]
self
.
_remap_to_ipm_y
=
remap_file_load_ret
[
'remap_to_ipm_y'
]
self
.
_color_map
=
[
np
.
array
([
255
,
0
,
0
]),
np
.
array
([
0
,
255
,
0
]),
np
.
array
([
0
,
0
,
255
]),
np
.
array
([
125
,
125
,
0
]),
np
.
array
([
0
,
125
,
125
]),
np
.
array
([
125
,
0
,
125
]),
np
.
array
([
50
,
100
,
50
]),
np
.
array
([
100
,
50
,
100
])]
self
.
_color_map
=
[
np
.
array
([
255
,
0
,
0
]),
np
.
array
([
0
,
255
,
0
]),
np
.
array
([
0
,
0
,
255
]),
np
.
array
([
125
,
125
,
0
]),
np
.
array
([
0
,
125
,
125
]),
np
.
array
([
125
,
0
,
125
]),
np
.
array
([
50
,
100
,
50
]),
np
.
array
([
100
,
50
,
100
])
]
def
_load_remap_matrix
(
self
):
fs
=
cv2
.
FileStorage
(
self
.
_ipm_remap_file_path
,
cv2
.
FILE_STORAGE_READ
)
...
...
@@ -256,15 +281,20 @@ class LaneNetPostProcessor(object):
return
ret
def
postprocess
(
self
,
binary_seg_result
,
instance_seg_result
=
None
,
min_area_threshold
=
100
,
source_image
=
None
,
def
postprocess
(
self
,
binary_seg_result
,
instance_seg_result
=
None
,
min_area_threshold
=
100
,
source_image
=
None
,
data_source
=
'tusimple'
):
# convert binary_seg_result
binary_seg_result
=
np
.
array
(
binary_seg_result
*
255
,
dtype
=
np
.
uint8
)
# apply image morphology operation to fill in the hold and reduce the small area
morphological_ret
=
_morphological_process
(
binary_seg_result
,
kernel_size
=
5
)
connect_components_analysis_ret
=
_connect_components_analysis
(
image
=
morphological_ret
)
morphological_ret
=
_morphological_process
(
binary_seg_result
,
kernel_size
=
5
)
connect_components_analysis_ret
=
_connect_components_analysis
(
image
=
morphological_ret
)
labels
=
connect_components_analysis_ret
[
1
]
stats
=
connect_components_analysis_ret
[
2
]
...
...
@@ -276,8 +306,7 @@ class LaneNetPostProcessor(object):
# apply embedding features cluster
mask_image
,
lane_coords
=
self
.
_cluster
.
apply_lane_feats_cluster
(
binary_seg_result
=
morphological_ret
,
instance_seg_result
=
instance_seg_result
)
instance_seg_result
=
instance_seg_result
)
if
mask_image
is
None
:
return
{
...
...
@@ -292,15 +321,15 @@ class LaneNetPostProcessor(object):
for
lane_index
,
coords
in
enumerate
(
lane_coords
):
if
data_source
==
'tusimple'
:
tmp_mask
=
np
.
zeros
(
shape
=
(
720
,
1280
),
dtype
=
np
.
uint8
)
tmp_mask
[
tuple
((
np
.
int_
(
coords
[:,
1
]
*
720
/
256
),
np
.
int_
(
coords
[:,
0
]
*
1280
/
512
)))]
=
255
tmp_mask
[
tuple
((
np
.
int_
(
coords
[:,
1
]
*
720
/
256
),
np
.
int_
(
coords
[:,
0
]
*
1280
/
512
)))]
=
255
else
:
raise
ValueError
(
'Wrong data source now only support tusimple'
)
tmp_ipm_mask
=
cv2
.
remap
(
tmp_mask
,
self
.
_remap_to_ipm_x
,
self
.
_remap_to_ipm_y
,
interpolation
=
cv2
.
INTER_NEAREST
)
interpolation
=
cv2
.
INTER_NEAREST
)
nonzero_y
=
np
.
array
(
tmp_ipm_mask
.
nonzero
()[
0
])
nonzero_x
=
np
.
array
(
tmp_ipm_mask
.
nonzero
()[
1
])
...
...
@@ -309,16 +338,19 @@ class LaneNetPostProcessor(object):
[
ipm_image_height
,
ipm_image_width
]
=
tmp_ipm_mask
.
shape
plot_y
=
np
.
linspace
(
10
,
ipm_image_height
,
ipm_image_height
-
10
)
fit_x
=
fit_param
[
0
]
*
plot_y
**
2
+
fit_param
[
1
]
*
plot_y
+
fit_param
[
2
]
fit_x
=
fit_param
[
0
]
*
plot_y
**
2
+
fit_param
[
1
]
*
plot_y
+
fit_param
[
2
]
lane_pts
=
[]
for
index
in
range
(
0
,
plot_y
.
shape
[
0
],
5
):
src_x
=
self
.
_remap_to_ipm_x
[
int
(
plot_y
[
index
]),
int
(
np
.
clip
(
fit_x
[
index
],
0
,
ipm_image_width
-
1
))]
int
(
plot_y
[
index
]),
int
(
np
.
clip
(
fit_x
[
index
],
0
,
ipm_image_width
-
1
))]
if
src_x
<=
0
:
continue
src_y
=
self
.
_remap_to_ipm_y
[
int
(
plot_y
[
index
]),
int
(
np
.
clip
(
fit_x
[
index
],
0
,
ipm_image_width
-
1
))]
int
(
plot_y
[
index
]),
int
(
np
.
clip
(
fit_x
[
index
],
0
,
ipm_image_width
-
1
))]
src_y
=
src_y
if
src_y
>
0
else
0
lane_pts
.
append
([
src_x
,
src_y
])
...
...
@@ -366,8 +398,10 @@ class LaneNetPostProcessor(object):
continue
lane_color
=
self
.
_color_map
[
index
].
tolist
()
cv2
.
circle
(
source_image
,
(
int
(
interpolation_src_pt_x
),
int
(
interpolation_src_pt_y
)),
5
,
lane_color
,
-
1
)
cv2
.
circle
(
source_image
,
(
int
(
interpolation_src_pt_x
),
int
(
interpolation_src_pt_y
)),
5
,
lane_color
,
-
1
)
ret
=
{
'mask_image'
:
mask_image
,
'fit_params'
:
fit_params
,
...
...
contrib/LaneNet/utils/load_model_utils.py
0 → 100644
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
os.path
as
osp
import
six
import
numpy
as
np
def
parse_param_file
(
param_file
,
return_shape
=
True
):
from
paddle.fluid.proto.framework_pb2
import
VarType
f
=
open
(
param_file
,
'rb'
)
version
=
np
.
fromstring
(
f
.
read
(
4
),
dtype
=
'int32'
)
lod_level
=
np
.
fromstring
(
f
.
read
(
8
),
dtype
=
'int64'
)
for
i
in
range
(
int
(
lod_level
)):
_size
=
np
.
fromstring
(
f
.
read
(
8
),
dtype
=
'int64'
)
_
=
f
.
read
(
_size
)
version
=
np
.
fromstring
(
f
.
read
(
4
),
dtype
=
'int32'
)
tensor_desc
=
VarType
.
TensorDesc
()
tensor_desc_size
=
np
.
fromstring
(
f
.
read
(
4
),
dtype
=
'int32'
)
tensor_desc
.
ParseFromString
(
f
.
read
(
int
(
tensor_desc_size
)))
tensor_shape
=
tuple
(
tensor_desc
.
dims
)
if
return_shape
:
f
.
close
()
return
tuple
(
tensor_desc
.
dims
)
if
tensor_desc
.
data_type
!=
5
:
raise
Exception
(
"Unexpected data type while parse {}"
.
format
(
param_file
))
data_size
=
4
for
i
in
range
(
len
(
tensor_shape
)):
data_size
*=
tensor_shape
[
i
]
weight
=
np
.
fromstring
(
f
.
read
(
data_size
),
dtype
=
'float32'
)
f
.
close
()
return
np
.
reshape
(
weight
,
tensor_shape
)
def
load_pdparams
(
exe
,
main_prog
,
model_dir
):
import
paddle.fluid
as
fluid
from
paddle.fluid.proto.framework_pb2
import
VarType
from
paddle.fluid.framework
import
Program
vars_to_load
=
list
()
vars_not_load
=
list
()
import
pickle
with
open
(
osp
.
join
(
model_dir
,
'model.pdparams'
),
'rb'
)
as
f
:
params_dict
=
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
f
,
encoding
=
'latin1'
)
unused_vars
=
list
()
for
var
in
main_prog
.
list_vars
():
if
not
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
continue
if
var
.
name
not
in
params_dict
:
print
(
"{} is not in saved model"
.
format
(
var
.
name
))
vars_not_load
.
append
(
var
.
name
)
continue
if
var
.
shape
!=
params_dict
[
var
.
name
].
shape
:
unused_vars
.
append
(
var
.
name
)
vars_not_load
.
append
(
var
.
name
)
print
(
"[SKIP] Shape of pretrained weight {} doesn't match.(Pretrained: {}, Actual: {})"
.
format
(
var
.
name
,
params_dict
[
var
.
name
].
shape
,
var
.
shape
))
continue
vars_to_load
.
append
(
var
)
for
var_name
in
unused_vars
:
del
params_dict
[
var_name
]
fluid
.
io
.
set_program_state
(
main_prog
,
params_dict
)
if
len
(
vars_to_load
)
==
0
:
print
(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
)
else
:
print
(
"There are {}/{} varaibles in {} are loaded."
.
format
(
len
(
vars_to_load
),
len
(
vars_to_load
)
+
len
(
vars_not_load
),
model_dir
))
def
load_pretrained_weights
(
exe
,
main_prog
,
weights_dir
):
if
not
osp
.
exists
(
weights_dir
):
raise
Exception
(
"Path {} not exists."
.
format
(
weights_dir
))
if
osp
.
exists
(
osp
.
join
(
weights_dir
,
"model.pdparams"
)):
return
load_pdparams
(
exe
,
main_prog
,
weights_dir
)
import
paddle.fluid
as
fluid
vars_to_load
=
list
()
vars_not_load
=
list
()
for
var
in
main_prog
.
list_vars
():
if
not
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
continue
if
not
osp
.
exists
(
osp
.
join
(
weights_dir
,
var
.
name
)):
print
(
"[SKIP] Pretrained weight {}/{} doesn't exist"
.
format
(
weights_dir
,
var
.
name
))
vars_not_load
.
append
(
var
)
continue
pretrained_shape
=
parse_param_file
(
osp
.
join
(
weights_dir
,
var
.
name
))
actual_shape
=
tuple
(
var
.
shape
)
if
pretrained_shape
!=
actual_shape
:
print
(
"[SKIP] Shape of pretrained weight {}/{} doesn't match.(Pretrained: {}, Actual: {})"
.
format
(
weights_dir
,
var
.
name
,
pretrained_shape
,
actual_shape
))
vars_not_load
.
append
(
var
)
continue
vars_to_load
.
append
(
var
)
params_dict
=
fluid
.
io
.
load_program_state
(
weights_dir
,
var_list
=
vars_to_load
)
fluid
.
io
.
set_program_state
(
main_prog
,
params_dict
)
if
len
(
vars_to_load
)
==
0
:
print
(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
)
else
:
print
(
"There are {}/{} varaibles in {} are loaded."
.
format
(
len
(
vars_to_load
),
len
(
vars_to_load
)
+
len
(
vars_not_load
),
weights_dir
))
contrib/LaneNet/vis.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -45,6 +45,7 @@ from models.model_builder import ModelPhase
from
utils
import
lanenet_postprocess
import
matplotlib.pyplot
as
plt
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'PaddeSeg visualization tools'
)
parser
.
add_argument
(
...
...
@@ -106,7 +107,6 @@ def minmax_scale(input_arr):
return
output_arr
def
visualize
(
cfg
,
vis_file_list
=
None
,
use_gpu
=
False
,
...
...
@@ -119,7 +119,6 @@ def visualize(cfg,
if
vis_file_list
is
None
:
vis_file_list
=
cfg
.
DATASET
.
TEST_FILE_LIST
dataset
=
LaneNetDataset
(
file_list
=
vis_file_list
,
mode
=
ModelPhase
.
VISUAL
,
...
...
@@ -139,7 +138,12 @@ def visualize(cfg,
ckpt_dir
=
cfg
.
TEST
.
TEST_MODEL
if
not
ckpt_dir
else
ckpt_dir
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
if
ckpt_dir
is
not
None
:
print
(
'load test model:'
,
ckpt_dir
)
try
:
fluid
.
load
(
test_prog
,
os
.
path
.
join
(
ckpt_dir
,
'model'
),
exe
)
except
:
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
save_dir
=
os
.
path
.
join
(
vis_dir
,
'visual_results'
)
makedirs
(
save_dir
)
...
...
@@ -161,22 +165,26 @@ def visualize(cfg,
for
i
in
range
(
num_imgs
):
gt_image
=
org_imgs
[
i
]
binary_seg_image
,
instance_seg_image
=
segLogits
[
i
].
squeeze
(
-
1
),
emLogits
[
i
].
transpose
((
1
,
2
,
0
))
binary_seg_image
,
instance_seg_image
=
segLogits
[
i
].
squeeze
(
-
1
),
emLogits
[
i
].
transpose
((
1
,
2
,
0
))
postprocess_result
=
postprocessor
.
postprocess
(
binary_seg_result
=
binary_seg_image
,
instance_seg_result
=
instance_seg_image
,
source_image
=
gt_image
)
pred_binary_fn
=
os
.
path
.
join
(
save_dir
,
to_png_fn
(
img_names
[
i
],
name
=
'_pred_binary'
))
pred_lane_fn
=
os
.
path
.
join
(
save_dir
,
to_png_fn
(
img_names
[
i
],
name
=
'_pred_lane'
))
pred_instance_fn
=
os
.
path
.
join
(
save_dir
,
to_png_fn
(
img_names
[
i
],
name
=
'_pred_instance'
))
source_image
=
gt_image
)
pred_binary_fn
=
os
.
path
.
join
(
save_dir
,
to_png_fn
(
img_names
[
i
],
name
=
'_pred_binary'
))
pred_lane_fn
=
os
.
path
.
join
(
save_dir
,
to_png_fn
(
img_names
[
i
],
name
=
'_pred_lane'
))
pred_instance_fn
=
os
.
path
.
join
(
save_dir
,
to_png_fn
(
img_names
[
i
],
name
=
'_pred_instance'
))
dirname
=
os
.
path
.
dirname
(
pred_binary_fn
)
makedirs
(
dirname
)
mask_image
=
postprocess_result
[
'mask_image'
]
for
i
in
range
(
4
):
instance_seg_image
[:,
:,
i
]
=
minmax_scale
(
instance_seg_image
[:,
:,
i
])
instance_seg_image
[:,
:,
i
]
=
minmax_scale
(
instance_seg_image
[:,
:,
i
])
embedding_image
=
np
.
array
(
instance_seg_image
).
astype
(
np
.
uint8
)
plt
.
figure
(
'mask_image'
)
...
...
@@ -189,13 +197,13 @@ def visualize(cfg,
plt
.
imshow
(
binary_seg_image
*
255
,
cmap
=
'gray'
)
plt
.
show
()
cv2
.
imwrite
(
pred_binary_fn
,
np
.
array
(
binary_seg_image
*
255
).
astype
(
np
.
uint8
))
cv2
.
imwrite
(
pred_binary_fn
,
np
.
array
(
binary_seg_image
*
255
).
astype
(
np
.
uint8
))
cv2
.
imwrite
(
pred_lane_fn
,
postprocess_result
[
'source_image'
])
cv2
.
imwrite
(
pred_instance_fn
,
mask_image
)
print
(
pred_lane_fn
,
'saved!'
)
if
__name__
==
'__main__'
:
args
=
parse_args
()
if
args
.
cfg_file
is
not
None
:
...
...
contrib/MechanicalIndustryMeter/download_mini_mechanical_industry_meter.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/MechanicalIndustryMeter/download_unet_mechanical_industry_meter.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
@@ -23,7 +24,8 @@ from test_utils import download_file_and_uncompress
if
__name__
==
"__main__"
:
download_file_and_uncompress
(
url
=
'https://paddleseg.bj.bcebos.com/models/unet_mechanical_industry_meter.tar'
,
url
=
'https://paddleseg.bj.bcebos.com/models/unet_mechanical_industry_meter.tar'
,
savepath
=
LOCAL_PATH
,
extrapath
=
LOCAL_PATH
)
...
...
contrib/RemoteSensing/__init__.py
浏览文件 @
61645b1d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/RemoteSensing/models/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.load_model
import
*
from
.unet
import
*
from
.hrnet
import
*
contrib/RemoteSensing/models/base.py
浏览文件 @
61645b1d
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
Licensed under the Apache License, Version 2.0 (the "License");
#
you may not use this file except in compliance with the License.
#
You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
#
Unless required by applicable law or agreed to in writing, software
#
distributed under the License is distributed on an "AS IS" BASIS,
#
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
See the License for the specific language governing permissions and
#
limitations under the License.
from
__future__
import
absolute_import
import
paddle.fluid
as
fluid
...
...
contrib/RemoteSensing/models/load_model.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/models/unet.py
浏览文件 @
61645b1d
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
Licensed under the Apache License, Version 2.0 (the "License");
#
you may not use this file except in compliance with the License.
#
You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
#
Unless required by applicable law or agreed to in writing, software
#
distributed under the License is distributed on an "AS IS" BASIS,
#
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
See the License for the specific language governing permissions and
#
limitations under the License.
from
__future__
import
absolute_import
import
numpy
as
np
...
...
contrib/RemoteSensing/nets/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.unet
import
UNet
from
.hrnet
import
HRNet
contrib/RemoteSensing/nets/libs.py
浏览文件 @
61645b1d
# coding: utf8
#
copyright (c) 2020
PaddlePaddle Authors. All Rights Reserve.
#
Copyright (c) 2019
PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/nets/loss.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/nets/unet.py
浏览文件 @
61645b1d
# coding: utf8
#
copyright (c) 2020
PaddlePaddle Authors. All Rights Reserve.
#
Copyright (c) 2019
PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/predict_demo.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
os.path
as
osp
import
sys
...
...
contrib/RemoteSensing/readers/__init__.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/readers/base.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/readers/reader.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/tools/create_dataset_list.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/tools/split_dataset_list.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/train_demo.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os.path
as
osp
import
argparse
import
transforms.transforms
as
T
...
...
contrib/RemoteSensing/transforms/__init__.py
浏览文件 @
61645b1d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/RemoteSensing/transforms/ops.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/transforms/transforms.py
浏览文件 @
61645b1d
# coding: utf8
#
copyright (c) 2020
PaddlePaddle Authors. All Rights Reserve.
#
Copyright (c) 2019
PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/utils/__init__.py
浏览文件 @
61645b1d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/RemoteSensing/utils/logging.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/utils/metrics.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
contrib/RemoteSensing/utils/pretrain_weights.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os.path
as
osp
...
...
contrib/RemoteSensing/utils/utils.py
浏览文件 @
61645b1d
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -201,11 +202,9 @@ def load_pretrain_weights(exe, main_prog, weights_dir, fuse_bn=False):
vars_to_load
.
append
(
var
)
logging
.
debug
(
"Weight {} will be load"
.
format
(
var
.
name
))
fluid
.
io
.
load_vars
(
executor
=
exe
,
dirname
=
weights_dir
,
main_program
=
main_prog
,
vars
=
vars_to_load
)
params_dict
=
fluid
.
io
.
load_program_state
(
weights_dir
,
var_list
=
vars_to_load
)
fluid
.
io
.
set_program_state
(
main_prog
,
params_dict
)
if
len
(
vars_to_load
)
==
0
:
logging
.
warning
(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
...
...
contrib/RoadLine/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
contrib/RoadLine/config.py
浏览文件 @
61645b1d
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
utils.util
import
AttrDict
,
merge_cfg_from_args
,
get_arguments
import
os
...
...
@@ -6,20 +20,20 @@ args = get_arguments()
cfg
=
AttrDict
()
# 待预测图像所在路径
cfg
.
data_dir
=
os
.
path
.
join
(
args
.
example
,
"data"
,
"test_images"
)
cfg
.
data_dir
=
os
.
path
.
join
(
args
.
example
,
"data"
,
"test_images"
)
# 待预测图像名称列表
cfg
.
data_list_file
=
os
.
path
.
join
(
args
.
example
,
"data"
,
"test.txt"
)
cfg
.
data_list_file
=
os
.
path
.
join
(
args
.
example
,
"data"
,
"test.txt"
)
# 模型加载路径
cfg
.
model_path
=
os
.
path
.
join
(
args
.
example
,
"model"
)
cfg
.
model_path
=
os
.
path
.
join
(
args
.
example
,
"model"
)
# 预测结果保存路径
cfg
.
vis_dir
=
os
.
path
.
join
(
args
.
example
,
"result"
)
cfg
.
vis_dir
=
os
.
path
.
join
(
args
.
example
,
"result"
)
# 预测类别数
cfg
.
class_num
=
2
# 均值, 图像预处理减去的均值
cfg
.
MEAN
=
127.5
,
127.5
,
127.5
# 标准差,图像预处理除以标准差
cfg
.
STD
=
127.5
,
127.5
,
127.5
cfg
.
STD
=
127.5
,
127.5
,
127.5
# 待预测图像输入尺寸
cfg
.
input_size
=
1536
,
576
...
...
contrib/RoadLine/download_RoadLine.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
contrib/RoadLine/infer.py
浏览文件 @
61645b1d
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
cv2
import
numpy
as
np
...
...
@@ -12,18 +26,19 @@ config = importlib.import_module('config')
cfg
=
getattr
(
config
,
'cfg'
)
# paddle垃圾回收策略FLAG,ACE2P模型较大,当显存不够时建议开启
os
.
environ
[
'FLAGS_eager_delete_tensor_gb'
]
=
'0.0'
os
.
environ
[
'FLAGS_eager_delete_tensor_gb'
]
=
'0.0'
import
paddle.fluid
as
fluid
# 预测数据集类
class
TestDataSet
():
def
__init__
(
self
):
self
.
data_dir
=
cfg
.
data_dir
self
.
data_dir
=
cfg
.
data_dir
self
.
data_list_file
=
cfg
.
data_list_file
self
.
data_list
=
self
.
get_data_list
()
self
.
data_num
=
len
(
self
.
data_list
)
def
get_data_list
(
self
):
# 获取预测图像路径列表
data_list
=
[]
...
...
@@ -40,7 +55,7 @@ class TestDataSet():
def
preprocess
(
self
,
img
):
# 图像预处理
if
cfg
.
example
==
'ACE2P'
:
reader
=
importlib
.
import_module
(
args
.
example
+
'.reader'
)
reader
=
importlib
.
import_module
(
args
.
example
+
'.reader'
)
ACE2P_preprocess
=
getattr
(
reader
,
'preprocess'
)
img
=
ACE2P_preprocess
(
img
)
else
:
...
...
@@ -56,10 +71,10 @@ class TestDataSet():
img_path
=
self
.
data_list
[
index
]
img
=
cv2
.
imread
(
img_path
,
cv2
.
IMREAD_COLOR
)
if
img
is
None
:
return
img
,
img
,
img_path
,
None
return
img
,
img
,
img_path
,
None
img_name
=
img_path
.
split
(
os
.
sep
)[
-
1
]
name_prefix
=
img_name
.
replace
(
'.'
+
img_name
.
split
(
'.'
)[
-
1
],
''
)
name_prefix
=
img_name
.
replace
(
'.'
+
img_name
.
split
(
'.'
)[
-
1
],
''
)
img_shape
=
img
.
shape
[:
2
]
img_process
=
self
.
preprocess
(
img
)
...
...
@@ -90,39 +105,44 @@ def infer():
if
image
is
None
:
print
(
im_name
,
'is None'
)
continue
# 预测
if
cfg
.
example
==
'ACE2P'
:
# ACE2P模型使用多尺度预测
reader
=
importlib
.
import_module
(
args
.
example
+
'.reader'
)
reader
=
importlib
.
import_module
(
args
.
example
+
'.reader'
)
multi_scale_test
=
getattr
(
reader
,
'multi_scale_test'
)
parsing
,
logits
=
multi_scale_test
(
exe
,
test_prog
,
feed_name
,
fetch_list
,
image
,
im_shape
)
parsing
,
logits
=
multi_scale_test
(
exe
,
test_prog
,
feed_name
,
fetch_list
,
image
,
im_shape
)
else
:
# HumanSeg,RoadLine模型单尺度预测
result
=
exe
.
run
(
program
=
test_prog
,
feed
=
{
feed_name
[
0
]:
image
},
fetch_list
=
fetch_list
)
result
=
exe
.
run
(
program
=
test_prog
,
feed
=
{
feed_name
[
0
]:
image
},
fetch_list
=
fetch_list
)
parsing
=
np
.
argmax
(
result
[
0
][
0
],
axis
=
0
)
parsing
=
cv2
.
resize
(
parsing
.
astype
(
np
.
uint8
),
im_shape
[::
-
1
])
# 预测结果保存
result_path
=
os
.
path
.
join
(
cfg
.
vis_dir
,
im_name
+
'.png'
)
if
cfg
.
example
==
'HumanSeg'
:
logits
=
result
[
0
][
0
][
1
]
*
255
logits
=
result
[
0
][
0
][
1
]
*
255
logits
=
cv2
.
resize
(
logits
,
im_shape
[::
-
1
])
ret
,
logits
=
cv2
.
threshold
(
logits
,
thresh
,
0
,
cv2
.
THRESH_TOZERO
)
logits
=
255
*
(
logits
-
thresh
)
/
(
255
-
thresh
)
logits
=
255
*
(
logits
-
thresh
)
/
(
255
-
thresh
)
# 将分割结果添加到alpha通道
rgba
=
np
.
concatenate
((
ori_img
,
np
.
expand_dims
(
logits
,
axis
=
2
)),
axis
=
2
)
rgba
=
np
.
concatenate
((
ori_img
,
np
.
expand_dims
(
logits
,
axis
=
2
)),
axis
=
2
)
cv2
.
imwrite
(
result_path
,
rgba
)
else
:
else
:
output_im
=
PILImage
.
fromarray
(
np
.
asarray
(
parsing
,
dtype
=
np
.
uint8
))
output_im
.
putpalette
(
palette
)
output_im
.
save
(
result_path
)
if
(
idx
+
1
)
%
100
==
0
:
print
(
'%d processd'
%
(
idx
+
1
))
print
(
'%d processd done'
%
(
idx
+
1
))
print
(
'%d processd done'
%
(
idx
+
1
))
return
0
...
...
contrib/RoadLine/utils/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
contrib/RoadLine/utils/palette.py
浏览文件 @
61645b1d
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: RainbowSecret
## Microsoft Research
## yuyua@microsoft.com
## Copyright (c) 2018
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
contrib/RoadLine/utils/util.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
import
argparse
import
os
def
get_arguments
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--use_gpu"
,
action
=
"store_true"
,
help
=
"Use gpu or cpu to test."
)
parser
.
add_argument
(
'--example'
,
type
=
str
,
help
=
'RoadLine, HumanSeg or ACE2P'
)
parser
.
add_argument
(
"--use_gpu"
,
action
=
"store_true"
,
help
=
"Use gpu or cpu to test."
)
parser
.
add_argument
(
'--example'
,
type
=
str
,
help
=
'RoadLine, HumanSeg or ACE2P'
)
return
parser
.
parse_args
()
...
...
@@ -34,6 +48,7 @@ class AttrDict(dict):
else
:
self
[
name
]
=
value
def
merge_cfg_from_args
(
args
,
cfg
):
"""Merge config keys, values in args into the global config."""
for
k
,
v
in
vars
(
args
).
items
():
...
...
@@ -44,4 +59,3 @@ def merge_cfg_from_args(args, cfg):
value
=
v
if
value
is
not
None
:
cfg
[
k
]
=
value
dataset/convert_voc2012.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
@@ -20,6 +21,8 @@ from PIL import Image
import
glob
LOCAL_PATH
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
def
remove_colormap
(
filename
):
gray_anno
=
np
.
array
(
Image
.
open
(
filename
))
return
gray_anno
...
...
@@ -30,6 +33,7 @@ def save_annotation(annotation, filename):
annotation
=
Image
.
fromarray
(
annotation
)
annotation
.
save
(
filename
)
def
convert_list
(
origin_file
,
seg_file
,
output_folder
):
with
open
(
seg_file
,
'w'
)
as
fid_seg
:
with
open
(
origin_file
)
as
fid_ori
:
...
...
@@ -43,6 +47,7 @@ def convert_list(origin_file, seg_file, output_folder):
new_line
=
' '
.
join
([
img_name
,
anno_name
])
fid_seg
.
write
(
new_line
+
"
\n
"
)
if
__name__
==
"__main__"
:
pascal_root
=
"./VOCtrainval_11-May-2012/VOC2012"
pascal_root
=
os
.
path
.
join
(
LOCAL_PATH
,
pascal_root
)
...
...
@@ -54,7 +59,7 @@ if __name__ == "__main__":
# 标注图转换后存储目录
output_folder
=
os
.
path
.
join
(
pascal_root
,
"SegmentationClassAug"
)
print
(
"annotation convert and file list convert"
)
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
LOCAL_PATH
,
output_folder
)):
os
.
mkdir
(
os
.
path
.
join
(
LOCAL_PATH
,
output_folder
))
...
...
@@ -67,5 +72,5 @@ if __name__ == "__main__":
convert_list
(
train_path
,
train_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
convert_list
(
val_path
,
val_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
convert_list
(
trainval_path
,
trainval_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
convert_list
(
trainval_path
,
trainval_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
dataset/download_and_convert_voc2012.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
@@ -28,12 +29,12 @@ from convert_voc2012 import remove_colormap
from
convert_voc2012
import
save_annotation
def
download_VOC_dataset
(
savepath
,
extrapath
):
url
=
"https://paddleseg.bj.bcebos.com/dataset/VOCtrainval_11-May-2012.tar"
download_file_and_uncompress
(
url
=
url
,
savepath
=
savepath
,
extrapath
=
extrapath
)
if
__name__
==
"__main__"
:
download_VOC_dataset
(
LOCAL_PATH
,
LOCAL_PATH
)
print
(
"Dataset download finish!"
)
...
...
@@ -45,10 +46,10 @@ if __name__ == "__main__":
train_path
=
os
.
path
.
join
(
txt_folder
,
"train.txt"
)
val_path
=
os
.
path
.
join
(
txt_folder
,
"val.txt"
)
trainval_path
=
os
.
path
.
join
(
txt_folder
,
"trainval.txt"
)
# 标注图转换后存储目录
output_folder
=
os
.
path
.
join
(
pascal_root
,
"SegmentationClassAug"
)
print
(
"annotation convert and file list convert"
)
if
not
os
.
path
.
exists
(
output_folder
):
os
.
mkdir
(
output_folder
)
...
...
@@ -61,5 +62,5 @@ if __name__ == "__main__":
convert_list
(
train_path
,
train_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
convert_list
(
val_path
,
val_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
convert_list
(
trainval_path
,
trainval_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
convert_list
(
trainval_path
,
trainval_path
.
replace
(
'txt'
,
'list'
),
output_folder
)
dataset/download_cityscapes.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
dataset/download_mini_deepglobe_road_extraction.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
dataset/download_optic.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
dataset/download_pet.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
deploy/python/infer.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -29,9 +29,9 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
gflags
.
DEFINE_string
(
"conf"
,
default
=
""
,
help
=
"Configuration File Path"
)
gflags
.
DEFINE_string
(
"input_dir"
,
default
=
""
,
help
=
"Directory of Input Images"
)
gflags
.
DEFINE_boolean
(
"use_pr"
,
default
=
False
,
help
=
"Use optimized model"
)
gflags
.
DEFINE_string
(
"trt_mode"
,
default
=
""
,
help
=
"Use optimized model"
)
gflags
.
DEFINE_string
(
"ext"
,
default
=
".jpeg|.jpg"
,
help
=
"Input Image File Extensions"
)
gflags
.
DEFINE_string
(
"ext"
,
default
=
".jpeg|.jpg"
,
help
=
"Input Image File Extensions"
)
gflags
.
FLAGS
=
gflags
.
FLAGS
...
...
@@ -103,6 +103,9 @@ class DeployConfig:
self
.
batch_size
=
deploy_conf
[
"BATCH_SIZE"
]
# 9. channels
self
.
channels
=
deploy_conf
[
"CHANNELS"
]
# 10. use_pr
self
.
use_pr
=
deploy_conf
[
"USE_PR"
]
class
ImageReader
:
...
...
@@ -257,23 +260,24 @@ class Predictor:
# record starting time point
total_start
=
time
.
time
()
batch_size
=
self
.
config
.
batch_size
use_pr
=
self
.
config
.
use_pr
for
i
in
range
(
0
,
len
(
images
),
batch_size
):
real_batch_size
=
batch_size
if
i
+
batch_size
>=
len
(
images
):
real_batch_size
=
len
(
images
)
-
i
reader_start
=
time
.
time
()
img_datas
=
self
.
image_reader
.
process
(
images
[
i
:
i
+
real_batch_size
],
gflags
.
FLAGS
.
use_pr
)
use_pr
)
input_data
=
np
.
concatenate
([
item
[
1
]
for
item
in
img_datas
])
input_data
=
self
.
create_tensor
(
input_data
,
real_batch_size
,
use_pr
=
gflags
.
FLAGS
.
use_pr
)
input_data
,
real_batch_size
,
use_pr
=
use_pr
)
reader_end
=
time
.
time
()
infer_start
=
time
.
time
()
output_data
=
self
.
predictor
.
run
(
input_data
)[
0
]
infer_end
=
time
.
time
()
output_data
=
output_data
.
as_ndarray
()
post_start
=
time
.
time
()
self
.
output_result
(
img_datas
,
output_data
,
gflags
.
FLAGS
.
use_pr
)
self
.
output_result
(
img_datas
,
output_data
,
use_pr
)
post_end
=
time
.
time
()
reader_time
+=
(
reader_end
-
reader_start
)
infer_time
+=
(
infer_end
-
infer_start
)
...
...
pdseg/__init__.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -14,4 +14,4 @@
# limitations under the License.
import
models
import
utils
from
.
import
tools
\ No newline at end of file
from
.
import
tools
pdseg/check.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
...
...
@@ -427,12 +440,17 @@ def max_img_size_statistics():
logger
.
info
(
"max width and max height of images are ({},{})"
.
format
(
max_width
,
max_height
))
def
num_classes_loss_matching_check
():
loss_type
=
cfg
.
SOLVER
.
LOSS
num_classes
=
cfg
.
DATASET
.
NUM_CLASSES
if
num_classes
>
2
and
((
"dice_loss"
in
loss_type
)
or
(
"bce_loss"
in
loss_type
)):
logger
.
info
(
error_print
(
"loss check."
" Dice loss and bce loss is only applicable to binary classfication"
))
if
num_classes
>
2
and
((
"dice_loss"
in
loss_type
)
or
(
"bce_loss"
in
loss_type
)):
logger
.
info
(
error_print
(
"loss check."
" Dice loss and bce loss is only applicable to binary classfication"
))
else
:
logger
.
info
(
correct_print
(
"loss check"
))
...
...
pdseg/data_aug.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -362,7 +362,7 @@ def hsv_color_jitter(crop_img,
saturation_jitter_ratio
>
0
or
\
contrast_jitter_ratio
>
0
:
crop_img
=
random_jitter
(
crop_img
,
saturation_jitter_ratio
,
brightness_jitter_ratio
,
contrast_jitter_ratio
)
brightness_jitter_ratio
,
contrast_jitter_ratio
)
return
crop_img
...
...
@@ -391,7 +391,7 @@ def rand_crop(crop_img, crop_seg, mode=ModelPhase.TRAIN):
crop_width
=
cfg
.
EVAL_CROP_SIZE
[
0
]
crop_height
=
cfg
.
EVAL_CROP_SIZE
[
1
]
if
not
ModelPhase
.
is_train
(
mode
):
if
not
ModelPhase
.
is_train
(
mode
):
if
(
crop_height
<
img_height
or
crop_width
<
img_width
):
raise
Exception
(
"Crop size({},{}) must large than img size({},{}) when in EvalPhase."
...
...
pdseg/data_utils.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py
"""
...
...
@@ -14,10 +28,10 @@ except ImportError:
class
GeneratorEnqueuer
(
object
):
"""
Multiple generators
Multiple generators
Args:
generators:
generators:
wait_time (float): time to sleep in-between calls to `put()`.
"""
...
...
pdseg/eval.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -22,13 +22,9 @@ import os
os
.
environ
[
'FLAGS_eager_delete_tensor_gb'
]
=
"0.0"
import
sys
import
time
import
argparse
import
functools
import
pprint
import
cv2
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
from
utils.config
import
cfg
...
...
@@ -116,7 +112,10 @@ def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs):
if
ckpt_dir
is
not
None
:
print
(
'load test model:'
,
ckpt_dir
)
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
try
:
fluid
.
load
(
test_prog
,
os
.
path
.
join
(
ckpt_dir
,
'model'
),
exe
)
except
:
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
# Use streaming confusion matrix to calculate mean_iou
np
.
set_printoptions
(
...
...
pdseg/export_model.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -49,6 +49,7 @@ def parse_args():
sys
.
exit
(
1
)
return
parser
.
parse_args
()
def
export_inference_config
():
deploy_cfg
=
'''DEPLOY:
USE_GPU : 1
...
...
@@ -66,9 +67,8 @@ def export_inference_config():
PREDICTOR_MODE : "ANALYSIS"
BATCH_SIZE : 1
'''
%
(
cfg
.
FREEZE
.
SAVE_DIR
,
cfg
.
FREEZE
.
MODEL_FILENAME
,
cfg
.
FREEZE
.
PARAMS_FILENAME
,
cfg
.
EVAL_CROP_SIZE
,
cfg
.
MEAN
,
cfg
.
STD
,
cfg
.
DATASET
.
IMAGE_TYPE
,
cfg
.
DATASET
.
NUM_CLASSES
,
len
(
cfg
.
STD
))
cfg
.
FREEZE
.
PARAMS_FILENAME
,
cfg
.
EVAL_CROP_SIZE
,
cfg
.
MEAN
,
cfg
.
STD
,
cfg
.
DATASET
.
IMAGE_TYPE
,
cfg
.
DATASET
.
NUM_CLASSES
,
len
(
cfg
.
STD
))
if
not
os
.
path
.
exists
(
cfg
.
FREEZE
.
SAVE_DIR
):
os
.
mkdir
(
cfg
.
FREEZE
.
SAVE_DIR
)
yaml_path
=
os
.
path
.
join
(
cfg
.
FREEZE
.
SAVE_DIR
,
'deploy.yaml'
)
...
...
@@ -94,7 +94,13 @@ def export_inference_model(args):
infer_prog
=
infer_prog
.
clone
(
for_test
=
True
)
if
os
.
path
.
exists
(
cfg
.
TEST
.
TEST_MODEL
):
fluid
.
io
.
load_params
(
exe
,
cfg
.
TEST
.
TEST_MODEL
,
main_program
=
infer_prog
)
print
(
'load test model:'
,
cfg
.
TEST
.
TEST_MODEL
)
try
:
fluid
.
load
(
infer_prog
,
os
.
path
.
join
(
cfg
.
TEST
.
TEST_MODEL
,
'model'
),
exe
)
except
:
fluid
.
io
.
load_params
(
exe
,
cfg
.
TEST
.
TEST_MODEL
,
main_program
=
infer_prog
)
else
:
print
(
"TEST.TEST_MODEL diretory is empty!"
)
exit
(
-
1
)
...
...
pdseg/loss.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
pdseg/lovasz_losses.py
浏览文件 @
61645b1d
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
pdseg/metrics.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
pdseg/models/__init__.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
pdseg/models/backbone/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pdseg/models/backbone/mobilenet_v2.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
pdseg/models/backbone/resnet.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -141,7 +141,7 @@ class ResNet():
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
dilation_rate
=
get_dilated_rate
(
dilation_dict
,
block
)
conv
=
self
.
bottleneck_block
(
input
=
conv
,
num_filters
=
int
(
num_filters
[
block
]
*
self
.
scale
),
...
...
@@ -215,11 +215,11 @@ class ResNet():
groups
=
1
,
act
=
None
,
name
=
None
):
if
self
.
stem
==
'pspnet'
:
bias_attr
=
ParamAttr
(
name
=
name
+
"_biases"
)
bias_attr
=
ParamAttr
(
name
=
name
+
"_biases"
)
else
:
bias_attr
=
False
bias_attr
=
False
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
...
...
@@ -238,13 +238,15 @@ class ResNet():
bn_name
=
"bn_"
+
name
else
:
bn_name
=
"bn"
+
name
[
3
:]
return
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
act
,
name
=
bn_name
+
'.output.1'
,
param_attr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
),
bias_attr
=
ParamAttr
(
bn_name
+
'_offset'
),
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
,
)
return
fluid
.
layers
.
batch_norm
(
input
=
conv
,
act
=
act
,
name
=
bn_name
+
'.output.1'
,
param_attr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
),
bias_attr
=
ParamAttr
(
bn_name
+
'_offset'
),
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
,
)
def
shortcut
(
self
,
input
,
ch_out
,
stride
,
is_first
,
name
):
ch_in
=
input
.
shape
[
1
]
...
...
@@ -258,7 +260,7 @@ class ResNet():
strides
=
[
1
,
stride
]
else
:
strides
=
[
stride
,
1
]
conv0
=
self
.
conv_bn_layer
(
input
=
input
,
num_filters
=
num_filters
,
...
...
pdseg/models/backbone/vgg.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -55,7 +55,8 @@ class VGGNet():
channels
=
[
64
,
128
,
256
,
512
,
512
]
conv
=
input
for
i
in
range
(
len
(
nums
)):
conv
=
self
.
conv_block
(
conv
,
channels
[
i
],
nums
[
i
],
name
=
"conv"
+
str
(
i
+
1
)
+
"_"
)
conv
=
self
.
conv_block
(
conv
,
channels
[
i
],
nums
[
i
],
name
=
"conv"
+
str
(
i
+
1
)
+
"_"
)
layers_count
+=
nums
[
i
]
if
check_points
(
layers_count
,
decode_points
):
short_cuts
[
layers_count
]
=
conv
...
...
pdseg/models/backbone/xception.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
pdseg/models/libs/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pdseg/models/libs/model_libs.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -197,4 +197,4 @@ def conv_bn_layer(input,
if
if_act
:
return
fluid
.
layers
.
relu6
(
bn
)
else
:
return
bn
\ No newline at end of file
return
bn
pdseg/models/model_builder.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
pdseg/models/modeling/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pdseg/models/modeling/deeplab.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
pdseg/models/modeling/fast_scnn.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
pdseg/models/modeling/hrnet.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -25,7 +25,14 @@ from paddle.fluid.param_attr import ParamAttr
from
utils.config
import
cfg
def
conv_bn_layer
(
input
,
filter_size
,
num_filters
,
stride
=
1
,
padding
=
1
,
num_groups
=
1
,
if_act
=
True
,
name
=
None
):
def
conv_bn_layer
(
input
,
filter_size
,
num_filters
,
stride
=
1
,
padding
=
1
,
num_groups
=
1
,
if_act
=
True
,
name
=
None
):
conv
=
fluid
.
layers
.
conv2d
(
input
=
input
,
num_filters
=
num_filters
,
...
...
@@ -37,37 +44,74 @@ def conv_bn_layer(input, filter_size, num_filters, stride=1, padding=1, num_grou
param_attr
=
ParamAttr
(
initializer
=
MSRA
(),
name
=
name
+
'_weights'
),
bias_attr
=
False
)
bn_name
=
name
+
'_bn'
bn
=
fluid
.
layers
.
batch_norm
(
input
=
conv
,
param_attr
=
ParamAttr
(
name
=
bn_name
+
"_scale"
,
initializer
=
fluid
.
initializer
.
Constant
(
1.0
)),
bias_attr
=
ParamAttr
(
name
=
bn_name
+
"_offset"
,
initializer
=
fluid
.
initializer
.
Constant
(
0.0
)),
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
)
bn
=
fluid
.
layers
.
batch_norm
(
input
=
conv
,
param_attr
=
ParamAttr
(
name
=
bn_name
+
"_scale"
,
initializer
=
fluid
.
initializer
.
Constant
(
1.0
)),
bias_attr
=
ParamAttr
(
name
=
bn_name
+
"_offset"
,
initializer
=
fluid
.
initializer
.
Constant
(
0.0
)),
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
)
if
if_act
:
bn
=
fluid
.
layers
.
relu
(
bn
)
return
bn
def
basic_block
(
input
,
num_filters
,
stride
=
1
,
downsample
=
False
,
name
=
None
):
residual
=
input
conv
=
conv_bn_layer
(
input
=
input
,
filter_size
=
3
,
num_filters
=
num_filters
,
stride
=
stride
,
name
=
name
+
'_conv1'
)
conv
=
conv_bn_layer
(
input
=
conv
,
filter_size
=
3
,
num_filters
=
num_filters
,
if_act
=
False
,
name
=
name
+
'_conv2'
)
conv
=
conv_bn_layer
(
input
=
input
,
filter_size
=
3
,
num_filters
=
num_filters
,
stride
=
stride
,
name
=
name
+
'_conv1'
)
conv
=
conv_bn_layer
(
input
=
conv
,
filter_size
=
3
,
num_filters
=
num_filters
,
if_act
=
False
,
name
=
name
+
'_conv2'
)
if
downsample
:
residual
=
conv_bn_layer
(
input
=
input
,
filter_size
=
1
,
num_filters
=
num_filters
,
if_act
=
False
,
name
=
name
+
'_downsample'
)
residual
=
conv_bn_layer
(
input
=
input
,
filter_size
=
1
,
num_filters
=
num_filters
,
if_act
=
False
,
name
=
name
+
'_downsample'
)
return
fluid
.
layers
.
elementwise_add
(
x
=
residual
,
y
=
conv
,
act
=
'relu'
)
def
bottleneck_block
(
input
,
num_filters
,
stride
=
1
,
downsample
=
False
,
name
=
None
):
residual
=
input
conv
=
conv_bn_layer
(
input
=
input
,
filter_size
=
1
,
num_filters
=
num_filters
,
name
=
name
+
'_conv1'
)
conv
=
conv_bn_layer
(
input
=
conv
,
filter_size
=
3
,
num_filters
=
num_filters
,
stride
=
stride
,
name
=
name
+
'_conv2'
)
conv
=
conv_bn_layer
(
input
=
conv
,
filter_size
=
1
,
num_filters
=
num_filters
*
4
,
if_act
=
False
,
name
=
name
+
'_conv3'
)
conv
=
conv_bn_layer
(
input
=
input
,
filter_size
=
1
,
num_filters
=
num_filters
,
name
=
name
+
'_conv1'
)
conv
=
conv_bn_layer
(
input
=
conv
,
filter_size
=
3
,
num_filters
=
num_filters
,
stride
=
stride
,
name
=
name
+
'_conv2'
)
conv
=
conv_bn_layer
(
input
=
conv
,
filter_size
=
1
,
num_filters
=
num_filters
*
4
,
if_act
=
False
,
name
=
name
+
'_conv3'
)
if
downsample
:
residual
=
conv_bn_layer
(
input
=
input
,
filter_size
=
1
,
num_filters
=
num_filters
*
4
,
if_act
=
False
,
name
=
name
+
'_downsample'
)
residual
=
conv_bn_layer
(
input
=
input
,
filter_size
=
1
,
num_filters
=
num_filters
*
4
,
if_act
=
False
,
name
=
name
+
'_downsample'
)
return
fluid
.
layers
.
elementwise_add
(
x
=
residual
,
y
=
conv
,
act
=
'relu'
)
def
fuse_layers
(
x
,
channels
,
multi_scale_output
=
True
,
name
=
None
):
out
=
[]
for
i
in
range
(
len
(
channels
)
if
multi_scale_output
else
1
):
...
...
@@ -77,40 +121,64 @@ def fuse_layers(x, channels, multi_scale_output=True, name=None):
height
=
shape
[
-
2
]
for
j
in
range
(
len
(
channels
)):
if
j
>
i
:
y
=
conv_bn_layer
(
x
[
j
],
filter_size
=
1
,
num_filters
=
channels
[
i
],
if_act
=
False
,
name
=
name
+
'_layer_'
+
str
(
i
+
1
)
+
'_'
+
str
(
j
+
1
))
y
=
fluid
.
layers
.
resize_bilinear
(
input
=
y
,
out_shape
=
[
height
,
width
])
residual
=
fluid
.
layers
.
elementwise_add
(
x
=
residual
,
y
=
y
,
act
=
None
)
y
=
conv_bn_layer
(
x
[
j
],
filter_size
=
1
,
num_filters
=
channels
[
i
],
if_act
=
False
,
name
=
name
+
'_layer_'
+
str
(
i
+
1
)
+
'_'
+
str
(
j
+
1
))
y
=
fluid
.
layers
.
resize_bilinear
(
input
=
y
,
out_shape
=
[
height
,
width
])
residual
=
fluid
.
layers
.
elementwise_add
(
x
=
residual
,
y
=
y
,
act
=
None
)
elif
j
<
i
:
y
=
x
[
j
]
for
k
in
range
(
i
-
j
):
if
k
==
i
-
j
-
1
:
y
=
conv_bn_layer
(
y
,
filter_size
=
3
,
num_filters
=
channels
[
i
],
stride
=
2
,
if_act
=
False
,
name
=
name
+
'_layer_'
+
str
(
i
+
1
)
+
'_'
+
str
(
j
+
1
)
+
'_'
+
str
(
k
+
1
))
y
=
conv_bn_layer
(
y
,
filter_size
=
3
,
num_filters
=
channels
[
i
],
stride
=
2
,
if_act
=
False
,
name
=
name
+
'_layer_'
+
str
(
i
+
1
)
+
'_'
+
str
(
j
+
1
)
+
'_'
+
str
(
k
+
1
))
else
:
y
=
conv_bn_layer
(
y
,
filter_size
=
3
,
num_filters
=
channels
[
j
],
stride
=
2
,
name
=
name
+
'_layer_'
+
str
(
i
+
1
)
+
'_'
+
str
(
j
+
1
)
+
'_'
+
str
(
k
+
1
))
residual
=
fluid
.
layers
.
elementwise_add
(
x
=
residual
,
y
=
y
,
act
=
None
)
y
=
conv_bn_layer
(
y
,
filter_size
=
3
,
num_filters
=
channels
[
j
],
stride
=
2
,
name
=
name
+
'_layer_'
+
str
(
i
+
1
)
+
'_'
+
str
(
j
+
1
)
+
'_'
+
str
(
k
+
1
))
residual
=
fluid
.
layers
.
elementwise_add
(
x
=
residual
,
y
=
y
,
act
=
None
)
residual
=
fluid
.
layers
.
relu
(
residual
)
out
.
append
(
residual
)
return
out
def
branches
(
x
,
block_num
,
channels
,
name
=
None
):
out
=
[]
for
i
in
range
(
len
(
channels
)):
residual
=
x
[
i
]
for
j
in
range
(
block_num
):
residual
=
basic_block
(
residual
,
channels
[
i
],
name
=
name
+
'_branch_layer_'
+
str
(
i
+
1
)
+
'_'
+
str
(
j
+
1
))
residual
=
basic_block
(
residual
,
channels
[
i
],
name
=
name
+
'_branch_layer_'
+
str
(
i
+
1
)
+
'_'
+
str
(
j
+
1
))
out
.
append
(
residual
)
return
out
def
high_resolution_module
(
x
,
channels
,
multi_scale_output
=
True
,
name
=
None
):
residual
=
branches
(
x
,
4
,
channels
,
name
=
name
)
out
=
fuse_layers
(
residual
,
channels
,
multi_scale_output
=
multi_scale_output
,
name
=
name
)
out
=
fuse_layers
(
residual
,
channels
,
multi_scale_output
=
multi_scale_output
,
name
=
name
)
return
out
def
transition_layer
(
x
,
in_channels
,
out_channels
,
name
=
None
):
num_in
=
len
(
in_channels
)
num_out
=
len
(
out_channels
)
...
...
@@ -118,46 +186,76 @@ def transition_layer(x, in_channels, out_channels, name=None):
for
i
in
range
(
num_out
):
if
i
<
num_in
:
if
in_channels
[
i
]
!=
out_channels
[
i
]:
residual
=
conv_bn_layer
(
x
[
i
],
filter_size
=
3
,
num_filters
=
out_channels
[
i
],
name
=
name
+
'_layer_'
+
str
(
i
+
1
))
residual
=
conv_bn_layer
(
x
[
i
],
filter_size
=
3
,
num_filters
=
out_channels
[
i
],
name
=
name
+
'_layer_'
+
str
(
i
+
1
))
out
.
append
(
residual
)
else
:
out
.
append
(
x
[
i
])
else
:
residual
=
conv_bn_layer
(
x
[
-
1
],
filter_size
=
3
,
num_filters
=
out_channels
[
i
],
stride
=
2
,
name
=
name
+
'_layer_'
+
str
(
i
+
1
))
residual
=
conv_bn_layer
(
x
[
-
1
],
filter_size
=
3
,
num_filters
=
out_channels
[
i
],
stride
=
2
,
name
=
name
+
'_layer_'
+
str
(
i
+
1
))
out
.
append
(
residual
)
return
out
def
stage
(
x
,
num_modules
,
channels
,
multi_scale_output
=
True
,
name
=
None
):
out
=
x
for
i
in
range
(
num_modules
):
if
i
==
num_modules
-
1
and
multi_scale_output
==
False
:
out
=
high_resolution_module
(
out
,
channels
,
multi_scale_output
=
False
,
name
=
name
+
'_'
+
str
(
i
+
1
))
out
=
high_resolution_module
(
out
,
channels
,
multi_scale_output
=
False
,
name
=
name
+
'_'
+
str
(
i
+
1
))
else
:
out
=
high_resolution_module
(
out
,
channels
,
name
=
name
+
'_'
+
str
(
i
+
1
))
out
=
high_resolution_module
(
out
,
channels
,
name
=
name
+
'_'
+
str
(
i
+
1
))
return
out
def
layer1
(
input
,
name
=
None
):
conv
=
input
for
i
in
range
(
4
):
conv
=
bottleneck_block
(
conv
,
num_filters
=
64
,
downsample
=
True
if
i
==
0
else
False
,
name
=
name
+
'_'
+
str
(
i
+
1
))
conv
=
bottleneck_block
(
conv
,
num_filters
=
64
,
downsample
=
True
if
i
==
0
else
False
,
name
=
name
+
'_'
+
str
(
i
+
1
))
return
conv
def
high_resolution_net
(
input
,
num_classes
):
channels_2
=
cfg
.
MODEL
.
HRNET
.
STAGE2
.
NUM_CHANNELS
channels_3
=
cfg
.
MODEL
.
HRNET
.
STAGE3
.
NUM_CHANNELS
channels_4
=
cfg
.
MODEL
.
HRNET
.
STAGE4
.
NUM_CHANNELS
num_modules_2
=
cfg
.
MODEL
.
HRNET
.
STAGE2
.
NUM_MODULES
num_modules_3
=
cfg
.
MODEL
.
HRNET
.
STAGE3
.
NUM_MODULES
num_modules_4
=
cfg
.
MODEL
.
HRNET
.
STAGE4
.
NUM_MODULES
x
=
conv_bn_layer
(
input
=
input
,
filter_size
=
3
,
num_filters
=
64
,
stride
=
2
,
if_act
=
True
,
name
=
'layer1_1'
)
x
=
conv_bn_layer
(
input
=
x
,
filter_size
=
3
,
num_filters
=
64
,
stride
=
2
,
if_act
=
True
,
name
=
'layer1_2'
)
x
=
conv_bn_layer
(
input
=
input
,
filter_size
=
3
,
num_filters
=
64
,
stride
=
2
,
if_act
=
True
,
name
=
'layer1_1'
)
x
=
conv_bn_layer
(
input
=
x
,
filter_size
=
3
,
num_filters
=
64
,
stride
=
2
,
if_act
=
True
,
name
=
'layer1_2'
)
la1
=
layer1
(
x
,
name
=
'layer2'
)
tr1
=
transition_layer
([
la1
],
[
256
],
channels_2
,
name
=
'tr1'
)
...
...
@@ -170,18 +268,21 @@ def high_resolution_net(input, num_classes):
# upsample
shape
=
st4
[
0
].
shape
height
,
width
=
shape
[
-
2
],
shape
[
-
1
]
st4
[
1
]
=
fluid
.
layers
.
resize_bilinear
(
st4
[
1
],
out_shape
=
[
height
,
width
])
st4
[
2
]
=
fluid
.
layers
.
resize_bilinear
(
st4
[
2
],
out_shape
=
[
height
,
width
])
st4
[
3
]
=
fluid
.
layers
.
resize_bilinear
(
st4
[
3
],
out_shape
=
[
height
,
width
])
st4
[
1
]
=
fluid
.
layers
.
resize_bilinear
(
st4
[
1
],
out_shape
=
[
height
,
width
])
st4
[
2
]
=
fluid
.
layers
.
resize_bilinear
(
st4
[
2
],
out_shape
=
[
height
,
width
])
st4
[
3
]
=
fluid
.
layers
.
resize_bilinear
(
st4
[
3
],
out_shape
=
[
height
,
width
])
out
=
fluid
.
layers
.
concat
(
st4
,
axis
=
1
)
last_channels
=
sum
(
channels_4
)
out
=
conv_bn_layer
(
input
=
out
,
filter_size
=
1
,
num_filters
=
last_channels
,
stride
=
1
,
if_act
=
True
,
name
=
'conv-2'
)
out
=
fluid
.
layers
.
conv2d
(
out
=
conv_bn_layer
(
input
=
out
,
filter_size
=
1
,
num_filters
=
last_channels
,
stride
=
1
,
if_act
=
True
,
name
=
'conv-2'
)
out
=
fluid
.
layers
.
conv2d
(
input
=
out
,
num_filters
=
num_classes
,
filter_size
=
1
,
...
...
@@ -193,7 +294,6 @@ def high_resolution_net(input, num_classes):
out
=
fluid
.
layers
.
resize_bilinear
(
out
,
input
.
shape
[
2
:])
return
out
...
...
@@ -201,6 +301,7 @@ def hrnet(input, num_classes):
logit
=
high_resolution_net
(
input
,
num_classes
)
return
logit
if
__name__
==
'__main__'
:
image_shape
=
[
-
1
,
3
,
769
,
769
]
image
=
fluid
.
data
(
name
=
'image'
,
shape
=
image_shape
,
dtype
=
'float32'
)
...
...
pdseg/models/modeling/icnet.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
pdseg/models/modeling/pspnet.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -24,6 +24,7 @@ from models.libs.model_libs import avg_pool, conv, bn
from
models.backbone.resnet
import
ResNet
as
resnet_backbone
from
utils.config
import
cfg
def
get_logit_interp
(
input
,
num_classes
,
out_shape
,
name
=
"logit"
):
# 根据类别数决定最后一层卷积输出, 并插值回原始尺寸
param_attr
=
fluid
.
ParamAttr
(
...
...
@@ -33,16 +34,15 @@ def get_logit_interp(input, num_classes, out_shape, name="logit"):
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.01
))
with
scope
(
name
):
logit
=
conv
(
input
,
num_classes
,
filter_size
=
1
,
param_attr
=
param_attr
,
bias_attr
=
True
,
name
=
name
+
'_conv'
)
logit
=
conv
(
input
,
num_classes
,
filter_size
=
1
,
param_attr
=
param_attr
,
bias_attr
=
True
,
name
=
name
+
'_conv'
)
logit_interp
=
fluid
.
layers
.
resize_bilinear
(
logit
,
out_shape
=
out_shape
,
name
=
name
+
'_interp'
)
logit
,
out_shape
=
out_shape
,
name
=
name
+
'_interp'
)
return
logit_interp
...
...
@@ -51,40 +51,44 @@ def psp_module(input, out_features):
# 输入:backbone输出的特征
# 输出:对输入进行不同尺度pooling, 卷积操作后插值回原始尺寸,并concat
# 最后进行一个卷积及BN操作
cat_layers
=
[]
sizes
=
(
1
,
2
,
3
,
6
)
sizes
=
(
1
,
2
,
3
,
6
)
for
size
in
sizes
:
psp_name
=
"psp"
+
str
(
size
)
with
scope
(
psp_name
):
pool
=
fluid
.
layers
.
adaptive_pool2d
(
input
,
pool_size
=
[
size
,
size
],
pool_type
=
'avg'
,
name
=
psp_name
+
'_adapool'
)
data
=
conv
(
pool
,
out_features
,
filter_size
=
1
,
bias_attr
=
True
,
name
=
psp_name
+
'_conv'
)
pool
=
fluid
.
layers
.
adaptive_pool2d
(
input
,
pool_size
=
[
size
,
size
],
pool_type
=
'avg'
,
name
=
psp_name
+
'_adapool'
)
data
=
conv
(
pool
,
out_features
,
filter_size
=
1
,
bias_attr
=
True
,
name
=
psp_name
+
'_conv'
)
data_bn
=
bn
(
data
,
act
=
'relu'
)
interp
=
fluid
.
layers
.
resize_bilinear
(
data_bn
,
out_shape
=
input
.
shape
[
2
:],
name
=
psp_name
+
'_interp'
)
interp
=
fluid
.
layers
.
resize_bilinear
(
data_bn
,
out_shape
=
input
.
shape
[
2
:],
name
=
psp_name
+
'_interp'
)
cat_layers
.
append
(
interp
)
cat_layers
=
[
input
]
+
cat_layers
[::
-
1
]
cat
=
fluid
.
layers
.
concat
(
cat_layers
,
axis
=
1
,
name
=
'psp_cat'
)
psp_end_name
=
"psp_end"
with
scope
(
psp_end_name
):
data
=
conv
(
cat
,
out_features
,
filter_size
=
3
,
padding
=
1
,
bias_attr
=
True
,
name
=
psp_end_name
)
data
=
conv
(
cat
,
out_features
,
filter_size
=
3
,
padding
=
1
,
bias_attr
=
True
,
name
=
psp_end_name
)
out
=
bn
(
data
,
act
=
'relu'
)
return
out
def
resnet
(
input
):
# PSPNET backbone: resnet, 默认resnet50
# end_points: resnet终止层数
...
...
@@ -92,14 +96,14 @@ def resnet(input):
scale
=
cfg
.
MODEL
.
PSPNET
.
DEPTH_MULTIPLIER
layers
=
cfg
.
MODEL
.
PSPNET
.
LAYERS
end_points
=
layers
-
1
dilation_dict
=
{
2
:
2
,
3
:
4
}
dilation_dict
=
{
2
:
2
,
3
:
4
}
model
=
resnet_backbone
(
layers
,
scale
,
stem
=
'pspnet'
)
data
,
_
=
model
.
net
(
input
,
end_points
=
end_points
,
dilation_dict
=
dilation_dict
)
data
,
_
=
model
.
net
(
input
,
end_points
=
end_points
,
dilation_dict
=
dilation_dict
)
return
data
def
pspnet
(
input
,
num_classes
):
# Backbone: ResNet
res
=
resnet
(
input
)
...
...
@@ -109,4 +113,3 @@ def pspnet(input, num_classes):
# 根据类别数决定最后一层卷积输出, 并插值回原始尺寸
logit
=
get_logit_interp
(
dropout
,
num_classes
,
input
.
shape
[
2
:])
return
logit
pdseg/models/modeling/unet.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
pdseg/reader.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -71,7 +71,8 @@ class SegDataset(object):
if
self
.
shuffle
and
cfg
.
NUM_TRAINERS
>
1
:
np
.
random
.
RandomState
(
self
.
shuffle_seed
).
shuffle
(
self
.
all_lines
)
num_lines
=
len
(
self
.
all_lines
)
//
cfg
.
NUM_TRAINERS
self
.
lines
=
self
.
all_lines
[
num_lines
*
cfg
.
TRAINER_ID
:
num_lines
*
(
cfg
.
TRAINER_ID
+
1
)]
self
.
lines
=
self
.
all_lines
[
num_lines
*
cfg
.
TRAINER_ID
:
num_lines
*
(
cfg
.
TRAINER_ID
+
1
)]
self
.
shuffle_seed
+=
1
elif
self
.
shuffle
:
np
.
random
.
shuffle
(
self
.
lines
)
...
...
@@ -99,7 +100,8 @@ class SegDataset(object):
if
self
.
shuffle
and
cfg
.
NUM_TRAINERS
>
1
:
np
.
random
.
RandomState
(
self
.
shuffle_seed
).
shuffle
(
self
.
all_lines
)
num_lines
=
len
(
self
.
all_lines
)
//
cfg
.
NUM_TRAINERS
self
.
lines
=
self
.
all_lines
[
num_lines
*
cfg
.
TRAINER_ID
:
num_lines
*
(
cfg
.
TRAINER_ID
+
1
)]
self
.
lines
=
self
.
all_lines
[
num_lines
*
cfg
.
TRAINER_ID
:
num_lines
*
(
cfg
.
TRAINER_ID
+
1
)]
self
.
shuffle_seed
+=
1
elif
self
.
shuffle
:
np
.
random
.
shuffle
(
self
.
lines
)
...
...
pdseg/solver.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
pdseg/tools/__init__.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
pdseg/tools/create_dataset_list.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -21,55 +21,48 @@ import warnings
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'PaddleSeg generate file list on cityscapes or your customized dataset.'
)
parser
.
add_argument
(
'dataset_root'
,
help
=
'dataset root directory'
,
type
=
str
description
=
'PaddleSeg generate file list on cityscapes or your customized dataset.'
)
parser
.
add_argument
(
'dataset_root'
,
help
=
'dataset root directory'
,
type
=
str
)
parser
.
add_argument
(
'--type'
,
help
=
'dataset type:
\n
'
'- cityscapes
\n
'
'- custom(default)'
,
'- cityscapes
\n
'
'- custom(default)'
,
default
=
"custom"
,
type
=
str
)
type
=
str
)
parser
.
add_argument
(
'--separator'
,
dest
=
'separator'
,
help
=
'file list separator'
,
default
=
"|"
,
type
=
str
)
type
=
str
)
parser
.
add_argument
(
'--folder'
,
help
=
'the folder names of images and labels'
,
type
=
str
,
nargs
=
2
,
default
=
[
'images'
,
'annotations'
]
)
default
=
[
'images'
,
'annotations'
])
parser
.
add_argument
(
'--second_folder'
,
help
=
'the second-level folder names of train set, validation set, test set'
,
help
=
'the second-level folder names of train set, validation set, test set'
,
type
=
str
,
nargs
=
'*'
,
default
=
[
'train'
,
'val'
,
'test'
]
)
default
=
[
'train'
,
'val'
,
'test'
])
parser
.
add_argument
(
'--format'
,
help
=
'data format of images and labels, e.g. jpg or png.'
,
type
=
str
,
nargs
=
2
,
default
=
[
'jpg'
,
'png'
]
)
default
=
[
'jpg'
,
'png'
])
parser
.
add_argument
(
'--postfix'
,
help
=
'postfix of images or labels'
,
type
=
str
,
nargs
=
2
,
default
=
[
''
,
''
]
)
default
=
[
''
,
''
])
return
parser
.
parse_args
()
...
...
@@ -120,15 +113,17 @@ def generate_list(args):
num_images
=
len
(
image_files
)
if
not
label_files
:
label_dir
=
os
.
path
.
join
(
dataset_root
,
args
.
folder
[
1
],
dataset_split
)
label_dir
=
os
.
path
.
join
(
dataset_root
,
args
.
folder
[
1
],
dataset_split
)
warnings
.
warn
(
"No labels in {} !!!"
.
format
(
label_dir
))
num_label
=
len
(
label_files
)
if
num_images
!=
num_label
and
num_label
>
0
:
raise
Exception
(
"Number of images = {} number of labels = {}
\n
"
"Either number of images is equal to number of labels, "
"or number of labels is equal to 0.
\n
"
"Please check your dataset!"
.
format
(
num_images
,
num_label
))
raise
Exception
(
"Number of images = {} number of labels = {}
\n
"
"Either number of images is equal to number of labels, "
"or number of labels is equal to 0.
\n
"
"Please check your dataset!"
.
format
(
num_images
,
num_label
))
file_list
=
os
.
path
.
join
(
dataset_root
,
dataset_split
+
'.txt'
)
with
open
(
file_list
,
"w"
)
as
f
:
...
...
pdseg/tools/gray2pseudo_color.py
浏览文件 @
61645b1d
# -*- coding: utf-8 -*-
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
argparse
...
...
@@ -11,16 +25,12 @@ from PIL import Image
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
)
parser
.
add_argument
(
'dir_or_file'
,
help
=
'input gray label directory or file list path'
)
parser
.
add_argument
(
'output_dir'
,
help
=
'output colorful label directory'
)
parser
.
add_argument
(
'--dataset_dir'
,
help
=
'dataset directory'
)
parser
.
add_argument
(
'--file_separator'
,
help
=
'file list separator'
)
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
)
parser
.
add_argument
(
'dir_or_file'
,
help
=
'input gray label directory or file list path'
)
parser
.
add_argument
(
'output_dir'
,
help
=
'output colorful label directory'
)
parser
.
add_argument
(
'--dataset_dir'
,
help
=
'dataset directory'
)
parser
.
add_argument
(
'--file_separator'
,
help
=
'file list separator'
)
return
parser
.
parse_args
()
...
...
pdseg/tools/jingling2seg.py
浏览文件 @
61645b1d
#!/usr/bin/env python
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
...
...
pdseg/tools/labelme2seg.py
浏览文件 @
61645b1d
#!/usr/bin/env python
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
...
...
pdseg/train.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -26,9 +26,7 @@ import argparse
import
pprint
import
random
import
shutil
import
functools
import
paddle
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid
import
profiler
...
...
@@ -39,10 +37,10 @@ from metrics import ConfusionMatrix
from
reader
import
SegDataset
from
models.model_builder
import
build_model
from
models.model_builder
import
ModelPhase
from
models.model_builder
import
parse_shape_from_file
from
eval
import
evaluate
from
vis
import
visualize
from
utils
import
dist_utils
from
utils.load_model_utils
import
load_pretrained_weights
def
parse_args
():
...
...
@@ -118,38 +116,7 @@ def parse_args():
return
parser
.
parse_args
()
def
save_vars
(
executor
,
dirname
,
program
=
None
,
vars
=
None
):
"""
Temporary resolution for Win save variables compatability.
Will fix in PaddlePaddle v1.5.2
"""
save_program
=
fluid
.
Program
()
save_block
=
save_program
.
global_block
()
for
each_var
in
vars
:
# NOTE: don't save the variable which type is RAW
if
each_var
.
type
==
fluid
.
core
.
VarDesc
.
VarType
.
RAW
:
continue
new_var
=
save_block
.
create_var
(
name
=
each_var
.
name
,
shape
=
each_var
.
shape
,
dtype
=
each_var
.
dtype
,
type
=
each_var
.
type
,
lod_level
=
each_var
.
lod_level
,
persistable
=
True
)
file_path
=
os
.
path
.
join
(
dirname
,
new_var
.
name
)
file_path
=
os
.
path
.
normpath
(
file_path
)
save_block
.
append_op
(
type
=
'save'
,
inputs
=
{
'X'
:
[
new_var
]},
outputs
=
{},
attrs
=
{
'file_path'
:
file_path
})
executor
.
run
(
save_program
)
def
save_checkpoint
(
exe
,
program
,
ckpt_name
):
def
save_checkpoint
(
program
,
ckpt_name
):
"""
Save checkpoint for evaluation or resume training
"""
...
...
@@ -158,29 +125,22 @@ def save_checkpoint(exe, program, ckpt_name):
if
not
os
.
path
.
isdir
(
ckpt_dir
):
os
.
makedirs
(
ckpt_dir
)
save_vars
(
exe
,
ckpt_dir
,
program
,
vars
=
list
(
filter
(
fluid
.
io
.
is_persistable
,
program
.
list_vars
())))
fluid
.
save
(
program
,
os
.
path
.
join
(
ckpt_dir
,
'model'
))
return
ckpt_dir
def
load_checkpoint
(
exe
,
program
):
"""
Load checkpoiont f
rom pretrained model directory for resume
training
Load checkpoiont f
or resuming
training
"""
print
(
'Resume model training from:'
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
)
if
not
os
.
path
.
exists
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
):
raise
ValueError
(
"TRAIN.PRETRAIN_MODEL {} not exist!"
.
format
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
))
fluid
.
io
.
load_persistables
(
exe
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
,
main_program
=
program
)
model_path
=
cfg
.
TRAIN
.
RESUME_MODEL_DIR
print
(
'Resume model training from:'
,
model_path
)
if
not
os
.
path
.
exists
(
model_path
):
raise
ValueError
(
"TRAIN.PRETRAIN_MODEL {} not exist!"
.
format
(
model_path
))
fluid
.
load
(
program
,
os
.
path
.
join
(
model_path
,
'model'
),
exe
)
# Check is path ended by path spearator
if
model_path
[
-
1
]
==
os
.
sep
:
model_path
=
model_path
[
0
:
-
1
]
...
...
@@ -195,7 +155,6 @@ def load_checkpoint(exe, program):
else
:
raise
ValueError
(
"Resume model path is not valid!"
)
print
(
"Model checkpoint loaded successfully!"
)
return
begin_epoch
...
...
@@ -247,8 +206,6 @@ def train(cfg):
yield
item
[
0
],
item
[
1
],
item
[
2
]
# Get device environment
# places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
# place = places[0]
gpu_id
=
int
(
os
.
environ
.
get
(
'FLAGS_selected_gpus'
,
0
))
place
=
fluid
.
CUDAPlace
(
gpu_id
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
places
=
fluid
.
cuda_places
()
if
args
.
use_gpu
else
fluid
.
cpu_places
()
...
...
@@ -304,42 +261,7 @@ def train(cfg):
begin_epoch
=
load_checkpoint
(
exe
,
train_prog
)
# Load pretrained model
elif
os
.
path
.
exists
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
):
print_info
(
'Pretrained model dir: '
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
load_vars
=
[]
load_fail_vars
=
[]
def
var_shape_matched
(
var
,
shape
):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist
=
os
.
path
.
exists
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
if
var_exist
:
var_shape
=
parse_shape_from_file
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
return
var_shape
==
shape
return
False
for
x
in
train_prog
.
list_vars
():
if
isinstance
(
x
,
fluid
.
framework
.
Parameter
):
shape
=
tuple
(
fluid
.
global_scope
().
find_var
(
x
.
name
).
get_tensor
().
shape
())
if
var_shape_matched
(
x
,
shape
):
load_vars
.
append
(
x
)
else
:
load_fail_vars
.
append
(
x
)
fluid
.
io
.
load_vars
(
exe
,
dirname
=
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
vars
=
load_vars
)
for
var
in
load_vars
:
print_info
(
"Parameter[{}] loaded sucessfully!"
.
format
(
var
.
name
))
for
var
in
load_fail_vars
:
print_info
(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it."
.
format
(
var
.
name
))
print_info
(
"{}/{} pretrained parameters loaded successfully!"
.
format
(
len
(
load_vars
),
len
(
load_vars
)
+
len
(
load_fail_vars
)))
load_pretrained_weights
(
exe
,
train_prog
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
else
:
print_info
(
'Pretrained model dir {} not exists, training from scratch...'
.
...
...
@@ -418,12 +340,9 @@ def train(cfg):
step
)
log_writer
.
add_scalar
(
'Train/mean_acc'
,
mean_acc
,
step
)
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/step/sec'
,
speed
,
step
)
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/step/sec'
,
speed
,
step
)
sys
.
stdout
.
flush
()
avg_loss
=
0.0
cm
.
zero_matrix
()
...
...
@@ -445,12 +364,9 @@ def train(cfg):
).
format
(
epoch
,
step
,
lr
[
0
],
avg_loss
,
speed
,
calculate_eta
(
all_step
-
step
,
speed
)))
if
args
.
use_vdl
:
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/speed'
,
speed
,
step
)
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/speed'
,
speed
,
step
)
sys
.
stdout
.
flush
()
avg_loss
=
0.0
timer
.
restart
()
...
...
@@ -470,7 +386,7 @@ def train(cfg):
if
(
epoch
%
cfg
.
TRAIN
.
SNAPSHOT_EPOCH
==
0
or
epoch
==
cfg
.
SOLVER
.
NUM_EPOCHS
)
and
cfg
.
TRAINER_ID
==
0
:
ckpt_dir
=
save_checkpoint
(
exe
,
train_prog
,
epoch
)
ckpt_dir
=
save_checkpoint
(
train_prog
,
epoch
)
if
args
.
do_eval
:
print
(
"Evaluation start"
)
...
...
@@ -480,10 +396,8 @@ def train(cfg):
use_gpu
=
args
.
use_gpu
,
use_mpio
=
args
.
use_mpio
)
if
args
.
use_vdl
:
log_writer
.
add_scalar
(
'Evaluate/mean_iou'
,
mean_iou
,
step
)
log_writer
.
add_scalar
(
'Evaluate/mean_acc'
,
mean_acc
,
step
)
log_writer
.
add_scalar
(
'Evaluate/mean_iou'
,
mean_iou
,
step
)
log_writer
.
add_scalar
(
'Evaluate/mean_acc'
,
mean_acc
,
step
)
if
mean_iou
>
best_mIoU
:
best_mIoU
=
mean_iou
...
...
@@ -505,7 +419,7 @@ def train(cfg):
# save final model
if
cfg
.
TRAINER_ID
==
0
:
save_checkpoint
(
exe
,
train_prog
,
'final'
)
save_checkpoint
(
train_prog
,
'final'
)
def
main
(
args
):
...
...
pdseg/utils/__init__.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pdseg/utils/collect.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
@@ -98,7 +99,7 @@ class SegConfig(dict):
'DATASET.IMAGE_TYPE config error, only support `rgb`, `gray` and `rgba`'
)
if
self
.
MEAN
is
not
None
:
self
.
DATASET
.
PADDING_VALUE
=
[
x
*
255.0
for
x
in
self
.
MEAN
]
self
.
DATASET
.
PADDING_VALUE
=
[
x
*
255.0
for
x
in
self
.
MEAN
]
if
not
self
.
TRAIN_CROP_SIZE
:
raise
ValueError
(
...
...
@@ -111,9 +112,12 @@ class SegConfig(dict):
)
# Ensure file list is use UTF-8 encoding
train_sets
=
codecs
.
open
(
self
.
DATASET
.
TRAIN_FILE_LIST
,
'r'
,
'utf-8'
).
readlines
()
val_sets
=
codecs
.
open
(
self
.
DATASET
.
VAL_FILE_LIST
,
'r'
,
'utf-8'
).
readlines
()
test_sets
=
codecs
.
open
(
self
.
DATASET
.
TEST_FILE_LIST
,
'r'
,
'utf-8'
).
readlines
()
train_sets
=
codecs
.
open
(
self
.
DATASET
.
TRAIN_FILE_LIST
,
'r'
,
'utf-8'
).
readlines
()
val_sets
=
codecs
.
open
(
self
.
DATASET
.
VAL_FILE_LIST
,
'r'
,
'utf-8'
).
readlines
()
test_sets
=
codecs
.
open
(
self
.
DATASET
.
TEST_FILE_LIST
,
'r'
,
'utf-8'
).
readlines
()
self
.
DATASET
.
TRAIN_TOTAL_IMAGES
=
len
(
train_sets
)
self
.
DATASET
.
VAL_TOTAL_IMAGES
=
len
(
val_sets
)
self
.
DATASET
.
TEST_TOTAL_IMAGES
=
len
(
test_sets
)
...
...
@@ -122,12 +126,13 @@ class SegConfig(dict):
len
(
self
.
MODEL
.
MULTI_LOSS_WEIGHT
)
!=
3
:
self
.
MODEL
.
MULTI_LOSS_WEIGHT
=
[
1.0
,
0.4
,
0.16
]
if
self
.
AUG
.
AUG_METHOD
not
in
[
'unpadding'
,
'stepscaling'
,
'rangescaling'
]:
if
self
.
AUG
.
AUG_METHOD
not
in
[
'unpadding'
,
'stepscaling'
,
'rangescaling'
]:
raise
ValueError
(
'AUG.AUG_METHOD config error, only support `unpadding`, `unpadding` and `rangescaling`'
)
def
update_from_list
(
self
,
config_list
):
if
len
(
config_list
)
%
2
!=
0
:
raise
ValueError
(
...
...
pdseg/utils/config.py
浏览文件 @
61645b1d
#
-*- coding: utf-8 -*-
#
Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
.
#
coding: utf8
#
Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve
.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
pdseg/utils/dist_utils.py
浏览文件 @
61645b1d
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
...
...
pdseg/utils/fp16_utils.py
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
paddle
import
fluid
def
load_fp16_vars
(
executor
,
dirname
,
program
):
load_dirname
=
os
.
path
.
normpath
(
dirname
)
...
...
@@ -28,4 +44,4 @@ def load_fp16_vars(executor, dirname, program):
'load_as_fp16'
:
var
.
dtype
==
fluid
.
core
.
VarDesc
.
VarType
.
FP16
})
executor
.
run
(
load_prog
)
\ No newline at end of file
executor
.
run
(
load_prog
)
pdseg/utils/load_model_utils.py
0 → 100644
浏览文件 @
61645b1d
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
os.path
as
osp
import
six
import
numpy
as
np
def
parse_param_file
(
param_file
,
return_shape
=
True
):
from
paddle.fluid.proto.framework_pb2
import
VarType
f
=
open
(
param_file
,
'rb'
)
version
=
np
.
fromstring
(
f
.
read
(
4
),
dtype
=
'int32'
)
lod_level
=
np
.
fromstring
(
f
.
read
(
8
),
dtype
=
'int64'
)
for
i
in
range
(
int
(
lod_level
)):
_size
=
np
.
fromstring
(
f
.
read
(
8
),
dtype
=
'int64'
)
_
=
f
.
read
(
_size
)
version
=
np
.
fromstring
(
f
.
read
(
4
),
dtype
=
'int32'
)
tensor_desc
=
VarType
.
TensorDesc
()
tensor_desc_size
=
np
.
fromstring
(
f
.
read
(
4
),
dtype
=
'int32'
)
tensor_desc
.
ParseFromString
(
f
.
read
(
int
(
tensor_desc_size
)))
tensor_shape
=
tuple
(
tensor_desc
.
dims
)
if
return_shape
:
f
.
close
()
return
tuple
(
tensor_desc
.
dims
)
if
tensor_desc
.
data_type
!=
5
:
raise
Exception
(
"Unexpected data type while parse {}"
.
format
(
param_file
))
data_size
=
4
for
i
in
range
(
len
(
tensor_shape
)):
data_size
*=
tensor_shape
[
i
]
weight
=
np
.
fromstring
(
f
.
read
(
data_size
),
dtype
=
'float32'
)
f
.
close
()
return
np
.
reshape
(
weight
,
tensor_shape
)
def
load_pdparams
(
exe
,
main_prog
,
model_dir
):
import
paddle.fluid
as
fluid
from
paddle.fluid.proto.framework_pb2
import
VarType
from
paddle.fluid.framework
import
Program
vars_to_load
=
list
()
vars_not_load
=
list
()
import
pickle
with
open
(
osp
.
join
(
model_dir
,
'model.pdparams'
),
'rb'
)
as
f
:
params_dict
=
pickle
.
load
(
f
)
if
six
.
PY2
else
pickle
.
load
(
f
,
encoding
=
'latin1'
)
unused_vars
=
list
()
for
var
in
main_prog
.
list_vars
():
if
not
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
continue
if
var
.
name
not
in
params_dict
:
print
(
"{} is not in saved model"
.
format
(
var
.
name
))
vars_not_load
.
append
(
var
.
name
)
continue
if
var
.
shape
!=
params_dict
[
var
.
name
].
shape
:
unused_vars
.
append
(
var
.
name
)
vars_not_load
.
append
(
var
.
name
)
print
(
"[SKIP] Shape of pretrained weight {} doesn't match.(Pretrained: {}, Actual: {})"
.
format
(
var
.
name
,
params_dict
[
var
.
name
].
shape
,
var
.
shape
))
continue
vars_to_load
.
append
(
var
)
for
var_name
in
unused_vars
:
del
params_dict
[
var_name
]
fluid
.
io
.
set_program_state
(
main_prog
,
params_dict
)
if
len
(
vars_to_load
)
==
0
:
print
(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
)
else
:
print
(
"There are {}/{} varaibles in {} are loaded."
.
format
(
len
(
vars_to_load
),
len
(
vars_to_load
)
+
len
(
vars_not_load
),
model_dir
))
def
load_pretrained_weights
(
exe
,
main_prog
,
weights_dir
):
if
not
osp
.
exists
(
weights_dir
):
raise
Exception
(
"Path {} not exists."
.
format
(
weights_dir
))
if
osp
.
exists
(
osp
.
join
(
weights_dir
,
"model.pdparams"
)):
return
load_pdparams
(
exe
,
main_prog
,
weights_dir
)
import
paddle.fluid
as
fluid
vars_to_load
=
list
()
vars_not_load
=
list
()
for
var
in
main_prog
.
list_vars
():
if
not
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
continue
if
not
osp
.
exists
(
osp
.
join
(
weights_dir
,
var
.
name
)):
print
(
"[SKIP] Pretrained weight {}/{} doesn't exist"
.
format
(
weights_dir
,
var
.
name
))
vars_not_load
.
append
(
var
)
continue
pretrained_shape
=
parse_param_file
(
osp
.
join
(
weights_dir
,
var
.
name
))
actual_shape
=
tuple
(
var
.
shape
)
if
pretrained_shape
!=
actual_shape
:
print
(
"[SKIP] Shape of pretrained weight {}/{} doesn't match.(Pretrained: {}, Actual: {})"
.
format
(
weights_dir
,
var
.
name
,
pretrained_shape
,
actual_shape
))
vars_not_load
.
append
(
var
)
continue
vars_to_load
.
append
(
var
)
params_dict
=
fluid
.
io
.
load_program_state
(
weights_dir
,
var_list
=
vars_to_load
)
fluid
.
io
.
set_program_state
(
main_prog
,
params_dict
)
if
len
(
vars_to_load
)
==
0
:
print
(
"There is no pretrain weights loaded, maybe you should check you pretrain model!"
)
else
:
print
(
"There are {}/{} varaibles in {} are loaded."
.
format
(
len
(
vars_to_load
),
len
(
vars_to_load
)
+
len
(
vars_not_load
),
weights_dir
))
pdseg/utils/timer.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
pdseg/vis.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -115,7 +115,12 @@ def visualize(cfg,
ckpt_dir
=
cfg
.
TEST
.
TEST_MODEL
if
not
ckpt_dir
else
ckpt_dir
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
if
ckpt_dir
is
not
None
:
print
(
'load test model:'
,
ckpt_dir
)
try
:
fluid
.
load
(
test_prog
,
os
.
path
.
join
(
ckpt_dir
,
'model'
),
exe
)
except
:
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
save_dir
=
vis_dir
makedirs
(
save_dir
)
...
...
@@ -169,18 +174,13 @@ def visualize(cfg,
print
(
"VisualDL visualization epoch"
,
epoch
)
pred_mask_np
=
np
.
array
(
pred_mask
.
convert
(
"RGB"
))
log_writer
.
add_image
(
"Predict/{}"
.
format
(
img_name
),
pred_mask_np
,
epoch
)
log_writer
.
add_image
(
"Predict/{}"
.
format
(
img_name
),
pred_mask_np
,
epoch
)
# Original image
# BGR->RGB
img
=
cv2
.
imread
(
os
.
path
.
join
(
cfg
.
DATASET
.
DATA_DIR
,
img_name
))[...,
::
-
1
]
log_writer
.
add_image
(
"Images/{}"
.
format
(
img_name
),
img
,
epoch
)
img
=
cv2
.
imread
(
os
.
path
.
join
(
cfg
.
DATASET
.
DATA_DIR
,
img_name
))[...,
::
-
1
]
log_writer
.
add_image
(
"Images/{}"
.
format
(
img_name
),
img
,
epoch
)
# add ground truth (label) images
grt
=
grts
[
i
]
if
grt
is
not
None
:
...
...
@@ -189,10 +189,8 @@ def visualize(cfg,
grt_pil
.
putpalette
(
color_map
)
grt_pil
=
grt_pil
.
resize
((
org_shape
[
1
],
org_shape
[
0
]))
grt
=
np
.
array
(
grt_pil
.
convert
(
"RGB"
))
log_writer
.
add_image
(
"Label/{}"
.
format
(
img_name
),
grt
,
epoch
)
log_writer
.
add_image
(
"Label/{}"
.
format
(
img_name
),
grt
,
epoch
)
# If in local_test mode, only visualize 5 images just for testing
# procedure
...
...
pretrained_model/download_model.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
@@ -42,7 +43,7 @@ model_urls = {
"hrnet_w30_bn_imagenet"
:
"https://paddleseg.bj.bcebos.com/models/hrnet_w30_imagenet.tar"
,
"hrnet_w32_bn_imagenet"
:
"https://paddleseg.bj.bcebos.com/models/hrnet_w32_imagenet.tar"
,
"https://paddleseg.bj.bcebos.com/models/hrnet_w32_imagenet.tar"
,
"hrnet_w40_bn_imagenet"
:
"https://paddleseg.bj.bcebos.com/models/hrnet_w40_imagenet.tar"
,
"hrnet_w44_bn_imagenet"
:
...
...
slim/distillation/model_builder.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
slim/distillation/train_distill.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -44,6 +44,7 @@ from model_builder import parse_shape_from_file
from
eval
import
evaluate
from
vis
import
visualize
from
utils
import
dist_utils
from
utils.load_model_utils
import
load_pretrained_weights
import
solver
from
paddleslim.dist.single_distiller
import
merge
,
l2_loss
...
...
@@ -116,38 +117,7 @@ def parse_args():
return
parser
.
parse_args
()
def
save_vars
(
executor
,
dirname
,
program
=
None
,
vars
=
None
):
"""
Temporary resolution for Win save variables compatability.
Will fix in PaddlePaddle v1.5.2
"""
save_program
=
fluid
.
Program
()
save_block
=
save_program
.
global_block
()
for
each_var
in
vars
:
# NOTE: don't save the variable which type is RAW
if
each_var
.
type
==
fluid
.
core
.
VarDesc
.
VarType
.
RAW
:
continue
new_var
=
save_block
.
create_var
(
name
=
each_var
.
name
,
shape
=
each_var
.
shape
,
dtype
=
each_var
.
dtype
,
type
=
each_var
.
type
,
lod_level
=
each_var
.
lod_level
,
persistable
=
True
)
file_path
=
os
.
path
.
join
(
dirname
,
new_var
.
name
)
file_path
=
os
.
path
.
normpath
(
file_path
)
save_block
.
append_op
(
type
=
'save'
,
inputs
=
{
'X'
:
[
new_var
]},
outputs
=
{},
attrs
=
{
'file_path'
:
file_path
})
executor
.
run
(
save_program
)
def
save_checkpoint
(
exe
,
program
,
ckpt_name
):
def
save_checkpoint
(
program
,
ckpt_name
):
"""
Save checkpoint for evaluation or resume training
"""
...
...
@@ -156,29 +126,22 @@ def save_checkpoint(exe, program, ckpt_name):
if
not
os
.
path
.
isdir
(
ckpt_dir
):
os
.
makedirs
(
ckpt_dir
)
save_vars
(
exe
,
ckpt_dir
,
program
,
vars
=
list
(
filter
(
fluid
.
io
.
is_persistable
,
program
.
list_vars
())))
fluid
.
save
(
program
,
os
.
path
.
join
(
ckpt_dir
,
'model'
))
return
ckpt_dir
def
load_checkpoint
(
exe
,
program
):
"""
Load checkpoiont f
rom pretrained model directory for resume
training
Load checkpoiont f
or resuming
training
"""
print
(
'Resume model training from:'
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
)
if
not
os
.
path
.
exists
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
):
raise
ValueError
(
"TRAIN.PRETRAIN_MODEL {} not exist!"
.
format
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
))
fluid
.
io
.
load_persistables
(
exe
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
,
main_program
=
program
)
model_path
=
cfg
.
TRAIN
.
RESUME_MODEL_DIR
print
(
'Resume model training from:'
,
model_path
)
if
not
os
.
path
.
exists
(
model_path
):
raise
ValueError
(
"TRAIN.PRETRAIN_MODEL {} not exist!"
.
format
(
model_path
))
fluid
.
load
(
program
,
os
.
path
.
join
(
model_path
,
'model'
),
exe
)
# Check is path ended by path spearator
if
model_path
[
-
1
]
==
os
.
sep
:
model_path
=
model_path
[
0
:
-
1
]
...
...
@@ -193,7 +156,6 @@ def load_checkpoint(exe, program):
else
:
raise
ValueError
(
"Resume model path is not valid!"
)
print
(
"Model checkpoint loaded successfully!"
)
return
begin_epoch
...
...
@@ -289,7 +251,11 @@ def train(cfg):
ckpt_dir
=
cfg
.
SLIM
.
KNOWLEDGE_DISTILL_TEACHER_MODEL_DIR
assert
ckpt_dir
is
not
None
print
(
'load teacher model:'
,
ckpt_dir
)
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
teacher_program
)
if
os
.
path
.
exists
(
ckpt_dir
):
try
:
fluid
.
load
(
teacher_program
,
os
.
path
.
join
(
ckpt_dir
,
'model'
),
exe
)
except
:
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
teacher_program
)
# cfg = load_config(FLAGS.config)
cfg
.
update_from_file
(
args
.
cfg_file
)
...
...
@@ -355,42 +321,8 @@ def train(cfg):
begin_epoch
=
load_checkpoint
(
exe
,
fluid
.
default_main_program
())
# Load pretrained model
elif
os
.
path
.
exists
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
):
print_info
(
'Pretrained model dir: '
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
load_vars
=
[]
load_fail_vars
=
[]
def
var_shape_matched
(
var
,
shape
):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist
=
os
.
path
.
exists
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
if
var_exist
:
var_shape
=
parse_shape_from_file
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
return
var_shape
==
shape
return
False
for
x
in
fluid
.
default_main_program
().
list_vars
():
if
isinstance
(
x
,
fluid
.
framework
.
Parameter
):
shape
=
tuple
(
fluid
.
global_scope
().
find_var
(
x
.
name
).
get_tensor
().
shape
())
if
var_shape_matched
(
x
,
shape
):
load_vars
.
append
(
x
)
else
:
load_fail_vars
.
append
(
x
)
fluid
.
io
.
load_vars
(
exe
,
dirname
=
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
vars
=
load_vars
)
for
var
in
load_vars
:
print_info
(
"Parameter[{}] loaded sucessfully!"
.
format
(
var
.
name
))
for
var
in
load_fail_vars
:
print_info
(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it."
.
format
(
var
.
name
))
print_info
(
"{}/{} pretrained parameters loaded successfully!"
.
format
(
len
(
load_vars
),
len
(
load_vars
)
+
len
(
load_fail_vars
)))
load_pretrained_weights
(
exe
,
fluid
.
default_main_program
(),
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
else
:
print_info
(
'Pretrained model dir {} not exists, training from scratch...'
.
...
...
@@ -475,12 +407,9 @@ def train(cfg):
step
)
log_writer
.
add_scalar
(
'Train/mean_acc'
,
mean_acc
,
step
)
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/step/sec'
,
speed
,
step
)
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/step/sec'
,
speed
,
step
)
sys
.
stdout
.
flush
()
avg_loss
=
0.0
cm
.
zero_matrix
()
...
...
@@ -503,16 +432,13 @@ def train(cfg):
speed
=
args
.
log_steps
/
timer
.
elapsed_time
()
print
((
"epoch={} step={} lr={:.5f} loss={:.4f} teacher loss={:.4f} distill loss={:.4f} step/sec={:.3f} | ETA {}"
).
format
(
epoch
,
step
,
lr
[
0
],
avg_loss
,
avg_
t_loss
,
avg_
d_loss
,
speed
,
).
format
(
epoch
,
step
,
lr
[
0
],
avg_loss
,
avg_t_loss
,
avg_d_loss
,
speed
,
calculate_eta
(
all_step
-
step
,
speed
)))
if
args
.
use_vdl
:
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/speed'
,
speed
,
step
)
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/speed'
,
speed
,
step
)
sys
.
stdout
.
flush
()
avg_loss
=
0.0
avg_t_loss
=
0.0
...
...
@@ -527,7 +453,7 @@ def train(cfg):
if
(
epoch
%
cfg
.
TRAIN
.
SNAPSHOT_EPOCH
==
0
or
epoch
==
cfg
.
SOLVER
.
NUM_EPOCHS
)
and
cfg
.
TRAINER_ID
==
0
:
ckpt_dir
=
save_checkpoint
(
exe
,
fluid
.
default_main_program
(),
epoch
)
ckpt_dir
=
save_checkpoint
(
fluid
.
default_main_program
(),
epoch
)
if
args
.
do_eval
:
print
(
"Evaluation start"
)
...
...
@@ -537,10 +463,8 @@ def train(cfg):
use_gpu
=
args
.
use_gpu
,
use_mpio
=
args
.
use_mpio
)
if
args
.
use_vdl
:
log_writer
.
add_scalar
(
'Evaluate/mean_iou'
,
mean_iou
,
step
)
log_writer
.
add_scalar
(
'Evaluate/mean_acc'
,
mean_acc
,
step
)
log_writer
.
add_scalar
(
'Evaluate/mean_iou'
,
mean_iou
,
step
)
log_writer
.
add_scalar
(
'Evaluate/mean_acc'
,
mean_acc
,
step
)
if
mean_iou
>
best_mIoU
:
best_mIoU
=
mean_iou
...
...
@@ -560,11 +484,11 @@ def train(cfg):
ckpt_dir
=
ckpt_dir
,
log_writer
=
log_writer
)
if
cfg
.
TRAINER_ID
==
0
:
ckpt_dir
=
save_checkpoint
(
exe
,
fluid
.
default_main_program
(),
epoch
)
ckpt_dir
=
save_checkpoint
(
fluid
.
default_main_program
(),
epoch
)
# save final model
if
cfg
.
TRAINER_ID
==
0
:
save_checkpoint
(
exe
,
fluid
.
default_main_program
(),
'final'
)
save_checkpoint
(
fluid
.
default_main_program
(),
'final'
)
def
main
(
args
):
...
...
slim/nas/deeplab.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -27,6 +27,7 @@ from models.libs.model_libs import separate_conv
from
models.backbone.mobilenet_v2
import
MobileNetV2
as
mobilenet_backbone
from
models.backbone.xception
import
Xception
as
xception_backbone
def
encoder
(
input
):
# 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv
# ASPP_WITH_SEP_CONV:默认为真,使用depthwise可分离卷积,否则使用普通卷积
...
...
@@ -47,8 +48,7 @@ def encoder(input):
with
scope
(
'encoder'
):
channel
=
256
with
scope
(
"image_pool"
):
image_avg
=
fluid
.
layers
.
reduce_mean
(
input
,
[
2
,
3
],
keep_dim
=
True
)
image_avg
=
fluid
.
layers
.
reduce_mean
(
input
,
[
2
,
3
],
keep_dim
=
True
)
image_avg
=
bn_relu
(
conv
(
image_avg
,
...
...
@@ -191,7 +191,10 @@ def nas_backbone(input, arch):
end_points
=
8
decode_point
=
3
data
,
decode_shortcuts
=
arch
(
input
,
end_points
=
end_points
,
return_block
=
decode_point
,
output_stride
=
16
)
input
,
end_points
=
end_points
,
return_block
=
decode_point
,
output_stride
=
16
)
decode_shortcut
=
decode_shortcuts
[
decode_point
]
return
data
,
decode_shortcut
...
...
slim/nas/eval_nas.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -123,7 +123,10 @@ def evaluate(cfg, ckpt_dir=None, use_gpu=False, use_mpio=False, **kwargs):
if
ckpt_dir
is
not
None
:
print
(
'load test model:'
,
ckpt_dir
)
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
try
:
fluid
.
load
(
test_prog
,
os
.
path
.
join
(
ckpt_dir
,
'model'
),
exe
)
except
:
fluid
.
io
.
load_params
(
exe
,
ckpt_dir
,
main_program
=
test_prog
)
# Use streaming confusion matrix to calculate mean_iou
np
.
set_printoptions
(
...
...
slim/nas/mobilenetv2_search_space.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
@@ -31,7 +32,7 @@ __all__ = ["MobileNetV2SpaceSeg"]
class
MobileNetV2SpaceSeg
(
SearchSpaceBase
):
def
__init__
(
self
,
input_size
,
output_size
,
block_num
,
block_mask
=
None
):
super
(
MobileNetV2SpaceSeg
,
self
).
__init__
(
input_size
,
output_size
,
block_num
,
block_mask
)
block_num
,
block_mask
)
# self.head_num means the first convolution channel
self
.
head_num
=
np
.
array
([
3
,
4
,
8
,
12
,
16
,
24
,
32
])
#7
# self.filter_num1 ~ self.filter_num6 means following convlution channel
...
...
@@ -48,7 +49,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
self
.
k_size
=
np
.
array
([
3
,
5
])
#2
# self.multiply means expansion_factor of each _inverted_residual_unit
self
.
multiply
=
np
.
array
([
1
,
2
,
3
,
4
,
6
])
#5
# self.repeat means repeat_num _inverted_residual_unit in each _invresi_blocks
# self.repeat means repeat_num _inverted_residual_unit in each _invresi_blocks
self
.
repeat
=
np
.
array
([
1
,
2
,
3
,
4
,
5
,
6
])
#6
def
init_tokens
(
self
):
...
...
@@ -72,7 +73,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
def
range_table
(
self
):
"""
Get range table of current search space, constrains the range of tokens.
Get range table of current search space, constrains the range of tokens.
"""
# head_num + 6 * [multiple(expansion_factor), filter_num, repeat, kernel_size]
# yapf: disable
...
...
@@ -95,8 +96,8 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
tokens
=
self
.
init_tokens
()
self
.
bottleneck_params_list
=
[]
self
.
bottleneck_params_list
.
append
(
(
1
,
self
.
head_num
[
tokens
[
0
]],
1
,
1
,
3
))
self
.
bottleneck_params_list
.
append
(
(
1
,
self
.
head_num
[
tokens
[
0
]],
1
,
1
,
3
))
self
.
bottleneck_params_list
.
append
(
(
self
.
multiply
[
tokens
[
1
]],
self
.
filter_num1
[
tokens
[
2
]],
self
.
repeat
[
tokens
[
3
]],
2
,
self
.
k_size
[
tokens
[
4
]]))
...
...
@@ -150,7 +151,7 @@ class MobileNetV2SpaceSeg(SearchSpaceBase):
return
(
True
if
count
==
points
else
False
)
#conv1
# all padding is 'SAME' in the conv2d, can compute the actual padding automatic.
# all padding is 'SAME' in the conv2d, can compute the actual padding automatic.
input
=
conv_bn_layer
(
input
,
num_filters
=
int
(
32
*
self
.
scale
),
...
...
slim/nas/model_builder.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
slim/nas/train_nas.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -47,6 +47,7 @@ from model_builder import parse_shape_from_file
from
eval_nas
import
evaluate
from
vis
import
visualize
from
utils
import
dist_utils
from
utils.load_model_utils
import
load_pretrained_weights
from
mobilenetv2_search_space
import
MobileNetV2SpaceSeg
from
paddleslim.nas.search_space.search_space_factory
import
SearchSpaceFactory
...
...
@@ -116,38 +117,7 @@ def parse_args():
return
parser
.
parse_args
()
def
save_vars
(
executor
,
dirname
,
program
=
None
,
vars
=
None
):
"""
Temporary resolution for Win save variables compatability.
Will fix in PaddlePaddle v1.5.2
"""
save_program
=
fluid
.
Program
()
save_block
=
save_program
.
global_block
()
for
each_var
in
vars
:
# NOTE: don't save the variable which type is RAW
if
each_var
.
type
==
fluid
.
core
.
VarDesc
.
VarType
.
RAW
:
continue
new_var
=
save_block
.
create_var
(
name
=
each_var
.
name
,
shape
=
each_var
.
shape
,
dtype
=
each_var
.
dtype
,
type
=
each_var
.
type
,
lod_level
=
each_var
.
lod_level
,
persistable
=
True
)
file_path
=
os
.
path
.
join
(
dirname
,
new_var
.
name
)
file_path
=
os
.
path
.
normpath
(
file_path
)
save_block
.
append_op
(
type
=
'save'
,
inputs
=
{
'X'
:
[
new_var
]},
outputs
=
{},
attrs
=
{
'file_path'
:
file_path
})
executor
.
run
(
save_program
)
def
save_checkpoint
(
exe
,
program
,
ckpt_name
):
def
save_checkpoint
(
program
,
ckpt_name
):
"""
Save checkpoint for evaluation or resume training
"""
...
...
@@ -156,29 +126,22 @@ def save_checkpoint(exe, program, ckpt_name):
if
not
os
.
path
.
isdir
(
ckpt_dir
):
os
.
makedirs
(
ckpt_dir
)
save_vars
(
exe
,
ckpt_dir
,
program
,
vars
=
list
(
filter
(
fluid
.
io
.
is_persistable
,
program
.
list_vars
())))
fluid
.
save
(
program
,
os
.
path
.
join
(
ckpt_dir
,
'model'
))
return
ckpt_dir
def
load_checkpoint
(
exe
,
program
):
"""
Load checkpoiont f
rom pretrained model directory for resume
training
Load checkpoiont f
or resuming
training
"""
print
(
'Resume model training from:'
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
)
if
not
os
.
path
.
exists
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
):
raise
ValueError
(
"TRAIN.PRETRAIN_MODEL {} not exist!"
.
format
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
))
fluid
.
io
.
load_persistables
(
exe
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
,
main_program
=
program
)
model_path
=
cfg
.
TRAIN
.
RESUME_MODEL_DIR
print
(
'Resume model training from:'
,
model_path
)
if
not
os
.
path
.
exists
(
model_path
):
raise
ValueError
(
"TRAIN.PRETRAIN_MODEL {} not exist!"
.
format
(
model_path
))
fluid
.
load
(
program
,
os
.
path
.
join
(
model_path
,
'model'
),
exe
)
# Check is path ended by path spearator
if
model_path
[
-
1
]
==
os
.
sep
:
model_path
=
model_path
[
0
:
-
1
]
...
...
@@ -193,7 +156,6 @@ def load_checkpoint(exe, program):
else
:
raise
ValueError
(
"Resume model path is not valid!"
)
print
(
"Model checkpoint loaded successfully!"
)
return
begin_epoch
...
...
@@ -245,8 +207,6 @@ def train(cfg):
yield
item
[
0
],
item
[
1
],
item
[
2
]
# Get device environment
# places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
# place = places[0]
gpu_id
=
int
(
os
.
environ
.
get
(
'FLAGS_selected_gpus'
,
0
))
place
=
fluid
.
CUDAPlace
(
gpu_id
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
places
=
fluid
.
cuda_places
()
if
args
.
use_gpu
else
fluid
.
cpu_places
()
...
...
@@ -326,43 +286,8 @@ def train(cfg):
begin_epoch
=
load_checkpoint
(
exe
,
train_prog
)
# Load pretrained model
elif
os
.
path
.
exists
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
):
print_info
(
'Pretrained model dir: '
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
load_vars
=
[]
load_fail_vars
=
[]
def
var_shape_matched
(
var
,
shape
):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist
=
os
.
path
.
exists
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
if
var_exist
:
var_shape
=
parse_shape_from_file
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
return
var_shape
==
shape
return
False
for
x
in
train_prog
.
list_vars
():
if
isinstance
(
x
,
fluid
.
framework
.
Parameter
):
shape
=
tuple
(
fluid
.
global_scope
().
find_var
(
x
.
name
).
get_tensor
().
shape
())
if
var_shape_matched
(
x
,
shape
):
load_vars
.
append
(
x
)
else
:
load_fail_vars
.
append
(
x
)
fluid
.
io
.
load_vars
(
exe
,
dirname
=
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
vars
=
load_vars
)
for
var
in
load_vars
:
print_info
(
"Parameter[{}] loaded sucessfully!"
.
format
(
var
.
name
))
for
var
in
load_fail_vars
:
print_info
(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it."
.
format
(
var
.
name
))
print_info
(
"{}/{} pretrained parameters loaded successfully!"
.
format
(
len
(
load_vars
),
len
(
load_vars
)
+
len
(
load_fail_vars
)))
load_pretrained_weights
(
exe
,
train_prog
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
else
:
print_info
(
'Pretrained model dir {} not exists, training from scratch...'
.
...
...
@@ -419,8 +344,7 @@ def train(cfg):
except
Exception
as
e
:
print
(
e
)
if
epoch
>
cfg
.
SLIM
.
NAS_START_EVAL_EPOCH
:
ckpt_dir
=
save_checkpoint
(
exe
,
train_prog
,
'{}_tmp'
.
format
(
port
))
ckpt_dir
=
save_checkpoint
(
train_prog
,
'{}_tmp'
.
format
(
port
))
_
,
mean_iou
,
_
,
mean_acc
=
evaluate
(
cfg
=
cfg
,
arch
=
arch
,
...
...
slim/prune/eval_prune.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
slim/prune/train_prune.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -46,6 +46,7 @@ from models.model_builder import parse_shape_from_file
from
eval_prune
import
evaluate
from
vis
import
visualize
from
utils
import
dist_utils
from
utils.load_model_utils
import
load_pretrained_weights
from
paddleslim.prune
import
Pruner
,
save_model
from
paddleslim.analysis
import
flops
...
...
@@ -285,42 +286,7 @@ def train(cfg):
begin_epoch
=
load_checkpoint
(
exe
,
train_prog
)
# Load pretrained model
elif
os
.
path
.
exists
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
):
print_info
(
'Pretrained model dir: '
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
load_vars
=
[]
load_fail_vars
=
[]
def
var_shape_matched
(
var
,
shape
):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist
=
os
.
path
.
exists
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
if
var_exist
:
var_shape
=
parse_shape_from_file
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
return
var_shape
==
shape
return
False
for
x
in
train_prog
.
list_vars
():
if
isinstance
(
x
,
fluid
.
framework
.
Parameter
):
shape
=
tuple
(
fluid
.
global_scope
().
find_var
(
x
.
name
).
get_tensor
().
shape
())
if
var_shape_matched
(
x
,
shape
):
load_vars
.
append
(
x
)
else
:
load_fail_vars
.
append
(
x
)
fluid
.
io
.
load_vars
(
exe
,
dirname
=
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
vars
=
load_vars
)
for
var
in
load_vars
:
print_info
(
"Parameter[{}] loaded sucessfully!"
.
format
(
var
.
name
))
for
var
in
load_fail_vars
:
print_info
(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it."
.
format
(
var
.
name
))
print_info
(
"{}/{} pretrained parameters loaded successfully!"
.
format
(
len
(
load_vars
),
len
(
load_vars
)
+
len
(
load_fail_vars
)))
load_pretrained_weights
(
exe
,
train_prog
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
else
:
print_info
(
'Pretrained model dir {} not exists, training from scratch...'
.
...
...
@@ -409,12 +375,9 @@ def train(cfg):
step
)
log_writer
.
add_scalar
(
'Train/mean_acc'
,
mean_acc
,
step
)
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/step/sec'
,
speed
,
step
)
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/step/sec'
,
speed
,
step
)
sys
.
stdout
.
flush
()
avg_loss
=
0.0
cm
.
zero_matrix
()
...
...
@@ -436,12 +399,9 @@ def train(cfg):
).
format
(
epoch
,
step
,
lr
[
0
],
avg_loss
,
speed
,
calculate_eta
(
all_step
-
step
,
speed
)))
if
args
.
use_vdl
:
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/speed'
,
speed
,
step
)
log_writer
.
add_scalar
(
'Train/loss'
,
avg_loss
,
step
)
log_writer
.
add_scalar
(
'Train/lr'
,
lr
[
0
],
step
)
log_writer
.
add_scalar
(
'Train/speed'
,
speed
,
step
)
sys
.
stdout
.
flush
()
avg_loss
=
0.0
timer
.
restart
()
...
...
@@ -464,10 +424,8 @@ def train(cfg):
use_gpu
=
args
.
use_gpu
,
use_mpio
=
args
.
use_mpio
)
if
args
.
use_vdl
:
log_writer
.
add_scalar
(
'Evaluate/mean_iou'
,
mean_iou
,
step
)
log_writer
.
add_scalar
(
'Evaluate/mean_acc'
,
mean_acc
,
step
)
log_writer
.
add_scalar
(
'Evaluate/mean_iou'
,
mean_iou
,
step
)
log_writer
.
add_scalar
(
'Evaluate/mean_acc'
,
mean_acc
,
step
)
# Use VisualDL to visualize results
if
args
.
use_vdl
and
cfg
.
DATASET
.
VIS_FILE_LIST
is
not
None
:
...
...
slim/quantization/eval_quant.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
slim/quantization/export_model.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
slim/quantization/train_quant.py
浏览文件 @
61645b1d
# coding: utf8
#
c
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
C
opyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -40,7 +40,8 @@ from models.model_builder import parse_shape_from_file
from
eval_quant
import
evaluate
from
vis
import
visualize
from
utils
import
dist_utils
from
train
import
save_vars
,
save_checkpoint
,
load_checkpoint
,
update_best_model
,
print_info
from
utils.load_model_utils
import
load_pretrained_weights
from
train
import
update_best_model
,
print_info
from
paddleslim.quant
import
quant_aware
...
...
@@ -103,6 +104,55 @@ def parse_args():
return
parser
.
parse_args
()
def
save_checkpoint
(
exe
,
program
,
ckpt_name
):
"""
Save checkpoint for evaluation or resume training
"""
ckpt_dir
=
os
.
path
.
join
(
cfg
.
TRAIN
.
MODEL_SAVE_DIR
,
str
(
ckpt_name
))
print
(
"Save model checkpoint to {}"
.
format
(
ckpt_dir
))
if
not
os
.
path
.
isdir
(
ckpt_dir
):
os
.
makedirs
(
ckpt_dir
)
fluid
.
io
.
save_vars
(
exe
,
ckpt_dir
,
program
,
vars
=
list
(
filter
(
fluid
.
io
.
is_persistable
,
program
.
list_vars
())))
return
ckpt_dir
def
load_checkpoint
(
exe
,
program
):
"""
Load checkpoiont from pretrained model directory for resume training
"""
print
(
'Resume model training from:'
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
)
if
not
os
.
path
.
exists
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
):
raise
ValueError
(
"TRAIN.PRETRAIN_MODEL {} not exist!"
.
format
(
cfg
.
TRAIN
.
RESUME_MODEL_DIR
))
fluid
.
io
.
load_persistables
(
exe
,
cfg
.
TRAIN
.
RESUME_MODEL_DIR
,
main_program
=
program
)
model_path
=
cfg
.
TRAIN
.
RESUME_MODEL_DIR
# Check is path ended by path spearator
if
model_path
[
-
1
]
==
os
.
sep
:
model_path
=
model_path
[
0
:
-
1
]
epoch_name
=
os
.
path
.
basename
(
model_path
)
# If resume model is final model
if
epoch_name
==
'final'
:
begin_epoch
=
cfg
.
SOLVER
.
NUM_EPOCHS
# If resume model path is end of digit, restore epoch status
elif
epoch_name
.
isdigit
():
epoch
=
int
(
epoch_name
)
begin_epoch
=
epoch
+
1
else
:
raise
ValueError
(
"Resume model path is not valid!"
)
print
(
"Model checkpoint loaded successfully!"
)
return
begin_epoch
def
train_quant
(
cfg
):
startup_prog
=
fluid
.
Program
()
train_prog
=
fluid
.
Program
()
...
...
@@ -182,42 +232,7 @@ def train_quant(cfg):
begin_epoch
=
load_checkpoint
(
exe
,
train_prog
)
# Load pretrained model
elif
os
.
path
.
exists
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
):
print_info
(
'Pretrained model dir: '
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
load_vars
=
[]
load_fail_vars
=
[]
def
var_shape_matched
(
var
,
shape
):
"""
Check whehter persitable variable shape is match with current network
"""
var_exist
=
os
.
path
.
exists
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
if
var_exist
:
var_shape
=
parse_shape_from_file
(
os
.
path
.
join
(
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
var
.
name
))
return
var_shape
==
shape
return
False
for
x
in
train_prog
.
list_vars
():
if
isinstance
(
x
,
fluid
.
framework
.
Parameter
):
shape
=
tuple
(
fluid
.
global_scope
().
find_var
(
x
.
name
).
get_tensor
().
shape
())
if
var_shape_matched
(
x
,
shape
):
load_vars
.
append
(
x
)
else
:
load_fail_vars
.
append
(
x
)
fluid
.
io
.
load_vars
(
exe
,
dirname
=
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
,
vars
=
load_vars
)
for
var
in
load_vars
:
print_info
(
"Parameter[{}] loaded sucessfully!"
.
format
(
var
.
name
))
for
var
in
load_fail_vars
:
print_info
(
"Parameter[{}] don't exist or shape does not match current network, skip"
" to load it."
.
format
(
var
.
name
))
print_info
(
"{}/{} pretrained parameters loaded successfully!"
.
format
(
len
(
load_vars
),
len
(
load_vars
)
+
len
(
load_fail_vars
)))
load_pretrained_weights
(
exe
,
train_prog
,
cfg
.
TRAIN
.
PRETRAINED_MODEL_DIR
)
else
:
print_info
(
'Pretrained model dir {} not exists, training from scratch...'
.
...
...
test/local_test_cityscapes.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
test/local_test_pet.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
test/test_utils.py
浏览文件 @
61645b1d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the "License"
);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录