diff --git a/modules/image/Image_gan/gan/pixel2style2pixel/README.md b/modules/image/Image_gan/gan/pixel2style2pixel/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..fa0c3925e23e62f30d6c4b3635c62a0ba1dfb6dd
--- /dev/null
+++ b/modules/image/Image_gan/gan/pixel2style2pixel/README.md
@@ -0,0 +1,133 @@
+# pixel2style2pixel
+
+|模型名称|pixel2style2pixel|
+| :--- | :---: |
+|类别|图像 - 图像生成|
+|网络|Pixel2Style2Pixel|
+|数据集|-|
+|是否支持Fine-tuning|否|
+|模型大小|1.7GB|
+|最新更新日期|2021-12-14|
+|数据指标|-|
+
+
+## 一、模型基本信息
+
+- ### 应用效果展示
+ - 样例结果示例:
+
+
+
+ 输入图像
+
+
+
+ 输出图像
+
+
+
+- ### 模型介绍
+
+ - Pixel2Style2Pixel使用相当大的模型对图像进行编码,将图像编码到StyleGAN V2的风格向量空间中,使编码前的图像和解码后的图像具有强关联性。该模块应用于人脸转正任务。
+
+
+
+## 二、安装
+
+- ### 1、环境依赖
+
+ - paddlepaddle >= 2.1.0
+ - paddlehub >= 2.1.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst)
+- ### 2、安装
+
+ - ```shell
+ $ hub install pixel2style2pixel
+ ```
+ - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
+ | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
+
+## 三、模型API预测
+
+- ### 1、命令行预测
+
+ - ```shell
+ # Read from a file
+ $ hub run pixel2style2pixel --input_path "/PATH/TO/IMAGE"
+ ```
+ - 通过命令行方式实现人脸转正模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
+
+- ### 2、预测代码示例
+
+ - ```python
+ import paddlehub as hub
+
+ module = hub.Module(name="pixel2style2pixel")
+ input_path = ["/PATH/TO/IMAGE"]
+ # Read from a file
+ module.style_transfer(paths=input_path, output_dir='./transfer_result/', use_gpu=True)
+ ```
+
+- ### 3、API
+
+ - ```python
+ style_transfer(images=None, paths=None, output_dir='./transfer_result/', use_gpu=False, visualization=True):
+ ```
+ - 人脸转正生成API。
+
+ - **参数**
+
+ - images (list\[numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\];
+ - paths (list\[str\]): 图片的路径;
+ - output\_dir (str): 结果保存的路径;
+ - use\_gpu (bool): 是否使用 GPU;
+ - visualization(bool): 是否保存结果到本地文件夹
+
+
+## 四、服务部署
+
+- PaddleHub Serving可以部署一个在线人脸转正服务。
+
+- ### 第一步:启动PaddleHub Serving
+
+ - 运行启动命令:
+ - ```shell
+ $ hub serving start -m pixel2style2pixel
+ ```
+
+ - 这样就完成了一个人脸转正的在线服务API的部署,默认端口号为8866。
+
+ - **NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
+
+- ### 第二步:发送预测请求
+
+ - 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
+
+ - ```python
+ import requests
+ import json
+ import cv2
+ import base64
+
+
+ def cv2_to_base64(image):
+ data = cv2.imencode('.jpg', image)[1]
+ return base64.b64encode(data.tostring()).decode('utf8')
+
+ # 发送HTTP请求
+ data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
+ headers = {"Content-type": "application/json"}
+ url = "http://127.0.0.1:8866/predict/pixel2style2pixel"
+ r = requests.post(url=url, headers=headers, data=json.dumps(data))
+
+ # 打印预测结果
+ print(r.json()["results"])
+
+## 五、更新历史
+
+* 1.0.0
+
+ 初始发布
+
+ - ```shell
+ $ hub install pixel2style2pixel==1.0.0
+ ```
diff --git a/modules/image/Image_gan/gan/pixel2style2pixel/model.py b/modules/image/Image_gan/gan/pixel2style2pixel/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e82fbc8ead5e2545628e59fff817b3a378d63560
--- /dev/null
+++ b/modules/image/Image_gan/gan/pixel2style2pixel/model.py
@@ -0,0 +1,205 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import cv2
+import scipy
+import random
+import numpy as np
+import paddle
+import paddle.vision.transforms as T
+import ppgan.faceutils as futils
+from ppgan.models.generators import Pixel2Style2Pixel
+from ppgan.utils.download import get_path_from_url
+from PIL import Image
+
+model_cfgs = {
+ 'ffhq-inversion': {
+ 'model_urls':
+ 'https://paddlegan.bj.bcebos.com/models/pSp-ffhq-inversion.pdparams',
+ 'transform':
+ T.Compose([T.Resize((256, 256)),
+ T.Transpose(),
+ T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5])]),
+ 'size':
+ 1024,
+ 'style_dim':
+ 512,
+ 'n_mlp':
+ 8,
+ 'channel_multiplier':
+ 2
+ },
+ 'ffhq-toonify': {
+ 'model_urls':
+ 'https://paddlegan.bj.bcebos.com/models/pSp-ffhq-toonify.pdparams',
+ 'transform':
+ T.Compose([T.Resize((256, 256)),
+ T.Transpose(),
+ T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5])]),
+ 'size':
+ 1024,
+ 'style_dim':
+ 512,
+ 'n_mlp':
+ 8,
+ 'channel_multiplier':
+ 2
+ },
+ 'default': {
+ 'transform':
+ T.Compose([T.Resize((256, 256)),
+ T.Transpose(),
+ T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5])])
+ }
+}
+
+
+def run_alignment(image):
+ img = Image.fromarray(image).convert("RGB")
+ face = futils.dlib.detect(img)
+ if not face:
+ raise Exception('Could not find a face in the given image.')
+ face_on_image = face[0]
+ lm = futils.dlib.landmarks(img, face_on_image)
+ lm = np.array(lm)[:, ::-1]
+ lm_eye_left = lm[36:42]
+ lm_eye_right = lm[42:48]
+ lm_mouth_outer = lm[48:60]
+
+ output_size = 1024
+ transform_size = 4096
+ enable_padding = True
+
+ # Calculate auxiliary vectors.
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ eye_avg = (eye_left + eye_right) * 0.5
+ eye_to_eye = eye_right - eye_left
+ mouth_left = lm_mouth_outer[0]
+ mouth_right = lm_mouth_outer[6]
+ mouth_avg = (mouth_left + mouth_right) * 0.5
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Choose oriented crop rectangle.
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ x /= np.hypot(*x)
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
+ y = np.flipud(x) * [-1, 1]
+ c = eye_avg + eye_to_mouth * 0.1
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ qsize = np.hypot(*x) * 2
+
+ # Shrink.
+ shrink = int(np.floor(qsize / output_size * 0.5))
+ if shrink > 1:
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
+ img = img.resize(rsize, Image.ANTIALIAS)
+ quad /= shrink
+ qsize /= shrink
+
+ # Crop.
+ border = max(int(np.rint(qsize * 0.1)), 3)
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
+ min(crop[3] + border, img.size[1]))
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
+ img = img.crop(crop)
+ quad -= crop[0:2]
+
+ # Pad.
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
+ max(pad[3] - img.size[1] + border, 0))
+ if enable_padding and max(pad) > border - 4:
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+ h, w, _ = img.shape
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
+ np.float32(w - 1 - x) / pad[2]),
+ 1.0 - np.minimum(np.float32(y) / pad[1],
+ np.float32(h - 1 - y) / pad[3]))
+ blur = qsize * 0.02
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
+ img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
+ quad += pad[:2]
+
+ # Transform.
+ img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR)
+
+ return img
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+
+class Pixel2Style2PixelPredictor:
+ def __init__(self,
+ weight_path=None,
+ model_type=None,
+ seed=None,
+ size=1024,
+ style_dim=512,
+ n_mlp=8,
+ channel_multiplier=2):
+
+ if weight_path is None and model_type != 'default':
+ if model_type in model_cfgs.keys():
+ weight_path = get_path_from_url(model_cfgs[model_type]['model_urls'])
+ size = model_cfgs[model_type].get('size', size)
+ style_dim = model_cfgs[model_type].get('style_dim', style_dim)
+ n_mlp = model_cfgs[model_type].get('n_mlp', n_mlp)
+ channel_multiplier = model_cfgs[model_type].get('channel_multiplier', channel_multiplier)
+ checkpoint = paddle.load(weight_path)
+ else:
+ raise ValueError('Predictor need a weight path or a pretrained model type')
+ else:
+ checkpoint = paddle.load(weight_path)
+
+ opts = checkpoint.pop('opts')
+ opts = AttrDict(opts)
+ opts['size'] = size
+ opts['style_dim'] = style_dim
+ opts['n_mlp'] = n_mlp
+ opts['channel_multiplier'] = channel_multiplier
+
+ self.generator = Pixel2Style2Pixel(opts)
+ self.generator.set_state_dict(checkpoint)
+ self.generator.eval()
+
+ if seed is not None:
+ paddle.seed(seed)
+ random.seed(seed)
+ np.random.seed(seed)
+
+ self.model_type = 'default' if model_type is None else model_type
+
+ def run(self, image):
+ src_img = run_alignment(image)
+ src_img = np.asarray(src_img)
+ transformed_image = model_cfgs[self.model_type]['transform'](src_img)
+ dst_img, latents = self.generator(
+ paddle.to_tensor(transformed_image[None, ...]), resize=False, return_latents=True)
+ dst_img = (dst_img * 0.5 + 0.5)[0].numpy() * 255
+ dst_img = dst_img.transpose((1, 2, 0))
+ dst_npy = latents[0].numpy()
+
+ return dst_img, dst_npy
diff --git a/modules/image/Image_gan/gan/pixel2style2pixel/module.py b/modules/image/Image_gan/gan/pixel2style2pixel/module.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb054a6f09becd52790df9437abb6de28f42118d
--- /dev/null
+++ b/modules/image/Image_gan/gan/pixel2style2pixel/module.py
@@ -0,0 +1,137 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import argparse
+import copy
+
+import paddle
+import paddlehub as hub
+from paddlehub.module.module import moduleinfo, runnable, serving
+import numpy as np
+import cv2
+from skimage.io import imread
+from skimage.transform import rescale, resize
+
+from .model import Pixel2Style2PixelPredictor
+from .util import base64_to_cv2
+
+
+@moduleinfo(
+ name="pixel2style2pixel",
+ type="CV/style_transfer",
+ author="paddlepaddle",
+ author_email="",
+ summary="",
+ version="1.0.0")
+class pixel2style2pixel:
+ def __init__(self):
+ self.pretrained_model = os.path.join(self.directory, "pSp-ffhq-inversion.pdparams")
+
+ self.network = Pixel2Style2PixelPredictor(weight_path=self.pretrained_model, model_type='ffhq-inversion')
+
+ def style_transfer(self,
+ images=None,
+ paths=None,
+ output_dir='./transfer_result/',
+ use_gpu=False,
+ visualization=True):
+ '''
+
+
+ images (list[numpy.ndarray]): data of images, shape of each is [H, W, C], color space must be BGR(read by cv2).
+ paths (list[str]): paths to images
+ output_dir: the dir to save the results
+ use_gpu: if True, use gpu to perform the computation, otherwise cpu.
+ visualization: if True, save results in output_dir.
+ '''
+ results = []
+ paddle.disable_static()
+ place = 'gpu:0' if use_gpu else 'cpu'
+ place = paddle.set_device(place)
+ if images == None and paths == None:
+ print('No image provided. Please input an image or a image path.')
+ return
+
+ if images != None:
+ for image in images:
+ image = image[:, :, ::-1]
+ out = self.network.run(image)
+ results.append(out)
+
+ if paths != None:
+ for path in paths:
+ image = cv2.imread(path)[:, :, ::-1]
+ out = self.network.run(image)
+ results.append(out)
+
+ if visualization == True:
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir, exist_ok=True)
+ for i, out in enumerate(results):
+ if out is not None:
+ cv2.imwrite(os.path.join(output_dir, 'output_{}.png'.format(i)), out[0][:, :, ::-1])
+ np.save(os.path.join(output_dir, 'output_{}.npy'.format(i)), out[1])
+
+ return results
+
+ @runnable
+ def run_cmd(self, argvs: list):
+ """
+ Run as a command.
+ """
+ self.parser = argparse.ArgumentParser(
+ description="Run the {} module.".format(self.name),
+ prog='hub run {}'.format(self.name),
+ usage='%(prog)s',
+ add_help=True)
+
+ self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
+ self.arg_config_group = self.parser.add_argument_group(
+ title="Config options", description="Run configuration for controlling module behavior, not required.")
+ self.add_module_config_arg()
+ self.add_module_input_arg()
+ self.args = self.parser.parse_args(argvs)
+ results = self.style_transfer(
+ paths=[self.args.input_path],
+ output_dir=self.args.output_dir,
+ use_gpu=self.args.use_gpu,
+ visualization=self.args.visualization)
+ return results
+
+ @serving
+ def serving_method(self, images, **kwargs):
+ """
+ Run as a service.
+ """
+ images_decode = [base64_to_cv2(image) for image in images]
+ results = self.style_transfer(images=images_decode, **kwargs)
+ tolist = [result.tolist() for result in results]
+ return tolist
+
+ def add_module_config_arg(self):
+ """
+ Add the command config options.
+ """
+ self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not")
+
+ self.arg_config_group.add_argument(
+ '--output_dir', type=str, default='transfer_result', help='output directory for saving result.')
+ self.arg_config_group.add_argument('--visualization', type=bool, default=False, help='save results or not.')
+
+ def add_module_input_arg(self):
+ """
+ Add the command input options.
+ """
+ self.arg_input_group.add_argument('--input_path', type=str, help="path to input image.")
diff --git a/modules/image/Image_gan/gan/pixel2style2pixel/requirements.txt b/modules/image/Image_gan/gan/pixel2style2pixel/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d9bfc85782a3ee323241fe7beb87a9f281c120fe
--- /dev/null
+++ b/modules/image/Image_gan/gan/pixel2style2pixel/requirements.txt
@@ -0,0 +1,2 @@
+ppgan
+dlib
diff --git a/modules/image/Image_gan/gan/pixel2style2pixel/util.py b/modules/image/Image_gan/gan/pixel2style2pixel/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88ac3562b74cadc1d4d6459a56097ca4a938a0b
--- /dev/null
+++ b/modules/image/Image_gan/gan/pixel2style2pixel/util.py
@@ -0,0 +1,10 @@
+import base64
+import cv2
+import numpy as np
+
+
+def base64_to_cv2(b64str):
+ data = base64.b64decode(b64str.encode('utf8'))
+ data = np.fromstring(data, np.uint8)
+ data = cv2.imdecode(data, cv2.IMREAD_COLOR)
+ return data