未验证 提交 00047359 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update ace2p (#2003)

* update ace2p

* add clean func

* update ace2p
上级 94949b0e
......@@ -99,20 +99,14 @@
* data (numpy.ndarray): 图像分割得到的结果,shape 为`H * W`,元素的取值为0-19,表示每个像素的分类结果,映射顺序与下面的调色板相同。
```python
def save_inference_model(dirname,
model_filename=None,
params_filename=None,
combined=True)
def save_inference_model(dirname)
```
- 将模型保存到指定路径。
- **参数**
* dirname: 存在模型的目录名称
* model\_filename: 模型文件名称,默认为\_\_model\_\_
* params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效)
* combined: 是否将参数保存到统一的一个文件中。
* dirname: 模型保存路径
## 四、服务部署
......@@ -176,3 +170,11 @@
* 1.1.0
适配paddlehub2.0版本
* 1.2.0
移除 Fluid API
```shell
$ hub install ace2p == 1.2.0
```
\ No newline at end of file
......@@ -105,19 +105,13 @@
- ```python
def save_inference_model(dirname,
model_filename=None,
params_filename=None,
combined=True)
def save_inference_model(dirname)
```
- Save the model to the specified path.
- **Parameters**
* dirname: Save path.
* model\_filename: mMdel file name,defalt is \_\_model\_\_
* params\_filename: Parameter file name,defalt is \_\_params\_\_(Only takes effect when `combined` is True)
* combined: Whether to save the parameters to a unified file.
* dirname: Model save path.
## IV. Server Deployment
......@@ -182,3 +176,11 @@
* 1.1.0
Adapt to paddlehub2.0
* 1.2.0
Remove Fluid API
```shell
$ hub install ace2p == 1.2.0
```
......@@ -6,7 +6,7 @@ from collections import OrderedDict
import cv2
import numpy as np
from ace2p.processor import get_direction, get_3rd_point, get_affine_transform
from .processor import get_affine_transform
__all__ = ['reader']
......@@ -45,7 +45,7 @@ def preprocess(org_im, scale, rotation):
img_mean = np.array([0.406, 0.456, 0.485]).reshape((1, 1, 3))
img_std = np.array([0.225, 0.224, 0.229]).reshape((1, 1, 3))
image = image.astype(np.float)
image = image.astype(np.float32)
image = (image / 255.0 - img_mean) / img_std
image = image.transpose(2, 0, 1).astype(np.float32)
......
......@@ -7,13 +7,14 @@ import argparse
import os
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
import paddle
import paddle.jit
import paddle.static
from paddle.inference import Config, create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving
from ace2p.processor import get_palette, postprocess, base64_to_cv2, cv2_to_base64
from ace2p.data_feed import reader
from .processor import get_palette, postprocess, base64_to_cv2, cv2_to_base64
from .data_feed import reader
@moduleinfo(
......@@ -22,10 +23,11 @@ from ace2p.data_feed import reader
author="baidu-idl",
author_email="",
summary="ACE2P is an image segmentation model for human parsing solution.",
version="1.1.0")
class ACE2P(hub.Module):
def _initialize(self):
self.default_pretrained_model_path = os.path.join(self.directory, "ace2p_human_parsing")
version="1.2.0")
class ACE2P:
def __init__(self):
self.default_pretrained_model_path = os.path.join(
self.directory, "ace2p_human_parsing", "model")
# label list
label_list_file = os.path.join(self.directory, 'label_list.txt')
with open(label_list_file, "r") as file:
......@@ -39,10 +41,12 @@ class ACE2P(hub.Module):
"""
predictor config setting
"""
cpu_config = AnalysisConfig(self.default_pretrained_model_path)
model = self.default_pretrained_model_path+'.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config)
self.cpu_predictor = create_predictor(cpu_config)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
......@@ -51,10 +55,10 @@ class ACE2P(hub.Module):
except:
use_gpu = False
if use_gpu:
gpu_config = AnalysisConfig(self.default_pretrained_model_path)
gpu_config = Config(model, params)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config)
self.gpu_predictor = create_predictor(gpu_config)
def segmentation(self,
images=None,
......@@ -114,12 +118,19 @@ class ACE2P(hub.Module):
pass
# feed batch image
batch_image = np.array([data['image'] for data in batch_data])
batch_image = PaddleTensor(batch_image.astype('float32'))
data_out = self.gpu_predictor.run([batch_image]) if use_gpu else self.cpu_predictor.run([batch_image])
predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(batch_image.astype('float32'))
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
# postprocess one by one
for i in range(len(batch_data)):
out = postprocess(
data_out=data_out[0].as_ndarray()[i],
data_out=output_handle.copy_to_cpu()[i],
org_im=batch_data[i]['org_im'],
org_im_path=batch_data[i]['org_im_path'],
image_info=batch_data[i]['image_info'],
......@@ -129,25 +140,6 @@ class ACE2P(hub.Module):
res.append(out)
return res
def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving
def serving_method(self, images, **kwargs):
"""
......
import os
import shutil
import unittest
import cv2
import requests
import numpy as np
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/pg_WCHWSdT8/download?ixid=MnwxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjYyNDM2ODI4&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
img = cv2.imread('tests/test.jpg')
video = cv2.VideoWriter('tests/test.avi', fourcc,
20.0, tuple(img.shape[:2]))
for i in range(40):
video.write(img)
video.release()
cls.module = hub.Module(name="ace2p")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('ace2p_output')
def test_segmentation1(self):
results = self.module.segmentation(
paths=['tests/test.jpg'],
use_gpu=False,
visualization=False
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segmentation2(self):
results = self.module.segmentation(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=False
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segmentation3(self):
results = self.module.segmentation(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=True
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segmentation4(self):
results = self.module.segmentation(
images=[cv2.imread('tests/test.jpg')],
use_gpu=True,
visualization=False
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segmentation5(self):
self.assertRaises(
AssertionError,
self.module.segmentation,
paths=['no.jpg']
)
def test_segmentation6(self):
self.assertRaises(
AttributeError,
self.module.segmentation,
images=['test.jpg']
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册