未验证 提交 51427477 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update ssd_vgg16_300_coco2017 (#1949)

* update ssd_vgg16_300_model

* update unittest

* update unittest

* update gpu config

* update

* add clean func

* update save inference model
Co-authored-by: Nchenjian <chenjian26@baidu.com>
上级 01002e40
## 命令行预测 # ssd_vgg16_300_coco2017
```shell |模型名称|ssd_vgg16_300_coco2017|
$ hub run ssd_vgg16_300_coco2017 --input_path "/PATH/TO/IMAGE" | :--- | :---: |
``` |类别|图像 - 目标检测|
|网络|SSD|
|数据集|COCO2017|
|是否支持Fine-tuning|否|
|模型大小|139MB|
|最新更新日期|2021-03-15|
|数据指标|-|
## API
```python ## 一、模型基本信息
def context(trainable=True,
pretrained=True,
get_prediction=False)
```
提取特征,用于迁移学习。 - ### 应用效果展示
- 样例结果示例:
<p align="center">
<img src="https://user-images.githubusercontent.com/22424850/131506781-b4ecb77b-5ab1-4795-88da-5f547f7f7f9c.jpg" width='50%' hspace='10'/>
<br />
</p>
**参数** - ### 模型介绍
* trainable(bool): 参数是否可训练; - Single Shot MultiBox Detector (SSD) 是一种单阶段的目标检测器。与两阶段的检测方法不同,单阶段目标检测并不进行区域推荐,而是直接从特征图回归出目标的边界框和分类概率。SSD 运用了这种单阶段检测的思想,并且对其进行改进:在不同尺度的特征图上检测对应尺度的目标。该PaddleHub Module的基网络为VGG16模型,在Pascal数据集上预训练得到,目前仅支持预测。
* pretrained (bool): 是否加载预训练模型;
* get\_prediction (bool): 是否执行预测。
**返回**
* inputs (dict): 模型的输入,keys 包括 'image', 'im\_size',相应的取值为: ## 二、安装
* image (Variable): 图像变量
* im\_size (Variable): 图片的尺寸
* outputs (dict): 模型的输出。如果 get\_prediction 为 False,输出 'head\_features',否则输出 'bbox\_out'。
* context\_prog (Program): 用于迁移学习的 Program.
```python - ### 1、环境依赖
def object_detection(paths=None,
- paddlepaddle >= 1.6.2
- paddlehub >= 1.6.0 | [如何安装paddlehub](../../../../docs/docs_ch/get_start/installation.rst)
- ### 2、安装
- ```shell
$ hub install ssd_vgg16_300_coco2017
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测
- ### 1、命令行预测
- ```shell
$ hub run ssd_vgg16_300_coco2017 --input_path "/PATH/TO/IMAGE"
```
- 通过命令行方式实现目标检测模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、预测代码示例
- ```python
import paddlehub as hub
import cv2
object_detector = hub.Module(name="ssd_vgg16_300_coco2017")
result = object_detector.object_detection(images=[cv2.imread('/PATH/TO/IMAGE')])
# or
# result = object_detector.object_detection((paths=['/PATH/TO/IMAGE'])
```
- ### 3、API
- ```python
def object_detection(paths=None,
images=None, images=None,
batch_size=1, batch_size=1,
use_gpu=False, use_gpu=False,
output_dir='detection_result', output_dir='detection_result',
score_thresh=0.5, score_thresh=0.5,
visualization=True) visualization=True)
``` ```
预测API,检测输入图片中的所有目标的位置。
**参数** - 预测API,检测输入图片中的所有目标的位置。
* paths (list\[str\]): 图片的路径; - **参数**
* images (list\[numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\],BGR格式;
* batch\_size (int): batch 的大小;
* use\_gpu (bool): 是否使用 GPU;
* score\_thresh (float): 识别置信度的阈值;
* visualization (bool): 是否将识别结果保存为图片文件;
* output\_dir (str): 图片的保存路径,默认设为 detection\_result;
**返回** - paths (list\[str\]): 图片的路径; <br/>
- images (list\[numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\],BGR格式; <br/>
- batch\_size (int): batch 的大小;<br/>
- use\_gpu (bool): 是否使用 GPU;<br/>
- output\_dir (str): 图片的保存路径,默认设为 detection\_result;<br/>
- score\_thresh (float): 识别置信度的阈值;<br/>
- visualization (bool): 是否将识别结果保存为图片文件。
* res (list\[dict\]): 识别结果的列表,列表中每一个元素为 dict,各字段为: **NOTE:** paths和images两个参数选择其一进行提供数据
* data (list): 检测结果,list的每一个元素为 dict,各字段为:
* confidence (float): 识别的置信度;
* label (str): 标签;
* left (int): 边界框的左上角x坐标;
* top (int): 边界框的左上角y坐标;
* right (int): 边界框的右下角x坐标;
* bottom (int): 边界框的右下角y坐标;
* save\_path (str, optional): 识别结果的保存路径 (仅当visualization=True时存在)。
```python - **返回**
def save_inference_model(dirname,
model_filename=None,
params_filename=None,
combined=True)
```
将模型保存到指定路径。 - res (list\[dict\]): 识别结果的列表,列表中每一个元素为 dict,各字段为:
- data (list): 检测结果,list的每一个元素为 dict,各字段为:
- confidence (float): 识别的置信度
- label (str): 标签
- left (int): 边界框的左上角x坐标
- top (int): 边界框的左上角y坐标
- right (int): 边界框的右下角x坐标
- bottom (int): 边界框的右下角y坐标
- save\_path (str, optional): 识别结果的保存路径 (仅当visualization=True时存在)
**参数** - ```python
def save_inference_model(dirname)
```
- 将模型保存到指定路径。
* dirname: 存在模型的目录名称 - **参数**
* model\_filename: 模型文件名称,默认为\_\_model\_\_
* params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效)
* combined: 是否将参数保存到统一的一个文件中
## 代码示例 - dirname: 模型保存路径 <br/>
```python
import paddlehub as hub
import cv2
object_detector = hub.Module(name="ssd_vgg16_300_coco2017") ## 四、服务部署
result = object_detector.object_detection(images=[cv2.imread('/PATH/TO/IMAGE')])
# or
# result = object_detector.object_detection((paths=['/PATH/TO/IMAGE'])
```
## 服务部署 - PaddleHub Serving可以部署一个目标检测的在线服务。
PaddleHub Serving可以部署一个目标检测的在线服务。 - ### 第一步:启动PaddleHub Serving
## 第一步:启动PaddleHub Serving - 运行启动命令:
- ```shell
$ hub serving start -m ssd_vgg16_300_coco2017
```
运行启动命令: - 这样就完成了一个目标检测的服务化API的部署,默认端口号为8866。
```shell
$ hub serving start -m ssd_vgg16_300_coco2017
```
这样就完成了一个目标检测的服务化API的部署,默认端口号为8866 - **NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置
**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。 - ### 第二步:发送预测请求
## 第二步:发送预测请求 - 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果 - ```python
import requests
import json
import cv2
import base64
```python
import requests
import json
import cv2
import base64
def cv2_to_base64(image):
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1] data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8') return base64.b64encode(data.tostring()).decode('utf8')
# 发送HTTP请求
data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/ssd_vgg16_300_coco2017"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 打印预测结果
print(r.json()["results"])
```
## 五、更新历史
* 1.0.0
初始发布
# 发送HTTP请求 * 1.0.2
data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/ssd_vgg16_300_coco2017"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 打印预测结果 修复numpy数据读取问题
print(r.json()["results"])
```
### 依赖 * 1.1.0
paddlepaddle >= 1.6.2 移除 fluid api
paddlehub >= 1.6.0 - ```shell
$ hub install ssd_vgg16_300_coco2017==1.1.0
```
# ssd_vgg16_300_coco2017
|Module Name|ssd_vgg16_300_coco2017|
| :--- | :---: |
|Category|object detection|
|Network|SSD|
|Dataset|COCO2017|
|Fine-tuning supported or not|No|
|Module Size|139MB|
|Latest update date|2021-03-15|
|Data indicators|-|
## I.Basic Information
- ### Application Effect Display
- Sample results:
<p align="center">
<img src="https://user-images.githubusercontent.com/22424850/131506781-b4ecb77b-5ab1-4795-88da-5f547f7f7f9c.jpg" width='50%' hspace='10'/>
<br />
</p>
- ### Module Introduction
- Single Shot MultiBox Detector (SSD) is a one-stage detector. Different from two-stage detector, SSD frames object detection as a re- gression problem to spatially separated bounding boxes and associated class probabilities. This module is based on VGG16, trained on COCO2017 dataset, and can be used for object detection.
## II.Installation
- ### 1、Environmental Dependence
- paddlepaddle >= 1.6.2
- paddlehub >= 1.6.0 | [How to install PaddleHub](../../../../docs/docs_en/get_start/installation.rst)
- ### 2、Installation
- ```shell
$ hub install ssd_vgg16_300_coco2017
```
- In case of any problems during installation, please refer to: [Windows_Quickstart](../../../../docs/docs_en/get_start/windows_quickstart.md) | [Linux_Quickstart](../../../../docs/docs_en/get_start/linux_quickstart.md) | [Mac_Quickstart](../../../../docs/docs_en/get_start/mac_quickstart.md)
## III.Module API Prediction
- ### 1、Command line Prediction
- ```shell
$ hub run ssd_vgg16_300_coco2017 --input_path "/PATH/TO/IMAGE"
```
- If you want to call the Hub module through the command line, please refer to: [PaddleHub Command Line Instruction](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、Prediction Code Example
- ```python
import paddlehub as hub
import cv2
object_detector = hub.Module(name="ssd_vgg16_300_coco2017")
result = object_detector.object_detection(images=[cv2.imread('/PATH/TO/IMAGE')])
# or
# result = object_detector.object_detection((paths=['/PATH/TO/IMAGE'])
```
- ### 3、API
- ```python
def object_detection(paths=None,
images=None,
batch_size=1,
use_gpu=False,
output_dir='detection_result',
score_thresh=0.5,
visualization=True)
```
- Detection API, detect positions of all objects in image
- **Parameters**
- paths (list[str]): image path;
- images (list\[numpy.ndarray\]): image data, ndarray.shape is in the format [H, W, C], BGR;
- batch_size (int): the size of batch;
- use_gpu (bool): use GPU or not; **set the CUDA_VISIBLE_DEVICES environment variable first if you are using GPU**
- output_dir (str): save path of images;
- score\_thresh (float): confidence threshold;<br/>
- visualization (bool): Whether to save the results as picture files;
**NOTE:** choose one parameter to provide data from paths and images
- **Return**
- res (list\[dict\]): results
- data (list): detection results, each element in the list is dict
- confidence (float): the confidence of the result
- label (str): label
- left (int): the upper left corner x coordinate of the detection box
- top (int): the upper left corner y coordinate of the detection box
- right (int): the lower right corner x coordinate of the detection box
- bottom (int): the lower right corner y coordinate of the detection box
- save\_path (str, optional): output path for saving results
- ```python
def save_inference_model(dirname)
```
- Save model to specific path
- **Parameters**
- dirname: model save path
## IV.Server Deployment
- PaddleHub Serving can deploy an online service of object detection.
- ### Step 1: Start PaddleHub Serving
- Run the startup command:
- ```shell
$ hub serving start -m ssd_vgg16_300_coco2017
```
- The servitization API is now deployed and the default port number is 8866.
- **NOTE:** If GPU is used for prediction, set CUDA_VISIBLE_DEVICES environment variable before the service, otherwise it need not be set.
- ### Step 2: Send a predictive request
- With a configured server, use the following lines of code to send the prediction request and obtain the result
- ```python
import requests
import json
import cv2
import base64
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
# Send an HTTP request
data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/ssd_vgg16_300_coco2017"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# print prediction results
print(r.json()["results"])
```
## V.Release Note
* 1.0.0
First release
* 1.0.2
Fix the problem of reading numpy
* 1.1.0
Remove fluid api
- ```shell
$ hub install ssd_vgg16_300_coco2017==1.1.0
```
...@@ -5,12 +5,10 @@ from __future__ import division ...@@ -5,12 +5,10 @@ from __future__ import division
import os import os
import random import random
from collections import OrderedDict
import cv2 import cv2
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from paddle import fluid
__all__ = ['reader'] __all__ = ['reader']
......
...@@ -7,39 +7,43 @@ import os ...@@ -7,39 +7,43 @@ import os
from functools import partial from functools import partial
import yaml import yaml
import paddle
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.static
import paddlehub as hub from paddle.inference import Config, create_predictor
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
from paddlehub.common.paddle_helper import add_vars_prefix
from ssd_vgg16_300_coco2017.vgg import VGG from .processor import load_label_info, postprocess, base64_to_cv2
from ssd_vgg16_300_coco2017.processor import load_label_info, postprocess, base64_to_cv2 from .data_feed import reader
from ssd_vgg16_300_coco2017.data_feed import reader
@moduleinfo( @moduleinfo(
name="ssd_vgg16_300_coco2017", name="ssd_vgg16_300_coco2017",
version="1.0.1", version="1.1.0",
type="cv/object_detection", type="cv/object_detection",
summary="SSD with backbone VGG16, trained with dataset COCO.", summary="SSD with backbone VGG16, trained with dataset COCO.",
author="paddlepaddle", author="paddlepaddle",
author_email="paddle-dev@baidu.com") author_email="paddle-dev@baidu.com")
class SSDVGG16(hub.Module): class SSDVGG16:
def _initialize(self): def __init__(self):
self.default_pretrained_model_path = os.path.join(self.directory, "ssd_vgg16_300_model") self.default_pretrained_model_path = os.path.join(
self.label_names = load_label_info(os.path.join(self.directory, "label_file.txt")) self.directory, "ssd_vgg16_300_model", "model")
self.label_names = load_label_info(
os.path.join(self.directory, "label_file.txt"))
self.model_config = None self.model_config = None
self._set_config() self._set_config()
def _set_config(self): def _set_config(self):
# predictor config setting. """
cpu_config = AnalysisConfig(self.default_pretrained_model_path) predictor config setting.
"""
model = self.default_pretrained_model_path+'.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
cpu_config.switch_ir_optim(False) cpu_config.switch_ir_optim(False)
self.cpu_predictor = create_paddle_predictor(cpu_config) self.cpu_predictor = create_predictor(cpu_config)
try: try:
_places = os.environ["CUDA_VISIBLE_DEVICES"] _places = os.environ["CUDA_VISIBLE_DEVICES"]
...@@ -48,10 +52,10 @@ class SSDVGG16(hub.Module): ...@@ -48,10 +52,10 @@ class SSDVGG16(hub.Module):
except: except:
use_gpu = False use_gpu = False
if use_gpu: if use_gpu:
gpu_config = AnalysisConfig(self.default_pretrained_model_path) gpu_config = Config(model, params)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0) gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config) self.gpu_predictor = create_predictor(gpu_config)
# model config setting. # model config setting.
if not self.model_config: if not self.model_config:
...@@ -61,73 +65,6 @@ class SSDVGG16(hub.Module): ...@@ -61,73 +65,6 @@ class SSDVGG16(hub.Module):
self.multi_box_head_config = self.model_config['MultiBoxHead'] self.multi_box_head_config = self.model_config['MultiBoxHead']
self.output_decoder_config = self.model_config['SSDOutputDecoder'] self.output_decoder_config = self.model_config['SSDOutputDecoder']
def context(self, trainable=True, pretrained=True, get_prediction=False):
"""
Distill the Head Features, so as to perform transfer learning.
Args:
trainable (bool): whether to set parameters trainable.
pretrained (bool): whether to load default pretrained model.
get_prediction (bool): whether to get prediction.
Returns:
inputs(dict): the input variables.
outputs(dict): the output variables.
context_prog (Program): the program to execute transfer learning.
"""
context_prog = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(context_prog, startup_program):
with fluid.unique_name.guard():
# image
image = fluid.layers.data(name='image', shape=[3, 300, 300], dtype='float32')
# backbone
backbone = VGG(depth=16, with_extra_blocks=True, normalizations=[20., -1, -1, -1, -1, -1])
# body_feats
body_feats = backbone(image)
# im_size
im_size = fluid.layers.data(name='im_size', shape=[2], dtype='int32')
# var_prefix
var_prefix = '@HUB_{}@'.format(self.name)
# names of inputs
inputs = {'image': var_prefix + image.name, 'im_size': var_prefix + im_size.name}
# names of outputs
if get_prediction:
locs, confs, box, box_var = fluid.layers.multi_box_head(
inputs=body_feats, image=image, num_classes=81, **self.multi_box_head_config)
pred = fluid.layers.detection_output(
loc=locs, scores=confs, prior_box=box, prior_box_var=box_var, **self.output_decoder_config)
outputs = {'bbox_out': [var_prefix + pred.name]}
else:
outputs = {'body_features': [var_prefix + var.name for var in body_feats]}
# add_vars_prefix
add_vars_prefix(context_prog, var_prefix)
add_vars_prefix(fluid.default_startup_program(), var_prefix)
# inputs
inputs = {key: context_prog.global_block().vars[value] for key, value in inputs.items()}
outputs = {
out_key: [context_prog.global_block().vars[varname] for varname in out_value]
for out_key, out_value in outputs.items()
}
# trainable
for param in context_prog.global_block().iter_parameters():
param.trainable = trainable
place = fluid.CPUPlace()
exe = fluid.Executor(place)
# pretrained
if pretrained:
def _if_exist(var):
return os.path.exists(os.path.join(self.default_pretrained_model_path, var.name))
fluid.io.load_vars(exe, self.default_pretrained_model_path, predicate=_if_exist)
else:
exe.run(startup_program)
return inputs, outputs, context_prog
def object_detection(self, def object_detection(self,
paths=None, paths=None,
images=None, images=None,
...@@ -160,20 +97,23 @@ class SSDVGG16(hub.Module): ...@@ -160,20 +97,23 @@ class SSDVGG16(hub.Module):
""" """
paths = paths if paths else list() paths = paths if paths else list()
data_reader = partial(reader, paths, images) data_reader = partial(reader, paths, images)
batch_reader = fluid.io.batch(data_reader, batch_size=batch_size) batch_reader = paddle.batch(data_reader, batch_size=batch_size)
res = [] res = []
for iter_id, feed_data in enumerate(batch_reader()): for iter_id, feed_data in enumerate(batch_reader()):
feed_data = np.array(feed_data) feed_data = np.array(feed_data)
image_tensor = PaddleTensor(np.array(list(feed_data[:, 0])).copy())
if use_gpu:
data_out = self.gpu_predictor.run([image_tensor])
else:
data_out = self.cpu_predictor.run([image_tensor])
output = postprocess( predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
paths=paths, input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(np.array(list(feed_data[:, 0])))
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
output = postprocess(paths=paths,
images=images, images=images,
data_out=data_out, data_out=output_handle,
score_thresh=score_thresh, score_thresh=score_thresh,
label_names=self.label_names, label_names=self.label_names,
output_dir=output_dir, output_dir=output_dir,
...@@ -182,25 +122,6 @@ class SSDVGG16(hub.Module): ...@@ -182,25 +122,6 @@ class SSDVGG16(hub.Module):
res.extend(output) res.extend(output)
return res return res
def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving @serving
def serving_method(self, images, **kwargs): def serving_method(self, images, **kwargs):
""" """
...@@ -220,9 +141,12 @@ class SSDVGG16(hub.Module): ...@@ -220,9 +141,12 @@ class SSDVGG16(hub.Module):
prog='hub run {}'.format(self.name), prog='hub run {}'.format(self.name),
usage='%(prog)s', usage='%(prog)s',
add_help=True) add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.arg_input_group = self.parser.add_argument_group(
title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group( self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, not required.") title="Config options",
description=
"Run configuration for controlling module behavior, not required.")
self.add_module_config_arg() self.add_module_config_arg()
self.add_module_input_arg() self.add_module_input_arg()
args = self.parser.parse_args(argvs) args = self.parser.parse_args(argvs)
...@@ -240,17 +164,34 @@ class SSDVGG16(hub.Module): ...@@ -240,17 +164,34 @@ class SSDVGG16(hub.Module):
Add the command config options. Add the command config options.
""" """
self.arg_config_group.add_argument( self.arg_config_group.add_argument(
'--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not") '--use_gpu',
type=ast.literal_eval,
default=False,
help="whether use GPU or not")
self.arg_config_group.add_argument( self.arg_config_group.add_argument(
'--output_dir', type=str, default='detection_result', help="The directory to save output images.") '--output_dir',
type=str,
default='detection_result',
help="The directory to save output images.")
self.arg_config_group.add_argument( self.arg_config_group.add_argument(
'--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.") '--visualization',
type=ast.literal_eval,
default=False,
help="whether to save output as images.")
def add_module_input_arg(self): def add_module_input_arg(self):
""" """
Add the command input options. Add the command input options.
""" """
self.arg_input_group.add_argument('--input_path', type=str, help="path to image.")
self.arg_input_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.")
self.arg_input_group.add_argument( self.arg_input_group.add_argument(
'--score_thresh', type=ast.literal_eval, default=0.5, help="threshold for object detecion.") '--input_path', type=str, help="path to image.")
self.arg_input_group.add_argument(
'--batch_size',
type=ast.literal_eval,
default=1,
help="batch size.")
self.arg_input_group.add_argument(
'--score_thresh',
type=ast.literal_eval,
default=0.5,
help="threshold for object detecion.")
...@@ -85,7 +85,7 @@ def load_label_info(file_path): ...@@ -85,7 +85,7 @@ def load_label_info(file_path):
def postprocess(paths, images, data_out, score_thresh, label_names, output_dir, handle_id, visualization=True): def postprocess(paths, images, data_out, score_thresh, label_names, output_dir, handle_id, visualization=True):
""" """
postprocess the lod_tensor produced by fluid.Executor.run postprocess the lod_tensor produced by Executor.run
Args: Args:
paths (list[str]): the path of images. paths (list[str]): the path of images.
...@@ -108,9 +108,9 @@ def postprocess(paths, images, data_out, score_thresh, label_names, output_dir, ...@@ -108,9 +108,9 @@ def postprocess(paths, images, data_out, score_thresh, label_names, output_dir,
confidence (float): The confidence of detection result. confidence (float): The confidence of detection result.
save_path (str): The path to save output images. save_path (str): The path to save output images.
""" """
lod_tensor = data_out[0] lod = data_out.lod()[0]
lod = lod_tensor.lod[0] results = data_out.copy_to_cpu()
results = lod_tensor.as_ndarray()
if handle_id < len(paths): if handle_id < len(paths):
unhandled_paths = paths[handle_id:] unhandled_paths = paths[handle_id:]
unhandled_paths_num = len(unhandled_paths) unhandled_paths_num = len(unhandled_paths)
......
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://ai-studio-static-online.cdn.bcebos.com/68313e182f5e4ad9907e69dac9ece8fc50840d7ffbd24fa88396f009958f969a'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="ssd_vgg16_300_coco2017")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('detection_result')
def test_object_detection1(self):
results = self.module.object_detection(
paths=['tests/test.jpg']
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'cat')
self.assertTrue(confidence > 0.5)
self.assertTrue(200 < left < 800)
self.assertTrue(2500 < right < 3500)
self.assertTrue(500 < top < 1500)
self.assertTrue(3500 < bottom < 4500)
def test_object_detection2(self):
results = self.module.object_detection(
images=[cv2.imread('tests/test.jpg')]
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'cat')
self.assertTrue(confidence > 0.5)
self.assertTrue(200 < left < 800)
self.assertTrue(2500 < right < 3500)
self.assertTrue(500 < top < 1500)
self.assertTrue(3500 < bottom < 4500)
def test_object_detection3(self):
results = self.module.object_detection(
images=[cv2.imread('tests/test.jpg')],
visualization=False
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'cat')
self.assertTrue(confidence > 0.5)
self.assertTrue(200 < left < 800)
self.assertTrue(2500 < right < 3500)
self.assertTrue(500 < top < 1500)
self.assertTrue(3500 < bottom < 4500)
def test_object_detection4(self):
self.assertRaises(
AssertionError,
self.module.object_detection,
paths=['no.jpg']
)
def test_object_detection5(self):
self.assertRaises(
cv2.error,
self.module.object_detection,
images=['test.jpg']
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = ['VGG']
class VGG(object):
"""
VGG, see https://arxiv.org/abs/1409.1556
Args:
depth (int): the VGG net depth (16 or 19)
normalizations (list): params list of init scale in l2 norm, skip init
scale if param is -1.
with_extra_blocks (bool): whether or not extra blocks should be added
extra_block_filters (list): in each extra block, params:
[in_channel, out_channel, padding_size, stride_size, filter_size]
class_dim (int): number of class while classification
"""
def __init__(self,
depth=16,
with_extra_blocks=False,
normalizations=[20., -1, -1, -1, -1, -1],
extra_block_filters=[[256, 512, 1, 2, 3], [128, 256, 1, 2, 3], [128, 256, 0, 1, 3],
[128, 256, 0, 1, 3]],
class_dim=1000):
assert depth in [16, 19], "depth {} not in [16, 19]"
self.depth = depth
self.depth_cfg = {16: [2, 2, 3, 3, 3], 19: [2, 2, 4, 4, 4]}
self.with_extra_blocks = with_extra_blocks
self.normalizations = normalizations
self.extra_block_filters = extra_block_filters
self.class_dim = class_dim
def __call__(self, input):
layers = []
layers += self._vgg_block(input)
if not self.with_extra_blocks:
return layers[-1]
layers += self._add_extras_block(layers[-1])
norm_cfg = self.normalizations
for k, v in enumerate(layers):
if not norm_cfg[k] == -1:
layers[k] = self._l2_norm_scale(v, init_scale=norm_cfg[k])
return layers
def _vgg_block(self, input):
nums = self.depth_cfg[self.depth]
vgg_base = [64, 128, 256, 512, 512]
conv = input
res_layer = []
layers = []
for k, v in enumerate(vgg_base):
conv = self._conv_block(conv, v, nums[k], name="conv{}_".format(k + 1))
layers.append(conv)
if self.with_extra_blocks:
if k == 4:
conv = self._pooling_block(conv, 3, 1, pool_padding=1)
else:
conv = self._pooling_block(conv, 2, 2)
else:
conv = self._pooling_block(conv, 2, 2)
if not self.with_extra_blocks:
fc_dim = 4096
fc_name = ["fc6", "fc7", "fc8"]
fc1 = fluid.layers.fc(
input=conv,
size=fc_dim,
act='relu',
param_attr=fluid.param_attr.ParamAttr(name=fc_name[0] + "_weights"),
bias_attr=fluid.param_attr.ParamAttr(name=fc_name[0] + "_offset"))
fc2 = fluid.layers.fc(
input=fc1,
size=fc_dim,
act='relu',
param_attr=fluid.param_attr.ParamAttr(name=fc_name[1] + "_weights"),
bias_attr=fluid.param_attr.ParamAttr(name=fc_name[1] + "_offset"))
out = fluid.layers.fc(
input=fc2,
size=self.class_dim,
param_attr=fluid.param_attr.ParamAttr(name=fc_name[2] + "_weights"),
bias_attr=fluid.param_attr.ParamAttr(name=fc_name[2] + "_offset"))
out = fluid.layers.softmax(out)
res_layer.append(out)
return [out]
else:
fc6 = self._conv_layer(conv, 1024, 3, 1, 6, dilation=6, name="fc6")
fc7 = self._conv_layer(fc6, 1024, 1, 1, 0, name="fc7")
return [layers[3], fc7]
def _add_extras_block(self, input):
cfg = self.extra_block_filters
conv = input
layers = []
for k, v in enumerate(cfg):
assert len(v) == 5, "extra_block_filters size not fix"
conv = self._extra_block(conv, v[0], v[1], v[2], v[3], v[4], name="conv{}_".format(6 + k))
layers.append(conv)
return layers
def _conv_block(self, input, num_filter, groups, name=None):
conv = input
for i in range(groups):
conv = self._conv_layer(
input=conv,
num_filters=num_filter,
filter_size=3,
stride=1,
padding=1,
act='relu',
name=name + str(i + 1))
return conv
def _extra_block(self, input, num_filters1, num_filters2, padding_size, stride_size, filter_size, name=None):
# 1x1 conv
conv_1 = self._conv_layer(
input=input, num_filters=int(num_filters1), filter_size=1, stride=1, act='relu', padding=0, name=name + "1")
# 3x3 conv
conv_2 = self._conv_layer(
input=conv_1,
num_filters=int(num_filters2),
filter_size=filter_size,
stride=stride_size,
act='relu',
padding=padding_size,
name=name + "2")
return conv_2
def _conv_layer(self,
input,
num_filters,
filter_size,
stride,
padding,
dilation=1,
act='relu',
use_cudnn=True,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
dilation=dilation,
act=act,
use_cudnn=use_cudnn,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=ParamAttr(name=name + "_biases") if self.with_extra_blocks else False,
name=name + '.conv2d.output.1')
return conv
def _pooling_block(self, conv, pool_size, pool_stride, pool_padding=0, ceil_mode=True):
pool = fluid.layers.pool2d(
input=conv,
pool_size=pool_size,
pool_type='max',
pool_stride=pool_stride,
pool_padding=pool_padding,
ceil_mode=ceil_mode)
return pool
def _l2_norm_scale(self, input, init_scale=1.0, channel_shared=False):
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.initializer import Constant
helper = LayerHelper("Scale")
l2_norm = fluid.layers.l2_normalize(input, axis=1) # l2 norm along channel
shape = [1] if channel_shared else [input.shape[1]]
scale = helper.create_parameter(
attr=helper.param_attr, shape=shape, dtype=input.dtype, default_initializer=Constant(init_scale))
out = fluid.layers.elementwise_mul(
x=l2_norm, y=scale, axis=-1 if channel_shared else 1, name="conv4_3_norm_scale")
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册