未验证 提交 7a847a39 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update yolov3_darknet53_vehicles (#1957)

* update yolov3_darknet53_vehicles

* update gpu config

* update

* add clean func

* update save inference model
上级 f3d7b12c
......@@ -100,19 +100,13 @@
- save\_path (str, optional): 识别结果的保存路径 (仅当visualization=True时存在)
- ```python
def save_inference_model(dirname,
model_filename=None,
params_filename=None,
combined=True)
def save_inference_model(dirname)
```
- 将模型保存到指定路径。
- **参数**
- dirname: 存在模型的目录名称; <br/>
- model\_filename: 模型文件名称,默认为\_\_model\_\_; <br/>
- params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效);<br/>
- combined: 是否将参数保存到统一的一个文件中。
- dirname: 模型保存路径 <br/>
## 四、服务部署
......@@ -170,6 +164,10 @@
移除 fluid api
* 1.1.0
修复推理模型无法导出的问题
- ```shell
$ hub install yolov3_darknet53_vehicles==1.0.3
$ hub install yolov3_darknet53_vehicles==1.1.0
```
......@@ -100,19 +100,13 @@
- save\_path (str, optional): output path for saving results
- ```python
def save_inference_model(dirname,
model_filename=None,
params_filename=None,
combined=True)
def save_inference_model(dirname)
```
- Save model to specific path
- **Parameters**
- dirname: output dir for saving model
- model\_filename: filename for saving model
- params\_filename: filename for saving parameters
- combined: whether save parameters into one file
- dirname: model save path
## IV.Server Deployment
......@@ -170,6 +164,10 @@
Remove fluid api
* 1.1.0
Fix bug of save_inference_model
- ```shell
$ hub install yolov3_darknet53_vehicles==1.0.3
$ hub install yolov3_darknet53_vehicles==1.1.0
```
......@@ -8,30 +8,29 @@ from functools import partial
import numpy as np
import paddle
import paddle.jit
import paddle.static
from paddle.inference import Config
from paddle.inference import create_predictor
from yolov3_darknet53_vehicles.data_feed import reader
from yolov3_darknet53_vehicles.processor import base64_to_cv2
from yolov3_darknet53_vehicles.processor import load_label_info
from yolov3_darknet53_vehicles.processor import postprocess
from .data_feed import reader
from .processor import base64_to_cv2
from .processor import load_label_info
from .processor import postprocess
import paddlehub as hub
from paddlehub.common.paddle_helper import add_vars_prefix
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
from paddlehub.module.module import serving
@moduleinfo(name="yolov3_darknet53_vehicles",
version="1.0.3",
version="1.1.0",
type="CV/object_detection",
summary="Baidu's YOLOv3 model for vehicles detection, with backbone DarkNet53.",
author="paddlepaddle",
author_email="paddle-dev@baidu.com")
class YOLOv3DarkNet53Vehicles(hub.Module):
def _initialize(self):
self.default_pretrained_model_path = os.path.join(self.directory, "yolov3_darknet53_vehicles_model")
class YOLOv3DarkNet53Vehicles:
def __init__(self):
self.default_pretrained_model_path = os.path.join(self.directory, "yolov3_darknet53_vehicles_model", "model")
self.label_names = load_label_info(os.path.join(self.directory, "label_file.txt"))
self._set_config()
......@@ -49,7 +48,9 @@ class YOLOv3DarkNet53Vehicles(hub.Module):
"""
# create default cpu predictor
cpu_config = Config(self.default_pretrained_model_path)
model = self.default_pretrained_model_path+'.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
self.cpu_predictor = create_predictor(cpu_config)
......@@ -60,7 +61,7 @@ class YOLOv3DarkNet53Vehicles(hub.Module):
npu_id = self._get_device_id("FLAGS_selected_npus")
if npu_id != -1:
# use npu
npu_config = Config(self.default_pretrained_model_path)
npu_config = Config(model, params)
npu_config.disable_glog_info()
npu_config.enable_npu(device_id=npu_id)
self.npu_predictor = create_predictor(npu_config)
......@@ -69,7 +70,7 @@ class YOLOv3DarkNet53Vehicles(hub.Module):
gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
if gpu_id != -1:
# use gpu
gpu_config = Config(self.default_pretrained_model_path)
gpu_config = Config(model, params)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=gpu_id)
self.gpu_predictor = create_predictor(gpu_config)
......@@ -78,7 +79,7 @@ class YOLOv3DarkNet53Vehicles(hub.Module):
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
# use xpu
xpu_config = Config(self.default_pretrained_model_path)
xpu_config = Config(model, params)
xpu_config.disable_glog_info()
xpu_config.enable_xpu(100)
self.xpu_predictor = create_predictor(xpu_config)
......@@ -169,24 +170,6 @@ class YOLOv3DarkNet53Vehicles(hub.Module):
res.extend(output)
return res
def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = paddle.CPUPlace()
exe = paddle.Executor(place)
program, feeded_var_names, target_vars = paddle.static.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
paddle.static.save_inference_model(dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving
def serving_method(self, images, **kwargs):
"""
......
......@@ -88,7 +88,7 @@ def load_label_info(file_path):
def postprocess(paths, images, data_out, score_thresh, label_names, output_dir, handle_id, visualization=True):
"""
postprocess the lod_tensor produced by fluid.Executor.run
postprocess the lod_tensor produced by Executor.run
Args:
paths (list[str]): The paths of images.
......
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://ai-studio-static-online.cdn.bcebos.com/036990d3d8654d789c2138492155d9dd95dba2a2fc8e410ab059eea42b330f59'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="yolov3_darknet53_vehicles")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('yolov3_vehicles_detect_output')
def test_object_detection1(self):
results = self.module.object_detection(
paths=['tests/test.jpg']
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'car')
self.assertTrue(confidence > 0.5)
self.assertTrue(2000 < left < 4000)
self.assertTrue(4000 < right < 6000)
self.assertTrue(1000 < top < 3000)
self.assertTrue(2000 < bottom < 5000)
def test_object_detection2(self):
results = self.module.object_detection(
images=[cv2.imread('tests/test.jpg')]
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'car')
self.assertTrue(confidence > 0.5)
self.assertTrue(2000 < left < 4000)
self.assertTrue(4000 < right < 6000)
self.assertTrue(1000 < top < 3000)
self.assertTrue(2000 < bottom < 5000)
def test_object_detection3(self):
results = self.module.object_detection(
images=[cv2.imread('tests/test.jpg')],
visualization=False
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'car')
self.assertTrue(confidence > 0.5)
self.assertTrue(2000 < left < 4000)
self.assertTrue(4000 < right < 6000)
self.assertTrue(1000 < top < 3000)
self.assertTrue(2000 < bottom < 5000)
def test_object_detection4(self):
self.assertRaises(
AssertionError,
self.module.object_detection,
paths=['no.jpg']
)
def test_object_detection5(self):
self.assertRaises(
AttributeError,
self.module.object_detection,
images=['test.jpg']
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册