未验证 提交 cf5f3112 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update humanseg_server (#2002)

* update humanseg_server

* add clean func

* update save inference model
上级 8873a70c
...@@ -173,19 +173,13 @@ ...@@ -173,19 +173,13 @@
```python ```python
def save_inference_model(dirname='humanseg_server_model', def save_inference_model(dirname)
model_filename=None,
params_filename=None,
combined=True)
``` ```
- 将模型保存到指定路径。 - 将模型保存到指定路径。
- **参数** - **参数**
* dirname: 存在模型的目录名称 * dirname: 模型保存路径
* model\_filename: 模型文件名称,默认为\_\_model\_\_
* params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效)
* combined: 是否将参数保存到统一的一个文件中
## 四、服务部署 ## 四、服务部署
...@@ -243,11 +237,21 @@ ...@@ -243,11 +237,21 @@
* 1.0.0 * 1.0.0
初始发布 初始发布
* 1.1.0 * 1.1.0
新增视频人像分割接口 新增视频人像分割接口
新增视频流人像分割接口 新增视频流人像分割接口
* 1.1.1 * 1.1.1
修复cudnn为8.0.4显存泄露问题 修复cudnn为8.0.4显存泄露问题
* 1.2.0
移除 Fluid API
```shell
$ hub install humanseg_server == 1.2.0
```
...@@ -170,10 +170,7 @@ ...@@ -170,10 +170,7 @@
```python ```python
def save_inference_model(dirname='humanseg_server_model', def save_inference_model(dirname)
model_filename=None,
params_filename=None,
combined=True)
``` ```
...@@ -181,10 +178,7 @@ ...@@ -181,10 +178,7 @@
- **Parameters** - **Parameters**
* dirname: Save path. * dirname: Model save path.
* model\_filename: Model file name,defalt is \_\_model\_\_
* params\_filename: Parameter file name,defalt is \_\_params\_\_(Only takes effect when `combined` is True)
* combined: Whether to save the parameters to a unified file.
...@@ -242,7 +236,7 @@ ...@@ -242,7 +236,7 @@
- 1.0.0 - 1.0.0
First release First release
- 1.1.0 - 1.1.0
...@@ -252,4 +246,13 @@ ...@@ -252,4 +246,13 @@
* 1.1.1 * 1.1.1
Fix memory leakage problem of on cudnn 8.0.4 Fix memory leakage problem of on cudnn 8.0.4
* 1.2.0
Remove Fluid API
```shell
$ hub install humanseg_server == 1.2.0
```
...@@ -5,7 +5,6 @@ from collections import OrderedDict ...@@ -5,7 +5,6 @@ from collections import OrderedDict
import cv2 import cv2
import numpy as np import numpy as np
from PIL import Image
__all__ = ['reader', 'preprocess_v'] __all__ = ['reader', 'preprocess_v']
......
...@@ -20,9 +20,10 @@ import argparse ...@@ -20,9 +20,10 @@ import argparse
import cv2 import cv2
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle
import paddlehub as hub import paddle.jit
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor import paddle.static
from paddle.inference import Config, create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
from humanseg_server.processor import postprocess, base64_to_cv2, cv2_to_base64, check_dir from humanseg_server.processor import postprocess, base64_to_cv2, cv2_to_base64, check_dir
...@@ -36,22 +37,22 @@ from humanseg_server.optimal import postprocess_v, threshold_mask ...@@ -36,22 +37,22 @@ from humanseg_server.optimal import postprocess_v, threshold_mask
author="baidu-vis", author="baidu-vis",
author_email="", author_email="",
summary="DeepLabv3+ is a semantic segmentation model.", summary="DeepLabv3+ is a semantic segmentation model.",
version="1.1.0") version="1.2.0")
class DeeplabV3pXception65HumanSeg(hub.Module): class DeeplabV3pXception65HumanSeg:
def _initialize(self): def __init__(self):
self.default_pretrained_model_path = os.path.join(self.directory, "humanseg_server_inference") self.default_pretrained_model_path = os.path.join(self.directory, "humanseg_server_inference", "model")
self._set_config() self._set_config()
def _set_config(self): def _set_config(self):
""" """
predictor config setting predictor config setting
""" """
self.model_file_path = os.path.join(self.default_pretrained_model_path, '__model__') model = self.default_pretrained_model_path+'.pdmodel'
self.params_file_path = os.path.join(self.default_pretrained_model_path, '__params__') params = self.default_pretrained_model_path+'.pdiparams'
cpu_config = AnalysisConfig(self.model_file_path, self.params_file_path) cpu_config = Config(model, params)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config) self.cpu_predictor = create_predictor(cpu_config)
try: try:
_places = os.environ["CUDA_VISIBLE_DEVICES"] _places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0]) int(_places[0])
...@@ -59,10 +60,14 @@ class DeeplabV3pXception65HumanSeg(hub.Module): ...@@ -59,10 +60,14 @@ class DeeplabV3pXception65HumanSeg(hub.Module):
except: except:
use_gpu = False use_gpu = False
if use_gpu: if use_gpu:
gpu_config = AnalysisConfig(self.model_file_path, self.params_file_path) gpu_config = Config(model, params)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0) gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config)
if paddle.get_cudnn_version() == 8004:
gpu_config.delete_pass('conv_elementwise_add_act_fuse_pass')
gpu_config.delete_pass('conv_elementwise_add2_act_fuse_pass')
self.gpu_predictor = create_predictor(gpu_config)
def segment(self, def segment(self,
images=None, images=None,
...@@ -114,9 +119,16 @@ class DeeplabV3pXception65HumanSeg(hub.Module): ...@@ -114,9 +119,16 @@ class DeeplabV3pXception65HumanSeg(hub.Module):
pass pass
# feed batch image # feed batch image
batch_image = np.array([data['image'] for data in batch_data]) batch_image = np.array([data['image'] for data in batch_data])
batch_image = PaddleTensor(batch_image.copy())
output = self.gpu_predictor.run([batch_image]) if use_gpu else self.cpu_predictor.run([batch_image]) predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
output = output[1].as_ndarray() input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(batch_image.copy())
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[1])
output = output_handle.copy_to_cpu()
output = np.expand_dims(output[:, 1, :, :], axis=1) output = np.expand_dims(output[:, 1, :, :], axis=1)
# postprocess one by one # postprocess one by one
for i in range(len(batch_data)): for i in range(len(batch_data)):
...@@ -154,9 +166,16 @@ class DeeplabV3pXception65HumanSeg(hub.Module): ...@@ -154,9 +166,16 @@ class DeeplabV3pXception65HumanSeg(hub.Module):
height = int(frame_org.shape[1]) height = int(frame_org.shape[1])
disflow = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST) disflow = cv2.DISOpticalFlow_create(cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST)
frame = preprocess_v(frame_org, resize_w, resize_h) frame = preprocess_v(frame_org, resize_w, resize_h)
image = PaddleTensor(np.array([frame.copy()]))
output = self.gpu_predictor.run([image]) if use_gpu else self.cpu_predictor.run([image]) predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
score_map = output[1].as_ndarray() input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(frame.copy()[None, ...])
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[1])
score_map = output_handle.copy_to_cpu()
frame = np.transpose(frame, axes=[1, 2, 0]) frame = np.transpose(frame, axes=[1, 2, 0])
score_map = np.transpose(np.squeeze(score_map, 0), axes=[1, 2, 0]) score_map = np.transpose(np.squeeze(score_map, 0), axes=[1, 2, 0])
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
...@@ -173,7 +192,7 @@ class DeeplabV3pXception65HumanSeg(hub.Module): ...@@ -173,7 +192,7 @@ class DeeplabV3pXception65HumanSeg(hub.Module):
img_matting = cv2.resize(optflow_map, (height, width), cv2.INTER_LINEAR) img_matting = cv2.resize(optflow_map, (height, width), cv2.INTER_LINEAR)
return [img_matting, cur_gray, optflow_map] return [img_matting, cur_gray, optflow_map]
def video_segment(self, video_path=None, use_gpu=False, save_dir='humanseg_server_video'): def video_segment(self, video_path=None, use_gpu=False, save_dir='humanseg_server_video_result'):
resize_h = 512 resize_h = 512
resize_w = 512 resize_w = 512
if not video_path: if not video_path:
...@@ -201,9 +220,16 @@ class DeeplabV3pXception65HumanSeg(hub.Module): ...@@ -201,9 +220,16 @@ class DeeplabV3pXception65HumanSeg(hub.Module):
ret, frame_org = cap_video.read() ret, frame_org = cap_video.read()
if ret: if ret:
frame = preprocess_v(frame_org, resize_w, resize_h) frame = preprocess_v(frame_org, resize_w, resize_h)
image = PaddleTensor(np.array([frame.copy()]))
output = self.gpu_predictor.run([image]) if use_gpu else self.cpu_predictor.run([image]) predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
score_map = output[1].as_ndarray() input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(frame.copy()[None, ...])
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[1])
score_map = output_handle.copy_to_cpu()
frame = np.transpose(frame, axes=[1, 2, 0]) frame = np.transpose(frame, axes=[1, 2, 0])
score_map = np.transpose(np.squeeze(score_map, 0), axes=[1, 2, 0]) score_map = np.transpose(np.squeeze(score_map, 0), axes=[1, 2, 0])
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
...@@ -228,9 +254,16 @@ class DeeplabV3pXception65HumanSeg(hub.Module): ...@@ -228,9 +254,16 @@ class DeeplabV3pXception65HumanSeg(hub.Module):
ret, frame_org = cap_video.read() ret, frame_org = cap_video.read()
if ret: if ret:
frame = preprocess_v(frame_org, resize_w, resize_h) frame = preprocess_v(frame_org, resize_w, resize_h)
image = PaddleTensor(np.array([frame.copy()]))
output = self.gpu_predictor.run([image]) if use_gpu else self.cpu_predictor.run([image]) predictor = self.gpu_predictor if use_gpu else self.cpu_predictor
score_map = output[1].as_ndarray() input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(frame.copy()[None, ...])
predictor.run()
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[1])
score_map = output_handle.copy_to_cpu()
frame = np.transpose(frame, axes=[1, 2, 0]) frame = np.transpose(frame, axes=[1, 2, 0])
score_map = np.transpose(np.squeeze(score_map, 0), axes=[1, 2, 0]) score_map = np.transpose(np.squeeze(score_map, 0), axes=[1, 2, 0])
cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) cur_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
...@@ -252,30 +285,6 @@ class DeeplabV3pXception65HumanSeg(hub.Module): ...@@ -252,30 +285,6 @@ class DeeplabV3pXception65HumanSeg(hub.Module):
break break
cap_video.release() cap_video.release()
def save_inference_model(self,
dirname='humanseg_server_model',
model_filename=None,
params_filename=None,
combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.default_pretrained_model_path,
model_filename=model_filename,
params_filename=params_filename,
executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
@serving @serving
def serving_method(self, images, **kwargs): def serving_method(self, images, **kwargs):
""" """
......
import os
import shutil
import unittest
import cv2
import requests
import numpy as np
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/pg_WCHWSdT8/download?ixid=MnwxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjYyNDM2ODI4&force=true&w=640'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
img = cv2.imread('tests/test.jpg')
video = cv2.VideoWriter('tests/test.avi', fourcc,
20.0, tuple(img.shape[:2]))
for i in range(40):
video.write(img)
video.release()
cls.module = hub.Module(name="humanseg_server")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('humanseg_server_output')
shutil.rmtree('humanseg_server_video_result')
def test_segment1(self):
results = self.module.segment(
paths=['tests/test.jpg'],
use_gpu=False,
visualization=False
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segment2(self):
results = self.module.segment(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=False
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segment3(self):
results = self.module.segment(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=True
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segment4(self):
results = self.module.segment(
images=[cv2.imread('tests/test.jpg')],
use_gpu=True,
visualization=False
)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segment5(self):
self.assertRaises(
AssertionError,
self.module.segment,
paths=['no.jpg']
)
def test_segment6(self):
self.assertRaises(
AttributeError,
self.module.segment,
images=['test.jpg']
)
def test_video_stream_segment1(self):
img_matting, cur_gray, optflow_map = self.module.video_stream_segment(
frame_org=cv2.imread('tests/test.jpg'),
frame_id=1,
prev_gray=None,
prev_cfd=None,
use_gpu=False
)
self.assertIsInstance(img_matting, np.ndarray)
self.assertIsInstance(cur_gray, np.ndarray)
self.assertIsInstance(optflow_map, np.ndarray)
img_matting, cur_gray, optflow_map = self.module.video_stream_segment(
frame_org=cv2.imread('tests/test.jpg'),
frame_id=2,
prev_gray=cur_gray,
prev_cfd=optflow_map,
use_gpu=False
)
self.assertIsInstance(img_matting, np.ndarray)
self.assertIsInstance(cur_gray, np.ndarray)
self.assertIsInstance(optflow_map, np.ndarray)
def test_video_stream_segment2(self):
img_matting, cur_gray, optflow_map = self.module.video_stream_segment(
frame_org=cv2.imread('tests/test.jpg'),
frame_id=1,
prev_gray=None,
prev_cfd=None,
use_gpu=True
)
self.assertIsInstance(img_matting, np.ndarray)
self.assertIsInstance(cur_gray, np.ndarray)
self.assertIsInstance(optflow_map, np.ndarray)
img_matting, cur_gray, optflow_map = self.module.video_stream_segment(
frame_org=cv2.imread('tests/test.jpg'),
frame_id=2,
prev_gray=cur_gray,
prev_cfd=optflow_map,
use_gpu=True
)
self.assertIsInstance(img_matting, np.ndarray)
self.assertIsInstance(cur_gray, np.ndarray)
self.assertIsInstance(optflow_map, np.ndarray)
def test_video_segment1(self):
self.module.video_segment(
video_path="tests/test.avi",
use_gpu=False
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册