未验证 提交 0ed21a48 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update ssd_vgg16_512_coco2017 (#1950)

* update ssd_vgg16_512_model

* update unittest

* update unittest

* update gpu config

* update

* add clean func

* update save inference model
Co-authored-by: Nwuzewu <wuzewu@baidu.com>
Co-authored-by: Nchenjian <chenjian26@baidu.com>
上级 51427477
...@@ -100,19 +100,13 @@ ...@@ -100,19 +100,13 @@
- save\_path (str, optional): 识别结果的保存路径 (仅当visualization=True时存在) - save\_path (str, optional): 识别结果的保存路径 (仅当visualization=True时存在)
- ```python - ```python
def save_inference_model(dirname, def save_inference_model(dirname)
model_filename=None,
params_filename=None,
combined=True)
``` ```
- 将模型保存到指定路径。 - 将模型保存到指定路径。
- **参数** - **参数**
- dirname: 存在模型的目录名称; <br/> - dirname: 模型保存路径 <br/>
- model\_filename: 模型文件名称,默认为\_\_model\_\_; <br/>
- params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效);<br/>
- combined: 是否将参数保存到统一的一个文件中。
## 四、服务部署 ## 四、服务部署
...@@ -166,6 +160,10 @@ ...@@ -166,6 +160,10 @@
修复numpy数据读取问题 修复numpy数据读取问题
* 1.1.0
移除 fluid api
- ```shell - ```shell
$ hub install ssd_vgg16_512_coco2017==1.0.2 $ hub install ssd_vgg16_512_coco2017==1.1.0
``` ```
...@@ -100,19 +100,13 @@ ...@@ -100,19 +100,13 @@
- save\_path (str, optional): output path for saving results - save\_path (str, optional): output path for saving results
- ```python - ```python
def save_inference_model(dirname, def save_inference_model(dirname)
model_filename=None,
params_filename=None,
combined=True)
``` ```
- Save model to specific path - Save model to specific path
- **Parameters** - **Parameters**
- dirname: output dir for saving model - dirname: model save path
- model\_filename: filename for saving model
- params\_filename: filename for saving parameters
- combined: whether save parameters into one file
## IV.Server Deployment ## IV.Server Deployment
...@@ -166,6 +160,10 @@ ...@@ -166,6 +160,10 @@
Fix the problem of reading numpy Fix the problem of reading numpy
* 1.1.0
移除 fluid api
- ```shell - ```shell
$ hub install ssd_vgg16_512_coco2017==1.0.2 $ hub install ssd_vgg16_512_coco2017==1.1.0
``` ```
...@@ -5,12 +5,10 @@ from __future__ import division ...@@ -5,12 +5,10 @@ from __future__ import division
import os import os
import random import random
from collections import OrderedDict
import cv2 import cv2
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from paddle import fluid
__all__ = ['reader'] __all__ = ['reader']
......
...@@ -7,41 +7,43 @@ import os ...@@ -7,41 +7,43 @@ import os
from functools import partial from functools import partial
import yaml import yaml
import paddle
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.static
import paddlehub as hub from paddle.inference import Config, create_predictor
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
from paddlehub.common.paddle_helper import add_vars_prefix
from ssd_vgg16_512_coco2017.vgg import VGG from .processor import load_label_info, postprocess, base64_to_cv2
from ssd_vgg16_512_coco2017.processor import load_label_info, postprocess, base64_to_cv2 from .data_feed import reader
from ssd_vgg16_512_coco2017.data_feed import reader
@moduleinfo( @moduleinfo(
name="ssd_vgg16_512_coco2017", name="ssd_vgg16_512_coco2017",
version="1.0.2", version="1.1.0",
type="cv/object_detection", type="cv/object_detection",
summary="SSD with backbone VGG16, trained with dataset COCO.", summary="SSD with backbone VGG16, trained with dataset COCO.",
author="paddlepaddle", author="paddlepaddle",
author_email="paddle-dev@baidu.com") author_email="paddle-dev@baidu.com")
class SSDVGG16_512(hub.Module): class SSDVGG16_512:
def _initialize(self): def __init__(self):
self.default_pretrained_model_path = os.path.join( self.default_pretrained_model_path = os.path.join(
self.directory, "ssd_vgg16_512_model") self.directory, "ssd_vgg16_512_model", "model")
self.label_names = load_label_info( self.label_names = load_label_info(
os.path.join(self.directory, "label_file.txt")) os.path.join(self.directory, "label_file.txt"))
self.model_config = None self.model_config = None
self._set_config() self._set_config()
def _set_config(self): def _set_config(self):
# predictor config setting. """
cpu_config = AnalysisConfig(self.default_pretrained_model_path) predictor config setting.
"""
model = self.default_pretrained_model_path+'.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
cpu_config.switch_ir_optim(False) cpu_config.switch_ir_optim(False)
self.cpu_predictor = create_paddle_predictor(cpu_config) self.cpu_predictor = create_predictor(cpu_config)
try: try:
_places = os.environ["CUDA_VISIBLE_DEVICES"] _places = os.environ["CUDA_VISIBLE_DEVICES"]
...@@ -50,10 +52,10 @@ class SSDVGG16_512(hub.Module): ...@@ -50,10 +52,10 @@ class SSDVGG16_512(hub.Module):
except: except:
use_gpu = False use_gpu = False
if use_gpu: if use_gpu:
gpu_config = AnalysisConfig(self.default_pretrained_model_path) gpu_config = Config(model, params)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0) gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config) self.gpu_predictor = create_predictor(gpu_config)
# model config setting. # model config setting.
if not self.model_config: if not self.model_config:
...@@ -63,107 +65,6 @@ class SSDVGG16_512(hub.Module): ...@@ -63,107 +65,6 @@ class SSDVGG16_512(hub.Module):
self.multi_box_head_config = self.model_config['MultiBoxHead'] self.multi_box_head_config = self.model_config['MultiBoxHead']
self.output_decoder_config = self.model_config['SSDOutputDecoder'] self.output_decoder_config = self.model_config['SSDOutputDecoder']
def context(self, trainable=True, pretrained=True, get_prediction=False):
"""
Distill the Head Features, so as to perform transfer learning.
Args:
trainable (bool): whether to set parameters trainable.
pretrained (bool): whether to load default pretrained model.
get_prediction (bool): whether to get prediction.
Returns:
inputs(dict): the input variables.
outputs(dict): the output variables.
context_prog (Program): the program to execute transfer learning.
"""
context_prog = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(context_prog, startup_program):
with fluid.unique_name.guard():
# image
image = fluid.layers.data(
name='image', shape=[3, 512, 512], dtype='float32')
# backbone
backbone = VGG(
depth=16,
with_extra_blocks=True,
normalizations=[20., -1, -1, -1, -1, -1, -1],
extra_block_filters=[[256, 512, 1, 2,
3], [128, 256, 1, 2, 3],
[128, 256, 1, 2,
3], [128, 256, 1, 2, 3],
[128, 256, 1, 1, 4]])
# body_feats
body_feats = backbone(image)
# im_size
im_size = fluid.layers.data(
name='im_size', shape=[2], dtype='int32')
# var_prefix
var_prefix = '@HUB_{}@'.format(self.name)
# names of inputs
inputs = {
'image': var_prefix + image.name,
'im_size': var_prefix + im_size.name
}
# names of outputs
if get_prediction:
locs, confs, box, box_var = fluid.layers.multi_box_head(
inputs=body_feats,
image=image,
num_classes=81,
**self.multi_box_head_config)
pred = fluid.layers.detection_output(
loc=locs,
scores=confs,
prior_box=box,
prior_box_var=box_var,
**self.output_decoder_config)
outputs = {'bbox_out': [var_prefix + pred.name]}
else:
outputs = {
'body_features':
[var_prefix + var.name for var in body_feats]
}
# add_vars_prefix
add_vars_prefix(context_prog, var_prefix)
add_vars_prefix(fluid.default_startup_program(), var_prefix)
# inputs
inputs = {
key: context_prog.global_block().vars[value]
for key, value in inputs.items()
}
outputs = {
out_key: [
context_prog.global_block().vars[varname]
for varname in out_value
]
for out_key, out_value in outputs.items()
}
# trainable
for param in context_prog.global_block().iter_parameters():
param.trainable = trainable
place = fluid.CPUPlace()
exe = fluid.Executor(place)
# pretrained
if pretrained:
def _if_exist(var):
return os.path.exists(
os.path.join(self.default_pretrained_model_path,
var.name))
fluid.io.load_vars(
exe,
self.default_pretrained_model_path,
predicate=_if_exist)
else:
exe.run(startup_program)
return inputs, outputs, context_prog
def object_detection(self, def object_detection(self,
paths=None, paths=None,
images=None, images=None,
...@@ -205,51 +106,31 @@ class SSDVGG16_512(hub.Module): ...@@ -205,51 +106,31 @@ class SSDVGG16_512(hub.Module):
paths = paths if paths else list() paths = paths if paths else list()
data_reader = partial(reader, paths, images) data_reader = partial(reader, paths, images)
batch_reader = fluid.io.batch(data_reader, batch_size=batch_size) batch_reader = paddle.batch(data_reader, batch_size=batch_size)
res = [] res = []
for iter_id, feed_data in enumerate(batch_reader()): for iter_id, feed_data in enumerate(batch_reader()):
feed_data = np.array(feed_data) feed_data = np.array(feed_data)
image_tensor = PaddleTensor(np.array(list(feed_data[:, 0])).copy())
if use_gpu:
data_out = self.gpu_predictor.run([image_tensor])
else:
data_out = self.cpu_predictor.run([image_tensor])
output = postprocess( predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
paths=paths, input_names = predictor.get_input_names()
images=images, input_handle = predictor.get_input_handle(input_names[0])
data_out=data_out, input_handle.copy_from_cpu(np.array(list(feed_data[:, 0])))
score_thresh=score_thresh,
label_names=self.label_names, predictor.run()
output_dir=output_dir, output_names = predictor.get_output_names()
handle_id=iter_id * batch_size, output_handle = predictor.get_output_handle(output_names[0])
visualization=visualization)
output = postprocess(paths=paths,
images=images,
data_out=output_handle,
score_thresh=score_thresh,
label_names=self.label_names,
output_dir=output_dir,
handle_id=iter_id * batch_size,
visualization=visualization)
res.extend(output) res.extend(output)
return res return res
def save_inference_model(self,
dirname,
model_filename=None,
params_filename=None,
combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving @serving
def serving_method(self, images, **kwargs): def serving_method(self, images, **kwargs):
""" """
......
...@@ -104,7 +104,7 @@ def postprocess(paths, ...@@ -104,7 +104,7 @@ def postprocess(paths,
handle_id, handle_id,
visualization=True): visualization=True):
""" """
postprocess the lod_tensor produced by fluid.Executor.run postprocess the lod_tensor produced by Executor.run
Args: Args:
paths (list[str]): the path of images. paths (list[str]): the path of images.
...@@ -127,9 +127,8 @@ def postprocess(paths, ...@@ -127,9 +127,8 @@ def postprocess(paths,
confidence (float): The confidence of detection result. confidence (float): The confidence of detection result.
save_path (str): The path to save output images. save_path (str): The path to save output images.
""" """
lod_tensor = data_out[0] lod = data_out.lod()[0]
lod = lod_tensor.lod[0] results = data_out.copy_to_cpu()
results = lod_tensor.as_ndarray()
check_dir(output_dir) check_dir(output_dir)
......
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://ai-studio-static-online.cdn.bcebos.com/68313e182f5e4ad9907e69dac9ece8fc50840d7ffbd24fa88396f009958f969a'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="ssd_vgg16_512_coco2017")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('detection_result')
def test_object_detection1(self):
results = self.module.object_detection(
paths=['tests/test.jpg']
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'cat')
self.assertTrue(confidence > 0.5)
self.assertTrue(200 < left < 800)
self.assertTrue(2500 < right < 3500)
self.assertTrue(500 < top < 1500)
self.assertTrue(3500 < bottom < 4500)
def test_object_detection2(self):
results = self.module.object_detection(
images=[cv2.imread('tests/test.jpg')]
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'cat')
self.assertTrue(confidence > 0.5)
self.assertTrue(200 < left < 800)
self.assertTrue(2500 < right < 3500)
self.assertTrue(500 < top < 1500)
self.assertTrue(3500 < bottom < 4500)
def test_object_detection3(self):
results = self.module.object_detection(
images=[cv2.imread('tests/test.jpg')],
visualization=False
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'cat')
self.assertTrue(confidence > 0.5)
self.assertTrue(200 < left < 800)
self.assertTrue(2500 < right < 3500)
self.assertTrue(500 < top < 1500)
self.assertTrue(3500 < bottom < 4500)
def test_object_detection4(self):
self.assertRaises(
AssertionError,
self.module.object_detection,
paths=['no.jpg']
)
def test_object_detection5(self):
self.assertRaises(
cv2.error,
self.module.object_detection,
images=['test.jpg']
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = ['VGG']
class VGG(object):
"""
VGG, see https://arxiv.org/abs/1409.1556
Args:
depth (int): the VGG net depth (16 or 19)
normalizations (list): params list of init scale in l2 norm, skip init
scale if param is -1.
with_extra_blocks (bool): whether or not extra blocks should be added
extra_block_filters (list): in each extra block, params:
[in_channel, out_channel, padding_size, stride_size, filter_size]
class_dim (int): number of class while classification
"""
def __init__(self,
depth=16,
with_extra_blocks=False,
normalizations=[20., -1, -1, -1, -1, -1],
extra_block_filters=[[256, 512, 1, 2, 3], [128, 256, 1, 2, 3],
[128, 256, 0, 1, 3], [128, 256, 0, 1, 3]],
class_dim=1000):
assert depth in [16, 19], "depth {} not in [16, 19]"
self.depth = depth
self.depth_cfg = {16: [2, 2, 3, 3, 3], 19: [2, 2, 4, 4, 4]}
self.with_extra_blocks = with_extra_blocks
self.normalizations = normalizations
self.extra_block_filters = extra_block_filters
self.class_dim = class_dim
def __call__(self, input):
layers = []
layers += self._vgg_block(input)
if not self.with_extra_blocks:
return layers[-1]
layers += self._add_extras_block(layers[-1])
norm_cfg = self.normalizations
for k, v in enumerate(layers):
if not norm_cfg[k] == -1:
layers[k] = self._l2_norm_scale(v, init_scale=norm_cfg[k])
return layers
def _vgg_block(self, input):
nums = self.depth_cfg[self.depth]
vgg_base = [64, 128, 256, 512, 512]
conv = input
res_layer = []
layers = []
for k, v in enumerate(vgg_base):
conv = self._conv_block(
conv, v, nums[k], name="conv{}_".format(k + 1))
layers.append(conv)
if self.with_extra_blocks:
if k == 4:
conv = self._pooling_block(conv, 3, 1, pool_padding=1)
else:
conv = self._pooling_block(conv, 2, 2)
else:
conv = self._pooling_block(conv, 2, 2)
if not self.with_extra_blocks:
fc_dim = 4096
fc_name = ["fc6", "fc7", "fc8"]
fc1 = fluid.layers.fc(
input=conv,
size=fc_dim,
act='relu',
param_attr=fluid.param_attr.ParamAttr(
name=fc_name[0] + "_weights"),
bias_attr=fluid.param_attr.ParamAttr(
name=fc_name[0] + "_offset"))
fc2 = fluid.layers.fc(
input=fc1,
size=fc_dim,
act='relu',
param_attr=fluid.param_attr.ParamAttr(
name=fc_name[1] + "_weights"),
bias_attr=fluid.param_attr.ParamAttr(
name=fc_name[1] + "_offset"))
out = fluid.layers.fc(
input=fc2,
size=self.class_dim,
param_attr=fluid.param_attr.ParamAttr(
name=fc_name[2] + "_weights"),
bias_attr=fluid.param_attr.ParamAttr(
name=fc_name[2] + "_offset"))
out = fluid.layers.softmax(out)
res_layer.append(out)
return [out]
else:
fc6 = self._conv_layer(conv, 1024, 3, 1, 6, dilation=6, name="fc6")
fc7 = self._conv_layer(fc6, 1024, 1, 1, 0, name="fc7")
return [layers[3], fc7]
def _add_extras_block(self, input):
cfg = self.extra_block_filters
conv = input
layers = []
for k, v in enumerate(cfg):
assert len(v) == 5, "extra_block_filters size not fix"
conv = self._extra_block(
conv,
v[0],
v[1],
v[2],
v[3],
v[4],
name="conv{}_".format(6 + k))
layers.append(conv)
return layers
def _conv_block(self, input, num_filter, groups, name=None):
conv = input
for i in range(groups):
conv = self._conv_layer(
input=conv,
num_filters=num_filter,
filter_size=3,
stride=1,
padding=1,
act='relu',
name=name + str(i + 1))
return conv
def _extra_block(self,
input,
num_filters1,
num_filters2,
padding_size,
stride_size,
filter_size,
name=None):
# 1x1 conv
conv_1 = self._conv_layer(
input=input,
num_filters=int(num_filters1),
filter_size=1,
stride=1,
act='relu',
padding=0,
name=name + "1")
# 3x3 conv
conv_2 = self._conv_layer(
input=conv_1,
num_filters=int(num_filters2),
filter_size=filter_size,
stride=stride_size,
act='relu',
padding=padding_size,
name=name + "2")
return conv_2
def _conv_layer(self,
input,
num_filters,
filter_size,
stride,
padding,
dilation=1,
act='relu',
use_cudnn=True,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
dilation=dilation,
act=act,
use_cudnn=use_cudnn,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=ParamAttr(
name=name + "_biases") if self.with_extra_blocks else False,
name=name + '.conv2d.output.1')
return conv
def _pooling_block(self,
conv,
pool_size,
pool_stride,
pool_padding=0,
ceil_mode=True):
pool = fluid.layers.pool2d(
input=conv,
pool_size=pool_size,
pool_type='max',
pool_stride=pool_stride,
pool_padding=pool_padding,
ceil_mode=ceil_mode)
return pool
def _l2_norm_scale(self, input, init_scale=1.0, channel_shared=False):
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.initializer import Constant
helper = LayerHelper("Scale")
l2_norm = fluid.layers.l2_normalize(
input, axis=1) # l2 norm along channel
shape = [1] if channel_shared else [input.shape[1]]
scale = helper.create_parameter(
attr=helper.param_attr,
shape=shape,
dtype=input.dtype,
default_initializer=Constant(init_scale))
out = fluid.layers.elementwise_mul(
x=l2_norm,
y=scale,
axis=-1 if channel_shared else 1,
name="conv4_3_norm_scale")
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册