提交 56e73de4 编写于 作者: C chenjian

add psgan module

上级 b3f19bfc
# psgan
|模型名称|psgan|
| :--- | :---: |
|类别|图像 - 妆容迁移|
|网络|PSGAN|
|数据集|-|
|是否支持Fine-tuning|否|
|模型大小|121MB|
|最新更新日期|2021-12-07|
|数据指标|-|
## 一、模型基本信息
- ### 应用效果展示
- 样例结果示例:
<p align="center">
<img src="https://user-images.githubusercontent.com/22424850/145003964-6d6572e0-3103-4898-a738-eb6c61c90be4.jpg" width = "30%" hspace='10'/>
<br />
输入内容图形
<br />
<img src="https://user-images.githubusercontent.com/22424850/145003966-c5c2e6ad-d306-4eaf-89a2-965a3dbf3675.jpg" width = "30%" hspace='10'/>
<br />
输入妆容图形
<br />
<img src="https://user-images.githubusercontent.com/22424850/145003965-288d56f9-49a2-43cb-8647-4a112a8e0dfb.png" width = "30%" hspace='10'/>
<br />
输出图像
<br />
</p>
- ### 模型介绍
- PSGAN模型的任务是妆容迁移, 即将任意参照图像上的妆容迁移到不带妆容的源图像上。很多人像美化应用都需要这种技术。
- 更多详情参考:[PSGAN: Pose and Expression Robust Spatial-Aware GAN for Customizable Makeup Transfer](https://arxiv.org/pdf/1909.06956.pdf)
## 二、安装
- ### 1、环境依赖
- ppgan
- dlib
- ### 2、安装
- ```shell
$ hub install psgan
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测
- ### 1、命令行预测
- ```shell
# Read from a file
$ hub run psgan --content "/PATH/TO/IMAGE" --style "/PATH/TO/IMAGE1"
```
- 通过命令行方式实现妆容转换模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、预测代码示例
- ```python
import paddlehub as hub
module = hub.Module(name="psgan")
content = cv2.imread("/PATH/TO/IMAGE")
style = cv2.imread("/PATH/TO/IMAGE1")
results = module.makeup_transfer(images=[{'content':content, 'style':style}], output_dir='./transfer_result', use_gpu=True)
```
- ### 3、API
- ```python
makeup_transfer(images=None, paths=None, output_dir='./transfer_result/', use_gpu=False, visualization=True)
```
- 妆容风格转换API。
- **参数**
- images (list[dict]): data of images, 每一个元素都为一个 dict,有关键字 content, style, 相应取值为:
- content (numpy.ndarray): 待转换的图片,shape 为 \[H, W, C\],BGR格式;<br/>
- style (numpy.ndarray) : 风格图像,shape为 \[H, W, C\],BGR格式;<br/>
- paths (list[str]): paths to images, 每一个元素都为一个dict, 有关键字 content, style, 相应取值为:
- content (str): 待转换的图片的路径;<br/>
- style (str) : 风格图像的路径;<br/>
- output\_dir (str): 结果保存的路径; <br/>
- use\_gpu (bool): 是否使用 GPU;<br/>
- visualization(bool): 是否保存结果到本地文件夹
## 四、服务部署
- PaddleHub Serving可以部署一个在线妆容风格转换服务。
- ### 第一步:启动PaddleHub Serving
- 运行启动命令:
- ```shell
$ hub serving start -m psgan
```
- 这样就完成了一个妆容风格转换的在线服务API的部署,默认端口号为8866。
- **NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
- ### 第二步:发送预测请求
- 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
- ```python
import requests
import json
import cv2
import base64
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
# 发送HTTP请求
data = {'images':[{'content': cv2_to_base64(cv2.imread("/PATH/TO/IMAGE")), 'style': cv2_to_base64(cv2.imread("/PATH/TO/IMAGE1"))}]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/psgan"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 打印预测结果
print(r.json()["results"])
## 五、更新历史
* 1.0.0
初始发布
- ```shell
$ hub install psgan==1.0.0
```
epochs: 100
output_dir: tmp
checkpoints_dir: checkpoints
find_unused_parameters: True
model:
name: MakeupModel
generator:
name: GeneratorPSGANAttention
conv_dim: 64
repeat_num: 6
discriminator:
name: NLayerDiscriminator
ndf: 64
n_layers: 3
input_nc: 3
norm_type: spectral
cycle_criterion:
name: L1Loss
idt_criterion:
name: L1Loss
loss_weight: 0.5
l1_criterion:
name: L1Loss
l2_criterion:
name: MSELoss
gan_criterion:
name: GANLoss
gan_mode: lsgan
dataset:
train:
name: MakeupDataset
trans_size: 256
dataroot: data/MT-Dataset
cls_list: [non-makeup, makeup]
phase: train
test:
name: MakeupDataset
trans_size: 256
dataroot: data/MT-Dataset
cls_list: [non-makeup, makeup]
phase: test
lr_scheduler:
name: LinearDecay
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1
optimizer:
optimizer_G:
name: Adam
net_names:
- netG
beta1: 0.5
optimizer_DA:
name: Adam
net_names:
- netD_A
beta1: 0.5
optimizer_DB:
name: Adam
net_names:
- netD_B
beta1: 0.5
log_config:
interval: 10
visiual_interval: 500
snapshot_config:
interval: 5
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import argparse
from pathlib import Path
from PIL import Image
import numpy as np
import paddle
import paddle.vision.transforms as T
from paddle.utils.download import get_weights_path_from_url
import ppgan.faceutils as futils
from ppgan.utils.options import parse_args
from ppgan.utils.config import get_config
from ppgan.utils.filesystem import load
from ppgan.models.builder import build_model
from ppgan.utils.preprocess import *
def toImage(net_output):
img = net_output.squeeze(0).transpose((1, 2, 0)).numpy() # [1,c,h,w]->[h,w,c]
img = (img * 255.0).clip(0, 255)
img = np.uint8(img)
img = Image.fromarray(img, mode='RGB')
return img
PS_WEIGHT_URL = "https://paddlegan.bj.bcebos.com/models/psgan_weight.pdparams"
class PreProcess:
def __init__(self, config, need_parser=True):
self.img_size = 256
self.transform = transform = T.Compose([
T.Resize(size=256),
T.ToTensor(),
])
self.norm = T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
if need_parser:
self.face_parser = futils.mask.FaceParser()
self.up_ratio = 0.6 / 0.85
self.down_ratio = 0.2 / 0.85
self.width_ratio = 0.2 / 0.85
def __call__(self, image):
face = futils.dlib.detect(image)
if not face:
return
face_on_image = face[0]
image, face, crop_face = futils.dlib.crop(image, face_on_image, self.up_ratio, self.down_ratio,
self.width_ratio)
np_image = np.array(image)
image_trans = self.transform(np_image)
mask = self.face_parser.parse(np.float32(cv2.resize(np_image, (512, 512))))
mask = cv2.resize(mask.numpy(), (self.img_size, self.img_size), interpolation=cv2.INTER_NEAREST)
mask = mask.astype(np.uint8)
mask_tensor = paddle.to_tensor(mask)
lms = futils.dlib.landmarks(image, face) / image_trans.shape[:2] * self.img_size
lms = lms.round()
P_np = generate_P_from_lmks(lms, self.img_size, self.img_size, self.img_size)
mask_aug = generate_mask_aug(mask, lms)
return [self.norm(image_trans).unsqueeze(0),
np.float32(mask_aug),
np.float32(P_np),
np.float32(mask)], face_on_image, crop_face
class PostProcess:
def __init__(self, config):
self.denoise = True
self.img_size = 256
def __call__(self, source: Image, result: Image):
# TODO: Refract -> name, resize
source = np.array(source)
result = np.array(result)
height, width = source.shape[:2]
small_source = cv2.resize(source, (self.img_size, self.img_size))
laplacian_diff = source.astype(np.float) - cv2.resize(small_source, (width, height)).astype(np.float)
result = (cv2.resize(result, (width, height)) + laplacian_diff).round().clip(0, 255).astype(np.uint8)
if self.denoise:
result = cv2.fastNlMeansDenoisingColored(result)
result = Image.fromarray(result).convert('RGB')
return result
class Inference:
def __init__(self, config, model_path=''):
self.model = build_model(config.model)
self.preprocess = PreProcess(config)
self.model_path = model_path
def transfer(self, source, reference, with_face=False):
source_input, face, crop_face = self.preprocess(source)
reference_input, face, crop_face = self.preprocess(reference)
consis_mask = np.float32(calculate_consis_mask(source_input[1], reference_input[1]))
consis_mask = paddle.to_tensor(np.expand_dims(consis_mask, 0))
if not (source_input and reference_input):
if with_face:
return None, None
return
for i in range(1, len(source_input) - 1):
source_input[i] = paddle.to_tensor(np.expand_dims(source_input[i], 0))
for i in range(1, len(reference_input) - 1):
reference_input[i] = paddle.to_tensor(np.expand_dims(reference_input[i], 0))
input_data = {
'image_A': source_input[0],
'image_B': reference_input[0],
'mask_A_aug': source_input[1],
'mask_B_aug': reference_input[1],
'P_A': source_input[2],
'P_B': reference_input[2],
'consis_mask': consis_mask
}
state_dicts = load(self.model_path)
for net_name, net in self.model.nets.items():
net.set_state_dict(state_dicts[net_name])
result, _ = self.model.test(input_data)
min_, max_ = result.min(), result.max()
result += -min_
result = paddle.divide(result, max_ - min_ + 1e-5)
img = toImage(result)
if with_face:
return img, crop_face
return img
class PSGANPredictor:
def __init__(self, cfg, weight_path):
self.cfg = cfg
self.weight_path = weight_path
def run(self, source, reference):
source = Image.fromarray(source)
reference = Image.fromarray(reference)
inference = Inference(self.cfg, self.weight_path)
postprocess = PostProcess(self.cfg)
# Transfer the psgan from reference to source.
image, face = inference.transfer(source, reference, with_face=True)
source_crop = source.crop((face.left(), face.top(), face.right(), face.bottom()))
image = postprocess(source_crop, image)
image = np.array(image)
return image
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import copy
import paddle
import paddlehub as hub
from paddlehub.module.module import moduleinfo, runnable, serving
import numpy as np
import cv2
from skimage.io import imread
from skimage.transform import rescale, resize
from ppgan.utils.config import get_config
from .model import PSGANPredictor
from .util import base64_to_cv2
@moduleinfo(name="psgan", type="CV/gan", author="paddlepaddle", author_email="", summary="", version="1.0.0")
class psgan:
def __init__(self):
self.pretrained_model = os.path.join(self.directory, "psgan_weight.pdparams")
cfg = get_config(os.path.join(self.directory, 'makeup.yaml'))
self.network = PSGANPredictor(cfg, self.pretrained_model)
def makeup_transfer(self,
images=None,
paths=None,
output_dir='./transfer_result/',
use_gpu=False,
visualization=True):
'''
Transfer a image to stars style.
images (list[dict]): data of images, 每一个元素都为一个 dict,有关键字 content, style, 相应取值为:
- content (numpy.ndarray): 待转换的图片,shape 为 \[H, W, C\],BGR格式;<br/>
- style (numpy.ndarray) : 妆容图像,shape为 \[H, W, C\],BGR格式;<br/>
paths (list[str]): paths to images, 每一个元素都为一个dict, 有关键字 content, style, 相应取值为:
- content (str): 待转换的图片的路径;<br/>
- style (str) : 妆容图像的路径;<br/>
output_dir: the dir to save the results
use_gpu: if True, use gpu to perform the computation, otherwise cpu.
visualization: if True, save results in output_dir.
'''
results = []
paddle.disable_static()
place = 'gpu:0' if use_gpu else 'cpu'
place = paddle.set_device(place)
if images == None and paths == None:
print('No image provided. Please input an image or a image path.')
return
if images != None:
for image_dict in images:
content_img = image_dict['content'][:, :, ::-1]
style_img = image_dict['style'][:, :, ::-1]
results.append(self.network.run(content_img, style_img))
if paths != None:
for path_dict in paths:
content_img = cv2.imread(path_dict['content'])[:, :, ::-1]
style_img = cv2.imread(path_dict['style'])[:, :, ::-1]
results.append(self.network.run(content_img, style_img))
if visualization == True:
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
for i, out in enumerate(results):
cv2.imwrite(os.path.join(output_dir, 'output_{}.png'.format(i)), out[:, :, ::-1])
return results
@runnable
def run_cmd(self, argvs: list):
"""
Run as a command.
"""
self.parser = argparse.ArgumentParser(
description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, not required.")
self.add_module_config_arg()
self.add_module_input_arg()
self.args = self.parser.parse_args(argvs)
self.makeup_transfer(
paths=[{
'content': self.args.content,
'style': self.args.style
}],
output_dir=self.args.output_dir,
use_gpu=self.args.use_gpu,
visualization=self.args.visualization)
@serving
def serving_method(self, images, **kwargs):
"""
Run as a service.
"""
images_decode = copy.deepcopy(images)
for image in images_decode:
image['content'] = base64_to_cv2(image['content'])
image['style'] = base64_to_cv2(image['style'])
results = self.makeup_transfer(images_decode, **kwargs)
tolist = [result.tolist() for result in results]
return tolist
def add_module_config_arg(self):
"""
Add the command config options.
"""
self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not")
self.arg_config_group.add_argument(
'--output_dir', type=str, default='transfer_result', help='output directory for saving result.')
self.arg_config_group.add_argument('--visualization', type=bool, default=False, help='save results or not.')
def add_module_input_arg(self):
"""
Add the command input options.
"""
self.arg_input_group.add_argument('--content', type=str, help="path to content image.")
self.arg_input_group.add_argument('--style', type=str, help="path to style image.")
import base64
import cv2
import numpy as np
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册