未验证 提交 3fcf01d5 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update humanseg_server (#2183)

上级 75b0857f
......@@ -229,6 +229,9 @@
cv2.imwrite("segment_human_server.png", rgba)
```
- ### Gradio APP 支持
从 PaddleHub 2.3.1 开始支持使用链接 http://127.0.0.1:8866/gradio/humanseg_server 在浏览器中访问 humanseg_server 的 Gradio APP。
## 五、更新历史
......@@ -258,6 +261,10 @@
移除 Fluid API
* 1.4.0
添加 Gradio APP 支持
```shell
$ hub install humanseg_server == 1.3.0
$ hub install humanseg_server == 1.4.0
```
......@@ -229,6 +229,8 @@
cv2.imwrite("segment_human_server.png", rgba)
```
- ### Gradio APP support
Starting with PaddleHub 2.3.1, the Gradio APP for humanseg_server is supported to be accessed in the browser using the link http://127.0.0.1:8866/gradio/humanseg_server.
## V. Release Note
......@@ -246,11 +248,14 @@
Fix memory leakage problem of on cudnn 8.0.4
* 1.2.0
* 1.3.0
Remove Fluid API
* 1.4.0
Add Gradio APP support.
```shell
$ hub install humanseg_server == 1.2.0
$ hub install humanseg_server == 1.4.0
```
......@@ -12,33 +12,39 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import os
import os.path as osp
import argparse
import cv2
import numpy as np
import paddle
import paddle.jit
import paddle.static
from paddle.inference import Config, create_predictor
from paddlehub.module.module import moduleinfo, runnable, serving
from paddle.inference import Config
from paddle.inference import create_predictor
from humanseg_server.processor import postprocess, base64_to_cv2, cv2_to_base64, check_dir
from humanseg_server.data_feed import reader, preprocess_v
from humanseg_server.optimal import postprocess_v, threshold_mask
from .data_feed import preprocess_v
from .data_feed import reader
from .optimal import postprocess_v
from .optimal import threshold_mask
from .processor import base64_to_cv2
from .processor import check_dir
from .processor import cv2_to_base64
from .processor import postprocess
from paddlehub.module.module import moduleinfo
from paddlehub.module.module import runnable
from paddlehub.module.module import serving
@moduleinfo(
name="humanseg_server",
@moduleinfo(name="humanseg_server",
type="CV/semantic_segmentation",
author="baidu-vis",
author_email="",
summary="DeepLabv3+ is a semantic segmentation model.",
version="1.3.0")
version="1.4.0")
class DeeplabV3pXception65HumanSeg:
def __init__(self):
self.default_pretrained_model_path = os.path.join(self.directory, "humanseg_server_inference", "model")
self._set_config()
......@@ -47,8 +53,8 @@ class DeeplabV3pXception65HumanSeg:
"""
predictor config setting
"""
model = self.default_pretrained_model_path+'.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams'
model = self.default_pretrained_model_path + '.pdmodel'
params = self.default_pretrained_model_path + '.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
......@@ -132,8 +138,7 @@ class DeeplabV3pXception65HumanSeg:
output = np.expand_dims(output[:, 1, :, :], axis=1)
# postprocess one by one
for i in range(len(batch_data)):
out = postprocess(
data_out=output[i],
out = postprocess(data_out=output[i],
org_im=batch_data[i]['org_im'],
org_im_shape=batch_data[i]['org_im_shape'],
org_im_path=batch_data[i]['org_im_path'],
......@@ -300,8 +305,7 @@ class DeeplabV3pXception65HumanSeg:
"""
Run as a command.
"""
self.parser = argparse.ArgumentParser(
description="Run the {} module.".format(self.name),
self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
......@@ -311,8 +315,7 @@ class DeeplabV3pXception65HumanSeg:
self.add_module_config_arg()
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
results = self.segment(
paths=[args.input_path],
results = self.segment(paths=[args.input_path],
batch_size=args.batch_size,
use_gpu=args.use_gpu,
output_dir=args.output_dir,
......@@ -326,14 +329,22 @@ class DeeplabV3pXception65HumanSeg:
"""
Add the command config options.
"""
self.arg_config_group.add_argument(
'--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not")
self.arg_config_group.add_argument(
'--output_dir', type=str, default='humanseg_server_output', help="The directory to save output images.")
self.arg_config_group.add_argument(
'--save_dir', type=str, default='humanseg_server_model', help="The directory to save model.")
self.arg_config_group.add_argument(
'--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.")
self.arg_config_group.add_argument('--use_gpu',
type=ast.literal_eval,
default=False,
help="whether use GPU or not")
self.arg_config_group.add_argument('--output_dir',
type=str,
default='humanseg_server_output',
help="The directory to save output images.")
self.arg_config_group.add_argument('--save_dir',
type=str,
default='humanseg_server_model',
help="The directory to save model.")
self.arg_config_group.add_argument('--visualization',
type=ast.literal_eval,
default=False,
help="whether to save output as images.")
self.arg_config_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.")
def add_module_input_arg(self):
......@@ -342,36 +353,20 @@ class DeeplabV3pXception65HumanSeg:
"""
self.arg_input_group.add_argument('--input_path', type=str, help="path to image.")
def create_gradio_app(self):
import gradio as gr
import tempfile
import os
from PIL import Image
if __name__ == "__main__":
m = DeeplabV3pXception65HumanSeg()
# img = cv2.imread('photo.jpg')
# res = m.segment(images=[img])
# print(res[0]['data'])
# m.save_inference_model()
#m.video_segment(video_path='video_test.mp4')
img = cv2.imread('photo.jpg')
# res = m.segment(images=[img], visualization=True)
# print(res[0]['data'])
# m.video_segment('')
cap_video = cv2.VideoCapture('video_test.mp4')
fps = cap_video.get(cv2.CAP_PROP_FPS)
save_path = 'result_frame.avi'
width = int(cap_video.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap_out = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, (width, height))
prev_gray = None
prev_cfd = None
while cap_video.isOpened():
ret, frame_org = cap_video.read()
if ret:
[img_matting, prev_gray, prev_cfd] = m.video_stream_segment(
frame_org=frame_org, frame_id=cap_video.get(1), prev_gray=prev_gray, prev_cfd=prev_cfd)
img_matting = np.repeat(img_matting[:, :, np.newaxis], 3, axis=2)
bg_im = np.ones_like(img_matting) * 255
comb = (img_matting * frame_org + (1 - img_matting) * bg_im).astype(np.uint8)
cap_out.write(comb)
else:
break
cap_video.release()
cap_out.release()
def inference(image, use_gpu=False):
with tempfile.TemporaryDirectory() as temp_dir:
self.segment(paths=[image], use_gpu=use_gpu, visualization=True, output_dir=temp_dir)
return Image.open(os.path.join(temp_dir, os.listdir(temp_dir)[0]))
interface = gr.Interface(
inference,
[gr.inputs.Image(type="filepath"), gr.Checkbox(label='use_gpu')],
gr.outputs.Image(type="ndarray"),
title='humanseg_server')
return interface
......@@ -32,8 +32,8 @@ def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
# 超出边界不跟踪
not_track = (cur_x < 0) + (cur_x >= w) + (cur_y < 0) + (cur_y >= h)
flow_bw[~not_track] = flow_bw[cur_y[~not_track], cur_x[~not_track]]
not_track += (
np.square(flow_fw[:, :, 0] + flow_bw[:, :, 0]) + np.square(flow_fw[:, :, 1] + flow_bw[:, :, 1])) >= check_thres
not_track += (np.square(flow_fw[:, :, 0] + flow_bw[:, :, 0]) +
np.square(flow_fw[:, :, 1] + flow_bw[:, :, 1])) >= check_thres
track_cfd[cur_y[~not_track], cur_x[~not_track]] = prev_cfd[~not_track]
is_track[cur_y[~not_track], cur_x[~not_track]] = 1
......
# -*- coding:utf-8 -*-
import base64
import os
import time
import base64
import cv2
import numpy as np
......
......@@ -3,15 +3,16 @@ import shutil
import unittest
import cv2
import requests
import numpy as np
import paddlehub as hub
import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/pg_WCHWSdT8/download?ixid=MnwxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjYyNDM2ODI4&force=true&w=640'
......@@ -23,8 +24,7 @@ class TestHubModule(unittest.TestCase):
f.write(response.content)
fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
img = cv2.imread('tests/test.jpg')
video = cv2.VideoWriter('tests/test.avi', fourcc,
20.0, tuple(img.shape[:2]))
video = cv2.VideoWriter('tests/test.avi', fourcc, 20.0, tuple(img.shape[:2]))
for i in range(40):
video.write(img)
video.release()
......@@ -38,100 +38,65 @@ class TestHubModule(unittest.TestCase):
shutil.rmtree('humanseg_server_video_result')
def test_segment1(self):
results = self.module.segment(
paths=['tests/test.jpg'],
use_gpu=False,
visualization=False
)
results = self.module.segment(paths=['tests/test.jpg'], use_gpu=False, visualization=False)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segment2(self):
results = self.module.segment(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=False
)
results = self.module.segment(images=[cv2.imread('tests/test.jpg')], use_gpu=False, visualization=False)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segment3(self):
results = self.module.segment(
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=True
)
results = self.module.segment(images=[cv2.imread('tests/test.jpg')], use_gpu=False, visualization=True)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segment4(self):
results = self.module.segment(
images=[cv2.imread('tests/test.jpg')],
use_gpu=True,
visualization=False
)
results = self.module.segment(images=[cv2.imread('tests/test.jpg')], use_gpu=True, visualization=False)
self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segment5(self):
self.assertRaises(
AssertionError,
self.module.segment,
paths=['no.jpg']
)
self.assertRaises(AssertionError, self.module.segment, paths=['no.jpg'])
def test_segment6(self):
self.assertRaises(
AttributeError,
self.module.segment,
images=['test.jpg']
)
self.assertRaises(AttributeError, self.module.segment, images=['test.jpg'])
def test_video_stream_segment1(self):
img_matting, cur_gray, optflow_map = self.module.video_stream_segment(
frame_org=cv2.imread('tests/test.jpg'),
img_matting, cur_gray, optflow_map = self.module.video_stream_segment(frame_org=cv2.imread('tests/test.jpg'),
frame_id=1,
prev_gray=None,
prev_cfd=None,
use_gpu=False
)
use_gpu=False)
self.assertIsInstance(img_matting, np.ndarray)
self.assertIsInstance(cur_gray, np.ndarray)
self.assertIsInstance(optflow_map, np.ndarray)
img_matting, cur_gray, optflow_map = self.module.video_stream_segment(
frame_org=cv2.imread('tests/test.jpg'),
img_matting, cur_gray, optflow_map = self.module.video_stream_segment(frame_org=cv2.imread('tests/test.jpg'),
frame_id=2,
prev_gray=cur_gray,
prev_cfd=optflow_map,
use_gpu=False
)
use_gpu=False)
self.assertIsInstance(img_matting, np.ndarray)
self.assertIsInstance(cur_gray, np.ndarray)
self.assertIsInstance(optflow_map, np.ndarray)
def test_video_stream_segment2(self):
img_matting, cur_gray, optflow_map = self.module.video_stream_segment(
frame_org=cv2.imread('tests/test.jpg'),
img_matting, cur_gray, optflow_map = self.module.video_stream_segment(frame_org=cv2.imread('tests/test.jpg'),
frame_id=1,
prev_gray=None,
prev_cfd=None,
use_gpu=True
)
use_gpu=True)
self.assertIsInstance(img_matting, np.ndarray)
self.assertIsInstance(cur_gray, np.ndarray)
self.assertIsInstance(optflow_map, np.ndarray)
img_matting, cur_gray, optflow_map = self.module.video_stream_segment(
frame_org=cv2.imread('tests/test.jpg'),
img_matting, cur_gray, optflow_map = self.module.video_stream_segment(frame_org=cv2.imread('tests/test.jpg'),
frame_id=2,
prev_gray=cur_gray,
prev_cfd=optflow_map,
use_gpu=True
)
use_gpu=True)
self.assertIsInstance(img_matting, np.ndarray)
self.assertIsInstance(cur_gray, np.ndarray)
self.assertIsInstance(optflow_map, np.ndarray)
def test_video_segment1(self):
self.module.video_segment(
video_path="tests/test.avi",
use_gpu=False
)
self.module.video_segment(video_path="tests/test.avi", use_gpu=False)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册