Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
0ded43f7
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
大约 1 年 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0ded43f7
编写于
5月 14, 2021
作者:
W
wangna11BD
提交者:
GitHub
5月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add LapStyle Model from vis (#307)
* add LapStyle Model
上级
c12e50aa
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
722 addition
and
3 deletion
+722
-3
configs/lapstyle_draft.yaml
configs/lapstyle_draft.yaml
+67
-0
ppgan/datasets/__init__.py
ppgan/datasets/__init__.py
+1
-0
ppgan/datasets/lapstyle_dataset.py
ppgan/datasets/lapstyle_dataset.py
+90
-0
ppgan/engine/trainer.py
ppgan/engine/trainer.py
+4
-1
ppgan/models/__init__.py
ppgan/models/__init__.py
+1
-0
ppgan/models/criterions/__init__.py
ppgan/models/criterions/__init__.py
+3
-1
ppgan/models/criterions/pixel_loss.py
ppgan/models/criterions/pixel_loss.py
+107
-0
ppgan/models/discriminators/__init__.py
ppgan/models/discriminators/__init__.py
+1
-0
ppgan/models/discriminators/discriminator_lapstyle.py
ppgan/models/discriminators/discriminator_lapstyle.py
+54
-0
ppgan/models/generators/__init__.py
ppgan/models/generators/__init__.py
+1
-0
ppgan/models/generators/generater_lapstyle.py
ppgan/models/generators/generater_lapstyle.py
+263
-0
ppgan/models/lapstyle_model.py
ppgan/models/lapstyle_model.py
+118
-0
ppgan/solver/__init__.py
ppgan/solver/__init__.py
+1
-1
ppgan/solver/lr_scheduler.py
ppgan/solver/lr_scheduler.py
+11
-0
未找到文件。
configs/lapstyle_draft.yaml
0 → 100644
浏览文件 @
0ded43f7
total_iters
:
30000
output_dir
:
output_dir
checkpoints_dir
:
checkpoints
min_max
:
(0., 1.)
model
:
name
:
LapStyleModel
generator_encode
:
name
:
Encoder
generator_decode
:
name
:
DecoderNet
calc_style_emd_loss
:
name
:
CalcStyleEmdLoss
calc_content_relt_loss
:
name
:
CalcContentReltLoss
calc_content_loss
:
name
:
CalcContentLoss
calc_style_loss
:
name
:
CalcStyleLoss
content_layers
:
[
'
r11'
,
'
r21'
,
'
r31'
,
'
r41'
,
'
r51'
]
style_layers
:
[
'
r11'
,
'
r21'
,
'
r31'
,
'
r41'
,
'
r51'
]
content_weight
:
1.0
style_weight
:
3.0
dataset
:
train
:
name
:
LapStyleDataset
content_root
:
data/coco/train2017/
style_root
:
data/starrynew.png
load_size
:
136
crop_size
:
128
num_workers
:
16
batch_size
:
5
test
:
name
:
LapStyleDataset
content_root
:
data/coco/test2017/
style_root
:
data/starrynew.png
load_size
:
136
crop_size
:
128
num_workers
:
0
batch_size
:
1
lr_scheduler
:
name
:
NonLinearDecay
learning_rate
:
1e-4
lr_decay
:
5e-5
optimizer
:
optimG
:
name
:
Adam
net_names
:
-
net_dec
beta1
:
0.9
beta2
:
0.999
validate
:
interval
:
5000
save_img
:
false
log_config
:
interval
:
10
visiual_interval
:
5000
snapshot_config
:
interval
:
5000
ppgan/datasets/__init__.py
浏览文件 @
0ded43f7
...
...
@@ -23,3 +23,4 @@ from .wav2lip_dataset import Wav2LipDataset
from
.starganv2_dataset
import
StarGANv2Dataset
from
.edvr_dataset
import
REDSDataset
from
.firstorder_dataset
import
FirstOrderDataset
from
.lapstyle_dataset
import
LapStyleDataset
ppgan/datasets/lapstyle_dataset.py
0 → 100644
浏览文件 @
0ded43f7
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
import
os
import
numpy
as
np
from
PIL
import
Image
import
paddle
import
paddle.vision.transforms
as
T
from
paddle.io
import
Dataset
from
.builder
import
DATASETS
logger
=
logging
.
getLogger
(
__name__
)
def
data_transform
(
crop_size
):
transform_list
=
[
T
.
RandomCrop
(
crop_size
)]
return
T
.
Compose
(
transform_list
)
@
DATASETS
.
register
()
class
LapStyleDataset
(
Dataset
):
"""
coco2017 dataset for LapStyle model
"""
def
__init__
(
self
,
content_root
,
style_root
,
load_size
,
crop_size
):
super
(
LapStyleDataset
,
self
).
__init__
()
self
.
content_root
=
content_root
self
.
paths
=
os
.
listdir
(
self
.
content_root
)
self
.
style_root
=
style_root
self
.
load_size
=
load_size
self
.
crop_size
=
crop_size
self
.
transform
=
data_transform
(
self
.
crop_size
)
def
__getitem__
(
self
,
index
):
"""Get training sample
return:
ci: content image with shape [C,W,H],
si: style image with shape [C,W,H],
ci_path: str
"""
path
=
self
.
paths
[
index
]
content_img
=
Image
.
open
(
os
.
path
.
join
(
self
.
content_root
,
path
)).
convert
(
'RGB'
)
content_img
=
content_img
.
resize
((
self
.
load_size
,
self
.
load_size
),
Image
.
BILINEAR
)
content_img
=
np
.
array
(
content_img
)
style_img
=
Image
.
open
(
self
.
style_root
).
convert
(
'RGB'
)
style_img
=
style_img
.
resize
((
self
.
load_size
,
self
.
load_size
),
Image
.
BILINEAR
)
style_img
=
np
.
array
(
style_img
)
content_img
=
self
.
transform
(
content_img
)
style_img
=
self
.
transform
(
style_img
)
content_img
=
self
.
img
(
content_img
)
style_img
=
self
.
img
(
style_img
)
return
{
'ci'
:
content_img
,
'si'
:
style_img
,
'ci_path'
:
path
}
def
img
(
self
,
img
):
"""make image with [0,255] and HWC to [0,1] and CHW
return:
img: image with shape [3,W,H] and value [0, 1].
"""
# [0,255] to [0,1]
img
=
img
.
astype
(
np
.
float32
)
/
255.
# some images have 4 channels
if
img
.
shape
[
2
]
>
3
:
img
=
img
[:,
:,
:
3
]
# HWC to CHW
img
=
np
.
transpose
(
img
,
(
2
,
0
,
1
)).
astype
(
'float32'
)
return
img
def
__len__
(
self
):
return
len
(
self
.
paths
)
def
name
(
self
):
return
'LapStyleDataset'
ppgan/engine/trainer.py
浏览文件 @
0ded43f7
...
...
@@ -347,7 +347,10 @@ class Trainer:
dataformats
=
"HWC"
if
image_num
==
1
else
"NCHW"
)
else
:
if
self
.
cfg
.
is_train
:
msg
=
'epoch%.3d_'
%
self
.
current_epoch
if
self
.
by_epoch
:
msg
=
'epoch%.3d_'
%
self
.
current_epoch
else
:
msg
=
'iter%.3d_'
%
self
.
current_iter
else
:
msg
=
''
makedirs
(
os
.
path
.
join
(
self
.
output_dir
,
results_dir
))
...
...
ppgan/models/__init__.py
浏览文件 @
0ded43f7
...
...
@@ -29,3 +29,4 @@ from .wav2lip_hq_model import Wav2LipModelHq
from
.starganv2_model
import
StarGANv2Model
from
.edvr_model
import
EDVRModel
from
.firstorder_model
import
FirstOrderModel
from
.lapstyle_model
import
LapStyleModel
ppgan/models/criterions/__init__.py
浏览文件 @
0ded43f7
from
.gan_loss
import
GANLoss
from
.perceptual_loss
import
PerceptualLoss
from
.pixel_loss
import
L1Loss
,
MSELoss
,
CharbonnierLoss
from
.pixel_loss
import
L1Loss
,
MSELoss
,
CharbonnierLoss
,
\
CalcStyleEmdLoss
,
CalcContentReltLoss
,
\
CalcContentLoss
,
CalcStyleLoss
from
.builder
import
build_criterion
ppgan/models/criterions/pixel_loss.py
浏览文件 @
0ded43f7
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
numpy
as
np
from
..generators.generater_lapstyle
import
calc_mean_std
,
mean_variance_norm
import
paddle
import
paddle.nn
as
nn
...
...
@@ -127,3 +128,109 @@ class BCEWithLogitsLoss():
weights. Default: None.
"""
return
self
.
loss_weight
*
self
.
_bce_loss
(
pred
,
target
)
def
calc_emd_loss
(
pred
,
target
):
"""calculate emd loss.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
"""
b
,
_
,
h
,
w
=
pred
.
shape
pred
=
pred
.
reshape
([
b
,
-
1
,
w
*
h
])
pred_norm
=
paddle
.
sqrt
((
pred
**
2
).
sum
(
1
).
reshape
([
b
,
-
1
,
1
]))
pred
=
pred
.
transpose
([
0
,
2
,
1
])
target_t
=
target
.
reshape
([
b
,
-
1
,
w
*
h
])
target_norm
=
paddle
.
sqrt
((
target
**
2
).
sum
(
1
).
reshape
([
b
,
1
,
-
1
]))
similarity
=
paddle
.
bmm
(
pred
,
target_t
)
/
pred_norm
/
target_norm
dist
=
1.
-
similarity
return
dist
@
CRITERIONS
.
register
()
class
CalcStyleEmdLoss
():
"""Calc Style Emd Loss.
"""
def
__init__
(
self
):
super
(
CalcStyleEmdLoss
,
self
).
__init__
()
def
__call__
(
self
,
pred
,
target
):
"""Forward Function.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
"""
CX_M
=
calc_emd_loss
(
pred
,
target
)
m1
=
CX_M
.
min
(
2
)
m2
=
CX_M
.
min
(
1
)
m
=
paddle
.
concat
([
m1
.
mean
(),
m2
.
mean
()])
loss_remd
=
paddle
.
max
(
m
)
return
loss_remd
@
CRITERIONS
.
register
()
class
CalcContentReltLoss
():
"""Calc Content Relt Loss.
"""
def
__init__
(
self
):
super
(
CalcContentReltLoss
,
self
).
__init__
()
def
__call__
(
self
,
pred
,
target
):
"""Forward Function.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
"""
dM
=
1.
Mx
=
calc_emd_loss
(
pred
,
pred
)
Mx
=
Mx
/
Mx
.
sum
(
1
,
keepdim
=
True
)
My
=
calc_emd_loss
(
target
,
target
)
My
=
My
/
My
.
sum
(
1
,
keepdim
=
True
)
loss_content
=
paddle
.
abs
(
dM
*
(
Mx
-
My
)).
mean
()
*
pred
.
shape
[
2
]
*
pred
.
shape
[
3
]
return
loss_content
@
CRITERIONS
.
register
()
class
CalcContentLoss
():
"""Calc Content Loss.
"""
def
__init__
(
self
):
self
.
mse_loss
=
nn
.
MSELoss
()
def
__call__
(
self
,
pred
,
target
,
norm
=
False
):
"""Forward Function.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
norm(Bool): whether use mean_variance_norm for pred and target
"""
if
(
norm
==
False
):
return
self
.
mse_loss
(
pred
,
target
)
else
:
return
self
.
mse_loss
(
mean_variance_norm
(
pred
),
mean_variance_norm
(
target
))
@
CRITERIONS
.
register
()
class
CalcStyleLoss
():
"""Calc Style Loss.
"""
def
__init__
(
self
):
self
.
mse_loss
=
nn
.
MSELoss
()
def
__call__
(
self
,
pred
,
target
):
"""Forward Function.
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
"""
pred_mean
,
pred_std
=
calc_mean_std
(
pred
)
target_mean
,
target_std
=
calc_mean_std
(
target
)
return
self
.
mse_loss
(
pred_mean
,
target_mean
)
+
self
.
mse_loss
(
pred_std
,
target_std
)
ppgan/models/discriminators/__init__.py
浏览文件 @
0ded43f7
...
...
@@ -22,3 +22,4 @@ from .syncnet import SyncNetColor
from
.wav2lip_disc_qual
import
Wav2LipDiscQual
from
.discriminator_starganv2
import
StarGANv2Discriminator
from
.discriminator_firstorder
import
FirstOrderDiscriminator
from
.discriminator_lapstyle
import
LapStyleDiscriminator
ppgan/models/discriminators/discriminator_lapstyle.py
0 → 100644
浏览文件 @
0ded43f7
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
from
.builder
import
DISCRIMINATORS
@
DISCRIMINATORS
.
register
()
class
LapStyleDiscriminator
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
LapStyleDiscriminator
,
self
).
__init__
()
num_layer
=
3
num_channel
=
32
self
.
head
=
nn
.
Sequential
(
(
'conv'
,
nn
.
Conv2D
(
3
,
num_channel
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)),
(
'norm'
,
nn
.
BatchNorm2D
(
num_channel
)),
(
'LeakyRelu'
,
nn
.
LeakyReLU
(
0.2
)))
self
.
body
=
nn
.
Sequential
()
for
i
in
range
(
num_layer
-
2
):
self
.
body
.
add_sublayer
(
'conv%d'
%
(
i
+
1
),
nn
.
Conv2D
(
num_channel
,
num_channel
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
))
self
.
body
.
add_sublayer
(
'norm%d'
%
(
i
+
1
),
nn
.
BatchNorm2D
(
num_channel
))
self
.
body
.
add_sublayer
(
'LeakyRelu%d'
%
(
i
+
1
),
nn
.
LeakyReLU
(
0.2
))
self
.
tail
=
nn
.
Conv2D
(
num_channel
,
1
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
):
x
=
self
.
head
(
x
)
x
=
self
.
body
(
x
)
x
=
self
.
tail
(
x
)
return
x
ppgan/models/generators/__init__.py
浏览文件 @
0ded43f7
...
...
@@ -29,3 +29,4 @@ from .drn import DRNGenerator
from
.generator_starganv2
import
StarGANv2Generator
,
StarGANv2Style
,
StarGANv2Mapping
,
FAN
from
.edvr
import
EDVRNet
from
.generator_firstorder
import
FirstOrderGenerator
from
.generater_lapstyle
import
DecoderNet
,
Encoder
ppgan/models/generators/generater_lapstyle.py
0 → 100644
浏览文件 @
0ded43f7
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
paddle
import
paddle.nn
as
nn
from
...utils.download
import
get_path_from_url
from
.builder
import
GENERATORS
def
calc_mean_std
(
feat
,
eps
=
1e-5
):
"""calculate mean and standard deviation.
Args:
feat (Tensor): Tensor with shape (N, C, H, W).
eps (float): Default: 1e-5.
Return:
mean and std of feat
shape: [N, C, 1, 1]
"""
size
=
feat
.
shape
assert
(
len
(
size
)
==
4
)
N
,
C
=
size
[:
2
]
feat_var
=
feat
.
reshape
([
N
,
C
,
-
1
])
feat_var
=
paddle
.
var
(
feat_var
,
axis
=
2
)
+
eps
feat_std
=
paddle
.
sqrt
(
feat_var
)
feat_std
=
feat_std
.
reshape
([
N
,
C
,
1
,
1
])
feat_mean
=
feat
.
reshape
([
N
,
C
,
-
1
])
feat_mean
=
paddle
.
mean
(
feat_mean
,
axis
=
2
)
feat_mean
=
feat_mean
.
reshape
([
N
,
C
,
1
,
1
])
return
feat_mean
,
feat_std
def
mean_variance_norm
(
feat
):
"""mean_variance_norm.
Args:
feat (Tensor): Tensor with shape (N, C, H, W).
Return:
Normalized feat with shape (N, C, H, W)
"""
size
=
feat
.
shape
mean
,
std
=
calc_mean_std
(
feat
)
normalized_feat
=
(
feat
-
mean
.
expand
(
size
))
/
std
.
expand
(
size
)
return
normalized_feat
def
adaptive_instance_normalization
(
content_feat
,
style_feat
):
"""adaptive_instance_normalization.
Args:
content_feat (Tensor): Tensor with shape (N, C, H, W).
style_feat (Tensor): Tensor with shape (N, C, H, W).
Return:
Normalized content_feat with shape (N, C, H, W)
"""
assert
(
content_feat
.
shape
[:
2
]
==
style_feat
.
shape
[:
2
])
size
=
content_feat
.
shape
style_mean
,
style_std
=
calc_mean_std
(
style_feat
)
content_mean
,
content_std
=
calc_mean_std
(
content_feat
)
normalized_feat
=
(
content_feat
-
content_mean
.
expand
(
size
))
/
content_std
.
expand
(
size
)
return
normalized_feat
*
style_std
.
expand
(
size
)
+
style_mean
.
expand
(
size
)
class
ResnetBlock
(
nn
.
Layer
):
"""Residual block.
It has a style of:
---Pad-Conv-ReLU-Pad-Conv-+-
|________________________|
Args:
dim (int): Channel number of intermediate features.
"""
def
__init__
(
self
,
dim
):
super
(
ResnetBlock
,
self
).
__init__
()
self
.
conv_block
=
nn
.
Sequential
(
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
dim
,
dim
,
(
3
,
3
)),
nn
.
ReLU
(),
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
dim
,
dim
,
(
3
,
3
)))
def
forward
(
self
,
x
):
out
=
x
+
self
.
conv_block
(
x
)
return
out
class
ConvBlock
(
nn
.
Layer
):
"""convolution block.
It has a style of:
---Pad-Conv-ReLU---
Args:
dim1 (int): Channel number of input features.
dim2 (int): Channel number of output features.
"""
def
__init__
(
self
,
dim1
,
dim2
):
super
(
ConvBlock
,
self
).
__init__
()
self
.
conv_block
=
nn
.
Sequential
(
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
dim1
,
dim2
,
(
3
,
3
)),
nn
.
ReLU
())
def
forward
(
self
,
x
):
out
=
self
.
conv_block
(
x
)
return
out
@
GENERATORS
.
register
()
class
DecoderNet
(
nn
.
Layer
):
"""Decoder of Drafting module.
Paper:
Drafting and Revision: Laplacian Pyramid Network for Fast High-Quality
Artistic Style Transfer.
"""
def
__init__
(
self
):
super
(
DecoderNet
,
self
).
__init__
()
self
.
resblock_41
=
ResnetBlock
(
512
)
self
.
convblock_41
=
ConvBlock
(
512
,
256
)
self
.
resblock_31
=
ResnetBlock
(
256
)
self
.
convblock_31
=
ConvBlock
(
256
,
128
)
self
.
convblock_21
=
ConvBlock
(
128
,
128
)
self
.
convblock_22
=
ConvBlock
(
128
,
64
)
self
.
convblock_11
=
ConvBlock
(
64
,
64
)
self
.
upsample
=
nn
.
Upsample
(
scale_factor
=
2
,
mode
=
'nearest'
)
self
.
final_conv
=
nn
.
Sequential
(
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
64
,
3
,
(
3
,
3
)))
def
forward
(
self
,
cF
,
sF
):
out
=
adaptive_instance_normalization
(
cF
[
'r41'
],
sF
[
'r41'
])
out
=
self
.
resblock_41
(
out
)
out
=
self
.
convblock_41
(
out
)
out
=
self
.
upsample
(
out
)
out
+=
adaptive_instance_normalization
(
cF
[
'r31'
],
sF
[
'r31'
])
out
=
self
.
resblock_31
(
out
)
out
=
self
.
convblock_31
(
out
)
out
=
self
.
upsample
(
out
)
out
+=
adaptive_instance_normalization
(
cF
[
'r21'
],
sF
[
'r21'
])
out
=
self
.
convblock_21
(
out
)
out
=
self
.
convblock_22
(
out
)
out
=
self
.
upsample
(
out
)
out
=
self
.
convblock_11
(
out
)
out
=
self
.
final_conv
(
out
)
return
out
vgg
=
nn
.
Sequential
(
nn
.
Conv2D
(
3
,
3
,
(
1
,
1
)),
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
3
,
64
,
(
3
,
3
)),
nn
.
ReLU
(),
# relu1-1
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
64
,
64
,
(
3
,
3
)),
nn
.
ReLU
(),
# relu1-2
nn
.
MaxPool2D
((
2
,
2
),
(
2
,
2
),
(
0
,
0
),
ceil_mode
=
True
),
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
64
,
128
,
(
3
,
3
)),
nn
.
ReLU
(),
# relu2-1
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
128
,
128
,
(
3
,
3
)),
nn
.
ReLU
(),
# relu2-2
nn
.
MaxPool2D
((
2
,
2
),
(
2
,
2
),
(
0
,
0
),
ceil_mode
=
True
),
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
128
,
256
,
(
3
,
3
)),
nn
.
ReLU
(),
# relu3-1
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
256
,
256
,
(
3
,
3
)),
nn
.
ReLU
(),
# relu3-2
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
256
,
256
,
(
3
,
3
)),
nn
.
ReLU
(),
# relu3-3
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
256
,
256
,
(
3
,
3
)),
nn
.
ReLU
(),
# relu3-4
nn
.
MaxPool2D
((
2
,
2
),
(
2
,
2
),
(
0
,
0
),
ceil_mode
=
True
),
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
256
,
512
,
(
3
,
3
)),
nn
.
ReLU
(),
# relu4-1, this is the last layer used
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
512
,
512
,
(
3
,
3
)),
nn
.
ReLU
(),
# relu4-2
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
512
,
512
,
(
3
,
3
)),
nn
.
ReLU
(),
# relu4-3
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
512
,
512
,
(
3
,
3
)),
nn
.
ReLU
(),
# relu4-4
nn
.
MaxPool2D
((
2
,
2
),
(
2
,
2
),
(
0
,
0
),
ceil_mode
=
True
),
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
512
,
512
,
(
3
,
3
)),
nn
.
ReLU
(),
# relu5-1
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
512
,
512
,
(
3
,
3
)),
nn
.
ReLU
(),
# relu5-2
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
512
,
512
,
(
3
,
3
)),
nn
.
ReLU
(),
# relu5-3
nn
.
Pad2D
([
1
,
1
,
1
,
1
],
mode
=
'reflect'
),
nn
.
Conv2D
(
512
,
512
,
(
3
,
3
)),
nn
.
ReLU
()
# relu5-4
)
@
GENERATORS
.
register
()
class
Encoder
(
nn
.
Layer
):
"""Encoder of Drafting module.
Paper:
Drafting and Revision: Laplacian Pyramid Network for Fast High-Quality
Artistic Style Transfer.
"""
def
__init__
(
self
):
super
(
Encoder
,
self
).
__init__
()
vgg_net
=
vgg
weight_path
=
get_path_from_url
(
'https://paddlegan.bj.bcebos.com/models/vgg_normalised.pdparams'
)
vgg_net
.
set_dict
(
paddle
.
load
(
weight_path
))
self
.
enc_1
=
nn
.
Sequential
(
*
list
(
vgg_net
.
children
())[:
4
])
# input -> relu1_1
self
.
enc_2
=
nn
.
Sequential
(
*
list
(
vgg_net
.
children
())[
4
:
11
])
# relu1_1 -> relu2_1
self
.
enc_3
=
nn
.
Sequential
(
*
list
(
vgg_net
.
children
())[
11
:
18
])
# relu2_1 -> relu3_1
self
.
enc_4
=
nn
.
Sequential
(
*
list
(
vgg_net
.
children
())[
18
:
31
])
# relu3_1 -> relu4_1
self
.
enc_5
=
nn
.
Sequential
(
*
list
(
vgg_net
.
children
())[
31
:
44
])
# relu4_1 -> relu5_1
def
forward
(
self
,
x
):
out
=
{}
x
=
self
.
enc_1
(
x
)
out
[
'r11'
]
=
x
x
=
self
.
enc_2
(
x
)
out
[
'r21'
]
=
x
x
=
self
.
enc_3
(
x
)
out
[
'r31'
]
=
x
x
=
self
.
enc_4
(
x
)
out
[
'r41'
]
=
x
x
=
self
.
enc_5
(
x
)
out
[
'r51'
]
=
x
return
out
ppgan/models/lapstyle_model.py
0 → 100644
浏览文件 @
0ded43f7
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
from
.base_model
import
BaseModel
from
.builder
import
MODELS
from
.generators.builder
import
build_generator
from
.criterions
import
build_criterion
from
..modules.init
import
init_weights
@
MODELS
.
register
()
class
LapStyleModel
(
BaseModel
):
def
__init__
(
self
,
generator_encode
,
generator_decode
,
calc_style_emd_loss
=
None
,
calc_content_relt_loss
=
None
,
calc_content_loss
=
None
,
calc_style_loss
=
None
,
content_layers
=
[
'r11'
,
'r21'
,
'r31'
,
'r41'
,
'r51'
],
style_layers
=
[
'r11'
,
'r21'
,
'r31'
,
'r41'
,
'r51'
],
content_weight
=
1.0
,
style_weight
=
3.0
):
super
(
LapStyleModel
,
self
).
__init__
()
# define generators
self
.
nets
[
'net_enc'
]
=
build_generator
(
generator_encode
)
self
.
nets
[
'net_dec'
]
=
build_generator
(
generator_decode
)
init_weights
(
self
.
nets
[
'net_dec'
])
self
.
set_requires_grad
([
self
.
nets
[
'net_enc'
]],
False
)
# define loss functions
self
.
calc_style_emd_loss
=
build_criterion
(
calc_style_emd_loss
)
self
.
calc_content_relt_loss
=
build_criterion
(
calc_content_relt_loss
)
self
.
calc_content_loss
=
build_criterion
(
calc_content_loss
)
self
.
calc_style_loss
=
build_criterion
(
calc_style_loss
)
self
.
content_layers
=
content_layers
self
.
style_layers
=
style_layers
self
.
content_weight
=
content_weight
self
.
style_weight
=
style_weight
def
setup_input
(
self
,
input
):
self
.
ci
=
paddle
.
to_tensor
(
input
[
'ci'
])
self
.
visual_items
[
'ci'
]
=
self
.
ci
self
.
si
=
paddle
.
to_tensor
(
input
[
'si'
])
self
.
visual_items
[
'si'
]
=
self
.
si
self
.
image_paths
=
input
[
'ci_path'
]
def
forward
(
self
):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self
.
cF
=
self
.
nets
[
'net_enc'
](
self
.
ci
)
self
.
sF
=
self
.
nets
[
'net_enc'
](
self
.
si
)
self
.
stylized
=
self
.
nets
[
'net_dec'
](
self
.
cF
,
self
.
sF
)
self
.
visual_items
[
'stylized'
]
=
self
.
stylized
def
backward_dnc
(
self
):
self
.
tF
=
self
.
nets
[
'net_enc'
](
self
.
stylized
)
"""content loss"""
self
.
loss_c
=
0
for
layer
in
self
.
content_layers
:
self
.
loss_c
+=
self
.
calc_content_loss
(
self
.
tF
[
layer
],
self
.
cF
[
layer
],
norm
=
True
)
self
.
losses
[
'loss_c'
]
=
self
.
loss_c
"""style loss"""
self
.
loss_s
=
0
for
layer
in
self
.
style_layers
:
self
.
loss_s
+=
self
.
calc_style_loss
(
self
.
tF
[
layer
],
self
.
sF
[
layer
])
self
.
losses
[
'loss_s'
]
=
self
.
loss_s
"""IDENTITY LOSSES"""
self
.
Icc
=
self
.
nets
[
'net_dec'
](
self
.
cF
,
self
.
cF
)
self
.
l_identity1
=
self
.
calc_content_loss
(
self
.
Icc
,
self
.
ci
)
self
.
Fcc
=
self
.
nets
[
'net_enc'
](
self
.
Icc
)
self
.
l_identity2
=
0
for
layer
in
self
.
content_layers
:
self
.
l_identity2
+=
self
.
calc_content_loss
(
self
.
Fcc
[
layer
],
self
.
cF
[
layer
])
self
.
losses
[
'l_identity1'
]
=
self
.
l_identity1
self
.
losses
[
'l_identity2'
]
=
self
.
l_identity2
"""relative loss"""
self
.
loss_style_remd
=
self
.
calc_style_emd_loss
(
self
.
tF
[
'r31'
],
self
.
sF
[
'r31'
])
+
self
.
calc_style_emd_loss
(
self
.
tF
[
'r41'
],
self
.
sF
[
'r41'
])
self
.
loss_content_relt
=
self
.
calc_content_relt_loss
(
self
.
tF
[
'r31'
],
self
.
cF
[
'r31'
])
+
self
.
calc_content_relt_loss
(
self
.
tF
[
'r41'
],
self
.
cF
[
'r41'
])
self
.
losses
[
'loss_style_remd'
]
=
self
.
loss_style_remd
self
.
losses
[
'loss_content_relt'
]
=
self
.
loss_content_relt
self
.
loss
=
self
.
loss_c
*
self
.
content_weight
+
self
.
loss_s
*
self
.
style_weight
+
\
self
.
l_identity1
*
50
+
self
.
l_identity2
*
1
+
self
.
loss_style_remd
*
10
+
\
self
.
loss_content_relt
*
16
self
.
loss
.
backward
()
return
self
.
loss
def
train_iter
(
self
,
optimizers
=
None
):
"""Calculate losses, gradients, and update network weights"""
self
.
forward
()
optimizers
[
'optimG'
].
clear_grad
()
self
.
backward_dnc
()
self
.
optimizers
[
'optimG'
].
step
()
ppgan/solver/__init__.py
浏览文件 @
0ded43f7
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
.lr_scheduler
import
CosineAnnealingRestartLR
,
LinearDecay
from
.lr_scheduler
import
CosineAnnealingRestartLR
,
LinearDecay
,
NonLinearDecay
from
.optimizer
import
*
from
.builder
import
build_lr_scheduler
from
.builder
import
build_optimizer
ppgan/solver/lr_scheduler.py
浏览文件 @
0ded43f7
...
...
@@ -21,6 +21,17 @@ from .builder import LRSCHEDULERS
LRSCHEDULERS
.
register
(
MultiStepDecay
)
@
LRSCHEDULERS
.
register
()
class
NonLinearDecay
(
LRScheduler
):
def
__init__
(
self
,
learning_rate
,
lr_decay
,
last_epoch
=-
1
):
self
.
lr_decay
=
lr_decay
super
(
NonLinearDecay
,
self
).
__init__
(
learning_rate
,
last_epoch
)
def
get_lr
(
self
):
lr
=
self
.
base_lr
/
(
1.0
+
self
.
lr_decay
*
self
.
last_epoch
)
return
lr
@
LRSCHEDULERS
.
register
()
class
LinearDecay
(
LambdaDecay
):
def
__init__
(
self
,
learning_rate
,
start_epoch
,
decay_epochs
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录