未验证 提交 939258c6 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update Extract_Line_Draft (#2021)

上级 5699dfb3
Extract_Line_Draft # Extract_Line_Draft
类别 图像 - 图像分割
# 模型概述 |模型名称|Extract_Line_Draft|
提取线稿(Extract_Line_Draft),该模型可自动根据彩色图生成线稿图。该PaddleHub Module支持API预测及命令行预测。 | :--- | :---: |
|类别|图像-图像分割|
|网络|-|
|数据集|-|
|是否支持Fine-tuning|否|
|模型大小|259MB|
|指标|-|
|最新更新日期|2021-02-26|
# 选择模型版本进行安装
$ hub install Extract_Line_Draft==1.0.0
# 命令行预测示例 ## 一、模型基本信息
$ hub run Extract_Line_Draft --image 1.png --use_gpu True
# Module API说明 - ### 应用效果展示
## ExtractLine(self, image, use_gpu=False)
提取线稿预测接口,预测输入一张图像,输出该图像的线稿
### 参数
- image(str): 待检测的图片路径
- use_gpu (bool): 是否使用 GPU
- 样例结果示例:
<p align="center">
<img src="https://ai-studio-static-online.cdn.bcebos.com/1c30757e069541a18dc89b92f0750983b77ad762560849afa0170046672e57a3" width = "337" height = "505" hspace='10'/> <img src="https://ai-studio-static-online.cdn.bcebos.com/7ef00637e5974be2847317053f8abe97236cec75fba14f77be2c095529a1eeb3" width = "337" height = "505" hspace='10'/>
</p>
# 代码示例 - ### 模型介绍
## API调用 - 提取线稿(Extract_Line_Draft),该模型可自动根据彩色图生成线稿图。该PaddleHub Module支持API预测及命令行预测。
~~~
import paddlehub as hub
Extract_Line_Draft_test = hub.Module(name="Extract_Line_Draft")
test_img = "testImage.png" ## 二、安装
# execute predict - ### 1、环境依赖
Extract_Line_Draft_test.ExtractLine(test_img, use_gpu=True)
~~~
## 命令行调用 - paddlepaddle >= 2.0.0
~~~
!hub run Extract_Line_Draft --input_path "testImage" --use_gpu True
~~~
# 效果展示 - paddlehub >= 2.0.0
## 原图 - ### 2.安装
![](https://ai-studio-static-online.cdn.bcebos.com/1c30757e069541a18dc89b92f0750983b77ad762560849afa0170046672e57a3)
![](https://ai-studio-static-online.cdn.bcebos.com/4a544c9ecd79461bbc1d1556d100b21d28b41b4f23db440ab776af78764292f2)
- ```shell
$ hub install Extract_Line_Draft
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 线稿图 ## 三、模型API预测
![](https://ai-studio-static-online.cdn.bcebos.com/7ef00637e5974be2847317053f8abe97236cec75fba14f77be2c095529a1eeb3) - ### 1、命令行预测
![](https://ai-studio-static-online.cdn.bcebos.com/074ea02d89bc4b5c9004a077b61301fa49583c13af734bd6a49e81f59f9cd322)
```shell
$ hub run Extract_Line_Draft --input_path "testImage" --use_gpu True
```
# 贡献者 - ### 2、预测代码示例
彭兆帅、郑博培
# 依赖 ```python
paddlepaddle >= 1.8.2 import paddlehub as hub
paddlehub >= 1.8.0
Extract_Line_Draft_test = hub.Module(name="Extract_Line_Draft")
test_img = "testImage.png"
# execute predict
Extract_Line_Draft_test.ExtractLine(test_img, use_gpu=True)
```
- ### 3、API
```python
def ExtractLine(image, use_gpu=False)
```
- 预测API,用于图像分割得到人体解析。
- **参数**
* image(str): 待检测的图片路径
* use_gpu (bool): 是否使用 GPU
## 四、更新历史
* 1.0.0
初始发布
* 1.1.0
移除 Fluid API
```shell
$ hub install Extract_Line_Draft == 1.1.0
```
\ No newline at end of file
...@@ -4,9 +4,9 @@ from scipy import ndimage ...@@ -4,9 +4,9 @@ from scipy import ndimage
def get_normal_map(img): def get_normal_map(img):
img = img.astype(np.float) img = img.astype(np.float32)
img = img / 255.0 img = img / 255.0
img = -img + 1 img = - img + 1
img[img < 0] = 0 img[img < 0] = 0
img[img > 1] = 1 img[img > 1] = 1
return img return img
...@@ -14,7 +14,7 @@ def get_normal_map(img): ...@@ -14,7 +14,7 @@ def get_normal_map(img):
def get_gray_map(img): def get_gray_map(img):
gray = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2GRAY) gray = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2GRAY)
highPass = gray.astype(np.float) highPass = gray.astype(np.float32)
highPass = highPass / 255.0 highPass = highPass / 255.0
highPass = 1 - highPass highPass = 1 - highPass
highPass = highPass[None] highPass = highPass[None]
...@@ -25,7 +25,7 @@ def get_light_map(img): ...@@ -25,7 +25,7 @@ def get_light_map(img):
gray = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2GRAY) gray = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2GRAY)
blur = cv2.GaussianBlur(gray, (0, 0), 3) blur = cv2.GaussianBlur(gray, (0, 0), 3)
highPass = gray.astype(int) - blur.astype(int) highPass = gray.astype(int) - blur.astype(int)
highPass = highPass.astype(np.float) highPass = highPass.astype(np.float32)
highPass = highPass / 128.0 highPass = highPass / 128.0
highPass = highPass[None] highPass = highPass[None]
return highPass.transpose((1, 2, 0)) return highPass.transpose((1, 2, 0))
...@@ -38,7 +38,7 @@ def get_light_map_single(img): ...@@ -38,7 +38,7 @@ def get_light_map_single(img):
blur = cv2.GaussianBlur(gray, (0, 0), 3) blur = cv2.GaussianBlur(gray, (0, 0), 3)
gray = gray.reshape((gray.shape[0], gray.shape[1])) gray = gray.reshape((gray.shape[0], gray.shape[1]))
highPass = gray.astype(int) - blur.astype(int) highPass = gray.astype(int) - blur.astype(int)
highPass = highPass.astype(np.float) highPass = highPass.astype(np.float32)
highPass = highPass / 128.0 highPass = highPass / 128.0
return highPass return highPass
...@@ -49,7 +49,7 @@ def get_light_map_drawer(img): ...@@ -49,7 +49,7 @@ def get_light_map_drawer(img):
highPass = gray.astype(int) - blur.astype(int) + 255 highPass = gray.astype(int) - blur.astype(int) + 255
highPass[highPass < 0] = 0 highPass[highPass < 0] = 0
highPass[highPass > 255] = 255 highPass[highPass > 255] = 255
highPass = highPass.astype(np.float) highPass = highPass.astype(np.float32)
highPass = highPass / 255.0 highPass = highPass / 255.0
highPass = 1 - highPass highPass = 1 - highPass
highPass = highPass[None] highPass = highPass[None]
...@@ -58,7 +58,7 @@ def get_light_map_drawer(img): ...@@ -58,7 +58,7 @@ def get_light_map_drawer(img):
def get_light_map_drawer2(img): def get_light_map_drawer2(img):
ret = img.copy() ret = img.copy()
ret = ret.astype(np.float) ret = ret.astype(np.float32)
ret[:, :, 0] = get_light_map_drawer3(img[:, :, 0]) ret[:, :, 0] = get_light_map_drawer3(img[:, :, 0])
ret[:, :, 1] = get_light_map_drawer3(img[:, :, 1]) ret[:, :, 1] = get_light_map_drawer3(img[:, :, 1])
ret[:, :, 2] = get_light_map_drawer3(img[:, :, 2]) ret[:, :, 2] = get_light_map_drawer3(img[:, :, 2])
...@@ -72,7 +72,7 @@ def get_light_map_drawer3(img): ...@@ -72,7 +72,7 @@ def get_light_map_drawer3(img):
highPass = gray.astype(int) - blur.astype(int) + 255 highPass = gray.astype(int) - blur.astype(int) + 255
highPass[highPass < 0] = 0 highPass[highPass < 0] = 0
highPass[highPass > 255] = 255 highPass[highPass > 255] = 255
highPass = highPass.astype(np.float) highPass = highPass.astype(np.float32)
highPass = highPass / 255.0 highPass = highPass / 255.0
highPass = 1 - highPass highPass = 1 - highPass
return highPass return highPass
...@@ -91,7 +91,7 @@ def superlize_pic(img): ...@@ -91,7 +91,7 @@ def superlize_pic(img):
def mask_pic(img, mask): def mask_pic(img, mask):
mask_mat = mask mask_mat = mask
mask_mat = mask_mat.astype(np.float) mask_mat = mask_mat.astype(np.float32)
mask_mat = cv2.GaussianBlur(mask_mat, (0, 0), 1) mask_mat = cv2.GaussianBlur(mask_mat, (0, 0), 1)
mask_mat = mask_mat / np.max(mask_mat) mask_mat = mask_mat / np.max(mask_mat)
mask_mat = mask_mat * 255 mask_mat = mask_mat * 255
...@@ -106,14 +106,14 @@ def mask_pic(img, mask): ...@@ -106,14 +106,14 @@ def mask_pic(img, mask):
def resize_img_512(img): def resize_img_512(img):
zeros = np.zeros((512, 512, img.shape[2]), dtype=np.float) zeros = np.zeros((512, 512, img.shape[2]), dtype=np.float32)
zeros[:img.shape[0], :img.shape[1]] = img zeros[:img.shape[0], :img.shape[1]] = img
return zeros return zeros
def resize_img_512_3d(img): def resize_img_512_3d(img):
zeros = np.zeros((1, 3, 512, 512), dtype=np.float) zeros = np.zeros((1, 3, 512, 512), dtype=np.float32)
zeros[0, 0:img.shape[0], 0:img.shape[1], 0:img.shape[2]] = img zeros[0, 0: img.shape[0], 0: img.shape[1], 0: img.shape[2]] = img
return zeros.transpose((1, 2, 3, 0)) return zeros.transpose((1, 2, 3, 0))
...@@ -122,8 +122,8 @@ def denoise_mat(img, i): ...@@ -122,8 +122,8 @@ def denoise_mat(img, i):
def show_active_img_and_save_denoise(img, path): def show_active_img_and_save_denoise(img, path):
mat = img.astype(np.float) mat = img.astype(np.float32)
mat = -mat + 1 mat = - mat + 1
mat = mat * 255.0 mat = mat * 255.0
mat[mat < 0] = 0 mat[mat < 0] = 0
mat[mat > 255] = 255 mat[mat > 255] = 255
...@@ -134,8 +134,8 @@ def show_active_img_and_save_denoise(img, path): ...@@ -134,8 +134,8 @@ def show_active_img_and_save_denoise(img, path):
def show_active_img(name, img): def show_active_img(name, img):
mat = img.astype(np.float) mat = img.astype(np.float32)
mat = -mat + 1 mat = - mat + 1
mat = mat * 255.0 mat = mat * 255.0
mat[mat < 0] = 0 mat[mat < 0] = 0
mat[mat > 255] = 255 mat[mat > 255] = 255
...@@ -145,8 +145,8 @@ def show_active_img(name, img): ...@@ -145,8 +145,8 @@ def show_active_img(name, img):
def get_active_img(img): def get_active_img(img):
mat = img.astype(np.float) mat = img.astype(np.float32)
mat = -mat + 1 mat = - mat + 1
mat = mat * 255.0 mat = mat * 255.0
mat[mat < 0] = 0 mat[mat < 0] = 0
mat[mat > 255] = 255 mat[mat > 255] = 255
...@@ -155,9 +155,9 @@ def get_active_img(img): ...@@ -155,9 +155,9 @@ def get_active_img(img):
def get_active_img_fil(img): def get_active_img_fil(img):
mat = img.astype(np.float) mat = img.astype(np.float32)
mat[mat < 0.18] = 0 mat[mat < 0.18] = 0
mat = -mat + 1 mat = - mat + 1
mat = mat * 255.0 mat = mat * 255.0
mat[mat < 0] = 0 mat[mat < 0] = 0
mat[mat > 255] = 255 mat[mat > 255] = 255
...@@ -166,7 +166,7 @@ def get_active_img_fil(img): ...@@ -166,7 +166,7 @@ def get_active_img_fil(img):
def show_double_active_img(name, img): def show_double_active_img(name, img):
mat = img.astype(np.float) mat = img.astype(np.float32)
mat = mat * 128.0 mat = mat * 128.0
mat = mat + 127.0 mat = mat + 127.0
mat[mat < 0] = 0 mat[mat < 0] = 0
......
import argparse import argparse
import ast import ast
import os import os
import math import cv2
import six
import time
from pathlib import Path from pathlib import Path
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor from paddle.inference import Config, create_predictor
from paddlehub.module.module import runnable, serving, moduleinfo from paddlehub.module.module import runnable, moduleinfo
from paddlehub.io.parser import txt_parser
import numpy as np import numpy as np
import paddle.fluid as fluid from .function import get_light_map_single, normalize_pic, resize_img_512_3d, show_active_img_and_save_denoise
import paddlehub as hub
from Extract_Line_Draft.function import *
@moduleinfo( @moduleinfo(
name="Extract_Line_Draft", name="Extract_Line_Draft",
version="1.0.0", version="1.1.0",
type="cv/segmentation", type="cv/segmentation",
summary="Import the color picture and generate the line draft of the picture", summary="Import the color picture and generate the line draft of the picture",
author="彭兆帅,郑博培", author="彭兆帅,郑博培",
author_email="1084667371@qq.com,2733821739@qq.com") author_email="1084667371@qq.com,2733821739@qq.com")
class ExtractLineDraft(hub.Module): class ExtractLineDraft:
def _initialize(self): def __init__(self):
""" """
Initialize with the necessary elements Initialize with the necessary elements
""" """
# 加载模型路径 # 加载模型路径
self.default_pretrained_model_path = os.path.join(self.directory, "assets", "infer_model") self.default_pretrained_model_path = os.path.join(
self.directory, "assets", "infer_model", "model")
self._set_config() self._set_config()
def _set_config(self): def _set_config(self):
...@@ -36,7 +32,9 @@ class ExtractLineDraft(hub.Module): ...@@ -36,7 +32,9 @@ class ExtractLineDraft(hub.Module):
predictor config setting predictor config setting
""" """
self.model_file_path = self.default_pretrained_model_path self.model_file_path = self.default_pretrained_model_path
cpu_config = AnalysisConfig(self.model_file_path) model = self.default_pretrained_model_path+'.pdmodel'
params = self.default_pretrained_model_path+'.pdiparams'
cpu_config = Config(model, params)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.switch_ir_optim(True) cpu_config.switch_ir_optim(True)
cpu_config.enable_memory_optim() cpu_config.enable_memory_optim()
...@@ -44,7 +42,7 @@ class ExtractLineDraft(hub.Module): ...@@ -44,7 +42,7 @@ class ExtractLineDraft(hub.Module):
cpu_config.switch_specify_input_names(True) cpu_config.switch_specify_input_names(True)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config) self.cpu_predictor = create_predictor(cpu_config)
try: try:
_places = os.environ["CUDA_VISIBLE_DEVICES"] _places = os.environ["CUDA_VISIBLE_DEVICES"]
...@@ -53,7 +51,7 @@ class ExtractLineDraft(hub.Module): ...@@ -53,7 +51,7 @@ class ExtractLineDraft(hub.Module):
except: except:
use_gpu = False use_gpu = False
if use_gpu: if use_gpu:
gpu_config = AnalysisConfig(self.model_file_path) gpu_config = Config(model, params)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.switch_ir_optim(True) gpu_config.switch_ir_optim(True)
gpu_config.enable_memory_optim() gpu_config.enable_memory_optim()
...@@ -61,7 +59,7 @@ class ExtractLineDraft(hub.Module): ...@@ -61,7 +59,7 @@ class ExtractLineDraft(hub.Module):
gpu_config.switch_specify_input_names(True) gpu_config.switch_specify_input_names(True)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(100, 0) gpu_config.enable_use_gpu(100, 0)
self.gpu_predictor = create_paddle_predictor(gpu_config) self.gpu_predictor = create_predictor(gpu_config)
# 模型预测函数 # 模型预测函数
def predict(self, input_datas): def predict(self, input_datas):
...@@ -69,9 +67,9 @@ class ExtractLineDraft(hub.Module): ...@@ -69,9 +67,9 @@ class ExtractLineDraft(hub.Module):
# 遍历输入数据进行预测 # 遍历输入数据进行预测
for input_data in input_datas: for input_data in input_datas:
inputs = input_data.copy() inputs = input_data.copy()
self.input_tensor.copy_from_cpu(inputs) self.input_handle.copy_from_cpu(inputs)
self.predictor.zero_copy_run() self.predictor.run()
output = self.output_tensor.copy_to_cpu() output = self.output_handle.copy_to_cpu()
outputs.append(output) outputs.append(output)
# 预测结果合并 # 预测结果合并
...@@ -85,7 +83,7 @@ class ExtractLineDraft(hub.Module): ...@@ -85,7 +83,7 @@ class ExtractLineDraft(hub.Module):
Get the input and program of the infer model Get the input and program of the infer model
Args: Args:
image (list(numpy.ndarray)): images data, shape of each is [H, W, C], the color space is BGR. image (str): image path
use_gpu(bool): Weather to use gpu use_gpu(bool): Weather to use gpu
""" """
if use_gpu: if use_gpu:
...@@ -103,16 +101,18 @@ class ExtractLineDraft(hub.Module): ...@@ -103,16 +101,18 @@ class ExtractLineDraft(hub.Module):
new_width = 0 new_width = 0
new_height = 0 new_height = 0
if (width > height): if (width > height):
from_mat = cv2.resize(from_mat, (512, int(512 / width * height)), interpolation=cv2.INTER_AREA) from_mat = cv2.resize(
from_mat, (512, int(512 / width * height)), interpolation=cv2.INTER_AREA)
new_width = 512 new_width = 512
new_height = int(512 / width * height) new_height = int(512 / width * height)
else: else:
from_mat = cv2.resize(from_mat, (int(512 / height * width), 512), interpolation=cv2.INTER_AREA) from_mat = cv2.resize(
from_mat, (int(512 / height * width), 512), interpolation=cv2.INTER_AREA)
new_width = int(512 / height * width) new_width = int(512 / height * width)
new_height = 512 new_height = 512
from_mat = from_mat.transpose((2, 0, 1)) from_mat = from_mat.transpose((2, 0, 1))
light_map = np.zeros(from_mat.shape, dtype=np.float) light_map = np.zeros(from_mat.shape, dtype=np.float32)
for channel in range(3): for channel in range(3):
light_map[channel] = get_light_map_single(from_mat[channel]) light_map[channel] = get_light_map_single(from_mat[channel])
light_map = normalize_pic(light_map) light_map = normalize_pic(light_map)
...@@ -127,9 +127,12 @@ class ExtractLineDraft(hub.Module): ...@@ -127,9 +127,12 @@ class ExtractLineDraft(hub.Module):
self.input_names = self.predictor.get_input_names() self.input_names = self.predictor.get_input_names()
self.output_names = self.predictor.get_output_names() self.output_names = self.predictor.get_output_names()
self.input_tensor = self.predictor.get_input_tensor(self.input_names[0]) self.input_handle = self.predictor.get_input_handle(
self.output_tensor = self.predictor.get_output_tensor(self.output_names[0]) self.input_names[0])
line_mat = self.predict(np.expand_dims(light_map, axis=0).astype('float32')) self.output_handle = self.predictor.get_output_handle(
self.output_names[0])
line_mat = self.predict(np.expand_dims(
light_map, axis=0).astype('float32'))
# 去除 batch 维度 (512, 512, 3) # 去除 batch 维度 (512, 512, 3)
line_mat = line_mat.transpose((3, 1, 2, 0))[0] line_mat = line_mat.transpose((3, 1, 2, 0))[0]
# 裁剪 (512, 384, 3) # 裁剪 (512, 384, 3)
...@@ -137,10 +140,12 @@ class ExtractLineDraft(hub.Module): ...@@ -137,10 +140,12 @@ class ExtractLineDraft(hub.Module):
line_mat = np.amax(line_mat, 2) line_mat = np.amax(line_mat, 2)
# 保存图片 # 保存图片
if Path('./output/').exists(): if Path('./output/').exists():
show_active_img_and_save_denoise(line_mat, './output/' + 'output.png') show_active_img_and_save_denoise(
line_mat, './output/' + 'output.png')
else: else:
os.makedirs('./output/') os.makedirs('./output/')
show_active_img_and_save_denoise(line_mat, './output/' + 'output.png') show_active_img_and_save_denoise(
line_mat, './output/' + 'output.png')
print('图片已经完成') print('图片已经完成')
@runnable @runnable
...@@ -154,9 +159,11 @@ class ExtractLineDraft(hub.Module): ...@@ -154,9 +159,11 @@ class ExtractLineDraft(hub.Module):
usage='%(prog)s', usage='%(prog)s',
add_help=True) add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") self.arg_input_group = self.parser.add_argument_group(
title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group( self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, not required.") title="Config options",
description="Run configuration for controlling module behavior, not required.")
self.add_module_input_arg() self.add_module_input_arg()
...@@ -175,8 +182,16 @@ class ExtractLineDraft(hub.Module): ...@@ -175,8 +182,16 @@ class ExtractLineDraft(hub.Module):
""" """
Add the command input options Add the command input options
""" """
self.arg_input_group.add_argument('--image', type=str, default=None, help="file contain input data") self.arg_input_group.add_argument(
self.arg_input_group.add_argument('--use_gpu', type=ast.literal_eval, default=None, help="weather to use gpu") '--image',
type=str,
default=None,
help="file contain input data")
self.arg_input_group.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=None,
help="weather to use gpu")
def check_input_data(self, args): def check_input_data(self, args):
input_data = [] input_data = []
......
import os
import shutil
import unittest
import cv2
import requests
import paddlehub as hub
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class TestHubModule(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
img_url = 'https://ai-studio-static-online.cdn.bcebos.com/1c30757e069541a18dc89b92f0750983b77ad762560849afa0170046672e57a3'
if not os.path.exists('tests'):
os.makedirs('tests')
response = requests.get(img_url)
assert response.status_code == 200, 'Network Error.'
with open('tests/test.jpg', 'wb') as f:
f.write(response.content)
cls.module = hub.Module(name="Extract_Line_Draft")
@classmethod
def tearDownClass(cls) -> None:
shutil.rmtree('tests')
shutil.rmtree('inference')
shutil.rmtree('output')
def test_ExtractLine1(self):
self.module.ExtractLine(
image='tests/test.jpg',
use_gpu=False
)
self.assertTrue(os.path.exists('output/output.png'))
def test_ExtractLine2(self):
self.module.ExtractLine(
image='tests/test.jpg',
use_gpu=True
)
self.assertTrue(os.path.exists('output/output.png'))
def test_ExtractLine3(self):
self.assertRaises(
AttributeError,
self.module.ExtractLine,
image='no.jpg'
)
def test_ExtractLine4(self):
self.assertRaises(
TypeError,
self.module.ExtractLine,
image=['tests/test.jpg']
)
def test_save_inference_model(self):
self.module.save_inference_model('./inference/model')
self.assertTrue(os.path.exists('./inference/model.pdmodel'))
self.assertTrue(os.path.exists('./inference/model.pdiparams'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册