diff --git a/modules/image/semantic_segmentation/ace2p/README.md b/modules/image/semantic_segmentation/ace2p/README.md index 12b23cf4f1beed338058a89e64a0ac1d854e3892..5677f51e7901ce9026e2445be1b1dba4835ad361 100644 --- a/modules/image/semantic_segmentation/ace2p/README.md +++ b/modules/image/semantic_segmentation/ace2p/README.md @@ -99,20 +99,14 @@ * data (numpy.ndarray): 图像分割得到的结果,shape 为`H * W`,元素的取值为0-19,表示每个像素的分类结果,映射顺序与下面的调色板相同。 ```python - def save_inference_model(dirname, - model_filename=None, - params_filename=None, - combined=True) + def save_inference_model(dirname) ``` - 将模型保存到指定路径。 - **参数** - * dirname: 存在模型的目录名称 - * model\_filename: 模型文件名称,默认为\_\_model\_\_ - * params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效) - * combined: 是否将参数保存到统一的一个文件中。 + * dirname: 模型保存路径 ## 四、服务部署 @@ -176,3 +170,11 @@ * 1.1.0 适配paddlehub2.0版本 + +* 1.2.0 + + 移除 Fluid API + + ```shell + $ hub install ace2p == 1.2.0 + ``` \ No newline at end of file diff --git a/modules/image/semantic_segmentation/ace2p/README_en.md b/modules/image/semantic_segmentation/ace2p/README_en.md index 3fa0c273e3b3095ce8ba7b8abf97543e3be6ca48..e8acf04f285d55d12462d3556fcb315bd3fb8775 100644 --- a/modules/image/semantic_segmentation/ace2p/README_en.md +++ b/modules/image/semantic_segmentation/ace2p/README_en.md @@ -105,19 +105,13 @@ - ```python - def save_inference_model(dirname, - model_filename=None, - params_filename=None, - combined=True) + def save_inference_model(dirname) ``` - Save the model to the specified path. - **Parameters** - * dirname: Save path. - * model\_filename: mMdel file name,defalt is \_\_model\_\_ - * params\_filename: Parameter file name,defalt is \_\_params\_\_(Only takes effect when `combined` is True) - * combined: Whether to save the parameters to a unified file. + * dirname: Model save path. ## IV. Server Deployment @@ -182,3 +176,11 @@ * 1.1.0 Adapt to paddlehub2.0 + +* 1.2.0 + + Remove Fluid API + + ```shell + $ hub install ace2p == 1.2.0 + ``` diff --git a/modules/image/semantic_segmentation/ace2p/data_feed.py b/modules/image/semantic_segmentation/ace2p/data_feed.py index 39094654805525fbba5bea55a30a57d605151646..230520253b94a6bc72c4c9cfe74fab7aedfca746 100644 --- a/modules/image/semantic_segmentation/ace2p/data_feed.py +++ b/modules/image/semantic_segmentation/ace2p/data_feed.py @@ -6,7 +6,7 @@ from collections import OrderedDict import cv2 import numpy as np -from ace2p.processor import get_direction, get_3rd_point, get_affine_transform +from .processor import get_affine_transform __all__ = ['reader'] @@ -45,7 +45,7 @@ def preprocess(org_im, scale, rotation): img_mean = np.array([0.406, 0.456, 0.485]).reshape((1, 1, 3)) img_std = np.array([0.225, 0.224, 0.229]).reshape((1, 1, 3)) - image = image.astype(np.float) + image = image.astype(np.float32) image = (image / 255.0 - img_mean) / img_std image = image.transpose(2, 0, 1).astype(np.float32) diff --git a/modules/image/semantic_segmentation/ace2p/module.py b/modules/image/semantic_segmentation/ace2p/module.py index 458f33d10def98dd46616add639e5a9998205b10..2f30df4d7fda4b24ec729976938ad75304e35f58 100644 --- a/modules/image/semantic_segmentation/ace2p/module.py +++ b/modules/image/semantic_segmentation/ace2p/module.py @@ -7,13 +7,14 @@ import argparse import os import numpy as np -import paddle.fluid as fluid -import paddlehub as hub -from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor +import paddle +import paddle.jit +import paddle.static +from paddle.inference import Config, create_predictor from paddlehub.module.module import moduleinfo, runnable, serving -from ace2p.processor import get_palette, postprocess, base64_to_cv2, cv2_to_base64 -from ace2p.data_feed import reader +from .processor import get_palette, postprocess, base64_to_cv2, cv2_to_base64 +from .data_feed import reader @moduleinfo( @@ -22,10 +23,11 @@ from ace2p.data_feed import reader author="baidu-idl", author_email="", summary="ACE2P is an image segmentation model for human parsing solution.", - version="1.1.0") -class ACE2P(hub.Module): - def _initialize(self): - self.default_pretrained_model_path = os.path.join(self.directory, "ace2p_human_parsing") + version="1.2.0") +class ACE2P: + def __init__(self): + self.default_pretrained_model_path = os.path.join( + self.directory, "ace2p_human_parsing", "model") # label list label_list_file = os.path.join(self.directory, 'label_list.txt') with open(label_list_file, "r") as file: @@ -39,10 +41,12 @@ class ACE2P(hub.Module): """ predictor config setting """ - cpu_config = AnalysisConfig(self.default_pretrained_model_path) + model = self.default_pretrained_model_path+'.pdmodel' + params = self.default_pretrained_model_path+'.pdiparams' + cpu_config = Config(model, params) cpu_config.disable_glog_info() cpu_config.disable_gpu() - self.cpu_predictor = create_paddle_predictor(cpu_config) + self.cpu_predictor = create_predictor(cpu_config) try: _places = os.environ["CUDA_VISIBLE_DEVICES"] @@ -51,10 +55,10 @@ class ACE2P(hub.Module): except: use_gpu = False if use_gpu: - gpu_config = AnalysisConfig(self.default_pretrained_model_path) + gpu_config = Config(model, params) gpu_config.disable_glog_info() gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0) - self.gpu_predictor = create_paddle_predictor(gpu_config) + self.gpu_predictor = create_predictor(gpu_config) def segmentation(self, images=None, @@ -114,12 +118,19 @@ class ACE2P(hub.Module): pass # feed batch image batch_image = np.array([data['image'] for data in batch_data]) - batch_image = PaddleTensor(batch_image.astype('float32')) - data_out = self.gpu_predictor.run([batch_image]) if use_gpu else self.cpu_predictor.run([batch_image]) + + predictor = self.gpu_predictor if use_gpu else self.cpu_predictor + input_names = predictor.get_input_names() + input_handle = predictor.get_input_handle(input_names[0]) + input_handle.copy_from_cpu(batch_image.astype('float32')) + predictor.run() + output_names = predictor.get_output_names() + output_handle = predictor.get_output_handle(output_names[0]) + # postprocess one by one for i in range(len(batch_data)): out = postprocess( - data_out=data_out[0].as_ndarray()[i], + data_out=output_handle.copy_to_cpu()[i], org_im=batch_data[i]['org_im'], org_im_path=batch_data[i]['org_im_path'], image_info=batch_data[i]['image_info'], @@ -129,25 +140,6 @@ class ACE2P(hub.Module): res.append(out) return res - def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True): - if combined: - model_filename = "__model__" if not model_filename else model_filename - params_filename = "__params__" if not params_filename else params_filename - place = fluid.CPUPlace() - exe = fluid.Executor(place) - - program, feeded_var_names, target_vars = fluid.io.load_inference_model( - dirname=self.default_pretrained_model_path, executor=exe) - - fluid.io.save_inference_model( - dirname=dirname, - main_program=program, - executor=exe, - feeded_var_names=feeded_var_names, - target_vars=target_vars, - model_filename=model_filename, - params_filename=params_filename) - @serving def serving_method(self, images, **kwargs): """ diff --git a/modules/image/semantic_segmentation/ace2p/test.py b/modules/image/semantic_segmentation/ace2p/test.py new file mode 100644 index 0000000000000000000000000000000000000000..fa738eb7193f4aef95b9c88162c1f2a61439c004 --- /dev/null +++ b/modules/image/semantic_segmentation/ace2p/test.py @@ -0,0 +1,93 @@ +import os +import shutil +import unittest + +import cv2 +import requests +import numpy as np +import paddlehub as hub + + +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + +class TestHubModule(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + img_url = 'https://unsplash.com/photos/pg_WCHWSdT8/download?ixid=MnwxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjYyNDM2ODI4&force=true&w=640' + if not os.path.exists('tests'): + os.makedirs('tests') + response = requests.get(img_url) + assert response.status_code == 200, 'Network Error.' + with open('tests/test.jpg', 'wb') as f: + f.write(response.content) + fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') + img = cv2.imread('tests/test.jpg') + video = cv2.VideoWriter('tests/test.avi', fourcc, + 20.0, tuple(img.shape[:2])) + for i in range(40): + video.write(img) + video.release() + cls.module = hub.Module(name="ace2p") + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree('tests') + shutil.rmtree('inference') + shutil.rmtree('ace2p_output') + + def test_segmentation1(self): + results = self.module.segmentation( + paths=['tests/test.jpg'], + use_gpu=False, + visualization=False + ) + self.assertIsInstance(results[0]['data'], np.ndarray) + + def test_segmentation2(self): + results = self.module.segmentation( + images=[cv2.imread('tests/test.jpg')], + use_gpu=False, + visualization=False + ) + self.assertIsInstance(results[0]['data'], np.ndarray) + + def test_segmentation3(self): + results = self.module.segmentation( + images=[cv2.imread('tests/test.jpg')], + use_gpu=False, + visualization=True + ) + self.assertIsInstance(results[0]['data'], np.ndarray) + + def test_segmentation4(self): + results = self.module.segmentation( + images=[cv2.imread('tests/test.jpg')], + use_gpu=True, + visualization=False + ) + self.assertIsInstance(results[0]['data'], np.ndarray) + + def test_segmentation5(self): + self.assertRaises( + AssertionError, + self.module.segmentation, + paths=['no.jpg'] + ) + + def test_segmentation6(self): + self.assertRaises( + AttributeError, + self.module.segmentation, + images=['test.jpg'] + ) + + def test_save_inference_model(self): + self.module.save_inference_model('./inference/model') + + self.assertTrue(os.path.exists('./inference/model.pdmodel')) + self.assertTrue(os.path.exists('./inference/model.pdiparams')) + + +if __name__ == "__main__": + unittest.main()