提交 12434165 编写于 作者: P pennypm

add contrib/SpatialEmbeddings

上级 12bf97cf
# SpatialEmbeddings
## 模型概述
本模型是基于proposal-free的实例分割模型,快速实时,同时准确率高,适用于自动驾驶等实时场景。
本模型基于KITTI中MOTS数据集训练得到,是论文 Segment as Points for Efficient Online Multi-Object Tracking and Segmentation中的分割部分
[论文地址](https://arxiv.org/pdf/2007.01550.pdf)
## KITTI MOTS指标
KITTI MOTS验证集AP:0.76, AP_50%:0.915
## 代码使用说明
### 1. 模型下载
执行以下命令下载并解压SpatialEmbeddings预测模型:
```
python download_SpatialEmbeddings_kitti.py
```
或点击[链接](https://paddleseg.bj.bcebos.com/models/SpatialEmbeddings_kitti.tar)进行手动下载并解压。
### 2. 数据下载
前往KITTI官网下载MOTS比赛数据[链接](https://www.vision.rwth-aachen.de/page/mots)
下载后解压到./data文件夹下, 并生成验证集图片路径的test.txt
### 3. 快速预测
使用GPU预测
```
python -u infer.py --use_gpu
```
使用CPU预测:
```
python -u infer.py
```
数据及模型路径等详细配置见config.py文件
#### 4. 预测结果示例:
原图:
![](imgs/kitti_0007_000518_ori.png)
预测结果:
![](imgs/kitti_0007_000518_pred.png)
## 引用
**论文**
*Instance Segmentation by Jointly Optimizing Spatial Embeddings and Clustering Bandwidth*
**代码**
https://github.com/davyneven/SpatialEmbeddings
# -*- coding: utf-8 -*-
from utils.util import AttrDict, merge_cfg_from_args, get_arguments
import os
args = get_arguments()
cfg = AttrDict()
# 待预测图像所在路径
cfg.data_dir = "data"
# 待预测图像名称列表
cfg.data_list_file = os.path.join("data", "test.txt")
# 模型加载路径
cfg.model_path = 'SpatialEmbeddings_kitti'
# 预测结果保存路径
cfg.vis_dir = "result"
# sigma值
cfg.n_sigma = 2
# 中心点阈值
cfg.threshold = 0.94
# 点集数阈值
cfg.min_pixel = 160
merge_cfg_from_args(args, cfg)
kitti/0007/kitti_0007_000512.png
kitti/0007/kitti_0007_000518.png
# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
TEST_PATH = os.path.join(LOCAL_PATH, "..", "..", "test")
sys.path.append(TEST_PATH)
from test_utils import download_file_and_uncompress
if __name__ == "__main__":
download_file_and_uncompress(
url='https://paddleseg.bj.bcebos.com/models/SpatialEmbeddings_kitti.tar',
savepath=LOCAL_PATH,
extrapath=LOCAL_PATH,
extraname='SpatialEmbeddings_kitti')
print("Pretrained Model download success!")
# -*- coding: utf-8 -*-
import os
import numpy as np
from utils.util import get_arguments
from utils.palette import get_palette
from utils.data_util import Cluster, pad_img
from PIL import Image as PILImage
import importlib
import paddle.fluid as fluid
args = get_arguments()
config = importlib.import_module('config')
cfg = getattr(config, 'cfg')
cluster = Cluster()
# 预测数据集类
class TestDataSet():
def __init__(self):
self.data_dir = cfg.data_dir
self.data_list_file = cfg.data_list_file
self.data_list = self.get_data_list()
self.data_num = len(self.data_list)
def get_data_list(self):
# 获取预测图像路径列表
data_list = []
data_file_handler = open(self.data_list_file, 'r')
for line in data_file_handler:
img_name = line.strip()
name_prefix = img_name.split('.')[0]
if len(img_name.split('.')) == 1:
img_name = img_name + '.jpg'
img_path = os.path.join(self.data_dir, img_name)
data_list.append(img_path)
return data_list
def preprocess(self, img):
# 图像预处理
h, w = img.shape[:2]
h_new = (h//32 + 1 if h % 32 != 0 else h//32)*32
w_new = (w//32 + 1 if w % 32 != 0 else w//32)*32
img = np.pad(img, ((0, h_new - h), (0, w_new - w), (0, 0)), 'edge')
img = img.astype(np.float32)/255.0
img = img.transpose((2, 0, 1))
img = np.expand_dims(img, axis=0)
return img
def get_data(self, index):
# 获取图像信息
img_path = self.data_list[index]
img = np.array(PILImage.open(img_path))
if img is None:
return img, img,img_path, None
img_name = img_path.split(os.sep)[-1]
name_prefix = img_name.replace('.'+img_name.split('.')[-1],'')
img_shape = img.shape[:2]
img_process = self.preprocess(img)
return img_process, name_prefix, img_shape
def infer():
if not os.path.exists(cfg.vis_dir):
os.makedirs(cfg.vis_dir)
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# 加载预测模型
test_prog, feed_name, fetch_list = fluid.io.load_inference_model(
dirname=cfg.model_path, executor=exe, params_filename='__params__')
#加载预测数据集
test_dataset = TestDataSet()
data_num = test_dataset.data_num
for idx in range(data_num):
# 数据获取
image, im_name, im_shape = test_dataset.get_data(idx)
if image is None:
print(im_name, 'is None')
continue
# 预测
output = exe.run(program=test_prog, feed={feed_name[0]: image}, fetch_list=fetch_list)
instance_map, predictions = cluster.cluster(output[0][0], n_sigma=cfg.n_sigma, \
min_pixel=cfg.min_pixel, threshold=cfg.threshold)
# 预测结果保存
instance_map = pad_img(instance_map, image.shape[2:])
instance_map = instance_map[:im_shape[0], :im_shape[1]]
output_im = PILImage.fromarray(np.asarray(instance_map, dtype=np.uint8))
palette = get_palette(len(predictions) + 1)
output_im.putpalette(palette)
result_path = os.path.join(cfg.vis_dir, im_name+'.png')
output_im.save(result_path)
if (idx + 1) % 100 == 0:
print('%d processd' % (idx + 1))
print('%d processd done' % (idx + 1))
return 0
if __name__ == "__main__":
infer()
export CUDA_VISIBLE_DEVICES=4
python infer.py --use_gpu
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from PIL import Image as PILImage
def sigmoid_np(x):
return 1/(1+np.exp(-x))
class Cluster:
def __init__(self, ):
xm = np.repeat(np.linspace(0, 2, 2048)[np.newaxis, np.newaxis,:], 1024, axis=1)
ym = np.repeat(np.linspace(0, 1, 1024)[np.newaxis, :, np.newaxis], 2048, axis=2)
self.xym = np.vstack((xm, ym))
def cluster(self, prediction, n_sigma=1, min_pixel=160, threshold=0.5):
height, width = prediction.shape[1:3]
xym_s = self.xym[:, 0:height, 0:width]
spatial_emb = np.tanh(prediction[0:2]) + xym_s
sigma = prediction[2:2+n_sigma]
seed_map = sigmoid_np(prediction[2+n_sigma:2+n_sigma + 1])
instance_map = np.zeros((height, width), np.float32)
instances = []
count = 1
mask = seed_map > 0.5
if mask.sum() > min_pixel:
spatial_emb_masked = spatial_emb[np.repeat(mask, \
spatial_emb.shape[0], 0)].reshape(2, -1)
sigma_masked = sigma[np.repeat(mask, n_sigma, 0)].reshape(n_sigma, -1)
seed_map_masked = seed_map[mask].reshape(1, -1)
unclustered = np.ones(mask.sum(), np.float32)
instance_map_masked = np.zeros(mask.sum(), np.float32)
while(unclustered.sum() > min_pixel):
seed = (seed_map_masked * unclustered).argmax().item()
seed_score = (seed_map_masked * unclustered).max().item()
if seed_score < threshold:
break
center = spatial_emb_masked[:, seed:seed+1]
unclustered[seed] = 0
s = np.exp(sigma_masked[:, seed:seed+1]*10)
dist = np.exp(-1*np.sum((spatial_emb_masked-center)**2 *s, 0))
proposal = (dist > 0.5).squeeze()
if proposal.sum() > min_pixel:
if unclustered[proposal].sum()/proposal.sum()> 0.5:
instance_map_masked[proposal.squeeze()] = count
instance_mask = np.zeros((height, width), np.float32)
instance_mask[mask.squeeze()] = proposal
instances.append(
{'mask': (instance_mask.squeeze()*255).astype(np.uint8), \
'score': seed_score})
count += 1
unclustered[proposal] = 0
instance_map[mask.squeeze()] = instance_map_masked
return instance_map, instances
def pad_img(img, dst_shape, mode='constant'):
img_h, img_w = img.shape[:2]
dst_h, dst_w = dst_shape
pad_shape = ((0, max(0, dst_h - img_h)), (0, max(0, dst_w - img_w)))
return np.pad(img, pad_shape, mode)
def save_for_eval(predictions, infer_shape, im_shape, vis_dir, im_name):
txt_file = os.path.join(vis_dir, im_name + '.txt')
with open(txt_file, 'w') as f:
for id, pred in enumerate(predictions):
save_name = im_name + '_{:02d}.png'.format(id)
pred_mask = pad_img(pred['mask'], infer_shape)
pred_mask = pred_mask[:im_shape[0], :im_shape[1]]
im = PILImage.fromarray(pred_mask)
im.save(os.path.join(vis_dir, save_name))
cl = 26
score = pred['score']
f.writelines("{} {} {:.02f}\n".format(save_name, cl, score))
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: RainbowSecret
## Microsoft Research
## yuyua@microsoft.com
## Copyright (c) 2018
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import cv2
def get_palette(num_cls):
""" Returns the color map for visualizing the segmentation mask.
Args:
num_cls: Number of classes
Returns:
The color map
"""
n = num_cls
palette = [0] * (n * 3)
for j in range(0, n):
lab = j
palette[j * 3 + 0] = 0
palette[j * 3 + 1] = 0
palette[j * 3 + 2] = 0
i = 0
while lab:
palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
i += 1
lab >>= 3
return palette
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import os
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--use_gpu",
action="store_true",
help="Use gpu or cpu to test.")
parser.add_argument('--example',
type=str,
help='RoadLine, HumanSeg or ACE2P')
return parser.parse_args()
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
def __getattr__(self, name):
if name in self.__dict__:
return self.__dict__[name]
elif name in self:
return self[name]
else:
raise AttributeError(name)
def __setattr__(self, name, value):
if name in self.__dict__:
self.__dict__[name] = value
else:
self[name] = value
def merge_cfg_from_args(args, cfg):
"""Merge config keys, values in args into the global config."""
for k, v in vars(args).items():
d = cfg
try:
value = eval(v)
except:
value = v
if value is not None:
cfg[k] = value
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册