未验证 提交 0d07d011 编写于 作者: C chenjian 提交者: GitHub

add styleganv2editing module (#1737)

* add styleganv2editing module

* modify according to review

* modify demo image
Co-authored-by: Nwuzewu <wuzewu@baidu.com>
Co-authored-by: NKP <109694228@qq.com>
上级 f622ba3b
# styleganv2_editing
|模型名称|styleganv2_editing|
| :--- | :---: |
|类别|图像 - 图像生成|
|网络|StyleGAN V2|
|数据集|-|
|是否支持Fine-tuning|否|
|模型大小|190MB|
|最新更新日期|2021-12-15|
|数据指标|-|
## 一、模型基本信息
- ### 应用效果展示
- 样例结果示例:
<p align="center">
<img src="https://user-images.githubusercontent.com/22424850/146483720-fb0ea3c0-b259-4ad6-b176-966675b9b164.png" width = "40%" hspace='10'/>
<br />
输入图像
<br />
<img src="https://user-images.githubusercontent.com/22424850/146483730-3104795e-4ee6-43de-b4dc-b7760d502b50.png" width = "40%" hspace='10'/>
<br />
输出图像(修改age)
<br />
</p>
- ### 模型介绍
- StyleGAN V2 的任务是使用风格向量进行image generation,而Editing模块则是利用预先对多图的风格向量进行分类回归得到的属性操纵向量来操纵生成图像的属性。
## 二、安装
- ### 1、环境依赖
- ppgan
- ### 2、安装
- ```shell
$ hub install styleganv2_editing
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测
- ### 1、命令行预测
- ```shell
# Read from a file
$ hub run styleganv2_editing --input_path "/PATH/TO/IMAGE" --direction_name age --direction_offset 5
```
- 通过命令行方式实现人脸编辑模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、预测代码示例
- ```python
import paddlehub as hub
module = hub.Module(name="styleganv2_editing")
input_path = ["/PATH/TO/IMAGE"]
# Read from a file
module.generate(paths=input_path, direction_name = 'age', direction_offset = 5, output_dir='./editing_result/', use_gpu=True)
```
- ### 3、API
- ```python
generate(self, images=None, paths=None, direction_name = 'age', direction_offset = 0.0, output_dir='./editing_result/', use_gpu=False, visualization=True)
```
- 人脸编辑生成API。
- **参数**
- images (list\[numpy.ndarray\]): 图片数据 <br/>
- paths (list\[str\]): 图片路径;<br/>
- direction_name (str): 要编辑的属性名称,对于ffhq-conf-f有预先准备的这些属性: age、eyes_open、eye_distance、eye_eyebrow_distance、eye_ratio、gender、lip_ratio、mouth_open、mouth_ratio、nose_mouth_distance、nose_ratio、nose_tip、pitch、roll、smile、yaw <br/>
- direction_offset (float): 属性的偏移强度 <br/>
- output\_dir (str): 结果保存的路径; <br/>
- use\_gpu (bool): 是否使用 GPU;<br/>
- visualization(bool): 是否保存结果到本地文件夹
## 四、服务部署
- PaddleHub Serving可以部署一个在线人脸编辑服务。
- ### 第一步:启动PaddleHub Serving
- 运行启动命令:
- ```shell
$ hub serving start -m styleganv2_editing
```
- 这样就完成了一个人脸编辑的在线服务API的部署,默认端口号为8866。
- **NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
- ### 第二步:发送预测请求
- 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
- ```python
import requests
import json
import cv2
import base64
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
# 发送HTTP请求
data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/styleganv2_editing"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 打印预测结果
print(r.json()["results"])
## 五、更新历史
* 1.0.0
初始发布
- ```shell
$ hub install styleganv2_editing==1.0.0
```
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import numpy as np
import paddle
from ppgan.models.generators import StyleGANv2Generator
from ppgan.utils.download import get_path_from_url
from ppgan.utils.visual import make_grid, tensor2img, save_image
model_cfgs = {
'ffhq-config-f': {
'model_urls': 'https://paddlegan.bj.bcebos.com/models/stylegan2-ffhq-config-f.pdparams',
'size': 1024,
'style_dim': 512,
'n_mlp': 8,
'channel_multiplier': 2
},
'animeface-512': {
'model_urls': 'https://paddlegan.bj.bcebos.com/models/stylegan2-animeface-512.pdparams',
'size': 512,
'style_dim': 512,
'n_mlp': 8,
'channel_multiplier': 2
}
}
@paddle.no_grad()
def get_mean_style(generator):
mean_style = None
for i in range(10):
style = generator.mean_latent(1024)
if mean_style is None:
mean_style = style
else:
mean_style += style
mean_style /= 10
return mean_style
@paddle.no_grad()
def sample(generator, mean_style, n_sample):
image = generator(
[paddle.randn([n_sample, generator.style_dim])],
truncation=0.7,
truncation_latent=mean_style,
)[0]
return image
@paddle.no_grad()
def style_mixing(generator, mean_style, n_source, n_target):
source_code = paddle.randn([n_source, generator.style_dim])
target_code = paddle.randn([n_target, generator.style_dim])
resolution = 2**((generator.n_latent + 2) // 2)
images = [paddle.ones([1, 3, resolution, resolution]) * -1]
source_image = generator([source_code], truncation_latent=mean_style, truncation=0.7)[0]
target_image = generator([target_code], truncation_latent=mean_style, truncation=0.7)[0]
images.append(source_image)
for i in range(n_target):
image = generator(
[target_code[i].unsqueeze(0).tile([n_source, 1]), source_code],
truncation_latent=mean_style,
truncation=0.7,
)[0]
images.append(target_image[i].unsqueeze(0))
images.append(image)
images = paddle.concat(images, 0)
return images
class StyleGANv2Predictor:
def __init__(self,
output_path='output_dir',
weight_path=None,
model_type=None,
seed=None,
size=1024,
style_dim=512,
n_mlp=8,
channel_multiplier=2):
self.output_path = output_path
if weight_path is None:
if model_type in model_cfgs.keys():
weight_path = get_path_from_url(model_cfgs[model_type]['model_urls'])
size = model_cfgs[model_type].get('size', size)
style_dim = model_cfgs[model_type].get('style_dim', style_dim)
n_mlp = model_cfgs[model_type].get('n_mlp', n_mlp)
channel_multiplier = model_cfgs[model_type].get('channel_multiplier', channel_multiplier)
checkpoint = paddle.load(weight_path)
else:
raise ValueError('Predictor need a weight path or a pretrained model type')
else:
checkpoint = paddle.load(weight_path)
self.generator = StyleGANv2Generator(size, style_dim, n_mlp, channel_multiplier)
self.generator.set_state_dict(checkpoint)
self.generator.eval()
if seed is not None:
paddle.seed(seed)
random.seed(seed)
np.random.seed(seed)
def run(self, n_row=3, n_col=5):
os.makedirs(self.output_path, exist_ok=True)
mean_style = get_mean_style(self.generator)
img = sample(self.generator, mean_style, n_row * n_col)
save_image(tensor2img(make_grid(img, nrow=n_col)), f'{self.output_path}/sample.png')
for j in range(2):
img = style_mixing(self.generator, mean_style, n_col, n_row)
save_image(tensor2img(make_grid(img, nrow=n_col + 1)), f'{self.output_path}/sample_mixing_{j}.png')
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import numpy as np
import paddle
from ppgan.utils.download import get_path_from_url
from .basemodel import StyleGANv2Predictor
model_cfgs = {
'ffhq-config-f': {
'direction_urls': 'https://paddlegan.bj.bcebos.com/models/stylegan2-ffhq-config-f-directions.pdparams'
}
}
def make_image(tensor):
return (((tensor.detach() + 1) / 2 * 255).clip(min=0, max=255).transpose((0, 2, 3, 1)).numpy().astype('uint8'))
class StyleGANv2EditingPredictor(StyleGANv2Predictor):
def __init__(self, model_type=None, direction_path=None, **kwargs):
super().__init__(model_type=model_type, **kwargs)
if direction_path is None and model_type is not None:
assert model_type in model_cfgs, f'There is not any pretrained direction file for {model_type} model.'
direction_path = get_path_from_url(model_cfgs[model_type]['direction_urls'])
self.directions = paddle.load(direction_path)
@paddle.no_grad()
def run(self, latent, direction, offset):
latent = paddle.to_tensor(latent).unsqueeze(0).astype('float32')
direction = self.directions[direction].unsqueeze(0).astype('float32')
latent_n = paddle.concat([latent, latent + offset * direction], 0)
generator = self.generator
img_gen, _ = generator([latent_n], input_is_latent=True, randomize_noise=False)
imgs = make_image(img_gen)
src_img = imgs[0]
dst_img = imgs[1]
dst_latent = (latent + offset * direction)[0].numpy().astype('float32')
return src_img, dst_img, dst_latent
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import copy
import paddle
import paddlehub as hub
from paddlehub.module.module import moduleinfo, runnable, serving
import numpy as np
import cv2
from skimage.io import imread
from skimage.transform import rescale, resize
from .model import StyleGANv2EditingPredictor
from .util import base64_to_cv2
@moduleinfo(
name="styleganv2_editing",
type="CV/style_transfer",
author="paddlepaddle",
author_email="",
summary="",
version="1.0.0")
class styleganv2_editing:
def __init__(self):
self.pretrained_model = os.path.join(self.directory, "stylegan2-ffhq-config-f-directions.pdparams")
self.network = StyleGANv2EditingPredictor(direction_path=self.pretrained_model, model_type='ffhq-config-f')
self.pixel2style2pixel_module = hub.Module(name='pixel2style2pixel')
def generate(self,
images=None,
paths=None,
direction_name='age',
direction_offset=0.0,
output_dir='./editing_result/',
use_gpu=False,
visualization=True):
'''
images (list[numpy.ndarray]): data of images, shape of each is [H, W, C], color space must be BGR(read by cv2).
paths (list[str]): paths to image.
direction_name(str): Attribute to be manipulated,For ffhq-conf-f, we have: age, eyes_open, eye_distance, eye_eyebrow_distance, eye_ratio, gender, lip_ratio, mouth_open, mouth_ratio, nose_mouth_distance, nose_ratio, nose_tip, pitch, roll, smile, yaw.
direction_offset(float): Offset strength of the attribute.
output_dir: the dir to save the results
use_gpu: if True, use gpu to perform the computation, otherwise cpu.
visualization: if True, save results in output_dir.
'''
results = []
paddle.disable_static()
place = 'gpu:0' if use_gpu else 'cpu'
place = paddle.set_device(place)
if images == None and paths == None:
print('No image provided. Please input an image or a image path.')
return
if images != None:
for image in images:
image = image[:, :, ::-1]
_, latent = self.pixel2style2pixel_module.network.run(image)
out = self.network.run(latent, direction_name, direction_offset)
results.append(out)
if paths != None:
for path in paths:
image = cv2.imread(path)[:, :, ::-1]
_, latent = self.pixel2style2pixel_module.network.run(image)
out = self.network.run(latent, direction_name, direction_offset)
results.append(out)
if visualization == True:
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
for i, out in enumerate(results):
if out is not None:
cv2.imwrite(os.path.join(output_dir, 'src_{}.png'.format(i)), out[0][:, :, ::-1])
cv2.imwrite(os.path.join(output_dir, 'dst_{}.png'.format(i)), out[1][:, :, ::-1])
np.save(os.path.join(output_dir, 'dst_{}.npy'.format(i)), out[2])
return results
@runnable
def run_cmd(self, argvs: list):
"""
Run as a command.
"""
self.parser = argparse.ArgumentParser(
description="Run the {} module.".format(self.name),
prog='hub run {}'.format(self.name),
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, not required.")
self.add_module_config_arg()
self.add_module_input_arg()
self.args = self.parser.parse_args(argvs)
results = self.generate(
paths=[self.args.input_path],
direction_name=self.args.direction_name,
direction_offset=self.args.direction_offset,
output_dir=self.args.output_dir,
use_gpu=self.args.use_gpu,
visualization=self.args.visualization)
return results
@serving
def serving_method(self, images, **kwargs):
"""
Run as a service.
"""
images_decode = [base64_to_cv2(image) for image in images]
results = self.generate(images=images_decode, **kwargs)
tolist = [result.tolist() for result in results]
return tolist
def add_module_config_arg(self):
"""
Add the command config options.
"""
self.arg_config_group.add_argument('--use_gpu', action='store_true', help="use GPU or not")
self.arg_config_group.add_argument(
'--output_dir', type=str, default='editing_result', help='output directory for saving result.')
self.arg_config_group.add_argument('--visualization', type=bool, default=False, help='save results or not.')
def add_module_input_arg(self):
"""
Add the command input options.
"""
self.arg_input_group.add_argument('--input_path', type=str, help="path to input image.")
self.arg_input_group.add_argument(
'--direction_name',
type=str,
default='age',
help=
"Attribute to be manipulated,For ffhq-conf-f, we have: age, eyes_open, eye_distance, eye_eyebrow_distance, eye_ratio, gender, lip_ratio, mouth_open, mouth_ratio, nose_mouth_distance, nose_ratio, nose_tip, pitch, roll, smile, yaw."
)
self.arg_input_group.add_argument('--direction_offset', type=float, help="Offset strength of the attribute.")
import base64
import cv2
import numpy as np
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册