diff --git a/modules/image/Image_gan/gan/photopen/README.md b/modules/image/Image_gan/gan/photopen/README.md new file mode 100644 index 0000000000000000000000000000000000000000..73c80f9ad381b2adaeb7ab28d95c702b6cc55102 --- /dev/null +++ b/modules/image/Image_gan/gan/photopen/README.md @@ -0,0 +1,126 @@ +# photopen + +|模型名称|photopen| +| :--- | :---: | +|类别|图像 - 图像生成| +|网络|SPADEGenerator| +|数据集|coco_stuff| +|是否支持Fine-tuning|否| +|模型大小|74MB| +|最新更新日期|2021-12-14| +|数据指标|-| + + +## 一、模型基本信息 + +- ### 应用效果展示 + - 样例结果示例: +

+ +
+ +- ### 模型介绍 + + - 本模块采用一个像素风格迁移网络 Pix2PixHD,能够根据输入的语义分割标签生成照片风格的图片。为了解决模型归一化层导致标签语义信息丢失的问题,向 Pix2PixHD 的生成器网络中添加了 SPADE(Spatially-Adaptive + Normalization)空间自适应归一化模块,通过两个卷积层保留了归一化时训练的缩放与偏置参数的空间维度,以增强生成图片的质量。语义风格标签图像可以参考[coco_stuff数据集](https://github.com/nightrome/cocostuff)获取, 也可以通过[PaddleGAN repo中的该项目](https://github.com/PaddlePaddle/PaddleGAN/blob/87537ad9d4eeda17eaa5916c6a585534ab989ea8/docs/zh_CN/tutorials/photopen.md)来自定义生成图像进行体验。 + + + +## 二、安装 + +- ### 1、环境依赖 + - ppgan + +- ### 2、安装 + + - ```shell + $ hub install photopen + ``` + - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md) + | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md) + +## 三、模型API预测 + +- ### 1、命令行预测 + + - ```shell + # Read from a file + $ hub run photopen --input_path "/PATH/TO/IMAGE" + ``` + - 通过命令行方式实现图像生成模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst) + +- ### 2、预测代码示例 + + - ```python + import paddlehub as hub + + module = hub.Module(name="photopen") + input_path = ["/PATH/TO/IMAGE"] + # Read from a file + module.photo_transfer(paths=input_path, output_dir='./transfer_result/', use_gpu=True) + ``` + +- ### 3、API + + - ```python + photo_transfer(images=None, paths=None, output_dir='./transfer_result/', use_gpu=False, visualization=True): + ``` + - 图像转换生成API。 + + - **参数** + + - images (list\[numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\];
+ - paths (list\[str\]): 图片的路径;
+ - output\_dir (str): 结果保存的路径;
+ - use\_gpu (bool): 是否使用 GPU;
+ - visualization(bool): 是否保存结果到本地文件夹 + + +## 四、服务部署 + +- PaddleHub Serving可以部署一个在线图像转换生成服务。 + +- ### 第一步:启动PaddleHub Serving + + - 运行启动命令: + - ```shell + $ hub serving start -m photopen + ``` + + - 这样就完成了一个图像转换生成的在线服务API的部署,默认端口号为8866。 + + - **NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。 + +- ### 第二步:发送预测请求 + + - 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果 + + - ```python + import requests + import json + import cv2 + import base64 + + + def cv2_to_base64(image): + data = cv2.imencode('.jpg', image)[1] + return base64.b64encode(data.tostring()).decode('utf8') + + # 发送HTTP请求 + data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]} + headers = {"Content-type": "application/json"} + url = "http://127.0.0.1:8866/predict/photopen" + r = requests.post(url=url, headers=headers, data=json.dumps(data)) + + # 打印预测结果 + print(r.json()["results"]) + +## 五、更新历史 + +* 1.0.0 + + 初始发布 + + - ```shell + $ hub install photopen==1.0.0 + ``` diff --git a/modules/image/Image_gan/gan/photopen/model.py b/modules/image/Image_gan/gan/photopen/model.py new file mode 100644 index 0000000000000000000000000000000000000000..4a0b0a4836b010ca4d72995c8857a8bb0ddd7aa2 --- /dev/null +++ b/modules/image/Image_gan/gan/photopen/model.py @@ -0,0 +1,62 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import cv2 +import numpy as np +import paddle +from PIL import Image +from PIL import ImageOps +from ppgan.models.generators import SPADEGenerator +from ppgan.utils.filesystem import load +from ppgan.utils.photopen import data_onehot_pro + + +class PhotoPenPredictor: + def __init__(self, weight_path, gen_cfg): + + # 初始化模型 + gen = SPADEGenerator( + gen_cfg.ngf, + gen_cfg.num_upsampling_layers, + gen_cfg.crop_size, + gen_cfg.aspect_ratio, + gen_cfg.norm_G, + gen_cfg.semantic_nc, + gen_cfg.use_vae, + gen_cfg.nef, + ) + gen.eval() + para = load(weight_path) + if 'net_gen' in para: + gen.set_state_dict(para['net_gen']) + else: + gen.set_state_dict(para) + + self.gen = gen + self.gen_cfg = gen_cfg + + def run(self, image): + sem = Image.fromarray(image).convert('L') + sem = sem.resize((self.gen_cfg.crop_size, self.gen_cfg.crop_size), Image.NEAREST) + sem = np.array(sem).astype('float32') + sem = paddle.to_tensor(sem) + sem = sem.reshape([1, 1, self.gen_cfg.crop_size, self.gen_cfg.crop_size]) + + one_hot = data_onehot_pro(sem, self.gen_cfg) + predicted = self.gen(one_hot) + pic = predicted.numpy()[0].reshape((3, 256, 256)).transpose((1, 2, 0)) + pic = ((pic + 1.) / 2. * 255).astype('uint8') + + return pic diff --git a/modules/image/Image_gan/gan/photopen/module.py b/modules/image/Image_gan/gan/photopen/module.py new file mode 100644 index 0000000000000000000000000000000000000000..f8a23e574c9823c52daf2e07a318e344b8220b70 --- /dev/null +++ b/modules/image/Image_gan/gan/photopen/module.py @@ -0,0 +1,133 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import copy +import os + +import cv2 +import numpy as np +import paddle +from ppgan.utils.config import get_config +from skimage.io import imread +from skimage.transform import rescale +from skimage.transform import resize + +import paddlehub as hub +from .model import PhotoPenPredictor +from .util import base64_to_cv2 +from paddlehub.module.module import moduleinfo +from paddlehub.module.module import runnable +from paddlehub.module.module import serving + + +@moduleinfo( + name="photopen", type="CV/style_transfer", author="paddlepaddle", author_email="", summary="", version="1.0.0") +class Photopen: + def __init__(self): + self.pretrained_model = os.path.join(self.directory, "photopen.pdparams") + cfg = get_config(os.path.join(self.directory, "photopen.yaml")) + self.network = PhotoPenPredictor(weight_path=self.pretrained_model, gen_cfg=cfg.predict) + + def photo_transfer(self, + images: list = None, + paths: list = None, + output_dir: str = './transfer_result/', + use_gpu: bool = False, + visualization: bool = True): + ''' + images (list[numpy.ndarray]): data of images, shape of each is [H, W, C], color space must be BGR(read by cv2). + paths (list[str]): paths to images + + output_dir (str): the dir to save the results + use_gpu (bool): if True, use gpu to perform the computation, otherwise cpu. + visualization (bool): if True, save results in output_dir. + ''' + results = [] + paddle.disable_static() + place = 'gpu:0' if use_gpu else 'cpu' + place = paddle.set_device(place) + if images == None and paths == None: + print('No image provided. Please input an image or a image path.') + return + + if images != None: + for image in images: + image = image[:, :, ::-1] + out = self.network.run(image) + results.append(out) + + if paths != None: + for path in paths: + image = cv2.imread(path)[:, :, ::-1] + out = self.network.run(image) + results.append(out) + + if visualization == True: + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + for i, out in enumerate(results): + if out is not None: + cv2.imwrite(os.path.join(output_dir, 'output_{}.png'.format(i)), out[:, :, ::-1]) + + return results + + @runnable + def run_cmd(self, argvs: list): + """ + Run as a command. + """ + self.parser = argparse.ArgumentParser( + description="Run the {} module.".format(self.name), + prog='hub run {}'.format(self.name), + usage='%(prog)s', + add_help=True) + + self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") + self.arg_config_group = self.parser.add_argument_group( + title="Config options", description="Run configuration for controlling module behavior, not required.") + self.add_module_config_arg() + self.add_module_input_arg() + self.args = self.parser.parse_args(argvs) + results = self.photo_transfer( + paths=[self.args.input_path], + output_dir=self.args.output_dir, + use_gpu=self.args.use_gpu, + visualization=self.args.visualization) + return results + + @serving + def serving_method(self, images, **kwargs): + """ + Run as a service. + """ + images_decode = [base64_to_cv2(image) for image in images] + results = self.photo_transfer(images=images_decode, **kwargs) + tolist = [result.tolist() for result in results] + return tolist + + def add_module_config_arg(self): + """ + Add the command config options. + """ + self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not") + + self.arg_config_group.add_argument( + '--output_dir', type=str, default='transfer_result', help='output directory for saving result.') + self.arg_config_group.add_argument('--visualization', type=bool, default=False, help='save results or not.') + + def add_module_input_arg(self): + """ + Add the command input options. + """ + self.arg_input_group.add_argument('--input_path', type=str, help="path to input image.") diff --git a/modules/image/Image_gan/gan/photopen/photopen.yaml b/modules/image/Image_gan/gan/photopen/photopen.yaml new file mode 100644 index 0000000000000000000000000000000000000000..178f361736c06f1f816997dc4a52a9a6bd62bcc9 --- /dev/null +++ b/modules/image/Image_gan/gan/photopen/photopen.yaml @@ -0,0 +1,95 @@ +total_iters: 1 +output_dir: output_dir +checkpoints_dir: checkpoints + +model: + name: PhotoPenModel + generator: + name: SPADEGenerator + ngf: 24 + num_upsampling_layers: normal + crop_size: 256 + aspect_ratio: 1.0 + norm_G: spectralspadebatch3x3 + semantic_nc: 14 + use_vae: False + nef: 16 + discriminator: + name: MultiscaleDiscriminator + ndf: 128 + num_D: 4 + crop_size: 256 + label_nc: 12 + output_nc: 3 + contain_dontcare_label: True + no_instance: False + n_layers_D: 6 + criterion: + name: PhotoPenPerceptualLoss + crop_size: 224 + lambda_vgg: 1.6 + label_nc: 12 + contain_dontcare_label: True + batchSize: 1 + crop_size: 256 + lambda_feat: 10.0 + +dataset: + train: + name: PhotoPenDataset + content_root: test/coco_stuff + load_size: 286 + crop_size: 256 + num_workers: 0 + batch_size: 1 + test: + name: PhotoPenDataset_test + content_root: test/coco_stuff + load_size: 286 + crop_size: 256 + num_workers: 0 + batch_size: 1 + +lr_scheduler: # abundoned + name: LinearDecay + learning_rate: 0.0001 + start_epoch: 99999 + decay_epochs: 99999 + # will get from real dataset + iters_per_epoch: 1 + +optimizer: + lr: 0.0001 + optimG: + name: Adam + net_names: + - net_gen + beta1: 0.9 + beta2: 0.999 + optimD: + name: Adam + net_names: + - net_des + beta1: 0.9 + beta2: 0.999 + +log_config: + interval: 1 + visiual_interval: 1 + +snapshot_config: + interval: 1 + +predict: + name: SPADEGenerator + ngf: 24 + num_upsampling_layers: normal + crop_size: 256 + aspect_ratio: 1.0 + norm_G: spectralspadebatch3x3 + semantic_nc: 14 + use_vae: False + nef: 16 + contain_dontcare_label: True + label_nc: 12 + batchSize: 1 diff --git a/modules/image/Image_gan/gan/photopen/requirements.txt b/modules/image/Image_gan/gan/photopen/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..67e9bb6fa840355e9ed0d44b7134850f1fe22fe1 --- /dev/null +++ b/modules/image/Image_gan/gan/photopen/requirements.txt @@ -0,0 +1 @@ +ppgan diff --git a/modules/image/Image_gan/gan/photopen/util.py b/modules/image/Image_gan/gan/photopen/util.py new file mode 100644 index 0000000000000000000000000000000000000000..531a0ae0d487822a870ba7f09817e658967aff10 --- /dev/null +++ b/modules/image/Image_gan/gan/photopen/util.py @@ -0,0 +1,11 @@ +import base64 + +import cv2 +import numpy as np + + +def base64_to_cv2(b64str): + data = base64.b64decode(b64str.encode('utf8')) + data = np.fromstring(data, np.uint8) + data = cv2.imdecode(data, cv2.IMREAD_COLOR) + return data