未验证 提交 5c923528 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update pyramidbox_lite_mobile_mask (#1997)

* update pyramidbox_lite_mobile_mask

* update

* add clean func

* update save inference model
上级 0ea0f8e8
......@@ -131,19 +131,13 @@
- ```python
def save_inference_model(dirname,
model_filename=None,
params_filename=None,
combined=True)
def save_inference_model(dirname)
```
- 将模型保存到指定路径。
- **参数**
- dirname: 存在模型的目录名称; <br/>
- model\_filename: 模型文件名称,默认为\_\_model\_\_; <br/>
- params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效);<br/>
- combined: 是否将参数保存到统一的一个文件中。
- dirname: 模型保存路径 <br/>
## 四、服务部署
......@@ -194,7 +188,6 @@
# 将模型保存在test_program文件夹之中
pyramidbox_lite_mobile_mask.save_inference_model(dirname="test_program")
```
通过以上命令,可以获得人脸检测和口罩佩戴判断模型,分别存储在pyramidbox\_lite和mask\_detector之中。文件夹中的\_\_model\_\_是模型结构文件,\_\_params\_\_文件是权重文件。
- ### 进行模型转换
- 从paddlehub下载的是预测模型,可以使用PaddleLite提供的模型优化工具OPT对预测模型进行转换,转换之后进而可以实现在手机等端侧硬件上的部署,具体请请参考[OPT工具](https://paddle-lite.readthedocs.io/zh/latest/user_guides/model_optimize_tool.html)
......@@ -212,6 +205,10 @@
移除 fluid api
* 1.4.0
修复无法导出模型的问题
- ```shell
$ hub install pyramidbox_lite_mobile_mask==1.3.1
$ hub install pyramidbox_lite_mobile_mask==1.4.0
```
......@@ -107,20 +107,13 @@
- ```python
def save_inference_model(dirname,
model_filename=None,
params_filename=None,
combined=True)
def save_inference_model(dirname)
```
- Save model to specific path
- **Parameters**
- dirname: output dir for saving model
- model\_filename: filename for saving model
- params\_filename: filename for saving parameters
- combined: whether save parameters into one file
- dirname: model save path
## IV.Server Deployment
......@@ -188,6 +181,10 @@
Remove fluid api
* 1.4.0
Fix a bug of save_inference_model
- ```shell
$ hub install pyramidbox_lite_mobile_mask==1.3.1
$ hub install pyramidbox_lite_mobile_mask==1.4.0
```
......@@ -10,9 +10,9 @@ import numpy as np
import paddle
from paddle.inference import Config
from paddle.inference import create_predictor
from pyramidbox_lite_mobile_mask.data_feed import reader
from pyramidbox_lite_mobile_mask.processor import base64_to_cv2
from pyramidbox_lite_mobile_mask.processor import postprocess
from .data_feed import reader
from .processor import base64_to_cv2
from .processor import postprocess
import paddlehub as hub
from paddlehub.module.module import moduleinfo
......@@ -27,15 +27,14 @@ from paddlehub.module.module import serving
author_email="",
summary=
"Pyramidbox-Lite-Mobile-Mask is a high-performance face detection model used to detect whether people wear masks.",
version="1.3.1")
class PyramidBoxLiteMobileMask(hub.Module):
def _initialize(self, face_detector_module=None):
version="1.4.0")
class PyramidBoxLiteMobileMask:
def __init__(self, face_detector_module=None):
"""
Args:
face_detector_module (class): module to detect face.
"""
self.default_pretrained_model_path = os.path.join(self.directory, "pyramidbox_lite_mobile_mask_model")
self.default_pretrained_model_path = os.path.join(self.directory, "pyramidbox_lite_mobile_mask_model", "model")
if face_detector_module is None:
self.face_detector = hub.Module(name='pyramidbox_lite_mobile')
else:
......@@ -47,7 +46,9 @@ class PyramidBoxLiteMobileMask(hub.Module):
"""
predictor config setting
"""
cpu_config = Config(self.default_pretrained_model_path)
model = self.default_pretrained_model_path+'.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
self.cpu_predictor = create_predictor(cpu_config)
......@@ -59,7 +60,7 @@ class PyramidBoxLiteMobileMask(hub.Module):
except:
use_gpu = False
if use_gpu:
gpu_config = Config(self.default_pretrained_model_path)
gpu_config = Config(model, params)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)
self.gpu_predictor = create_predictor(gpu_config)
......@@ -180,33 +181,6 @@ class PyramidBoxLiteMobileMask(hub.Module):
res.append(out)
return res
def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True):
classifier_dir = os.path.join(dirname, 'mask_detector')
detector_dir = os.path.join(dirname, 'pyramidbox_lite')
self._save_classifier_model(classifier_dir, model_filename, params_filename, combined)
self._save_detector_model(detector_dir, model_filename, params_filename, combined)
def _save_detector_model(self, dirname, model_filename=None, params_filename=None, combined=True):
self.face_detector.save_inference_model(dirname, model_filename, params_filename, combined)
def _save_classifier_model(self, dirname, model_filename=None, params_filename=None, combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = paddle.CPUPlace()
exe = paddle.Executor(place)
program, feeded_var_names, target_vars = paddle.static.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
paddle.static.save_inference_model(dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving
def serving_method(self, images, **kwargs):
"""
......
......@@ -5,7 +5,6 @@ from __future__ import print_function
import os
import time
from collections import OrderedDict
import base64
import cv2
......
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://ai-studio-static-online.cdn.bcebos.com/7799a8ccc5f6471b9d56fb6eff94f82a08b70ca2c7594d3f99877e366c0a2619'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="pyramidbox_lite_mobile_mask")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('detection_result')
def test_face_detection1(self):
results = self.module.face_detection(
paths=['tests/test.jpg'],
use_gpu=False,
visualization=False
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'NO MASK')
self.assertTrue(confidence > 0.5)
self.assertTrue(1000 < left < 4000)
self.assertTrue(1000 < right < 4000)
self.assertTrue(0 < top < 2000)
self.assertTrue(0 < bottom < 2000)
def test_face_detection2(self):
results = self.module.face_detection(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=False
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'NO MASK')
self.assertTrue(confidence > 0.5)
self.assertTrue(1000 < left < 4000)
self.assertTrue(1000 < right < 4000)
self.assertTrue(0 < top < 2000)
self.assertTrue(0 < bottom < 2000)
def test_face_detection3(self):
results = self.module.face_detection(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=True
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'NO MASK')
self.assertTrue(confidence > 0.5)
self.assertTrue(1000 < left < 4000)
self.assertTrue(1000 < right < 4000)
self.assertTrue(0 < top < 2000)
self.assertTrue(0 < bottom < 2000)
def test_face_detection4(self):
results = self.module.face_detection(
images=[cv2.imread('tests/test.jpg')],
use_gpu=True,
visualization=False
)
bbox = results[0]['data'][0]
label = bbox['label']
confidence = bbox['confidence']
left = bbox['left']
right = bbox['right']
top = bbox['top']
bottom = bbox['bottom']
self.assertEqual(label, 'NO MASK')
self.assertTrue(confidence > 0.5)
self.assertTrue(1000 < left < 4000)
self.assertTrue(1000 < right < 4000)
self.assertTrue(0 < top < 2000)
self.assertTrue(0 < bottom < 2000)
def test_face_detection5(self):
self.assertRaises(
AssertionError,
self.module.face_detection,
paths=['no.jpg']
)
def test_face_detection6(self):
self.assertRaises(
AttributeError,
self.module.face_detection,
images=['test.jpg']
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model/face_detector.pdmodel'))
self.assertTrue(os.path.exists('./inference/model/face_detector.pdiparams'))
self.assertTrue(os.path.exists('./inference/model/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册