未验证 提交 51427477 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update ssd_vgg16_300_coco2017 (#1949)

* update ssd_vgg16_300_model

* update unittest

* update unittest

* update gpu config

* update

* add clean func

* update save inference model
Co-authored-by: Nchenjian <chenjian26@baidu.com>
上级 01002e40
## 命令行预测
# ssd_vgg16_300_coco2017
```shell
$ hub run ssd_vgg16_300_coco2017 --input_path "/PATH/TO/IMAGE"
```
|模型名称|ssd_vgg16_300_coco2017|
| :--- | :---: |
|类别|图像 - 目标检测|
|网络|SSD|
|数据集|COCO2017|
|是否支持Fine-tuning|否|
|模型大小|139MB|
|最新更新日期|2021-03-15|
|数据指标|-|
## API
```python
def context(trainable=True,
pretrained=True,
get_prediction=False)
```
## 一、模型基本信息
提取特征,用于迁移学习。
- ### 应用效果展示
- 样例结果示例:
<p align="center">
<img src="https://user-images.githubusercontent.com/22424850/131506781-b4ecb77b-5ab1-4795-88da-5f547f7f7f9c.jpg" width='50%' hspace='10'/>
<br />
</p>
**参数**
- ### 模型介绍
* trainable(bool): 参数是否可训练;
* pretrained (bool): 是否加载预训练模型;
* get\_prediction (bool): 是否执行预测。
- Single Shot MultiBox Detector (SSD) 是一种单阶段的目标检测器。与两阶段的检测方法不同,单阶段目标检测并不进行区域推荐,而是直接从特征图回归出目标的边界框和分类概率。SSD 运用了这种单阶段检测的思想,并且对其进行改进:在不同尺度的特征图上检测对应尺度的目标。该PaddleHub Module的基网络为VGG16模型,在Pascal数据集上预训练得到,目前仅支持预测。
**返回**
* inputs (dict): 模型的输入,keys 包括 'image', 'im\_size',相应的取值为:
* image (Variable): 图像变量
* im\_size (Variable): 图片的尺寸
* outputs (dict): 模型的输出。如果 get\_prediction 为 False,输出 'head\_features',否则输出 'bbox\_out'。
* context\_prog (Program): 用于迁移学习的 Program.
## 二、安装
```python
def object_detection(paths=None,
images=None,
batch_size=1,
use_gpu=False,
output_dir='detection_result',
score_thresh=0.5,
visualization=True)
```
- ### 1、环境依赖
预测API,检测输入图片中的所有目标的位置。
- paddlepaddle >= 1.6.2
**参数**
- paddlehub >= 1.6.0 | [如何安装paddlehub](../../../../docs/docs_ch/get_start/installation.rst)
* paths (list\[str\]): 图片的路径;
* images (list\[numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\],BGR格式;
* batch\_size (int): batch 的大小;
* use\_gpu (bool): 是否使用 GPU;
* score\_thresh (float): 识别置信度的阈值;
* visualization (bool): 是否将识别结果保存为图片文件;
* output\_dir (str): 图片的保存路径,默认设为 detection\_result;
- ### 2、安装
**返回**
- ```shell
$ hub install ssd_vgg16_300_coco2017
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
* res (list\[dict\]): 识别结果的列表,列表中每一个元素为 dict,各字段为:
* data (list): 检测结果,list的每一个元素为 dict,各字段为:
* confidence (float): 识别的置信度;
* label (str): 标签;
* left (int): 边界框的左上角x坐标;
* top (int): 边界框的左上角y坐标;
* right (int): 边界框的右下角x坐标;
* bottom (int): 边界框的右下角y坐标;
* save\_path (str, optional): 识别结果的保存路径 (仅当visualization=True时存在)。
## 三、模型API预测
```python
def save_inference_model(dirname,
model_filename=None,
params_filename=None,
combined=True)
```
- ### 1、命令行预测
将模型保存到指定路径。
- ```shell
$ hub run ssd_vgg16_300_coco2017 --input_path "/PATH/TO/IMAGE"
```
- 通过命令行方式实现目标检测模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、预测代码示例
**参数**
- ```python
import paddlehub as hub
import cv2
* dirname: 存在模型的目录名称
* model\_filename: 模型文件名称,默认为\_\_model\_\_
* params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效)
* combined: 是否将参数保存到统一的一个文件中
object_detector = hub.Module(name="ssd_vgg16_300_coco2017")
result = object_detector.object_detection(images=[cv2.imread('/PATH/TO/IMAGE')])
# or
# result = object_detector.object_detection((paths=['/PATH/TO/IMAGE'])
```
## 代码示例
- ### 3、API
```python
import paddlehub as hub
import cv2
- ```python
def object_detection(paths=None,
images=None,
batch_size=1,
use_gpu=False,
output_dir='detection_result',
score_thresh=0.5,
visualization=True)
```
object_detector = hub.Module(name="ssd_vgg16_300_coco2017")
result = object_detector.object_detection(images=[cv2.imread('/PATH/TO/IMAGE')])
# or
# result = object_detector.object_detection((paths=['/PATH/TO/IMAGE'])
```
- 预测API,检测输入图片中的所有目标的位置。
## 服务部署
- **参数**
PaddleHub Serving可以部署一个目标检测的在线服务。
- paths (list\[str\]): 图片的路径; <br/>
- images (list\[numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\],BGR格式; <br/>
- batch\_size (int): batch 的大小;<br/>
- use\_gpu (bool): 是否使用 GPU;<br/>
- output\_dir (str): 图片的保存路径,默认设为 detection\_result;<br/>
- score\_thresh (float): 识别置信度的阈值;<br/>
- visualization (bool): 是否将识别结果保存为图片文件。
## 第一步:启动PaddleHub Serving
**NOTE:** paths和images两个参数选择其一进行提供数据
运行启动命令:
```shell
$ hub serving start -m ssd_vgg16_300_coco2017
```
- **返回**
这样就完成了一个目标检测的服务化API的部署,默认端口号为8866。
- res (list\[dict\]): 识别结果的列表,列表中每一个元素为 dict,各字段为:
- data (list): 检测结果,list的每一个元素为 dict,各字段为:
- confidence (float): 识别的置信度
- label (str): 标签
- left (int): 边界框的左上角x坐标
- top (int): 边界框的左上角y坐标
- right (int): 边界框的右下角x坐标
- bottom (int): 边界框的右下角y坐标
- save\_path (str, optional): 识别结果的保存路径 (仅当visualization=True时存在)
**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
- ```python
def save_inference_model(dirname)
```
- 将模型保存到指定路径。
## 第二步:发送预测请求
- **参数**
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
- dirname: 模型保存路径 <br/>
```python
import requests
import json
import cv2
import base64
## 四、服务部署
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
- PaddleHub Serving可以部署一个目标检测的在线服务。
- ### 第一步:启动PaddleHub Serving
# 发送HTTP请求
data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/ssd_vgg16_300_coco2017"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
- 运行启动命令:
- ```shell
$ hub serving start -m ssd_vgg16_300_coco2017
```
# 打印预测结果
print(r.json()["results"])
```
- 这样就完成了一个目标检测的服务化API的部署,默认端口号为8866。
### 依赖
- **NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
paddlepaddle >= 1.6.2
- ### 第二步:发送预测请求
paddlehub >= 1.6.0
- 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
- ```python
import requests
import json
import cv2
import base64
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
# 发送HTTP请求
data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/ssd_vgg16_300_coco2017"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 打印预测结果
print(r.json()["results"])
```
## 五、更新历史
* 1.0.0
初始发布
* 1.0.2
修复numpy数据读取问题
* 1.1.0
移除 fluid api
- ```shell
$ hub install ssd_vgg16_300_coco2017==1.1.0
```
# ssd_vgg16_300_coco2017
|Module Name|ssd_vgg16_300_coco2017|
| :--- | :---: |
|Category|object detection|
|Network|SSD|
|Dataset|COCO2017|
|Fine-tuning supported or not|No|
|Module Size|139MB|
|Latest update date|2021-03-15|
|Data indicators|-|
## I.Basic Information
- ### Application Effect Display
- Sample results:
<p align="center">
<img src="https://user-images.githubusercontent.com/22424850/131506781-b4ecb77b-5ab1-4795-88da-5f547f7f7f9c.jpg" width='50%' hspace='10'/>
<br />
</p>
- ### Module Introduction
- Single Shot MultiBox Detector (SSD) is a one-stage detector. Different from two-stage detector, SSD frames object detection as a re- gression problem to spatially separated bounding boxes and associated class probabilities. This module is based on VGG16, trained on COCO2017 dataset, and can be used for object detection.
## II.Installation
- ### 1、Environmental Dependence
- paddlepaddle >= 1.6.2
- paddlehub >= 1.6.0 | [How to install PaddleHub](../../../../docs/docs_en/get_start/installation.rst)
- ### 2、Installation
- ```shell
$ hub install ssd_vgg16_300_coco2017
```
- In case of any problems during installation, please refer to: [Windows_Quickstart](../../../../docs/docs_en/get_start/windows_quickstart.md) | [Linux_Quickstart](../../../../docs/docs_en/get_start/linux_quickstart.md) | [Mac_Quickstart](../../../../docs/docs_en/get_start/mac_quickstart.md)
## III.Module API Prediction
- ### 1、Command line Prediction
- ```shell
$ hub run ssd_vgg16_300_coco2017 --input_path "/PATH/TO/IMAGE"
```
- If you want to call the Hub module through the command line, please refer to: [PaddleHub Command Line Instruction](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、Prediction Code Example
- ```python
import paddlehub as hub
import cv2
object_detector = hub.Module(name="ssd_vgg16_300_coco2017")
result = object_detector.object_detection(images=[cv2.imread('/PATH/TO/IMAGE')])
# or
# result = object_detector.object_detection((paths=['/PATH/TO/IMAGE'])
```
- ### 3、API
- ```python
def object_detection(paths=None,
images=None,
batch_size=1,
use_gpu=False,
output_dir='detection_result',
score_thresh=0.5,
visualization=True)
```
- Detection API, detect positions of all objects in image
- **Parameters**
- paths (list[str]): image path;
- images (list\[numpy.ndarray\]): image data, ndarray.shape is in the format [H, W, C], BGR;
- batch_size (int): the size of batch;
- use_gpu (bool): use GPU or not; **set the CUDA_VISIBLE_DEVICES environment variable first if you are using GPU**
- output_dir (str): save path of images;
- score\_thresh (float): confidence threshold;<br/>
- visualization (bool): Whether to save the results as picture files;
**NOTE:** choose one parameter to provide data from paths and images
- **Return**
- res (list\[dict\]): results
- data (list): detection results, each element in the list is dict
- confidence (float): the confidence of the result
- label (str): label
- left (int): the upper left corner x coordinate of the detection box
- top (int): the upper left corner y coordinate of the detection box
- right (int): the lower right corner x coordinate of the detection box
- bottom (int): the lower right corner y coordinate of the detection box
- save\_path (str, optional): output path for saving results
- ```python
def save_inference_model(dirname)
```
- Save model to specific path
- **Parameters**
- dirname: model save path
## IV.Server Deployment
- PaddleHub Serving can deploy an online service of object detection.
- ### Step 1: Start PaddleHub Serving
- Run the startup command:
- ```shell
$ hub serving start -m ssd_vgg16_300_coco2017
```
- The servitization API is now deployed and the default port number is 8866.
- **NOTE:** If GPU is used for prediction, set CUDA_VISIBLE_DEVICES environment variable before the service, otherwise it need not be set.
- ### Step 2: Send a predictive request
- With a configured server, use the following lines of code to send the prediction request and obtain the result
- ```python
import requests
import json
import cv2
import base64
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
# Send an HTTP request
data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/ssd_vgg16_300_coco2017"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# print prediction results
print(r.json()["results"])
```
## V.Release Note
* 1.0.0
First release
* 1.0.2
Fix the problem of reading numpy
* 1.1.0
Remove fluid api
- ```shell
$ hub install ssd_vgg16_300_coco2017==1.1.0
```
......@@ -5,12 +5,10 @@ from __future__ import division
import os
import random
from collections import OrderedDict
import cv2
import numpy as np
from PIL import Image
from paddle import fluid
__all__ = ['reader']
......
......@@ -7,39 +7,43 @@ import os
from functools import partial
import yaml
import paddle
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
import paddle.static
from paddle.inference import Config, create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving
from paddlehub.common.paddle_helper import add_vars_prefix
from ssd_vgg16_300_coco2017.vgg import VGG
from ssd_vgg16_300_coco2017.processor import load_label_info, postprocess, base64_to_cv2
from ssd_vgg16_300_coco2017.data_feed import reader
from .processor import load_label_info, postprocess, base64_to_cv2
from .data_feed import reader
@moduleinfo(
name="ssd_vgg16_300_coco2017",
version="1.0.1",
version="1.1.0",
type="cv/object_detection",
summary="SSD with backbone VGG16, trained with dataset COCO.",
author="paddlepaddle",
author_email="paddle-dev@baidu.com")
class SSDVGG16(hub.Module):
def _initialize(self):
self.default_pretrained_model_path = os.path.join(self.directory, "ssd_vgg16_300_model")
self.label_names = load_label_info(os.path.join(self.directory, "label_file.txt"))
class SSDVGG16:
def __init__(self):
self.default_pretrained_model_path = os.path.join(
self.directory, "ssd_vgg16_300_model", "model")
self.label_names = load_label_info(
os.path.join(self.directory, "label_file.txt"))
self.model_config = None
self._set_config()
def _set_config(self):
# predictor config setting.
cpu_config = AnalysisConfig(self.default_pretrained_model_path)
"""
predictor config setting.
"""
model = self.default_pretrained_model_path+'.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
cpu_config.switch_ir_optim(False)
self.cpu_predictor = create_paddle_predictor(cpu_config)
self.cpu_predictor = create_predictor(cpu_config)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
......@@ -48,10 +52,10 @@ class SSDVGG16(hub.Module):
except:
use_gpu = False
if use_gpu:
gpu_config = AnalysisConfig(self.default_pretrained_model_path)
gpu_config = Config(model, params)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config)
self.gpu_predictor = create_predictor(gpu_config)
# model config setting.
if not self.model_config:
......@@ -61,73 +65,6 @@ class SSDVGG16(hub.Module):
self.multi_box_head_config = self.model_config['MultiBoxHead']
self.output_decoder_config = self.model_config['SSDOutputDecoder']
def context(self, trainable=True, pretrained=True, get_prediction=False):
"""
Distill the Head Features, so as to perform transfer learning.
Args:
trainable (bool): whether to set parameters trainable.
pretrained (bool): whether to load default pretrained model.
get_prediction (bool): whether to get prediction.
Returns:
inputs(dict): the input variables.
outputs(dict): the output variables.
context_prog (Program): the program to execute transfer learning.
"""
context_prog = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(context_prog, startup_program):
with fluid.unique_name.guard():
# image
image = fluid.layers.data(name='image', shape=[3, 300, 300], dtype='float32')
# backbone
backbone = VGG(depth=16, with_extra_blocks=True, normalizations=[20., -1, -1, -1, -1, -1])
# body_feats
body_feats = backbone(image)
# im_size
im_size = fluid.layers.data(name='im_size', shape=[2], dtype='int32')
# var_prefix
var_prefix = '@HUB_{}@'.format(self.name)
# names of inputs
inputs = {'image': var_prefix + image.name, 'im_size': var_prefix + im_size.name}
# names of outputs
if get_prediction:
locs, confs, box, box_var = fluid.layers.multi_box_head(
inputs=body_feats, image=image, num_classes=81, **self.multi_box_head_config)
pred = fluid.layers.detection_output(
loc=locs, scores=confs, prior_box=box, prior_box_var=box_var, **self.output_decoder_config)
outputs = {'bbox_out': [var_prefix + pred.name]}
else:
outputs = {'body_features': [var_prefix + var.name for var in body_feats]}
# add_vars_prefix
add_vars_prefix(context_prog, var_prefix)
add_vars_prefix(fluid.default_startup_program(), var_prefix)
# inputs
inputs = {key: context_prog.global_block().vars[value] for key, value in inputs.items()}
outputs = {
out_key: [context_prog.global_block().vars[varname] for varname in out_value]
for out_key, out_value in outputs.items()
}
# trainable
for param in context_prog.global_block().iter_parameters():
param.trainable = trainable
place = fluid.CPUPlace()
exe = fluid.Executor(place)
# pretrained
if pretrained:
def _if_exist(var):
return os.path.exists(os.path.join(self.default_pretrained_model_path, var.name))
fluid.io.load_vars(exe, self.default_pretrained_model_path, predicate=_if_exist)
else:
exe.run(startup_program)
return inputs, outputs, context_prog
def object_detection(self,
paths=None,
images=None,
......@@ -160,47 +97,31 @@ class SSDVGG16(hub.Module):
"""
paths = paths if paths else list()
data_reader = partial(reader, paths, images)
batch_reader = fluid.io.batch(data_reader, batch_size=batch_size)
batch_reader = paddle.batch(data_reader, batch_size=batch_size)
res = []
for iter_id, feed_data in enumerate(batch_reader()):
feed_data = np.array(feed_data)
image_tensor = PaddleTensor(np.array(list(feed_data[:, 0])).copy())
if use_gpu:
data_out = self.gpu_predictor.run([image_tensor])
else:
data_out = self.cpu_predictor.run([image_tensor])
output = postprocess(
paths=paths,
images=images,
data_out=data_out,
score_thresh=score_thresh,
label_names=self.label_names,
output_dir=output_dir,
handle_id=iter_id * batch_size,
visualization=visualization)
predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(np.array(list(feed_data[:, 0])))
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
output = postprocess(paths=paths,
images=images,
data_out=output_handle,
score_thresh=score_thresh,
label_names=self.label_names,
output_dir=output_dir,
handle_id=iter_id * batch_size,
visualization=visualization)
res.extend(output)
return res
def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving
def serving_method(self, images, **kwargs):
"""
......@@ -220,9 +141,12 @@ class SSDVGG16(hub.Module):
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_input_group = self.parser.add_argument_group(
title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, not required.")
title="Config options",
description=
"Run configuration for controlling module behavior, not required.")
self.add_module_config_arg()
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
......@@ -240,17 +164,34 @@ class SSDVGG16(hub.Module):
Add the command config options.
"""
self.arg_config_group.add_argument(
'--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not")
'--use_gpu',
type=ast.literal_eval,
default=False,
help="whether use GPU or not")
self.arg_config_group.add_argument(
'--output_dir', type=str, default='detection_result', help="The directory to save output images.")
'--output_dir',
type=str,
default='detection_result',
help="The directory to save output images.")
self.arg_config_group.add_argument(
'--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.")
'--visualization',
type=ast.literal_eval,
default=False,
help="whether to save output as images.")
def add_module_input_arg(self):
"""
Add the command input options.
"""
self.arg_input_group.add_argument('--input_path', type=str, help="path to image.")
self.arg_input_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.")
self.arg_input_group.add_argument(
'--score_thresh', type=ast.literal_eval, default=0.5, help="threshold for object detecion.")
'--input_path', type=str, help="path to image.")
self.arg_input_group.add_argument(
'--batch_size',
type=ast.literal_eval,
default=1,
help="batch size.")
self.arg_input_group.add_argument(
'--score_thresh',
type=ast.literal_eval,
default=0.5,
help="threshold for object detecion.")
......@@ -85,7 +85,7 @@ def load_label_info(file_path):
def postprocess(paths, images, data_out, score_thresh, label_names, output_dir, handle_id, visualization=True):
"""
postprocess the lod_tensor produced by fluid.Executor.run
postprocess the lod_tensor produced by Executor.run
Args:
paths (list[str]): the path of images.
......@@ -108,9 +108,9 @@ def postprocess(paths, images, data_out, score_thresh, label_names, output_dir,
confidence (float): The confidence of detection result.
save_path (str): The path to save output images.
"""
lod_tensor = data_out[0]
lod = lod_tensor.lod[0]
results = lod_tensor.as_ndarray()
lod = data_out.lod()[0]
results = data_out.copy_to_cpu()
if handle_id < len(paths):
unhandled_paths = paths[handle_id:]
unhandled_paths_num = len(unhandled_paths)
......
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://ai-studio-static-online.cdn.bcebos.com/68313e182f5e4ad9907e69dac9ece8fc50840d7ffbd24fa88396f009958f969a'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="ssd_vgg16_300_coco2017")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('detection_result')
def test_object_detection1(self):
results = self.module.object_detection(
paths=['tests/test.jpg']
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'cat')
self.assertTrue(confidence > 0.5)
self.assertTrue(200 < left < 800)
self.assertTrue(2500 < right < 3500)
self.assertTrue(500 < top < 1500)
self.assertTrue(3500 < bottom < 4500)
def test_object_detection2(self):
results = self.module.object_detection(
images=[cv2.imread('tests/test.jpg')]
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'cat')
self.assertTrue(confidence > 0.5)
self.assertTrue(200 < left < 800)
self.assertTrue(2500 < right < 3500)
self.assertTrue(500 < top < 1500)
self.assertTrue(3500 < bottom < 4500)
def test_object_detection3(self):
results = self.module.object_detection(
images=[cv2.imread('tests/test.jpg')],
visualization=False
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'cat')
self.assertTrue(confidence > 0.5)
self.assertTrue(200 < left < 800)
self.assertTrue(2500 < right < 3500)
self.assertTrue(500 < top < 1500)
self.assertTrue(3500 < bottom < 4500)
def test_object_detection4(self):
self.assertRaises(
AssertionError,
self.module.object_detection,
paths=['no.jpg']
)
def test_object_detection5(self):
self.assertRaises(
cv2.error,
self.module.object_detection,
images=['test.jpg']
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = ['VGG']
class VGG(object):
"""
VGG, see https://arxiv.org/abs/1409.1556
Args:
depth (int): the VGG net depth (16 or 19)
normalizations (list): params list of init scale in l2 norm, skip init
scale if param is -1.
with_extra_blocks (bool): whether or not extra blocks should be added
extra_block_filters (list): in each extra block, params:
[in_channel, out_channel, padding_size, stride_size, filter_size]
class_dim (int): number of class while classification
"""
def __init__(self,
depth=16,
with_extra_blocks=False,
normalizations=[20., -1, -1, -1, -1, -1],
extra_block_filters=[[256, 512, 1, 2, 3], [128, 256, 1, 2, 3], [128, 256, 0, 1, 3],
[128, 256, 0, 1, 3]],
class_dim=1000):
assert depth in [16, 19], "depth {} not in [16, 19]"
self.depth = depth
self.depth_cfg = {16: [2, 2, 3, 3, 3], 19: [2, 2, 4, 4, 4]}
self.with_extra_blocks = with_extra_blocks
self.normalizations = normalizations
self.extra_block_filters = extra_block_filters
self.class_dim = class_dim
def __call__(self, input):
layers = []
layers += self._vgg_block(input)
if not self.with_extra_blocks:
return layers[-1]
layers += self._add_extras_block(layers[-1])
norm_cfg = self.normalizations
for k, v in enumerate(layers):
if not norm_cfg[k] == -1:
layers[k] = self._l2_norm_scale(v, init_scale=norm_cfg[k])
return layers
def _vgg_block(self, input):
nums = self.depth_cfg[self.depth]
vgg_base = [64, 128, 256, 512, 512]
conv = input
res_layer = []
layers = []
for k, v in enumerate(vgg_base):
conv = self._conv_block(conv, v, nums[k], name="conv{}_".format(k + 1))
layers.append(conv)
if self.with_extra_blocks:
if k == 4:
conv = self._pooling_block(conv, 3, 1, pool_padding=1)
else:
conv = self._pooling_block(conv, 2, 2)
else:
conv = self._pooling_block(conv, 2, 2)
if not self.with_extra_blocks:
fc_dim = 4096
fc_name = ["fc6", "fc7", "fc8"]
fc1 = fluid.layers.fc(
input=conv,
size=fc_dim,
act='relu',
param_attr=fluid.param_attr.ParamAttr(name=fc_name[0] + "_weights"),
bias_attr=fluid.param_attr.ParamAttr(name=fc_name[0] + "_offset"))
fc2 = fluid.layers.fc(
input=fc1,
size=fc_dim,
act='relu',
param_attr=fluid.param_attr.ParamAttr(name=fc_name[1] + "_weights"),
bias_attr=fluid.param_attr.ParamAttr(name=fc_name[1] + "_offset"))
out = fluid.layers.fc(
input=fc2,
size=self.class_dim,
param_attr=fluid.param_attr.ParamAttr(name=fc_name[2] + "_weights"),
bias_attr=fluid.param_attr.ParamAttr(name=fc_name[2] + "_offset"))
out = fluid.layers.softmax(out)
res_layer.append(out)
return [out]
else:
fc6 = self._conv_layer(conv, 1024, 3, 1, 6, dilation=6, name="fc6")
fc7 = self._conv_layer(fc6, 1024, 1, 1, 0, name="fc7")
return [layers[3], fc7]
def _add_extras_block(self, input):
cfg = self.extra_block_filters
conv = input
layers = []
for k, v in enumerate(cfg):
assert len(v) == 5, "extra_block_filters size not fix"
conv = self._extra_block(conv, v[0], v[1], v[2], v[3], v[4], name="conv{}_".format(6 + k))
layers.append(conv)
return layers
def _conv_block(self, input, num_filter, groups, name=None):
conv = input
for i in range(groups):
conv = self._conv_layer(
input=conv,
num_filters=num_filter,
filter_size=3,
stride=1,
padding=1,
act='relu',
name=name + str(i + 1))
return conv
def _extra_block(self, input, num_filters1, num_filters2, padding_size, stride_size, filter_size, name=None):
# 1x1 conv
conv_1 = self._conv_layer(
input=input, num_filters=int(num_filters1), filter_size=1, stride=1, act='relu', padding=0, name=name + "1")
# 3x3 conv
conv_2 = self._conv_layer(
input=conv_1,
num_filters=int(num_filters2),
filter_size=filter_size,
stride=stride_size,
act='relu',
padding=padding_size,
name=name + "2")
return conv_2
def _conv_layer(self,
input,
num_filters,
filter_size,
stride,
padding,
dilation=1,
act='relu',
use_cudnn=True,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
dilation=dilation,
act=act,
use_cudnn=use_cudnn,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=ParamAttr(name=name + "_biases") if self.with_extra_blocks else False,
name=name + '.conv2d.output.1')
return conv
def _pooling_block(self, conv, pool_size, pool_stride, pool_padding=0, ceil_mode=True):
pool = fluid.layers.pool2d(
input=conv,
pool_size=pool_size,
pool_type='max',
pool_stride=pool_stride,
pool_padding=pool_padding,
ceil_mode=ceil_mode)
return pool
def _l2_norm_scale(self, input, init_scale=1.0, channel_shared=False):
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.initializer import Constant
helper = LayerHelper("Scale")
l2_norm = fluid.layers.l2_normalize(input, axis=1) # l2 norm along channel
shape = [1] if channel_shared else [input.shape[1]]
scale = helper.create_parameter(
attr=helper.param_attr, shape=shape, dtype=input.dtype, default_initializer=Constant(init_scale))
out = fluid.layers.elementwise_mul(
x=l2_norm, y=scale, axis=-1 if channel_shared else 1, name="conv4_3_norm_scale")
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册