未验证 提交 6cbd3e50 编写于 作者: 艾梦 提交者: GitHub

add styleganv2 editor (#455)

上级 a3838811
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import paddle
from ppgan.apps import StyleGANv2EditingPredictor
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--latent",
type=str,
help="path to first image latent codes")
parser.add_argument("--direction_name",
type=str,
default=None,
help="name in directions dictionary")
parser.add_argument("--direction_offset",
type=float,
default=0.0,
help="offset value of edited attribute")
parser.add_argument("--direction_path",
type=str,
default=None,
help="path to latent editing directions")
parser.add_argument("--output_path",
type=str,
default='output_dir',
help="path to output image dir")
parser.add_argument("--weight_path",
type=str,
default=None,
help="path to model checkpoint path")
parser.add_argument("--model_type",
type=str,
default=None,
help="type of model for loading pretrained model")
parser.add_argument("--size",
type=int,
default=1024,
help="resolution of output image")
parser.add_argument("--style_dim",
type=int,
default=512,
help="number of style dimension")
parser.add_argument("--n_mlp",
type=int,
default=8,
help="number of mlp layer depth")
parser.add_argument("--channel_multiplier",
type=int,
default=2,
help="number of channel multiplier")
parser.add_argument("--cpu",
dest="cpu",
action="store_true",
help="cpu mode.")
args = parser.parse_args()
if args.cpu:
paddle.set_device('cpu')
predictor = StyleGANv2EditingPredictor(
output_path=args.output_path,
weight_path=args.weight_path,
model_type=args.model_type,
seed=None,
size=args.size,
style_dim=args.style_dim,
n_mlp=args.n_mlp,
channel_multiplier=args.channel_multiplier,
direction_path=args.direction_path)
predictor.run(args.latent, args.direction_name, args.direction_offset)
# StyleGAN V2 Editing Module
## StyleGAN V2 Editing introduction
The task of StyleGAN V2 is image generation while the Editing module uses the attribute manipulation vector obtained by pre-classifying and regressing the style vector of the multi-image to manipulate the attributes of the generated image.
## How to use
### Editing
The user can use the following command to edit images:
```
cd applications/
python -u tools/styleganv2editing.py \
--latent <PATH TO STYLE VECTOR> \
--output_path <DIRECTORY TO STORE OUTPUT IMAGE> \
--weight_path <YOUR PRETRAINED MODEL PATH> \
--model_type ffhq-config-f \
--size 1024 \
--style_dim 512 \
--n_mlp 8 \
--channel_multiplier 2 \
--direction_path <PATH TO STORE ATTRIBUTE DIRECTIONS> \
--direction_name <ATTRIBUTE NAME TO BE MANIPULATED> \
--direction_offset 0.0 \
--cpu
```
**params:**
- latent: The path of the style vector which represents an image. Come from `dst.npy` generated by Pixel2Style2Pixel or `dst.fitting.npy` generated by StyleGANv2 Fitting module
- output_path: the directory where the generated images are stored
- weight_path: pretrained model path
- model_type: inner model type in PaddleGAN. If you use an existing model type, `weight_path` will have no effect.
Currently recommended use: `ffhq-config-f`
- size: model parameters, output image resolution
- style_dim: model parameters, dimensions of style z
- n_mlp: model parameters, the number of multi-layer perception layers for style z
- channel_multiplier: model parameters, channel product, affect model size and the quality of generated pictures
- direction_path: The path of the file storing a series of attribute names and object attribute vectors. The default is empty, that is, the file that comes with ppgan is used. If you don’t use it, please remove it from the command
- direction_name: Attribute to be manipulated,For `ffhq-conf-f`, we have: age, eyes_open, eye_distance, eye_eyebrow_distance, eye_ratio, gender, lip_ratio, mouth_open, mouth_ratio, nose_mouth_distance, nose_ratio, nose_tip, pitch, roll, smile, yaw
- direction_offset: Offset strength of the attribute
- cpu: whether to use cpu inference, if not, please remove it from the command
## Editing Results
The image corresponding to the style vector:
<div align="center">
<img src="../../imgs/stylegan2fitting-sample.png" width="300"/>
</div>
The image obtained by editing the `age` attribute according to [-5,-2.5,0,2.5,5]:
<div align="center">
<img src="../../imgs/stylegan2editing-sample1.png" width="640"/>
</div>
The image obtained by further editing the `gender` to the style vector obtained by the `-5` offset:
<div align="center">
<img src="../../imgs/stylegan2editing-sample2.png" width="640"/>
</div>
## Make Attribute Direction Vector
For details, please refer to [Puzer/stylegan-encoder](https://github.com/Puzer/stylegan-encoder/blob/master/Learn_direction_in_latent_space.ipynb)
## Reference
- 1. [Analyzing and Improving the Image Quality of StyleGAN](https://arxiv.org/abs/1912.04958)
```
@article{Karras2019stylegan2,
title={Analyzing and Improving the Image Quality of {StyleGAN}},
author={Tero Karras and Samuli Laine and Miika Aittala and Janne Hellsten and Jaakko Lehtinen and Timo Aila},
booktitle={Proc. CVPR},
year={2020}
}
```
- 2. [Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation](hhttps://arxiv.org/abs/2008.00951)
```
@article{richardson2020encoding,
title={Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation},
author={Richardson, Elad and Alaluf, Yuval and Patashnik, Or and Nitzan, Yotam and Azar, Yaniv and Shapiro, Stav and Cohen-Or, Daniel},
journal={arXiv preprint arXiv:2008.00951},
year={2020}
}
```
......@@ -9,7 +9,7 @@ The task of StyleGAN V2 is image generation while the Mixing module uses its sty
### Mixing
The user can use the following command to fit images:
The user can use the following command to mix images:
```
cd applications/
......@@ -35,13 +35,6 @@ python -u tools/styleganv2mixing.py \
- latent2: The path of the second style vector. The source is the same as the first style vector
- weights: The two style vectors are mixed in different proportions at different levels. For a resolution of 1024, there are 18 levels. For a resolution of 512, there are 16 levels, and so on.
The more in front, the more it affects the whole of the mixed image. The more behind, the more it affects the details of the mixed image. In the figure below we show the fusion results of different weights for reference.
- need_align: whether to crop the image to an image that can be recognized by the model. For an image that has been cropped, such as the `src.png` that is pre-generated when Pixel2Style2Pixel is used to generate the style vector, the need_align parameter may not be filled in
- start_lr: learning rate at the begin of training
- final_lr: learning rate at the end of training
- latent_level: The style vector level involved in fitting is 0~17 at 1024 resolution, 0~15 at 512 resolution, and so on. The lower the level, the more biased toward the overall style change. The higher the level, the more biased toward the detail style change
- step: the number of steps required to fit the image, the larger the number of steps, the longer it takes and the better the effect
- mse_weight: weight of MSE loss
- pre_latent: The pre-made style vector files are saved to facilitate better fitting. The default is empty, you can fill in the file path of `dst.npy` generated by Pixel2Style2Pixel
- output_path: the directory where the generated images are stored
- weight_path: pretrained model path
- model_type: inner model type in PaddleGAN. If you use an existing model type, `weight_path` will have no effect.
......@@ -52,7 +45,7 @@ python -u tools/styleganv2mixing.py \
- channel_multiplier: model parameters, channel product, affect model size and the quality of generated pictures
- cpu: whether to use cpu inference, if not, please remove it from the command
## Fitting Results
## Mixing Results
The image corresponding to the first style vector:
......
# StyleGAN V2 Editing 模块
## StyleGAN V2 Editing 原理
StyleGAN V2 的任务是使用风格向量进行image generation,而Editing模块则是利用预先对多图的风格向量进行分类回归得到的属性操纵向量来操纵生成图像的属性
## 使用方法
### 编辑
用户使用如下命令中对图像属性进行编辑:
```
cd applications/
python -u tools/styleganv2editing.py \
--latent <替换为要编辑的风格向量的路径> \
--output_path <替换为生成图片存放的文件夹> \
--weight_path <替换为你的预训练模型路径> \
--model_type ffhq-config-f \
--size 1024 \
--style_dim 512 \
--n_mlp 8 \
--channel_multiplier 2 \
--direction_path <替换为存放属性向量的文件路径> \
--direction_name <替换为你操纵的属性名称> \
--direction_offset 0.0 \
--cpu
```
**参数说明:**
- latent: 要编辑的代表图像的风格向量的路径。可来自于Pixel2Style2Pixel生成的`dst.npy`或StyleGANv2 Fitting模块生成的`dst.fitting.npy`
- latent2: 第二个风格向量的路径。来源同第一个风格向量
- output_path: 生成图片存放的文件夹
- weight_path: 预训练模型路径
- model_type: PaddleGAN内置模型类型,若输入PaddleGAN已存在的模型类型,`weight_path`将失效。当前建议使用: `ffhq-config-f`
- size: 模型参数,输出图片的分辨率
- style_dim: 模型参数,风格z的维度
- n_mlp: 模型参数,风格z所输入的多层感知层的层数
- channel_multiplier: 模型参数,通道乘积,影响模型大小和生成图片质量
- direction_path: 存放一系列属性名称及对象属性向量的文件的路径。默认为空,即使用ppgan自带的文件。若不使用,请在命令中去除
- direction_name: 要编辑的属性名称,对于`ffhq-conf-f`有预先准备的这些属性: age、eyes_open、eye_distance、eye_eyebrow_distance、eye_ratio、gender、lip_ratio、mouth_open、mouth_ratio、nose_mouth_distance、nose_ratio、nose_tip、pitch、roll、smile、yaw
- direction_offset: 属性的偏移强度
- cpu: 是否使用cpu推理,若不使用,请在命令中去除
## 编辑结果展示
风格向量对应的图像:
<div align="center">
<img src="../../imgs/stylegan2fitting-sample.png" width="300"/>
</div>
按[-5,-2.5,0,2.5,5]进行`age`(年龄)属性编辑得到的图像:
<div align="center">
<img src="../../imgs/stylegan2editing-sample1.png" width="640"/>
</div>
`-5`偏移得到的风格向量进一步进行`gender`(性别)编辑得到的图像:
<div align="center">
<img src="../../imgs/stylegan2editing-sample2.png" width="640"/>
</div>
## 制作属性向量
具体可以参考[Puzer/stylegan-encoder](https://github.com/Puzer/stylegan-encoder/blob/master/Learn_direction_in_latent_space.ipynb)中的做法。
# 参考文献
- 1. [Analyzing and Improving the Image Quality of StyleGAN](https://arxiv.org/abs/1912.04958)
```
@article{Karras2019stylegan2,
title={Analyzing and Improving the Image Quality of {StyleGAN}},
author={Tero Karras and Samuli Laine and Miika Aittala and Janne Hellsten and Jaakko Lehtinen and Timo Aila},
booktitle={Proc. CVPR},
year={2020}
}
```
- 2. [Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation](hhttps://arxiv.org/abs/2008.00951)
```
@article{richardson2020encoding,
title={Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation},
author={Richardson, Elad and Alaluf, Yuval and Patashnik, Or and Nitzan, Yotam and Azar, Yaniv and Shapiro, Stav and Cohen-Or, Daniel},
journal={arXiv preprint arXiv:2008.00951},
year={2020}
}
```
......@@ -8,7 +8,7 @@ StyleGAN V2 的任务是使用风格向量进行image generation,而Mixing模
### 混合
用户使用如下命令中进行合:
用户使用如下命令中进行合:
```
cd applications/
......@@ -43,7 +43,7 @@ python -u tools/styleganv2mixing.py \
- channel_multiplier: 模型参数,通道乘积,影响模型大小和生成图片质量
- cpu: 是否使用cpu推理,若不使用,请在命令中去除
## 合结果展示
## 合结果展示
第一个风格向量对应的图像:
......
......@@ -25,6 +25,7 @@ from .photo2cartoon_predictor import Photo2CartoonPredictor
from .styleganv2_predictor import StyleGANv2Predictor
from .styleganv2fitting_predictor import StyleGANv2FittingPredictor
from .styleganv2mixing_predictor import StyleGANv2MixingPredictor
from .styleganv2editing_predictor import StyleGANv2EditingPredictor
from .pixel2style2pixel_predictor import Pixel2Style2PixelPredictor
from .wav2lip_predictor import Wav2LipPredictor
from .mpr_predictor import MPRPredictor
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import numpy as np
import paddle
from ppgan.utils.download import get_path_from_url
from .styleganv2_predictor import StyleGANv2Predictor
model_cfgs = {
'ffhq-config-f': {
'direction_urls':
'https://paddlegan.bj.bcebos.com/models/stylegan2-ffhq-config-f-directions.pdparams'
}
}
def make_image(tensor):
return (((tensor.detach() + 1) / 2 * 255).clip(min=0, max=255).transpose(
(0, 2, 3, 1)).numpy().astype('uint8'))
class StyleGANv2EditingPredictor(StyleGANv2Predictor):
def __init__(self, model_type=None, direction_path=None, **kwargs):
super().__init__(model_type=model_type, **kwargs)
if direction_path is None and model_type is not None:
assert model_type in model_cfgs, f'There is not any pretrained direction file for {model_type} model.'
direction_path = get_path_from_url(
model_cfgs[model_type]['direction_urls'])
self.directions = paddle.load(direction_path)
@paddle.no_grad()
def run(self, latent, direction, offset):
latent = paddle.to_tensor(
np.load(latent)).unsqueeze(0).astype('float32')
direction = self.directions[direction].unsqueeze(0).astype('float32')
latent_n = paddle.concat([latent, latent + offset * direction], 0)
generator = self.generator
img_gen, _ = generator([latent_n],
input_is_latent=True,
randomize_noise=False)
imgs = make_image(img_gen)
src_img = imgs[0]
dst_img = imgs[1]
dst_latent = (latent + offset * direction)[0].numpy().astype('float32')
os.makedirs(self.output_path, exist_ok=True)
save_src_path = os.path.join(self.output_path, 'src.editing.png')
cv2.imwrite(save_src_path, cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR))
save_dst_path = os.path.join(self.output_path, 'dst.editing.png')
cv2.imwrite(save_dst_path, cv2.cvtColor(dst_img, cv2.COLOR_RGB2BGR))
save_npy_path = os.path.join(self.output_path, 'dst.editing.npy')
np.save(save_npy_path, dst_latent)
return src_img, dst_img, dst_latent
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册