未验证 提交 423af984 编写于 作者: L lzzyzlbb 提交者: GitHub

add msvsr static model (#510)

上级 1091e63e
......@@ -54,6 +54,7 @@ dataset:
val_partition: REDS4
num_workers: 0
batch_size: 1
num_clips: 270
lr_scheduler:
name: CosineAnnealingRestartLR
......
......@@ -48,6 +48,7 @@ dataset:
use_rot: True
scale: 4
val_partition: REDS4
num_clips: 270
test:
name: SRREDSMultipleGTDataset
......@@ -63,6 +64,7 @@ dataset:
val_partition: REDS4
num_workers: 0
batch_size: 1
num_clips: 270
lr_scheduler:
name: CosineAnnealingRestartLR
......@@ -100,3 +102,6 @@ log_config:
snapshot_config:
interval: 5000
export_model:
- {name: 'generator', inputs_num: 1}
\ No newline at end of file
......@@ -303,7 +303,9 @@ class MSVSR(nn.Layer):
pre_mask = {}
# propagation branches module
for prop_name in ['stage2_backward', 'stage2_forward']:
prop_names = ['stage2_backward', 'stage2_forward']
for index in range(2):
prop_name = prop_names[index]
pre_offset[prop_name] = [0 for _ in range(t)]
pre_mask[prop_name] = [0 for _ in range(t)]
feats[prop_name] = []
......@@ -372,7 +374,9 @@ class MSVSR(nn.Layer):
n, t, _, h, w = flows_backward.shape
# propagation branches module
for prop_name in ['stage3_backward', 'stage3_forward']:
prop_names = ['stage3_backward', 'stage3_forward']
for index in range(2):
prop_name = prop_names[index]
feats[prop_name] = []
frame_idx = range(0, t + 1)
flow_idx = range(-1, t)
......@@ -439,7 +443,8 @@ class MSVSR(nn.Layer):
mapping_idx = list(range(0, num_outputs))
mapping_idx += mapping_idx[::-1]
for i in range(0, lqs.shape[1]):
t = lqs.shape[1]
for i in range(0, t):
hr = [feats[k][i] for k in feats if (k != 'spatial')]
feat_current = feats['spatial'][mapping_idx[i]]
hr.insert(0, feat_current)
......@@ -479,16 +484,13 @@ class MSVSR(nn.Layer):
"""
outputs = []
outputs_head = []
num_outputs = len(feats['spatial'])
mapping_idx = list(range(0, num_outputs))
mapping_idx += mapping_idx[::-1]
cas_outs = []
pas = []
hrs = []
for i in range(0, lqs.shape[1]):
t = lqs.shape[1]
for i in range(0, t):
hr = [
feats[k].pop(0) for k in feats
if (k != 'spatial' and k != 'feat_stage1')
......
===========================train_params===========================
model_name:msvsr
python:python3.7
gpu_list:0
##
auto_cast:null
total_iters:lite_train_lite_infer=10|whole_train_whole_infer=200
output_dir:./output/
dataset.train.batch_size:lite_train_lite_infer=1|whole_train_whole_infer=1
pretrained_model:null
train_model_name:msvsr_reds*/*checkpoint.pdparams
train_infer_img_dir:./data/msvsr_reds/test
null:null
##
trainer:norm_train
norm_train:tools/main.py -c configs/msvsr_reds.yaml --seed 123 -o dataset.train.dataset.num_clips=2 dataset.train.num_workers=0 log_config.interval=1 snapshot_config.interval=5
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
--output_dir:./output/
load:null
norm_export:tools/export_model.py -c configs/msvsr_reds.yaml --inputs_size="1,4,3,180,320" --load
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
inference_dir:multistagevsrmodel_generator
train_model:./inference/msvsr/multistagevsrmodel_generator
infer_export:null
infer_quant:False
inference:tools/inference.py --model_type msvsr -c configs/msvsr_reds.yaml --seed 123 -o dataset.test.num_clips=2 dataset.test.number_frames=4 --output_path test_tipc/output/
--device:gpu
null:null
null:null
null:null
null:null
null:null
--model_path:
null:null
null:null
--benchmark:True
null:null
\ No newline at end of file
......@@ -56,6 +56,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
rm -rf ./data/basicvsr*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/basicvsr_lite.tar --no-check-certificate
cd ./data/ && tar xf basicvsr_lite.tar && cd ../
elif [ ${model_name} == "msvsr" ]; then
rm -rf ./data/basicvsr*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/basicvsr_lite.tar --no-check-certificate
cd ./data/ && tar xf basicvsr_lite.tar && cd ../
fi
elif [ ${MODE} = "whole_train_whole_infer" ];then
......@@ -89,6 +93,10 @@ elif [ ${MODE} = "lite_train_whole_infer" ];then
rm -rf ./data/REDS*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/basicvsr_lite.tar --no-check-certificate
cd ./data/ && tar xf basicvsr_lite.tar && cd ../
elif [ ${model_name} == "msvsr" ]; then
rm -rf ./data/REDS*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/basicvsr_lite.tar --no-check-certificate
cd ./data/ && tar xf basicvsr_lite.tar && cd ../
fi
elif [ ${MODE} = "whole_infer" ];then
if [ ${model_name} = "pix2pix" ]; then
......@@ -125,6 +133,13 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference https://paddlegan.bj.bcebos.com/static_model/basicvsr.tar --no-check-certificate
cd ./inference && tar xf basicvsr.tar && cd ../
cd ./data/ && tar xf basicvsr_lite_test.tar && cd ../
elif [ ${model_name} == "msvsr" ]; then
rm -rf ./data/basic*
rm -rf ./inference/msvsr*
wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/basicvsr_lite_test.tar --no-check-certificate
wget -nc -P ./inference https://paddlegan.bj.bcebos.com/static_model/msvsr.tar --no-check-certificate
cd ./inference && tar xf msvsr.tar && cd ../
cd ./data/ && tar xf basicvsr_lite_test.tar && cd ../
fi
fi
Metric psnr: 27.3670
Metric ssim: 0.8021
\ No newline at end of file
......@@ -14,7 +14,8 @@ from ppgan.utils.filesystem import makedirs
from ppgan.metrics import build_metric
MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", "edvr", "fom", "stylegan2", "basicvsr"]
MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", \
"edvr", "fom", "stylegan2", "basicvsr", "msvsr"]
def parse_args():
......@@ -106,7 +107,6 @@ def main():
max_eval_steps = len(test_dataloader)
iter_loader = IterLoader(test_dataloader)
min_max = cfg.get('min_max', None)
if min_max is None:
min_max = (-1., 1.)
......@@ -192,7 +192,7 @@ def main():
real_img = paddle.to_tensor(data['A'])
for metric in metrics.values():
metric.update(prediction, real_img)
elif model_type == "basicvsr":
elif model_type in ["basicvsr", "msvsr"]:
lq = data['lq'].numpy()
input_handles[0].copy_from_cpu(lq)
predictor.run()
......@@ -208,9 +208,9 @@ def main():
gt_img.append(tensor2img(gt_tensor, (0.,1.)))
image_numpy = tensor2img(prediction[0], min_max)
save_image(image_numpy, os.path.join(args.output_path, "basicvsr/{}.png".format(i)))
save_image(image_numpy, os.path.join(args.output_path, model_type, "{}.png".format(i)))
metric_file = os.path.join(args.output_path, "basicvsr/metric.txt")
metric_file = os.path.join(args.output_path, model_type, "metric.txt")
for metric in metrics.values():
metric.update(out_img, gt_img, is_seq=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册