Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
12434165
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
286
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
12434165
编写于
8月 20, 2020
作者:
P
pennypm
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add contrib/SpatialEmbeddings
上级
12bf97cf
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
404 addition
and
0 deletion
+404
-0
contrib/SpatialEmbeddings/README.md
contrib/SpatialEmbeddings/README.md
+63
-0
contrib/SpatialEmbeddings/config.py
contrib/SpatialEmbeddings/config.py
+24
-0
contrib/SpatialEmbeddings/data/kitti/0007/kitti_0007_000512.png
...b/SpatialEmbeddings/data/kitti/0007/kitti_0007_000512.png
+0
-0
contrib/SpatialEmbeddings/data/kitti/0007/kitti_0007_000518.png
...b/SpatialEmbeddings/data/kitti/0007/kitti_0007_000518.png
+0
-0
contrib/SpatialEmbeddings/data/test.txt
contrib/SpatialEmbeddings/data/test.txt
+2
-0
contrib/SpatialEmbeddings/download_SpatialEmbeddings_kitti.py
...rib/SpatialEmbeddings/download_SpatialEmbeddings_kitti.py
+32
-0
contrib/SpatialEmbeddings/imgs/kitti_0007_000518_ori.png
contrib/SpatialEmbeddings/imgs/kitti_0007_000518_ori.png
+0
-0
contrib/SpatialEmbeddings/imgs/kitti_0007_000518_pred.png
contrib/SpatialEmbeddings/imgs/kitti_0007_000518_pred.png
+0
-0
contrib/SpatialEmbeddings/infer.py
contrib/SpatialEmbeddings/infer.py
+109
-0
contrib/SpatialEmbeddings/run.sh
contrib/SpatialEmbeddings/run.sh
+2
-0
contrib/SpatialEmbeddings/utils/__init__.py
contrib/SpatialEmbeddings/utils/__init__.py
+0
-0
contrib/SpatialEmbeddings/utils/data_util.py
contrib/SpatialEmbeddings/utils/data_util.py
+87
-0
contrib/SpatialEmbeddings/utils/palette.py
contrib/SpatialEmbeddings/utils/palette.py
+38
-0
contrib/SpatialEmbeddings/utils/util.py
contrib/SpatialEmbeddings/utils/util.py
+47
-0
未找到文件。
contrib/SpatialEmbeddings/README.md
0 → 100644
浏览文件 @
12434165
# SpatialEmbeddings
## 模型概述
本模型是基于proposal-free的实例分割模型,快速实时,同时准确率高,适用于自动驾驶等实时场景。
本模型基于KITTI中MOTS数据集训练得到,是论文 Segment as Points for Efficient Online Multi-Object Tracking and Segmentation中的分割部分
[
论文地址
](
https://arxiv.org/pdf/2007.01550.pdf
)
## KITTI MOTS指标
KITTI MOTS验证集AP:0.76, AP_50%:0.915
## 代码使用说明
### 1. 模型下载
执行以下命令下载并解压SpatialEmbeddings预测模型:
```
python download_SpatialEmbeddings_kitti.py
```
或点击
[
链接
](
https://paddleseg.bj.bcebos.com/models/SpatialEmbeddings_kitti.tar
)
进行手动下载并解压。
### 2. 数据下载
前往KITTI官网下载MOTS比赛数据
[
链接
](
https://www.vision.rwth-aachen.de/page/mots
)
下载后解压到./data文件夹下, 并生成验证集图片路径的test.txt
### 3. 快速预测
使用GPU预测
```
python -u infer.py --use_gpu
```
使用CPU预测:
```
python -u infer.py
```
数据及模型路径等详细配置见config.py文件
#### 4. 预测结果示例:
原图:
!
[](
imgs/kitti_0007_000518_ori.png
)
预测结果:
!
[](
imgs/kitti_0007_000518_pred.png
)
## 引用
**论文**
*Instance Segmentation by Jointly Optimizing Spatial Embeddings and Clustering Bandwidth*
**代码**
https://github.com/davyneven/SpatialEmbeddings
contrib/SpatialEmbeddings/config.py
0 → 100644
浏览文件 @
12434165
# -*- coding: utf-8 -*-
from
utils.util
import
AttrDict
,
merge_cfg_from_args
,
get_arguments
import
os
args
=
get_arguments
()
cfg
=
AttrDict
()
# 待预测图像所在路径
cfg
.
data_dir
=
"data"
# 待预测图像名称列表
cfg
.
data_list_file
=
os
.
path
.
join
(
"data"
,
"test.txt"
)
# 模型加载路径
cfg
.
model_path
=
'SpatialEmbeddings_kitti'
# 预测结果保存路径
cfg
.
vis_dir
=
"result"
# sigma值
cfg
.
n_sigma
=
2
# 中心点阈值
cfg
.
threshold
=
0.94
# 点集数阈值
cfg
.
min_pixel
=
160
merge_cfg_from_args
(
args
,
cfg
)
contrib/SpatialEmbeddings/data/kitti/0007/kitti_0007_000512.png
0 → 100755
浏览文件 @
12434165
952.5 KB
contrib/SpatialEmbeddings/data/kitti/0007/kitti_0007_000518.png
0 → 100755
浏览文件 @
12434165
960.0 KB
contrib/SpatialEmbeddings/data/test.txt
0 → 100644
浏览文件 @
12434165
kitti/0007/kitti_0007_000512.png
kitti/0007/kitti_0007_000518.png
contrib/SpatialEmbeddings/download_SpatialEmbeddings_kitti.py
0 → 100644
浏览文件 @
12434165
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
os
LOCAL_PATH
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
TEST_PATH
=
os
.
path
.
join
(
LOCAL_PATH
,
".."
,
".."
,
"test"
)
sys
.
path
.
append
(
TEST_PATH
)
from
test_utils
import
download_file_and_uncompress
if
__name__
==
"__main__"
:
download_file_and_uncompress
(
url
=
'https://paddleseg.bj.bcebos.com/models/SpatialEmbeddings_kitti.tar'
,
savepath
=
LOCAL_PATH
,
extrapath
=
LOCAL_PATH
,
extraname
=
'SpatialEmbeddings_kitti'
)
print
(
"Pretrained Model download success!"
)
contrib/SpatialEmbeddings/imgs/kitti_0007_000518_ori.png
0 → 100755
浏览文件 @
12434165
960.0 KB
contrib/SpatialEmbeddings/imgs/kitti_0007_000518_pred.png
0 → 100644
浏览文件 @
12434165
1.7 KB
contrib/SpatialEmbeddings/infer.py
0 → 100644
浏览文件 @
12434165
# -*- coding: utf-8 -*-
import
os
import
numpy
as
np
from
utils.util
import
get_arguments
from
utils.palette
import
get_palette
from
utils.data_util
import
Cluster
,
pad_img
from
PIL
import
Image
as
PILImage
import
importlib
import
paddle.fluid
as
fluid
args
=
get_arguments
()
config
=
importlib
.
import_module
(
'config'
)
cfg
=
getattr
(
config
,
'cfg'
)
cluster
=
Cluster
()
# 预测数据集类
class
TestDataSet
():
def
__init__
(
self
):
self
.
data_dir
=
cfg
.
data_dir
self
.
data_list_file
=
cfg
.
data_list_file
self
.
data_list
=
self
.
get_data_list
()
self
.
data_num
=
len
(
self
.
data_list
)
def
get_data_list
(
self
):
# 获取预测图像路径列表
data_list
=
[]
data_file_handler
=
open
(
self
.
data_list_file
,
'r'
)
for
line
in
data_file_handler
:
img_name
=
line
.
strip
()
name_prefix
=
img_name
.
split
(
'.'
)[
0
]
if
len
(
img_name
.
split
(
'.'
))
==
1
:
img_name
=
img_name
+
'.jpg'
img_path
=
os
.
path
.
join
(
self
.
data_dir
,
img_name
)
data_list
.
append
(
img_path
)
return
data_list
def
preprocess
(
self
,
img
):
# 图像预处理
h
,
w
=
img
.
shape
[:
2
]
h_new
=
(
h
//
32
+
1
if
h
%
32
!=
0
else
h
//
32
)
*
32
w_new
=
(
w
//
32
+
1
if
w
%
32
!=
0
else
w
//
32
)
*
32
img
=
np
.
pad
(
img
,
((
0
,
h_new
-
h
),
(
0
,
w_new
-
w
),
(
0
,
0
)),
'edge'
)
img
=
img
.
astype
(
np
.
float32
)
/
255.0
img
=
img
.
transpose
((
2
,
0
,
1
))
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
return
img
def
get_data
(
self
,
index
):
# 获取图像信息
img_path
=
self
.
data_list
[
index
]
img
=
np
.
array
(
PILImage
.
open
(
img_path
))
if
img
is
None
:
return
img
,
img
,
img_path
,
None
img_name
=
img_path
.
split
(
os
.
sep
)[
-
1
]
name_prefix
=
img_name
.
replace
(
'.'
+
img_name
.
split
(
'.'
)[
-
1
],
''
)
img_shape
=
img
.
shape
[:
2
]
img_process
=
self
.
preprocess
(
img
)
return
img_process
,
name_prefix
,
img_shape
def
infer
():
if
not
os
.
path
.
exists
(
cfg
.
vis_dir
):
os
.
makedirs
(
cfg
.
vis_dir
)
place
=
fluid
.
CUDAPlace
(
0
)
if
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
# 加载预测模型
test_prog
,
feed_name
,
fetch_list
=
fluid
.
io
.
load_inference_model
(
dirname
=
cfg
.
model_path
,
executor
=
exe
,
params_filename
=
'__params__'
)
#加载预测数据集
test_dataset
=
TestDataSet
()
data_num
=
test_dataset
.
data_num
for
idx
in
range
(
data_num
):
# 数据获取
image
,
im_name
,
im_shape
=
test_dataset
.
get_data
(
idx
)
if
image
is
None
:
print
(
im_name
,
'is None'
)
continue
# 预测
output
=
exe
.
run
(
program
=
test_prog
,
feed
=
{
feed_name
[
0
]:
image
},
fetch_list
=
fetch_list
)
instance_map
,
predictions
=
cluster
.
cluster
(
output
[
0
][
0
],
n_sigma
=
cfg
.
n_sigma
,
\
min_pixel
=
cfg
.
min_pixel
,
threshold
=
cfg
.
threshold
)
# 预测结果保存
instance_map
=
pad_img
(
instance_map
,
image
.
shape
[
2
:])
instance_map
=
instance_map
[:
im_shape
[
0
],
:
im_shape
[
1
]]
output_im
=
PILImage
.
fromarray
(
np
.
asarray
(
instance_map
,
dtype
=
np
.
uint8
))
palette
=
get_palette
(
len
(
predictions
)
+
1
)
output_im
.
putpalette
(
palette
)
result_path
=
os
.
path
.
join
(
cfg
.
vis_dir
,
im_name
+
'.png'
)
output_im
.
save
(
result_path
)
if
(
idx
+
1
)
%
100
==
0
:
print
(
'%d processd'
%
(
idx
+
1
))
print
(
'%d processd done'
%
(
idx
+
1
))
return
0
if
__name__
==
"__main__"
:
infer
()
contrib/SpatialEmbeddings/run.sh
0 → 100644
浏览文件 @
12434165
export
CUDA_VISIBLE_DEVICES
=
4
python infer.py
--use_gpu
contrib/SpatialEmbeddings/utils/__init__.py
0 → 100644
浏览文件 @
12434165
contrib/SpatialEmbeddings/utils/data_util.py
0 → 100644
浏览文件 @
12434165
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
numpy
as
np
from
PIL
import
Image
as
PILImage
def
sigmoid_np
(
x
):
return
1
/
(
1
+
np
.
exp
(
-
x
))
class
Cluster
:
def
__init__
(
self
,
):
xm
=
np
.
repeat
(
np
.
linspace
(
0
,
2
,
2048
)[
np
.
newaxis
,
np
.
newaxis
,:],
1024
,
axis
=
1
)
ym
=
np
.
repeat
(
np
.
linspace
(
0
,
1
,
1024
)[
np
.
newaxis
,
:,
np
.
newaxis
],
2048
,
axis
=
2
)
self
.
xym
=
np
.
vstack
((
xm
,
ym
))
def
cluster
(
self
,
prediction
,
n_sigma
=
1
,
min_pixel
=
160
,
threshold
=
0.5
):
height
,
width
=
prediction
.
shape
[
1
:
3
]
xym_s
=
self
.
xym
[:,
0
:
height
,
0
:
width
]
spatial_emb
=
np
.
tanh
(
prediction
[
0
:
2
])
+
xym_s
sigma
=
prediction
[
2
:
2
+
n_sigma
]
seed_map
=
sigmoid_np
(
prediction
[
2
+
n_sigma
:
2
+
n_sigma
+
1
])
instance_map
=
np
.
zeros
((
height
,
width
),
np
.
float32
)
instances
=
[]
count
=
1
mask
=
seed_map
>
0.5
if
mask
.
sum
()
>
min_pixel
:
spatial_emb_masked
=
spatial_emb
[
np
.
repeat
(
mask
,
\
spatial_emb
.
shape
[
0
],
0
)].
reshape
(
2
,
-
1
)
sigma_masked
=
sigma
[
np
.
repeat
(
mask
,
n_sigma
,
0
)].
reshape
(
n_sigma
,
-
1
)
seed_map_masked
=
seed_map
[
mask
].
reshape
(
1
,
-
1
)
unclustered
=
np
.
ones
(
mask
.
sum
(),
np
.
float32
)
instance_map_masked
=
np
.
zeros
(
mask
.
sum
(),
np
.
float32
)
while
(
unclustered
.
sum
()
>
min_pixel
):
seed
=
(
seed_map_masked
*
unclustered
).
argmax
().
item
()
seed_score
=
(
seed_map_masked
*
unclustered
).
max
().
item
()
if
seed_score
<
threshold
:
break
center
=
spatial_emb_masked
[:,
seed
:
seed
+
1
]
unclustered
[
seed
]
=
0
s
=
np
.
exp
(
sigma_masked
[:,
seed
:
seed
+
1
]
*
10
)
dist
=
np
.
exp
(
-
1
*
np
.
sum
((
spatial_emb_masked
-
center
)
**
2
*
s
,
0
))
proposal
=
(
dist
>
0.5
).
squeeze
()
if
proposal
.
sum
()
>
min_pixel
:
if
unclustered
[
proposal
].
sum
()
/
proposal
.
sum
()
>
0.5
:
instance_map_masked
[
proposal
.
squeeze
()]
=
count
instance_mask
=
np
.
zeros
((
height
,
width
),
np
.
float32
)
instance_mask
[
mask
.
squeeze
()]
=
proposal
instances
.
append
(
{
'mask'
:
(
instance_mask
.
squeeze
()
*
255
).
astype
(
np
.
uint8
),
\
'score'
:
seed_score
})
count
+=
1
unclustered
[
proposal
]
=
0
instance_map
[
mask
.
squeeze
()]
=
instance_map_masked
return
instance_map
,
instances
def
pad_img
(
img
,
dst_shape
,
mode
=
'constant'
):
img_h
,
img_w
=
img
.
shape
[:
2
]
dst_h
,
dst_w
=
dst_shape
pad_shape
=
((
0
,
max
(
0
,
dst_h
-
img_h
)),
(
0
,
max
(
0
,
dst_w
-
img_w
)))
return
np
.
pad
(
img
,
pad_shape
,
mode
)
def
save_for_eval
(
predictions
,
infer_shape
,
im_shape
,
vis_dir
,
im_name
):
txt_file
=
os
.
path
.
join
(
vis_dir
,
im_name
+
'.txt'
)
with
open
(
txt_file
,
'w'
)
as
f
:
for
id
,
pred
in
enumerate
(
predictions
):
save_name
=
im_name
+
'_{:02d}.png'
.
format
(
id
)
pred_mask
=
pad_img
(
pred
[
'mask'
],
infer_shape
)
pred_mask
=
pred_mask
[:
im_shape
[
0
],
:
im_shape
[
1
]]
im
=
PILImage
.
fromarray
(
pred_mask
)
im
.
save
(
os
.
path
.
join
(
vis_dir
,
save_name
))
cl
=
26
score
=
pred
[
'score'
]
f
.
writelines
(
"{} {} {:.02f}
\n
"
.
format
(
save_name
,
cl
,
score
))
contrib/SpatialEmbeddings/utils/palette.py
0 → 100644
浏览文件 @
12434165
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: RainbowSecret
## Microsoft Research
## yuyua@microsoft.com
## Copyright (c) 2018
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
cv2
def
get_palette
(
num_cls
):
""" Returns the color map for visualizing the segmentation mask.
Args:
num_cls: Number of classes
Returns:
The color map
"""
n
=
num_cls
palette
=
[
0
]
*
(
n
*
3
)
for
j
in
range
(
0
,
n
):
lab
=
j
palette
[
j
*
3
+
0
]
=
0
palette
[
j
*
3
+
1
]
=
0
palette
[
j
*
3
+
2
]
=
0
i
=
0
while
lab
:
palette
[
j
*
3
+
0
]
|=
(((
lab
>>
0
)
&
1
)
<<
(
7
-
i
))
palette
[
j
*
3
+
1
]
|=
(((
lab
>>
1
)
&
1
)
<<
(
7
-
i
))
palette
[
j
*
3
+
2
]
|=
(((
lab
>>
2
)
&
1
)
<<
(
7
-
i
))
i
+=
1
lab
>>=
3
return
palette
contrib/SpatialEmbeddings/utils/util.py
0 → 100644
浏览文件 @
12434165
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
import
argparse
import
os
def
get_arguments
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--use_gpu"
,
action
=
"store_true"
,
help
=
"Use gpu or cpu to test."
)
parser
.
add_argument
(
'--example'
,
type
=
str
,
help
=
'RoadLine, HumanSeg or ACE2P'
)
return
parser
.
parse_args
()
class
AttrDict
(
dict
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
AttrDict
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
__getattr__
(
self
,
name
):
if
name
in
self
.
__dict__
:
return
self
.
__dict__
[
name
]
elif
name
in
self
:
return
self
[
name
]
else
:
raise
AttributeError
(
name
)
def
__setattr__
(
self
,
name
,
value
):
if
name
in
self
.
__dict__
:
self
.
__dict__
[
name
]
=
value
else
:
self
[
name
]
=
value
def
merge_cfg_from_args
(
args
,
cfg
):
"""Merge config keys, values in args into the global config."""
for
k
,
v
in
vars
(
args
).
items
():
d
=
cfg
try
:
value
=
eval
(
v
)
except
:
value
=
v
if
value
is
not
None
:
cfg
[
k
]
=
value
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录