未验证 提交 fc9becf0 编写于 作者: 吴何聪 提交者: GitHub

Add stylegan v2 fitting module and mixing module to expand stylegan v2 application. (#361)

* add fitting module for styleganv2 and format some related codes
上级 49e549ee
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import paddle
import os
import sys
sys.path.insert(0, os.getcwd())
from ppgan.apps import Pixel2Style2PixelPredictor
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......@@ -33,17 +45,17 @@ if __name__ == "__main__":
type=int,
default=1024,
help="resolution of output image")
parser.add_argument("--style_dim",
type=int,
default=512,
help="number of style dimension")
parser.add_argument("--n_mlp",
type=int,
default=8,
help="number of mlp layer depth")
parser.add_argument("--channel_multiplier",
type=int,
default=2,
......@@ -67,6 +79,5 @@ if __name__ == "__main__":
size=args.size,
style_dim=args.style_dim,
n_mlp=args.n_mlp,
channel_multiplier=args.channel_multiplier
)
channel_multiplier=args.channel_multiplier)
predictor.run(args.input_image)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import paddle
import os
import sys
sys.path.insert(0, os.getcwd())
from ppgan.apps import StyleGANv2Predictor
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......@@ -26,22 +38,22 @@ if __name__ == "__main__":
type=int,
default=None,
help="sample random seed for model's image generation")
parser.add_argument("--size",
type=int,
default=1024,
help="resolution of output image")
parser.add_argument("--style_dim",
type=int,
default=512,
help="number of style dimension")
parser.add_argument("--n_mlp",
type=int,
default=8,
help="number of mlp layer depth")
parser.add_argument("--channel_multiplier",
type=int,
default=2,
......@@ -67,14 +79,12 @@ if __name__ == "__main__":
if args.cpu:
paddle.set_device('cpu')
predictor = StyleGANv2Predictor(
output_path=args.output_path,
weight_path=args.weight_path,
model_type=args.model_type,
seed=args.seed,
size=args.size,
style_dim=args.style_dim,
n_mlp=args.n_mlp,
channel_multiplier=args.channel_multiplier
)
predictor = StyleGANv2Predictor(output_path=args.output_path,
weight_path=args.weight_path,
model_type=args.model_type,
seed=args.seed,
size=args.size,
style_dim=args.style_dim,
n_mlp=args.n_mlp,
channel_multiplier=args.channel_multiplier)
predictor.run(args.n_row, args.n_col)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import paddle
from ppgan.apps import StyleGANv2FittingPredictor
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_image", type=str, help="path to source image")
parser.add_argument("--need_align",
action="store_true",
help="whether to align input image")
parser.add_argument("--start_lr",
type=float,
default=0.1,
help="learning rate at the begin of training")
parser.add_argument("--final_lr",
type=float,
default=0.025,
help="learning rate at the end of training")
parser.add_argument("--latent_level",
type=int,
nargs="+",
default=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
help="indices of latent code for training")
parser.add_argument("--step",
type=int,
default=100,
help="optimize iterations")
parser.add_argument("--mse_weight",
type=float,
default=1,
help="weight of the mse loss")
parser.add_argument("--pre_latent",
type=str,
default=None,
help="path to pre-prepared latent codes")
parser.add_argument("--output_path",
type=str,
default='output_dir',
help="path to output image dir")
parser.add_argument("--weight_path",
type=str,
default=None,
help="path to model checkpoint path")
parser.add_argument("--model_type",
type=str,
default=None,
help="type of model for loading pretrained model")
parser.add_argument("--size",
type=int,
default=1024,
help="resolution of output image")
parser.add_argument("--style_dim",
type=int,
default=512,
help="number of style dimension")
parser.add_argument("--n_mlp",
type=int,
default=8,
help="number of mlp layer depth")
parser.add_argument("--channel_multiplier",
type=int,
default=2,
help="number of channel multiplier")
parser.add_argument("--cpu",
dest="cpu",
action="store_true",
help="cpu mode.")
args = parser.parse_args()
if args.cpu:
paddle.set_device('cpu')
predictor = StyleGANv2FittingPredictor(
output_path=args.output_path,
weight_path=args.weight_path,
model_type=args.model_type,
seed=None,
size=args.size,
style_dim=args.style_dim,
n_mlp=args.n_mlp,
channel_multiplier=args.channel_multiplier)
predictor.run(args.input_image,
need_align=args.need_align,
start_lr=args.start_lr,
final_lr=args.final_lr,
latent_level=args.latent_level,
step=args.step,
mse_weight=args.mse_weight,
pre_latent=args.pre_latent)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import paddle
from ppgan.apps import StyleGANv2MixingPredictor
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--latent1",
type=str,
help="path to first image latent codes")
parser.add_argument("--latent2",
type=str,
help="path to second image latent codes")
parser.add_argument(
"--weights",
type=float,
nargs="+",
default=[0.5] * 18,
help="different weights at each level of two latent codes")
parser.add_argument("--output_path",
type=str,
default='output_dir',
help="path to output image dir")
parser.add_argument("--weight_path",
type=str,
default=None,
help="path to model checkpoint path")
parser.add_argument("--model_type",
type=str,
default=None,
help="type of model for loading pretrained model")
parser.add_argument("--size",
type=int,
default=1024,
help="resolution of output image")
parser.add_argument("--style_dim",
type=int,
default=512,
help="number of style dimension")
parser.add_argument("--n_mlp",
type=int,
default=8,
help="number of mlp layer depth")
parser.add_argument("--channel_multiplier",
type=int,
default=2,
help="number of channel multiplier")
parser.add_argument("--cpu",
dest="cpu",
action="store_true",
help="cpu mode.")
args = parser.parse_args()
if args.cpu:
paddle.set_device('cpu')
predictor = StyleGANv2MixingPredictor(
output_path=args.output_path,
weight_path=args.weight_path,
model_type=args.model_type,
seed=None,
size=args.size,
style_dim=args.style_dim,
n_mlp=args.n_mlp,
channel_multiplier=args.channel_multiplier)
predictor.run(args.latent1, args.latent2, args.weights)
# StyleGAN V2 Fitting Module
## StyleGAN V2 Fitting introduction
The task of StyleGAN V2 is image generation while the Fitting module inversely derives the style vector with a high degree of decoupling based on the existing image. The generated style vector can be used in tasks such as face fusion and face attribute editing.
## How to use
### Fitting
The user can use the following command to fit images:
```
cd applications/
python -u tools/styleganv2fitting.py \
--input_image <YOUR INPUT IMAGE> \
--need_align \
--start_lr 0.1 \
--final_lr 0.025 \
--latent_level 0 1 2 3 4 5 6 7 8 9 10 11 \
--step 100 \
--mse_weight 1 \
--pre_latent <PRE-PREPARED LATENT CODE PATH> \
--output_path <DIRECTORY TO STORE OUTPUT IMAGE> \
--weight_path <YOUR PRETRAINED MODEL PATH> \
--model_type ffhq-config-f \
--size 1024 \
--style_dim 512 \
--n_mlp 8 \
--channel_multiplier 2 \
--cpu
```
**params:**
- input_image: the input image file path
- need_align: whether to crop the image to an image that can be recognized by the model. For an image that has been cropped, such as the `src.png` that is pre-generated when Pixel2Style2Pixel is used to generate the style vector, the need_align parameter may not be filled in
- start_lr: learning rate at the begin of training
- final_lr: learning rate at the end of training
- latent_level: The style vector levels involved in fitting are from 0 to 17 at 1024 resolution, from 0 to 15 at 512 resolution, and so on. The lower the level, the more biased toward the overall style change. The higher the level, the more biased toward the detail style change
- step: the number of steps required to fit the image, the larger the number of steps, the longer it takes and the better the effect
- mse_weight: weight of MSE loss
- pre_latent: The pre-made style vector files are saved to facilitate better fitting. The default is empty, you can fill in the file path of `dst.npy` generated by Pixel2Style2Pixel
- output_path: the directory where the generated images are stored
- weight_path: pretrained model path
- model_type: inner model type in PaddleGAN. If you use an existing model type, `weight_path` will have no effect.
Currently recommended use: `ffhq-config-f`
- size: model parameters, output image resolution
- style_dim: model parameters, dimensions of style z
- n_mlp: model parameters, the number of multi-layer perception layers for style z
- channel_multiplier: model parameters, channel product, affect model size and the quality of generated pictures
- cpu: whether to use cpu inference, if not, please remove it from the command
## Fitting Results
Source image:
<div align="center">
<img src="../../imgs/pSp-input.jpg" width="300"/>
</div>
Image encoded by Pixel2Style2Pixel:
<div align="center">
<img src="../../imgs/pSp-inversion.png" width="256"/>
</div>
After passing the style vector generated by Pixel2Style2Pixel, use the Fitting module to perform 1000 steps of fitting to get the result:
<div align="center">
<img src="../../imgs/stylegan2fitting-sample.png" width="256"/>
</div>
## Reference
- 1. [Analyzing and Improving the Image Quality of StyleGAN](https://arxiv.org/abs/1912.04958)
```
@article{Karras2019stylegan2,
title={Analyzing and Improving the Image Quality of {StyleGAN}},
author={Tero Karras and Samuli Laine and Miika Aittala and Janne Hellsten and Jaakko Lehtinen and Timo Aila},
booktitle={Proc. CVPR},
year={2020}
}
```
- 2. [Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation](hhttps://arxiv.org/abs/2008.00951)
```
@article{richardson2020encoding,
title={Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation},
author={Richardson, Elad and Alaluf, Yuval and Patashnik, Or and Nitzan, Yotam and Azar, Yaniv and Shapiro, Stav and Cohen-Or, Daniel},
journal={arXiv preprint arXiv:2008.00951},
year={2020}
}
```
# StyleGAN V2 Mixing Module
## StyleGAN V2 Mixing introduction
The task of StyleGAN V2 is image generation while the Mixing module uses its style vector to achieve the mixing of two generated images with different levels and different proportions.
## How to use
### Mixing
The user can use the following command to fit images:
```
cd applications/
python -u tools/styleganv2mixing.py \
--latent1 <PATH TO FIRST STYLE VECTOR> \
--latent2 <PATH TO SECOND STYLE VECTOR> \
--weights \
0.5 0.5 0.5 0.5 0.5 0.5 \
0.5 0.5 0.5 0.5 0.5 0.5 \
0.5 0.5 0.5 0.5 0.5 0.5 \
--output_path <DIRECTORY TO STORE OUTPUT IMAGE> \
--weight_path <YOUR PRETRAINED MODEL PATH> \
--model_type ffhq-config-f \
--size 1024 \
--style_dim 512 \
--n_mlp 8 \
--channel_multiplier 2 \
--cpu
```
**params:**
- latent1: The path of the first style vector. Come from `dst.npy` generated by Pixel2Style2Pixel or `dst.fitting.npy` generated by StyleGANv2 Fitting module
- latent2: The path of the second style vector. The source is the same as the first style vector
- weights: The two style vectors are mixed in different proportions at different levels. For a resolution of 1024, there are 18 levels. For a resolution of 512, there are 16 levels, and so on.
The more in front, the more it affects the whole of the mixed image. The more behind, the more it affects the details of the mixed image.
- need_align: whether to crop the image to an image that can be recognized by the model. For an image that has been cropped, such as the `src.png` that is pre-generated when Pixel2Style2Pixel is used to generate the style vector, the need_align parameter may not be filled in
- start_lr: learning rate at the begin of training
- final_lr: learning rate at the end of training
- latent_level: The style vector level involved in fitting is 0~17 at 1024 resolution, 0~15 at 512 resolution, and so on. The lower the level, the more biased toward the overall style change. The higher the level, the more biased toward the detail style change
- step: the number of steps required to fit the image, the larger the number of steps, the longer it takes and the better the effect
- mse_weight: weight of MSE loss
- pre_latent: The pre-made style vector files are saved to facilitate better fitting. The default is empty, you can fill in the file path of `dst.npy` generated by Pixel2Style2Pixel
- output_path: the directory where the generated images are stored
- weight_path: pretrained model path
- model_type: inner model type in PaddleGAN. If you use an existing model type, `weight_path` will have no effect.
Currently recommended use: `ffhq-config-f`
- size: model parameters, output image resolution
- style_dim: model parameters, dimensions of style z
- n_mlp: model parameters, the number of multi-layer perception layers for style z
- channel_multiplier: model parameters, channel product, affect model size and the quality of generated pictures
- cpu: whether to use cpu inference, if not, please remove it from the command
## Fitting Results
The image corresponding to the first style vector:
<div align="center">
<img src="../../imgs/stylegan2fitting-sample.png" width="300"/>
</div>
The image corresponding to the second style vector:
<div align="center">
<img src="../../imgs/stylegan2fitting-sample2.png" width="256"/>
</div>
The result of mixing two style vectors in a specific ratio:
<div align="center">
<img src="../../imgs/stylegan2mixing-sample.png" width="256"/>
</div>
## Reference
- 1. [Analyzing and Improving the Image Quality of StyleGAN](https://arxiv.org/abs/1912.04958)
```
@article{Karras2019stylegan2,
title={Analyzing and Improving the Image Quality of {StyleGAN}},
author={Tero Karras and Samuli Laine and Miika Aittala and Janne Hellsten and Jaakko Lehtinen and Timo Aila},
booktitle={Proc. CVPR},
year={2020}
}
```
- 2. [Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation](hhttps://arxiv.org/abs/2008.00951)
```
@article{richardson2020encoding,
title={Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation},
author={Richardson, Elad and Alaluf, Yuval and Patashnik, Or and Nitzan, Yotam and Azar, Yaniv and Shapiro, Stav and Cohen-Or, Daniel},
journal={arXiv preprint arXiv:2008.00951},
year={2020}
}
```
# StyleGAN V2 Fitting 模块
## StyleGAN V2 Fitting 原理
StyleGAN V2 的任务是使用风格向量进行image generation,而Fitting模块则是根据已有的图像反推出解耦程度高的风格向量。得到的风格向量可用于人脸融合、人脸属性编辑等任务中
## 使用方法
### 拟合
用户使用如下命令中进行拟合:
```
cd applications/
python -u tools/styleganv2fitting.py \
--input_image <替换为输入的图像路径> \
--need_align \
--start_lr 0.1 \
--final_lr 0.025 \
--latent_level 0 1 2 3 4 5 6 7 8 9 10 11 \
--step 100 \
--mse_weight 1 \
--pre_latent <替换为预先准备好的风格向量> \
--output_path <替换为生成图片存放的文件夹> \
--weight_path <替换为你的预训练模型路径> \
--model_type ffhq-config-f \
--size 1024 \
--style_dim 512 \
--n_mlp 8 \
--channel_multiplier 2 \
--cpu
```
**参数说明:**
- input_image: 输入的图像路径
- need_align: 是否将图像裁剪为模型能识别的图像,对于输入为已经裁剪过的图像,如使用Pixel2Style2Pixel生成风格向量时预生成的`src.png`,可不填写need_align参数
- start_lr: 拟合的初始学习率
- final_lr: 拟合结束时的学习率
- latent_level: 参与拟合的风格向量层次,1024分辨率下为0到17,512分辨率下则为0到15,以此类推。级别越低越偏向于整体风格改变,越高越偏向于细节风格改变
- step: 拟合图像所需步数,步数越大,花费时间越久,效果也更好
- mse_weight: MSE损失的权重
- pre_latent: 预制的风格向量保存的文件,便于更好效果的拟合。默认为空,可填入使用Pixel2Style2Pixel生成的`dst.npy`文件路径
- output_path: 生成图片存放的文件夹
- weight_path: 预训练模型路径
- model_type: PaddleGAN内置模型类型,若输入PaddleGAN已存在的模型类型,`weight_path`将失效。当前建议使用: `ffhq-config-f`
- size: 模型参数,输出图片的分辨率
- style_dim: 模型参数,风格z的维度
- n_mlp: 模型参数,风格z所输入的多层感知层的层数
- channel_multiplier: 模型参数,通道乘积,影响模型大小和生成图片质量
- cpu: 是否使用cpu推理,若不使用,请在命令中去除
## 拟合结果展示
源图像:
<div align="center">
<img src="../../imgs/pSp-input.jpg" width="300"/>
</div>
Pixel2Style2Pixel编码结果:
<div align="center">
<img src="../../imgs/pSp-inversion.png" width="256"/>
</div>
利用Pixel2Style2Pixel产生的风格向量,使用Fitting模块再进行1000步拟合的结果:
<div align="center">
<img src="../../imgs/stylegan2fitting-sample.png" width="256"/>
</div>
# 参考文献
- 1. [Analyzing and Improving the Image Quality of StyleGAN](https://arxiv.org/abs/1912.04958)
```
@article{Karras2019stylegan2,
title={Analyzing and Improving the Image Quality of {StyleGAN}},
author={Tero Karras and Samuli Laine and Miika Aittala and Janne Hellsten and Jaakko Lehtinen and Timo Aila},
booktitle={Proc. CVPR},
year={2020}
}
```
- 2. [Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation](hhttps://arxiv.org/abs/2008.00951)
```
@article{richardson2020encoding,
title={Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation},
author={Richardson, Elad and Alaluf, Yuval and Patashnik, Or and Nitzan, Yotam and Azar, Yaniv and Shapiro, Stav and Cohen-Or, Daniel},
journal={arXiv preprint arXiv:2008.00951},
year={2020}
}
```
# StyleGAN V2 Mixing 模块
## StyleGAN V2 Mixing 原理
StyleGAN V2 的任务是使用风格向量进行image generation,而Mixing模块则是利用其风格向量实现两张生成图像不同层次不同比例的混合
## 使用方法
### 混合
用户使用如下命令中进行拟合:
```
cd applications/
python -u tools/styleganv2mixing.py \
--latent1 <替换为第一个风格向量的路径> \
--latent2 <替换为第二个风格向量的路径> \
--weights \
0.5 0.5 0.5 0.5 0.5 0.5 \
0.5 0.5 0.5 0.5 0.5 0.5 \
0.5 0.5 0.5 0.5 0.5 0.5 \
--output_path <替换为生成图片存放的文件夹> \
--weight_path <替换为你的预训练模型路径> \
--model_type ffhq-config-f \
--size 1024 \
--style_dim 512 \
--n_mlp 8 \
--channel_multiplier 2 \
--cpu
```
**参数说明:**
- latent1: 第一个风格向量的路径。可来自于Pixel2Style2Pixel生成的`dst.npy`或StyleGANv2 Fitting模块生成的`dst.fitting.npy`
- latent2: 第二个风格向量的路径。来源同第一个风格向量
- weights: 两个风格向量在不同的层次按不同比例进行混合。对于1024的分辨率,有18个层次,512的分辨率,有16个层次,以此类推。
越前面,越影响混合图像的整体。越后面,越影响混合图像的细节。
- output_path: 生成图片存放的文件夹
- weight_path: 预训练模型路径
- model_type: PaddleGAN内置模型类型,若输入PaddleGAN已存在的模型类型,`weight_path`将失效。当前建议使用: `ffhq-config-f`
- size: 模型参数,输出图片的分辨率
- style_dim: 模型参数,风格z的维度
- n_mlp: 模型参数,风格z所输入的多层感知层的层数
- channel_multiplier: 模型参数,通道乘积,影响模型大小和生成图片质量
- cpu: 是否使用cpu推理,若不使用,请在命令中去除
## 拟合结果展示
第一个风格向量对应的图像:
<div align="center">
<img src="../../imgs/stylegan2fitting-sample.png" width="300"/>
</div>
第二个风格向量对应的图像:
<div align="center">
<img src="../../imgs/stylegan2fitting-sample2.png" width="256"/>
</div>
两个风格向量按特定比例混合的结果:
<div align="center">
<img src="../../imgs/stylegan2mixing-sample.png" width="256"/>
</div>
# 参考文献
- 1. [Analyzing and Improving the Image Quality of StyleGAN](https://arxiv.org/abs/1912.04958)
```
@article{Karras2019stylegan2,
title={Analyzing and Improving the Image Quality of {StyleGAN}},
author={Tero Karras and Samuli Laine and Miika Aittala and Janne Hellsten and Jaakko Lehtinen and Timo Aila},
booktitle={Proc. CVPR},
year={2020}
}
```
- 2. [Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation](hhttps://arxiv.org/abs/2008.00951)
```
@article{richardson2020encoding,
title={Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation},
author={Richardson, Elad and Alaluf, Yuval and Patashnik, Or and Nitzan, Yotam and Azar, Yaniv and Shapiro, Stav and Cohen-Or, Daniel},
journal={arXiv preprint arXiv:2008.00951},
year={2020}
}
```
......@@ -23,6 +23,8 @@ from .animegan_predictor import AnimeGANPredictor
from .midas_predictor import MiDaSPredictor
from .photo2cartoon_predictor import Photo2CartoonPredictor
from .styleganv2_predictor import StyleGANv2Predictor
from .styleganv2fitting_predictor import StyleGANv2FittingPredictor
from .styleganv2mixing_predictor import StyleGANv2MixingPredictor
from .pixel2style2pixel_predictor import Pixel2Style2PixelPredictor
from .wav2lip_predictor import Wav2LipPredictor
from .mpr_predictor import MPRPredictor
......
......@@ -25,34 +25,46 @@ from ppgan.models.generators import Pixel2Style2Pixel
from ppgan.utils.download import get_path_from_url
from PIL import Image
model_cfgs = {
'ffhq-inversion': {
'model_urls': 'https://paddlegan.bj.bcebos.com/models/pSp-ffhq-inversion.pdparams',
'transform': T.Compose([
'model_urls':
'https://paddlegan.bj.bcebos.com/models/pSp-ffhq-inversion.pdparams',
'transform':
T.Compose([
T.Resize((256, 256)),
T.Transpose(),
T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5])
]),
'size': 1024,
'style_dim': 512,
'n_mlp': 8,
'channel_multiplier': 2
'size':
1024,
'style_dim':
512,
'n_mlp':
8,
'channel_multiplier':
2
},
'ffhq-toonify': {
'model_urls': 'https://paddlegan.bj.bcebos.com/models/pSp-ffhq-toonify.pdparams',
'transform': T.Compose([
'model_urls':
'https://paddlegan.bj.bcebos.com/models/pSp-ffhq-toonify.pdparams',
'transform':
T.Compose([
T.Resize((256, 256)),
T.Transpose(),
T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5])
]),
'size': 1024,
'style_dim': 512,
'n_mlp': 8,
'channel_multiplier': 2
'size':
1024,
'style_dim':
512,
'n_mlp':
8,
'channel_multiplier':
2
},
'default': {
'transform': T.Compose([
'transform':
T.Compose([
T.Resize((256, 256)),
T.Transpose(),
T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5])
......@@ -68,23 +80,23 @@ def run_alignment(image_path):
raise Exception('Could not find a face in the given image.')
face_on_image = face[0]
lm = futils.dlib.landmarks(img, face_on_image)
lm = np.array(lm)[:,::-1]
lm_eye_left = lm[36 : 42]
lm_eye_right = lm[42 : 48]
lm_mouth_outer = lm[48 : 60]
lm = np.array(lm)[:, ::-1]
lm_eye_left = lm[36:42]
lm_eye_right = lm[42:48]
lm_mouth_outer = lm[48:60]
output_size = 1024
transform_size = 4096
enable_padding = True
# Calculate auxiliary vectors.
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
eye_avg = (eye_left + eye_right) * 0.5
eye_to_eye = eye_right - eye_left
mouth_left = lm_mouth_outer[0]
mouth_right = lm_mouth_outer[6]
mouth_avg = (mouth_left + mouth_right) * 0.5
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
eye_avg = (eye_left + eye_right) * 0.5
eye_to_eye = eye_right - eye_left
mouth_left = lm_mouth_outer[0]
mouth_right = lm_mouth_outer[6]
mouth_avg = (mouth_left + mouth_right) * 0.5
eye_to_mouth = mouth_avg - eye_avg
# Choose oriented crop rectangle.
......@@ -99,36 +111,52 @@ def run_alignment(image_path):
# Shrink.
shrink = int(np.floor(qsize / output_size * 0.5))
if shrink > 1:
rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
rsize = (int(np.rint(float(img.size[0]) / shrink)),
int(np.rint(float(img.size[1]) / shrink)))
img = img.resize(rsize, Image.ANTIALIAS)
quad /= shrink
qsize /= shrink
# Crop.
border = max(int(np.rint(qsize * 0.1)), 3)
crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0),
min(crop[2] + border,
img.size[0]), min(crop[3] + border, img.size[1]))
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
img = img.crop(crop)
quad -= crop[0:2]
# Pad.
pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1]))))
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0))
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
pad = (max(-pad[0] + border,
0), max(-pad[1] + border,
0), max(pad[2] - img.size[0] + border,
0), max(pad[3] - img.size[1] + border, 0))
if enable_padding and max(pad) > border - 4:
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
img = np.pad(np.float32(img),
((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
h, w, _ = img.shape
y, x, _ = np.ogrid[:h, :w, :1]
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3]))
mask = np.maximum(
1.0 -
np.minimum(np.float32(x) / pad[0],
np.float32(w - 1 - x) / pad[2]), 1.0 -
np.minimum(np.float32(y) / pad[1],
np.float32(h - 1 - y) / pad[3]))
blur = qsize * 0.02
img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0)
img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) -
img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
quad += pad[:2]
# Transform.
img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR)
img = img.transform((transform_size, transform_size), Image.QUAD,
(quad + 0.5).flatten(), Image.BILINEAR)
return img
......@@ -153,14 +181,17 @@ class Pixel2Style2PixelPredictor(BasePredictor):
if weight_path is None and model_type != 'default':
if model_type in model_cfgs.keys():
weight_path = get_path_from_url(model_cfgs[model_type]['model_urls'])
weight_path = get_path_from_url(
model_cfgs[model_type]['model_urls'])
size = model_cfgs[model_type].get('size', size)
style_dim = model_cfgs[model_type].get('style_dim', style_dim)
n_mlp = model_cfgs[model_type].get('n_mlp', n_mlp)
channel_multiplier = model_cfgs[model_type].get('channel_multiplier', channel_multiplier)
channel_multiplier = model_cfgs[model_type].get(
'channel_multiplier', channel_multiplier)
checkpoint = paddle.load(weight_path)
else:
raise ValueError('Predictor need a weight path or a pretrained model type')
raise ValueError(
'Predictor need a weight path or a pretrained model type')
else:
checkpoint = paddle.load(weight_path)
......@@ -174,7 +205,7 @@ class Pixel2Style2PixelPredictor(BasePredictor):
self.generator = Pixel2Style2Pixel(opts)
self.generator.set_state_dict(checkpoint)
self.generator.eval()
if seed is not None:
paddle.seed(seed)
random.seed(seed)
......@@ -186,14 +217,20 @@ class Pixel2Style2PixelPredictor(BasePredictor):
src_img = run_alignment(image)
src_img = np.asarray(src_img)
transformed_image = model_cfgs[self.model_type]['transform'](src_img)
dst_img = (self.generator(paddle.to_tensor(transformed_image[None, ...]))
* 0.5 + 0.5)[0].numpy() * 255
dst_img, latents = self.generator(paddle.to_tensor(
transformed_image[None, ...]),
resize=False,
return_latents=True)
dst_img = (dst_img * 0.5 + 0.5)[0].numpy() * 255
dst_img = dst_img.transpose((1, 2, 0))
dst_npy = latents[0].numpy()
os.makedirs(self.output_path, exist_ok=True)
save_src_path = os.path.join(self.output_path, 'src.png')
cv2.imwrite(save_src_path, cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR))
save_dst_path = os.path.join(self.output_path, 'dst.png')
cv2.imwrite(save_dst_path, cv2.cvtColor(dst_img, cv2.COLOR_RGB2BGR))
save_npy_path = os.path.join(self.output_path, 'dst.npy')
np.save(save_npy_path, dst_npy)
return src_img
return src_img, dst_img, dst_npy
......@@ -21,17 +21,18 @@ from ppgan.models.generators import StyleGANv2Generator
from ppgan.utils.download import get_path_from_url
from ppgan.utils.visual import make_grid, tensor2img, save_image
model_cfgs = {
'ffhq-config-f': {
'model_urls': 'https://paddlegan.bj.bcebos.com/models/stylegan2-ffhq-config-f.pdparams',
'model_urls':
'https://paddlegan.bj.bcebos.com/models/stylegan2-ffhq-config-f.pdparams',
'size': 1024,
'style_dim': 512,
'n_mlp': 8,
'channel_multiplier': 2
},
'animeface-512': {
'model_urls': 'https://paddlegan.bj.bcebos.com/models/stylegan2-animeface-512.pdparams',
'model_urls':
'https://paddlegan.bj.bcebos.com/models/stylegan2-animeface-512.pdparams',
'size': 512,
'style_dim': 512,
'n_mlp': 8,
......@@ -64,7 +65,7 @@ def sample(generator, mean_style, n_sample):
truncation=0.7,
truncation_latent=mean_style,
)[0]
return image
......@@ -73,16 +74,16 @@ def style_mixing(generator, mean_style, n_source, n_target):
source_code = paddle.randn([n_source, generator.style_dim])
target_code = paddle.randn([n_target, generator.style_dim])
resolution = 2 ** ((generator.n_latent + 2) // 2)
resolution = 2**((generator.n_latent + 2) // 2)
images = [paddle.ones([1, 3, resolution, resolution]) * -1]
source_image = generator(
[source_code], truncation_latent=mean_style, truncation=0.7
)[0]
target_image = generator(
[target_code], truncation_latent=mean_style, truncation=0.7
)[0]
source_image = generator([source_code],
truncation_latent=mean_style,
truncation=0.7)[0]
target_image = generator([target_code],
truncation_latent=mean_style,
truncation=0.7)[0]
images.append(source_image)
......@@ -96,7 +97,7 @@ def style_mixing(generator, mean_style, n_source, n_target):
images.append(image)
images = paddle.concat(images, 0)
return images
......@@ -114,21 +115,25 @@ class StyleGANv2Predictor(BasePredictor):
if weight_path is None:
if model_type in model_cfgs.keys():
weight_path = get_path_from_url(model_cfgs[model_type]['model_urls'])
weight_path = get_path_from_url(
model_cfgs[model_type]['model_urls'])
size = model_cfgs[model_type].get('size', size)
style_dim = model_cfgs[model_type].get('style_dim', style_dim)
n_mlp = model_cfgs[model_type].get('n_mlp', n_mlp)
channel_multiplier = model_cfgs[model_type].get('channel_multiplier', channel_multiplier)
channel_multiplier = model_cfgs[model_type].get(
'channel_multiplier', channel_multiplier)
checkpoint = paddle.load(weight_path)
else:
raise ValueError('Predictor need a weight path or a pretrained model type')
raise ValueError(
'Predictor need a weight path or a pretrained model type')
else:
checkpoint = paddle.load(weight_path)
self.generator = StyleGANv2Generator(size, style_dim, n_mlp, channel_multiplier)
self.generator = StyleGANv2Generator(size, style_dim, n_mlp,
channel_multiplier)
self.generator.set_state_dict(checkpoint)
self.generator.eval()
if seed is not None:
paddle.seed(seed)
random.seed(seed)
......@@ -139,10 +144,10 @@ class StyleGANv2Predictor(BasePredictor):
mean_style = get_mean_style(self.generator)
img = sample(self.generator, mean_style, n_row * n_col)
save_image(tensor2img(make_grid(img, nrow=n_col)), f'{self.output_path}/sample.png')
save_image(tensor2img(make_grid(img, nrow=n_col)),
f'{self.output_path}/sample.png')
for j in range(2):
img = style_mixing(self.generator, mean_style, n_col, n_row)
save_image(tensor2img(make_grid(
img, nrow=n_col + 1
)), f'{self.output_path}/sample_mixing_{j}.png')
save_image(tensor2img(make_grid(img, nrow=n_col + 1)),
f'{self.output_path}/sample_mixing_{j}.png')
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import numpy as np
import paddle
from paddle import optimizer as optim
from paddle.nn import functional as F
from paddle.vision import transforms
from tqdm import tqdm
from PIL import Image
from .styleganv2_predictor import StyleGANv2Predictor
from .pixel2style2pixel_predictor import run_alignment
from ..metrics.lpips import LPIPS
def get_lr(t, ts, initial_lr, final_lr):
alpha = pow(final_lr / initial_lr, 1 / ts)**(t * ts)
return initial_lr * alpha
def make_image(tensor):
return (((tensor.detach() + 1) / 2 * 255).clip(min=0, max=255).transpose(
(0, 2, 3, 1)).numpy().astype('uint8'))
class StyleGANv2FittingPredictor(StyleGANv2Predictor):
def run(
self,
image,
need_align=False,
start_lr=0.1,
final_lr=0.025,
latent_level=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11], # for ffhq (0~17)
step=100,
mse_weight=1,
pre_latent=None):
if need_align:
src_img = run_alignment(image)
else:
src_img = Image.open(image).convert("RGB")
generator = self.generator
generator.train()
percept = LPIPS(net='vgg')
# on PaddlePaddle, lpips's default eval mode means no gradients.
percept.train()
n_mean_latent = 4096
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.Transpose(),
transforms.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5]),
])
imgs = paddle.to_tensor(transform(src_img)).unsqueeze(0)
if pre_latent is None:
with paddle.no_grad():
noise_sample = paddle.randn(
(n_mean_latent, generator.style_dim))
latent_out = generator.style(noise_sample)
latent_mean = latent_out.mean(0)
latent_in = latent_mean.detach().clone().unsqueeze(0).tile(
(imgs.shape[0], 1))
latent_in = latent_in.unsqueeze(1).tile(
(1, generator.n_latent, 1)).detach()
else:
latent_in = paddle.to_tensor(np.load(pre_latent)).unsqueeze(0)
var_levels = list(latent_level)
const_levels = [
i for i in range(generator.n_latent) if i not in var_levels
]
assert len(var_levels) > 0
if len(const_levels) > 0:
latent_fix = latent_in.index_select(paddle.to_tensor(const_levels),
1).detach().clone()
latent_in = latent_in.index_select(paddle.to_tensor(var_levels),
1).detach().clone()
latent_in.stop_gradient = False
optimizer = optim.Adam(parameters=[latent_in], learning_rate=start_lr)
pbar = tqdm(range(step))
for i in pbar:
t = i / step
lr = get_lr(t, step, start_lr, final_lr)
optimizer.set_lr(lr)
if len(const_levels) > 0:
latent_dict = {}
for idx, idx2 in enumerate(var_levels):
latent_dict[idx2] = latent_in[:, idx:idx + 1]
for idx, idx2 in enumerate(const_levels):
latent_dict[idx2] = (latent_fix[:, idx:idx + 1]).detach()
latent_list = []
for idx in range(generator.n_latent):
latent_list.append(latent_dict[idx])
latent_n = paddle.concat(latent_list, 1)
else:
latent_n = latent_in
img_gen, _ = generator([latent_n],
input_is_latent=True,
randomize_noise=False)
batch, channel, height, width = img_gen.shape
if height > 256:
factor = height // 256
img_gen = img_gen.reshape((batch, channel, height // factor,
factor, width // factor, factor))
img_gen = img_gen.mean([3, 5])
p_loss = percept(img_gen, imgs).sum()
mse_loss = F.mse_loss(img_gen, imgs)
loss = p_loss + mse_weight * mse_loss
optimizer.clear_grad()
loss.backward()
optimizer.step()
pbar.set_description(
(f"perceptual: {p_loss.numpy()[0]:.4f}; "
f"mse: {mse_loss.numpy()[0]:.4f}; lr: {lr:.4f}"))
img_gen, _ = generator([latent_n],
input_is_latent=True,
randomize_noise=False)
dst_img = make_image(img_gen)[0]
dst_latent = latent_n.numpy()[0]
os.makedirs(self.output_path, exist_ok=True)
save_src_path = os.path.join(self.output_path, 'src.fitting.png')
cv2.imwrite(save_src_path,
cv2.cvtColor(np.asarray(src_img), cv2.COLOR_RGB2BGR))
save_dst_path = os.path.join(self.output_path, 'dst.fitting.png')
cv2.imwrite(save_dst_path, cv2.cvtColor(dst_img, cv2.COLOR_RGB2BGR))
save_npy_path = os.path.join(self.output_path, 'dst.fitting.npy')
np.save(save_npy_path, dst_latent)
return np.asarray(src_img), dst_img, dst_latent
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import numpy as np
import paddle
from .styleganv2_predictor import StyleGANv2Predictor
def make_image(tensor):
return (((tensor.detach() + 1) / 2 * 255).clip(min=0, max=255).transpose(
(0, 2, 3, 1)).numpy().astype('uint8'))
class StyleGANv2MixingPredictor(StyleGANv2Predictor):
@paddle.no_grad()
def run(self, latent1, latent2, weights=[0.5] * 18):
latent1 = paddle.to_tensor(np.load(latent1)).unsqueeze(0)
latent2 = paddle.to_tensor(np.load(latent2)).unsqueeze(0)
assert latent1.shape[1] == latent2.shape[1] == len(
weights
), 'latents and their weights should have the same level nums.'
mix_latent = []
for i, weight in enumerate(weights):
mix_latent.append(latent1[:, i:i + 1] * weight +
latent2[:, i:i + 1] * (1 - weight))
mix_latent = paddle.concat(mix_latent, 1)
latent_n = paddle.concat([latent1, latent2, mix_latent], 0)
generator = self.generator
img_gen, _ = generator([latent_n],
input_is_latent=True,
randomize_noise=False)
imgs = make_image(img_gen)
src_img1 = imgs[0]
src_img2 = imgs[1]
dst_img = imgs[2]
os.makedirs(self.output_path, exist_ok=True)
save_src_path = os.path.join(self.output_path, 'src1.mixing.png')
cv2.imwrite(save_src_path, cv2.cvtColor(src_img1, cv2.COLOR_RGB2BGR))
save_src_path = os.path.join(self.output_path, 'src2.mixing.png')
cv2.imwrite(save_src_path, cv2.cvtColor(src_img2, cv2.COLOR_RGB2BGR))
save_dst_path = os.path.join(self.output_path, 'dst.mixing.png')
cv2.imwrite(save_dst_path, cv2.cvtColor(dst_img, cv2.COLOR_RGB2BGR))
return src_img1, src_img2, dst_img
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册