提交 2b83c0a1 编写于 作者: Y yangyaming

Init demo.

上级 63a80a32
# SSD目标检测
## 概述
SSD全称为Single Shot MultiBox Detector,是目标检测领域较新且效果较好的检测算法之一,具体参见论文\[[1](#引用)\]。SSD算法主要特点是检测速度快且检测精度高,当输入图像大小为300*300,显卡采用Nvidia Titan X时,检测速度可达到59FPS,并且在VOC2007 test数据集上mAP高达74.3%。PaddlePaddle已集成SSD算法,本示例旨在介绍如何使用PaddlePaddle的SSD模型进行目标检测。下文展开顺序为:首先简要介绍SSD原理,然后介绍示例包含文件及作用,接着介绍如何在PASCAL VOC数据集上训练、评估及检测,最后简要介绍如何在自有数据集上使用SSD。
## SSD原理
SSD使用一个卷积神经网络实现“端到端”的检测,所谓“端到端”指输入为原始图像,输出为检测结果,无需借助外部工具或流程进行特征提取、候选框生成等。论文中SSD的基础模型为VGG-16,其在VGG-16的某些层后面增加了一些额外的层进行候选框的提取,下图为模型的总体结构:
<p align="center">
<img src="images/ssd_network.png" width="600" hspace='10'/> <br/>
图1. SSD网络结构
</p>
如图所示,候选框的生成规则是预先设定的,比如Conv7输出的特征图每个像素点会对应6个候选框,这些候选框长宽比或面积有区分。在预测阶段模型会对这些提取出来的候选框做后处理,然后输出作为最终的检测结果。
## 示例总览
本示例共包含如下文件:
<center>
文件 | 用途
---- | -----
train.py | 训练脚本
eval.py | 评估脚本,用于评估训好模型
infer.py | 检测脚本,给定图片及模型,实施检测
visual.py | 检测结果可视化
image_util.py | 图像预处理所需公共函数
data_provider.py | 数据处理脚本,生成训练、评估或检测所需数据
config/pascal\_voc\_conf.py | 神经网络超参数配置文件
data/label\_list | 类别列表
data/prepare\_voc\_data.py | 准备训练PASCAL VOC数据列表
</center>
<center>表1. 示例文件</center>
训练阶段需要对数据做预处理,包括裁剪、采样等,这部分操作在```image_util.py``````data_provider.py```中完成;值得注意的是,```config/vgg_config.py```为参数配置文件,包括训练参数、神经网络参数等,本配置文件包含参数是针对PASCAL VOC数据配置的,当训练自有数据时,需要仿照该文件配置新的参数;```data/prepare_voc_data.py```脚本用来生成文件列表,包括切分训练集和测试集,使用时需要用户事先下载并解压数据,默认采用VOC2007和VOC2012。
## PASCAL VOC数据集
### 数据准备
首先需要下载数据集,VOC2007\[[2](#引用)\]和VOC2012\[[3](#引用)\],VOC2007包含训练集和测试集,VOC2012只包含训练集,将下载好的数据解压,目录结构为```VOCdevkit/{VOC2007,VOC2012}```。进入```data```目录,运行```python prepare_voc_data.py```即可生成```trainval.txt``````test.txt```,默认```prepare_voc_data.py``````VOCdevkit```在相同目录下,且生成的文件列表也在该目录。需注意```trainval.txt```既包含VOC2007的训练数据,也包含VOC2012的训练数据,```test.txt```只包含VOC2007的测试数据。
### 预训练模型准备
下载训练好的VGG-16模型,推荐在ImageNet分类数据集上预训练的模型,针对caffe训练的模型,PaddlePaddle提供转换脚本,可方便转换成PaddlePaddle格式(待扩展),这里默认转换后的模型路径为```atrous_vgg/model.tar.gz```
### 模型训练
直接执行```python train.py```即可进行训练。需要注意本示例仅支持CUDA GPU环境,无法在CPU上训练。```train.py```的一些关键执行逻辑:
```python
paddle.init(use_gpu=True, trainer_count=4)
data_args = data_provider.Settings(
data_dir='./data',
label_file='label_list',
resize_h=cfg.IMG_HEIGHT,
resize_w=cfg.IMG_WIDTH,
mean_value=[104,117,124])
train(train_file_list='./data/trainval.txt',
dev_file_list='./data/test.txt',
data_args=data_args,
init_model_path='./atrous_vgg/model.tar.gz')
```
调用```paddle.init```指定使用4卡GPU训练;调用```data_provider.Settings```配置数据预处理所需参数,其中```cfg.IMG_HEIGHT``````cfg.IMG_WIDTH```在配置文件```config/vgg_config.py```中设置,这里均为300;调用```train```执行训练,其中```train_file_list```指定训练数据列表,```dev_file_list```指定评估数据列表,```init_model_path```指定预选模型位置。训练过程中会打印一些日志信息,每训练10个batch会输出当前的轮数、当前batch的cost及mAP,每训练一个pass,会保存一次模型,默认保存在```models```目录下(注:需事先创建)。
### 模型评估
### 图像检测
## 自有数据集
## 引用
1. Liu, Wei, et al. "Ssd: Single shot multibox detector." European conference on computer vision. Springer, Cham, 2016.
2. http://host.robots.ox.ac.uk/pascal/VOC/voc2007/index.html
3. http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html
from easydict import EasyDict as edict
import numpy as np
__C = edict()
cfg = __C
__C.TRAIN = edict()
__C.IMG_WIDTH = 300
__C.IMG_HEIGHT = 300
__C.IMG_CHANNEL = 3
__C.CLASS_NUM = 21
__C.BACKGROUND_ID = 0
# training settings
__C.TRAIN.LEARNING_RATE = 0.001 / 4
__C.TRAIN.MOMENTUM = 0.9
__C.TRAIN.BATCH_SIZE = 32
__C.TRAIN.NUM_PASS = 200
__C.TRAIN.L2REGULARIZATION = 0.0005 * 4
__C.TRAIN.LEARNING_RATE_DECAY_A = 0.1
__C.TRAIN.LEARNING_RATE_DECAY_B = 16551 * 80
__C.TRAIN.LEARNING_RATE_SCHEDULE = 'discexp'
__C.NET = edict()
# configuration for multibox_loss_layer
__C.NET.MBLOSS = edict()
__C.NET.MBLOSS.OVERLAP_THRESHOLD = 0.5
__C.NET.MBLOSS.NEG_POS_RATIO = 3.0
__C.NET.MBLOSS.NEG_OVERLAP = 0.5
# configuration for detection_map
__C.NET.DETMAP = edict()
__C.NET.DETMAP.OVERLAP_THRESHOLD = 0.5
__C.NET.DETMAP.EVAL_DIFFICULT = False
__C.NET.DETMAP.AP_TYPE = "11point"
# configuration for detection_output_layer
__C.NET.DETOUT = edict()
__C.NET.DETOUT.CONFIDENCE_THRESHOLD = 0.01
__C.NET.DETOUT.NMS_THRESHOLD = 0.45
__C.NET.DETOUT.NMS_TOP_K = 400
__C.NET.DETOUT.KEEP_TOP_K = 200
# configuration for priorbox_layer from conv4_3
__C.NET.CONV4 = edict()
__C.NET.CONV4.PB = edict()
__C.NET.CONV4.PB.MIN_SIZE = [30]
__C.NET.CONV4.PB.ASPECT_RATIO = [2.]
__C.NET.CONV4.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2]
# configuration for priorbox_layer from fc7
__C.NET.FC7 = edict()
__C.NET.FC7.PB = edict()
__C.NET.FC7.PB.MIN_SIZE = [60]
__C.NET.FC7.PB.MAX_SIZE = [114]
__C.NET.FC7.PB.ASPECT_RATIO = [2., 3.]
__C.NET.FC7.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2]
# configuration for priorbox_layer from conv6_2
__C.NET.CONV6 = edict()
__C.NET.CONV6.PB = edict()
__C.NET.CONV6.PB.MIN_SIZE = [114]
__C.NET.CONV6.PB.MAX_SIZE = [168]
__C.NET.CONV6.PB.ASPECT_RATIO = [2., 3.]
__C.NET.CONV6.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2]
# configuration for priorbox_layer from conv7_2
__C.NET.CONV7 = edict()
__C.NET.CONV7.PB = edict()
__C.NET.CONV7.PB.MIN_SIZE = [168]
__C.NET.CONV7.PB.MAX_SIZE = [222]
__C.NET.CONV7.PB.ASPECT_RATIO = [2., 3.]
__C.NET.CONV7.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2]
# configuration for priorbox_layer from conv8_2
__C.NET.CONV8 = edict()
__C.NET.CONV8.PB = edict()
__C.NET.CONV8.PB.MIN_SIZE = [222]
__C.NET.CONV8.PB.MAX_SIZE = [276]
__C.NET.CONV8.PB.ASPECT_RATIO = [2., 3.]
__C.NET.CONV8.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2]
# configuration for priorbox_layer from pool6
__C.NET.POOL6 = edict()
__C.NET.POOL6.PB = edict()
__C.NET.POOL6.PB.MIN_SIZE = [276]
__C.NET.POOL6.PB.MAX_SIZE = [330]
__C.NET.POOL6.PB.ASPECT_RATIO = [2., 3.]
__C.NET.POOL6.PB.VARIANCE = [0.1, 0.1, 0.2, 0.2]
background
aeroplane
bicycle
bird
boat
bottle
bus
car
cat
chair
cow
diningtable
dog
horse
motorbike
person
pottedplant
sheep
sofa
train
tvmonitor
import os
import os.path as osp
import re
import random
devkit_dir = './VOCdevkit'
years = ['2007', '2012']
def get_img_dir(devkit_dir, year):
return osp.join(devkit_dir, 'VOC' + year, 'JPEGImages')
def get_annotation_dir(devkit_dir, year):
return osp.join(devkit_dir, 'VOC' + year, 'Annotations')
def get_filelist_dir(devkit_dir, year):
return osp.join(devkit_dir, 'VOC' + year, 'ImageSets/Main')
def walk_dir(devkit_dir, year):
filelist_dir = get_filelist_dir(devkit_dir, year)
annotation_dir = get_annotation_dir(devkit_dir, year)
img_dir = get_img_dir(devkit_dir, year)
trainval_list = []
test_list = []
added = set()
for _, _, files in os.walk(filelist_dir):
for fname in files:
img_ann_list = []
if re.match('[a-z]+_trainval\.txt', fname):
img_ann_list = trainval_list
elif re.match('[a-z]+_test\.txt', fname):
img_ann_list = test_list
else:
continue
fpath = osp.join(filelist_dir, fname)
for line in open(fpath):
name_prefix = line.strip().split()[0]
if name_prefix in added:
continue
added.add(name_prefix)
ann_path = osp.join(annotation_dir, name_prefix + '.xml')
img_path = osp.join(img_dir, name_prefix + '.jpg')
assert os.path.isfile(ann_path), 'file %s not found.' % ann_path
assert os.path.isfile(img_path), 'file %s not found.' % img_path
img_ann_list.append((img_path, ann_path))
return trainval_list, test_list
def prepare_filelist(devkit_dir, years, output_dir):
trainval_list = []
test_list = []
for year in years:
trainval, test = walk_dir(devkit_dir, year)
trainval_list.extend(trainval)
test_list.extend(test)
random.shuffle(trainval_list)
with open(osp.join(output_dir, 'trainval.txt'), 'w') as ftrainval:
for item in trainval_list:
ftrainval.write(item[0] + ' ' + item[1] + '\n')
with open(osp.join(output_dir, 'test.txt'), 'w') as ftest:
for item in test_list:
ftest.write(item[0] + ' ' + item[1] + '\n')
prepare_filelist(devkit_dir, years, '.')
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import image_util
from paddle.utils.image_util import *
import random
from PIL import Image
import numpy as np
import xml.etree.ElementTree
import os
class Settings(object):
def __init__(self, data_dir, label_file, resize_h, resize_w, mean_value):
self._data_dir = data_dir
self._label_list = []
label_fpath = os.path.join(data_dir, label_file)
for line in open(label_fpath):
self._label_list.append(line.strip())
self._resize_height = resize_h
self._resize_width = resize_w
self._mean_value = mean_value
img_size = self._resize_height * self._resize_width
self._img_mean = np.zeros(img_size * 3, dtype=np.single)
for idx, value in enumerate(self._mean_value):
self._img_mean[idx * img_size:(idx + 1) * img_size] = value
self._img_mean = self._img_mean.reshape(3, self._resize_height,
self._resize_width)
self._img_mean = self._img_mean.astype('float32')
@property
def data_dir(self):
return self._data_dir
@property
def label_list(self):
return self._label_list
@property
def resize_h(self):
return self._resize_height
@property
def resize_w(self):
return self._resize_width
@property
def img_mean(self):
return self._img_mean
def _reader_creator(settings, file_list, mode, shuffle):
def reader():
with open(file_list) as flist:
lines = [line.strip() for line in flist]
if shuffle:
random.shuffle(lines)
for line in lines:
if mode == 'train' or mode == 'test':
img_path, label_path = line.split()
img_path = os.path.join(settings.data_dir, img_path)
label_path = os.path.join(settings.data_dir, label_path)
elif mode == 'infer':
img_path = os.path.join(settings.data_dir, line)
img = Image.open(img_path)
img_width, img_height = img.size
img = np.array(img)
# layout: label | xmin | ymin | xmax | ymax | difficult
if mode == 'train' or mode == 'test':
bbox_labels = []
root = xml.etree.ElementTree.parse(label_path).getroot()
for object in root.findall('object'):
bbox_sample = []
# start from 1
bbox_sample.append(
float(
settings.label_list.index(
object.find('name').text)))
bbox = object.find('bndbox')
difficult = float(object.find('difficult').text)
bbox_sample.append(
float(bbox.find('xmin').text) / img_width)
bbox_sample.append(
float(bbox.find('ymin').text) / img_height)
bbox_sample.append(
float(bbox.find('xmax').text) / img_width)
bbox_sample.append(
float(bbox.find('ymax').text) / img_height)
bbox_sample.append(difficult)
bbox_labels.append(bbox_sample)
sample_labels = bbox_labels
if mode == 'train':
batch_sampler = []
# hard-code here
batch_sampler.append(
image_util.sampler(1, 1, 1.0, 1.0, 1.0, 1.0, 0.0,
0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.1,
0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.3,
0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.5,
0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.7,
0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.9,
0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.0,
1.0))
""" random crop """
sampled_bbox = image_util.generateBatchSamples(
batch_sampler, bbox_labels, img_width, img_height)
if len(sampled_bbox) > 0:
idx = int(random.uniform(0, len(sampled_bbox)))
img, sample_labels = image_util.cropImage(
img, bbox_labels, sampled_bbox[idx], img_width,
img_height)
img = Image.fromarray(img)
img = img.resize((settings.resize_w, settings.resize_h),
Image.ANTIALIAS)
img = np.array(img)
if mode == 'train':
mirror = int(random.uniform(0, 2))
if mirror == 1:
img = img[:, ::-1, :]
for i in xrange(len(sample_labels)):
tmp = sample_labels[i][1]
sample_labels[i][1] = 1 - sample_labels[i][3]
sample_labels[i][3] = 1 - tmp
if len(img.shape) == 3:
img = np.swapaxes(img, 1, 2)
img = np.swapaxes(img, 1, 0)
img = img.astype('float32')
img -= settings.img_mean
img = img.flatten()
if mode == 'train' or mode == 'test':
if mode == 'train' and len(sample_labels) == 0: continue
yield img.astype('float32'), sample_labels
elif mode == 'infer':
yield img.astype('float32')
return reader
def train(settings, file_list, shuffle=True):
return _reader_creator(settings, file_list, 'train', shuffle)
def test(settings, file_list):
return _reader_creator(settings, file_list, 'test', False)
def infer(settings, file_list):
return _reader_creator(settings, file_list, 'infer', False)
import paddle.v2 as paddle
import data_provider
import vgg_ssd_net
import os, sys
import gzip
from config.pascal_voc_conf import cfg
def eval(eval_file_list, batch_size, data_args, model_path):
cost, detect_out = vgg_ssd_net.net_conf(mode='eval')
assert os.path.isfile(model_path), 'Invalid model.'
parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_path))
optimizer = paddle.optimizer.Momentum()
trainer = paddle.trainer.SGD(
cost=cost,
parameters=parameters,
extra_layers=[detect_out],
update_equation=optimizer)
feeding = {'image': 0, 'bbox': 1}
reader = paddle.batch(
data_provider.test(data_args, eval_file_list), batch_size=batch_size)
result = trainer.test(reader=reader, feeding=feeding)
print "TestCost: %f, Detection mAP=%g" % \
(result.cost, result.metrics['detection_evaluator'])
if __name__ == "__main__":
paddle.init(use_gpu=True, trainer_count=4) # use 4 gpus
data_args = data_provider.Settings(
data_dir='./data',
label_file='label_list',
resize_h=cfg.IMG_HEIGHT,
resize_w=cfg.IMG_WIDTH,
mean_value=[104, 117, 124])
eval(
eval_file_list='./data/test.txt',
batch_size=4,
data_args=data_args,
model_path='models/pass-00000.tar.gz')
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from PIL import Image
import numpy as np
import random
import math
class sampler():
def __init__(self, max_sample, max_trial, min_scale, max_scale,
min_aspect_ratio, max_aspect_ratio, min_jaccard_overlap,
max_jaccard_overlap):
self.max_sample = max_sample
self.max_trial = max_trial
self.min_scale = min_scale
self.max_scale = max_scale
self.min_aspect_ratio = min_aspect_ratio
self.max_aspect_ratio = max_aspect_ratio
self.min_jaccard_overlap = min_jaccard_overlap
self.max_jaccard_overlap = max_jaccard_overlap
class bbox():
def __init__(self, xmin, ymin, xmax, ymax):
self.xmin = xmin
self.ymin = ymin
self.xmax = xmax
self.ymax = ymax
def bboxSize(src_bbox):
width = src_bbox.xmax - src_bbox.xmin
height = src_bbox.ymax - src_bbox.ymin
return width * height
def preprocessImg(obj, im):
im = im.astype('float32')
pic = im
pic -= obj.img_mean
return pic.flatten()
def generateSample(sampler):
scale = random.uniform(sampler.min_scale, sampler.max_scale)
min_aspect_ratio = max(sampler.min_aspect_ratio, (scale**2.0))
max_aspect_ratio = min(sampler.max_aspect_ratio, 1 / (scale**2.0))
aspect_ratio = random.uniform(min_aspect_ratio, max_aspect_ratio)
bbox_width = scale * (aspect_ratio**0.5)
bbox_height = scale / (aspect_ratio**0.5)
xmin_bound = 1 - bbox_width
ymin_bound = 1 - bbox_height
xmin = random.uniform(0, xmin_bound)
ymin = random.uniform(0, ymin_bound)
xmax = xmin + bbox_width
ymax = ymin + bbox_height
sampled_bbox = bbox(xmin, ymin, xmax, ymax)
return sampled_bbox
def jaccardOverlap(sample_bbox, object_bbox):
if sample_bbox.xmin >= object_bbox.xmax or \
sample_bbox.xmax <= object_bbox.xmin or \
sample_bbox.ymin >= object_bbox.ymax or \
sample_bbox.ymax <= object_bbox.ymin:
return 0
intersect_xmin = max(sample_bbox.xmin, object_bbox.xmin)
intersect_ymin = max(sample_bbox.ymin, object_bbox.ymin)
intersect_xmax = min(sample_bbox.xmax, object_bbox.xmax)
intersect_ymax = min(sample_bbox.ymax, object_bbox.ymax)
intersect_size = (intersect_xmax - intersect_xmin) * (
intersect_ymax - intersect_ymin)
sample_bbox_size = bboxSize(sample_bbox)
object_bbox_size = bboxSize(object_bbox)
overlap = intersect_size / (
sample_bbox_size + object_bbox_size - intersect_size)
return overlap
def satisfySampleConstraint(sampler, sample_bbox, bbox_labels):
if sampler.min_jaccard_overlap == 0 and sampler.max_jaccard_overlap == 0:
return True
for i in range(len(bbox_labels)):
object_bbox = bbox(bbox_labels[i][1], bbox_labels[i][2],
bbox_labels[i][3], bbox_labels[i][4])
overlap = jaccardOverlap(sample_bbox, object_bbox)
if sampler.min_jaccard_overlap != 0 and \
overlap < sampler.min_jaccard_overlap:
continue
if sampler.max_jaccard_overlap != 0 and \
overlap > sampler.max_jaccard_overlap:
continue
return True
return False
def generateBatchSamples(batch_sampler, bbox_labels, image_width, image_height):
sampled_bbox = []
index = []
c = 0
for sampler in batch_sampler:
found = 0
for i in range(sampler.max_trial):
if found >= sampler.max_sample:
break
sample_bbox = generateSample(sampler)
if satisfySampleConstraint(sampler, sample_bbox, bbox_labels):
sampled_bbox.append(sample_bbox)
found = found + 1
index.append(c)
c = c + 1
return sampled_bbox
def clipBBox(src_bbox):
src_bbox.xmin = max(min(src_bbox.xmin, 1.0), 0.0)
src_bbox.ymin = max(min(src_bbox.ymin, 1.0), 0.0)
src_bbox.xmax = max(min(src_bbox.xmax, 1.0), 0.0)
src_bbox.ymax = max(min(src_bbox.ymax, 1.0), 0.0)
return src_bbox
def meetEmitConstraint(src_bbox, sample_bbox):
center_x = (src_bbox.xmax + src_bbox.xmin) / 2
center_y = (src_bbox.ymax + src_bbox.ymin) / 2
if center_x >= sample_bbox.xmin and \
center_x <= sample_bbox.xmax and \
center_y >= sample_bbox.ymin and \
center_y <= sample_bbox.ymax:
return True
return False
def transformLabels(bbox_labels, sample_bbox):
proj_bbox = bbox(0, 0, 0, 0)
sample_labels = []
for i in range(len(bbox_labels)):
sample_label = []
object_bbox = bbox(bbox_labels[i][1], bbox_labels[i][2],
bbox_labels[i][3], bbox_labels[i][4])
if not meetEmitConstraint(object_bbox, sample_bbox):
continue
sample_width = sample_bbox.xmax - sample_bbox.xmin
sample_height = sample_bbox.ymax - sample_bbox.ymin
proj_bbox.xmin = (object_bbox.xmin - sample_bbox.xmin) / sample_width
proj_bbox.ymin = (object_bbox.ymin - sample_bbox.ymin) / sample_height
proj_bbox.xmax = (object_bbox.xmax - sample_bbox.xmin) / sample_width
proj_bbox.ymax = (object_bbox.ymax - sample_bbox.ymin) / sample_height
proj_bbox = clipBBox(proj_bbox)
if bboxSize(proj_bbox) > 0:
sample_label.append(bbox_labels[i][0])
sample_label.append(float(proj_bbox.xmin))
sample_label.append(float(proj_bbox.ymin))
sample_label.append(float(proj_bbox.xmax))
sample_label.append(float(proj_bbox.ymax))
sample_label.append(bbox_labels[i][5])
sample_labels.append(sample_label)
return sample_labels
def cropImage(img, bbox_labels, sample_bbox, image_width, image_height):
sample_bbox = clipBBox(sample_bbox)
xmin = int(sample_bbox.xmin * image_width)
xmax = int(sample_bbox.xmax * image_width)
ymin = int(sample_bbox.ymin * image_height)
ymax = int(sample_bbox.ymax * image_height)
sample_img = img[ymin:ymax, xmin:xmax]
sample_labels = transformLabels(bbox_labels, sample_bbox)
return sample_img, sample_labels
<html>
<head>
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
extensions: ["tex2jax.js", "TeX/AMSsymbols.js", "TeX/AMSmath.js"],
jax: ["input/TeX", "output/HTML-CSS"],
tex2jax: {
inlineMath: [ ['$','$'] ],
displayMath: [ ['$$','$$'] ],
processEscapes: true
},
"HTML-CSS": { availableFonts: ["TeX"] }
});
</script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js" async></script>
<script type="text/javascript" src="../.tools/theme/marked.js">
</script>
<link href="http://cdn.bootcss.com/highlight.js/9.9.0/styles/darcula.min.css" rel="stylesheet">
<script src="http://cdn.bootcss.com/highlight.js/9.9.0/highlight.min.js"></script>
<link href="http://cdn.bootcss.com/bootstrap/4.0.0-alpha.6/css/bootstrap.min.css" rel="stylesheet">
<link href="https://cdn.jsdelivr.net/perfect-scrollbar/0.6.14/css/perfect-scrollbar.min.css" rel="stylesheet">
<link href="../.tools/theme/github-markdown.css" rel='stylesheet'>
</head>
<style type="text/css" >
.markdown-body {
box-sizing: border-box;
min-width: 200px;
max-width: 980px;
margin: 0 auto;
padding: 45px;
}
</style>
<body>
<div id="context" class="container-fluid markdown-body">
</div>
<!-- This block will be replaced by each markdown file content. Please do not change lines below.-->
<div id="markdown" style='display:none'>
# SSD目标检测
## 概述
SSD全称为Single Shot MultiBox Detector,是目标检测领域较新且效果较好的检测算法之一,具体参见论文\[[1](#引用)\]。SSD算法主要特点是检测速度快且检测精度高,当输入图像大小为300*300,显卡采用Nvidia Titan X时,检测速度可达到59FPS,并且在VOC2007 test数据集上mAP高达74.3%。PaddlePaddle已集成SSD算法,本示例旨在介绍如何使用PaddlePaddle的SSD模型进行目标检测。下文展开顺序为:首先简要介绍SSD原理,然后介绍示例包含文件及作用,接着介绍如何在PASCAL VOC数据集上训练、评估及检测,最后简要介绍如何在自有数据集上使用SSD。
## SSD原理
SSD使用一个卷积神经网络实现“端到端”的检测,所谓“端到端”指输入为原始图像,输出为检测结果,无需借助外部工具或流程进行特征提取、候选框生成等。论文中SSD的基础模型为VGG-16,其在VGG-16的某些层后面增加了一些额外的层进行候选框的提取,下图为模型的总体结构:
<p align="center">
<img src="images/ssd_network.png" width="600" hspace='10'/> <br/>
图1. SSD网络结构
</p>
如图所示,候选框的生成规则是预先设定的,比如Conv7输出的特征图每个像素点会对应6个候选框,这些候选框长宽比或面积有区分。在预测阶段模型会对这些提取出来的候选框做后处理,然后输出作为最终的检测结果。
## 示例总览
本示例共包含如下文件:
<center>
文件 | 用途
---- | -----
train.py | 训练脚本
eval.py | 评估脚本,用于评估训好模型
infer.py | 检测脚本,给定图片及模型,实施检测
visual.py | 检测结果可视化
image_util.py | 图像预处理所需公共函数
data_provider.py | 数据处理脚本,生成训练、评估或检测所需数据
config/pascal\_voc\_conf.py | 神经网络超参数配置文件
data/label\_list | 类别列表
data/prepare\_voc\_data.py | 准备训练PASCAL VOC数据列表
</center>
<center>表1. 示例文件</center>
训练阶段需要对数据做预处理,包括裁剪、采样等,这部分操作在```image_util.py```和```data_provider.py```中完成;值得注意的是,```config/vgg_config.py```为参数配置文件,包括训练参数、神经网络参数等,本配置文件包含参数是针对PASCAL VOC数据配置的,当训练自有数据时,需要仿照该文件配置新的参数;```data/prepare_voc_data.py```脚本用来生成文件列表,包括切分训练集和测试集,使用时需要用户事先下载并解压数据,默认采用VOC2007和VOC2012。
## PASCAL VOC数据集
### 数据准备
首先需要下载数据集,VOC2007\[[2](#引用)\]和VOC2012\[[3](#引用)\],VOC2007包含训练集和测试集,VOC2012只包含训练集,将下载好的数据解压,目录结构为```VOCdevkit/{VOC2007,VOC2012}```。进入```data```目录,运行```python prepare_voc_data.py```即可生成```trainval.txt```和```test.txt```,默认```prepare_voc_data.py```和```VOCdevkit```在相同目录下,且生成的文件列表也在该目录。需注意```trainval.txt```既包含VOC2007的训练数据,也包含VOC2012的训练数据,```test.txt```只包含VOC2007的测试数据。
### 预训练模型准备
下载训练好的VGG-16模型,推荐在ImageNet分类数据集上预训练的模型,针对caffe训练的模型,PaddlePaddle提供转换脚本,可方便转换成PaddlePaddle格式(待扩展),这里默认转换后的模型路径为```atrous_vgg/model.tar.gz```。
### 模型训练
直接执行```python train.py```即可进行训练。需要注意本示例仅支持CUDA GPU环境,无法在CPU上训练。```train.py```的一些关键执行逻辑:
```python
paddle.init(use_gpu=True, trainer_count=4)
data_args = data_provider.Settings(
data_dir='./data',
label_file='label_list',
resize_h=cfg.IMG_HEIGHT,
resize_w=cfg.IMG_WIDTH,
mean_value=[104,117,124])
train(train_file_list='./data/trainval.txt',
dev_file_list='./data/test.txt',
data_args=data_args,
init_model_path='./atrous_vgg/model.tar.gz')
```
调用```paddle.init```指定使用4卡GPU训练;调用```data_provider.Settings```配置数据预处理所需参数,其中```cfg.IMG_HEIGHT```和```cfg.IMG_WIDTH```在配置文件```config/vgg_config.py```中设置,这里均为300;调用```train```执行训练,其中```train_file_list```指定训练数据列表,```dev_file_list```指定评估数据列表,```init_model_path```指定预选模型位置。训练过程中会打印一些日志信息,每训练10个batch会输出当前的轮数、当前batch的cost及mAP,每训练一个pass,会保存一次模型,默认保存在```models```目录下(注:需事先创建)。
### 模型评估
### 图像检测
## 自有数据集
## 引用
1. Liu, Wei, et al. "Ssd: Single shot multibox detector." European conference on computer vision. Springer, Cham, 2016.
2. http://host.robots.ox.ac.uk/pascal/VOC/voc2007/index.html
3. http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html
</div>
<!-- You can change the lines below now. -->
<script type="text/javascript">
marked.setOptions({
renderer: new marked.Renderer(),
gfm: true,
breaks: false,
smartypants: true,
highlight: function(code, lang) {
code = code.replace(/&amp;/g, "&")
code = code.replace(/&gt;/g, ">")
code = code.replace(/&lt;/g, "<")
code = code.replace(/&nbsp;/g, " ")
return hljs.highlightAuto(code, [lang]).value;
}
});
document.getElementById("context").innerHTML = marked(
document.getElementById("markdown").innerHTML)
</script>
</body>
import paddle.v2 as paddle
import data_provider
import vgg_ssd_net
import os, sys
import numpy as np
import gzip
from PIL import Image
from config.pascal_voc_conf import cfg
def _infer(inferer, infer_data, threshold):
ret = []
infer_res = inferer.infer(input=infer_data)
keep_inds = np.where(infer_res[:, 2] >= threshold)[0]
for idx in keep_inds:
ret.append([
infer_res[idx][0], infer_res[idx][1] - 1, infer_res[idx][2],
infer_res[idx][3], infer_res[idx][4], infer_res[idx][5],
infer_res[idx][6]
])
return ret
def infer(eval_file_list, save_path, data_args, batch_size, model_path,
threshold):
detect_out = vgg_net_ssd_v2.net_conf(mode='infer')
assert os.path.isfile(init_model_path), 'Invalid model.'
parameters = paddle.parameters.Parameters.from_tar(gzip.open(model_path))
inferer = paddle.inference.Inference(
output_layer=detect_out, parameters=parameters)
reader = data_provider.infer(data_args, eval_file_list)
all_fname_list = [line.strip() for line in open(eval_file_list).readlines()]
test_data = []
fname_list = []
img_w = []
img_h = []
idx = 0
"""Do inference batch by batch,
coords of bbox will be scaled based on image size
"""
with open(save_path, 'w') as fout:
for img in reader():
test_data.append([img])
fname_list.append(all_fname_list[idx])
w, h = \
Image.open(os.path.join('./data', fname_list[-1])).size
img_w.append(w)
img_h.append(h)
if len(test_data) == batch_size:
ret_res = _infer(inferer, test_data, threshold)
for det_res in ret_res:
img_idx = int(det_res[0])
label = int(det_res[1])
conf_score = det_res[2]
xmin = det_res[3] * img_w[img_idx]
ymin = det_res[4] * img_h[img_idx]
xmax = det_res[5] * img_w[img_idx]
ymax = det_res[6] * img_h[img_idx]
fout.write(fname_list[img_idx] + '\t' + str(label) + '\t' +
str(conf_score) + '\t' + str(xmin) + ' ' + str(
ymin) + ' ' + str(xmax) + ' ' + str(
ymax) + '\n')
test_data = []
fname_list = []
img_w = []
img_h = []
idx += 1
if len(test_data) > 0:
ret_res = _infer(inferer, test_data, threshold)
for det_res in ret_res:
img_idx = int(det_res[0])
label = int(det_res[1])
conf_score = det_res[2]
xmin = det_res[3] * img_w[img_idx]
ymin = det_res[4] * img_h[img_idx]
xmax = det_res[5] * img_w[img_idx]
ymax = det_res[6] * img_h[img_idx]
fout.write(fname_list[img_idx] + '\t' + str(label) + '\t' + str(
conf_score) + '\t' + str(xmin) + ' ' + str(ymin) + ' ' +
str(xmax) + ' ' + str(ymax) + '\n')
if __name__ == "__main__":
paddle.init(use_gpu=True, trainer_count=1)
data_args = data_provider.Settings(
data_dir='./data',
label_file='label_list',
resize_h=cfg.IMG_HEIGHT,
resize_w=cfg.IMG_WIDTH,
mean_value=[104, 117, 124])
infer(
eval_file_list='./data/infer.txt',
save_path='infer.res',
data_args=data_args,
batch_size=4,
model_path='models/pass-00000.tar.gz',
threshold=0.3)
import paddle.v2 as paddle
import data_provider
import vgg_ssd_net
import os, sys
import gzip
import tarfile
from pascal_voc_conf import cfg
def train(train_file_list, dev_file_list, data_args, init_model_path):
cost, detect_out = vgg_ssd_net.net_conf('train')
parameters = paddle.parameters.create(cost)
if not (init_model_path is None):
assert os.path.isfile(init_model_path), 'Invalid model.'
fparams = paddle.parameters.Parameters.from_tar(
gzip.open(init_model_path))
for param_name in fparams.names():
parameters.set(param_name, fparams.get(param_name))
optimizer = paddle.optimizer.Momentum(
momentum=cfg.TRAIN.MOMENTUM,
learning_rate=cfg.TRAIN.LEARNING_RATE,
regularization=paddle.optimizer.L2Regularization(
rate=cfg.TRAIN.L2REGULARIZATION),
learning_rate_decay_a=cfg.TRAIN.LEARNING_RATE_DECAY_A,
learning_rate_decay_b=cfg.TRAIN.LEARNING_RATE_DECAY_B,
learning_rate_schedule=cfg.TRAIN.LEARNING_RATE_SCHEDULE)
trainer = paddle.trainer.SGD(
cost=cost,
parameters=parameters,
extra_layers=[detect_out],
update_equation=optimizer)
feeding = {'image': 0, 'bbox': 1}
train_reader = paddle.batch(
paddle.reader.shuffle(
data_provider.train(data_args, train_file_list), buf_size=200),
batch_size=cfg.TRAIN.BATCH_SIZE) # generate a batch image each time
dev_reader = paddle.batch(
data_provider.test(data_args, dev_file_list),
batch_size=cfg.TRAIN.BATCH_SIZE)
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 1 == 0:
print "\nPass %d, Batch %d, TrainCost %f, Detection mAP=%f" % \
(event.pass_id,
event.batch_id,
event.cost,
event.metrics['detection_evaluator'])
else:
sys.stdout.write('.')
sys.stdout.flush()
if isinstance(event, paddle.event.EndPass):
with gzip.open('models/params_pass_%05d.tar.gz' % event.pass_id,
'w') as f:
parameters.to_tar(f)
result = trainer.test(reader=dev_reader, feeding=feeding)
print "\nTest with Pass %d, TestCost: %f, Detection mAP=%g" % \
(event.pass_id,
result.cost,
result.metrics['detection_evaluator'])
trainer.train(
reader=train_reader,
event_handler=event_handler,
num_passes=cfg.TRAIN.NUM_PASS,
feeding=feeding)
if __name__ == "__main__":
paddle.init(use_gpu=True, trainer_count=4)
data_args = data_provider.Settings(
data_dir='./data',
label_file='label_list',
resize_h=cfg.IMG_HEIGHT,
resize_w=cfg.IMG_WIDTH,
mean_value=[104, 117, 124])
train(
train_file_list='./data/trainval.txt',
dev_file_list='./data/test.txt',
data_args=data_args,
init_model_path='./atrous_vgg/model.tar.gz')
import paddle.v2 as paddle
from config.vgg_config import cfg
def net_conf(mode):
"""Network configuration. Total three modes included 'train' 'eval'
and 'infer'. Loss and mAP evaluation layer will return if using 'train'
and 'eval'. In 'infer' mode, only detection output layer will be returned.
"""
default_l2regularization = cfg.TRAIN.L2REGULARIZATION
default_bias_attr = paddle.attr.ParamAttr(
l2_rate=0.0, learning_rate=2.0, momentum=cfg.TRAIN.MOMENTUM)
default_static_bias_attr = paddle.attr.ParamAttr(is_static=True)
def xavier(channels, filter_size, local_lr, regularization):
init_w = (3.0 / (filter_size**2 * channels))**0.5
is_static = False
if local_lr == 0.0:
is_static = True
return paddle.attr.ParamAttr(
initial_min=(0.0 - init_w),
initial_max=init_w,
learning_rate=local_lr,
l2_rate=regularization,
momentum=cfg.TRAIN.MOMENTUM,
is_static=is_static)
def vgg_block(idx_str, input, num_channels, num_filters, pool_size,
pool_stride, pool_pad):
layer_name = "conv%s_" % idx_str
conv1 = paddle.layer.img_conv(
name=layer_name + "1",
input=input,
filter_size=3,
num_channels=num_channels,
num_filters=num_filters,
stride=1,
padding=1,
bias_attr=default_bias_attr,
param_attr=xavier(num_filters, 3, 1, default_l2regularization),
act=paddle.activation.Relu())
conv2 = paddle.layer.img_conv(
name=layer_name + "2",
input=conv1,
filter_size=3,
num_channels=num_filters,
num_filters=num_filters,
stride=1,
padding=1,
bias_attr=default_bias_attr,
param_attr=xavier(num_filters, 3, 1, default_l2regularization),
act=paddle.activation.Relu())
conv3 = paddle.layer.img_conv(
name=layer_name + "3",
input=conv2,
filter_size=3,
num_channels=num_filters,
num_filters=num_filters,
stride=1,
padding=1,
bias_attr=default_bias_attr,
param_attr=xavier(num_filters, 3, 1, default_l2regularization),
act=paddle.activation.Relu())
pool = paddle.layer.img_pool(
input=conv3,
pool_size=pool_size,
num_channels=num_filters,
pool_type=paddle.pooling.CudnnMax(),
stride=pool_stride,
padding=pool_pad)
return conv3, pool
def mbox_block(layer_idx, input, num_channels, filter_size, loc_filters,
conf_filters):
mbox_loc_name = layer_idx + "_mbox_loc"
mbox_loc = paddle.layer.img_conv(
name=mbox_loc_name,
input=input,
filter_size=filter_size,
num_channels=num_channels,
num_filters=loc_filters,
stride=1,
padding=1,
bias_attr=default_bias_attr,
param_attr=xavier(loc_filters, filter_size, 1,
default_l2regularization),
act=paddle.activation.Identity())
mbox_conf_name = layer_idx + "_mbox_conf"
mbox_conf = paddle.layer.img_conv(
name=mbox_conf_name,
input=input,
filter_size=filter_size,
num_channels=num_channels,
num_filters=conf_filters,
stride=1,
padding=1,
bias_attr=default_bias_attr,
param_attr=xavier(conf_filters, filter_size, 1,
default_l2regularization),
act=paddle.activation.Identity())
return mbox_loc, mbox_conf
def ssd_block(layer_idx, input, img_shape, num_channels, num_filters1,
num_filters2, aspect_ratio, variance, min_size, max_size):
layer_name = "conv" + layer_idx + "_"
conv1_name = layer_name + "1"
conv1 = paddle.layer.img_conv(
name=conv1_name,
input=input,
filter_size=1,
num_channels=num_channels,
num_filters=num_filters1,
stride=1,
padding=0,
bias_attr=default_bias_attr,
param_attr=xavier(num_filters1, 1, 1, default_l2regularization),
act=paddle.activation.Relu())
conv2_name = layer_name + "2"
conv2 = paddle.layer.img_conv(
name=conv2_name,
input=conv1,
filter_size=3,
num_channels=num_filters1,
num_filters=num_filters2,
stride=2,
padding=1,
bias_attr=default_bias_attr,
param_attr=xavier(num_filters2, 3, 1, default_l2regularization),
act=paddle.activation.Relu())
loc_filters = (len(aspect_ratio) * 2 + 1 + len(max_size)) * 4
conf_filters = (
len(aspect_ratio) * 2 + 1 + len(max_size)) * cfg.CLASS_NUM
mbox_loc, mbox_conf = mbox_block(conv2_name, conv2, num_filters2, 3,
loc_filters, conf_filters)
mbox_priorbox = paddle.layer.priorbox(
input=conv2,
image=img_shape,
min_size=min_size,
max_size=max_size,
aspect_ratio=aspect_ratio,
variance=variance)
return conv2, mbox_loc, mbox_conf, mbox_priorbox
img = paddle.layer.data(
name='image',
type=paddle.data_type.dense_vector(cfg.IMG_CHANNEL * cfg.IMG_HEIGHT *
cfg.IMG_WIDTH),
height=cfg.IMG_HEIGHT,
width=cfg.IMG_WIDTH)
conv1_1 = paddle.layer.img_conv(
name="conv1_1",
input=img,
filter_size=3,
num_channels=3,
num_filters=64,
stride=1,
padding=1,
bias_attr=default_static_bias_attr,
param_attr=xavier(64, 3, 0, 0),
act=paddle.activation.Relu())
conv1_2 = paddle.layer.img_conv(
name="conv1_2",
input=conv1_1,
filter_size=3,
num_channels=64,
num_filters=64,
stride=1,
padding=1,
bias_attr=default_static_bias_attr,
param_attr=xavier(64, 3, 0, 0),
act=paddle.activation.Relu())
pool1 = paddle.layer.img_pool(
name="pool1",
input=conv1_2,
pool_type=paddle.pooling.CudnnMax(),
pool_size=2,
num_channels=64,
stride=2)
conv2_1 = paddle.layer.img_conv(
name="conv2_1",
input=pool1,
filter_size=3,
num_channels=64,
num_filters=128,
stride=1,
padding=1,
bias_attr=default_static_bias_attr,
param_attr=xavier(128, 3, 0, 0),
act=paddle.activation.Relu())
conv2_2 = paddle.layer.img_conv(
name="conv2_2",
input=conv2_1,
filter_size=3,
num_channels=128,
num_filters=128,
stride=1,
padding=1,
bias_attr=default_static_bias_attr,
param_attr=xavier(128, 3, 0, 0),
act=paddle.activation.Relu())
pool2 = paddle.layer.img_pool(
name="pool2",
input=conv2_2,
pool_type=paddle.pooling.CudnnMax(),
pool_size=2,
num_channels=128,
stride=2)
conv3_3, pool3 = vgg_block("3", pool2, 128, 256, 2, 2, 0)
conv4_3, pool4 = vgg_block("4", pool3, 256, 512, 2, 2, 0)
conv4_3_mbox_priorbox = paddle.layer.priorbox(
input=conv4_3,
image=img,
min_size=cfg.NET.CONV4.PB.MIN_SIZE,
aspect_ratio=cfg.NET.CONV4.PB.ASPECT_RATIO,
variance=cfg.NET.CONV4.PB.VARIANCE)
conv4_3_norm = paddle.layer.cross_channel_norm(
name="conv4_3_norm",
input=conv4_3,
param_attr=paddle.attr.ParamAttr(
initial_mean=20,
initial_std=0,
is_static=False,
learning_rate=1,
momentum=cfg.TRAIN.MOMENTUM))
conv4_3_norm_mbox_loc, conv4_3_norm_mbox_conf = \
mbox_block("conv4_3_norm", conv4_3_norm, 512, 3, 12, 63)
conv5_3, pool5 = vgg_block("5", pool4, 512, 512, 3, 1, 1)
fc6 = paddle.layer.img_conv(
name="fc6",
input=pool5,
filter_size=3,
num_channels=512,
num_filters=1024,
stride=1,
padding=1,
bias_attr=default_bias_attr,
param_attr=xavier(1024, 3, 1, default_l2regularization),
act=paddle.activation.Relu())
fc7 = paddle.layer.img_conv(
name="fc7",
input=fc6,
filter_size=1,
num_channels=1024,
num_filters=1024,
stride=1,
padding=0,
bias_attr=default_bias_attr,
param_attr=xavier(1024, 1, 1, default_l2regularization),
act=paddle.activation.Relu())
fc7_mbox_loc, fc7_mbox_conf = mbox_block("fc7", fc7, 1024, 3, 24, 126)
fc7_mbox_priorbox = paddle.layer.priorbox(
input=fc7,
image=img,
min_size=cfg.NET.FC7.PB.MIN_SIZE,
max_size=cfg.NET.FC7.PB.MAX_SIZE,
aspect_ratio=cfg.NET.FC7.PB.ASPECT_RATIO,
variance=cfg.NET.FC7.PB.VARIANCE)
conv6_2, conv6_2_mbox_loc, conv6_2_mbox_conf, conv6_2_mbox_priorbox = \
ssd_block("6", fc7, img, 1024, 256, 512,
cfg.NET.CONV6.PB.ASPECT_RATIO,
cfg.NET.CONV6.PB.VARIANCE,
cfg.NET.CONV6.PB.MIN_SIZE,
cfg.NET.CONV6.PB.MAX_SIZE)
conv7_2, conv7_2_mbox_loc, conv7_2_mbox_conf, conv7_2_mbox_priorbox = \
ssd_block("7", conv6_2, img, 512, 128, 256,
cfg.NET.CONV7.PB.ASPECT_RATIO,
cfg.NET.CONV7.PB.VARIANCE,
cfg.NET.CONV7.PB.MIN_SIZE,
cfg.NET.CONV7.PB.MAX_SIZE)
conv8_2, conv8_2_mbox_loc, conv8_2_mbox_conf, conv8_2_mbox_priorbox = \
ssd_block("8", conv7_2, img, 256, 128, 256,
cfg.NET.CONV8.PB.ASPECT_RATIO,
cfg.NET.CONV8.PB.VARIANCE,
cfg.NET.CONV8.PB.MIN_SIZE,
cfg.NET.CONV8.PB.MAX_SIZE)
pool6 = paddle.layer.img_pool(
name="pool6",
input=conv8_2,
pool_size=3,
num_channels=256,
stride=1,
pool_type=paddle.pooling.Avg())
pool6_mbox_loc, pool6_mbox_conf = mbox_block("pool6", pool6, 256, 3, 24,
126)
pool6_mbox_priorbox = paddle.layer.priorbox(
input=pool6,
image=img,
min_size=cfg.NET.POOL6.PB.MIN_SIZE,
max_size=cfg.NET.POOL6.PB.MAX_SIZE,
aspect_ratio=cfg.NET.POOL6.PB.ASPECT_RATIO,
variance=cfg.NET.POOL6.PB.VARIANCE)
mbox_priorbox = paddle.layer.concat(
name="mbox_priorbox",
input=[
conv4_3_mbox_priorbox, fc7_mbox_priorbox, conv6_2_mbox_priorbox,
conv7_2_mbox_priorbox, conv8_2_mbox_priorbox, pool6_mbox_priorbox
])
loc_loss_input = [
conv4_3_norm_mbox_loc, fc7_mbox_loc, conv6_2_mbox_loc, conv7_2_mbox_loc,
conv8_2_mbox_loc, pool6_mbox_loc
]
conf_loss_input = [
conv4_3_norm_mbox_conf, fc7_mbox_conf, conv6_2_mbox_conf,
conv7_2_mbox_conf, conv8_2_mbox_conf, pool6_mbox_conf
]
detection_out = paddle.layer.detection_output(
input_loc=loc_loss_input,
input_conf=conf_loss_input,
priorbox=mbox_priorbox,
confidence_threshold=cfg.NET.DETOUT.CONFIDENCE_THRESHOLD,
nms_threshold=cfg.NET.DETOUT.NMS_THRESHOLD,
num_classes=cfg.CLASS_NUM,
nms_top_k=cfg.NET.DETOUT.NMS_TOP_K,
keep_top_k=cfg.NET.DETOUT.KEEP_TOP_K,
background_id=cfg.BACKGROUND_ID,
name="detection_output")
if mode == 'train' or mode == 'eval':
bbox = paddle.layer.data(
name='bbox', type=paddle.data_type.dense_vector_sequence(6))
loss = paddle.layer.multibox_loss(
input_loc=loc_loss_input,
input_conf=conf_loss_input,
priorbox=mbox_priorbox,
label=bbox,
num_classes=cfg.CLASS_NUM,
overlap_threshold=cfg.NET.MBLOSS.OVERLAP_THRESHOLD,
neg_pos_ratio=cfg.NET.MBLOSS.NEG_POS_RATIO,
neg_overlap=cfg.NET.MBLOSS.NEG_OVERLAP,
background_id=cfg.BACKGROUND_ID,
name="multibox_loss")
paddle.evaluator.detection_map(
input=detection_out,
label=bbox,
overlap_threshold=cfg.NET.DETMAP.OVERLAP_THRESHOLD,
background_id=cfg.BACKGROUND_ID,
evaluate_difficult=cfg.NET.DETMAP.EVAL_DIFFICULT,
ap_type=cfg.NET.DETMAP.AP_TYPE,
name="detection_evaluator")
return loss, detection_out
elif mode == 'infer':
return detection_out
import cv2
import os
data_dir = './data'
infer_file = './infer.res'
path_to_im = dict()
for line in open(infer_file):
img_path, _, _, _ = line.strip().split('\t')
if img_path not in path_to_im:
im = cv2.imread(os.path.join(data_dir, img_path))
path_to_im[img_path] = im
for line in open(infer_file):
img_path, label, conf, bbox = line.strip().split('\t')
xmin, ymin, xmax, ymax = map(float, bbox.split(' '))
xmin = int(round(xmin))
ymin = int(round(ymin))
xmax = int(round(xmax))
ymax = int(round(ymax))
img = path_to_im[img_path]
cv2.rectangle(img, (xmin, ymin), (xmax, ymax),
(0, (1 - xmin) * 255, xmin * 255), 2)
for img_path in path_to_im:
im = path_to_im[img_path]
cv2.imwrite(img_path, im)
print 'Done.'
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册