Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
浅灬忆灬
PaddleGAN
提交
099c595e
P
PaddleGAN
项目概览
浅灬忆灬
/
PaddleGAN
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleGAN
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
099c595e
编写于
8月 06, 2021
作者:
L
LielinJiang
提交者:
GitHub
8月 06, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement IconVSR (#384)
* add iconvsr
上级
1ffa3fba
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
523 addition
and
7 deletion
+523
-7
configs/iconvsr_reds.yaml
configs/iconvsr_reds.yaml
+91
-0
docs/en_US/tutorials/video_super_resolution.md
docs/en_US/tutorials/video_super_resolution.md
+2
-0
docs/zh_CN/tutorials/video_super_resolution.md
docs/zh_CN/tutorials/video_super_resolution.md
+2
-2
ppgan/models/basicvsr_model.py
ppgan/models/basicvsr_model.py
+4
-3
ppgan/models/generators/__init__.py
ppgan/models/generators/__init__.py
+1
-0
ppgan/models/generators/basicvsr.py
ppgan/models/generators/basicvsr.py
+0
-2
ppgan/models/generators/iconvsr.py
ppgan/models/generators/iconvsr.py
+423
-0
未找到文件。
configs/iconvsr_reds.yaml
0 → 100644
浏览文件 @
099c595e
total_iters
:
300000
output_dir
:
output_dir
find_unused_parameters
:
True
checkpoints_dir
:
checkpoints
use_dataset
:
True
# tensor range for function tensor2img
min_max
:
(0., 1.)
model
:
name
:
BasicVSRModel
fix_iter
:
5000
generator
:
name
:
IconVSR
mid_channels
:
64
num_blocks
:
30
pixel_criterion
:
name
:
CharbonnierLoss
reduction
:
mean
dataset
:
train
:
name
:
RepeatDataset
times
:
1000
num_workers
:
4
# 6
batch_size
:
2
# 4*2
dataset
:
name
:
SRREDSMultipleGTDataset
mode
:
train
lq_folder
:
data/REDS/train_sharp_bicubic/X4
gt_folder
:
data/REDS/train_sharp/X4
crop_size
:
256
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
15
use_flip
:
True
use_rot
:
True
scale
:
4
val_partition
:
REDS4
test
:
name
:
SRREDSMultipleGTDataset
mode
:
test
lq_folder
:
data/REDS/REDS4_test_sharp_bicubic/X4
gt_folder
:
data/REDS/REDS4_test_sharp/X4
interval_list
:
[
1
]
random_reverse
:
False
number_frames
:
100
use_flip
:
False
use_rot
:
False
scale
:
4
val_partition
:
REDS4
num_workers
:
0
batch_size
:
1
lr_scheduler
:
name
:
CosineAnnealingRestartLR
learning_rate
:
!!float
2e-4
periods
:
[
300000
]
restart_weights
:
[
1
]
eta_min
:
!!float
1e-7
optimizer
:
name
:
Adam
# add parameters of net_name to optim
# name should in self.nets
net_names
:
-
generator
beta1
:
0.9
beta2
:
0.99
validate
:
interval
:
5000
save_img
:
false
metrics
:
psnr
:
# metric name, can be arbitrary
name
:
PSNR
crop_border
:
0
test_y_channel
:
False
ssim
:
name
:
SSIM
crop_border
:
0
test_y_channel
:
False
log_config
:
interval
:
100
visiual_interval
:
500
snapshot_config
:
interval
:
5000
docs/en_US/tutorials/video_super_resolution.md
浏览文件 @
099c595e
...
...
@@ -78,6 +78,7 @@ The metrics are PSNR / SSIM.
| EDVR_L_wo_tsa_deblur | 34.9587 / 0.9509 |
| EDVR_L_w_tsa_deblur | 35.1473 / 0.9526 |
| BasicVSR_x4 | 31.4325 / 0.8913 |
| IconVSR_x4 | 31.6882 / 0.8950 |
## 1.4 Model Download
...
...
@@ -90,6 +91,7 @@ The metrics are PSNR / SSIM.
| EDVR_L_wo_tsa_deblur | REDS |
[
EDVR_L_wo_tsa_deblur
](
https://paddlegan.bj.bcebos.com/models/EDVR_L_wo_tsa_deblur.pdparams
)
| EDVR_L_w_tsa_deblur | REDS |
[
EDVR_L_w_tsa_deblur
](
https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_deblur.pdparams
)
| BasicVSR_x4 | REDS |
[
BasicVSR_x4
](
https://paddlegan.bj.bcebos.com/models/BasicVSR_reds_x4.pdparams
)
| IconVSR_x4 | REDS |
[
IconVSR_x4
](
https://paddlegan.bj.bcebos.com/models/IconVSR_reds_x4.pdparams
)
...
...
docs/zh_CN/tutorials/video_super_resolution.md
浏览文件 @
099c595e
...
...
@@ -74,7 +74,7 @@
| EDVR_L_wo_tsa_deblur | 34.9587 / 0.9509 |
| EDVR_L_w_tsa_deblur | 35.1473 / 0.9526 |
| BasicVSR_x4 | 31.4325 / 0.8913 |
| IconVSR_x4 | 31.6882 / 0.8950 |
## 1.4 模型下载
| 模型 | 数据集 | 下载地址 |
...
...
@@ -86,7 +86,7 @@
| EDVR_L_wo_tsa_deblur | REDS |
[
EDVR_L_wo_tsa_deblur
](
https://paddlegan.bj.bcebos.com/models/EDVR_L_wo_tsa_deblur.pdparams
)
| EDVR_L_w_tsa_deblur | REDS |
[
EDVR_L_w_tsa_deblur
](
https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_deblur.pdparams
)
| BasicVSR_x4 | REDS |
[
BasicVSR_x4
](
https://paddlegan.bj.bcebos.com/models/BasicVSR_reds_x4.pdparams
)
| IconVSR_x4 | REDS |
[
IconVSR_x4
](
https://paddlegan.bj.bcebos.com/models/IconVSR_reds_x4.pdparams
)
...
...
ppgan/models/basicvsr_model.py
浏览文件 @
099c595e
...
...
@@ -17,7 +17,8 @@ import paddle.nn as nn
from
.builder
import
MODELS
from
.sr_model
import
BaseSRModel
from
.generators.basicvsr
import
ResidualBlockNoBN
,
PixelShufflePack
,
SPyNet
from
.generators.iconvsr
import
EDVRFeatureExtractor
from
.generators.basicvsr
import
ResidualBlockNoBN
,
PixelShufflePack
,
SPyNet
from
..modules.init
import
reset_parameters
from
..utils.visual
import
tensor2img
...
...
@@ -57,7 +58,7 @@ class BasicVSRModel(BaseSRModel):
print
(
'Train BasicVSR with fixed spynet for'
,
self
.
fix_iter
,
'iters.'
)
for
name
,
param
in
self
.
nets
[
'generator'
].
named_parameters
():
if
'spynet'
in
name
:
if
'spynet'
in
name
or
'edvr'
in
name
:
param
.
trainable
=
False
elif
self
.
current_iter
>=
self
.
fix_iter
+
1
and
self
.
flag
:
print
(
'Train all the parameters.'
)
...
...
@@ -107,5 +108,5 @@ def init_basicvsr_weight(net):
continue
if
(
not
isinstance
(
m
,
(
ResidualBlockNoBN
,
PixelShufflePack
,
SPyNet
))):
m
,
(
ResidualBlockNoBN
,
PixelShufflePack
,
SPyNet
,
EDVRFeatureExtractor
))):
init_basicvsr_weight
(
m
)
ppgan/models/generators/__init__.py
浏览文件 @
099c595e
...
...
@@ -32,4 +32,5 @@ from .generator_firstorder import FirstOrderGenerator
from
.generater_lapstyle
import
DecoderNet
,
Encoder
,
RevisionNet
from
.basicvsr
import
BasicVSRNet
from
.mpr
import
MPRNet
from
.iconvsr
import
IconVSR
from
.gpen
import
GPEN
ppgan/models/generators/basicvsr.py
浏览文件 @
099c595e
...
...
@@ -15,11 +15,9 @@
import
paddle
import
numpy
as
np
import
scipy.io
as
scio
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn
import
initializer
from
...utils.download
import
get_path_from_url
from
...modules.init
import
kaiming_normal_
,
constant_
...
...
ppgan/models/generators/iconvsr.py
0 → 100644
浏览文件 @
099c595e
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# basicvsr and iconvsr code are heavily based on mmedit
import
paddle
import
numpy
as
np
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
.builder
import
GENERATORS
from
.edvr
import
PCDAlign
,
TSAFusion
from
.basicvsr
import
SPyNet
,
PixelShufflePack
,
ResidualBlockNoBN
,
\
ResidualBlocksWithInputConv
,
flow_warp
from
...utils.download
import
get_path_from_url
@
GENERATORS
.
register
()
class
IconVSR
(
nn
.
Layer
):
"""BasicVSR network structure for video super-resolution.
Support only x4 upsampling.
Paper:
BasicVSR: The Search for Essential Components in Video Super-Resolution
and Beyond, CVPR, 2021
Args:
mid_channels (int): Channel number of the intermediate features.
Default: 64.
num_blocks (int): Number of residual blocks in each propagation branch.
Default: 30.
padding (int): Number of frames to be padded at two ends of the
sequence. 2 for REDS and 3 for Vimeo-90K. Default: 2.
keyframe_stride (int): Number determining the keyframes. If stride=5,
then the (0, 5, 10, 15, ...)-th frame will be the keyframes.
Default: 5.
"""
def
__init__
(
self
,
mid_channels
=
64
,
num_blocks
=
30
,
padding
=
2
,
keyframe_stride
=
5
):
super
().
__init__
()
self
.
mid_channels
=
mid_channels
self
.
padding
=
padding
self
.
keyframe_stride
=
keyframe_stride
# optical flow network for feature alignment
self
.
spynet
=
SPyNet
()
weight_path
=
get_path_from_url
(
'https://paddlegan.bj.bcebos.com/models/spynet.pdparams'
)
self
.
spynet
.
set_state_dict
(
paddle
.
load
(
weight_path
))
# information-refill
self
.
edvr
=
EDVRFeatureExtractor
(
num_frames
=
padding
*
2
+
1
,
center_frame_idx
=
padding
)
edvr_wight_path
=
get_path_from_url
(
'https://paddlegan.bj.bcebos.com/models/edvrm.pdparams'
)
self
.
edvr
.
set_state_dict
(
paddle
.
load
(
edvr_wight_path
))
self
.
backward_fusion
=
nn
.
Conv2D
(
2
*
mid_channels
,
mid_channels
,
3
,
1
,
1
,
bias_attr
=
True
)
self
.
forward_fusion
=
nn
.
Conv2D
(
2
*
mid_channels
,
mid_channels
,
3
,
1
,
1
,
bias_attr
=
True
)
# propagation branches
self
.
backward_resblocks
=
ResidualBlocksWithInputConv
(
mid_channels
+
3
,
mid_channels
,
num_blocks
)
self
.
forward_resblocks
=
ResidualBlocksWithInputConv
(
2
*
mid_channels
+
3
,
mid_channels
,
num_blocks
)
# upsample
# self.fusion = nn.Conv2D(mid_channels * 2, mid_channels, 1, 1, 0)
self
.
upsample1
=
PixelShufflePack
(
mid_channels
,
mid_channels
,
2
,
upsample_kernel
=
3
)
self
.
upsample2
=
PixelShufflePack
(
mid_channels
,
64
,
2
,
upsample_kernel
=
3
)
self
.
conv_hr
=
nn
.
Conv2D
(
64
,
64
,
3
,
1
,
1
)
self
.
conv_last
=
nn
.
Conv2D
(
64
,
3
,
3
,
1
,
1
)
self
.
img_upsample
=
nn
.
Upsample
(
scale_factor
=
4
,
mode
=
'bilinear'
,
align_corners
=
False
)
# activation function
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.1
)
def
spatial_padding
(
self
,
lrs
):
""" Apply pdding spatially.
Since the PCD module in EDVR requires that the resolution is a multiple
of 4, we apply padding to the input LR images if their resolution is
not divisible by 4.
Args:
lrs (Tensor): Input LR sequence with shape (n, t, c, h, w).
Returns:
Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad).
"""
n
,
t
,
c
,
h
,
w
=
lrs
.
shape
pad_h
=
(
4
-
h
%
4
)
%
4
pad_w
=
(
4
-
w
%
4
)
%
4
# padding
lrs
=
lrs
.
reshape
([
-
1
,
c
,
h
,
w
])
lrs
=
F
.
pad
(
lrs
,
[
0
,
pad_w
,
0
,
pad_h
],
mode
=
'reflect'
)
return
lrs
.
reshape
([
n
,
t
,
c
,
h
+
pad_h
,
w
+
pad_w
])
def
check_if_mirror_extended
(
self
,
lrs
):
"""Check whether the input is a mirror-extended sequence.
If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the
(t-1-i)-th frame.
Args:
lrs (tensor): Input LR images with shape (n, t, c, h, w)
Returns:
bool: whether the input is a mirror-extended sequence
"""
self
.
is_mirror_extended
=
False
if
lrs
.
shape
[
1
]
%
2
==
0
:
lrs_1
,
lrs_2
=
paddle
.
chunk
(
lrs
,
2
,
axis
=
1
)
lrs_2
=
paddle
.
flip
(
lrs_2
,
[
1
])
if
paddle
.
norm
(
lrs_1
-
lrs_2
)
==
0
:
self
.
is_mirror_extended
=
True
def
compute_refill_features
(
self
,
lrs
,
keyframe_idx
):
""" Compute keyframe features for information-refill.
Since EDVR-M is used, padding is performed before feature computation.
Args:
lrs (Tensor): Input LR images with shape (n, t, c, h, w)
keyframe_idx (list(int)): The indices specifying the keyframes.
Return:
dict(Tensor): The keyframe features. Each key corresponds to the
indices in keyframe_idx.
"""
if
self
.
padding
==
2
:
lrs
=
[
lrs
[:,
4
:
5
,
:,
:],
lrs
[:,
3
:
4
,
:,
:],
lrs
,
lrs
[:,
-
4
:
-
3
,
:,
:],
lrs
[:,
-
5
:
-
4
,
:,
:]
]
elif
self
.
padding
==
3
:
lrs
=
[
lrs
[:,
[
6
,
5
,
4
]],
lrs
,
lrs
[:,
[
-
5
,
-
6
,
-
7
]]]
lrs
=
paddle
.
concat
(
lrs
,
axis
=
1
)
num_frames
=
2
*
self
.
padding
+
1
feats_refill
=
{}
for
i
in
keyframe_idx
:
feats_refill
[
i
]
=
self
.
edvr
(
lrs
[:,
i
:
i
+
num_frames
])
return
feats_refill
def
compute_flow
(
self
,
lrs
):
"""Compute optical flow using SPyNet for feature warping.
Note that if the input is an mirror-extended sequence, 'flows_forward'
is not needed, since it is equal to 'flows_backward.flip(1)'.
Args:
lrs (tensor): Input LR images with shape (n, t, c, h, w)
Return:
tuple(Tensor): Optical flow. 'flows_forward' corresponds to the
flows used for forward-time propagation (current to previous).
'flows_backward' corresponds to the flows used for
backward-time propagation (current to next).
"""
n
,
t
,
c
,
h
,
w
=
lrs
.
shape
lrs_1
=
lrs
[:,
:
-
1
,
:,
:,
:].
reshape
([
-
1
,
c
,
h
,
w
])
lrs_2
=
lrs
[:,
1
:,
:,
:,
:].
reshape
([
-
1
,
c
,
h
,
w
])
flows_backward
=
self
.
spynet
(
lrs_1
,
lrs_2
).
reshape
([
n
,
t
-
1
,
2
,
h
,
w
])
if
self
.
is_mirror_extended
:
# flows_forward = flows_backward.flip(1)
flows_forward
=
None
else
:
flows_forward
=
self
.
spynet
(
lrs_2
,
lrs_1
).
reshape
([
n
,
t
-
1
,
2
,
h
,
w
])
return
flows_forward
,
flows_backward
def
forward
(
self
,
lrs
):
"""Forward function for BasicVSR.
Args:
lrs (Tensor): Input LR sequence with shape (n, t, c, h, w).
Returns:
Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
"""
n
,
t
,
c
,
h_input
,
w_input
=
lrs
.
shape
assert
h_input
>=
64
and
w_input
>=
64
,
(
'The height and width of inputs should be at least 64, '
f
'but got
{
h_input
}
and
{
w_input
}
.'
)
# check whether the input is an extended sequence
self
.
check_if_mirror_extended
(
lrs
)
lrs
=
self
.
spatial_padding
(
lrs
)
h
,
w
=
lrs
.
shape
[
3
],
lrs
.
shape
[
4
]
# get the keyframe indices for information-refill
keyframe_idx
=
list
(
range
(
0
,
t
,
self
.
keyframe_stride
))
if
keyframe_idx
[
-
1
]
!=
t
-
1
:
keyframe_idx
.
append
(
t
-
1
)
# the last frame must be a keyframe
# compute optical flow and compute features for information-refill
flows_forward
,
flows_backward
=
self
.
compute_flow
(
lrs
)
feats_refill
=
self
.
compute_refill_features
(
lrs
,
keyframe_idx
)
# compute optical flow
flows_forward
,
flows_backward
=
self
.
compute_flow
(
lrs
)
# backward-time propgation
outputs
=
[]
feat_prop
=
paddle
.
to_tensor
(
np
.
zeros
([
n
,
self
.
mid_channels
,
h
,
w
],
'float32'
))
for
i
in
range
(
t
-
1
,
-
1
,
-
1
):
# no warping required for the last timestep
if
i
<
t
-
1
:
flow
=
flows_backward
[:,
i
,
:,
:,
:]
feat_prop
=
flow_warp
(
feat_prop
,
flow
.
transpose
([
0
,
2
,
3
,
1
]))
# information refill
if
i
in
keyframe_idx
:
feat_prop
=
paddle
.
concat
([
feat_prop
,
feats_refill
[
i
]],
axis
=
1
)
feat_prop
=
self
.
backward_fusion
(
feat_prop
)
feat_prop
=
paddle
.
concat
([
lrs
[:,
i
,
:,
:,
:],
feat_prop
],
axis
=
1
)
feat_prop
=
self
.
backward_resblocks
(
feat_prop
)
outputs
.
append
(
feat_prop
)
outputs
=
outputs
[::
-
1
]
# forward-time propagation and upsampling
feat_prop
=
paddle
.
zeros_like
(
feat_prop
)
for
i
in
range
(
0
,
t
):
lr_curr
=
lrs
[:,
i
,
:,
:,
:]
if
i
>
0
:
# no warping required for the first timestep
if
flows_forward
is
not
None
:
flow
=
flows_forward
[:,
i
-
1
,
:,
:,
:]
else
:
flow
=
flows_backward
[:,
-
i
,
:,
:,
:]
feat_prop
=
flow_warp
(
feat_prop
,
flow
.
transpose
([
0
,
2
,
3
,
1
]))
# information refill
if
i
in
keyframe_idx
:
feat_prop
=
paddle
.
concat
([
feat_prop
,
feats_refill
[
i
]],
axis
=
1
)
feat_prop
=
self
.
forward_fusion
(
feat_prop
)
feat_prop
=
paddle
.
concat
([
lr_curr
,
outputs
[
i
],
feat_prop
],
axis
=
1
)
feat_prop
=
self
.
forward_resblocks
(
feat_prop
)
# upsampling given the backward and forward features
out
=
self
.
lrelu
(
self
.
upsample1
(
feat_prop
))
out
=
self
.
lrelu
(
self
.
upsample2
(
out
))
out
=
self
.
lrelu
(
self
.
conv_hr
(
out
))
out
=
self
.
conv_last
(
out
)
base
=
self
.
img_upsample
(
lr_curr
)
out
+=
base
outputs
[
i
]
=
out
return
paddle
.
stack
(
outputs
,
axis
=
1
)
class
EDVRFeatureExtractor
(
nn
.
Layer
):
"""EDVR feature extractor for information-refill in IconVSR.
We use EDVR-M in IconVSR. To adopt pretrained models, please
specify "pretrained".
Paper:
EDVR: Video Restoration with Enhanced Deformable Convolutional Networks.
Args:
in_channels (int): Channel number of inputs.
out_channels (int): Channel number of outputs.
mid_channels (int): Channel number of intermediate features.
Default: 64.
num_frames (int): Number of input frames. Default: 5.
deform_groups (int): Deformable groups. Defaults: 8.
num_blocks_extraction (int): Number of blocks for feature extraction.
Default: 5.
num_blocks_reconstruction (int): Number of blocks for reconstruction.
Default: 10.
center_frame_idx (int): The index of center frame. Frame counting from
0. Default: 2.
with_tsa (bool): Whether to use TSA module. Default: True.
"""
def
__init__
(
self
,
in_channels
=
3
,
out_channel
=
3
,
mid_channels
=
64
,
num_frames
=
5
,
deform_groups
=
8
,
num_blocks_extraction
=
5
,
num_blocks_reconstruction
=
10
,
center_frame_idx
=
2
,
with_tsa
=
True
):
super
().
__init__
()
self
.
center_frame_idx
=
center_frame_idx
self
.
with_tsa
=
with_tsa
self
.
conv_first
=
nn
.
Conv2D
(
in_channels
,
mid_channels
,
3
,
1
,
1
)
self
.
feature_extraction
=
make_layer
(
ResidualBlockNoBN
,
num_blocks_extraction
,
nf
=
mid_channels
)
# generate pyramid features
self
.
feat_l2_conv1
=
nn
.
Conv2D
(
mid_channels
,
mid_channels
,
3
,
2
,
1
)
self
.
feat_l2_conv2
=
nn
.
Conv2D
(
mid_channels
,
mid_channels
,
3
,
1
,
1
)
self
.
feat_l3_conv1
=
nn
.
Conv2D
(
mid_channels
,
mid_channels
,
3
,
2
,
1
)
self
.
feat_l3_conv2
=
nn
.
Conv2D
(
mid_channels
,
mid_channels
,
3
,
1
,
1
)
# pcd alignment
self
.
pcd_alignment
=
PCDAlign
(
nf
=
mid_channels
,
groups
=
deform_groups
)
# fusion
if
self
.
with_tsa
:
self
.
fusion
=
TSAFusion
(
nf
=
mid_channels
,
nframes
=
num_frames
,
center
=
self
.
center_frame_idx
)
else
:
self
.
fusion
=
nn
.
Conv2D
(
num_frames
*
mid_channels
,
mid_channels
,
1
,
1
)
# activation function
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.1
)
def
forward
(
self
,
x
):
"""Forward function for EDVRFeatureExtractor.
Args:
x (Tensor): Input tensor with shape (n, t, 3, h, w).
Returns:
Tensor: Intermediate feature with shape (n, mid_channels, h, w).
"""
n
,
t
,
c
,
h
,
w
=
x
.
shape
# extract LR features
# L1
l1_feat
=
self
.
lrelu
(
self
.
conv_first
(
x
.
reshape
([
-
1
,
c
,
h
,
w
])))
l1_feat
=
self
.
feature_extraction
(
l1_feat
)
# L2
l2_feat
=
self
.
lrelu
(
self
.
feat_l2_conv2
(
self
.
lrelu
(
self
.
feat_l2_conv1
(
l1_feat
))))
# L3
l3_feat
=
self
.
lrelu
(
self
.
feat_l3_conv2
(
self
.
lrelu
(
self
.
feat_l3_conv1
(
l2_feat
))))
l1_feat
=
l1_feat
.
reshape
([
n
,
t
,
-
1
,
h
,
w
])
l2_feat
=
l2_feat
.
reshape
([
n
,
t
,
-
1
,
h
//
2
,
w
//
2
])
l3_feat
=
l3_feat
.
reshape
([
n
,
t
,
-
1
,
h
//
4
,
w
//
4
])
# pcd alignment
ref_feats
=
[
# reference feature list
l1_feat
[:,
self
.
center_frame_idx
,
:,
:,
:].
clone
(),
l2_feat
[:,
self
.
center_frame_idx
,
:,
:,
:].
clone
(),
l3_feat
[:,
self
.
center_frame_idx
,
:,
:,
:].
clone
()
]
aligned_feat
=
[]
for
i
in
range
(
t
):
neighbor_feats
=
[
l1_feat
[:,
i
,
:,
:,
:].
clone
(),
l2_feat
[:,
i
,
:,
:,
:].
clone
(),
l3_feat
[:,
i
,
:,
:,
:].
clone
()
]
aligned_feat
.
append
(
self
.
pcd_alignment
(
neighbor_feats
,
ref_feats
))
aligned_feat
=
paddle
.
stack
(
aligned_feat
,
axis
=
1
)
# (n, t, c, h, w)
if
self
.
with_tsa
:
feat
=
self
.
fusion
(
aligned_feat
)
else
:
aligned_feat
=
aligned_feat
.
reshape
([
n
,
-
1
,
h
,
w
])
feat
=
self
.
fusion
(
aligned_feat
)
return
feat
def
make_layer
(
block
,
num_blocks
,
**
kwarg
):
"""Make layers by stacking the same blocks.
Args:
block (nn.Layer): nn.module class for basic block.
num_blocks (int): number of blocks.
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
layers
=
[]
for
_
in
range
(
num_blocks
):
layers
.
append
(
block
(
**
kwarg
))
return
nn
.
Sequential
(
*
layers
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录