未验证 提交 00047359 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update ace2p (#2003)

* update ace2p

* add clean func

* update ace2p
上级 94949b0e
...@@ -99,20 +99,14 @@ ...@@ -99,20 +99,14 @@
* data (numpy.ndarray): 图像分割得到的结果,shape 为`H * W`,元素的取值为0-19,表示每个像素的分类结果,映射顺序与下面的调色板相同。 * data (numpy.ndarray): 图像分割得到的结果,shape 为`H * W`,元素的取值为0-19,表示每个像素的分类结果,映射顺序与下面的调色板相同。
```python ```python
def save_inference_model(dirname, def save_inference_model(dirname)
model_filename=None,
params_filename=None,
combined=True)
``` ```
- 将模型保存到指定路径。 - 将模型保存到指定路径。
- **参数** - **参数**
* dirname: 存在模型的目录名称 * dirname: 模型保存路径
* model\_filename: 模型文件名称,默认为\_\_model\_\_
* params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效)
* combined: 是否将参数保存到统一的一个文件中。
## 四、服务部署 ## 四、服务部署
...@@ -176,3 +170,11 @@ ...@@ -176,3 +170,11 @@
* 1.1.0 * 1.1.0
适配paddlehub2.0版本 适配paddlehub2.0版本
* 1.2.0
移除 Fluid API
```shell
$ hub install ace2p == 1.2.0
```
\ No newline at end of file
...@@ -105,19 +105,13 @@ ...@@ -105,19 +105,13 @@
- ```python - ```python
def save_inference_model(dirname, def save_inference_model(dirname)
model_filename=None,
params_filename=None,
combined=True)
``` ```
- Save the model to the specified path. - Save the model to the specified path.
- **Parameters** - **Parameters**
* dirname: Save path. * dirname: Model save path.
* model\_filename: mMdel file name,defalt is \_\_model\_\_
* params\_filename: Parameter file name,defalt is \_\_params\_\_(Only takes effect when `combined` is True)
* combined: Whether to save the parameters to a unified file.
## IV. Server Deployment ## IV. Server Deployment
...@@ -182,3 +176,11 @@ ...@@ -182,3 +176,11 @@
* 1.1.0 * 1.1.0
Adapt to paddlehub2.0 Adapt to paddlehub2.0
* 1.2.0
Remove Fluid API
```shell
$ hub install ace2p == 1.2.0
```
...@@ -6,7 +6,7 @@ from collections import OrderedDict ...@@ -6,7 +6,7 @@ from collections import OrderedDict
import cv2 import cv2
import numpy as np import numpy as np
from ace2p.processor import get_direction, get_3rd_point, get_affine_transform from .processor import get_affine_transform
__all__ = ['reader'] __all__ = ['reader']
...@@ -45,7 +45,7 @@ def preprocess(org_im, scale, rotation): ...@@ -45,7 +45,7 @@ def preprocess(org_im, scale, rotation):
img_mean = np.array([0.406, 0.456, 0.485]).reshape((1, 1, 3)) img_mean = np.array([0.406, 0.456, 0.485]).reshape((1, 1, 3))
img_std = np.array([0.225, 0.224, 0.229]).reshape((1, 1, 3)) img_std = np.array([0.225, 0.224, 0.229]).reshape((1, 1, 3))
image = image.astype(np.float) image = image.astype(np.float32)
image = (image / 255.0 - img_mean) / img_std image = (image / 255.0 - img_mean) / img_std
image = image.transpose(2, 0, 1).astype(np.float32) image = image.transpose(2, 0, 1).astype(np.float32)
......
...@@ -7,13 +7,14 @@ import argparse ...@@ -7,13 +7,14 @@ import argparse
import os import os
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle
import paddlehub as hub import paddle.jit
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor import paddle.static
from paddle.inference import Config, create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
from ace2p.processor import get_palette, postprocess, base64_to_cv2, cv2_to_base64 from .processor import get_palette, postprocess, base64_to_cv2, cv2_to_base64
from ace2p.data_feed import reader from .data_feed import reader
@moduleinfo( @moduleinfo(
...@@ -22,10 +23,11 @@ from ace2p.data_feed import reader ...@@ -22,10 +23,11 @@ from ace2p.data_feed import reader
author="baidu-idl", author="baidu-idl",
author_email="", author_email="",
summary="ACE2P is an image segmentation model for human parsing solution.", summary="ACE2P is an image segmentation model for human parsing solution.",
version="1.1.0") version="1.2.0")
class ACE2P(hub.Module): class ACE2P:
def _initialize(self): def __init__(self):
self.default_pretrained_model_path = os.path.join(self.directory, "ace2p_human_parsing") self.default_pretrained_model_path = os.path.join(
self.directory, "ace2p_human_parsing", "model")
# label list # label list
label_list_file = os.path.join(self.directory, 'label_list.txt') label_list_file = os.path.join(self.directory, 'label_list.txt')
with open(label_list_file, "r") as file: with open(label_list_file, "r") as file:
...@@ -39,10 +41,12 @@ class ACE2P(hub.Module): ...@@ -39,10 +41,12 @@ class ACE2P(hub.Module):
""" """
predictor config setting predictor config setting
""" """
cpu_config = AnalysisConfig(self.default_pretrained_model_path) model = self.default_pretrained_model_path+'.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config) self.cpu_predictor = create_predictor(cpu_config)
try: try:
_places = os.environ["CUDA_VISIBLE_DEVICES"] _places = os.environ["CUDA_VISIBLE_DEVICES"]
...@@ -51,10 +55,10 @@ class ACE2P(hub.Module): ...@@ -51,10 +55,10 @@ class ACE2P(hub.Module):
except: except:
use_gpu = False use_gpu = False
if use_gpu: if use_gpu:
gpu_config = AnalysisConfig(self.default_pretrained_model_path) gpu_config = Config(model, params)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0) gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config) self.gpu_predictor = create_predictor(gpu_config)
def segmentation(self, def segmentation(self,
images=None, images=None,
...@@ -114,12 +118,19 @@ class ACE2P(hub.Module): ...@@ -114,12 +118,19 @@ class ACE2P(hub.Module):
pass pass
# feed batch image # feed batch image
batch_image = np.array([data['image'] for data in batch_data]) batch_image = np.array([data['image'] for data in batch_data])
batch_image = PaddleTensor(batch_image.astype('float32'))
data_out = self.gpu_predictor.run([batch_image]) if use_gpu else self.cpu_predictor.run([batch_image]) predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(batch_image.astype('float32'))
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
# postprocess one by one # postprocess one by one
for i in range(len(batch_data)): for i in range(len(batch_data)):
out = postprocess( out = postprocess(
data_out=data_out[0].as_ndarray()[i], data_out=output_handle.copy_to_cpu()[i],
org_im=batch_data[i]['org_im'], org_im=batch_data[i]['org_im'],
org_im_path=batch_data[i]['org_im_path'], org_im_path=batch_data[i]['org_im_path'],
image_info=batch_data[i]['image_info'], image_info=batch_data[i]['image_info'],
...@@ -129,25 +140,6 @@ class ACE2P(hub.Module): ...@@ -129,25 +140,6 @@ class ACE2P(hub.Module):
res.append(out) res.append(out)
return res return res
def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving @serving
def serving_method(self, images, **kwargs): def serving_method(self, images, **kwargs):
""" """
......
import os
import shutil
import unittest
import cv2
import requests
import numpy as np
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/pg_WCHWSdT8/download?ixid=MnwxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjYyNDM2ODI4&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
img = cv2.imread('tests/test.jpg')
video = cv2.VideoWriter('tests/test.avi', fourcc,
20.0, tuple(img.shape[:2]))
for i in range(40):
video.write(img)
video.release()
cls.module = hub.Module(name="ace2p")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('ace2p_output')
def test_segmentation1(self):
results = self.module.segmentation(
paths=['tests/test.jpg'],
use_gpu=False,
visualization=False
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segmentation2(self):
results = self.module.segmentation(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=False
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segmentation3(self):
results = self.module.segmentation(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=True
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segmentation4(self):
results = self.module.segmentation(
images=[cv2.imread('tests/test.jpg')],
use_gpu=True,
visualization=False
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segmentation5(self):
self.assertRaises(
AssertionError,
self.module.segmentation,
paths=['no.jpg']
)
def test_segmentation6(self):
self.assertRaises(
AttributeError,
self.module.segmentation,
images=['test.jpg']
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册