Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
cd642c08
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
大约 1 年 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
cd642c08
编写于
1月 05, 2021
作者:
L
LielinJiang
提交者:
GitHub
1月 05, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Lesrcnn model (#136)
* add lesrcnn model
上级
ffb1a225
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
540 addition
and
46 deletion
+540
-46
configs/esrgan_psnr_x4_div2k.yaml
configs/esrgan_psnr_x4_div2k.yaml
+2
-2
configs/lesrcnn_psnr_x4_div2k.yaml
configs/lesrcnn_psnr_x4_div2k.yaml
+101
-0
ppgan/datasets/preprocess/transforms.py
ppgan/datasets/preprocess/transforms.py
+16
-2
ppgan/datasets/transforms/__init__.py
ppgan/datasets/transforms/__init__.py
+1
-1
ppgan/engine/trainer.py
ppgan/engine/trainer.py
+0
-1
ppgan/metrics/psnr_ssim.py
ppgan/metrics/psnr_ssim.py
+43
-1
ppgan/models/esrgan_model.py
ppgan/models/esrgan_model.py
+45
-39
ppgan/models/generators/__init__.py
ppgan/models/generators/__init__.py
+1
-0
ppgan/models/generators/lesrcnn.py
ppgan/models/generators/lesrcnn.py
+331
-0
未找到文件。
configs/esrgan_psnr_x4_div2k.yaml
浏览文件 @
cd642c08
...
...
@@ -91,11 +91,11 @@ validate:
psnr
:
# metric name, can be arbitrary
name
:
PSNR
crop_border
:
4
test_y_channel
:
fals
e
test_y_channel
:
Tru
e
ssim
:
name
:
SSIM
crop_border
:
4
test_y_channel
:
fals
e
test_y_channel
:
Tru
e
log_config
:
interval
:
10
...
...
configs/lesrcnn_psnr_x4_div2k.yaml
0 → 100644
浏览文件 @
cd642c08
total_iters
:
1000000
output_dir
:
output_dir
# tensor range for function tensor2img
min_max
:
(0., 1.)
model
:
name
:
BaseSRModel
generator
:
name
:
LESRCNNGenerator
pixel_criterion
:
name
:
L1Loss
dataset
:
train
:
name
:
SRDataset
gt_folder
:
data/DIV2K/DIV2K_train_HR_sub
lq_folder
:
data/DIV2K/DIV2K_train_LR_bicubic/X4_sub
num_workers
:
4
batch_size
:
16
scale
:
4
preprocess
:
-
name
:
LoadImageFromFile
key
:
lq
-
name
:
LoadImageFromFile
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
SRPairedRandomCrop
gt_patch_size
:
128
scale
:
4
keys
:
[
image
,
image
]
-
name
:
PairedRandomHorizontalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomVerticalFlip
keys
:
[
image
,
image
]
-
name
:
PairedRandomTransposeHW
keys
:
[
image
,
image
]
-
name
:
Transpose
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
0.
,
.0
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
test
:
name
:
SRDataset
gt_folder
:
data/DIV2K/val_set14/Set14
lq_folder
:
data/DIV2K/val_set14/Set14_bicLRx4
scale
:
4
preprocess
:
-
name
:
LoadImageFromFile
key
:
lq
-
name
:
LoadImageFromFile
key
:
gt
-
name
:
Transforms
input_keys
:
[
lq
,
gt
]
pipeline
:
-
name
:
Transpose
keys
:
[
image
,
image
]
-
name
:
Normalize
mean
:
[
0.
,
.0
,
0.
]
std
:
[
255.
,
255.
,
255.
]
keys
:
[
image
,
image
]
lr_scheduler
:
name
:
CosineAnnealingRestartLR
learning_rate
:
0.0002
periods
:
[
250000
,
250000
,
250000
,
250000
]
restart_weights
:
[
1
,
1
,
1
,
1
]
eta_min
:
!!float
1e-7
optimizer
:
name
:
Adam
# add parameters of net_name to optim
# name should in self.nets
net_names
:
-
generator
beta1
:
0.9
beta2
:
0.99
validate
:
interval
:
5000
save_img
:
false
metrics
:
psnr
:
# metric name, can be arbitrary
name
:
PSNR
crop_border
:
4
test_y_channel
:
True
ssim
:
name
:
SSIM
crop_border
:
4
test_y_channel
:
True
log_config
:
interval
:
100
visiual_interval
:
5000
snapshot_config
:
interval
:
5000
ppgan/datasets/preprocess/transforms.py
浏览文件 @
cd642c08
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
sys
import
cv2
import
random
import
numbers
import
collections
...
...
@@ -40,8 +41,9 @@ TRANSFORMS.register(T.Transpose)
@
PREPROCESS
.
register
()
class
Transforms
():
def
__init__
(
self
,
pipeline
,
input_keys
):
def
__init__
(
self
,
pipeline
,
input_keys
,
output_keys
=
None
):
self
.
input_keys
=
input_keys
self
.
output_keys
=
output_keys
self
.
transforms
=
[]
for
transform_cfg
in
pipeline
:
self
.
transforms
.
append
(
build_from_config
(
transform_cfg
,
TRANSFORMS
))
...
...
@@ -58,6 +60,11 @@ class Transforms():
transform
.
params
,
dict
):
datas
.
update
(
transform
.
params
)
if
self
.
output_keys
is
not
None
:
for
i
,
k
in
enumerate
(
self
.
output_keys
):
datas
[
k
]
=
data
[
i
]
return
datas
for
i
,
k
in
enumerate
(
self
.
input_keys
):
datas
[
k
]
=
data
[
i
]
...
...
@@ -183,10 +190,11 @@ class SRPairedRandomCrop(T.BaseTransform):
scale (int): model upscale factor.
gt_patch_size (int): cropped gt patch size.
"""
def
__init__
(
self
,
scale
,
gt_patch_size
,
keys
=
None
):
def
__init__
(
self
,
scale
,
gt_patch_size
,
scale_list
=
False
,
keys
=
None
):
self
.
gt_patch_size
=
gt_patch_size
self
.
scale
=
scale
self
.
keys
=
keys
self
.
scale_list
=
scale_list
def
__call__
(
self
,
inputs
):
"""inputs must be (lq_img, gt_img)"""
...
...
@@ -214,5 +222,11 @@ class SRPairedRandomCrop(T.BaseTransform):
gt
=
gt
[
top_gt
:
top_gt
+
self
.
gt_patch_size
,
left_gt
:
left_gt
+
self
.
gt_patch_size
,
...]
if
self
.
scale_list
and
self
.
scale
==
4
:
lqx2
=
F
.
resize
(
gt
,
(
lq_patch_size
*
2
,
lq_patch_size
*
2
),
'bicubic'
)
outputs
=
(
lq
,
lqx2
,
gt
)
return
outputs
outputs
=
(
lq
,
gt
)
return
outputs
ppgan/datasets/transforms/__init__.py
浏览文件 @
cd642c08
from
.transforms
import
ResizeToScale
,
PairedRandomCrop
,
PairedRandomHorizontalFlip
,
Add
\ No newline at end of file
from
.transforms
import
ResizeToScale
,
PairedRandomCrop
,
PairedRandomHorizontalFlip
,
Add
ppgan/engine/trainer.py
浏览文件 @
cd642c08
...
...
@@ -386,7 +386,6 @@ class Trainer:
self
.
logger
.
warning
(
'Can not find state dict of net {}. Skip load pretrained weight for net {}'
.
format
(
net_name
,
net_name
))
net
.
set_state_dict
(
state_dicts
[
net_name
])
def
close
(
self
):
"""
...
...
ppgan/metrics/psnr_ssim.py
浏览文件 @
cd642c08
...
...
@@ -270,6 +270,48 @@ def bgr2ycbcr(img, y_only=False):
return
out_img
def
rgb2ycbcr
(
img
,
y_only
=
False
):
"""Convert a RGB image to YCbCr image.
The RGB version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type
=
img
.
dtype
if
img_type
!=
np
.
uint8
:
img
*=
255.
if
y_only
:
out_img
=
np
.
dot
(
img
,
[
65.481
,
128.553
,
24.966
])
/
255.
+
16.0
else
:
out_img
=
np
.
matmul
(
img
,
[[
24.966
,
112.0
,
-
18.214
],
[
128.553
,
-
74.203
,
-
93.786
],
[
65.481
,
-
37.797
,
112.0
]])
+
[
16
,
128
,
128
]
if
img_type
!=
np
.
uint8
:
out_img
/=
255.
else
:
out_img
=
out_img
.
round
()
return
out_img
def
to_y_channel
(
img
):
"""Change to Y channel of YCbCr.
...
...
@@ -281,6 +323,6 @@ def to_y_channel(img):
"""
img
=
img
.
astype
(
np
.
float32
)
/
255.
if
img
.
ndim
==
3
and
img
.
shape
[
2
]
==
3
:
img
=
bgr
2ycbcr
(
img
,
y_only
=
True
)
img
=
rgb
2ycbcr
(
img
,
y_only
=
True
)
img
=
img
[...,
None
]
return
img
*
255.
ppgan/models/esrgan_model.py
浏览文件 @
cd642c08
...
...
@@ -61,7 +61,6 @@ class ESRGAN(BaseSRModel):
self
.
gan_criterion
=
build_criterion
(
gan_criterion
)
def
train_iter
(
self
,
optimizers
=
None
):
self
.
set_requires_grad
(
self
.
nets
[
'discriminator'
],
False
)
optimizers
[
'optimG'
].
clear_grad
()
l_total
=
0
self
.
output
=
self
.
nets
[
'generator'
](
self
.
lq
)
...
...
@@ -83,41 +82,48 @@ class ESRGAN(BaseSRModel):
self
.
losses
[
'loss_style'
]
=
l_g_style
# gan loss (relativistic gan)
real_d_pred
=
self
.
nets
[
'discriminator'
](
self
.
gt
).
detach
()
fake_g_pred
=
self
.
nets
[
'discriminator'
](
self
.
output
)
l_g_real
=
self
.
gan_criterion
(
real_d_pred
-
paddle
.
mean
(
fake_g_pred
),
False
,
is_disc
=
False
)
l_g_fake
=
self
.
gan_criterion
(
fake_g_pred
-
paddle
.
mean
(
real_d_pred
),
True
,
is_disc
=
False
)
l_g_gan
=
(
l_g_real
+
l_g_fake
)
/
2
l_total
+=
l_g_gan
self
.
losses
[
'l_g_gan'
]
=
l_g_gan
l_total
.
backward
()
optimizers
[
'optimG'
].
step
()
self
.
set_requires_grad
(
self
.
nets
[
'discriminator'
],
True
)
optimizers
[
'optimD'
].
clear_grad
()
# real
fake_d_pred
=
self
.
nets
[
'discriminator'
](
self
.
output
).
detach
()
real_d_pred
=
self
.
nets
[
'discriminator'
](
self
.
gt
)
l_d_real
=
self
.
gan_criterion
(
real_d_pred
-
paddle
.
mean
(
fake_d_pred
),
True
,
is_disc
=
True
)
*
0.5
# fake
fake_d_pred
=
self
.
nets
[
'discriminator'
](
self
.
output
.
detach
())
l_d_fake
=
self
.
gan_criterion
(
fake_d_pred
-
paddle
.
mean
(
real_d_pred
.
detach
()),
False
,
is_disc
=
True
)
*
0.5
(
l_d_real
+
l_d_fake
).
backward
()
optimizers
[
'optimD'
].
step
()
self
.
losses
[
'l_d_real'
]
=
l_d_real
self
.
losses
[
'l_d_fake'
]
=
l_d_fake
self
.
losses
[
'out_d_real'
]
=
paddle
.
mean
(
real_d_pred
.
detach
())
self
.
losses
[
'out_d_fake'
]
=
paddle
.
mean
(
fake_d_pred
.
detach
())
if
hasattr
(
self
,
'gan_criterion'
):
self
.
set_requires_grad
(
self
.
nets
[
'discriminator'
],
False
)
real_d_pred
=
self
.
nets
[
'discriminator'
](
self
.
gt
).
detach
()
fake_g_pred
=
self
.
nets
[
'discriminator'
](
self
.
output
)
l_g_real
=
self
.
gan_criterion
(
real_d_pred
-
paddle
.
mean
(
fake_g_pred
),
False
,
is_disc
=
False
)
l_g_fake
=
self
.
gan_criterion
(
fake_g_pred
-
paddle
.
mean
(
real_d_pred
),
True
,
is_disc
=
False
)
l_g_gan
=
(
l_g_real
+
l_g_fake
)
/
2
l_total
+=
l_g_gan
self
.
losses
[
'l_g_gan'
]
=
l_g_gan
l_total
.
backward
()
optimizers
[
'optimG'
].
step
()
self
.
set_requires_grad
(
self
.
nets
[
'discriminator'
],
True
)
optimizers
[
'optimD'
].
clear_grad
()
# real
fake_d_pred
=
self
.
nets
[
'discriminator'
](
self
.
output
).
detach
()
real_d_pred
=
self
.
nets
[
'discriminator'
](
self
.
gt
)
l_d_real
=
self
.
gan_criterion
(
real_d_pred
-
paddle
.
mean
(
fake_d_pred
),
True
,
is_disc
=
True
)
*
0.5
# fake
fake_d_pred
=
self
.
nets
[
'discriminator'
](
self
.
output
.
detach
())
l_d_fake
=
self
.
gan_criterion
(
fake_d_pred
-
paddle
.
mean
(
real_d_pred
.
detach
()),
False
,
is_disc
=
True
)
*
0.5
(
l_d_real
+
l_d_fake
).
backward
()
optimizers
[
'optimD'
].
step
()
self
.
losses
[
'l_d_real'
]
=
l_d_real
self
.
losses
[
'l_d_fake'
]
=
l_d_fake
self
.
losses
[
'out_d_real'
]
=
paddle
.
mean
(
real_d_pred
.
detach
())
self
.
losses
[
'out_d_fake'
]
=
paddle
.
mean
(
fake_d_pred
.
detach
())
else
:
l_total
.
backward
()
optimizers
[
'optimG'
].
step
()
ppgan/models/generators/__init__.py
浏览文件 @
cd642c08
...
...
@@ -21,6 +21,7 @@ from .resnet_ugatit import ResnetUGATITGenerator
from
.dcgenerator
import
DCGenerator
from
.generater_animegan
import
AnimeGenerator
,
AnimeGeneratorLite
from
.wav2lip
import
Wav2Lip
from
.lesrcnn
import
LESRCNNGenerator
from
.resnet_ugatit_p2c
import
ResnetUGATITP2CGenerator
from
.generator_styleganv2
import
StyleGANv2Generator
from
.generator_pixel2style2pixel
import
Pixel2Style2Pixel
ppgan/models/generators/lesrcnn.py
0 → 100644
浏览文件 @
cd642c08
import
math
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
.builder
import
GENERATORS
class
MeanShift
(
nn
.
Layer
):
def
__init__
(
self
,
mean_rgb
,
sub
):
super
(
MeanShift
,
self
).
__init__
()
sign
=
-
1
if
sub
else
1
r
=
mean_rgb
[
0
]
*
sign
g
=
mean_rgb
[
1
]
*
sign
b
=
mean_rgb
[
2
]
*
sign
self
.
shifter
=
nn
.
Conv2D
(
3
,
3
,
1
,
1
,
0
)
self
.
shifter
.
weight
.
set_value
(
paddle
.
eye
(
3
).
reshape
([
3
,
3
,
1
,
1
]))
self
.
shifter
.
bias
.
set_value
(
np
.
array
([
r
,
g
,
b
]).
astype
(
'float32'
))
# Freeze the mean shift layer
for
params
in
self
.
shifter
.
parameters
():
params
.
trainable
=
False
def
forward
(
self
,
x
):
x
=
self
.
shifter
(
x
)
return
x
class
UpsampleBlock
(
nn
.
Layer
):
def
__init__
(
self
,
n_channels
,
scale
,
multi_scale
,
group
=
1
):
super
(
UpsampleBlock
,
self
).
__init__
()
if
multi_scale
:
self
.
up2
=
_UpsampleBlock
(
n_channels
,
scale
=
2
,
group
=
group
)
self
.
up3
=
_UpsampleBlock
(
n_channels
,
scale
=
3
,
group
=
group
)
self
.
up4
=
_UpsampleBlock
(
n_channels
,
scale
=
4
,
group
=
group
)
else
:
self
.
up
=
_UpsampleBlock
(
n_channels
,
scale
=
scale
,
group
=
group
)
self
.
multi_scale
=
multi_scale
def
forward
(
self
,
x
,
scale
):
if
self
.
multi_scale
:
if
scale
==
2
:
return
self
.
up2
(
x
)
elif
scale
==
3
:
return
self
.
up3
(
x
)
elif
scale
==
4
:
return
self
.
up4
(
x
)
else
:
return
self
.
up
(
x
)
class
_UpsampleBlock
(
nn
.
Layer
):
def
__init__
(
self
,
n_channels
,
scale
,
group
=
1
):
super
(
_UpsampleBlock
,
self
).
__init__
()
modules
=
[]
if
scale
==
2
or
scale
==
4
or
scale
==
8
:
for
_
in
range
(
int
(
math
.
log
(
scale
,
2
))):
modules
+=
[
nn
.
Conv2D
(
n_channels
,
4
*
n_channels
,
3
,
1
,
1
,
groups
=
group
)
]
modules
+=
[
nn
.
PixelShuffle
(
2
)]
elif
scale
==
3
:
modules
+=
[
nn
.
Conv2D
(
n_channels
,
9
*
n_channels
,
3
,
1
,
1
,
groups
=
group
)
]
modules
+=
[
nn
.
PixelShuffle
(
3
)]
self
.
body
=
nn
.
Sequential
(
*
modules
)
def
forward
(
self
,
x
):
out
=
self
.
body
(
x
)
return
out
@
GENERATORS
.
register
()
class
LESRCNNGenerator
(
nn
.
Layer
):
"""Construct a Resnet-based generator that consists of residual blocks
between a few downsampling/upsampling operations.
Args:
scale (int): scale of upsample.
multi_scale (bool): Whether to train multi scale model.
group (int): group option for convolution.
"""
def
__init__
(
self
,
scale
=
4
,
multi_scale
=
False
,
group
=
1
,
):
super
(
LESRCNNGenerator
,
self
).
__init__
()
kernel_size
=
3
kernel_size1
=
1
padding1
=
0
padding
=
1
features
=
64
groups
=
1
channels
=
3
features1
=
64
self
.
scale
=
scale
self
.
sub_mean
=
MeanShift
((
0.4488
,
0.4371
,
0.4040
),
sub
=
True
)
self
.
add_mean
=
MeanShift
((
0.4488
,
0.4371
,
0.4040
),
sub
=
False
)
self
.
conv1
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
channels
,
out_channels
=
features
,
kernel_size
=
kernel_size
,
padding
=
padding
,
groups
=
1
,
bias_attr
=
False
))
self
.
conv2
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size
,
padding
=
1
,
groups
=
1
,
bias_attr
=
False
),
nn
.
ReLU
())
self
.
conv3
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size1
,
padding
=
0
,
groups
=
groups
,
bias_attr
=
False
))
self
.
conv4
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size
,
padding
=
1
,
groups
=
1
,
bias_attr
=
False
),
nn
.
ReLU
())
self
.
conv5
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size1
,
padding
=
0
,
groups
=
groups
,
bias_attr
=
False
))
self
.
conv6
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size
,
padding
=
1
,
groups
=
1
,
bias_attr
=
False
),
nn
.
ReLU
())
self
.
conv7
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size1
,
padding
=
0
,
groups
=
groups
,
bias_attr
=
False
))
self
.
conv8
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size
,
padding
=
1
,
groups
=
1
,
bias_attr
=
False
),
nn
.
ReLU
())
self
.
conv9
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size1
,
padding
=
0
,
groups
=
groups
,
bias_attr
=
False
))
self
.
conv10
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size
,
padding
=
1
,
groups
=
1
,
bias_attr
=
False
),
nn
.
ReLU
())
self
.
conv11
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size1
,
padding
=
0
,
groups
=
groups
,
bias_attr
=
False
))
self
.
conv12
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size
,
padding
=
1
,
groups
=
1
,
bias_attr
=
False
),
nn
.
ReLU
())
self
.
conv13
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size1
,
padding
=
0
,
groups
=
groups
,
bias_attr
=
False
))
self
.
conv14
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size
,
padding
=
1
,
groups
=
1
,
bias_attr
=
False
),
nn
.
ReLU
())
self
.
conv15
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size1
,
padding
=
0
,
groups
=
groups
,
bias_attr
=
False
))
self
.
conv16
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size
,
padding
=
1
,
groups
=
1
,
bias_attr
=
False
),
nn
.
ReLU
())
self
.
conv17
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size1
,
padding
=
0
,
groups
=
groups
,
bias_attr
=
False
))
self
.
conv17_1
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size
,
padding
=
1
,
groups
=
1
,
bias_attr
=
False
),
nn
.
ReLU
())
self
.
conv17_2
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size
,
padding
=
1
,
groups
=
1
,
bias_attr
=
False
),
nn
.
ReLU
())
self
.
conv17_3
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size
,
padding
=
1
,
groups
=
1
,
bias_attr
=
False
),
nn
.
ReLU
())
self
.
conv17_4
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
features
,
kernel_size
=
kernel_size
,
padding
=
1
,
groups
=
1
,
bias_attr
=
False
),
nn
.
ReLU
())
self
.
conv18
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_channels
=
features
,
out_channels
=
3
,
kernel_size
=
kernel_size
,
padding
=
padding
,
groups
=
groups
,
bias_attr
=
False
))
self
.
ReLU
=
nn
.
ReLU
()
self
.
upsample
=
UpsampleBlock
(
64
,
scale
=
scale
,
multi_scale
=
multi_scale
,
group
=
1
)
def
forward
(
self
,
x
,
scale
=
None
):
if
scale
is
None
:
scale
=
self
.
scale
x
=
self
.
sub_mean
(
x
)
x1
=
self
.
conv1
(
x
)
x1_1
=
self
.
ReLU
(
x1
)
x2
=
self
.
conv2
(
x1_1
)
x3
=
self
.
conv3
(
x2
)
x2_3
=
x1
+
x3
x2_4
=
self
.
ReLU
(
x2_3
)
x4
=
self
.
conv4
(
x2_4
)
x5
=
self
.
conv5
(
x4
)
x3_5
=
x2_3
+
x5
x3_6
=
self
.
ReLU
(
x3_5
)
x6
=
self
.
conv6
(
x3_6
)
x7
=
self
.
conv7
(
x6
)
x7_1
=
x3_5
+
x7
x7_2
=
self
.
ReLU
(
x7_1
)
x8
=
self
.
conv8
(
x7_2
)
x9
=
self
.
conv9
(
x8
)
x9_2
=
x7_1
+
x9
x9_1
=
self
.
ReLU
(
x9_2
)
x10
=
self
.
conv10
(
x9_1
)
x11
=
self
.
conv11
(
x10
)
x11_1
=
x9_2
+
x11
x11_2
=
self
.
ReLU
(
x11_1
)
x12
=
self
.
conv12
(
x11_2
)
x13
=
self
.
conv13
(
x12
)
x13_1
=
x11_1
+
x13
x13_2
=
self
.
ReLU
(
x13_1
)
x14
=
self
.
conv14
(
x13_2
)
x15
=
self
.
conv15
(
x14
)
x15_1
=
x15
+
x13_1
x15_2
=
self
.
ReLU
(
x15_1
)
x16
=
self
.
conv16
(
x15_2
)
x17
=
self
.
conv17
(
x16
)
x17_2
=
x17
+
x15_1
x17_3
=
self
.
ReLU
(
x17_2
)
temp
=
self
.
upsample
(
x17_3
,
scale
=
scale
)
x1111
=
self
.
upsample
(
x1_1
,
scale
=
scale
)
temp1
=
x1111
+
temp
temp2
=
self
.
ReLU
(
temp1
)
temp3
=
self
.
conv17_1
(
temp2
)
temp4
=
self
.
conv17_2
(
temp3
)
temp5
=
self
.
conv17_3
(
temp4
)
temp6
=
self
.
conv17_4
(
temp5
)
x18
=
self
.
conv18
(
temp6
)
out
=
self
.
add_mean
(
x18
)
return
out
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录