diff --git a/modules/image/semantic_segmentation/Extract_Line_Draft/Readme.md b/modules/image/semantic_segmentation/Extract_Line_Draft/Readme.md
index 92808c8b4ba6ea7b4f1f0aad8aa6dd39268703c7..813a94eab04e23ab541e72514c55fb8f3f92531e 100644
--- a/modules/image/semantic_segmentation/Extract_Line_Draft/Readme.md
+++ b/modules/image/semantic_segmentation/Extract_Line_Draft/Readme.md
@@ -1,57 +1,90 @@
-Extract_Line_Draft
-类别 图像 - 图像分割
+# Extract_Line_Draft
-# 模型概述
-提取线稿(Extract_Line_Draft),该模型可自动根据彩色图生成线稿图。该PaddleHub Module支持API预测及命令行预测。
+|模型名称|Extract_Line_Draft|
+| :--- | :---: |
+|类别|图像-图像分割|
+|网络|-|
+|数据集|-|
+|是否支持Fine-tuning|否|
+|模型大小|259MB|
+|指标|-|
+|最新更新日期|2021-02-26|
-# 选择模型版本进行安装
-$ hub install Extract_Line_Draft==1.0.0
-# 命令行预测示例
-$ hub run Extract_Line_Draft --image 1.png --use_gpu True
+## 一、模型基本信息
-# Module API说明
-## ExtractLine(self, image, use_gpu=False)
-提取线稿预测接口,预测输入一张图像,输出该图像的线稿
-### 参数
-- image(str): 待检测的图片路径
-- use_gpu (bool): 是否使用 GPU
+- ### 应用效果展示
+ - 样例结果示例:
+
+
+
-# 代码示例
+- ### 模型介绍
-## API调用
-~~~
-import paddlehub as hub
+ - 提取线稿(Extract_Line_Draft),该模型可自动根据彩色图生成线稿图。该PaddleHub Module支持API预测及命令行预测。
-Extract_Line_Draft_test = hub.Module(name="Extract_Line_Draft")
-test_img = "testImage.png"
+## 二、安装
-# execute predict
-Extract_Line_Draft_test.ExtractLine(test_img, use_gpu=True)
-~~~
+- ### 1、环境依赖
-## 命令行调用
-~~~
-!hub run Extract_Line_Draft --input_path "testImage" --use_gpu True
-~~~
+ - paddlepaddle >= 2.0.0
-# 效果展示
+ - paddlehub >= 2.0.0
-## 原图
-![](https://ai-studio-static-online.cdn.bcebos.com/1c30757e069541a18dc89b92f0750983b77ad762560849afa0170046672e57a3)
-![](https://ai-studio-static-online.cdn.bcebos.com/4a544c9ecd79461bbc1d1556d100b21d28b41b4f23db440ab776af78764292f2)
+- ### 2.安装
+ - ```shell
+ $ hub install Extract_Line_Draft
+ ```
+ - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
+ | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
-## 线稿图
-![](https://ai-studio-static-online.cdn.bcebos.com/7ef00637e5974be2847317053f8abe97236cec75fba14f77be2c095529a1eeb3)
-![](https://ai-studio-static-online.cdn.bcebos.com/074ea02d89bc4b5c9004a077b61301fa49583c13af734bd6a49e81f59f9cd322)
+## 三、模型API预测
+ - ### 1、命令行预测
+ ```shell
+ $ hub run Extract_Line_Draft --input_path "testImage" --use_gpu True
+ ```
-# 贡献者
-彭兆帅、郑博培
+ - ### 2、预测代码示例
-# 依赖
-paddlepaddle >= 1.8.2
-paddlehub >= 1.8.0
+ ```python
+ import paddlehub as hub
+
+ Extract_Line_Draft_test = hub.Module(name="Extract_Line_Draft")
+
+ test_img = "testImage.png"
+
+ # execute predict
+ Extract_Line_Draft_test.ExtractLine(test_img, use_gpu=True)
+ ```
+
+ - ### 3、API
+
+ ```python
+ def ExtractLine(image, use_gpu=False)
+ ```
+
+ - 预测API,用于图像分割得到人体解析。
+
+ - **参数**
+
+ * image(str): 待检测的图片路径
+ * use_gpu (bool): 是否使用 GPU
+
+
+## 四、更新历史
+
+* 1.0.0
+
+ 初始发布
+
+* 1.1.0
+
+ 移除 Fluid API
+
+ ```shell
+ $ hub install Extract_Line_Draft == 1.1.0
+ ```
\ No newline at end of file
diff --git a/modules/image/semantic_segmentation/Extract_Line_Draft/function.py b/modules/image/semantic_segmentation/Extract_Line_Draft/function.py
index bb83b2237b9bcff8cfdb730175bb2e4250868123..aa3fe8385fc05f7d3fe562eb686c6682fe3342e9 100644
--- a/modules/image/semantic_segmentation/Extract_Line_Draft/function.py
+++ b/modules/image/semantic_segmentation/Extract_Line_Draft/function.py
@@ -4,9 +4,9 @@ from scipy import ndimage
def get_normal_map(img):
- img = img.astype(np.float)
+ img = img.astype(np.float32)
img = img / 255.0
- img = -img + 1
+ img = - img + 1
img[img < 0] = 0
img[img > 1] = 1
return img
@@ -14,7 +14,7 @@ def get_normal_map(img):
def get_gray_map(img):
gray = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2GRAY)
- highPass = gray.astype(np.float)
+ highPass = gray.astype(np.float32)
highPass = highPass / 255.0
highPass = 1 - highPass
highPass = highPass[None]
@@ -25,7 +25,7 @@ def get_light_map(img):
gray = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2GRAY)
blur = cv2.GaussianBlur(gray, (0, 0), 3)
highPass = gray.astype(int) - blur.astype(int)
- highPass = highPass.astype(np.float)
+ highPass = highPass.astype(np.float32)
highPass = highPass / 128.0
highPass = highPass[None]
return highPass.transpose((1, 2, 0))
@@ -38,7 +38,7 @@ def get_light_map_single(img):
blur = cv2.GaussianBlur(gray, (0, 0), 3)
gray = gray.reshape((gray.shape[0], gray.shape[1]))
highPass = gray.astype(int) - blur.astype(int)
- highPass = highPass.astype(np.float)
+ highPass = highPass.astype(np.float32)
highPass = highPass / 128.0
return highPass
@@ -49,7 +49,7 @@ def get_light_map_drawer(img):
highPass = gray.astype(int) - blur.astype(int) + 255
highPass[highPass < 0] = 0
highPass[highPass > 255] = 255
- highPass = highPass.astype(np.float)
+ highPass = highPass.astype(np.float32)
highPass = highPass / 255.0
highPass = 1 - highPass
highPass = highPass[None]
@@ -58,7 +58,7 @@ def get_light_map_drawer(img):
def get_light_map_drawer2(img):
ret = img.copy()
- ret = ret.astype(np.float)
+ ret = ret.astype(np.float32)
ret[:, :, 0] = get_light_map_drawer3(img[:, :, 0])
ret[:, :, 1] = get_light_map_drawer3(img[:, :, 1])
ret[:, :, 2] = get_light_map_drawer3(img[:, :, 2])
@@ -72,7 +72,7 @@ def get_light_map_drawer3(img):
highPass = gray.astype(int) - blur.astype(int) + 255
highPass[highPass < 0] = 0
highPass[highPass > 255] = 255
- highPass = highPass.astype(np.float)
+ highPass = highPass.astype(np.float32)
highPass = highPass / 255.0
highPass = 1 - highPass
return highPass
@@ -91,7 +91,7 @@ def superlize_pic(img):
def mask_pic(img, mask):
mask_mat = mask
- mask_mat = mask_mat.astype(np.float)
+ mask_mat = mask_mat.astype(np.float32)
mask_mat = cv2.GaussianBlur(mask_mat, (0, 0), 1)
mask_mat = mask_mat / np.max(mask_mat)
mask_mat = mask_mat * 255
@@ -106,14 +106,14 @@ def mask_pic(img, mask):
def resize_img_512(img):
- zeros = np.zeros((512, 512, img.shape[2]), dtype=np.float)
+ zeros = np.zeros((512, 512, img.shape[2]), dtype=np.float32)
zeros[:img.shape[0], :img.shape[1]] = img
return zeros
def resize_img_512_3d(img):
- zeros = np.zeros((1, 3, 512, 512), dtype=np.float)
- zeros[0, 0:img.shape[0], 0:img.shape[1], 0:img.shape[2]] = img
+ zeros = np.zeros((1, 3, 512, 512), dtype=np.float32)
+ zeros[0, 0: img.shape[0], 0: img.shape[1], 0: img.shape[2]] = img
return zeros.transpose((1, 2, 3, 0))
@@ -122,8 +122,8 @@ def denoise_mat(img, i):
def show_active_img_and_save_denoise(img, path):
- mat = img.astype(np.float)
- mat = -mat + 1
+ mat = img.astype(np.float32)
+ mat = - mat + 1
mat = mat * 255.0
mat[mat < 0] = 0
mat[mat > 255] = 255
@@ -134,8 +134,8 @@ def show_active_img_and_save_denoise(img, path):
def show_active_img(name, img):
- mat = img.astype(np.float)
- mat = -mat + 1
+ mat = img.astype(np.float32)
+ mat = - mat + 1
mat = mat * 255.0
mat[mat < 0] = 0
mat[mat > 255] = 255
@@ -145,8 +145,8 @@ def show_active_img(name, img):
def get_active_img(img):
- mat = img.astype(np.float)
- mat = -mat + 1
+ mat = img.astype(np.float32)
+ mat = - mat + 1
mat = mat * 255.0
mat[mat < 0] = 0
mat[mat > 255] = 255
@@ -155,9 +155,9 @@ def get_active_img(img):
def get_active_img_fil(img):
- mat = img.astype(np.float)
+ mat = img.astype(np.float32)
mat[mat < 0.18] = 0
- mat = -mat + 1
+ mat = - mat + 1
mat = mat * 255.0
mat[mat < 0] = 0
mat[mat > 255] = 255
@@ -166,7 +166,7 @@ def get_active_img_fil(img):
def show_double_active_img(name, img):
- mat = img.astype(np.float)
+ mat = img.astype(np.float32)
mat = mat * 128.0
mat = mat + 127.0
mat[mat < 0] = 0
diff --git a/modules/image/semantic_segmentation/Extract_Line_Draft/module.py b/modules/image/semantic_segmentation/Extract_Line_Draft/module.py
index f1aa37769f14281a9d1e4e99980c76494aa81af9..8096a33a78b01c781e3e4420523d3f14bbb7f19a 100644
--- a/modules/image/semantic_segmentation/Extract_Line_Draft/module.py
+++ b/modules/image/semantic_segmentation/Extract_Line_Draft/module.py
@@ -1,34 +1,30 @@
import argparse
import ast
import os
-import math
-import six
-import time
+import cv2
from pathlib import Path
-from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
-from paddlehub.module.module import runnable, serving, moduleinfo
-from paddlehub.io.parser import txt_parser
+from paddle.inference import Config, create_predictor
+from paddlehub.module.module import runnable, moduleinfo
import numpy as np
-import paddle.fluid as fluid
-import paddlehub as hub
-from Extract_Line_Draft.function import *
+from .function import get_light_map_single, normalize_pic, resize_img_512_3d, show_active_img_and_save_denoise
@moduleinfo(
name="Extract_Line_Draft",
- version="1.0.0",
+ version="1.1.0",
type="cv/segmentation",
summary="Import the color picture and generate the line draft of the picture",
author="彭兆帅,郑博培",
author_email="1084667371@qq.com,2733821739@qq.com")
-class ExtractLineDraft(hub.Module):
- def _initialize(self):
+class ExtractLineDraft:
+ def __init__(self):
"""
Initialize with the necessary elements
"""
# 加载模型路径
- self.default_pretrained_model_path = os.path.join(self.directory, "assets", "infer_model")
+ self.default_pretrained_model_path = os.path.join(
+ self.directory, "assets", "infer_model", "model")
self._set_config()
def _set_config(self):
@@ -36,7 +32,9 @@ class ExtractLineDraft(hub.Module):
predictor config setting
"""
self.model_file_path = self.default_pretrained_model_path
- cpu_config = AnalysisConfig(self.model_file_path)
+ model = self.default_pretrained_model_path+'.pdmodel'
+ params = self.default_pretrained_model_path+'.pdiparams'
+ cpu_config = Config(model, params)
cpu_config.disable_glog_info()
cpu_config.switch_ir_optim(True)
cpu_config.enable_memory_optim()
@@ -44,7 +42,7 @@ class ExtractLineDraft(hub.Module):
cpu_config.switch_specify_input_names(True)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
- self.cpu_predictor = create_paddle_predictor(cpu_config)
+ self.cpu_predictor = create_predictor(cpu_config)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
@@ -53,7 +51,7 @@ class ExtractLineDraft(hub.Module):
except:
use_gpu = False
if use_gpu:
- gpu_config = AnalysisConfig(self.model_file_path)
+ gpu_config = Config(model, params)
gpu_config.disable_glog_info()
gpu_config.switch_ir_optim(True)
gpu_config.enable_memory_optim()
@@ -61,7 +59,7 @@ class ExtractLineDraft(hub.Module):
gpu_config.switch_specify_input_names(True)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(100, 0)
- self.gpu_predictor = create_paddle_predictor(gpu_config)
+ self.gpu_predictor = create_predictor(gpu_config)
# 模型预测函数
def predict(self, input_datas):
@@ -69,9 +67,9 @@ class ExtractLineDraft(hub.Module):
# 遍历输入数据进行预测
for input_data in input_datas:
inputs = input_data.copy()
- self.input_tensor.copy_from_cpu(inputs)
- self.predictor.zero_copy_run()
- output = self.output_tensor.copy_to_cpu()
+ self.input_handle.copy_from_cpu(inputs)
+ self.predictor.run()
+ output = self.output_handle.copy_to_cpu()
outputs.append(output)
# 预测结果合并
@@ -85,7 +83,7 @@ class ExtractLineDraft(hub.Module):
Get the input and program of the infer model
Args:
- image (list(numpy.ndarray)): images data, shape of each is [H, W, C], the color space is BGR.
+ image (str): image path
use_gpu(bool): Weather to use gpu
"""
if use_gpu:
@@ -103,16 +101,18 @@ class ExtractLineDraft(hub.Module):
new_width = 0
new_height = 0
if (width > height):
- from_mat = cv2.resize(from_mat, (512, int(512 / width * height)), interpolation=cv2.INTER_AREA)
+ from_mat = cv2.resize(
+ from_mat, (512, int(512 / width * height)), interpolation=cv2.INTER_AREA)
new_width = 512
new_height = int(512 / width * height)
else:
- from_mat = cv2.resize(from_mat, (int(512 / height * width), 512), interpolation=cv2.INTER_AREA)
+ from_mat = cv2.resize(
+ from_mat, (int(512 / height * width), 512), interpolation=cv2.INTER_AREA)
new_width = int(512 / height * width)
new_height = 512
from_mat = from_mat.transpose((2, 0, 1))
- light_map = np.zeros(from_mat.shape, dtype=np.float)
+ light_map = np.zeros(from_mat.shape, dtype=np.float32)
for channel in range(3):
light_map[channel] = get_light_map_single(from_mat[channel])
light_map = normalize_pic(light_map)
@@ -127,9 +127,12 @@ class ExtractLineDraft(hub.Module):
self.input_names = self.predictor.get_input_names()
self.output_names = self.predictor.get_output_names()
- self.input_tensor = self.predictor.get_input_tensor(self.input_names[0])
- self.output_tensor = self.predictor.get_output_tensor(self.output_names[0])
- line_mat = self.predict(np.expand_dims(light_map, axis=0).astype('float32'))
+ self.input_handle = self.predictor.get_input_handle(
+ self.input_names[0])
+ self.output_handle = self.predictor.get_output_handle(
+ self.output_names[0])
+ line_mat = self.predict(np.expand_dims(
+ light_map, axis=0).astype('float32'))
# 去除 batch 维度 (512, 512, 3)
line_mat = line_mat.transpose((3, 1, 2, 0))[0]
# 裁剪 (512, 384, 3)
@@ -137,10 +140,12 @@ class ExtractLineDraft(hub.Module):
line_mat = np.amax(line_mat, 2)
# 保存图片
if Path('./output/').exists():
- show_active_img_and_save_denoise(line_mat, './output/' + 'output.png')
+ show_active_img_and_save_denoise(
+ line_mat, './output/' + 'output.png')
else:
os.makedirs('./output/')
- show_active_img_and_save_denoise(line_mat, './output/' + 'output.png')
+ show_active_img_and_save_denoise(
+ line_mat, './output/' + 'output.png')
print('图片已经完成')
@runnable
@@ -154,9 +159,11 @@ class ExtractLineDraft(hub.Module):
usage='%(prog)s',
add_help=True)
- self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
+ self.arg_input_group = self.parser.add_argument_group(
+ title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
- title="Config options", description="Run configuration for controlling module behavior, not required.")
+ title="Config options",
+ description="Run configuration for controlling module behavior, not required.")
self.add_module_input_arg()
@@ -175,8 +182,16 @@ class ExtractLineDraft(hub.Module):
"""
Add the command input options
"""
- self.arg_input_group.add_argument('--image', type=str, default=None, help="file contain input data")
- self.arg_input_group.add_argument('--use_gpu', type=ast.literal_eval, default=None, help="weather to use gpu")
+ self.arg_input_group.add_argument(
+ '--image',
+ type=str,
+ default=None,
+ help="file contain input data")
+ self.arg_input_group.add_argument(
+ '--use_gpu',
+ type=ast.literal_eval,
+ default=None,
+ help="weather to use gpu")
def check_input_data(self, args):
input_data = []
diff --git a/modules/image/semantic_segmentation/Extract_Line_Draft/test.py b/modules/image/semantic_segmentation/Extract_Line_Draft/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..6235b38ab95b11cd597054cd347856598d100080
--- /dev/null
+++ b/modules/image/semantic_segmentation/Extract_Line_Draft/test.py
@@ -0,0 +1,66 @@
+import os
+import shutil
+import unittest
+
+import cv2
+import requests
+import paddlehub as hub
+
+
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+class TestHubModule(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls) -> None:
+ img_url = 'https://ai-studio-static-online.cdn.bcebos.com/1c30757e069541a18dc89b92f0750983b77ad762560849afa0170046672e57a3'
+ if not os.path.exists('tests'):
+ os.makedirs('tests')
+ response = requests.get(img_url)
+ assert response.status_code == 200, 'Network Error.'
+ with open('tests/test.jpg', 'wb') as f:
+ f.write(response.content)
+ cls.module = hub.Module(name="Extract_Line_Draft")
+
+ @classmethod
+ def tearDownClass(cls) -> None:
+ shutil.rmtree('tests')
+ shutil.rmtree('inference')
+ shutil.rmtree('output')
+
+ def test_ExtractLine1(self):
+ self.module.ExtractLine(
+ image='tests/test.jpg',
+ use_gpu=False
+ )
+ self.assertTrue(os.path.exists('output/output.png'))
+
+ def test_ExtractLine2(self):
+ self.module.ExtractLine(
+ image='tests/test.jpg',
+ use_gpu=True
+ )
+ self.assertTrue(os.path.exists('output/output.png'))
+
+ def test_ExtractLine3(self):
+ self.assertRaises(
+ AttributeError,
+ self.module.ExtractLine,
+ image='no.jpg'
+ )
+
+ def test_ExtractLine4(self):
+ self.assertRaises(
+ TypeError,
+ self.module.ExtractLine,
+ image=['tests/test.jpg']
+ )
+
+ def test_save_inference_model(self):
+ self.module.save_inference_model('./inference/model')
+
+ self.assertTrue(os.path.exists('./inference/model.pdmodel'))
+ self.assertTrue(os.path.exists('./inference/model.pdiparams'))
+
+
+if __name__ == "__main__":
+ unittest.main()