提交 27fd7982 编写于 作者: W wuzewu

add ssd demo

上级 0a4b1020
import os
import numpy as np
import processor
import paddle_hub as hub
import paddle
import paddle.fluid as fluid
from mobilenet_ssd import mobile_net
def build_program():
image_shape = [3, 300, 300]
class_num = 21
image = fluid.layers.data(dtype="float32", shape=image_shape, name="image")
gt_box = fluid.layers.data(
dtype="float32", shape=[4], name="gtbox", lod_level=1)
gt_label = fluid.layers.data(
dtype="int32", shape=[1], name="label", lod_level=1)
difficult = fluid.layers.data(
dtype="int32", shape=[1], name="difficult", lod_level=1)
with fluid.unique_name.guard():
locs, confs, box, box_var = mobile_net(class_num, image, image_shape)
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=0.45)
return image, nmsed_out
def create_module():
image, nmsed_out = build_program()
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
pretrained_model = "resources/ssd_mobilenet_v1_pascalvoc"
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
assets = ["resources/label_list.txt"]
sign = hub.create_signature(
"object_detection", inputs=[image], outputs=[nmsed_out])
hub.create_module(
sign_arr=[sign],
module_dir="hub_module_ssd",
exe=exe,
processor=processor.Processor,
assets=assets)
if __name__ == '__main__':
create_module()
#!/bin/bash
set -o nounset
set -o errexit
script_path=$(cd `dirname $0`; pwd)
cd $script_path
python create_module.py
python ../../paddle_hub/commands/hub.py run hub_module_ssd/ --signature object_detection --config resources/test/test.yml --dataset resources/test/test.csv
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
def conv_bn(input,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
act='relu',
use_cudnn=True):
parameter_attr = ParamAttr(learning_rate=0.1, initializer=MSRA())
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=parameter_attr,
bias_attr=False)
return fluid.layers.batch_norm(input=conv, act=act)
def depthwise_separable(input, num_filters1, num_filters2, num_groups, stride,
scale):
depthwise_conv = conv_bn(
input=input,
filter_size=3,
num_filters=int(num_filters1 * scale),
stride=stride,
padding=1,
num_groups=int(num_groups * scale),
use_cudnn=False)
pointwise_conv = conv_bn(
input=depthwise_conv,
filter_size=1,
num_filters=int(num_filters2 * scale),
stride=1,
padding=0)
return pointwise_conv
def extra_block(input, num_filters1, num_filters2, num_groups, stride, scale):
# 1x1 conv
pointwise_conv = conv_bn(
input=input,
filter_size=1,
num_filters=int(num_filters1 * scale),
stride=1,
num_groups=int(num_groups * scale),
padding=0)
# 3x3 conv
normal_conv = conv_bn(
input=pointwise_conv,
filter_size=3,
num_filters=int(num_filters2 * scale),
stride=2,
num_groups=int(num_groups * scale),
padding=1)
return normal_conv
def mobile_net(num_classes, img, img_shape, scale=1.0):
# 300x300
tmp = conv_bn(img, 3, int(32 * scale), 2, 1, 3)
# 150x150
tmp = depthwise_separable(tmp, 32, 64, 32, 1, scale)
tmp = depthwise_separable(tmp, 64, 128, 64, 2, scale)
# 75x75
tmp = depthwise_separable(tmp, 128, 128, 128, 1, scale)
tmp = depthwise_separable(tmp, 128, 256, 128, 2, scale)
# 38x38
tmp = depthwise_separable(tmp, 256, 256, 256, 1, scale)
tmp = depthwise_separable(tmp, 256, 512, 256, 2, scale)
# 19x19
for i in range(5):
tmp = depthwise_separable(tmp, 512, 512, 512, 1, scale)
module11 = tmp
tmp = depthwise_separable(tmp, 512, 1024, 512, 2, scale)
# 10x10
module13 = depthwise_separable(tmp, 1024, 1024, 1024, 1, scale)
module14 = extra_block(module13, 256, 512, 1, 2, scale)
# 5x5
module15 = extra_block(module14, 128, 256, 1, 2, scale)
# 3x3
module16 = extra_block(module15, 128, 256, 1, 2, scale)
# 2x2
module17 = extra_block(module16, 64, 128, 1, 2, scale)
mbox_locs, mbox_confs, box, box_var = fluid.layers.multi_box_head(
inputs=[module11, module13, module14, module15, module16, module17],
image=img,
num_classes=num_classes,
min_ratio=20,
max_ratio=90,
min_sizes=[60.0, 105.0, 150.0, 195.0, 240.0, 285.0],
max_sizes=[[], 150.0, 195.0, 240.0, 285.0, 300.0],
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2., 3.]],
base_size=img_shape[2],
offset=0.5,
flip=True)
return mbox_locs, mbox_confs, box, box_var
import paddle
import paddle_hub as hub
import numpy as np
import os
from paddle_hub import BaseProcessor
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
def clip_bbox(bbox):
xmin = max(min(bbox[0], 1.), 0.)
ymin = max(min(bbox[1], 1.), 0.)
xmax = max(min(bbox[2], 1.), 0.)
ymax = max(min(bbox[3], 1.), 0.)
return xmin, ymin, xmax, ymax
def draw_bounding_box_on_image(image_path, data_list, save_path):
image = Image.open(image_path)
draw = ImageDraw.Draw(image)
for data in data_list:
left, right, top, bottom = data['left'], data['right'], data[
'top'], data['bottom']
draw.line([(left, top), (left, bottom), (right, bottom), (right, top),
(left, top)],
width=4,
fill='red')
if image.mode == 'RGB':
draw.text((left, top), data['label'], (255, 255, 0))
image_name = image_path.split('/')[-1]
if not os.path.exists(save_path):
os.mkdir(save_path)
save_path = os.path.join(save_path, image_name)
print("image with bbox drawed saved as {}".format(save_path))
image.save(save_path)
class Processor(BaseProcessor):
def __init__(self, module):
self.module = module
label_list_file = os.path.join(self.module.helper.assets_path(),
"label_list.txt")
with open(label_list_file, "r") as file:
content = file.read()
self.label_list = content.split("\n")
self.confs_threshold = 0.5
def preprocess(self, sign_name, data_dict):
def process_image(img):
if img.mode == 'L':
img = im.convert('RGB')
im_width, im_height = img.size
img = img.resize((300, 300), Image.ANTIALIAS)
img = np.array(img)
# HWC to CHW
if len(img.shape) == 3:
img = np.swapaxes(img, 1, 2)
img = np.swapaxes(img, 1, 0)
# RBG to BGR
img = img[[2, 1, 0], :, :]
img = img.astype('float32')
mean_value = [127.5, 127.5, 127.5]
mean_value = np.array(mean_value)[:, np.newaxis, np.newaxis].astype(
'float32')
img -= mean_value
img = img * 0.007843
return img
result = {'image': []}
for path in data_dict['image']:
img = Image.open(path)
im_width, im_height = img.size
result_i = {}
result_i['path'] = path
result_i['width'] = im_width
result_i['height'] = im_height
result_i['processed'] = process_image(img)
result['image'].append(result_i)
return result
def postprocess(self, sign_name, data_out, data_info, **kwargs):
if sign_name == "object_detection":
lod_tensor = data_out[0]
lod = lod_tensor.lod()[0]
results = np.array(data_out[0])
output = []
for index in range(len(lod) - 1):
result_i = results[lod[index]:lod[index + 1]]
output_i = {
'path': data_info['image'][index]['path'],
'data': []
}
for dt in result_i:
if dt[1] < self.confs_threshold:
continue
dt_i = {}
category_id = dt[0]
bbox = dt[2:]
xmin, ymin, xmax, ymax = clip_bbox(dt[2:])
(left, right, top,
bottom) = (xmin * data_info['image'][index]['width'],
xmax * data_info['image'][index]['width'],
ymin * data_info['image'][index]['height'],
ymax * data_info['image'][index]['height'])
dt_i['left'] = left
dt_i['right'] = right
dt_i['top'] = top
dt_i['bottom'] = bottom
dt_i['label'] = self.label_list[int(category_id)]
output_i['data'].append(dt_i)
draw_bounding_box_on_image(
output_i['path'], output_i['data'], save_path="test_result")
output.append(output_i)
return output
def data_format(self, sign_name):
if sign_name == "object_detection":
return {
"image": {
'type': hub.DataType.IMAGE,
'feed_key': self.module.signatures[sign_name].inputs[0].name
}
}
return None
#!/bin/bash
set -o nounset
set -o errexit
script_path=$(cd `dirname $0`; pwd)
cd $script_path
wget --no-check-certificate https://paddlehub.bj.bcebos.com/paddle_model/ssd_mobilenet_v1_pascalvoc.tar.gz
tar xvzf ssd_mobilenet_v1_pascalvoc.tar.gz
rm ssd_mobilenet_v1_pascalvoc.tar.gz
background
aeroplane
bicycle
bird
boat
bottle
bus
car
cat
chair
cow
diningtable
dog
horse
motorbike
person
pottedplant
sheep
sofa
train
tvmonitor
name: ssd_mobilenet_v1_pascalvoc
type: CV/object-detection
author: paddlepaddle
author_email: paddle-dev@baidu.com
version: 1.0.0
IMAGE_PATH
./resources/test/test_img_sheep.jpg
./resources/test/test_img_cat.jpg
./resources/test/test_img_bird.jpg
input_data:
image:
type : IMAGE
key : IMAGE_PATH
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册