diff --git a/modules/image/Image_gan/style_transfer/lapstyle_starrynew/README.md b/modules/image/Image_gan/style_transfer/lapstyle_starrynew/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6c082ac3af73eeab46e7a4025c8ddee8a804b5b1 --- /dev/null +++ b/modules/image/Image_gan/style_transfer/lapstyle_starrynew/README.md @@ -0,0 +1,142 @@ +# lapstyle_starrynew + +|模型名称|lapstyle_starrynew| +| :--- | :---: | +|类别|图像 - 风格迁移| +|网络|LapStyle| +|数据集|-| +|是否支持Fine-tuning|否| +|模型大小|121MB| +|最新更新日期|2021-12-07| +|数据指标|-| + + +## 一、模型基本信息 + +- ### 应用效果展示 + - 样例结果示例: +

+ +
+ 输入内容图形 +
+ +
+ 输入风格图形 +
+ +
+ 输出图像 +
+

+ +- ### 模型介绍 + + - LapStyle--拉普拉斯金字塔风格化网络,是一种能够生成高质量风格化图的快速前馈风格化网络,能渐进地生成复杂的纹理迁移效果,同时能够在512分辨率下达到100fps的速度。可实现多种不同艺术风格的快速迁移,在艺术图像生成、滤镜等领域有广泛的应用。 + + - 更多详情参考:[Drafting and Revision: Laplacian Pyramid Network for Fast High-Quality Artistic Style Transfer](https://arxiv.org/pdf/2104.05376.pdf) + + + +## 二、安装 + +- ### 1、环境依赖 + - ppgan + +- ### 2、安装 + + - ```shell + $ hub install lapstyle_starrynew + ``` + - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md) + | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md) + +## 三、模型API预测 + +- ### 1、命令行预测 + + - ```shell + # Read from a file + $ hub run lapstyle_starrynew --content "/PATH/TO/IMAGE" --style "/PATH/TO/IMAGE1" + ``` + - 通过命令行方式实现风格转换模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst) + +- ### 2、预测代码示例 + + - ```python + import paddlehub as hub + + module = hub.Module(name="lapstyle_starrynew") + content = cv2.imread("/PATH/TO/IMAGE") + style = cv2.imread("/PATH/TO/IMAGE1") + results = module.style_transfer(images=[{'content':content, 'style':style}], output_dir='./transfer_result', use_gpu=True) + ``` + +- ### 3、API + + - ```python + style_transfer(images=None, paths=None, output_dir='./transfer_result/', use_gpu=False, visualization=True) + ``` + - 风格转换API。 + + - **参数** + + - images (list[dict]): data of images, 每一个元素都为一个 dict,有关键字 content, style, 相应取值为: + - content (numpy.ndarray): 待转换的图片,shape 为 \[H, W, C\],BGR格式;
+ - style (numpy.ndarray) : 风格图像,shape为 \[H, W, C\],BGR格式;
+ - paths (list[str]): paths to images, 每一个元素都为一个dict, 有关键字 content, style, 相应取值为: + - content (str): 待转换的图片的路径;
+ - style (str) : 风格图像的路径;
+ - output\_dir (str): 结果保存的路径;
+ - use\_gpu (bool): 是否使用 GPU;
+ - visualization(bool): 是否保存结果到本地文件夹 + + +## 四、服务部署 + +- PaddleHub Serving可以部署一个在线图像风格转换服务。 + +- ### 第一步:启动PaddleHub Serving + + - 运行启动命令: + - ```shell + $ hub serving start -m lapstyle_starrynew + ``` + + - 这样就完成了一个图像风格转换的在线服务API的部署,默认端口号为8866。 + + - **NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。 + +- ### 第二步:发送预测请求 + + - 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果 + + - ```python + import requests + import json + import cv2 + import base64 + + + def cv2_to_base64(image): + data = cv2.imencode('.jpg', image)[1] + return base64.b64encode(data.tostring()).decode('utf8') + + # 发送HTTP请求 + data = {'images':[{'content': cv2_to_base64(cv2.imread("/PATH/TO/IMAGE")), 'style': cv2_to_base64(cv2.imread("/PATH/TO/IMAGE1"))}]} + headers = {"Content-type": "application/json"} + url = "http://127.0.0.1:8866/predict/lapstyle_starrynew" + r = requests.post(url=url, headers=headers, data=json.dumps(data)) + + # 打印预测结果 + print(r.json()["results"]) + +## 五、更新历史 + +* 1.0.0 + + 初始发布 + + - ```shell + $ hub install lapstyle_starrynew==1.0.0 + ``` diff --git a/modules/image/Image_gan/style_transfer/lapstyle_starrynew/model.py b/modules/image/Image_gan/style_transfer/lapstyle_starrynew/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ac37f5e677d092cadde1432b777bcd6a7a8877d3 --- /dev/null +++ b/modules/image/Image_gan/style_transfer/lapstyle_starrynew/model.py @@ -0,0 +1,140 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 as cv +import numpy as np +import urllib.request +from PIL import Image + +import paddle +import paddle.nn.functional as F +from paddle.vision.transforms import functional + +from ppgan.utils.visual import tensor2img +from ppgan.models.generators import DecoderNet, Encoder, RevisionNet + + +def img(img): + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + # HWC to CHW + return img + + +def img_totensor(content_img, style_img): + if content_img.ndim == 2: + content_img = cv.cvtColor(content_img, cv.COLOR_GRAY2RGB) + else: + content_img = cv.cvtColor(content_img, cv.COLOR_BGR2RGB) + h, w, c = content_img.shape + content_img = Image.fromarray(content_img) + content_img = content_img.resize((512, 512), Image.BILINEAR) + content_img = np.array(content_img) + content_img = img(content_img) + content_img = functional.to_tensor(content_img) + + style_img = cv.cvtColor(style_img, cv.COLOR_BGR2RGB) + style_img = Image.fromarray(style_img) + style_img = style_img.resize((512, 512), Image.BILINEAR) + style_img = np.array(style_img) + style_img = img(style_img) + style_img = functional.to_tensor(style_img) + + content_img = paddle.unsqueeze(content_img, axis=0) + style_img = paddle.unsqueeze(style_img, axis=0) + return content_img, style_img, h, w + + +def tensor_resample(tensor, dst_size, mode='bilinear'): + return F.interpolate(tensor, dst_size, mode=mode, align_corners=False) + + +def laplacian(x): + """ + Laplacian + + return: + x - upsample(downsample(x)) + """ + return x - tensor_resample(tensor_resample(x, [x.shape[2] // 2, x.shape[3] // 2]), [x.shape[2], x.shape[3]]) + + +def make_laplace_pyramid(x, levels): + """ + Make Laplacian Pyramid + """ + pyramid = [] + current = x + for i in range(levels): + pyramid.append(laplacian(current)) + current = tensor_resample(current, (max(current.shape[2] // 2, 1), max(current.shape[3] // 2, 1))) + pyramid.append(current) + return pyramid + + +def fold_laplace_pyramid(pyramid): + """ + Fold Laplacian Pyramid + """ + current = pyramid[-1] + for i in range(len(pyramid) - 2, -1, -1): # iterate from len-2 to 0 + up_h, up_w = pyramid[i].shape[2], pyramid[i].shape[3] + current = pyramid[i] + tensor_resample(current, (up_h, up_w)) + return current + + +class LapStylePredictor: + def __init__(self, weight_path=None): + + self.net_enc = Encoder() + self.net_dec = DecoderNet() + self.net_rev = RevisionNet() + self.net_rev_2 = RevisionNet() + + self.net_enc.set_dict(paddle.load(weight_path)['net_enc']) + self.net_enc.eval() + self.net_dec.set_dict(paddle.load(weight_path)['net_dec']) + self.net_dec.eval() + self.net_rev.set_dict(paddle.load(weight_path)['net_rev']) + self.net_rev.eval() + self.net_rev_2.set_dict(paddle.load(weight_path)['net_rev_2']) + self.net_rev_2.eval() + + def run(self, content_img, style_image): + content_img, style_img, h, w = img_totensor(content_img, style_image) + pyr_ci = make_laplace_pyramid(content_img, 2) + pyr_si = make_laplace_pyramid(style_img, 2) + pyr_ci.append(content_img) + pyr_si.append(style_img) + cF = self.net_enc(pyr_ci[2]) + sF = self.net_enc(pyr_si[2]) + stylized_small = self.net_dec(cF, sF) + stylized_up = F.interpolate(stylized_small, scale_factor=2) + + revnet_input = paddle.concat(x=[pyr_ci[1], stylized_up], axis=1) + stylized_rev_lap = self.net_rev(revnet_input) + stylized_rev = fold_laplace_pyramid([stylized_rev_lap, stylized_small]) + + stylized_up = F.interpolate(stylized_rev, scale_factor=2) + + revnet_input = paddle.concat(x=[pyr_ci[0], stylized_up], axis=1) + stylized_rev_lap_second = self.net_rev_2(revnet_input) + stylized_rev_second = fold_laplace_pyramid([stylized_rev_lap_second, stylized_rev_lap, stylized_small]) + + stylized = stylized_rev_second + stylized_visual = tensor2img(stylized, min_max=(0., 1.)) + + return stylized_visual diff --git a/modules/image/Image_gan/style_transfer/lapstyle_starrynew/module.py b/modules/image/Image_gan/style_transfer/lapstyle_starrynew/module.py new file mode 100644 index 0000000000000000000000000000000000000000..84a5645c59eeb314a23522bf5e405abdf40a38a2 --- /dev/null +++ b/modules/image/Image_gan/style_transfer/lapstyle_starrynew/module.py @@ -0,0 +1,147 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +import copy + +import paddle +import paddlehub as hub +from paddlehub.module.module import moduleinfo, runnable, serving +import numpy as np +import cv2 +from skimage.io import imread +from skimage.transform import rescale, resize + +from .model import LapStylePredictor +from .util import base64_to_cv2 + + +@moduleinfo( + name="lapstyle_starrynew", + type="CV/style_transfer", + author="paddlepaddle", + author_email="", + summary="", + version="1.0.0") +class Lapstyle_starrynew: + def __init__(self): + self.pretrained_model = os.path.join(self.directory, "lapstyle_starrynew.pdparams") + + self.network = LapStylePredictor(weight_path=self.pretrained_model) + + def style_transfer(self, + images=None, + paths=None, + output_dir='./transfer_result/', + use_gpu=False, + visualization=True): + ''' + Transfer a image to starrynew style. + + images (list[dict]): data of images, 每一个元素都为一个 dict,有关键字 content, style, 相应取值为: + - content (numpy.ndarray): 待转换的图片,shape 为 \[H, W, C\],BGR格式;
+ - style (numpy.ndarray) : 风格图像,shape为 \[H, W, C\],BGR格式;
+ paths (list[str]): paths to images, 每一个元素都为一个dict, 有关键字 content, style, 相应取值为: + - content (str): 待转换的图片的路径;
+ - style (str) : 风格图像的路径;
+ + output_dir: the dir to save the results + use_gpu: if True, use gpu to perform the computation, otherwise cpu. + visualization: if True, save results in output_dir. + ''' + results = [] + paddle.disable_static() + place = 'gpu:0' if use_gpu else 'cpu' + place = paddle.set_device(place) + if images == None and paths == None: + print('No image provided. Please input an image or a image path.') + return + + if images != None: + for image_dict in images: + content_img = image_dict['content'] + style_img = image_dict['style'] + results.append(self.network.run(content_img, style_img)) + + if paths != None: + for path_dict in paths: + content_img = cv2.imread(path_dict['content']) + style_img = cv2.imread(path_dict['style']) + results.append(self.network.run(content_img, style_img)) + + if visualization == True: + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + for i, out in enumerate(results): + cv2.imwrite(os.path.join(output_dir, 'output_{}.png'.format(i)), out[:, :, ::-1]) + + return results + + @runnable + def run_cmd(self, argvs: list): + """ + Run as a command. + """ + self.parser = argparse.ArgumentParser( + description="Run the {} module.".format(self.name), + prog='hub run {}'.format(self.name), + usage='%(prog)s', + add_help=True) + + self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") + self.arg_config_group = self.parser.add_argument_group( + title="Config options", description="Run configuration for controlling module behavior, not required.") + self.add_module_config_arg() + self.add_module_input_arg() + self.args = self.parser.parse_args(argvs) + + self.style_transfer( + paths=[{ + 'content': self.args.content, + 'style': self.args.style + }], + output_dir=self.args.output_dir, + use_gpu=self.args.use_gpu, + visualization=self.args.visualization) + + @serving + def serving_method(self, images, **kwargs): + """ + Run as a service. + """ + images_decode = copy.deepcopy(images) + for image in images_decode: + image['content'] = base64_to_cv2(image['content']) + image['style'] = base64_to_cv2(image['style']) + results = self.style_transfer(images_decode, **kwargs) + tolist = [result.tolist() for result in results] + return tolist + + def add_module_config_arg(self): + """ + Add the command config options. + """ + self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not") + + self.arg_config_group.add_argument( + '--output_dir', type=str, default='transfer_result', help='output directory for saving result.') + self.arg_config_group.add_argument('--visualization', type=bool, default=False, help='save results or not.') + + def add_module_input_arg(self): + """ + Add the command input options. + """ + self.arg_input_group.add_argument('--content', type=str, help="path to content image.") + self.arg_input_group.add_argument('--style', type=str, help="path to style image.") diff --git a/modules/image/Image_gan/style_transfer/lapstyle_starrynew/requirements.txt b/modules/image/Image_gan/style_transfer/lapstyle_starrynew/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..67e9bb6fa840355e9ed0d44b7134850f1fe22fe1 --- /dev/null +++ b/modules/image/Image_gan/style_transfer/lapstyle_starrynew/requirements.txt @@ -0,0 +1 @@ +ppgan diff --git a/modules/image/Image_gan/style_transfer/lapstyle_starrynew/util.py b/modules/image/Image_gan/style_transfer/lapstyle_starrynew/util.py new file mode 100644 index 0000000000000000000000000000000000000000..b88ac3562b74cadc1d4d6459a56097ca4a938a0b --- /dev/null +++ b/modules/image/Image_gan/style_transfer/lapstyle_starrynew/util.py @@ -0,0 +1,10 @@ +import base64 +import cv2 +import numpy as np + + +def base64_to_cv2(b64str): + data = base64.b64decode(b64str.encode('utf8')) + data = np.fromstring(data, np.uint8) + data = cv2.imdecode(data, cv2.IMREAD_COLOR) + return data