未验证 提交 d84bfbbf 编写于 作者: W wangna11BD 提交者: GitHub

Add vis code (#318)

* add LapStyle Model
上级 b2d44b45
...@@ -37,7 +37,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional ...@@ -37,7 +37,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional
* [Pixel2Pixel](./docs/en_US/tutorials/pix2pix_cyclegan.md) * [Pixel2Pixel](./docs/en_US/tutorials/pix2pix_cyclegan.md)
* [CycleGAN](./docs/en_US/tutorials/pix2pix_cyclegan.md) * [CycleGAN](./docs/en_US/tutorials/pix2pix_cyclegan.md)
* [LapStyle(coming soon)](./docs/en_US/tutorials/lap_style.md) * [LapStyle](./docs/en_US/tutorials/lap_style.md)
* [PSGAN](./docs/en_US/tutorials/psgan.md) * [PSGAN](./docs/en_US/tutorials/psgan.md)
* [First Order Motion Model](./docs/en_US/tutorials/motion_driving.md) * [First Order Motion Model](./docs/en_US/tutorials/motion_driving.md)
* [FaceParsing](./docs/en_US/tutorials/face_parse.md) * [FaceParsing](./docs/en_US/tutorials/face_parse.md)
......
...@@ -70,7 +70,7 @@ GAN--生成对抗网络,被“卷积网络之父”**Yann LeCun(杨立昆) ...@@ -70,7 +70,7 @@ GAN--生成对抗网络,被“卷积网络之父”**Yann LeCun(杨立昆)
* [Pixel2Pixel](./docs/zh_CN/tutorials/pix2pix_cyclegan.md) * [Pixel2Pixel](./docs/zh_CN/tutorials/pix2pix_cyclegan.md)
* [CycleGAN](./docs/zh_CN/tutorials/pix2pix_cyclegan.md) * [CycleGAN](./docs/zh_CN/tutorials/pix2pix_cyclegan.md)
* [LapStyle(coming soon)](./docs/zh_CN/tutorials/lap_style.md) * [LapStyle](./docs/zh_CN/tutorials/lap_style.md)
* [PSGAN](./docs/zh_CN/tutorials/psgan.md) * [PSGAN](./docs/zh_CN/tutorials/psgan.md)
* [First Order Motion Model](./docs/zh_CN/tutorials/motion_driving.md) * [First Order Motion Model](./docs/zh_CN/tutorials/motion_driving.md)
* [FaceParsing](./docs/zh_CN/tutorials/face_parse.md) * [FaceParsing](./docs/zh_CN/tutorials/face_parse.md)
......
...@@ -5,7 +5,7 @@ min_max: ...@@ -5,7 +5,7 @@ min_max:
(0., 1.) (0., 1.)
model: model:
name: LapStyleModel name: LapStyleDraModel
generator_encode: generator_encode:
name: Encoder name: Encoder
generator_decode: generator_decode:
...@@ -37,7 +37,7 @@ dataset: ...@@ -37,7 +37,7 @@ dataset:
name: LapStyleDataset name: LapStyleDataset
content_root: data/coco/test2017/ content_root: data/coco/test2017/
style_root: data/starrynew.png style_root: data/starrynew.png
load_size: 136 load_size: 128
crop_size: 128 crop_size: 128
num_workers: 0 num_workers: 0
batch_size: 1 batch_size: 1
...@@ -56,12 +56,12 @@ optimizer: ...@@ -56,12 +56,12 @@ optimizer:
beta2: 0.999 beta2: 0.999
validate: validate:
interval: 5000 interval: 500
save_img: false save_img: false
log_config: log_config:
interval: 10 interval: 10
visiual_interval: 5000 visiual_interval: 500
snapshot_config: snapshot_config:
interval: 5000 interval: 5000
total_iters: 30000
output_dir: output_dir
checkpoints_dir: checkpoints
min_max:
(0., 1.)
model:
name: LapStyleRevFirstModel
revnet_generator:
name: RevisionNet
revnet_discriminator:
name: LapStyleDiscriminator
draftnet_encode:
name: Encoder
draftnet_decode:
name: DecoderNet
calc_style_emd_loss:
name: CalcStyleEmdLoss
calc_content_relt_loss:
name: CalcContentReltLoss
calc_content_loss:
name: CalcContentLoss
calc_style_loss:
name: CalcStyleLoss
gan_criterion:
name: GANLoss
gan_mode: vanilla
content_layers: ['r11', 'r21', 'r31', 'r41', 'r51']
style_layers: ['r11', 'r21', 'r31', 'r41', 'r51']
content_weight: 1.0
style_weight: 3.0
dataset:
train:
name: LapStyleDataset
content_root: data/coco/train2017/
style_root: data/starrynew.png
load_size: 280
crop_size: 256
num_workers: 16
batch_size: 5
test:
name: LapStyleDataset
content_root: data/coco/test2017/
style_root: data/starrynew.png
load_size: 256
crop_size: 256
num_workers: 0
batch_size: 1
lr_scheduler:
name: NonLinearDecay
learning_rate: 1e-4
lr_decay: 5e-5
optimizer:
optimG:
name: Adam
net_names:
- net_rev
beta1: 0.9
beta2: 0.999
optimD:
name: Adam
net_names:
- netD
beta1: 0.9
beta2: 0.999
validate:
interval: 500
save_img: false
log_config:
interval: 10
visiual_interval: 500
snapshot_config:
interval: 5000
total_iters: 30000
output_dir: output_dir
checkpoints_dir: checkpoints
min_max:
(0., 1.)
model:
name: LapStyleRevSecondModel
revnet_generator:
name: RevisionNet
revnet_discriminator:
name: LapStyleDiscriminator
draftnet_encode:
name: Encoder
draftnet_decode:
name: DecoderNet
calc_style_emd_loss:
name: CalcStyleEmdLoss
calc_content_relt_loss:
name: CalcContentReltLoss
calc_content_loss:
name: CalcContentLoss
calc_style_loss:
name: CalcStyleLoss
gan_criterion:
name: GANLoss
gan_mode: vanilla
content_layers: ['r11', 'r21', 'r31', 'r41', 'r51']
style_layers: ['r11', 'r21', 'r31', 'r41', 'r51']
content_weight: 1.0
style_weight: 3.0
dataset:
train:
name: LapStyleDataset
content_root: data/coco/train2017/
style_root: data/starrynew.png
load_size: 540
crop_size: 512
num_workers: 16
batch_size: 2
test:
name: LapStyleDataset
content_root: data/coco/test2017/
style_root: data/starrynew.png
load_size: 512
crop_size: 512
num_workers: 0
batch_size: 1
lr_scheduler:
name: NonLinearDecay
learning_rate: 1e-4
lr_decay: 5e-5
optimizer:
optimG:
name: Adam
net_names:
- net_rev_2
beta1: 0.9
beta2: 0.999
optimD:
name: Adam
net_names:
- netD
beta1: 0.9
beta2: 0.999
validate:
interval: 500
save_img: false
log_config:
interval: 10
visiual_interval: 500
snapshot_config:
interval: 5000
Coming soon.
# LapStyle
This repo holds the official codes of paper: "Drafting and Revision: Laplacian Pyramid Network for Fast High-Quality Artistic Style Transfer", which is accepted in CVPR 2021.
## 1 Paper Introduction
Artistic style transfer aims at migrating the style from an example image to a content image. Currently, optimization- based methods have achieved great stylization quality, but expensive time cost restricts their practical applications. Meanwhile, feed-forward methods still fail to synthesize complex style, especially when holistic global and local patterns exist. Inspired by the common painting process ofdrawing a draft and revising the details, [this paper](https://arxiv.org/pdf/2104.05376.pdf) introduce a novel feed- forward method Laplacian Pyramid Network (LapStyle). LapStyle first transfers global style pattern in low-resolution via a Drafting Network. It then revises the local details in high-resolution via a Revision Network, which hallucinates a residual image according to the draft and the image textures extracted by Laplacian filtering. Higher resolution details can be easily generated by stacking Revision Networks with multiple Laplacian pyramid levels. The final stylized image is obtained by aggregating outputs ofall pyramid levels. We also introduce a patch discriminator to better learn local pattern adversarially. Experiments demonstrate that our method can synthesize high quality stylized images in real time, where holistic style patterns are properly transferred.
![lapstyle_overview](https://user-images.githubusercontent.com/79366697/118654987-b24dc100-b81b-11eb-9430-d84630f80511.png)
## 2 How to use
### 2.1 Prepare Datasets
To train LapStyle, we use the COCO dataset as content set. And you can choose any style image you like. Before training or testing, remember modify the data path of style image in the config file.
### 2.2 Train
Datasets used in example is COCO, you can also change it to your own dataset in the config file.
(1) Train the Draft Network of LapStyle under 128*128 resolution:
```
python -u tools/main.py --config-file configs/lapstyle_draft.yaml
```
(2) Then, train the Revision Network of LapStyle under 256*256 resolution:
```
python -u tools/main.py --config-file configs/lapstyle_rev_first.yaml --load ${PATH_OF_LAST_STAGE_WEIGHT}
```
(3) Further, you can train the second Revision Network under 512*512 resolution:
```
python -u tools/main.py --config-file configs/lapstyle_rev_second.yaml --load ${PATH_OF_LAST_STAGE_WEIGHT}
```
### 2.4 Test
To test the trained model, you can directly test the "lapstyle_rev_second", since it also contains the trained weight of previous stages:
```
python tools/main.py --config-file configs/lapstyle_rev_second.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 3 Results
| Style | Stylized Results |
| --- | --- |
| ![starrynew](https://user-images.githubusercontent.com/79366697/118655415-1ec8c000-b81c-11eb-8002-90bf8d477860.png) | ![chicago_stylized_starrynew](https://user-images.githubusercontent.com/79366697/118655671-59325d00-b81c-11eb-93a3-4fcc24680124.png)|
| ![ocean](https://user-images.githubusercontent.com/79366697/118655407-1c666600-b81c-11eb-83a6-300ee1952415.png) | ![chicago_ocean_512](https://user-images.githubusercontent.com/79366697/118655625-4cae0480-b81c-11eb-83ec-30936ed3df65.png)|
| ![stars](https://user-images.githubusercontent.com/79366697/118655423-20928380-b81c-11eb-92bd-0deeb320ff14.png) | ![chicago_stylized_stars_512](https://user-images.githubusercontent.com/79366697/118655638-50da2200-b81c-11eb-9223-58d5df022fa5.png)|
| ![circuit](https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg) | ![chicago_stylized_circuit](https://user-images.githubusercontent.com/79366697/118655660-56376c80-b81c-11eb-87f2-64ae5a82375c.png)|
## 4 Pre-trained models
We also provide several trained models.
| model | style | path |
|---|---|---|
| lapstyle_circuit | circuit | [lapstyle_circuit](https://paddlegan.bj.bcebos.com/models/lapstyle_circuit.pdparams)
| lapstyle_ocean | ocean | [lapstyle_ocean](https://paddlegan.bj.bcebos.com/models/lapstyle_ocean.pdparams)
| lapstyle_starrynew | starrynew | [lapstyle_starrynew](https://paddlegan.bj.bcebos.com/models/lapstyle_starrynew.pdparams)
| lapstyle_stars | stars | [lapstyle_stars](https://paddlegan.bj.bcebos.com/models/lapstyle_stars.pdparams)
# References
```
@article{lin2021drafting,
title={Drafting and Revision: Laplacian Pyramid Network for Fast High-Quality Artistic Style Transfer},
author={Lin, Tianwei and Ma, Zhuoqi and Li, Fu and He, Dongliang and Li, Xin and Ding, Errui and Wang, Nannan and Li, Jie and Gao, Xinbo},
booktitle={Computer Vision and Pattern Recognition (CVPR)},
year={2021}
}
```
即将开源。
# LapStyle
这个repo提供CVPR2021论文"Drafting and Revision: Laplacian Pyramid Network for Fast High-Quality Artistic Style Transfer"的官方代码。
## 1 论文介绍
艺术风格迁移的目的是将一个实例图像的艺术风格迁移到一个内容图像。目前,基于优化的方法已经取得了很好的合成质量,但昂贵的时间成本限制了其实际应用。
同时,前馈方法仍然不能合成复杂风格,特别是存在全局和局部模式时。受绘制草图和修改细节这一常见绘画过程的启发,[论文](https://arxiv.org/pdf/2104.05376.pdf) 提出了一种新的前馈方法拉普拉斯金字塔网络(LapStyle)。
LapStyle首先通过绘图网络(Drafting Network)传输低分辨率的全局风格模式。然后通过修正网络(Revision Network)对局部细节进行高分辨率的修正,它根据拉普拉斯滤波提取的图像纹理和草图产生图像残差。通过叠加具有多个拉普拉斯金字塔级别的修订网络,可以很容易地生成更高分辨率的细节。最终的样式化图像是通过聚合所有金字塔级别的输出得到的。论文还引入了一个补丁鉴别器,以更好地对抗的学习局部风格。实验表明,该方法能实时合成高质量的风格化图像,并能正确生成整体风格模式。
![lapstyle_overview](https://user-images.githubusercontent.com/79366697/118654987-b24dc100-b81b-11eb-9430-d84630f80511.png)
## 2 如何使用
### 2.1 数据准备
为了训练LapStyle,我们使用COCO数据集作为内容数据集。您可以任意选择您喜欢的风格图片。在开始训练与测试之前,记得修改配置文件的数据路径。
### 2.2 训练
示例以COCO数据为例。如果您想使用自己的数据集,可以在配置文件中修改数据集为您自己的数据集。
(1) 首先在128*128像素下训练LapStyle的绘图网络(Drafting Network):
```
python -u tools/main.py --config-file configs/lapstyle_draft.yaml
```
(2) 然后,在256*256像素下训练LapStyle的修正网络(Revision Network):
```
python -u tools/main.py --config-file configs/lapstyle_rev_first.yaml --load ${PATH_OF_LAST_STAGE_WEIGHT}
```
(3) 最后,在512*512像素下再次训练LapStyle的修正网络(Revision Network):
```
python -u tools/main.py --config-file configs/lapstyle_rev_second.yaml --load ${PATH_OF_LAST_STAGE_WEIGHT}
```
### 2.4 测试
测试训练好的模型,您可以直接测试 "lapstyle_rev_second",因为它包含了之前步骤里的训练权重:
```
python tools/main.py --config-file configs/lapstyle_rev_second.yaml --evaluate-only --load ${PATH_OF_WEIGHT}
```
## 3 结果展示
| Style | Stylized Results |
| --- | --- |
| ![starrynew](https://user-images.githubusercontent.com/79366697/118655415-1ec8c000-b81c-11eb-8002-90bf8d477860.png) | ![chicago_stylized_starrynew](https://user-images.githubusercontent.com/79366697/118655671-59325d00-b81c-11eb-93a3-4fcc24680124.png)|
| ![ocean](https://user-images.githubusercontent.com/79366697/118655407-1c666600-b81c-11eb-83a6-300ee1952415.png) | ![chicago_ocean_512](https://user-images.githubusercontent.com/79366697/118655625-4cae0480-b81c-11eb-83ec-30936ed3df65.png)|
| ![stars](https://user-images.githubusercontent.com/79366697/118655423-20928380-b81c-11eb-92bd-0deeb320ff14.png) | ![chicago_stylized_stars_512](https://user-images.githubusercontent.com/79366697/118655638-50da2200-b81c-11eb-9223-58d5df022fa5.png)|
| ![circuit](https://user-images.githubusercontent.com/79366697/118655399-196b7580-b81c-11eb-8bc5-d5ece80c18ba.jpg) | ![chicago_stylized_circuit](https://user-images.githubusercontent.com/79366697/118655660-56376c80-b81c-11eb-87f2-64ae5a82375c.png)|
## 4 模型下载
我们提供几个训练好的权重。
| 模型 | 风格 | 下载地址 |
|---|---|---|
| lapstyle_circuit | circuit | [lapstyle_circuit](https://paddlegan.bj.bcebos.com/models/lapstyle_circuit.pdparams)
| lapstyle_ocean | ocean | [lapstyle_ocean](https://paddlegan.bj.bcebos.com/models/lapstyle_ocean.pdparams)
| lapstyle_starrynew | starrynew | [lapstyle_starrynew](https://paddlegan.bj.bcebos.com/models/lapstyle_starrynew.pdparams)
| lapstyle_stars | stars | [lapstyle_stars](https://paddlegan.bj.bcebos.com/models/lapstyle_stars.pdparams)
# References
```
@article{lin2021drafting,
title={Drafting and Revision: Laplacian Pyramid Network for Fast High-Quality Artistic Style Transfer},
author={Lin, Tianwei and Ma, Zhuoqi and Li, Fu and He, Dongliang and Li, Xin and Ding, Errui and Wang, Nannan and Li, Jie and Gao, Xinbo},
booktitle={Computer Vision and Pattern Recognition (CVPR)},
year={2021}
}
```
...@@ -19,6 +19,7 @@ from PIL import Image ...@@ -19,6 +19,7 @@ from PIL import Image
import paddle import paddle
import paddle.vision.transforms as T import paddle.vision.transforms as T
from paddle.io import Dataset from paddle.io import Dataset
import cv2
from .builder import DATASETS from .builder import DATASETS
...@@ -53,12 +54,18 @@ class LapStyleDataset(Dataset): ...@@ -53,12 +54,18 @@ class LapStyleDataset(Dataset):
ci_path: str ci_path: str
""" """
path = self.paths[index] path = self.paths[index]
content_img = Image.open(os.path.join(self.content_root, content_img = cv2.imread(os.path.join(self.content_root, path))
path)).convert('RGB') if content_img.ndim == 2:
content_img = cv2.cvtColor(content_img, cv2.COLOR_GRAY2RGB)
else:
content_img = cv2.cvtColor(content_img, cv2.COLOR_BGR2RGB)
content_img = Image.fromarray(content_img)
content_img = content_img.resize((self.load_size, self.load_size), content_img = content_img.resize((self.load_size, self.load_size),
Image.BILINEAR) Image.BILINEAR)
content_img = np.array(content_img) content_img = np.array(content_img)
style_img = Image.open(self.style_root).convert('RGB') style_img = cv2.imread(self.style_root)
style_img = cv2.cvtColor(style_img, cv2.COLOR_BGR2RGB)
style_img = Image.fromarray(style_img)
style_img = style_img.resize((self.load_size, self.load_size), style_img = style_img.resize((self.load_size, self.load_size),
Image.BILINEAR) Image.BILINEAR)
style_img = np.array(style_img) style_img = np.array(style_img)
......
...@@ -29,4 +29,4 @@ from .wav2lip_hq_model import Wav2LipModelHq ...@@ -29,4 +29,4 @@ from .wav2lip_hq_model import Wav2LipModelHq
from .starganv2_model import StarGANv2Model from .starganv2_model import StarGANv2Model
from .edvr_model import EDVRModel from .edvr_model import EDVRModel
from .firstorder_model import FirstOrderModel from .firstorder_model import FirstOrderModel
from .lapstyle_model import LapStyleModel from .lapstyle_model import LapStyleDraModel, LapStyleRevFirstModel, LapStyleRevSecondModel
...@@ -29,4 +29,4 @@ from .drn import DRNGenerator ...@@ -29,4 +29,4 @@ from .drn import DRNGenerator
from .generator_starganv2 import StarGANv2Generator, StarGANv2Style, StarGANv2Mapping, FAN from .generator_starganv2 import StarGANv2Generator, StarGANv2Style, StarGANv2Mapping, FAN
from .edvr import EDVRNet from .edvr import EDVRNet
from .generator_firstorder import FirstOrderGenerator from .generator_firstorder import FirstOrderGenerator
from .generater_lapstyle import DecoderNet, Encoder from .generater_lapstyle import DecoderNet, Encoder, RevisionNet
...@@ -261,3 +261,55 @@ class Encoder(nn.Layer): ...@@ -261,3 +261,55 @@ class Encoder(nn.Layer):
x = self.enc_5(x) x = self.enc_5(x)
out['r51'] = x out['r51'] = x
return out return out
@GENERATORS.register()
class RevisionNet(nn.Layer):
"""RevisionNet of Revision module.
Paper:
Drafting and Revision: Laplacian Pyramid Network for Fast High-Quality
Artistic Style Transfer.
"""
def __init__(self, input_nc=6):
super(RevisionNet, self).__init__()
DownBlock = []
DownBlock += [
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(input_nc, 64, (3, 3)),
nn.ReLU()
]
DownBlock += [
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(64, 64, (3, 3), stride=2),
nn.ReLU()
]
self.resblock = ResnetBlock(64)
UpBlock = []
UpBlock += [
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(64, 64, (3, 3)),
nn.ReLU()
]
UpBlock += [
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(64, 3, (3, 3))
]
self.DownBlock = nn.Sequential(*DownBlock)
self.UpBlock = nn.Sequential(*UpBlock)
def forward(self, input):
"""
Args:
input (Tensor): (b, 6, 256, 256) is concat of last input and this lap.
Returns:
Tensor: (b, 3, 256, 256).
"""
out = self.DownBlock(input)
out = self.resblock(out)
out = self.UpBlock(out)
return out
...@@ -13,17 +13,19 @@ ...@@ -13,17 +13,19 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
import paddle.nn.functional as F
from .base_model import BaseModel from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
from .generators.builder import build_generator from .generators.builder import build_generator
from .criterions import build_criterion from .criterions import build_criterion
from .discriminators.builder import build_discriminator
from ..modules.init import init_weights from ..modules.init import init_weights
@MODELS.register() @MODELS.register()
class LapStyleModel(BaseModel): class LapStyleDraModel(BaseModel):
def __init__(self, def __init__(self,
generator_encode, generator_encode,
generator_decode, generator_decode,
...@@ -36,7 +38,7 @@ class LapStyleModel(BaseModel): ...@@ -36,7 +38,7 @@ class LapStyleModel(BaseModel):
content_weight=1.0, content_weight=1.0,
style_weight=3.0): style_weight=3.0):
super(LapStyleModel, self).__init__() super(LapStyleDraModel, self).__init__()
# define generators # define generators
self.nets['net_enc'] = build_generator(generator_encode) self.nets['net_enc'] = build_generator(generator_encode)
...@@ -69,7 +71,7 @@ class LapStyleModel(BaseModel): ...@@ -69,7 +71,7 @@ class LapStyleModel(BaseModel):
self.stylized = self.nets['net_dec'](self.cF, self.sF) self.stylized = self.nets['net_dec'](self.cF, self.sF)
self.visual_items['stylized'] = self.stylized self.visual_items['stylized'] = self.stylized
def backward_dnc(self): def backward_Dec(self):
self.tF = self.nets['net_enc'](self.stylized) self.tF = self.nets['net_enc'](self.stylized)
"""content loss""" """content loss"""
self.loss_c = 0 self.loss_c = 0
...@@ -114,5 +116,330 @@ class LapStyleModel(BaseModel): ...@@ -114,5 +116,330 @@ class LapStyleModel(BaseModel):
"""Calculate losses, gradients, and update network weights""" """Calculate losses, gradients, and update network weights"""
self.forward() self.forward()
optimizers['optimG'].clear_grad() optimizers['optimG'].clear_grad()
self.backward_dnc() self.backward_Dec()
self.optimizers['optimG'].step() self.optimizers['optimG'].step()
def tensor_resample(tensor, dst_size, mode='bilinear'):
return F.interpolate(tensor, dst_size, mode=mode, align_corners=False)
def laplacian(x):
"""
Laplacian
return:
x - upsample(downsample(x))
"""
return x - tensor_resample(
tensor_resample(x, [x.shape[2] // 2, x.shape[3] // 2]),
[x.shape[2], x.shape[3]])
def make_laplace_pyramid(x, levels):
"""
Make Laplacian Pyramid
"""
pyramid = []
current = x
for i in range(levels):
pyramid.append(laplacian(current))
current = tensor_resample(
current,
(max(current.shape[2] // 2, 1), max(current.shape[3] // 2, 1)))
pyramid.append(current)
return pyramid
def fold_laplace_pyramid(pyramid):
"""
Fold Laplacian Pyramid
"""
current = pyramid[-1]
for i in range(len(pyramid) - 2, -1, -1): # iterate from len-2 to 0
up_h, up_w = pyramid[i].shape[2], pyramid[i].shape[3]
current = pyramid[i] + tensor_resample(current, (up_h, up_w))
return current
@MODELS.register()
class LapStyleRevFirstModel(BaseModel):
def __init__(self,
revnet_generator,
revnet_discriminator,
draftnet_encode,
draftnet_decode,
calc_style_emd_loss=None,
calc_content_relt_loss=None,
calc_content_loss=None,
calc_style_loss=None,
gan_criterion=None,
content_layers=['r11', 'r21', 'r31', 'r41', 'r51'],
style_layers=['r11', 'r21', 'r31', 'r41', 'r51'],
content_weight=1.0,
style_weight=3.0):
super(LapStyleRevFirstModel, self).__init__()
# define draftnet params
self.nets['net_enc'] = build_generator(draftnet_encode)
self.nets['net_dec'] = build_generator(draftnet_decode)
self.set_requires_grad([self.nets['net_enc']], False)
self.set_requires_grad([self.nets['net_enc']], False)
# define revision-net params
self.nets['net_rev'] = build_generator(revnet_generator)
init_weights(self.nets['net_rev'])
self.nets['netD'] = build_discriminator(revnet_discriminator)
init_weights(self.nets['netD'])
# define loss functions
self.calc_style_emd_loss = build_criterion(calc_style_emd_loss)
self.calc_content_relt_loss = build_criterion(calc_content_relt_loss)
self.calc_content_loss = build_criterion(calc_content_loss)
self.calc_style_loss = build_criterion(calc_style_loss)
self.gan_criterion = build_criterion(gan_criterion)
self.content_layers = content_layers
self.style_layers = style_layers
self.content_weight = content_weight
self.style_weight = style_weight
def setup_input(self, input):
self.ci = paddle.to_tensor(input['ci'])
self.visual_items['ci'] = self.ci
self.si = paddle.to_tensor(input['si'])
self.visual_items['si'] = self.si
self.image_paths = input['ci_path']
self.pyr_ci = make_laplace_pyramid(self.ci, 1)
self.pyr_si = make_laplace_pyramid(self.si, 1)
self.pyr_ci.append(self.ci)
self.pyr_si.append(self.si)
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
cF = self.nets['net_enc'](self.pyr_ci[1])
sF = self.nets['net_enc'](self.pyr_si[1])
stylized_small = self.nets['net_dec'](cF, sF)
self.visual_items['stylized_small'] = stylized_small
stylized_up = F.interpolate(stylized_small, scale_factor=2)
revnet_input = paddle.concat(x=[self.pyr_ci[0], stylized_up], axis=1)
stylized_rev_lap = self.nets['net_rev'](revnet_input)
stylized_rev = fold_laplace_pyramid([stylized_rev_lap, stylized_small])
self.stylized = stylized_rev
self.visual_items['stylized'] = self.stylized
def backward_G(self):
self.tF = self.nets['net_enc'](self.stylized)
self.cF = self.nets['net_enc'](self.pyr_ci[2])
self.sF = self.nets['net_enc'](self.pyr_si[2])
"""content loss"""
self.loss_c = 0
for layer in self.content_layers:
self.loss_c += self.calc_content_loss(self.tF[layer],
self.cF[layer],
norm=True)
self.losses['loss_c'] = self.loss_c
"""style loss"""
self.loss_s = 0
for layer in self.style_layers:
self.loss_s += self.calc_style_loss(self.tF[layer], self.sF[layer])
self.losses['loss_s'] = self.loss_s
"""relative loss"""
self.loss_style_remd = self.calc_style_emd_loss(
self.tF['r31'], self.sF['r31']) + self.calc_style_emd_loss(
self.tF['r41'], self.sF['r41'])
self.loss_content_relt = self.calc_content_relt_loss(
self.tF['r31'], self.cF['r31']) + self.calc_content_relt_loss(
self.tF['r41'], self.cF['r41'])
self.losses['loss_style_remd'] = self.loss_style_remd
self.losses['loss_content_relt'] = self.loss_content_relt
"""gan loss"""
pred_fake = self.nets['netD'](self.stylized)
self.loss_G_GAN = self.gan_criterion(pred_fake, True)
self.losses['loss_gan_G'] = self.loss_G_GAN
self.loss = self.loss_G_GAN + self.loss_c * self.content_weight + self.loss_s * self.style_weight +\
self.loss_style_remd * 10 + self.loss_content_relt * 16
self.loss.backward()
return self.loss
def backward_D(self):
"""Calculate GAN loss for the discriminator"""
pred_fake = self.nets['netD'](self.stylized.detach())
self.loss_D_fake = self.gan_criterion(pred_fake, False)
pred_real = self.nets['netD'](self.pyr_si[2])
self.loss_D_real = self.gan_criterion(pred_real, True)
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D.backward()
self.losses['D_fake_loss'] = self.loss_D_fake
self.losses['D_real_loss'] = self.loss_D_real
def train_iter(self, optimizers=None):
# compute fake images: G(A)
self.forward()
# update D
self.set_requires_grad(self.nets['netD'], True)
optimizers['optimD'].clear_grad()
self.backward_D()
optimizers['optimD'].step()
# update G
self.set_requires_grad(self.nets['netD'], False)
optimizers['optimG'].clear_grad()
self.backward_G()
optimizers['optimG'].step()
@MODELS.register()
class LapStyleRevSecondModel(BaseModel):
def __init__(self,
revnet_generator,
revnet_discriminator,
draftnet_encode,
draftnet_decode,
calc_style_emd_loss=None,
calc_content_relt_loss=None,
calc_content_loss=None,
calc_style_loss=None,
gan_criterion=None,
content_layers=['r11', 'r21', 'r31', 'r41', 'r51'],
style_layers=['r11', 'r21', 'r31', 'r41', 'r51'],
content_weight=1.0,
style_weight=3.0):
super(LapStyleRevSecondModel, self).__init__()
# define draftnet params
self.nets['net_enc'] = build_generator(draftnet_encode)
self.nets['net_dec'] = build_generator(draftnet_decode)
self.set_requires_grad([self.nets['net_enc']], False)
self.set_requires_grad([self.nets['net_enc']], False)
# define the first revnet params
self.nets['net_rev'] = build_generator(revnet_generator)
self.set_requires_grad([self.nets['net_rev']], False)
# define the second revnet params
self.nets['net_rev_2'] = build_generator(revnet_generator)
init_weights(self.nets['net_rev_2'])
self.nets['netD'] = build_discriminator(revnet_discriminator)
init_weights(self.nets['netD'])
# define loss functions
self.calc_style_emd_loss = build_criterion(calc_style_emd_loss)
self.calc_content_relt_loss = build_criterion(calc_content_relt_loss)
self.calc_content_loss = build_criterion(calc_content_loss)
self.calc_style_loss = build_criterion(calc_style_loss)
self.gan_criterion = build_criterion(gan_criterion)
self.content_layers = content_layers
self.style_layers = style_layers
self.content_weight = content_weight
self.style_weight = style_weight
def setup_input(self, input):
self.ci = paddle.to_tensor(input['ci'])
self.visual_items['ci'] = self.ci
self.si = paddle.to_tensor(input['si'])
self.visual_items['si'] = self.si
self.image_paths = input['ci_path']
self.pyr_ci = make_laplace_pyramid(self.ci, 2)
self.pyr_si = make_laplace_pyramid(self.si, 2)
self.pyr_ci.append(self.ci)
self.pyr_si.append(self.si)
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
cF = self.nets['net_enc'](self.pyr_ci[2])
sF = self.nets['net_enc'](self.pyr_si[2])
stylized_small = self.nets['net_dec'](cF, sF)
self.visual_items['stylized_small'] = stylized_small
stylized_up = F.interpolate(stylized_small, scale_factor=2)
revnet_input = paddle.concat(x=[self.pyr_ci[1], stylized_up], axis=1)
stylized_rev_lap = self.nets['net_rev'](revnet_input)
stylized_rev = fold_laplace_pyramid([stylized_rev_lap, stylized_small])
self.visual_items['stylized_rev_first'] = stylized_rev
stylized_up = F.interpolate(stylized_rev, scale_factor=2)
revnet_input = paddle.concat(x=[self.pyr_ci[0], stylized_up], axis=1)
stylized_rev_lap_second = self.nets['net_rev_2'](revnet_input)
stylized_rev_second = fold_laplace_pyramid(
[stylized_rev_lap_second, stylized_rev_lap, stylized_small])
self.stylized = stylized_rev_second
self.visual_items['stylized'] = self.stylized
def backward_G(self):
self.tF = self.nets['net_enc'](self.stylized)
self.cF = self.nets['net_enc'](self.pyr_ci[3])
self.sF = self.nets['net_enc'](self.pyr_si[3])
"""content loss"""
self.loss_c = 0
for layer in self.content_layers:
self.loss_c += self.calc_content_loss(self.tF[layer],
self.cF[layer],
norm=True)
self.losses['loss_c'] = self.loss_c
"""style loss"""
self.loss_s = 0
for layer in self.style_layers:
self.loss_s += self.calc_style_loss(self.tF[layer], self.sF[layer])
self.losses['loss_s'] = self.loss_s
"""relative loss"""
self.loss_style_remd = self.calc_style_emd_loss(self.tF['r41'],
self.sF['r41'])
self.loss_content_relt = self.calc_content_relt_loss(
self.tF['r41'], self.cF['r41'])
self.losses['loss_style_remd'] = self.loss_style_remd
self.losses['loss_content_relt'] = self.loss_content_relt
"""gan loss"""
pred_fake = self.nets['netD'](self.stylized)
self.loss_G_GAN = self.gan_criterion(pred_fake, True)
self.losses['loss_gan_G'] = self.loss_G_GAN
self.loss = self.loss_G_GAN + self.loss_c * self.content_weight + self.loss_s * self.style_weight +\
self.loss_style_remd * 10 + self.loss_content_relt * 16
self.loss.backward()
return self.loss
def backward_D(self):
"""Calculate GAN loss for the discriminator"""
pred_fake = self.nets['netD'](self.stylized.detach())
self.loss_D_fake = self.gan_criterion(pred_fake, False)
pred_real = self.nets['netD'](self.pyr_si[3])
self.loss_D_real = self.gan_criterion(pred_real, True)
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D.backward()
self.losses['D_fake_loss'] = self.loss_D_fake
self.losses['D_real_loss'] = self.loss_D_real
def train_iter(self, optimizers=None):
# compute fake images: G(A)
self.forward()
# update D
self.set_requires_grad(self.nets['netD'], True)
optimizers['optimD'].clear_grad()
self.backward_D()
optimizers['optimD'].step()
# update G
self.set_requires_grad(self.nets['netD'], False)
optimizers['optimG'].clear_grad()
self.backward_G()
optimizers['optimG'].step()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册