diff --git a/modules/image/object_detection/ssd_vgg16_300_coco2017/README.md b/modules/image/object_detection/ssd_vgg16_300_coco2017/README.md
index 85510f3e39bc2814978c822b805c868b0294f1e5..567e575a375992e3af06e1bc4c4913da9ce5e942 100644
--- a/modules/image/object_detection/ssd_vgg16_300_coco2017/README.md
+++ b/modules/image/object_detection/ssd_vgg16_300_coco2017/README.md
@@ -1,138 +1,169 @@
-## 命令行预测
+# ssd_vgg16_300_coco2017
-```shell
-$ hub run ssd_vgg16_300_coco2017 --input_path "/PATH/TO/IMAGE"
-```
+|模型名称|ssd_vgg16_300_coco2017|
+| :--- | :---: |
+|类别|图像 - 目标检测|
+|网络|SSD|
+|数据集|COCO2017|
+|是否支持Fine-tuning|否|
+|模型大小|139MB|
+|最新更新日期|2021-03-15|
+|数据指标|-|
-## API
-```python
-def context(trainable=True,
- pretrained=True,
- get_prediction=False)
-```
+## 一、模型基本信息
-提取特征,用于迁移学习。
+- ### 应用效果展示
+ - 样例结果示例:
+
+
+
+
-**参数**
+- ### 模型介绍
-* trainable(bool): 参数是否可训练;
-* pretrained (bool): 是否加载预训练模型;
-* get\_prediction (bool): 是否执行预测。
+ - Single Shot MultiBox Detector (SSD) 是一种单阶段的目标检测器。与两阶段的检测方法不同,单阶段目标检测并不进行区域推荐,而是直接从特征图回归出目标的边界框和分类概率。SSD 运用了这种单阶段检测的思想,并且对其进行改进:在不同尺度的特征图上检测对应尺度的目标。该PaddleHub Module的基网络为VGG16模型,在Pascal数据集上预训练得到,目前仅支持预测。
-**返回**
-* inputs (dict): 模型的输入,keys 包括 'image', 'im\_size',相应的取值为:
- * image (Variable): 图像变量
- * im\_size (Variable): 图片的尺寸
-* outputs (dict): 模型的输出。如果 get\_prediction 为 False,输出 'head\_features',否则输出 'bbox\_out'。
-* context\_prog (Program): 用于迁移学习的 Program.
+## 二、安装
-```python
-def object_detection(paths=None,
- images=None,
- batch_size=1,
- use_gpu=False,
- output_dir='detection_result',
- score_thresh=0.5,
- visualization=True)
-```
+- ### 1、环境依赖
-预测API,检测输入图片中的所有目标的位置。
+ - paddlepaddle >= 1.6.2
-**参数**
+ - paddlehub >= 1.6.0 | [如何安装paddlehub](../../../../docs/docs_ch/get_start/installation.rst)
-* paths (list\[str\]): 图片的路径;
-* images (list\[numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\],BGR格式;
-* batch\_size (int): batch 的大小;
-* use\_gpu (bool): 是否使用 GPU;
-* score\_thresh (float): 识别置信度的阈值;
-* visualization (bool): 是否将识别结果保存为图片文件;
-* output\_dir (str): 图片的保存路径,默认设为 detection\_result;
+- ### 2、安装
-**返回**
+ - ```shell
+ $ hub install ssd_vgg16_300_coco2017
+ ```
+ - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
+ | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
-* res (list\[dict\]): 识别结果的列表,列表中每一个元素为 dict,各字段为:
- * data (list): 检测结果,list的每一个元素为 dict,各字段为:
- * confidence (float): 识别的置信度;
- * label (str): 标签;
- * left (int): 边界框的左上角x坐标;
- * top (int): 边界框的左上角y坐标;
- * right (int): 边界框的右下角x坐标;
- * bottom (int): 边界框的右下角y坐标;
- * save\_path (str, optional): 识别结果的保存路径 (仅当visualization=True时存在)。
+## 三、模型API预测
-```python
-def save_inference_model(dirname,
- model_filename=None,
- params_filename=None,
- combined=True)
-```
+- ### 1、命令行预测
-将模型保存到指定路径。
+ - ```shell
+ $ hub run ssd_vgg16_300_coco2017 --input_path "/PATH/TO/IMAGE"
+ ```
+ - 通过命令行方式实现目标检测模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
+- ### 2、预测代码示例
-**参数**
+ - ```python
+ import paddlehub as hub
+ import cv2
-* dirname: 存在模型的目录名称
-* model\_filename: 模型文件名称,默认为\_\_model\_\_
-* params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效)
-* combined: 是否将参数保存到统一的一个文件中
+ object_detector = hub.Module(name="ssd_vgg16_300_coco2017")
+ result = object_detector.object_detection(images=[cv2.imread('/PATH/TO/IMAGE')])
+ # or
+ # result = object_detector.object_detection((paths=['/PATH/TO/IMAGE'])
+ ```
-## 代码示例
+- ### 3、API
-```python
-import paddlehub as hub
-import cv2
+ - ```python
+ def object_detection(paths=None,
+ images=None,
+ batch_size=1,
+ use_gpu=False,
+ output_dir='detection_result',
+ score_thresh=0.5,
+ visualization=True)
+ ```
-object_detector = hub.Module(name="ssd_vgg16_300_coco2017")
-result = object_detector.object_detection(images=[cv2.imread('/PATH/TO/IMAGE')])
-# or
-# result = object_detector.object_detection((paths=['/PATH/TO/IMAGE'])
-```
+ - 预测API,检测输入图片中的所有目标的位置。
-## 服务部署
+ - **参数**
-PaddleHub Serving可以部署一个目标检测的在线服务。
+ - paths (list\[str\]): 图片的路径;
+ - images (list\[numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\],BGR格式;
+ - batch\_size (int): batch 的大小;
+ - use\_gpu (bool): 是否使用 GPU;
+ - output\_dir (str): 图片的保存路径,默认设为 detection\_result;
+ - score\_thresh (float): 识别置信度的阈值;
+ - visualization (bool): 是否将识别结果保存为图片文件。
-## 第一步:启动PaddleHub Serving
+ **NOTE:** paths和images两个参数选择其一进行提供数据
-运行启动命令:
-```shell
-$ hub serving start -m ssd_vgg16_300_coco2017
-```
+ - **返回**
-这样就完成了一个目标检测的服务化API的部署,默认端口号为8866。
+ - res (list\[dict\]): 识别结果的列表,列表中每一个元素为 dict,各字段为:
+ - data (list): 检测结果,list的每一个元素为 dict,各字段为:
+ - confidence (float): 识别的置信度
+ - label (str): 标签
+ - left (int): 边界框的左上角x坐标
+ - top (int): 边界框的左上角y坐标
+ - right (int): 边界框的右下角x坐标
+ - bottom (int): 边界框的右下角y坐标
+ - save\_path (str, optional): 识别结果的保存路径 (仅当visualization=True时存在)
-**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
+ - ```python
+ def save_inference_model(dirname)
+ ```
+ - 将模型保存到指定路径。
-## 第二步:发送预测请求
+ - **参数**
-配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
+ - dirname: 模型保存路径
-```python
-import requests
-import json
-import cv2
-import base64
+## 四、服务部署
-def cv2_to_base64(image):
- data = cv2.imencode('.jpg', image)[1]
- return base64.b64encode(data.tostring()).decode('utf8')
+- PaddleHub Serving可以部署一个目标检测的在线服务。
+- ### 第一步:启动PaddleHub Serving
-# 发送HTTP请求
-data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
-headers = {"Content-type": "application/json"}
-url = "http://127.0.0.1:8866/predict/ssd_vgg16_300_coco2017"
-r = requests.post(url=url, headers=headers, data=json.dumps(data))
+ - 运行启动命令:
+ - ```shell
+ $ hub serving start -m ssd_vgg16_300_coco2017
+ ```
-# 打印预测结果
-print(r.json()["results"])
-```
+ - 这样就完成了一个目标检测的服务化API的部署,默认端口号为8866。
-### 依赖
+ - **NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
-paddlepaddle >= 1.6.2
+- ### 第二步:发送预测请求
-paddlehub >= 1.6.0
+ - 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
+
+ - ```python
+ import requests
+ import json
+ import cv2
+ import base64
+
+
+ def cv2_to_base64(image):
+ data = cv2.imencode('.jpg', image)[1]
+ return base64.b64encode(data.tostring()).decode('utf8')
+
+ # 发送HTTP请求
+ data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
+ headers = {"Content-type": "application/json"}
+ url = "http://127.0.0.1:8866/predict/ssd_vgg16_300_coco2017"
+ r = requests.post(url=url, headers=headers, data=json.dumps(data))
+
+ # 打印预测结果
+ print(r.json()["results"])
+ ```
+
+
+## 五、更新历史
+
+* 1.0.0
+
+ 初始发布
+
+* 1.0.2
+
+ 修复numpy数据读取问题
+
+* 1.1.0
+
+ 移除 fluid api
+
+ - ```shell
+ $ hub install ssd_vgg16_300_coco2017==1.1.0
+ ```
diff --git a/modules/image/object_detection/ssd_vgg16_300_coco2017/README_en.md b/modules/image/object_detection/ssd_vgg16_300_coco2017/README_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..0d53ce2f7f467095b28072e56a594df6043839fe
--- /dev/null
+++ b/modules/image/object_detection/ssd_vgg16_300_coco2017/README_en.md
@@ -0,0 +1,169 @@
+# ssd_vgg16_300_coco2017
+
+|Module Name|ssd_vgg16_300_coco2017|
+| :--- | :---: |
+|Category|object detection|
+|Network|SSD|
+|Dataset|COCO2017|
+|Fine-tuning supported or not|No|
+|Module Size|139MB|
+|Latest update date|2021-03-15|
+|Data indicators|-|
+
+
+## I.Basic Information
+
+- ### Application Effect Display
+ - Sample results:
+
+
+
+
+
+- ### Module Introduction
+
+ - Single Shot MultiBox Detector (SSD) is a one-stage detector. Different from two-stage detector, SSD frames object detection as a re- gression problem to spatially separated bounding boxes and associated class probabilities. This module is based on VGG16, trained on COCO2017 dataset, and can be used for object detection.
+
+
+
+## II.Installation
+
+- ### 1、Environmental Dependence
+
+ - paddlepaddle >= 1.6.2
+
+ - paddlehub >= 1.6.0 | [How to install PaddleHub](../../../../docs/docs_en/get_start/installation.rst)
+
+- ### 2、Installation
+
+ - ```shell
+ $ hub install ssd_vgg16_300_coco2017
+ ```
+ - In case of any problems during installation, please refer to: [Windows_Quickstart](../../../../docs/docs_en/get_start/windows_quickstart.md) | [Linux_Quickstart](../../../../docs/docs_en/get_start/linux_quickstart.md) | [Mac_Quickstart](../../../../docs/docs_en/get_start/mac_quickstart.md)
+
+## III.Module API Prediction
+
+- ### 1、Command line Prediction
+
+ - ```shell
+ $ hub run ssd_vgg16_300_coco2017 --input_path "/PATH/TO/IMAGE"
+ ```
+ - If you want to call the Hub module through the command line, please refer to: [PaddleHub Command Line Instruction](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
+- ### 2、Prediction Code Example
+
+ - ```python
+ import paddlehub as hub
+ import cv2
+
+ object_detector = hub.Module(name="ssd_vgg16_300_coco2017")
+ result = object_detector.object_detection(images=[cv2.imread('/PATH/TO/IMAGE')])
+ # or
+ # result = object_detector.object_detection((paths=['/PATH/TO/IMAGE'])
+ ```
+
+- ### 3、API
+
+ - ```python
+ def object_detection(paths=None,
+ images=None,
+ batch_size=1,
+ use_gpu=False,
+ output_dir='detection_result',
+ score_thresh=0.5,
+ visualization=True)
+ ```
+
+ - Detection API, detect positions of all objects in image
+
+ - **Parameters**
+
+ - paths (list[str]): image path;
+ - images (list\[numpy.ndarray\]): image data, ndarray.shape is in the format [H, W, C], BGR;
+ - batch_size (int): the size of batch;
+ - use_gpu (bool): use GPU or not; **set the CUDA_VISIBLE_DEVICES environment variable first if you are using GPU**
+ - output_dir (str): save path of images;
+ - score\_thresh (float): confidence threshold;
+ - visualization (bool): Whether to save the results as picture files;
+
+ **NOTE:** choose one parameter to provide data from paths and images
+
+ - **Return**
+
+ - res (list\[dict\]): results
+ - data (list): detection results, each element in the list is dict
+ - confidence (float): the confidence of the result
+ - label (str): label
+ - left (int): the upper left corner x coordinate of the detection box
+ - top (int): the upper left corner y coordinate of the detection box
+ - right (int): the lower right corner x coordinate of the detection box
+ - bottom (int): the lower right corner y coordinate of the detection box
+ - save\_path (str, optional): output path for saving results
+
+ - ```python
+ def save_inference_model(dirname)
+ ```
+ - Save model to specific path
+
+ - **Parameters**
+
+ - dirname: model save path
+
+
+## IV.Server Deployment
+
+- PaddleHub Serving can deploy an online service of object detection.
+
+- ### Step 1: Start PaddleHub Serving
+
+ - Run the startup command:
+ - ```shell
+ $ hub serving start -m ssd_vgg16_300_coco2017
+ ```
+
+ - The servitization API is now deployed and the default port number is 8866.
+
+ - **NOTE:** If GPU is used for prediction, set CUDA_VISIBLE_DEVICES environment variable before the service, otherwise it need not be set.
+
+- ### Step 2: Send a predictive request
+
+ - With a configured server, use the following lines of code to send the prediction request and obtain the result
+
+ - ```python
+ import requests
+ import json
+ import cv2
+ import base64
+
+
+ def cv2_to_base64(image):
+ data = cv2.imencode('.jpg', image)[1]
+ return base64.b64encode(data.tostring()).decode('utf8')
+
+ # Send an HTTP request
+ data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
+ headers = {"Content-type": "application/json"}
+ url = "http://127.0.0.1:8866/predict/ssd_vgg16_300_coco2017"
+ r = requests.post(url=url, headers=headers, data=json.dumps(data))
+
+ # print prediction results
+ print(r.json()["results"])
+ ```
+
+
+## V.Release Note
+
+* 1.0.0
+
+ First release
+
+* 1.0.2
+
+ Fix the problem of reading numpy
+
+* 1.1.0
+
+ Remove fluid api
+
+ - ```shell
+ $ hub install ssd_vgg16_300_coco2017==1.1.0
+ ```
diff --git a/modules/image/object_detection/ssd_vgg16_300_coco2017/data_feed.py b/modules/image/object_detection/ssd_vgg16_300_coco2017/data_feed.py
index 9fad7c95ec6207ad758b75b4799e5698509f07e6..3d3382bb2ba6ecbe9b3655adb2a667fd7b0ec20f 100644
--- a/modules/image/object_detection/ssd_vgg16_300_coco2017/data_feed.py
+++ b/modules/image/object_detection/ssd_vgg16_300_coco2017/data_feed.py
@@ -5,12 +5,10 @@ from __future__ import division
import os
import random
-from collections import OrderedDict
import cv2
import numpy as np
from PIL import Image
-from paddle import fluid
__all__ = ['reader']
diff --git a/modules/image/object_detection/ssd_vgg16_300_coco2017/module.py b/modules/image/object_detection/ssd_vgg16_300_coco2017/module.py
index e0083b95f7c4e3567fda508109af196e1226d087..beefaf6abe623446f1612e4ecf90b8c681392d7b 100644
--- a/modules/image/object_detection/ssd_vgg16_300_coco2017/module.py
+++ b/modules/image/object_detection/ssd_vgg16_300_coco2017/module.py
@@ -7,39 +7,43 @@ import os
from functools import partial
import yaml
+import paddle
import numpy as np
-import paddle.fluid as fluid
-import paddlehub as hub
-from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
+import paddle.static
+from paddle.inference import Config, create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving
-from paddlehub.common.paddle_helper import add_vars_prefix
-from ssd_vgg16_300_coco2017.vgg import VGG
-from ssd_vgg16_300_coco2017.processor import load_label_info, postprocess, base64_to_cv2
-from ssd_vgg16_300_coco2017.data_feed import reader
+from .processor import load_label_info, postprocess, base64_to_cv2
+from .data_feed import reader
@moduleinfo(
name="ssd_vgg16_300_coco2017",
- version="1.0.1",
+ version="1.1.0",
type="cv/object_detection",
summary="SSD with backbone VGG16, trained with dataset COCO.",
author="paddlepaddle",
author_email="paddle-dev@baidu.com")
-class SSDVGG16(hub.Module):
- def _initialize(self):
- self.default_pretrained_model_path = os.path.join(self.directory, "ssd_vgg16_300_model")
- self.label_names = load_label_info(os.path.join(self.directory, "label_file.txt"))
+class SSDVGG16:
+ def __init__(self):
+ self.default_pretrained_model_path = os.path.join(
+ self.directory, "ssd_vgg16_300_model", "model")
+ self.label_names = load_label_info(
+ os.path.join(self.directory, "label_file.txt"))
self.model_config = None
self._set_config()
def _set_config(self):
- # predictor config setting.
- cpu_config = AnalysisConfig(self.default_pretrained_model_path)
+ """
+ predictor config setting.
+ """
+ model = self.default_pretrained_model_path+'.pdmodel'
+ params = self.default_pretrained_model_path+'.pdiparams'
+ cpu_config = Config(model, params)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
cpu_config.switch_ir_optim(False)
- self.cpu_predictor = create_paddle_predictor(cpu_config)
+ self.cpu_predictor = create_predictor(cpu_config)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
@@ -48,10 +52,10 @@ class SSDVGG16(hub.Module):
except:
use_gpu = False
if use_gpu:
- gpu_config = AnalysisConfig(self.default_pretrained_model_path)
+ gpu_config = Config(model, params)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0)
- self.gpu_predictor = create_paddle_predictor(gpu_config)
+ self.gpu_predictor = create_predictor(gpu_config)
# model config setting.
if not self.model_config:
@@ -61,73 +65,6 @@ class SSDVGG16(hub.Module):
self.multi_box_head_config = self.model_config['MultiBoxHead']
self.output_decoder_config = self.model_config['SSDOutputDecoder']
- def context(self, trainable=True, pretrained=True, get_prediction=False):
- """
- Distill the Head Features, so as to perform transfer learning.
-
- Args:
- trainable (bool): whether to set parameters trainable.
- pretrained (bool): whether to load default pretrained model.
- get_prediction (bool): whether to get prediction.
-
- Returns:
- inputs(dict): the input variables.
- outputs(dict): the output variables.
- context_prog (Program): the program to execute transfer learning.
- """
- context_prog = fluid.Program()
- startup_program = fluid.Program()
- with fluid.program_guard(context_prog, startup_program):
- with fluid.unique_name.guard():
- # image
- image = fluid.layers.data(name='image', shape=[3, 300, 300], dtype='float32')
- # backbone
- backbone = VGG(depth=16, with_extra_blocks=True, normalizations=[20., -1, -1, -1, -1, -1])
- # body_feats
- body_feats = backbone(image)
- # im_size
- im_size = fluid.layers.data(name='im_size', shape=[2], dtype='int32')
- # var_prefix
- var_prefix = '@HUB_{}@'.format(self.name)
- # names of inputs
- inputs = {'image': var_prefix + image.name, 'im_size': var_prefix + im_size.name}
- # names of outputs
- if get_prediction:
- locs, confs, box, box_var = fluid.layers.multi_box_head(
- inputs=body_feats, image=image, num_classes=81, **self.multi_box_head_config)
- pred = fluid.layers.detection_output(
- loc=locs, scores=confs, prior_box=box, prior_box_var=box_var, **self.output_decoder_config)
- outputs = {'bbox_out': [var_prefix + pred.name]}
- else:
- outputs = {'body_features': [var_prefix + var.name for var in body_feats]}
-
- # add_vars_prefix
- add_vars_prefix(context_prog, var_prefix)
- add_vars_prefix(fluid.default_startup_program(), var_prefix)
- # inputs
- inputs = {key: context_prog.global_block().vars[value] for key, value in inputs.items()}
- outputs = {
- out_key: [context_prog.global_block().vars[varname] for varname in out_value]
- for out_key, out_value in outputs.items()
- }
- # trainable
- for param in context_prog.global_block().iter_parameters():
- param.trainable = trainable
-
- place = fluid.CPUPlace()
- exe = fluid.Executor(place)
- # pretrained
- if pretrained:
-
- def _if_exist(var):
- return os.path.exists(os.path.join(self.default_pretrained_model_path, var.name))
-
- fluid.io.load_vars(exe, self.default_pretrained_model_path, predicate=_if_exist)
- else:
- exe.run(startup_program)
-
- return inputs, outputs, context_prog
-
def object_detection(self,
paths=None,
images=None,
@@ -160,47 +97,31 @@ class SSDVGG16(hub.Module):
"""
paths = paths if paths else list()
data_reader = partial(reader, paths, images)
- batch_reader = fluid.io.batch(data_reader, batch_size=batch_size)
+ batch_reader = paddle.batch(data_reader, batch_size=batch_size)
res = []
for iter_id, feed_data in enumerate(batch_reader()):
feed_data = np.array(feed_data)
- image_tensor = PaddleTensor(np.array(list(feed_data[:, 0])).copy())
- if use_gpu:
- data_out = self.gpu_predictor.run([image_tensor])
- else:
- data_out = self.cpu_predictor.run([image_tensor])
- output = postprocess(
- paths=paths,
- images=images,
- data_out=data_out,
- score_thresh=score_thresh,
- label_names=self.label_names,
- output_dir=output_dir,
- handle_id=iter_id * batch_size,
- visualization=visualization)
+ predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
+ input_names = predictor.get_input_names()
+ input_handle = predictor.get_input_handle(input_names[0])
+ input_handle.copy_from_cpu(np.array(list(feed_data[:, 0])))
+
+ predictor.run()
+ output_names = predictor.get_output_names()
+ output_handle = predictor.get_output_handle(output_names[0])
+
+ output = postprocess(paths=paths,
+ images=images,
+ data_out=output_handle,
+ score_thresh=score_thresh,
+ label_names=self.label_names,
+ output_dir=output_dir,
+ handle_id=iter_id * batch_size,
+ visualization=visualization)
res.extend(output)
return res
- def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True):
- if combined:
- model_filename = "__model__" if not model_filename else model_filename
- params_filename = "__params__" if not params_filename else params_filename
- place = fluid.CPUPlace()
- exe = fluid.Executor(place)
-
- program, feeded_var_names, target_vars = fluid.io.load_inference_model(
- dirname=self.default_pretrained_model_path, executor=exe)
-
- fluid.io.save_inference_model(
- dirname=dirname,
- main_program=program,
- executor=exe,
- feeded_var_names=feeded_var_names,
- target_vars=target_vars,
- model_filename=model_filename,
- params_filename=params_filename)
-
@serving
def serving_method(self, images, **kwargs):
"""
@@ -220,9 +141,12 @@ class SSDVGG16(hub.Module):
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
- self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
+ self.arg_input_group = self.parser.add_argument_group(
+ title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
- title="Config options", description="Run configuration for controlling module behavior, not required.")
+ title="Config options",
+ description=
+ "Run configuration for controlling module behavior, not required.")
self.add_module_config_arg()
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
@@ -240,17 +164,34 @@ class SSDVGG16(hub.Module):
Add the command config options.
"""
self.arg_config_group.add_argument(
- '--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not")
+ '--use_gpu',
+ type=ast.literal_eval,
+ default=False,
+ help="whether use GPU or not")
self.arg_config_group.add_argument(
- '--output_dir', type=str, default='detection_result', help="The directory to save output images.")
+ '--output_dir',
+ type=str,
+ default='detection_result',
+ help="The directory to save output images.")
self.arg_config_group.add_argument(
- '--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.")
+ '--visualization',
+ type=ast.literal_eval,
+ default=False,
+ help="whether to save output as images.")
def add_module_input_arg(self):
"""
Add the command input options.
"""
- self.arg_input_group.add_argument('--input_path', type=str, help="path to image.")
- self.arg_input_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.")
self.arg_input_group.add_argument(
- '--score_thresh', type=ast.literal_eval, default=0.5, help="threshold for object detecion.")
+ '--input_path', type=str, help="path to image.")
+ self.arg_input_group.add_argument(
+ '--batch_size',
+ type=ast.literal_eval,
+ default=1,
+ help="batch size.")
+ self.arg_input_group.add_argument(
+ '--score_thresh',
+ type=ast.literal_eval,
+ default=0.5,
+ help="threshold for object detecion.")
diff --git a/modules/image/object_detection/ssd_vgg16_300_coco2017/processor.py b/modules/image/object_detection/ssd_vgg16_300_coco2017/processor.py
index ff4eb9fe5fd596233ef90c1a0a5baa9d0ff0e56f..9bf964540ef7657c8052040768c00b2758574ca9 100644
--- a/modules/image/object_detection/ssd_vgg16_300_coco2017/processor.py
+++ b/modules/image/object_detection/ssd_vgg16_300_coco2017/processor.py
@@ -85,7 +85,7 @@ def load_label_info(file_path):
def postprocess(paths, images, data_out, score_thresh, label_names, output_dir, handle_id, visualization=True):
"""
- postprocess the lod_tensor produced by fluid.Executor.run
+ postprocess the lod_tensor produced by Executor.run
Args:
paths (list[str]): the path of images.
@@ -108,9 +108,9 @@ def postprocess(paths, images, data_out, score_thresh, label_names, output_dir,
confidence (float): The confidence of detection result.
save_path (str): The path to save output images.
"""
- lod_tensor = data_out[0]
- lod = lod_tensor.lod[0]
- results = lod_tensor.as_ndarray()
+ lod = data_out.lod()[0]
+ results = data_out.copy_to_cpu()
+
if handle_id < len(paths):
unhandled_paths = paths[handle_id:]
unhandled_paths_num = len(unhandled_paths)
diff --git a/modules/image/object_detection/ssd_vgg16_300_coco2017/test.py b/modules/image/object_detection/ssd_vgg16_300_coco2017/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..922f3b6012457297b3a1078585b14ebc47c4d249
--- /dev/null
+++ b/modules/image/object_detection/ssd_vgg16_300_coco2017/test.py
@@ -0,0 +1,108 @@
+import os
+import shutil
+import unittest
+
+import cv2
+import requests
+import paddlehub as hub
+
+
+class TestHubModule(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls) -> None:
+ img_url = 'https://ai-studio-static-online.cdn.bcebos.com/68313e182f5e4ad9907e69dac9ece8fc50840d7ffbd24fa88396f009958f969a'
+ if not os.path.exists('tests'):
+ os.makedirs('tests')
+ response = requests.get(img_url)
+ assert response.status_code == 200, 'Network Error.'
+ with open('tests/test.jpg', 'wb') as f:
+ f.write(response.content)
+ cls.module = hub.Module(name="ssd_vgg16_300_coco2017")
+
+ @classmethod
+ def tearDownClass(cls) -> None:
+ shutil.rmtree('tests')
+ shutil.rmtree('inference')
+ shutil.rmtree('detection_result')
+
+ def test_object_detection1(self):
+ results = self.module.object_detection(
+ paths=['tests/test.jpg']
+ )
+ bbox = results[0]['data'][0]
+ label = bbox['label']
+ confidence = bbox['confidence']
+ left = bbox['left']
+ right = bbox['right']
+ top = bbox['top']
+ bottom = bbox['bottom']
+
+ self.assertEqual(label, 'cat')
+ self.assertTrue(confidence > 0.5)
+ self.assertTrue(200 < left < 800)
+ self.assertTrue(2500 < right < 3500)
+ self.assertTrue(500 < top < 1500)
+ self.assertTrue(3500 < bottom < 4500)
+
+ def test_object_detection2(self):
+ results = self.module.object_detection(
+ images=[cv2.imread('tests/test.jpg')]
+ )
+ bbox = results[0]['data'][0]
+ label = bbox['label']
+ confidence = bbox['confidence']
+ left = bbox['left']
+ right = bbox['right']
+ top = bbox['top']
+ bottom = bbox['bottom']
+
+ self.assertEqual(label, 'cat')
+ self.assertTrue(confidence > 0.5)
+ self.assertTrue(200 < left < 800)
+ self.assertTrue(2500 < right < 3500)
+ self.assertTrue(500 < top < 1500)
+ self.assertTrue(3500 < bottom < 4500)
+
+ def test_object_detection3(self):
+ results = self.module.object_detection(
+ images=[cv2.imread('tests/test.jpg')],
+ visualization=False
+ )
+ bbox = results[0]['data'][0]
+ label = bbox['label']
+ confidence = bbox['confidence']
+ left = bbox['left']
+ right = bbox['right']
+ top = bbox['top']
+ bottom = bbox['bottom']
+
+ self.assertEqual(label, 'cat')
+ self.assertTrue(confidence > 0.5)
+ self.assertTrue(200 < left < 800)
+ self.assertTrue(2500 < right < 3500)
+ self.assertTrue(500 < top < 1500)
+ self.assertTrue(3500 < bottom < 4500)
+
+ def test_object_detection4(self):
+ self.assertRaises(
+ AssertionError,
+ self.module.object_detection,
+ paths=['no.jpg']
+ )
+
+ def test_object_detection5(self):
+ self.assertRaises(
+ cv2.error,
+ self.module.object_detection,
+ images=['test.jpg']
+ )
+
+ def test_save_inference_model(self):
+ self.module.save_inference_model('./inference/model')
+
+ self.assertTrue(os.path.exists('./inference/model.pdmodel'))
+ self.assertTrue(os.path.exists('./inference/model.pdiparams'))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/modules/image/object_detection/ssd_vgg16_300_coco2017/vgg.py b/modules/image/object_detection/ssd_vgg16_300_coco2017/vgg.py
deleted file mode 100644
index d950c6b553d9af29086ba6f942d005920e74c296..0000000000000000000000000000000000000000
--- a/modules/image/object_detection/ssd_vgg16_300_coco2017/vgg.py
+++ /dev/null
@@ -1,184 +0,0 @@
-# coding=utf-8
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from paddle import fluid
-from paddle.fluid.param_attr import ParamAttr
-
-__all__ = ['VGG']
-
-
-class VGG(object):
- """
- VGG, see https://arxiv.org/abs/1409.1556
-
- Args:
- depth (int): the VGG net depth (16 or 19)
- normalizations (list): params list of init scale in l2 norm, skip init
- scale if param is -1.
- with_extra_blocks (bool): whether or not extra blocks should be added
- extra_block_filters (list): in each extra block, params:
- [in_channel, out_channel, padding_size, stride_size, filter_size]
- class_dim (int): number of class while classification
- """
-
- def __init__(self,
- depth=16,
- with_extra_blocks=False,
- normalizations=[20., -1, -1, -1, -1, -1],
- extra_block_filters=[[256, 512, 1, 2, 3], [128, 256, 1, 2, 3], [128, 256, 0, 1, 3],
- [128, 256, 0, 1, 3]],
- class_dim=1000):
- assert depth in [16, 19], "depth {} not in [16, 19]"
- self.depth = depth
- self.depth_cfg = {16: [2, 2, 3, 3, 3], 19: [2, 2, 4, 4, 4]}
- self.with_extra_blocks = with_extra_blocks
- self.normalizations = normalizations
- self.extra_block_filters = extra_block_filters
- self.class_dim = class_dim
-
- def __call__(self, input):
- layers = []
- layers += self._vgg_block(input)
-
- if not self.with_extra_blocks:
- return layers[-1]
-
- layers += self._add_extras_block(layers[-1])
- norm_cfg = self.normalizations
- for k, v in enumerate(layers):
- if not norm_cfg[k] == -1:
- layers[k] = self._l2_norm_scale(v, init_scale=norm_cfg[k])
-
- return layers
-
- def _vgg_block(self, input):
- nums = self.depth_cfg[self.depth]
- vgg_base = [64, 128, 256, 512, 512]
- conv = input
- res_layer = []
- layers = []
- for k, v in enumerate(vgg_base):
- conv = self._conv_block(conv, v, nums[k], name="conv{}_".format(k + 1))
- layers.append(conv)
- if self.with_extra_blocks:
- if k == 4:
- conv = self._pooling_block(conv, 3, 1, pool_padding=1)
- else:
- conv = self._pooling_block(conv, 2, 2)
- else:
- conv = self._pooling_block(conv, 2, 2)
- if not self.with_extra_blocks:
- fc_dim = 4096
- fc_name = ["fc6", "fc7", "fc8"]
- fc1 = fluid.layers.fc(
- input=conv,
- size=fc_dim,
- act='relu',
- param_attr=fluid.param_attr.ParamAttr(name=fc_name[0] + "_weights"),
- bias_attr=fluid.param_attr.ParamAttr(name=fc_name[0] + "_offset"))
- fc2 = fluid.layers.fc(
- input=fc1,
- size=fc_dim,
- act='relu',
- param_attr=fluid.param_attr.ParamAttr(name=fc_name[1] + "_weights"),
- bias_attr=fluid.param_attr.ParamAttr(name=fc_name[1] + "_offset"))
- out = fluid.layers.fc(
- input=fc2,
- size=self.class_dim,
- param_attr=fluid.param_attr.ParamAttr(name=fc_name[2] + "_weights"),
- bias_attr=fluid.param_attr.ParamAttr(name=fc_name[2] + "_offset"))
- out = fluid.layers.softmax(out)
- res_layer.append(out)
- return [out]
- else:
- fc6 = self._conv_layer(conv, 1024, 3, 1, 6, dilation=6, name="fc6")
- fc7 = self._conv_layer(fc6, 1024, 1, 1, 0, name="fc7")
- return [layers[3], fc7]
-
- def _add_extras_block(self, input):
- cfg = self.extra_block_filters
- conv = input
- layers = []
- for k, v in enumerate(cfg):
- assert len(v) == 5, "extra_block_filters size not fix"
- conv = self._extra_block(conv, v[0], v[1], v[2], v[3], v[4], name="conv{}_".format(6 + k))
- layers.append(conv)
-
- return layers
-
- def _conv_block(self, input, num_filter, groups, name=None):
- conv = input
- for i in range(groups):
- conv = self._conv_layer(
- input=conv,
- num_filters=num_filter,
- filter_size=3,
- stride=1,
- padding=1,
- act='relu',
- name=name + str(i + 1))
- return conv
-
- def _extra_block(self, input, num_filters1, num_filters2, padding_size, stride_size, filter_size, name=None):
- # 1x1 conv
- conv_1 = self._conv_layer(
- input=input, num_filters=int(num_filters1), filter_size=1, stride=1, act='relu', padding=0, name=name + "1")
-
- # 3x3 conv
- conv_2 = self._conv_layer(
- input=conv_1,
- num_filters=int(num_filters2),
- filter_size=filter_size,
- stride=stride_size,
- act='relu',
- padding=padding_size,
- name=name + "2")
- return conv_2
-
- def _conv_layer(self,
- input,
- num_filters,
- filter_size,
- stride,
- padding,
- dilation=1,
- act='relu',
- use_cudnn=True,
- name=None):
- conv = fluid.layers.conv2d(
- input=input,
- num_filters=num_filters,
- filter_size=filter_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- act=act,
- use_cudnn=use_cudnn,
- param_attr=ParamAttr(name=name + "_weights"),
- bias_attr=ParamAttr(name=name + "_biases") if self.with_extra_blocks else False,
- name=name + '.conv2d.output.1')
- return conv
-
- def _pooling_block(self, conv, pool_size, pool_stride, pool_padding=0, ceil_mode=True):
- pool = fluid.layers.pool2d(
- input=conv,
- pool_size=pool_size,
- pool_type='max',
- pool_stride=pool_stride,
- pool_padding=pool_padding,
- ceil_mode=ceil_mode)
- return pool
-
- def _l2_norm_scale(self, input, init_scale=1.0, channel_shared=False):
- from paddle.fluid.layer_helper import LayerHelper
- from paddle.fluid.initializer import Constant
- helper = LayerHelper("Scale")
- l2_norm = fluid.layers.l2_normalize(input, axis=1) # l2 norm along channel
- shape = [1] if channel_shared else [input.shape[1]]
- scale = helper.create_parameter(
- attr=helper.param_attr, shape=shape, dtype=input.dtype, default_initializer=Constant(init_scale))
- out = fluid.layers.elementwise_mul(
- x=l2_norm, y=scale, axis=-1 if channel_shared else 1, name="conv4_3_norm_scale")
- return out