diff --git a/modules/image/Image_gan/gan/pixel2style2pixel/README.md b/modules/image/Image_gan/gan/pixel2style2pixel/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fa0c3925e23e62f30d6c4b3635c62a0ba1dfb6dd --- /dev/null +++ b/modules/image/Image_gan/gan/pixel2style2pixel/README.md @@ -0,0 +1,133 @@ +# pixel2style2pixel + +|模型名称|pixel2style2pixel| +| :--- | :---: | +|类别|图像 - 图像生成| +|网络|Pixel2Style2Pixel| +|数据集|-| +|是否支持Fine-tuning|否| +|模型大小|1.7GB| +|最新更新日期|2021-12-14| +|数据指标|-| + + +## 一、模型基本信息 + +- ### 应用效果展示 + - 样例结果示例: +

+ +
+ 输入图像 +
+ +
+ 输出图像 +
+

+ +- ### 模型介绍 + + - Pixel2Style2Pixel使用相当大的模型对图像进行编码,将图像编码到StyleGAN V2的风格向量空间中,使编码前的图像和解码后的图像具有强关联性。该模块应用于人脸转正任务。 + + + +## 二、安装 + +- ### 1、环境依赖 + + - paddlepaddle >= 2.1.0 + - paddlehub >= 2.1.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst) +- ### 2、安装 + + - ```shell + $ hub install pixel2style2pixel + ``` + - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md) + | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md) + +## 三、模型API预测 + +- ### 1、命令行预测 + + - ```shell + # Read from a file + $ hub run pixel2style2pixel --input_path "/PATH/TO/IMAGE" + ``` + - 通过命令行方式实现人脸转正模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst) + +- ### 2、预测代码示例 + + - ```python + import paddlehub as hub + + module = hub.Module(name="pixel2style2pixel") + input_path = ["/PATH/TO/IMAGE"] + # Read from a file + module.style_transfer(paths=input_path, output_dir='./transfer_result/', use_gpu=True) + ``` + +- ### 3、API + + - ```python + style_transfer(images=None, paths=None, output_dir='./transfer_result/', use_gpu=False, visualization=True): + ``` + - 人脸转正生成API。 + + - **参数** + + - images (list\[numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\];
+ - paths (list\[str\]): 图片的路径;
+ - output\_dir (str): 结果保存的路径;
+ - use\_gpu (bool): 是否使用 GPU;
+ - visualization(bool): 是否保存结果到本地文件夹 + + +## 四、服务部署 + +- PaddleHub Serving可以部署一个在线人脸转正服务。 + +- ### 第一步:启动PaddleHub Serving + + - 运行启动命令: + - ```shell + $ hub serving start -m pixel2style2pixel + ``` + + - 这样就完成了一个人脸转正的在线服务API的部署,默认端口号为8866。 + + - **NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。 + +- ### 第二步:发送预测请求 + + - 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果 + + - ```python + import requests + import json + import cv2 + import base64 + + + def cv2_to_base64(image): + data = cv2.imencode('.jpg', image)[1] + return base64.b64encode(data.tostring()).decode('utf8') + + # 发送HTTP请求 + data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]} + headers = {"Content-type": "application/json"} + url = "http://127.0.0.1:8866/predict/pixel2style2pixel" + r = requests.post(url=url, headers=headers, data=json.dumps(data)) + + # 打印预测结果 + print(r.json()["results"]) + +## 五、更新历史 + +* 1.0.0 + + 初始发布 + + - ```shell + $ hub install pixel2style2pixel==1.0.0 + ``` diff --git a/modules/image/Image_gan/gan/pixel2style2pixel/model.py b/modules/image/Image_gan/gan/pixel2style2pixel/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e82fbc8ead5e2545628e59fff817b3a378d63560 --- /dev/null +++ b/modules/image/Image_gan/gan/pixel2style2pixel/model.py @@ -0,0 +1,205 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import scipy +import random +import numpy as np +import paddle +import paddle.vision.transforms as T +import ppgan.faceutils as futils +from ppgan.models.generators import Pixel2Style2Pixel +from ppgan.utils.download import get_path_from_url +from PIL import Image + +model_cfgs = { + 'ffhq-inversion': { + 'model_urls': + 'https://paddlegan.bj.bcebos.com/models/pSp-ffhq-inversion.pdparams', + 'transform': + T.Compose([T.Resize((256, 256)), + T.Transpose(), + T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5])]), + 'size': + 1024, + 'style_dim': + 512, + 'n_mlp': + 8, + 'channel_multiplier': + 2 + }, + 'ffhq-toonify': { + 'model_urls': + 'https://paddlegan.bj.bcebos.com/models/pSp-ffhq-toonify.pdparams', + 'transform': + T.Compose([T.Resize((256, 256)), + T.Transpose(), + T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5])]), + 'size': + 1024, + 'style_dim': + 512, + 'n_mlp': + 8, + 'channel_multiplier': + 2 + }, + 'default': { + 'transform': + T.Compose([T.Resize((256, 256)), + T.Transpose(), + T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5])]) + } +} + + +def run_alignment(image): + img = Image.fromarray(image).convert("RGB") + face = futils.dlib.detect(img) + if not face: + raise Exception('Could not find a face in the given image.') + face_on_image = face[0] + lm = futils.dlib.landmarks(img, face_on_image) + lm = np.array(lm)[:, ::-1] + lm_eye_left = lm[36:42] + lm_eye_right = lm[42:48] + lm_mouth_outer = lm[48:60] + + output_size = 1024 + transform_size = 4096 + enable_padding = True + + # Calculate auxiliary vectors. + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + mouth_left = lm_mouth_outer[0] + mouth_right = lm_mouth_outer[6] + mouth_avg = (mouth_left + mouth_right) * 0.5 + eye_to_mouth = mouth_avg - eye_avg + + # Choose oriented crop rectangle. + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + x /= np.hypot(*x) + x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) + y = np.flipud(x) * [-1, 1] + c = eye_avg + eye_to_mouth * 0.1 + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + qsize = np.hypot(*x) * 2 + + # Shrink. + shrink = int(np.floor(qsize / output_size * 0.5)) + if shrink > 1: + rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) + img = img.resize(rsize, Image.ANTIALIAS) + quad /= shrink + qsize /= shrink + + # Crop. + border = max(int(np.rint(qsize * 0.1)), 3) + crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), + min(crop[3] + border, img.size[1])) + if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: + img = img.crop(crop) + quad -= crop[0:2] + + # Pad. + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), + max(pad[3] - img.size[1] + border, 0)) + if enable_padding and max(pad) > border - 4: + pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + h, w, _ = img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], + np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], + np.float32(h - 1 - y) / pad[3])) + blur = qsize * 0.02 + img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) + img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') + quad += pad[:2] + + # Transform. + img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR) + + return img + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +class Pixel2Style2PixelPredictor: + def __init__(self, + weight_path=None, + model_type=None, + seed=None, + size=1024, + style_dim=512, + n_mlp=8, + channel_multiplier=2): + + if weight_path is None and model_type != 'default': + if model_type in model_cfgs.keys(): + weight_path = get_path_from_url(model_cfgs[model_type]['model_urls']) + size = model_cfgs[model_type].get('size', size) + style_dim = model_cfgs[model_type].get('style_dim', style_dim) + n_mlp = model_cfgs[model_type].get('n_mlp', n_mlp) + channel_multiplier = model_cfgs[model_type].get('channel_multiplier', channel_multiplier) + checkpoint = paddle.load(weight_path) + else: + raise ValueError('Predictor need a weight path or a pretrained model type') + else: + checkpoint = paddle.load(weight_path) + + opts = checkpoint.pop('opts') + opts = AttrDict(opts) + opts['size'] = size + opts['style_dim'] = style_dim + opts['n_mlp'] = n_mlp + opts['channel_multiplier'] = channel_multiplier + + self.generator = Pixel2Style2Pixel(opts) + self.generator.set_state_dict(checkpoint) + self.generator.eval() + + if seed is not None: + paddle.seed(seed) + random.seed(seed) + np.random.seed(seed) + + self.model_type = 'default' if model_type is None else model_type + + def run(self, image): + src_img = run_alignment(image) + src_img = np.asarray(src_img) + transformed_image = model_cfgs[self.model_type]['transform'](src_img) + dst_img, latents = self.generator( + paddle.to_tensor(transformed_image[None, ...]), resize=False, return_latents=True) + dst_img = (dst_img * 0.5 + 0.5)[0].numpy() * 255 + dst_img = dst_img.transpose((1, 2, 0)) + dst_npy = latents[0].numpy() + + return dst_img, dst_npy diff --git a/modules/image/Image_gan/gan/pixel2style2pixel/module.py b/modules/image/Image_gan/gan/pixel2style2pixel/module.py new file mode 100644 index 0000000000000000000000000000000000000000..fb054a6f09becd52790df9437abb6de28f42118d --- /dev/null +++ b/modules/image/Image_gan/gan/pixel2style2pixel/module.py @@ -0,0 +1,137 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +import copy + +import paddle +import paddlehub as hub +from paddlehub.module.module import moduleinfo, runnable, serving +import numpy as np +import cv2 +from skimage.io import imread +from skimage.transform import rescale, resize + +from .model import Pixel2Style2PixelPredictor +from .util import base64_to_cv2 + + +@moduleinfo( + name="pixel2style2pixel", + type="CV/style_transfer", + author="paddlepaddle", + author_email="", + summary="", + version="1.0.0") +class pixel2style2pixel: + def __init__(self): + self.pretrained_model = os.path.join(self.directory, "pSp-ffhq-inversion.pdparams") + + self.network = Pixel2Style2PixelPredictor(weight_path=self.pretrained_model, model_type='ffhq-inversion') + + def style_transfer(self, + images=None, + paths=None, + output_dir='./transfer_result/', + use_gpu=False, + visualization=True): + ''' + + + images (list[numpy.ndarray]): data of images, shape of each is [H, W, C], color space must be BGR(read by cv2). + paths (list[str]): paths to images + output_dir: the dir to save the results + use_gpu: if True, use gpu to perform the computation, otherwise cpu. + visualization: if True, save results in output_dir. + ''' + results = [] + paddle.disable_static() + place = 'gpu:0' if use_gpu else 'cpu' + place = paddle.set_device(place) + if images == None and paths == None: + print('No image provided. Please input an image or a image path.') + return + + if images != None: + for image in images: + image = image[:, :, ::-1] + out = self.network.run(image) + results.append(out) + + if paths != None: + for path in paths: + image = cv2.imread(path)[:, :, ::-1] + out = self.network.run(image) + results.append(out) + + if visualization == True: + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + for i, out in enumerate(results): + if out is not None: + cv2.imwrite(os.path.join(output_dir, 'output_{}.png'.format(i)), out[0][:, :, ::-1]) + np.save(os.path.join(output_dir, 'output_{}.npy'.format(i)), out[1]) + + return results + + @runnable + def run_cmd(self, argvs: list): + """ + Run as a command. + """ + self.parser = argparse.ArgumentParser( + description="Run the {} module.".format(self.name), + prog='hub run {}'.format(self.name), + usage='%(prog)s', + add_help=True) + + self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required") + self.arg_config_group = self.parser.add_argument_group( + title="Config options", description="Run configuration for controlling module behavior, not required.") + self.add_module_config_arg() + self.add_module_input_arg() + self.args = self.parser.parse_args(argvs) + results = self.style_transfer( + paths=[self.args.input_path], + output_dir=self.args.output_dir, + use_gpu=self.args.use_gpu, + visualization=self.args.visualization) + return results + + @serving + def serving_method(self, images, **kwargs): + """ + Run as a service. + """ + images_decode = [base64_to_cv2(image) for image in images] + results = self.style_transfer(images=images_decode, **kwargs) + tolist = [result.tolist() for result in results] + return tolist + + def add_module_config_arg(self): + """ + Add the command config options. + """ + self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not") + + self.arg_config_group.add_argument( + '--output_dir', type=str, default='transfer_result', help='output directory for saving result.') + self.arg_config_group.add_argument('--visualization', type=bool, default=False, help='save results or not.') + + def add_module_input_arg(self): + """ + Add the command input options. + """ + self.arg_input_group.add_argument('--input_path', type=str, help="path to input image.") diff --git a/modules/image/Image_gan/gan/pixel2style2pixel/requirements.txt b/modules/image/Image_gan/gan/pixel2style2pixel/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d9bfc85782a3ee323241fe7beb87a9f281c120fe --- /dev/null +++ b/modules/image/Image_gan/gan/pixel2style2pixel/requirements.txt @@ -0,0 +1,2 @@ +ppgan +dlib diff --git a/modules/image/Image_gan/gan/pixel2style2pixel/util.py b/modules/image/Image_gan/gan/pixel2style2pixel/util.py new file mode 100644 index 0000000000000000000000000000000000000000..b88ac3562b74cadc1d4d6459a56097ca4a938a0b --- /dev/null +++ b/modules/image/Image_gan/gan/pixel2style2pixel/util.py @@ -0,0 +1,10 @@ +import base64 +import cv2 +import numpy as np + + +def base64_to_cv2(b64str): + data = base64.b64decode(b64str.encode('utf8')) + data = np.fromstring(data, np.uint8) + data = cv2.imdecode(data, cv2.IMREAD_COLOR) + return data