未验证 提交 d1225d09 编写于 作者: qq_19291021's avatar qq_19291021 提交者: GitHub

add styleclip (#643)

* add styleclip

* update 2022

* add weight url

* update doc & img url
上级 0541ace7
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import paddle
from ppgan.apps import StyleGANv2ClipPredictor
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--latent",
type=str,
help="path to first image latent codes")
parser.add_argument("--neutral", type=str, help="neutral description")
parser.add_argument("--target", type=str, help="neutral description")
parser.add_argument("--beta_threshold",
type=float,
default=0.12,
help="beta threshold for channel editing")
parser.add_argument("--direction_offset",
type=float,
default=5.0,
help="offset value of edited attribute")
parser.add_argument("--direction_path",
type=str,
default=None,
help="path to latent editing directions")
parser.add_argument("--output_path",
type=str,
default='output_dir',
help="path to output image dir")
parser.add_argument("--weight_path",
type=str,
default=None,
help="path to model checkpoint path")
parser.add_argument("--model_type",
type=str,
default=None,
help="type of model for loading pretrained model")
parser.add_argument("--size",
type=int,
default=1024,
help="resolution of output image")
parser.add_argument("--style_dim",
type=int,
default=512,
help="number of style dimension")
parser.add_argument("--n_mlp",
type=int,
default=8,
help="number of mlp layer depth")
parser.add_argument("--channel_multiplier",
type=int,
default=2,
help="number of channel multiplier")
parser.add_argument("--cpu",
dest="cpu",
action="store_true",
help="cpu mode.")
args = parser.parse_args()
if args.cpu:
paddle.set_device('cpu')
predictor = StyleGANv2ClipPredictor(
output_path=args.output_path,
weight_path=args.weight_path,
model_type=args.model_type,
seed=None,
size=args.size,
style_dim=args.style_dim,
n_mlp=args.n_mlp,
channel_multiplier=args.channel_multiplier,
direction_path=args.direction_path)
predictor.run(args.latent, args.neutral, args.target, args.direction_offset,
args.beta_threshold)
# StyleCLIP: Text-Driven Manipulation of StyleGAN Imagery
## Introduction
The task of StyleGAN V2 is image generation while the Clip guided Editing module uses the attribute manipulation vector obtained by CLIP (Contrastive Language-Image Pre-training) Model for mapping text prompts to input-agnostic directions in StyleGAN’s style space, enabling interactive text-driven image manipulation.
This model uses pretrained StyleGAN V2 generator and uses Pixel2Style2Pixel model for image encoding. At present, only the models of portrait editing (trained on FFHQ dataset) is available.
Paddle-CLIP and dlib package is needed for this module.
```
pip install -e .
pip install paddleclip
pip install dlib-bin
```
## How to use
### Editing
```
cd applications/
python -u tools/styleganv2clip.py \
--latent <<PATH TO STYLE VECTOR> \
--output_path <DIRECTORY TO STORE OUTPUT IMAGE> \
--weight_path <YOUR PRETRAINED MODEL PATH> \
--model_type ffhq-config-f \
--size 1024 \
--style_dim 512 \
--n_mlp 8 \
--channel_multiplier 2 \
--direction_path <PATH TO STORE ATTRIBUTE DIRECTIONS> \
--neutral <DESCRIPTION OF THE SOURCE IMAGE> \
--target <DESCRIPTION OF THE TARGET IMAGE> \
--beta_threshold 0.12 \
--direction_offset 5
--cpu
```
**params:**
- latent: The path of the style vector which represents an image. Come from `dst.npy` generated by Pixel2Style2Pixel or `dst.fitting.npy` generated by StyleGANv2 Fitting module
- output_path: the directory where the generated images are stored
- weight_path: pretrained StyleGANv2 model path
- model_type: inner model type, currently only `ffhq-config-f` is available.
- direction_path: The path of CLIP mapping vector
- stat_path: The path of latent statisitc file
- neutral: Description of the source image,for example: face
- target: Description of the target image,for example: young face
- beta_threshold: editing threshold of the attribute channels
- direction_offset: Offset strength of the attribute
- cpu: whether to use cpu inference, if not, please remove it from the command
>inherited params for the pretrained StyleGAN model
- size: model parameters, output image resolution
- style_dim: model parameters, dimensions of style z
- n_mlp: model parameters, the number of multi-layer perception layers for style z
- channel_multiplier: model parameters, channel product, affect model size and the quality of generated pictures
### Results
Input portrait:
<div align="center">
<img src="../../imgs/stylegan2fitting-sample.png" width="300"/>
</div>
with
> direction_offset = [ -1, 0, 1, 2, 3, 4, 5]
> beta_threshold = 0.1
edit from 'face' to 'boy face':
![stylegan2clip-sample-boy](https://user-images.githubusercontent.com/29187613/187344690-6709fba5-6e21-4bc0-83d1-5996947c99a4.png)
edit from 'face' to 'happy face':
![stylegan2clip-sample-happy](https://user-images.githubusercontent.com/29187613/187344681-6509f01b-0d9e-4dea-8a97-ee9ca75d152e.png)
edit from 'face' to 'angry face':
![stylegan2clip-sample-angry](https://user-images.githubusercontent.com/29187613/187344686-ff5047ab-5499-420d-ad02-e0908ac71bf7.png)
edit from 'face' to 'face with long hair':
![stylegan2clip-sample-long-hair](https://user-images.githubusercontent.com/29187613/187344684-4e452631-52b0-47cf-966e-3216c0392815.png)
edit from 'face' to 'face with curly hair':
![stylegan2clip-sample-curl-hair](https://user-images.githubusercontent.com/29187613/187344677-c9a3aa9f-1f3c-41b3-a1f0-fcd48a9c627b.png)
edit from 'head with black hair' to 'head with gold hair':
![stylegan2clip-sample-gold-hair](https://user-images.githubusercontent.com/29187613/187344678-5220e8b2-b1c9-4f2f-8655-621b6272c457.png)
## Make Attribute Direction Vector
For details, please refer to [Puzer/stylegan-encoder](https://github.com/Puzer/stylegan-encoder/blob/master/Learn_direction_in_latent_space.ipynb)
Currently pretrained weight for `stylegan2` & `ffhq-config-f` dataset is provided:
direction: https://paddlegan.bj.bcebos.com/models/stylegan2-ffhq-config-f-styleclip-global-directions.pdparams
stats: https://paddlegan.bj.bcebos.com/models/stylegan2-ffhq-config-f-styleclip-stats.pdparams
## Training
1. extract style latent vector stats
```
python styleclip_getf.py
```
2. calcuate mapping vector using CLIP model
```
python ppgan/apps/styleganv2clip_predictor.py extract
```
# Reference
- 1. [StyleCLIP: Text-Driven Manipulation of StyleGAN Imagery](https://arxiv.org/abs/2103.17249)
```
@article{Patashnik2021StyleCLIPTM,
title={StyleCLIP: Text-Driven Manipulation of StyleGAN Imagery},
author={Or Patashnik and Zongze Wu and Eli Shechtman and Daniel Cohen-Or and D. Lischinski},
journal={2021 IEEE/CVF International Conference on Computer Vision (ICCV)},
year={2021},
pages={2065-2074}
}
```
- 2. [Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation](hhttps://arxiv.org/abs/2008.00951)
```
@article{richardson2020encoding,
title={Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation},
author={Richardson, Elad and Alaluf, Yuval and Patashnik, Or and Nitzan, Yotam and Azar, Yaniv and Shapiro, Stav and Cohen-Or, Daniel},
journal={arXiv preprint arXiv:2008.00951},
year={2020}
}
```
# StyleCLIP: 文本驱动的图像处理
## 1. 简介
StyleGAN V2 的任务是使用风格向量进行image generation,而Clip guided Editing 则是利用CLIP (Contrastive Language-Image Pre-training ) 多模态预训练模型计算文本输入对应的风格向量变化,用文字表述来对图像进行编辑操纵风格向量进而操纵生成图像的属性。相比于Editing 模块,StyleCLIP不受预先统计的标注属性限制,可以通过语言描述自由控制图像编辑。
原论文中使用 Pixel2Style2Pixel 的 升级模型 Encode4Editing 计算要编辑的代表图像的风格向量,为尽量利用PaddleGAN提供的预训练模型本次复现中仍使用Pixel2Style2Pixel计算得到风格向量进行实验,重构效果略有下降,期待PaddleGAN跟进e4e相关工作。
## 2. 复现
StyleCLIP 模型 需要使用简介中对应提到的几个预训练模型,
本次复现使用PPGAN 提供的 在FFHQ数据集上进行预训练的StyleGAN V2 模型作为生成器,并使用Pixel2Style2Pixel模型将待编辑图像转换为对应风格向量。
CLIP模型依赖Paddle-CLIP实现。
pSp模型包含人脸检测步骤,依赖dlib框架。
除本repo外还需要安装 Paddle-CLIP 和 dlib 依赖。
整体安装方法如下。
```
pip install -e .
pip install paddleclip
pip install dlib-bin
```
### 编辑结果展示
风格向量对应的图像:
<div align="center">
<img src="../../imgs/stylegan2fitting-sample.png" width="300"/>
</div>
设置
> direction_offset = [ -1, 0, 1, 2, 3, 4, 5]
> beta_threshold = 0.1
从 'face' 到 'boy face' 编辑得到的图像:
![stylegan2clip-sample-boy](https://user-images.githubusercontent.com/29187613/187344690-6709fba5-6e21-4bc0-83d1-5996947c99a4.png)
从'face' 到 'happy face' 编辑得到的图像:
![stylegan2clip-sample-happy](https://user-images.githubusercontent.com/29187613/187344681-6509f01b-0d9e-4dea-8a97-ee9ca75d152e.png)
从'face' 到 'angry face' 编辑得到的图像:
![stylegan2clip-sample-angry](https://user-images.githubusercontent.com/29187613/187344686-ff5047ab-5499-420d-ad02-e0908ac71bf7.png)
从'face' 到 'face with long hair' 编辑得到的图像:
![stylegan2clip-sample-long-hair](https://user-images.githubusercontent.com/29187613/187344684-4e452631-52b0-47cf-966e-3216c0392815.png)
从'face' 到 'face with curl hair' (卷发) 编辑得到的图像:
![stylegan2clip-sample-curl-hair](https://user-images.githubusercontent.com/29187613/187344677-c9a3aa9f-1f3c-41b3-a1f0-fcd48a9c627b.png)
从'head with black hair'(黑发) 到 'head with gold hair'(金发)编辑得到的图像:
![stylegan2clip-sample-gold-hair](https://user-images.githubusercontent.com/29187613/187344678-5220e8b2-b1c9-4f2f-8655-621b6272c457.png)
## 3. 使用方法
### 制作属性向量
具体可以参考[Puzer/stylegan-encoder](https://github.com/Puzer/stylegan-encoder/blob/master/Learn_direction_in_latent_space.ipynb)中的做法。
当前提供与`stylegan2`对应`ffhq-config-f`数据集上的权重参数:
direction: https://paddlegan.bj.bcebos.com/models/stylegan2-ffhq-config-f-styleclip-global-directions.pdparams
stats: https://paddlegan.bj.bcebos.com/models/stylegan2-ffhq-config-f-styleclip-stats.pdparams
### 训练
在StyleCLIP论文中作者研究了 3 种结合 StyleGAN 和 CLIP 的方法:
1. 文本引导的风格向量优化,使用 CLIP 模型作为损失网络对现有风格向量进行多次迭代更新,但该方法对每次处理都需要重新训练。
2. 训练 风格向量映射器,使CLIP文本特征向量映射至StyleGAN 风格向量空间,避免(1)方法的训练问题,但可控性较差,经论文对比其生成质量也不如(3)。
3. 在 StyleGAN 的 StyleSpace 中,把文本描述映射到输入图像的全局方向 (Global Direction),进而运行自由控制图像操作强度以及分离程度,实现类似于StyleGAN Editing 模块的使用体验。
本次仅复现论文中效果最好的 (3)Global Direction 方法。
StyleCLIP Global Direction 训练过程分两步:
1. 提取风格向量并统计
```
python styleclip_getf.py
```
2. 结合CLIP模型计算转换矩阵
```
python ppgan/apps/styleganv2clip_predictor.py extract
```
### 编辑
用户使用如下命令中对图像属性进行编辑:
```
cd applications/
python -u tools/styleganv2clip.py \
--latent <替换为要编辑的风格向量的路径> \
--output_path <替换为生成图片存放的文件夹> \
--weight_path <替换为你的预训练模型路径> \
--model_type ffhq-config-f \
--size 1024 \
--style_dim 512 \
--n_mlp 8 \
--channel_multiplier 2 \
--direction_path <替换为存放统计数据的文件路径> \
--neutral <替换为对原图像的描述,如face> \
--target <替换为对目标图像的描述> \
--beta_threshold 0.12 \
--direction_offset 5
--cpu
```
**参数说明:**
- latent: 要编辑的代表图像的风格向量的路径。可来自于Pixel2Style2Pixel生成的`dst.npy`或StyleGANv2 Fitting模块生成的`dst.fitting.npy`
- output_path: 生成图片存放的文件夹
- weight_path: 或StyleGANv2 预训练模型路径
- model_type: 模型类型,当前使用: `ffhq-config-f`
- direction_path: 存放CLIP统计向量的文件路径
- stat_path: 存放StyleGAN向量统计数据的文件路径
- neutral: 对原图像的中性描述,如 face
- target: 为对目标图像的描述,如 young face
- beta_threshold: 向量调整阈值
- direction_offset: 属性的偏移强度
- cpu: 是否使用cpu推理,若不使用,请在命令中去除
!以下 参数需与StyleGAN 预训练模型保持一致
- size: 模型参数,输出图片的分辨率
- style_dim: 模型参数,风格z的维度
- n_mlp: 模型参数,风格z所输入的多层感知层的层数
- channel_multiplier: 模型参数,通道乘积,影响模型大小和生成图片质量
## 复现记录
1. PaddleGAN 实现中的StyleGAN模型将Style Affine层进行了模块耦合,而论文中使用到的S Space 需要用到,因此对StyleGAN 生成器代码也进行了魔改,增加style_affine 及 synthesis_from_styles 方法同时尽量兼容现有接口。
2. StyleCLIP论文中表示使用100张图像进行Global Direction 训练在V1080Ti需要约4h,但使用V100的训练数据及官方repo中也有issue提到实际需要约24h,该问题但作者还未能给出解答。
3. Paddle Resize处理对Tensor和ndarray的处理方法不同,默认Tensor使用BCHW模式存储而非图像的BHWC。
4. 现有 uppfirdn2d 模块中似乎存在多次不必要的Tensor拷贝、reshape过程,希望后续能够优化运算及显存占用。
5. 切片拷贝:paddle中对Tensor进行切片时(有时)会创建新的拷贝,此时再对其进行赋值很可能不生效,两种写法`a[ind1][ind2]=0``a[ind1, ind2]=0` 前者并不改变a中的参数。
# 参考文献
- 1. [StyleCLIP: Text-Driven Manipulation of StyleGAN Imagery](https://arxiv.org/abs/2103.17249)
```
@article{Patashnik2021StyleCLIPTM,
title={StyleCLIP: Text-Driven Manipulation of StyleGAN Imagery},
author={Or Patashnik and Zongze Wu and Eli Shechtman and Daniel Cohen-Or and D. Lischinski},
journal={2021 IEEE/CVF International Conference on Computer Vision (ICCV)},
year={2021},
pages={2065-2074}
}
```
- 2. [Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation](hhttps://arxiv.org/abs/2008.00951)
```
@article{richardson2020encoding,
title={Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation},
author={Richardson, Elad and Alaluf, Yuval and Patashnik, Or and Nitzan, Yotam and Azar, Yaniv and Shapiro, Stav and Cohen-Or, Daniel},
journal={arXiv preprint arXiv:2008.00951},
year={2020}
}
```
......@@ -23,6 +23,7 @@ from .animegan_predictor import AnimeGANPredictor
from .midas_predictor import MiDaSPredictor
from .photo2cartoon_predictor import Photo2CartoonPredictor
from .styleganv2_predictor import StyleGANv2Predictor
from .styleganv2clip_predictor import StyleGANv2ClipPredictor
from .styleganv2fitting_predictor import StyleGANv2FittingPredictor
from .styleganv2mixing_predictor import StyleGANv2MixingPredictor
from .styleganv2editing_predictor import StyleGANv2EditingPredictor
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import cv2
import numpy as np
import paddle
from ppgan.apps.styleganv2_predictor import StyleGANv2Predictor
from ppgan.utils.download import get_path_from_url
from clip import tokenize, load_model
model_cfgs = {
'ffhq-config-f': {
'direction_urls':
'https://paddlegan.bj.bcebos.com/models/stylegan2-ffhq-config-f-styleclip-global-directions.pdparams',
'stat_urls':
'https://paddlegan.bj.bcebos.com/models/stylegan2-ffhq-config-f-styleclip-stats.pdparams'
}
}
def make_image(tensor):
return (((tensor.detach() + 1) / 2 * 255).clip(min=0, max=255).transpose(
(0, 2, 3, 1)).numpy().astype('uint8'))
# prompt engineering
prompt_templates = [
'a bad photo of a {}.',
'a photo of the hard to see {}.',
'a low resolution photo of the {}.',
'a rendering of a {}.',
'graffiti of a {}.',
'a bad photo of the {}.',
'a cropped photo of the {}.',
'a photo of a hard to see {}.',
'a bright photo of a {}.',
'a photo of a clean {}.',
'a photo of a dirty {}.',
'a dark photo of the {}.',
'a drawing of a {}.',
'a photo of my {}.',
'a photo of the cool {}.',
'a close-up photo of a {}.',
'a black and white photo of the {}.',
'a painting of the {}.',
'a painting of a {}.',
'a pixelated photo of the {}.',
'a sculpture of the {}.',
'a bright photo of the {}.',
'a cropped photo of a {}.',
'a jpeg corrupted photo of a {}.',
'a blurry photo of the {}.',
'a photo of the {}.',
'a good photo of the {}.',
'a rendering of the {}.',
'a close-up photo of the {}.',
'a photo of a {}.',
'a low resolution photo of a {}.',
'a photo of the clean {}.',
'a photo of a large {}.',
'a photo of a nice {}.',
'a blurry photo of a {}.',
'a cartoon {}.',
'art of a {}.',
'a good photo of a {}.',
'a photo of the nice {}.',
'a photo of the small {}.',
'a photo of the weird {}.',
'art of the {}.',
'a drawing of the {}.',
'a photo of the large {}.',
'a dark photo of a {}.',
'graffiti of the {}.',
]
@paddle.no_grad()
def get_delta_t(neutral, target, model, templates=prompt_templates):
text_features = []
for classname in [neutral, target]:
texts = [template.format(classname)
for template in templates] #format with class
texts = tokenize(texts) #tokenize
class_embeddings = model.encode_text(texts) #embed with text encoder
class_embeddings /= class_embeddings.norm(axis=-1, keepdim=True)
class_embedding = class_embeddings.mean(axis=0)
class_embedding /= class_embedding.norm()
text_features.append(class_embedding)
text_features = paddle.stack(text_features, axis=1).t()
delta_t = (text_features[1] - text_features[0])
delta_t = delta_t / delta_t.norm()
return delta_t
@paddle.no_grad()
def get_ds_from_dt(global_style_direction,
delta_t,
generator,
beta_threshold,
relative=False,
soft_threshold=False):
delta_s = global_style_direction @ delta_t
delta_s_max = delta_s.abs().max()
print(f'max delta_s is {delta_s_max.item()}')
if relative: beta_threshold *= delta_s_max
# apply beta threshold (disentangle)
select = delta_s.abs() < beta_threshold
num_channel = paddle.sum(~select).item()
# threshold in style direction
delta_s[select] = delta_s[select] * soft_threshold
delta_s /= delta_s_max # normalize
# delta_s -> style dict
dic = []
ind = 0
for layer in range(len(generator.w_idx_lst)): # 26
dim = generator.channels_lst[layer]
if layer in generator.style_layers:
dic.append(paddle.to_tensor(delta_s[ind:ind + dim]))
ind += dim
else:
dic.append(paddle.zeros([dim]))
return dic, num_channel
class StyleGANv2ClipPredictor(StyleGANv2Predictor):
def __init__(self, model_type=None, direction_path=None, stat_path=None, **kwargs):
super().__init__(model_type=model_type, **kwargs)
if direction_path is None and model_type is not None:
assert model_type in model_cfgs, f'There is not any pretrained direction file for {model_type} model.'
direction_path = get_path_from_url(
model_cfgs[model_type]['direction_urls'])
self.fs3 = paddle.load(direction_path)
self.clip_model, _ = load_model('ViT_B_32', pretrained=True)
self.manipulator = Manipulator(self.generator, model_type=model_type, stat_path=stat_path)
def get_delta_s(self,
neutral,
target,
beta_threshold,
relative=False,
soft_threshold=0):
# get delta_t in CLIP text space (text directions)
delta_t = get_delta_t(neutral, target, self.clip_model)
# get delta_s in global image directions
delta_s, num_channel = get_ds_from_dt(self.fs3, delta_t, self.generator,
beta_threshold, relative,
soft_threshold)
print(
f'{num_channel} channels will be manipulated under the {"relative" if relative else ""} beta threshold {beta_threshold}'
)
return delta_s
@paddle.no_grad()
def gengrate(self, latent: paddle.Tensor, delta_s, lst_alpha):
styles = self.generator.style_affine(latent)
styles = self.manipulator.manipulate(styles, delta_s, lst_alpha)
# synthesis images from manipulated styles
img_gen = self.manipulator.synthesis_from_styles(styles)
return img_gen, styles
@paddle.no_grad()
def run(self, latent, neutral, target, offset, beta_threshold=0.8):
latent = paddle.to_tensor(
np.load(latent)).unsqueeze(0).astype('float32')
delta_s = self.get_delta_s(neutral, target, beta_threshold)
img_gen, styles = self.gengrate(latent, delta_s, [0, offset])
imgs = make_image(paddle.concat(img_gen))
src_img = imgs[0]
dst_img = imgs[1]
dst_latent = styles[1]
os.makedirs(self.output_path, exist_ok=True)
save_src_path = os.path.join(self.output_path, 'src.editing.png')
cv2.imwrite(save_src_path, cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR))
save_dst_path = os.path.join(self.output_path, 'dst.editing.png')
cv2.imwrite(save_dst_path, cv2.cvtColor(dst_img, cv2.COLOR_RGB2BGR))
save_path = os.path.join(self.output_path, 'dst.editing.pd')
paddle.save(dst_latent, save_path)
return src_img, dst_img, dst_latent
@paddle.no_grad()
def extract_global_direction(G,
lst_alpha,
batchsize=5,
num=100,
dataset_name='',
seed=None):
from tqdm import tqdm
import PIL
"""Extract global style direction in 100 images
"""
assert len(lst_alpha) == 2 #[-5, 5]
assert num < 200
#np.random.seed(0)
# get intermediate latent of n samples
try:
S = paddle.load(f'S-{dataset_name}.pdparams')
S = [S[i][:num] for i in range(len(G.w_idx_lst))]
except:
print('No pre-computed S, run tools/styleclip_getf.py first!')
exit()
# total channel used: 1024 -> 6048 channels, 256 -> 4928 channels
print(
f"total channels to manipulate: {sum([G.channels_lst[i] for i in G.style_layers])}"
)
manipulator = Manipulator(G, model_type=dataset_name,
stat_path=f'stylegan2-{dataset_name}-styleclip-stats.pdparams')
model, preprocess = load_model('ViT_B_32', pretrained=True)
nbatch = int(num / batchsize)
all_feats = list()
for layer in G.style_layers:
print(f'\nStyle manipulation in layer "{layer}"')
for channel_ind in tqdm(range(G.channels_lst[layer])):
styles = manipulator.manipulate_one_channel(copy.deepcopy(S), layer,
channel_ind, lst_alpha,
num)
# 2 * num images
feats = list()
for img_ind in range(nbatch): # batch size * nbatch * 2
start = img_ind * batchsize
end = img_ind * batchsize + batchsize
synth_imgs = manipulator.synthesis_from_styles(
styles, [start, end])
synth_imgs = [(synth_img.transpose((0, 2, 3, 1)) * 127.5 +
128).clip(0, 255).astype('uint8').numpy()
for synth_img in synth_imgs]
imgs = list()
for i in range(batchsize):
img0 = PIL.Image.fromarray(synth_imgs[0][i])
img1 = PIL.Image.fromarray(synth_imgs[1][i])
imgs.append(preprocess(img0).unsqueeze(0))
imgs.append(preprocess(img1).unsqueeze(0))
feat = model.encode_image(paddle.concat(imgs))
feats.append(feat.numpy())
all_feats.append(np.concatenate(feats).reshape([-1, 2, 512]))
all_feats = np.stack(all_feats)
np.save(f'fs-{dataset_name}.npy', all_feats)
fs = all_feats #L B 2 512
fs1 = fs / np.linalg.norm(fs, axis=-1)[:, :, :, None]
fs2 = fs1[:, :, 1, :] - fs1[:, :, 0, :] # 5*sigma - (-5)* sigma
fs3 = fs2 / np.linalg.norm(fs2, axis=-1)[:, :, None]
fs3 = fs3.mean(axis=1)
fs3 = fs3 / np.linalg.norm(fs3, axis=-1)[:, None]
paddle.save(paddle.to_tensor(fs3),
f'stylegan2-{dataset_name}-styleclip-global-directions.pdparams'
) # global style direction
class Manipulator():
"""Manipulator for style editing
The paper uses 100 image pairs to estimate the mean for alpha(magnitude of the perturbation) [-5, 5]
"""
def __init__(self, generator, model_type='ffhq-config-f', stat_path=None):
self.generator = generator
if stat_path is None and model_type is not None:
assert model_type in model_cfgs, f'There is not any pretrained stat file for {model_type} model.'
stat_path = get_path_from_url(
model_cfgs[model_type]['direction_urls'])
data = paddle.load(stat_path)
self.S_mean = data['mean']
self.S_std = data['std']
@paddle.no_grad()
def manipulate(self, styles, delta_s, lst_alpha):
"""Edit style by given delta_style
- use perturbation (delta s) * (alpha) as a boundary
"""
styles = [copy.deepcopy(styles) for _ in range(len(lst_alpha))]
for (alpha, style) in zip(lst_alpha, styles):
for i in range(len(self.generator.w_idx_lst)):
style[i] += delta_s[i] * alpha
return styles
@paddle.no_grad()
def manipulate_one_channel(self,
styles,
layer_ind,
channel_ind: int,
lst_alpha=[0],
num_images=100):
"""Edit style from given layer, channel index
- use mean value of pre-saved style
- use perturbation (pre-saved style std) * (alpha) as a boundary
"""
assert 0 <= channel_ind < styles[layer_ind].shape[1]
boundary = self.S_std[layer_ind][channel_ind].item()
# apply self.S_mean value for given layer, channel_ind
for img_ind in range(num_images):
styles[layer_ind][img_ind,
channel_ind] = self.S_mean[layer_ind][channel_ind]
styles = [copy.deepcopy(styles) for _ in range(len(lst_alpha))]
perturbation = (paddle.to_tensor(lst_alpha) * boundary).numpy().tolist()
# apply one channel manipulation
for img_ind in range(num_images):
for edit_ind, delta in enumerate(perturbation):
styles[edit_ind][layer_ind][img_ind, channel_ind] += delta
return styles
@paddle.no_grad()
def synthesis_from_styles(self, styles, slice=None, randomize_noise=True):
"""Synthesis edited styles from styles, lst_alpha
"""
imgs = list()
if slice is not None:
for style in styles:
style_ = [list() for _ in range(len(self.generator.w_idx_lst))]
for i in range(len(self.generator.w_idx_lst)):
style_[i] = style[i][slice[0]:slice[1]]
imgs.append(
self.generator.synthesis(style_,
randomize_noise=randomize_noise))
else:
for style in styles:
imgs.append(
self.generator.synthesis(style,
randomize_noise=randomize_noise))
return imgs
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('runtype',
type=str,
default='generate',
choices=['generate', 'test', 'extract'])
parser.add_argument("--latent",
type=str,
default='output_dir/sample/dst.npy',
help="path to first image latent codes")
parser.add_argument("--neutral",
type=str,
default=None,
help="neutral description")
parser.add_argument("--target",
type=str,
default=None,
help="neutral description")
parser.add_argument("--direction_path",
type=str,
default=None,
help="path to latent editing directions")
parser.add_argument("--stat_path",
type=str,
default=None,
help="path to latent stat files")
parser.add_argument("--direction_offset",
type=float,
default=5,
help="offset value of edited attribute")
parser.add_argument("--beta_threshold",
type=float,
default=0.12,
help="beta threshold for channel editing")
parser.add_argument('--dataset_name', type=str,
default='ffhq-config-f') #'animeface-512')
args = parser.parse_args()
runtype = args.runtype
if runtype in ['test', 'extract']:
dataset_name = args.dataset_name
G = StyleGANv2Predictor(model_type=dataset_name).generator
if runtype == 'test': # test manipulator
from ppgan.utils.visual import make_grid, tensor2img, save_image
num_images = 2
lst_alpha = [-5, 0, 5]
layer = 6
channel_ind = 501
manipulator = Manipulator(G, model_type=dataset_name, stat_path=args.stat_path)
styles = manipulator.manipulate_one_channel(layer, channel_ind,
lst_alpha, num_images)
imgs = manipulator.synthesis_from_styles(styles)
print(len(imgs), imgs[0].shape)
save_image(
tensor2img(make_grid(paddle.concat(imgs), nrow=num_images)),
f'sample.png')
elif runtype == 'extract': # train: extract global style direction
batchsize = 10
num_images = 100
lst_alpha = [-5, 5]
extract_global_direction(G,
lst_alpha,
batchsize,
num_images,
dataset_name=dataset_name)
else:
predictor = StyleGANv2ClipPredictor(model_type=args.dataset_name,
seed=None,
direction_path=args.direction_path,
stat_path=args.stat_path)
predictor.run(args.latent, args.neutral, args.target,
args.direction_offset, args.beta_threshold)
......@@ -93,11 +93,13 @@ class ModulatedConv2D(nn.Layer):
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
f"upsample={self.upsample}, downsample={self.downsample})")
def forward(self, inputs, style):
def forward(self, inputs, style, apply_modulation=False):
batch, in_channel, height, width = inputs.shape
style = self.modulation(style).reshape((batch, 1, in_channel, 1, 1))
if apply_modulation: style = self.modulation(style)
style = style.reshape((batch, 1, in_channel, 1, 1))
weight = self.scale * self.weight * style
del style
if self.demodulate:
demod = paddle.rsqrt((weight * weight).sum([2, 3, 4]) + 1e-8)
......@@ -165,8 +167,7 @@ class ConstantInput(nn.Layer):
(1, channel, size, size),
default_initializer=nn.initializer.Normal())
def forward(self, inputs):
batch = inputs.shape[0]
def forward(self, batch):
out = self.input.tile((batch, 1, 1, 1))
return out
......@@ -250,8 +251,9 @@ class StyleGANv2Generator(nn.Layer):
super().__init__()
self.size = size
self.style_dim = style_dim
self.log_size = int(math.log(size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1
layers = [PixelNorm()]
......@@ -275,6 +277,33 @@ class StyleGANv2Generator(nn.Layer):
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
self.channels_lst = []
self.w_idx_lst = [
0,1, # 4
1,2,3, # 8
3,4,5, # 16
5,6,7, # 32
7,8,9, # 64
9,10,11, # 128
11,12,13, # 256
13,14,15, # 512
15,16,17, # 1024
]
self.style_layers = [
0, #1,
2, 3, #4,
5, 6, #7,
8, 9, #10,
11, 12,# 13,
14, 15,# 16,
17, 18,# 19,
20, 21,# 22,
23, 24,# 25
]
if self.log_size != 10:
self.w_idx_lst = self.w_idx_lst[:-(3 * (10 - self.log_size))]
self.style_layers = self.style_layers[:-(2 * (10 - self.log_size))]
self.input = ConstantInput(self.channels[4])
self.conv1 = StyledConv(self.channels[4],
......@@ -287,9 +316,7 @@ class StyleGANv2Generator(nn.Layer):
2 if is_concat else self.channels[4],
style_dim,
upsample=False)
self.log_size = int(math.log(size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1
self.channels_lst.extend([self.channels[4], self.channels[4]])
self.convs = nn.LayerList()
self.upsamples = nn.LayerList()
......@@ -329,6 +356,7 @@ class StyleGANv2Generator(nn.Layer):
self.to_rgbs.append(
ToRGB(out_channel * 2 if is_concat else out_channel, style_dim))
self.channels_lst.extend([in_channel, out_channel, out_channel])
in_channel = out_channel
self.n_latent = self.log_size * 2 - 2
......@@ -352,25 +380,115 @@ class StyleGANv2Generator(nn.Layer):
def get_latent(self, inputs):
return self.style(inputs)
def get_mean_style(self):
def get_latents(
self,
inputs,
truncation=1.0,
truncation_cutoff=None,
truncation_latent=None,
input_is_latent=False,
):
assert truncation >= 0, "truncation should be a float in range [0, 1]"
if not input_is_latent:
style = self.style(inputs)
if truncation < 1.0:
if truncation_latent is None:
truncation_latent = self.get_mean_style()
cutoff = truncation_cutoff
if truncation_cutoff is None:
style = truncation_latent + \
truncation * (style - truncation_latent)
else:
style[:, :cutoff] = truncation_latent[:, :cutoff] + \
truncation * (style[:, :cutoff] - truncation_latent[:, :cutoff])
return style
@paddle.no_grad()
def get_mean_style(self, n_sample=10, n_latent=1024):
mean_style = None
with paddle.no_grad():
for i in range(10):
style = self.mean_latent(1024)
if mean_style is None:
mean_style = style
else:
mean_style += style
for i in range(n_sample):
style = self.mean_latent(n_latent)
if mean_style is None:
mean_style = style
else:
mean_style += style
mean_style /= 10
mean_style /= n_sample
return mean_style
def get_latent_S(self, inputs):
return self.style_affine(self.style(inputs))
def style_affine(self, latent):
if latent.ndim < 3:
latent = latent.unsqueeze(1).tile((1, self.n_latent, 1))
latent_ = []
latent_.append(self.conv1.conv.modulation(latent[:, 0]))
latent_.append(self.to_rgb1.conv.modulation(latent[:, 1]))
i = 1
for conv1, conv2, to_rgb in zip(self.convs[::2], self.convs[1::2],
self.to_rgbs):
latent_.append(conv1.conv.modulation(latent[:, i + 0]))
latent_.append(conv2.conv.modulation(latent[:, i + 1]))
latent_.append(to_rgb.conv.modulation(latent[:, i + 2]))
i += 2
return latent_ #paddle.concat(latent_, axis=1)
def synthesis(self,
latent,
noise=None,
randomize_noise=True,
is_w_latent=False):
out = self.input(latent[0].shape[0])
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers
#noise = [paddle.randn(getattr(self.noises, f"noise_{i}").shape) for i in range(self.num_layers)]
else:
noise = [
getattr(self.noises, f"noise_{i}")
for i in range(self.num_layers)
]
out = self.conv1(out, latent[0], noise=noise[0])
skip = self.to_rgb1(out, latent[1])
i = 2
if self.is_concat:
noise_i = 1
for conv1, conv2, to_rgb in zip(self.convs[::2], self.convs[1::2],
self.to_rgbs):
out = conv1(out, latent[i],
noise=noise[(noise_i + 1) // 2]) ### 1 for 2
out = conv2(out, latent[i + 1],
noise=noise[(noise_i + 2) // 2]) ### 1 for 2
skip = to_rgb(out, latent[i + 2], skip)
i += 3
noise_i += 2
else:
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2],
self.to_rgbs):
out = conv1(out, latent[i], noise=noise1)
out = conv2(out, latent[i + 1], noise=noise2)
skip = to_rgb(out, latent[i + 2], skip)
i += 3
return skip #image = skip
def forward(
self,
styles,
return_latents=False,
inject_index=None,
truncation=1.0,
truncation_cutoff=None,
truncation_latent=None,
input_is_latent=False,
noise=None,
......@@ -379,23 +497,19 @@ class StyleGANv2Generator(nn.Layer):
if not input_is_latent:
styles = [self.style(s) for s in styles]
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers
else:
noise = [
getattr(self.noises, f"noise_{i}")
for i in range(self.num_layers)
]
if truncation < 1.0:
style_t = []
if truncation_latent is None:
truncation_latent = self.get_mean_style()
cutoff = truncation_cutoff
for style in styles:
style_t.append(truncation_latent + truncation *
(style - truncation_latent))
if truncation_cutoff is None:
style = truncation_latent + \
truncation * (style - truncation_latent)
else:
style[:, :cutoff] = truncation_latent[:, :cutoff] + \
truncation * (style[:, :cutoff] - truncation_latent[:, :cutoff])
style_t.append(style)
styles = style_t
if len(styles) < 2:
......@@ -417,41 +531,12 @@ class StyleGANv2Generator(nn.Layer):
latent = paddle.concat([latent, latent2], 1)
out = self.input(latent)
out = self.conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
if self.is_concat:
noise_i = 1
outs = []
for conv1, conv2, to_rgb in zip(self.convs[::2], self.convs[1::2],
self.to_rgbs):
out = conv1(out, latent[:, i],
noise=noise[(noise_i + 1) // 2]) ### 1 for 2
out = conv2(out,
latent[:, i + 1],
noise=noise[(noise_i + 2) // 2]) ### 1 for 2
skip = to_rgb(out, latent[:, i + 2], skip)
i += 2
noise_i += 2
else:
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2],
self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
i += 2
#if not input_is_affined_latent:
styles = self.style_affine(latent)
image = skip
image = self.synthesis(styles, noise, randomize_noise)
if return_latents:
return image, latent
else:
return image, None
import argparse
from tqdm import tqdm
import paddle
import numpy as np
from ppgan.apps.styleganv2_predictor import StyleGANv2Predictor
def concat_style_paddle(s_lst, n_layers):
result = [list() for _ in range(n_layers)]
assert n_layers == len(s_lst[0])
for i in range(n_layers):
for s_ in s_lst:
result[i].append(s_[i])
for i in range(n_layers):
result[i] = paddle.concat(result[i])
return result
def to_np(s_lst):
for i in range(len(s_lst)):
s_lst[i] = s_lst[i].numpy()
return s_lst
def concat_style_np(s_lst, n_layers):
result = [list() for _ in range(n_layers)]
assert n_layers == len(s_lst[0])
for i in range(n_layers):
for s_ in s_lst:
result[i].append(s_[i])
for i in range(n_layers):
result[i] = np.concatenate(result[i])
return result
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', type=str, default='ffhq-config-f')
parser.add_argument('--seed', type=int, default=1234)
args = parser.parse_args()
dataset_name = args.dataset_name
G = StyleGANv2Predictor(model_type=dataset_name).generator
w_idx_lst = G.w_idx_lst
with paddle.no_grad():
# get intermediate latent of 100000 samples
w_lst = list()
z = paddle.to_tensor(
np.random.RandomState(args.seed).randn(
1000, 100, G.style_dim).astype('float32'))
#z = paddle.randn([1000, 100, G.style_dim])
for i in tqdm(range(1000)): # 100 * 1000 = 100000 # 1000
# apply truncation_psi=.7 truncation_cutoff=8
w_lst.append(
G.get_latents(z[i], truncation=0.7, truncation_cutoff=8))
#paddle.save(paddle.concat(w_lst[:20]), f'W-{dataset_name}.pdparams')
s_lst = []
# get style of first 2000 sample in W
for i in tqdm(range(20)): # 2*1000
s_ = G.style_affine(w_lst[i])
s_lst.append(s_)
paddle.save(concat_style_paddle(s_lst, len(w_idx_lst)),
f'S-{dataset_name}.pdparams')
for i in tqdm(range(20)): # 2*1000
s_lst[i] = to_np(s_lst[i])
# get std, mean of 100000 style samples
for i in tqdm(range(20, 1000)): # 100 * 1000
s_ = G.style_affine(w_lst[i])
s_lst.append(to_np(s_))
del w_lst, z, s_, G
s_lst = concat_style_np(s_lst, len(w_idx_lst))
s_mean = [
paddle.mean(paddle.to_tensor(s_lst[i]), axis=0)
for i in range(len(w_idx_lst))
]
s_std = [
paddle.std(paddle.to_tensor(s_lst[i]), axis=0)
for i in range(len(w_idx_lst))
]
paddle.save({
'mean': s_mean,
'std': s_std
}, f'stylegan2-{dataset_name}-styleclip-stats.pdparams')
print("Done.")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册