未验证 提交 7a847a39 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update yolov3_darknet53_vehicles (#1957)

* update yolov3_darknet53_vehicles

* update gpu config

* update

* add clean func

* update save inference model
上级 f3d7b12c
...@@ -100,19 +100,13 @@ ...@@ -100,19 +100,13 @@
- save\_path (str, optional): 识别结果的保存路径 (仅当visualization=True时存在) - save\_path (str, optional): 识别结果的保存路径 (仅当visualization=True时存在)
- ```python - ```python
def save_inference_model(dirname, def save_inference_model(dirname)
model_filename=None,
params_filename=None,
combined=True)
``` ```
- 将模型保存到指定路径。 - 将模型保存到指定路径。
- **参数** - **参数**
- dirname: 存在模型的目录名称; <br/> - dirname: 模型保存路径 <br/>
- model\_filename: 模型文件名称,默认为\_\_model\_\_; <br/>
- params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效);<br/>
- combined: 是否将参数保存到统一的一个文件中。
## 四、服务部署 ## 四、服务部署
...@@ -170,6 +164,10 @@ ...@@ -170,6 +164,10 @@
移除 fluid api 移除 fluid api
* 1.1.0
修复推理模型无法导出的问题
- ```shell - ```shell
$ hub install yolov3_darknet53_vehicles==1.0.3 $ hub install yolov3_darknet53_vehicles==1.1.0
``` ```
...@@ -100,19 +100,13 @@ ...@@ -100,19 +100,13 @@
- save\_path (str, optional): output path for saving results - save\_path (str, optional): output path for saving results
- ```python - ```python
def save_inference_model(dirname, def save_inference_model(dirname)
model_filename=None,
params_filename=None,
combined=True)
``` ```
- Save model to specific path - Save model to specific path
- **Parameters** - **Parameters**
- dirname: output dir for saving model - dirname: model save path
- model\_filename: filename for saving model
- params\_filename: filename for saving parameters
- combined: whether save parameters into one file
## IV.Server Deployment ## IV.Server Deployment
...@@ -170,6 +164,10 @@ ...@@ -170,6 +164,10 @@
Remove fluid api Remove fluid api
* 1.1.0
Fix bug of save_inference_model
- ```shell - ```shell
$ hub install yolov3_darknet53_vehicles==1.0.3 $ hub install yolov3_darknet53_vehicles==1.1.0
``` ```
...@@ -8,30 +8,29 @@ from functools import partial ...@@ -8,30 +8,29 @@ from functools import partial
import numpy as np import numpy as np
import paddle import paddle
import paddle.jit
import paddle.static
from paddle.inference import Config from paddle.inference import Config
from paddle.inference import create_predictor from paddle.inference import create_predictor
from yolov3_darknet53_vehicles.data_feed import reader from .data_feed import reader
from yolov3_darknet53_vehicles.processor import base64_to_cv2 from .processor import base64_to_cv2
from yolov3_darknet53_vehicles.processor import load_label_info from .processor import load_label_info
from yolov3_darknet53_vehicles.processor import postprocess from .processor import postprocess
import paddlehub as hub
from paddlehub.common.paddle_helper import add_vars_prefix
from paddlehub.module.module import moduleinfo from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable from paddlehub.module.module import runnable
from paddlehub.module.module import serving from paddlehub.module.module import serving
@moduleinfo(name="yolov3_darknet53_vehicles", @moduleinfo(name="yolov3_darknet53_vehicles",
version="1.0.3", version="1.1.0",
type="CV/object_detection", type="CV/object_detection",
summary="Baidu's YOLOv3 model for vehicles detection, with backbone DarkNet53.", summary="Baidu's YOLOv3 model for vehicles detection, with backbone DarkNet53.",
author="paddlepaddle", author="paddlepaddle",
author_email="paddle-dev@baidu.com") author_email="paddle-dev@baidu.com")
class YOLOv3DarkNet53Vehicles(hub.Module): class YOLOv3DarkNet53Vehicles:
def __init__(self):
def _initialize(self): self.default_pretrained_model_path = os.path.join(self.directory, "yolov3_darknet53_vehicles_model", "model")
self.default_pretrained_model_path = os.path.join(self.directory, "yolov3_darknet53_vehicles_model")
self.label_names = load_label_info(os.path.join(self.directory, "label_file.txt")) self.label_names = load_label_info(os.path.join(self.directory, "label_file.txt"))
self._set_config() self._set_config()
...@@ -49,7 +48,9 @@ class YOLOv3DarkNet53Vehicles(hub.Module): ...@@ -49,7 +48,9 @@ class YOLOv3DarkNet53Vehicles(hub.Module):
""" """
# create default cpu predictor # create default cpu predictor
cpu_config = Config(self.default_pretrained_model_path) model = self.default_pretrained_model_path+'.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
self.cpu_predictor = create_predictor(cpu_config) self.cpu_predictor = create_predictor(cpu_config)
...@@ -60,7 +61,7 @@ class YOLOv3DarkNet53Vehicles(hub.Module): ...@@ -60,7 +61,7 @@ class YOLOv3DarkNet53Vehicles(hub.Module):
npu_id = self._get_device_id("FLAGS_selected_npus") npu_id = self._get_device_id("FLAGS_selected_npus")
if npu_id != -1: if npu_id != -1:
# use npu # use npu
npu_config = Config(self.default_pretrained_model_path) npu_config = Config(model, params)
npu_config.disable_glog_info() npu_config.disable_glog_info()
npu_config.enable_npu(device_id=npu_id) npu_config.enable_npu(device_id=npu_id)
self.npu_predictor = create_predictor(npu_config) self.npu_predictor = create_predictor(npu_config)
...@@ -69,7 +70,7 @@ class YOLOv3DarkNet53Vehicles(hub.Module): ...@@ -69,7 +70,7 @@ class YOLOv3DarkNet53Vehicles(hub.Module):
gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES") gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
if gpu_id != -1: if gpu_id != -1:
# use gpu # use gpu
gpu_config = Config(self.default_pretrained_model_path) gpu_config = Config(model, params)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=gpu_id) gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=gpu_id)
self.gpu_predictor = create_predictor(gpu_config) self.gpu_predictor = create_predictor(gpu_config)
...@@ -78,7 +79,7 @@ class YOLOv3DarkNet53Vehicles(hub.Module): ...@@ -78,7 +79,7 @@ class YOLOv3DarkNet53Vehicles(hub.Module):
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES") xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1: if xpu_id != -1:
# use xpu # use xpu
xpu_config = Config(self.default_pretrained_model_path) xpu_config = Config(model, params)
xpu_config.disable_glog_info() xpu_config.disable_glog_info()
xpu_config.enable_xpu(100) xpu_config.enable_xpu(100)
self.xpu_predictor = create_predictor(xpu_config) self.xpu_predictor = create_predictor(xpu_config)
...@@ -169,24 +170,6 @@ class YOLOv3DarkNet53Vehicles(hub.Module): ...@@ -169,24 +170,6 @@ class YOLOv3DarkNet53Vehicles(hub.Module):
res.extend(output) res.extend(output)
return res return res
def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = paddle.CPUPlace()
exe = paddle.Executor(place)
program, feeded_var_names, target_vars = paddle.static.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
paddle.static.save_inference_model(dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving @serving
def serving_method(self, images, **kwargs): def serving_method(self, images, **kwargs):
""" """
......
...@@ -88,7 +88,7 @@ def load_label_info(file_path): ...@@ -88,7 +88,7 @@ def load_label_info(file_path):
def postprocess(paths, images, data_out, score_thresh, label_names, output_dir, handle_id, visualization=True): def postprocess(paths, images, data_out, score_thresh, label_names, output_dir, handle_id, visualization=True):
""" """
postprocess the lod_tensor produced by fluid.Executor.run postprocess the lod_tensor produced by Executor.run
Args: Args:
paths (list[str]): The paths of images. paths (list[str]): The paths of images.
......
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://ai-studio-static-online.cdn.bcebos.com/036990d3d8654d789c2138492155d9dd95dba2a2fc8e410ab059eea42b330f59'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="yolov3_darknet53_vehicles")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('yolov3_vehicles_detect_output')
def test_object_detection1(self):
results = self.module.object_detection(
paths=['tests/test.jpg']
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'car')
self.assertTrue(confidence > 0.5)
self.assertTrue(2000 < left < 4000)
self.assertTrue(4000 < right < 6000)
self.assertTrue(1000 < top < 3000)
self.assertTrue(2000 < bottom < 5000)
def test_object_detection2(self):
results = self.module.object_detection(
images=[cv2.imread('tests/test.jpg')]
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'car')
self.assertTrue(confidence > 0.5)
self.assertTrue(2000 < left < 4000)
self.assertTrue(4000 < right < 6000)
self.assertTrue(1000 < top < 3000)
self.assertTrue(2000 < bottom < 5000)
def test_object_detection3(self):
results = self.module.object_detection(
images=[cv2.imread('tests/test.jpg')],
visualization=False
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'car')
self.assertTrue(confidence > 0.5)
self.assertTrue(2000 < left < 4000)
self.assertTrue(4000 < right < 6000)
self.assertTrue(1000 < top < 3000)
self.assertTrue(2000 < bottom < 5000)
def test_object_detection4(self):
self.assertRaises(
AssertionError,
self.module.object_detection,
paths=['no.jpg']
)
def test_object_detection5(self):
self.assertRaises(
AttributeError,
self.module.object_detection,
images=['test.jpg']
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册