Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
423af984
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
大约 1 年 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
423af984
编写于
12月 06, 2021
作者:
L
lzzyzlbb
提交者:
GitHub
12月 06, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add msvsr static model (#510)
上级
1091e63e
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
89 addition
and
13 deletion
+89
-13
configs/basicvsr_reds.yaml
configs/basicvsr_reds.yaml
+1
-0
configs/msvsr_reds.yaml
configs/msvsr_reds.yaml
+5
-0
ppgan/models/generators/msvsr.py
ppgan/models/generators/msvsr.py
+10
-8
test_tipc/configs/msvsr/train_infer_python.txt
test_tipc/configs/msvsr/train_infer_python.txt
+51
-0
test_tipc/prepare.sh
test_tipc/prepare.sh
+15
-0
test_tipc/results/python_msvsr_results_fp32.txt
test_tipc/results/python_msvsr_results_fp32.txt
+2
-0
tools/inference.py
tools/inference.py
+5
-5
未找到文件。
configs/basicvsr_reds.yaml
浏览文件 @
423af984
...
...
@@ -54,6 +54,7 @@ dataset:
val_partition
:
REDS4
num_workers
:
0
batch_size
:
1
num_clips
:
270
lr_scheduler
:
name
:
CosineAnnealingRestartLR
...
...
configs/msvsr_reds.yaml
浏览文件 @
423af984
...
...
@@ -48,6 +48,7 @@ dataset:
use_rot
:
True
scale
:
4
val_partition
:
REDS4
num_clips
:
270
test
:
name
:
SRREDSMultipleGTDataset
...
...
@@ -63,6 +64,7 @@ dataset:
val_partition
:
REDS4
num_workers
:
0
batch_size
:
1
num_clips
:
270
lr_scheduler
:
name
:
CosineAnnealingRestartLR
...
...
@@ -100,3 +102,6 @@ log_config:
snapshot_config
:
interval
:
5000
export_model
:
-
{
name
:
'
generator'
,
inputs_num
:
1
}
\ No newline at end of file
ppgan/models/generators/msvsr.py
浏览文件 @
423af984
...
...
@@ -303,7 +303,9 @@ class MSVSR(nn.Layer):
pre_mask
=
{}
# propagation branches module
for
prop_name
in
[
'stage2_backward'
,
'stage2_forward'
]:
prop_names
=
[
'stage2_backward'
,
'stage2_forward'
]
for
index
in
range
(
2
):
prop_name
=
prop_names
[
index
]
pre_offset
[
prop_name
]
=
[
0
for
_
in
range
(
t
)]
pre_mask
[
prop_name
]
=
[
0
for
_
in
range
(
t
)]
feats
[
prop_name
]
=
[]
...
...
@@ -372,7 +374,9 @@ class MSVSR(nn.Layer):
n
,
t
,
_
,
h
,
w
=
flows_backward
.
shape
# propagation branches module
for
prop_name
in
[
'stage3_backward'
,
'stage3_forward'
]:
prop_names
=
[
'stage3_backward'
,
'stage3_forward'
]
for
index
in
range
(
2
):
prop_name
=
prop_names
[
index
]
feats
[
prop_name
]
=
[]
frame_idx
=
range
(
0
,
t
+
1
)
flow_idx
=
range
(
-
1
,
t
)
...
...
@@ -439,7 +443,8 @@ class MSVSR(nn.Layer):
mapping_idx
=
list
(
range
(
0
,
num_outputs
))
mapping_idx
+=
mapping_idx
[::
-
1
]
for
i
in
range
(
0
,
lqs
.
shape
[
1
]):
t
=
lqs
.
shape
[
1
]
for
i
in
range
(
0
,
t
):
hr
=
[
feats
[
k
][
i
]
for
k
in
feats
if
(
k
!=
'spatial'
)]
feat_current
=
feats
[
'spatial'
][
mapping_idx
[
i
]]
hr
.
insert
(
0
,
feat_current
)
...
...
@@ -479,16 +484,13 @@ class MSVSR(nn.Layer):
"""
outputs
=
[]
outputs_head
=
[]
num_outputs
=
len
(
feats
[
'spatial'
])
mapping_idx
=
list
(
range
(
0
,
num_outputs
))
mapping_idx
+=
mapping_idx
[::
-
1
]
cas_outs
=
[]
pas
=
[]
hrs
=
[]
for
i
in
range
(
0
,
lqs
.
shape
[
1
]):
t
=
lqs
.
shape
[
1
]
for
i
in
range
(
0
,
t
):
hr
=
[
feats
[
k
].
pop
(
0
)
for
k
in
feats
if
(
k
!=
'spatial'
and
k
!=
'feat_stage1'
)
...
...
test_tipc/configs/msvsr/train_infer_python.txt
0 → 100644
浏览文件 @
423af984
===========================train_params===========================
model_name:msvsr
python:python3.7
gpu_list:0
##
auto_cast:null
total_iters:lite_train_lite_infer=10|whole_train_whole_infer=200
output_dir:./output/
dataset.train.batch_size:lite_train_lite_infer=1|whole_train_whole_infer=1
pretrained_model:null
train_model_name:msvsr_reds*/*checkpoint.pdparams
train_infer_img_dir:./data/msvsr_reds/test
null:null
##
trainer:norm_train
norm_train:tools/main.py -c configs/msvsr_reds.yaml --seed 123 -o dataset.train.dataset.num_clips=2 dataset.train.num_workers=0 log_config.interval=1 snapshot_config.interval=5
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
--output_dir:./output/
load:null
norm_export:tools/export_model.py -c configs/msvsr_reds.yaml --inputs_size="1,4,3,180,320" --load
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
inference_dir:multistagevsrmodel_generator
train_model:./inference/msvsr/multistagevsrmodel_generator
infer_export:null
infer_quant:False
inference:tools/inference.py --model_type msvsr -c configs/msvsr_reds.yaml --seed 123 -o dataset.test.num_clips=2 dataset.test.number_frames=4 --output_path test_tipc/output/
--device:gpu
null:null
null:null
null:null
null:null
null:null
--model_path:
null:null
null:null
--benchmark:True
null:null
\ No newline at end of file
test_tipc/prepare.sh
浏览文件 @
423af984
...
...
@@ -56,6 +56,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
rm
-rf
./data/basicvsr
*
wget
-nc
-P
./data/ https://paddlegan.bj.bcebos.com/datasets/basicvsr_lite.tar
--no-check-certificate
cd
./data/
&&
tar
xf basicvsr_lite.tar
&&
cd
../
elif
[
${
model_name
}
==
"msvsr"
]
;
then
rm
-rf
./data/basicvsr
*
wget
-nc
-P
./data/ https://paddlegan.bj.bcebos.com/datasets/basicvsr_lite.tar
--no-check-certificate
cd
./data/
&&
tar
xf basicvsr_lite.tar
&&
cd
../
fi
elif
[
${
MODE
}
=
"whole_train_whole_infer"
]
;
then
...
...
@@ -89,6 +93,10 @@ elif [ ${MODE} = "lite_train_whole_infer" ];then
rm
-rf
./data/REDS
*
wget
-nc
-P
./data/ https://paddlegan.bj.bcebos.com/datasets/basicvsr_lite.tar
--no-check-certificate
cd
./data/
&&
tar
xf basicvsr_lite.tar
&&
cd
../
elif
[
${
model_name
}
==
"msvsr"
]
;
then
rm
-rf
./data/REDS
*
wget
-nc
-P
./data/ https://paddlegan.bj.bcebos.com/datasets/basicvsr_lite.tar
--no-check-certificate
cd
./data/
&&
tar
xf basicvsr_lite.tar
&&
cd
../
fi
elif
[
${
MODE
}
=
"whole_infer"
]
;
then
if
[
${
model_name
}
=
"pix2pix"
]
;
then
...
...
@@ -125,6 +133,13 @@ elif [ ${MODE} = "whole_infer" ];then
wget
-nc
-P
./inference https://paddlegan.bj.bcebos.com/static_model/basicvsr.tar
--no-check-certificate
cd
./inference
&&
tar
xf basicvsr.tar
&&
cd
../
cd
./data/
&&
tar
xf basicvsr_lite_test.tar
&&
cd
../
elif
[
${
model_name
}
==
"msvsr"
]
;
then
rm
-rf
./data/basic
*
rm
-rf
./inference/msvsr
*
wget
-nc
-P
./data/ https://paddlegan.bj.bcebos.com/datasets/basicvsr_lite_test.tar
--no-check-certificate
wget
-nc
-P
./inference https://paddlegan.bj.bcebos.com/static_model/msvsr.tar
--no-check-certificate
cd
./inference
&&
tar
xf msvsr.tar
&&
cd
../
cd
./data/
&&
tar
xf basicvsr_lite_test.tar
&&
cd
../
fi
fi
test_tipc/results/python_msvsr_results_fp32.txt
0 → 100644
浏览文件 @
423af984
Metric psnr: 27.3670
Metric ssim: 0.8021
\ No newline at end of file
tools/inference.py
浏览文件 @
423af984
...
...
@@ -14,7 +14,8 @@ from ppgan.utils.filesystem import makedirs
from
ppgan.metrics
import
build_metric
MODEL_CLASSES
=
[
"pix2pix"
,
"cyclegan"
,
"wav2lip"
,
"esrgan"
,
"edvr"
,
"fom"
,
"stylegan2"
,
"basicvsr"
]
MODEL_CLASSES
=
[
"pix2pix"
,
"cyclegan"
,
"wav2lip"
,
"esrgan"
,
\
"edvr"
,
"fom"
,
"stylegan2"
,
"basicvsr"
,
"msvsr"
]
def
parse_args
():
...
...
@@ -106,7 +107,6 @@ def main():
max_eval_steps
=
len
(
test_dataloader
)
iter_loader
=
IterLoader
(
test_dataloader
)
min_max
=
cfg
.
get
(
'min_max'
,
None
)
if
min_max
is
None
:
min_max
=
(
-
1.
,
1.
)
...
...
@@ -192,7 +192,7 @@ def main():
real_img
=
paddle
.
to_tensor
(
data
[
'A'
])
for
metric
in
metrics
.
values
():
metric
.
update
(
prediction
,
real_img
)
elif
model_type
==
"basicvsr"
:
elif
model_type
in
[
"basicvsr"
,
"msvsr"
]
:
lq
=
data
[
'lq'
].
numpy
()
input_handles
[
0
].
copy_from_cpu
(
lq
)
predictor
.
run
()
...
...
@@ -208,9 +208,9 @@ def main():
gt_img
.
append
(
tensor2img
(
gt_tensor
,
(
0.
,
1.
)))
image_numpy
=
tensor2img
(
prediction
[
0
],
min_max
)
save_image
(
image_numpy
,
os
.
path
.
join
(
args
.
output_path
,
"basicvsr/
{}.png"
.
format
(
i
)))
save_image
(
image_numpy
,
os
.
path
.
join
(
args
.
output_path
,
model_type
,
"
{}.png"
.
format
(
i
)))
metric_file
=
os
.
path
.
join
(
args
.
output_path
,
"basicvsr/
metric.txt"
)
metric_file
=
os
.
path
.
join
(
args
.
output_path
,
model_type
,
"
metric.txt"
)
for
metric
in
metrics
.
values
():
metric
.
update
(
out_img
,
gt_img
,
is_seq
=
True
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录