未验证 提交 71ee4cf6 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update deeplabv3p_xception65_humanseg (#2008)

* update deeplabv3p_xception65_humanseg

* update save inference model
上级 00047359
...@@ -70,11 +70,11 @@ ...@@ -70,11 +70,11 @@
```python ```python
def segmentation(images=None, def segmentation(images=None,
paths=None, paths=None,
batch_size=1, batch_size=1,
use_gpu=False, use_gpu=False,
visualization=False, visualization=False,
output_dir='humanseg_output') output_dir='humanseg_output')
``` ```
- 预测API,用于人像分割。 - 预测API,用于人像分割。
...@@ -95,20 +95,14 @@ ...@@ -95,20 +95,14 @@
* data (numpy.ndarray): 人像分割结果,仅包含Alpha通道,取值为0-255 (0为全透明,255为不透明),也即取值越大的像素点越可能为人体,取值越小的像素点越可能为背景。 * data (numpy.ndarray): 人像分割结果,仅包含Alpha通道,取值为0-255 (0为全透明,255为不透明),也即取值越大的像素点越可能为人体,取值越小的像素点越可能为背景。
```python ```python
def save_inference_model(dirname, def save_inference_model(dirname)
model_filename=None,
params_filename=None,
combined=True)
``` ```
- 将模型保存到指定路径。 - 将模型保存到指定路径。
- **参数** - **参数**
* dirname: 存在模型的目录名称 * dirname: 模型保存路径
* model\_filename: 模型文件名称,默认为\_\_model\_\_
* params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效)
* combined: 是否将参数保存到统一的一个文件中
## 四、服务部署 ## 四、服务部署
...@@ -175,10 +169,10 @@ ...@@ -175,10 +169,10 @@
修复预测后处理图像数据超过[0,255]范围 修复预测后处理图像数据超过[0,255]范围
* 1.1.2 * 1.2.0
移除 fluid api 移除 fluid api
- ```shell - ```shell
$ hub install deeplabv3p_xception65_humanseg==1.1.2 $ hub install deeplabv3p_xception65_humanseg==1.2.0
``` ```
...@@ -70,11 +70,11 @@ ...@@ -70,11 +70,11 @@
- ```python - ```python
def segmentation(images=None, def segmentation(images=None,
paths=None, paths=None,
batch_size=1, batch_size=1,
use_gpu=False, use_gpu=False,
visualization=False, visualization=False,
output_dir='humanseg_output') output_dir='humanseg_output')
``` ```
- Prediction API, generating segmentation result. - Prediction API, generating segmentation result.
...@@ -94,19 +94,13 @@ ...@@ -94,19 +94,13 @@
* data (numpy.ndarray): The result of portrait segmentation. * data (numpy.ndarray): The result of portrait segmentation.
- ```python - ```python
def save_inference_model(dirname, def save_inference_model(dirname)
model_filename=None,
params_filename=None,
combined=True)
``` ```
- Save the model to the specified path. - Save the model to the specified path.
- **Parameters** - **Parameters**
* dirname: Save path. * dirname: Model save path.
* model\_filename: Model file name,defalt is \_\_model\_\_
* params\_filename: Parameter file name,defalt is \_\_params\_\_(Only takes effect when `combined` is True)
* combined: Whether to save the parameters to a unified file.
## IV. Server Deployment ## IV. Server Deployment
...@@ -171,10 +165,10 @@ ...@@ -171,10 +165,10 @@
Fix the bug of image value out of range Fix the bug of image value out of range
* 1.1.2 * 1.2.0
Remove fluid api Remove fluid api
- ```shell - ```shell
$ hub install deeplabv3p_xception65_humanseg==1.1.2 $ hub install deeplabv3p_xception65_humanseg==1.2.0
``` ```
...@@ -5,7 +5,6 @@ from collections import OrderedDict ...@@ -5,7 +5,6 @@ from collections import OrderedDict
import cv2 import cv2
import numpy as np import numpy as np
from PIL import Image
__all__ = ['reader'] __all__ = ['reader']
......
...@@ -8,14 +8,13 @@ import os ...@@ -8,14 +8,13 @@ import os
import numpy as np import numpy as np
import paddle import paddle
from deeplabv3p_xception65_humanseg.data_feed import reader from .data_feed import reader
from deeplabv3p_xception65_humanseg.processor import base64_to_cv2 from .processor import base64_to_cv2
from deeplabv3p_xception65_humanseg.processor import cv2_to_base64 from .processor import cv2_to_base64
from deeplabv3p_xception65_humanseg.processor import postprocess from .processor import postprocess
from paddle.inference import Config from paddle.inference import Config
from paddle.inference import create_predictor from paddle.inference import create_predictor
import paddlehub as hub
from paddlehub.module.module import moduleinfo from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable from paddlehub.module.module import runnable
from paddlehub.module.module import serving from paddlehub.module.module import serving
...@@ -26,18 +25,20 @@ from paddlehub.module.module import serving ...@@ -26,18 +25,20 @@ from paddlehub.module.module import serving
author="baidu-vis", author="baidu-vis",
author_email="", author_email="",
summary="DeepLabv3+ is a semantic segmentation model.", summary="DeepLabv3+ is a semantic segmentation model.",
version="1.1.2") version="1.2.0")
class DeeplabV3pXception65HumanSeg(hub.Module): class DeeplabV3pXception65HumanSeg:
def _initialize(self): def __init__(self):
self.default_pretrained_model_path = os.path.join(self.directory, "deeplabv3p_xception65_humanseg_model") self.default_pretrained_model_path = os.path.join(self.directory, "deeplabv3p_xception65_humanseg_model", "model")
self._set_config() self._set_config()
def _set_config(self): def _set_config(self):
""" """
predictor config setting predictor config setting
""" """
cpu_config = Config(self.default_pretrained_model_path) model = self.default_pretrained_model_path+'.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
self.cpu_predictor = create_predictor(cpu_config) self.cpu_predictor = create_predictor(cpu_config)
...@@ -49,7 +50,7 @@ class DeeplabV3pXception65HumanSeg(hub.Module): ...@@ -49,7 +50,7 @@ class DeeplabV3pXception65HumanSeg(hub.Module):
except: except:
use_gpu = False use_gpu = False
if use_gpu: if use_gpu:
gpu_config = Config(self.default_pretrained_model_path) gpu_config = Config(model, params)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0) gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)
self.gpu_predictor = create_predictor(gpu_config) self.gpu_predictor = create_predictor(gpu_config)
...@@ -134,24 +135,6 @@ class DeeplabV3pXception65HumanSeg(hub.Module): ...@@ -134,24 +135,6 @@ class DeeplabV3pXception65HumanSeg(hub.Module):
res.append(out) res.append(out)
return res return res
def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = paddle.CPUPlace()
exe = paddle.Executor(place)
program, feeded_var_names, target_vars = paddle.static.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
paddle.static.save_inference_model(dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving @serving
def serving_method(self, images, **kwargs): def serving_method(self, images, **kwargs):
""" """
......
...@@ -5,7 +5,6 @@ from __future__ import print_function ...@@ -5,7 +5,6 @@ from __future__ import print_function
import os import os
import time import time
from collections import OrderedDict
import base64 import base64
import cv2 import cv2
......
import os
import shutil
import unittest
import cv2
import requests
import numpy as np
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/pg_WCHWSdT8/download?ixid=MnwxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjYyNDM2ODI4&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="deeplabv3p_xception65_humanseg")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('humanseg_output')
def test_segmentation1(self):
results = self.module.segmentation(
paths=['tests/test.jpg'],
use_gpu=False,
visualization=False
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segmentation2(self):
results = self.module.segmentation(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=False
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segmentation3(self):
results = self.module.segmentation(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=True
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segmentation4(self):
results = self.module.segmentation(
images=[cv2.imread('tests/test.jpg')],
use_gpu=True,
visualization=False
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segmentation5(self):
self.assertRaises(
AssertionError,
self.module.segmentation,
paths=['no.jpg']
)
def test_segmentation6(self):
self.assertRaises(
AttributeError,
self.module.segmentation,
images=['test.jpg']
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册