From 16165a742fed5b37aa2a6c7c750b1950c8c29040 Mon Sep 17 00:00:00 2001 From: jm12138 <2286040843@qq.com> Date: Fri, 16 Sep 2022 15:17:39 +0800 Subject: [PATCH] update pyramidbox_face_detection (#1975) * update pyramidbox_face_detection * update * add clean func * update save inference model --- .../pyramidbox_face_detection/README.md | 17 ++- .../pyramidbox_face_detection/README_en.md | 17 ++- .../pyramidbox_face_detection/module.py | 60 ++++---- .../pyramidbox_face_detection/processor.py | 3 +- .../pyramidbox_face_detection/test.py | 133 ++++++++++++++++++ 5 files changed, 175 insertions(+), 55 deletions(-) create mode 100644 modules/image/face_detection/pyramidbox_face_detection/test.py diff --git a/modules/image/face_detection/pyramidbox_face_detection/README.md b/modules/image/face_detection/pyramidbox_face_detection/README.md index d7c26e9b..7a629372 100644 --- a/modules/image/face_detection/pyramidbox_face_detection/README.md +++ b/modules/image/face_detection/pyramidbox_face_detection/README.md @@ -100,19 +100,13 @@ - ```python - def save_inference_model(dirname, - model_filename=None, - params_filename=None, - combined=True) + def save_inference_model(dirname) ``` - 将模型保存到指定路径。 - **参数** - - dirname: 存在模型的目录名称;
- - model\_filename: 模型文件名称,默认为\_\_model\_\_;
- - params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效);
- - combined: 是否将参数保存到统一的一个文件中。 + - dirname: 模型保存路径
## 四、服务部署 @@ -165,6 +159,11 @@ * 1.1.0 修复numpy数据读取问题 + +* 1.2.0 + + 修复无法导出推理模型的问题 + - ```shell - $ hub install pyramidbox_face_detection==1.1.0 + $ hub install pyramidbox_face_detection==1.2.0 ``` diff --git a/modules/image/face_detection/pyramidbox_face_detection/README_en.md b/modules/image/face_detection/pyramidbox_face_detection/README_en.md index 5f12c1de..502437e0 100644 --- a/modules/image/face_detection/pyramidbox_face_detection/README_en.md +++ b/modules/image/face_detection/pyramidbox_face_detection/README_en.md @@ -99,19 +99,13 @@ - ```python - def save_inference_model(dirname, - model_filename=None, - params_filename=None, - combined=True) + def save_inference_model(dirname) ``` - Save model to specific path - **Parameters** - - dirname: output dir for saving model - - model\_filename: filename for saving model - - params\_filename: filename for saving parameters - - combined: whether save parameters into one file + - dirname: model save path ## IV.Server Deployment @@ -164,6 +158,11 @@ * 1.1.0 Fix the problem of reading numpy + +* 1.2.0 + + Fix a bug of save_inference_model + - ```shell - $ hub install pyramidbox_face_detection==1.1.0 + $ hub install pyramidbox_face_detection==1.2.0 ``` diff --git a/modules/image/face_detection/pyramidbox_face_detection/module.py b/modules/image/face_detection/pyramidbox_face_detection/module.py index 8b44a11d..89fa16c4 100644 --- a/modules/image/face_detection/pyramidbox_face_detection/module.py +++ b/modules/image/face_detection/pyramidbox_face_detection/module.py @@ -7,13 +7,14 @@ import argparse import os import numpy as np -import paddle.fluid as fluid -import paddlehub as hub -from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor +import paddle +import paddle.jit +import paddle.static +from paddle.inference import Config, create_predictor from paddlehub.module.module import moduleinfo, runnable, serving -from pyramidbox_face_detection.data_feed import reader -from pyramidbox_face_detection.processor import postprocess, base64_to_cv2 +from .data_feed import reader +from .processor import postprocess, base64_to_cv2 @moduleinfo( @@ -22,20 +23,22 @@ from pyramidbox_face_detection.processor import postprocess, base64_to_cv2 author="baidu-vis", author_email="", summary="Baidu's PyramidBox model for face detection.", - version="1.1.0") -class PyramidBoxFaceDetection(hub.Module): - def _initialize(self): - self.default_pretrained_model_path = os.path.join(self.directory, "pyramidbox_face_detection_widerface") + version="1.2.0") +class PyramidBoxFaceDetection: + def __init__(self): + self.default_pretrained_model_path = os.path.join(self.directory, "pyramidbox_face_detection_widerface", "model") self._set_config() def _set_config(self): """ predictor config setting """ - cpu_config = AnalysisConfig(self.default_pretrained_model_path) + model = self.default_pretrained_model_path+'.pdmodel' + params = self.default_pretrained_model_path+'.pdiparams' + cpu_config = Config(model, params) cpu_config.disable_glog_info() cpu_config.disable_gpu() - self.cpu_predictor = create_paddle_predictor(cpu_config) + self.cpu_predictor = create_predictor(cpu_config) try: _places = os.environ["CUDA_VISIBLE_DEVICES"] @@ -44,10 +47,10 @@ class PyramidBoxFaceDetection(hub.Module): except: use_gpu = False if use_gpu: - gpu_config = AnalysisConfig(self.default_pretrained_model_path) + gpu_config = Config(model, params) gpu_config.disable_glog_info() gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0) - self.gpu_predictor = create_paddle_predictor(gpu_config) + self.gpu_predictor = create_predictor(gpu_config) def face_detection(self, images=None, @@ -95,11 +98,17 @@ class PyramidBoxFaceDetection(hub.Module): # process one by one for element in reader(images, paths): image = np.expand_dims(element['image'], axis=0).astype('float32') - image_tensor = PaddleTensor(image.copy()) - data_out = self.gpu_predictor.run([image_tensor]) if use_gpu else self.cpu_predictor.run([image_tensor]) + predictor = self.gpu_predictor if use_gpu else self.cpu_predictor + input_names = predictor.get_input_names() + input_handle = predictor.get_input_handle(input_names[0]) + input_handle.copy_from_cpu(image) + predictor.run() + output_names = predictor.get_output_names() + output_handle = predictor.get_output_handle(output_names[0]) + output = np.expand_dims(output_handle.copy_to_cpu(), axis=1) # print(len(data_out)) # 1 out = postprocess( - data_out=data_out[0].as_ndarray(), + data_out=output_handle.copy_to_cpu(), org_im=element['org_im'], org_im_path=element['org_im_path'], org_im_width=element['org_im_width'], @@ -110,25 +119,6 @@ class PyramidBoxFaceDetection(hub.Module): res.append(out) return res - def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True): - if combined: - model_filename = "__model__" if not model_filename else model_filename - params_filename = "__params__" if not params_filename else params_filename - place = fluid.CPUPlace() - exe = fluid.Executor(place) - - program, feeded_var_names, target_vars = fluid.io.load_inference_model( - dirname=self.default_pretrained_model_path, executor=exe) - - fluid.io.save_inference_model( - dirname=dirname, - main_program=program, - executor=exe, - feeded_var_names=feeded_var_names, - target_vars=target_vars, - model_filename=model_filename, - params_filename=params_filename) - @serving def serving_method(self, images, **kwargs): """ diff --git a/modules/image/face_detection/pyramidbox_face_detection/processor.py b/modules/image/face_detection/pyramidbox_face_detection/processor.py index 0d27ee57..3fee41e8 100644 --- a/modules/image/face_detection/pyramidbox_face_detection/processor.py +++ b/modules/image/face_detection/pyramidbox_face_detection/processor.py @@ -5,12 +5,11 @@ from __future__ import print_function import os import time -from collections import OrderedDict import base64 import cv2 import numpy as np -from PIL import Image, ImageDraw +from PIL import ImageDraw __all__ = ['base64_to_cv2', 'postprocess'] diff --git a/modules/image/face_detection/pyramidbox_face_detection/test.py b/modules/image/face_detection/pyramidbox_face_detection/test.py new file mode 100644 index 00000000..730a3141 --- /dev/null +++ b/modules/image/face_detection/pyramidbox_face_detection/test.py @@ -0,0 +1,133 @@ +import os +import shutil +import unittest + +import cv2 +import requests +import paddlehub as hub + + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +class TestHubModule(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + img_url = 'https://ai-studio-static-online.cdn.bcebos.com/7799a8ccc5f6471b9d56fb6eff94f82a08b70ca2c7594d3f99877e366c0a2619' + if not os.path.exists('tests'): + os.makedirs('tests') + response = requests.get(img_url) + assert response.status_code == 200, 'Network Error.' + with open('tests/test.jpg', 'wb') as f: + f.write(response.content) + cls.module = hub.Module(name="pyramidbox_face_detection") + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree('tests') + shutil.rmtree('inference') + shutil.rmtree('detection_result') + + def test_face_detection1(self): + results = self.module.face_detection( + paths=['tests/test.jpg'], + use_gpu=False, + visualization=False + ) + bbox = results[0]['data'][0] + + confidence = bbox['confidence'] + left = bbox['left'] + right = bbox['right'] + top = bbox['top'] + bottom = bbox['bottom'] + + self.assertTrue(confidence > 0.5) + self.assertTrue(1000 < left < 4000) + self.assertTrue(1000 < right < 4000) + self.assertTrue(0 < top < 2000) + self.assertTrue(0 < bottom < 2000) + + def test_face_detection2(self): + results = self.module.face_detection( + images=[cv2.imread('tests/test.jpg')], + use_gpu=False, + visualization=False + ) + bbox = results[0]['data'][0] + + confidence = bbox['confidence'] + left = bbox['left'] + right = bbox['right'] + top = bbox['top'] + bottom = bbox['bottom'] + + self.assertTrue(confidence > 0.5) + self.assertTrue(1000 < left < 4000) + self.assertTrue(1000 < right < 4000) + self.assertTrue(0 < top < 2000) + self.assertTrue(0 < bottom < 2000) + + def test_face_detection3(self): + results = self.module.face_detection( + images=[cv2.imread('tests/test.jpg')], + use_gpu=False, + visualization=True + ) + bbox = results[0]['data'][0] + + confidence = bbox['confidence'] + left = bbox['left'] + right = bbox['right'] + top = bbox['top'] + bottom = bbox['bottom'] + + self.assertTrue(confidence > 0.5) + self.assertTrue(1000 < left < 4000) + self.assertTrue(1000 < right < 4000) + self.assertTrue(0 < top < 2000) + self.assertTrue(0 < bottom < 2000) + + def test_face_detection4(self): + results = self.module.face_detection( + images=[cv2.imread('tests/test.jpg')], + use_gpu=True, + visualization=False + ) + bbox = results[0]['data'][0] + + confidence = bbox['confidence'] + left = bbox['left'] + right = bbox['right'] + top = bbox['top'] + bottom = bbox['bottom'] + + self.assertTrue(confidence > 0.5) + self.assertTrue(1000 < left < 4000) + self.assertTrue(1000 < right < 4000) + self.assertTrue(0 < top < 2000) + self.assertTrue(0 < bottom < 2000) + + def test_face_detection5(self): + self.assertRaises( + AssertionError, + self.module.face_detection, + paths=['no.jpg'] + ) + + def test_face_detection6(self): + self.assertRaises( + cv2.error, + self.module.face_detection, + images=['test.jpg'] + ) + + def test_save_inference_model(self): + self.module.save_inference_model('./inference/model') + + self.assertTrue(os.path.exists('./inference/model.pdmodel')) + self.assertTrue(os.path.exists('./inference/model.pdiparams')) + + +if __name__ == "__main__": + unittest.main() -- GitLab