未验证 提交 75b0857f 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update humanseg_mobile (#2182)

上级 e301f567
# humanseg_mobile # humanseg_mobile
|模型名称|humanseg_mobile| |模型名称|humanseg_mobile|
| :--- | :---: | | :--- | :---: |
|类别|图像-图像分割| |类别|图像-图像分割|
|网络|hrnet| |网络|hrnet|
|数据集|百度自建数据集| |数据集|百度自建数据集|
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
- ### 应用效果展示 - ### 应用效果展示
- 样例结果示例: - 样例结果示例:
<p align="center"> <p align="center">
<img src="https://user-images.githubusercontent.com/35907364/130913092-312a5f37-842e-4fd0-8db4-5f853fd8419f.jpg" width = "337" height = "505" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/130914325-3795e241-b611-46a1-aa70-ffc47326c86a.png" width = "337" height = "505" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/130913092-312a5f37-842e-4fd0-8db4-5f853fd8419f.jpg" width = "337" height = "505" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/130914325-3795e241-b611-46a1-aa70-ffc47326c86a.png" width = "337" height = "505" hspace='10'/>
</p> </p>
- ### 模型介绍 - ### 模型介绍
- HumanSeg-mobile采用了HRNet_w18_small_v1的网络结构,模型大小只有5.8M, 适用于移动端或服务端CPU的前置摄像头场景。 - HumanSeg-mobile采用了HRNet_w18_small_v1的网络结构,模型大小只有5.8M, 适用于移动端或服务端CPU的前置摄像头场景。
...@@ -41,7 +41,7 @@ ...@@ -41,7 +41,7 @@
- ```shell - ```shell
$ hub install humanseg_mobile $ hub install humanseg_mobile
``` ```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md) - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md) | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
...@@ -230,6 +230,9 @@ ...@@ -230,6 +230,9 @@
cv2.imwrite("segment_human_mobile.png", rgba) cv2.imwrite("segment_human_mobile.png", rgba)
``` ```
- ### Gradio APP 支持
从 PaddleHub 2.3.1 开始支持使用链接 http://127.0.0.1:8866/gradio/humanseg_mobile 在浏览器中访问 humanseg_mobile 的 Gradio APP。
## 五、更新历史 ## 五、更新历史
...@@ -238,19 +241,23 @@ ...@@ -238,19 +241,23 @@
初始发布 初始发布
* 1.1.0 * 1.1.0
新增视频人像分割接口 新增视频人像分割接口
新增视频流人像分割接口 新增视频流人像分割接口
* 1.1.1 * 1.1.1
修复cudnn为8.0.4显存泄露问题 修复cudnn为8.0.4显存泄露问题
* 1.2.0 * 1.2.0
移除 Fluid API 移除 Fluid API
* 1.3.0
添加 Gradio APP 支持
```shell ```shell
$ hub install humanseg_mobile == 1.2.0 $ hub install humanseg_mobile == 1.3.0
``` ```
# humanseg_mobile # humanseg_mobile
|Module Name |humanseg_mobile| |Module Name |humanseg_mobile|
| :--- | :---: | | :--- | :---: |
|Category |Image segmentation| |Category |Image segmentation|
|Network|hrnet| |Network|hrnet|
|Dataset|Baidu self-built dataset| |Dataset|Baidu self-built dataset|
...@@ -10,10 +10,10 @@ ...@@ -10,10 +10,10 @@
|Data indicators|-| |Data indicators|-|
|Latest update date|2021-02-26| |Latest update date|2021-02-26|
## I. Basic Information ## I. Basic Information
- ### Application Effect Display - ### Application Effect Display
- Sample results: - Sample results:
<p align="center"> <p align="center">
<img src="https://user-images.githubusercontent.com/35907364/130913092-312a5f37-842e-4fd0-8db4-5f853fd8419f.jpg" width = "337" height = "505" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/130914325-3795e241-b611-46a1-aa70-ffc47326c86a.png" width = "337" height = "505" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/130913092-312a5f37-842e-4fd0-8db4-5f853fd8419f.jpg" width = "337" height = "505" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/130914325-3795e241-b611-46a1-aa70-ffc47326c86a.png" width = "337" height = "505" hspace='10'/>
...@@ -39,7 +39,7 @@ ...@@ -39,7 +39,7 @@
- ```shell - ```shell
$ hub install humanseg_mobile $ hub install humanseg_mobile
``` ```
- In case of any problems during installation, please refer to:[Windows_Quickstart](../../../../docs/docs_en/get_start/windows_quickstart.md) - In case of any problems during installation, please refer to:[Windows_Quickstart](../../../../docs/docs_en/get_start/windows_quickstart.md)
| [Linux_Quickstart](../../../../docs/docs_en/get_start/linux_quickstart.md) | [Mac_Quickstart](../../../../docs/docs_en/get_start/mac_quickstart.md) | [Linux_Quickstart](../../../../docs/docs_en/get_start/linux_quickstart.md) | [Mac_Quickstart](../../../../docs/docs_en/get_start/mac_quickstart.md)
...@@ -49,11 +49,11 @@ ...@@ -49,11 +49,11 @@
- ``` - ```
hub run humanseg_mobile --input_path "/PATH/TO/IMAGE" hub run humanseg_mobile --input_path "/PATH/TO/IMAGE"
``` ```
- If you want to call the Hub module through the command line, please refer to: [PaddleHub Command Line Instruction](../../../../docs/docs_en/tutorial/cmd_usage.rst) - If you want to call the Hub module through the command line, please refer to: [PaddleHub Command Line Instruction](../../../../docs/docs_en/tutorial/cmd_usage.rst)
- ### 2、Prediction Code Example - ### 2、Prediction Code Example
- Image segmentation and video segmentation example: - Image segmentation and video segmentation example:
```python ```python
...@@ -122,9 +122,9 @@ ...@@ -122,9 +122,9 @@
- **Return** - **Return**
* res (list\[dict\]): The list of recognition results, where each element is dict and each field is: * res (list\[dict\]): The list of recognition results, where each element is dict and each field is:
* save\_path (str, optional): Save path of the result. * save\_path (str, optional): Save path of the result.
* data (numpy.ndarray): The result of portrait segmentation. * data (numpy.ndarray): The result of portrait segmentation.
```python ```python
def video_stream_segment(self, def video_stream_segment(self,
...@@ -231,6 +231,8 @@ ...@@ -231,6 +231,8 @@
cv2.imwrite("segment_human_mobile.png", rgba) cv2.imwrite("segment_human_mobile.png", rgba)
``` ```
- ### Gradio APP support
Starting with PaddleHub 2.3.1, the Gradio APP for humanseg_mobile is supported to be accessed in the browser using the link http://127.0.0.1:8866/gradio/humanseg_mobile.
## V. Release Note ## V. Release Note
...@@ -239,7 +241,7 @@ ...@@ -239,7 +241,7 @@
First release First release
- 1.1.0 - 1.1.0
Added video portrait split interface Added video portrait split interface
Added video stream portrait segmentation interface Added video stream portrait segmentation interface
...@@ -252,7 +254,10 @@ ...@@ -252,7 +254,10 @@
Remove Fluid API Remove Fluid API
* 1.3.0
Add Gradio APP support.
```shell ```shell
$ hub install humanseg_mobile == 1.2.0 $ hub install humanseg_mobile == 1.3.0
``` ```
...@@ -12,43 +12,49 @@ ...@@ -12,43 +12,49 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
import ast import ast
import os import os
import os.path as osp import os.path as osp
import argparse
import cv2 import cv2
import numpy as np import numpy as np
import paddle
import paddle.jit import paddle.jit
import paddle.static import paddle.static
from paddle.inference import Config, create_predictor from paddle.inference import Config
from paddlehub.module.module import moduleinfo, runnable, serving from paddle.inference import create_predictor
from .processor import postprocess, base64_to_cv2, cv2_to_base64, check_dir from .data_feed import preprocess_v
from .data_feed import reader, preprocess_v from .data_feed import reader
from .optimal import postprocess_v, threshold_mask from .optimal import postprocess_v
from .optimal import threshold_mask
from .processor import base64_to_cv2
@moduleinfo( from .processor import check_dir
name="humanseg_mobile", from .processor import cv2_to_base64
type="CV/semantic_segmentation", from .processor import postprocess
author="paddlepaddle", from paddlehub.module.module import moduleinfo
author_email="", from paddlehub.module.module import runnable
summary="HRNet_w18_samll_v1 is a semantic segmentation model.", from paddlehub.module.module import serving
version="1.2.0")
@moduleinfo(name="humanseg_mobile",
type="CV/semantic_segmentation",
author="paddlepaddle",
author_email="",
summary="HRNet_w18_samll_v1 is a semantic segmentation model.",
version="1.3.0")
class HRNetw18samllv1humanseg: class HRNetw18samllv1humanseg:
def __init__(self): def __init__(self):
self.default_pretrained_model_path = os.path.join( self.default_pretrained_model_path = os.path.join(self.directory, "humanseg_mobile_inference", "model")
self.directory, "humanseg_mobile_inference", "model")
self._set_config() self._set_config()
def _set_config(self): def _set_config(self):
""" """
predictor config setting predictor config setting
""" """
model = self.default_pretrained_model_path+'.pdmodel' model = self.default_pretrained_model_path + '.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams' params = self.default_pretrained_model_path + '.pdiparams'
cpu_config = Config(model, params) cpu_config = Config(model, params)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
...@@ -63,7 +69,7 @@ class HRNetw18samllv1humanseg: ...@@ -63,7 +69,7 @@ class HRNetw18samllv1humanseg:
gpu_config = Config(model, params) gpu_config = Config(model, params)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0) gpu_config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0)
if paddle.get_cudnn_version() == 8004: if paddle.get_cudnn_version() == 8004:
gpu_config.delete_pass('conv_elementwise_add_act_fuse_pass') gpu_config.delete_pass('conv_elementwise_add_act_fuse_pass')
gpu_config.delete_pass('conv_elementwise_add2_act_fuse_pass') gpu_config.delete_pass('conv_elementwise_add2_act_fuse_pass')
...@@ -131,13 +137,12 @@ class HRNetw18samllv1humanseg: ...@@ -131,13 +137,12 @@ class HRNetw18samllv1humanseg:
output = np.expand_dims(output[:, 1, :, :], axis=1) output = np.expand_dims(output[:, 1, :, :], axis=1)
# postprocess one by one # postprocess one by one
for i in range(len(batch_data)): for i in range(len(batch_data)):
out = postprocess( out = postprocess(data_out=output[i],
data_out=output[i], org_im=batch_data[i]['org_im'],
org_im=batch_data[i]['org_im'], org_im_shape=batch_data[i]['org_im_shape'],
org_im_shape=batch_data[i]['org_im_shape'], org_im_path=batch_data[i]['org_im_path'],
org_im_path=batch_data[i]['org_im_path'], output_dir=output_dir,
output_dir=output_dir, visualization=visualization)
visualization=visualization)
res.append(out) res.append(out)
return res return res
...@@ -317,23 +322,21 @@ class HRNetw18samllv1humanseg: ...@@ -317,23 +322,21 @@ class HRNetw18samllv1humanseg:
""" """
Run as a command. Run as a command.
""" """
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(description="Run the {} module.".format(self.name),
description="Run the {} module.".format(self.name), prog='hub run {}'.format(self.name),
prog='hub run {}'.format(self.name), usage='%(prog)s',
usage='%(prog)s', add_help=True)
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group( self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, not required.") title="Config options", description="Run configuration for controlling module behavior, not required.")
self.add_module_config_arg() self.add_module_config_arg()
self.add_module_input_arg() self.add_module_input_arg()
args = self.parser.parse_args(argvs) args = self.parser.parse_args(argvs)
results = self.segment( results = self.segment(paths=[args.input_path],
paths=[args.input_path], batch_size=args.batch_size,
batch_size=args.batch_size, use_gpu=args.use_gpu,
use_gpu=args.use_gpu, output_dir=args.output_dir,
output_dir=args.output_dir, visualization=args.visualization)
visualization=args.visualization)
if args.save_dir is not None: if args.save_dir is not None:
check_dir(args.save_dir) check_dir(args.save_dir)
self.save_inference_model(args.save_dir) self.save_inference_model(args.save_dir)
...@@ -344,14 +347,22 @@ class HRNetw18samllv1humanseg: ...@@ -344,14 +347,22 @@ class HRNetw18samllv1humanseg:
""" """
Add the command config options. Add the command config options.
""" """
self.arg_config_group.add_argument( self.arg_config_group.add_argument('--use_gpu',
'--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not") type=ast.literal_eval,
self.arg_config_group.add_argument( default=False,
'--output_dir', type=str, default='humanseg_mobile_output', help="The directory to save output images.") help="whether use GPU or not")
self.arg_config_group.add_argument( self.arg_config_group.add_argument('--output_dir',
'--save_dir', type=str, default='humanseg_mobile_model', help="The directory to save model.") type=str,
self.arg_config_group.add_argument( default='humanseg_mobile_output',
'--visualization', type=ast.literal_eval, default=False, help="whether to save output as images.") help="The directory to save output images.")
self.arg_config_group.add_argument('--save_dir',
type=str,
default='humanseg_mobile_model',
help="The directory to save model.")
self.arg_config_group.add_argument('--visualization',
type=ast.literal_eval,
default=False,
help="whether to save output as images.")
self.arg_config_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.") self.arg_config_group.add_argument('--batch_size', type=ast.literal_eval, default=1, help="batch size.")
def add_module_input_arg(self): def add_module_input_arg(self):
...@@ -360,31 +371,20 @@ class HRNetw18samllv1humanseg: ...@@ -360,31 +371,20 @@ class HRNetw18samllv1humanseg:
""" """
self.arg_input_group.add_argument('--input_path', type=str, help="path to image.") self.arg_input_group.add_argument('--input_path', type=str, help="path to image.")
def create_gradio_app(self):
if __name__ == "__main__": import gradio as gr
m = HRNetw18samllv1humanseg() import tempfile
img = cv2.imread('photo.jpg') import os
#res = m.segment(images=[img], visualization=True) from PIL import Image
#print(res[0]['data'])
#m.video_segment('') def inference(image, use_gpu=False):
cap_video = cv2.VideoCapture('video_test.mp4') with tempfile.TemporaryDirectory() as temp_dir:
fps = cap_video.get(cv2.CAP_PROP_FPS) self.segment(paths=[image], use_gpu=use_gpu, visualization=True, output_dir=temp_dir)
save_path = 'result_frame.avi' return Image.open(os.path.join(temp_dir, os.listdir(temp_dir)[0]))
width = int(cap_video.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap_video.get(cv2.CAP_PROP_FRAME_HEIGHT)) interface = gr.Interface(
cap_out = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, (width, height)) inference,
prev_gray = None [gr.inputs.Image(type="filepath"), gr.Checkbox(label='use_gpu')],
prev_cfd = None gr.outputs.Image(type="ndarray"),
while cap_video.isOpened(): title='humanseg_mobile')
ret, frame_org = cap_video.read() return interface
if ret:
[img_matting, prev_gray, prev_cfd] = m.video_stream_segment(
frame_org=frame_org, frame_id=cap_video.get(1), prev_gray=prev_gray, prev_cfd=prev_cfd)
img_matting = np.repeat(img_matting[:, :, np.newaxis], 3, axis=2)
bg_im = np.ones_like(img_matting) * 255
comb = (img_matting * frame_org + (1 - img_matting) * bg_im).astype(np.uint8)
cap_out.write(comb)
else:
break
cap_video.release()
cap_out.release()
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import numpy as np import numpy as np
...@@ -33,8 +32,8 @@ def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow): ...@@ -33,8 +32,8 @@ def human_seg_tracking(pre_gray, cur_gray, prev_cfd, dl_weights, disflow):
# 超出边界不跟踪 # 超出边界不跟踪
not_track = (cur_x < 0) + (cur_x >= w) + (cur_y < 0) + (cur_y >= h) not_track = (cur_x < 0) + (cur_x >= w) + (cur_y < 0) + (cur_y >= h)
flow_bw[~not_track] = flow_bw[cur_y[~not_track], cur_x[~not_track]] flow_bw[~not_track] = flow_bw[cur_y[~not_track], cur_x[~not_track]]
not_track += ( not_track += (np.square(flow_fw[:, :, 0] + flow_bw[:, :, 0]) +
np.square(flow_fw[:, :, 0] + flow_bw[:, :, 0]) + np.square(flow_fw[:, :, 1] + flow_bw[:, :, 1])) >= check_thres np.square(flow_fw[:, :, 1] + flow_bw[:, :, 1])) >= check_thres
track_cfd[cur_y[~not_track], cur_x[~not_track]] = prev_cfd[~not_track] track_cfd[cur_y[~not_track], cur_x[~not_track]] = prev_cfd[~not_track]
is_track[cur_y[~not_track], cur_x[~not_track]] = 1 is_track[cur_y[~not_track], cur_x[~not_track]] = 1
......
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import base64
import os import os
import time import time
import base64
import cv2 import cv2
import numpy as np import numpy as np
......
...@@ -3,15 +3,16 @@ import shutil ...@@ -3,15 +3,16 @@ import shutil
import unittest import unittest
import cv2 import cv2
import requests
import numpy as np import numpy as np
import paddlehub as hub import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase): class TestHubModule(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
img_url = 'https://unsplash.com/photos/pg_WCHWSdT8/download?ixid=MnwxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjYyNDM2ODI4&force=true&w=640' img_url = 'https://unsplash.com/photos/pg_WCHWSdT8/download?ixid=MnwxMjA3fDB8MXxhbGx8fHx8fHx8fHwxNjYyNDM2ODI4&force=true&w=640'
...@@ -23,8 +24,7 @@ class TestHubModule(unittest.TestCase): ...@@ -23,8 +24,7 @@ class TestHubModule(unittest.TestCase):
f.write(response.content) f.write(response.content)
fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
img = cv2.imread('tests/test.jpg') img = cv2.imread('tests/test.jpg')
video = cv2.VideoWriter('tests/test.avi', fourcc, video = cv2.VideoWriter('tests/test.avi', fourcc, 20.0, tuple(img.shape[:2]))
20.0, tuple(img.shape[:2]))
for i in range(40): for i in range(40):
video.write(img) video.write(img)
video.release() video.release()
...@@ -38,100 +38,65 @@ class TestHubModule(unittest.TestCase): ...@@ -38,100 +38,65 @@ class TestHubModule(unittest.TestCase):
shutil.rmtree('humanseg_mobile_video_result') shutil.rmtree('humanseg_mobile_video_result')
def test_segment1(self): def test_segment1(self):
results = self.module.segment( results = self.module.segment(paths=['tests/test.jpg'], use_gpu=False, visualization=False)
paths=['tests/test.jpg'],
use_gpu=False,
visualization=False
)
self.assertIsInstance(results[0]['data'], np.ndarray) self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segment2(self): def test_segment2(self):
results = self.module.segment( results = self.module.segment(images=[cv2.imread('tests/test.jpg')], use_gpu=False, visualization=False)
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=False
)
self.assertIsInstance(results[0]['data'], np.ndarray) self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segment3(self): def test_segment3(self):
results = self.module.segment( results = self.module.segment(images=[cv2.imread('tests/test.jpg')], use_gpu=False, visualization=True)
images=[cv2.imread('tests/test.jpg')],
use_gpu=False,
visualization=True
)
self.assertIsInstance(results[0]['data'], np.ndarray) self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segment4(self): def test_segment4(self):
results = self.module.segment( results = self.module.segment(images=[cv2.imread('tests/test.jpg')], use_gpu=True, visualization=False)
images=[cv2.imread('tests/test.jpg')],
use_gpu=True,
visualization=False
)
self.assertIsInstance(results[0]['data'], np.ndarray) self.assertIsInstance(results[0]['data'], np.ndarray)
def test_segment5(self): def test_segment5(self):
self.assertRaises( self.assertRaises(AssertionError, self.module.segment, paths=['no.jpg'])
AssertionError,
self.module.segment,
paths=['no.jpg']
)
def test_segment6(self): def test_segment6(self):
self.assertRaises( self.assertRaises(AttributeError, self.module.segment, images=['test.jpg'])
AttributeError,
self.module.segment,
images=['test.jpg']
)
def test_video_stream_segment1(self): def test_video_stream_segment1(self):
img_matting, cur_gray, optflow_map = self.module.video_stream_segment( img_matting, cur_gray, optflow_map = self.module.video_stream_segment(frame_org=cv2.imread('tests/test.jpg'),
frame_org=cv2.imread('tests/test.jpg'), frame_id=1,
frame_id=1, prev_gray=None,
prev_gray=None, prev_cfd=None,
prev_cfd=None, use_gpu=False)
use_gpu=False
)
self.assertIsInstance(img_matting, np.ndarray) self.assertIsInstance(img_matting, np.ndarray)
self.assertIsInstance(cur_gray, np.ndarray) self.assertIsInstance(cur_gray, np.ndarray)
self.assertIsInstance(optflow_map, np.ndarray) self.assertIsInstance(optflow_map, np.ndarray)
img_matting, cur_gray, optflow_map = self.module.video_stream_segment( img_matting, cur_gray, optflow_map = self.module.video_stream_segment(frame_org=cv2.imread('tests/test.jpg'),
frame_org=cv2.imread('tests/test.jpg'), frame_id=2,
frame_id=2, prev_gray=cur_gray,
prev_gray=cur_gray, prev_cfd=optflow_map,
prev_cfd=optflow_map, use_gpu=False)
use_gpu=False
)
self.assertIsInstance(img_matting, np.ndarray) self.assertIsInstance(img_matting, np.ndarray)
self.assertIsInstance(cur_gray, np.ndarray) self.assertIsInstance(cur_gray, np.ndarray)
self.assertIsInstance(optflow_map, np.ndarray) self.assertIsInstance(optflow_map, np.ndarray)
def test_video_stream_segment2(self): def test_video_stream_segment2(self):
img_matting, cur_gray, optflow_map = self.module.video_stream_segment( img_matting, cur_gray, optflow_map = self.module.video_stream_segment(frame_org=cv2.imread('tests/test.jpg'),
frame_org=cv2.imread('tests/test.jpg'), frame_id=1,
frame_id=1, prev_gray=None,
prev_gray=None, prev_cfd=None,
prev_cfd=None, use_gpu=True)
use_gpu=True
)
self.assertIsInstance(img_matting, np.ndarray) self.assertIsInstance(img_matting, np.ndarray)
self.assertIsInstance(cur_gray, np.ndarray) self.assertIsInstance(cur_gray, np.ndarray)
self.assertIsInstance(optflow_map, np.ndarray) self.assertIsInstance(optflow_map, np.ndarray)
img_matting, cur_gray, optflow_map = self.module.video_stream_segment( img_matting, cur_gray, optflow_map = self.module.video_stream_segment(frame_org=cv2.imread('tests/test.jpg'),
frame_org=cv2.imread('tests/test.jpg'), frame_id=2,
frame_id=2, prev_gray=cur_gray,
prev_gray=cur_gray, prev_cfd=optflow_map,
prev_cfd=optflow_map, use_gpu=True)
use_gpu=True
)
self.assertIsInstance(img_matting, np.ndarray) self.assertIsInstance(img_matting, np.ndarray)
self.assertIsInstance(cur_gray, np.ndarray) self.assertIsInstance(cur_gray, np.ndarray)
self.assertIsInstance(optflow_map, np.ndarray) self.assertIsInstance(optflow_map, np.ndarray)
def test_video_segment1(self): def test_video_segment1(self):
self.module.video_segment( self.module.video_segment(video_path="tests/test.avi", use_gpu=False)
video_path="tests/test.avi",
use_gpu=False
)
def test_save_inference_model(self): def test_save_inference_model(self):
self.module.save_inference_model('./inference/model') self.module.save_inference_model('./inference/model')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册