未验证 提交 24e9287e 编写于 作者: 郑博培 提交者: GitHub

Add Extract_Line_Draft module

上级 8fa0291f
Extract_Line_Draft
类别 图像 - 图像分割
# 模型概述
提取线稿(Extract_Line_Draft),该模型可自动根据彩色图生成线稿图。该PaddleHub Module支持API预测及命令行预测。
# 选择模型版本进行安装
$ hub install Extract_Line_Draft==1.0.0
# 命令行预测示例
$ hub run Extract_Line_Draft --image 1.png --use_gpu True
# Module API说明
## ExtractLine(self, image, use_gpu=False)
提取线稿预测接口,预测输入一张图像,输出该图像的线稿
### 参数
- image(str): 待检测的图片路径
- use_gpu (bool): 是否使用 GPU
# 代码示例
## API调用
~~~
import paddlehub as hub
Extract_Line_Draft_test = hub.Module(name="Extract_Line_Draft")
test_img = "testImage.png"
# execute predict
Extract_Line_Draft_test.ExtractLine(test_img, use_gpu=True)
~~~
## 命令行调用
~~~
!hub run Extract_Line_Draft --input_path "testImage" --use_gpu True
~~~
# 效果展示
## 原图
![](https://ai-studio-static-online.cdn.bcebos.com/1c30757e069541a18dc89b92f0750983b77ad762560849afa0170046672e57a3)
![](https://ai-studio-static-online.cdn.bcebos.com/4a544c9ecd79461bbc1d1556d100b21d28b41b4f23db440ab776af78764292f2)
## 线稿图
![](https://ai-studio-static-online.cdn.bcebos.com/7ef00637e5974be2847317053f8abe97236cec75fba14f77be2c095529a1eeb3)
![](https://ai-studio-static-online.cdn.bcebos.com/074ea02d89bc4b5c9004a077b61301fa49583c13af734bd6a49e81f59f9cd322)
# 贡献者
彭兆帅、郑博培
# 依赖
paddlepaddle >= 1.8.2
paddlehub >= 1.8.0
import numpy as np
import cv2
from scipy import ndimage
def get_normal_map(img):
img = img.astype(np.float)
img = img / 255.0
img = - img + 1
img[img < 0] = 0
img[img > 1] = 1
return img
def get_gray_map(img):
gray = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2GRAY)
highPass = gray.astype(np.float)
highPass = highPass / 255.0
highPass = 1 - highPass
highPass = highPass[None]
return highPass.transpose((1, 2, 0))
def get_light_map(img):
gray = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2GRAY)
blur = cv2.GaussianBlur(gray, (0, 0), 3)
highPass = gray.astype(int) - blur.astype(int)
highPass = highPass.astype(np.float)
highPass = highPass / 128.0
highPass = highPass[None]
return highPass.transpose((1, 2, 0))
def get_light_map_single(img):
gray = img
gray = gray[None]
gray = gray.transpose((1, 2, 0))
blur = cv2.GaussianBlur(gray, (0, 0), 3)
gray = gray.reshape((gray.shape[0], gray.shape[1]))
highPass = gray.astype(int) - blur.astype(int)
highPass = highPass.astype(np.float)
highPass = highPass / 128.0
return highPass
def get_light_map_drawer(img):
gray = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2GRAY)
blur = cv2.GaussianBlur(gray, (0, 0), 3)
highPass = gray.astype(int) - blur.astype(int) + 255
highPass[highPass < 0] = 0
highPass[highPass > 255] = 255
highPass = highPass.astype(np.float)
highPass = highPass / 255.0
highPass = 1 - highPass
highPass = highPass[None]
return highPass.transpose((1, 2, 0))
def get_light_map_drawer2(img):
ret = img.copy()
ret = ret.astype(np.float)
ret[:, :, 0] = get_light_map_drawer3(img[:, :, 0])
ret[:, :, 1] = get_light_map_drawer3(img[:, :, 1])
ret[:, :, 2] = get_light_map_drawer3(img[:, :, 2])
ret = np.amax(ret, 2)
return ret
def get_light_map_drawer3(img):
gray = img
blur = cv2.blur(gray, ksize=(5, 5))
highPass = gray.astype(int) - blur.astype(int) + 255
highPass[highPass < 0] = 0
highPass[highPass > 255] = 255
highPass = highPass.astype(np.float)
highPass = highPass / 255.0
highPass = 1 - highPass
return highPass
def normalize_pic(img):
img = img / np.max(img)
return img
def superlize_pic(img):
img = img * 2.33333
img[img > 1] = 1
return img
def mask_pic(img, mask):
mask_mat = mask
mask_mat = mask_mat.astype(np.float)
mask_mat = cv2.GaussianBlur(mask_mat, (0, 0), 1)
mask_mat = mask_mat / np.max(mask_mat)
mask_mat = mask_mat * 255
mask_mat[mask_mat < 255] = 0
mask_mat = mask_mat.astype(np.uint8)
mask_mat = cv2.GaussianBlur(mask_mat, (0, 0), 3)
mask_mat = get_gray_map(mask_mat)
mask_mat = normalize_pic(mask_mat)
mask_mat = resize_img_512(mask_mat)
super_from = np.multiply(img, mask_mat)
return super_from
def resize_img_512(img):
zeros = np.zeros((512, 512, img.shape[2]), dtype=np.float)
zeros[:img.shape[0], :img.shape[1]] = img
return zeros
def resize_img_512_3d(img):
zeros = np.zeros((1, 3, 512, 512), dtype=np.float)
zeros[0, 0: img.shape[0], 0: img.shape[1], 0: img.shape[2]] = img
return zeros.transpose((1, 2, 3, 0))
def denoise_mat(img, i):
return ndimage.median_filter(img, i)
def show_active_img_and_save_denoise(img, path):
mat = img.astype(np.float)
mat = - mat + 1
mat = mat * 255.0
mat[mat < 0] = 0
mat[mat > 255] = 255
mat = mat.astype(np.uint8)
mat = ndimage.median_filter(mat, 1)
cv2.imwrite(path, mat)
return
def show_active_img(name, img):
mat = img.astype(np.float)
mat = - mat + 1
mat = mat * 255.0
mat[mat < 0] = 0
mat[mat > 255] = 255
mat = mat.astype(np.uint8)
cv2.imshow(name, mat)
return
def get_active_img(img):
mat = img.astype(np.float)
mat = - mat + 1
mat = mat * 255.0
mat[mat < 0] = 0
mat[mat > 255] = 255
mat = mat.astype(np.uint8)
return mat
def get_active_img_fil(img):
mat = img.astype(np.float)
mat[mat < 0.18] = 0
mat = - mat + 1
mat = mat * 255.0
mat[mat < 0] = 0
mat[mat > 255] = 255
mat = mat.astype(np.uint8)
return mat
def show_double_active_img(name, img):
mat = img.astype(np.float)
mat = mat * 128.0
mat = mat + 127.0
mat[mat < 0] = 0
mat[mat > 255] = 255
cv2.imshow(name, mat.astype(np.uint8))
return
def debug_pic_helper():
for index in range(1130):
gray_path = 'data\\gray\\' + str(index) + '.jpg'
color_path = 'data\\color\\' + str(index) + '.jpg'
mat_color = cv2.imread(color_path)
mat_color = get_light_map(mat_color)
mat_color = normalize_pic(mat_color)
mat_color = resize_img_512(mat_color)
show_double_active_img('mat_color', mat_color)
mat_gray = cv2.imread(gray_path)
mat_gray = get_gray_map(mat_gray)
mat_gray = normalize_pic(mat_gray)
mat_gray = resize_img_512(mat_gray)
show_active_img('mat_gray', mat_gray)
cv2.waitKey(1000)
import argparse
import ast
import os
import math
import six
import time
from pathlib import Path
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddlehub.module.module import runnable, serving, moduleinfo
from paddlehub.io.parser import txt_parser
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from Extract_Line_Draft.function import *
@moduleinfo(
name="Extract_Line_Draft",
version="1.0.0",
type="cv/segmentation",
summary=
"Import the color picture and generate the line draft of the picture",
author="彭兆帅,郑博培",
author_email="1084667371@qq.com,2733821739@qq.com")
class ExtractLineDraft(hub.Module):
def _initialize(self):
"""
Initialize with the necessary elements
"""
# 加载模型路径
self.default_pretrained_model_path = os.path.join(self.directory, "assets","infer_model")
self._set_config()
def _set_config(self):
"""
predictor config setting
"""
self.model_file_path = self.default_pretrained_model_path
cpu_config = AnalysisConfig(self.model_file_path)
cpu_config.disable_glog_info()
cpu_config.switch_ir_optim(True)
cpu_config.enable_memory_optim()
cpu_config.switch_use_feed_fetch_ops(False)
cpu_config.switch_specify_input_names(True)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
except:
use_gpu = False
if use_gpu:
gpu_config = AnalysisConfig(self.model_file_path)
gpu_config.disable_glog_info()
gpu_config.switch_ir_optim(True)
gpu_config.enable_memory_optim()
gpu_config.switch_use_feed_fetch_ops(False)
gpu_config.switch_specify_input_names(True)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(100, 0)
self.gpu_predictor = create_paddle_predictor(gpu_config)
# 模型预测函数
def predict(self, input_datas):
outputs = []
# 遍历输入数据进行预测
for input_data in input_datas:
inputs = input_data.copy()
self.input_tensor.copy_from_cpu(inputs)
self.predictor.zero_copy_run()
output = self.output_tensor.copy_to_cpu()
outputs.append(output)
# 预测结果合并
outputs = np.concatenate(outputs, 0)
# 返回预测结果
return outputs
def ExtractLine(self, image, use_gpu=False):
"""
Get the input and program of the infer model
Args:
image (list(numpy.ndarray)): images data, shape of each is [H, W, C], the color space is BGR.
use_gpu(bool): Weather to use gpu
"""
if use_gpu:
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
except:
raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id."
)
from_mat = cv2.imread(image)
width = float(from_mat.shape[1])
height = float(from_mat.shape[0])
new_width = 0
new_height = 0
if (width > height):
from_mat = cv2.resize(from_mat, (512, int(512 / width * height)), interpolation=cv2.INTER_AREA)
new_width = 512
new_height = int(512 / width * height)
else:
from_mat = cv2.resize(from_mat, (int(512 / height * width), 512), interpolation=cv2.INTER_AREA)
new_width = int(512 / height * width)
new_height = 512
from_mat = from_mat.transpose((2, 0, 1))
light_map = np.zeros(from_mat.shape, dtype=np.float)
for channel in range(3):
light_map[channel] = get_light_map_single(from_mat[channel])
light_map = normalize_pic(light_map)
light_map = resize_img_512_3d(light_map)
light_map = light_map.astype('float32')
# 获取模型的输入输出
if use_gpu:
self.predictor = self.gpu_predictor
else:
self.predictor = self.cpu_predictor
self.input_names = self.predictor.get_input_names()
self.output_names = self.predictor.get_output_names()
self.input_tensor = self.predictor.get_input_tensor(self.input_names[0])
self.output_tensor = self.predictor.get_output_tensor(self.output_names[0])
line_mat = self.predict(np.expand_dims(light_map, axis=0).astype('float32'))
# 去除 batch 维度 (512, 512, 3)
line_mat = line_mat.transpose((3, 1, 2, 0))[0]
# 裁剪 (512, 384, 3)
line_mat = line_mat[0:int(new_height), 0:int(new_width), :]
line_mat = np.amax(line_mat, 2)
# 保存图片
if Path('./output/').exists():
show_active_img_and_save_denoise(line_mat, './output/' + 'output.png')
else:
os.makedirs('./output/')
show_active_img_and_save_denoise(line_mat, './output/' + 'output.png')
print('图片已经完成')
@runnable
def run_cmd(self, argvs):
"""
Run as a command.
"""
self.parser = argparse.ArgumentParser(
description='Run the %s module.' % self.name,
prog='hub run %s' % self.name,
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(
title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options",
description=
"Run configuration for controlling module behavior, not required.")
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
try:
input_data = self.check_input_data(args)
except RuntimeError:
self.parser.print_help()
return None
use_gpu = args.use_gpu
self.ExtractLine(image=input_data, use_gpu=use_gpu)
def add_module_input_arg(self):
"""
Add the command input options
"""
self.arg_input_group.add_argument(
'--image',
type=str,
default=None,
help="file contain input data")
self.arg_input_group.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=None,
help="weather to use gpu")
def check_input_data(self, args):
input_data = []
if args.image:
if not os.path.exists(args.image):
raise RuntimeError("Path %s is not exist." % args.image)
path = "{}".format(args.image)
return path
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册