From 423af9840164ad7ba86901c605f31521a562f932 Mon Sep 17 00:00:00 2001 From: lzzyzlbb <287246233@qq.com> Date: Mon, 6 Dec 2021 17:29:05 +0800 Subject: [PATCH] add msvsr static model (#510) --- configs/basicvsr_reds.yaml | 1 + configs/msvsr_reds.yaml | 5 ++ ppgan/models/generators/msvsr.py | 18 ++++--- .../configs/msvsr/train_infer_python.txt | 51 +++++++++++++++++++ test_tipc/prepare.sh | 15 ++++++ .../results/python_msvsr_results_fp32.txt | 2 + tools/inference.py | 10 ++-- 7 files changed, 89 insertions(+), 13 deletions(-) create mode 100644 test_tipc/configs/msvsr/train_infer_python.txt create mode 100644 test_tipc/results/python_msvsr_results_fp32.txt diff --git a/configs/basicvsr_reds.yaml b/configs/basicvsr_reds.yaml index 0b6b835..172764f 100644 --- a/configs/basicvsr_reds.yaml +++ b/configs/basicvsr_reds.yaml @@ -54,6 +54,7 @@ dataset: val_partition: REDS4 num_workers: 0 batch_size: 1 + num_clips: 270 lr_scheduler: name: CosineAnnealingRestartLR diff --git a/configs/msvsr_reds.yaml b/configs/msvsr_reds.yaml index 0ddc333..7767780 100644 --- a/configs/msvsr_reds.yaml +++ b/configs/msvsr_reds.yaml @@ -48,6 +48,7 @@ dataset: use_rot: True scale: 4 val_partition: REDS4 + num_clips: 270 test: name: SRREDSMultipleGTDataset @@ -63,6 +64,7 @@ dataset: val_partition: REDS4 num_workers: 0 batch_size: 1 + num_clips: 270 lr_scheduler: name: CosineAnnealingRestartLR @@ -100,3 +102,6 @@ log_config: snapshot_config: interval: 5000 + +export_model: + - {name: 'generator', inputs_num: 1} \ No newline at end of file diff --git a/ppgan/models/generators/msvsr.py b/ppgan/models/generators/msvsr.py index 40ce3c8..79e841c 100644 --- a/ppgan/models/generators/msvsr.py +++ b/ppgan/models/generators/msvsr.py @@ -303,7 +303,9 @@ class MSVSR(nn.Layer): pre_mask = {} # propagation branches module - for prop_name in ['stage2_backward', 'stage2_forward']: + prop_names = ['stage2_backward', 'stage2_forward'] + for index in range(2): + prop_name = prop_names[index] pre_offset[prop_name] = [0 for _ in range(t)] pre_mask[prop_name] = [0 for _ in range(t)] feats[prop_name] = [] @@ -372,7 +374,9 @@ class MSVSR(nn.Layer): n, t, _, h, w = flows_backward.shape # propagation branches module - for prop_name in ['stage3_backward', 'stage3_forward']: + prop_names = ['stage3_backward', 'stage3_forward'] + for index in range(2): + prop_name = prop_names[index] feats[prop_name] = [] frame_idx = range(0, t + 1) flow_idx = range(-1, t) @@ -439,7 +443,8 @@ class MSVSR(nn.Layer): mapping_idx = list(range(0, num_outputs)) mapping_idx += mapping_idx[::-1] - for i in range(0, lqs.shape[1]): + t = lqs.shape[1] + for i in range(0, t): hr = [feats[k][i] for k in feats if (k != 'spatial')] feat_current = feats['spatial'][mapping_idx[i]] hr.insert(0, feat_current) @@ -479,16 +484,13 @@ class MSVSR(nn.Layer): """ outputs = [] - outputs_head = [] num_outputs = len(feats['spatial']) mapping_idx = list(range(0, num_outputs)) mapping_idx += mapping_idx[::-1] - cas_outs = [] - pas = [] - hrs = [] - for i in range(0, lqs.shape[1]): + t = lqs.shape[1] + for i in range(0, t): hr = [ feats[k].pop(0) for k in feats if (k != 'spatial' and k != 'feat_stage1') diff --git a/test_tipc/configs/msvsr/train_infer_python.txt b/test_tipc/configs/msvsr/train_infer_python.txt new file mode 100644 index 0000000..a848037 --- /dev/null +++ b/test_tipc/configs/msvsr/train_infer_python.txt @@ -0,0 +1,51 @@ +===========================train_params=========================== +model_name:msvsr +python:python3.7 +gpu_list:0 +## +auto_cast:null +total_iters:lite_train_lite_infer=10|whole_train_whole_infer=200 +output_dir:./output/ +dataset.train.batch_size:lite_train_lite_infer=1|whole_train_whole_infer=1 +pretrained_model:null +train_model_name:msvsr_reds*/*checkpoint.pdparams +train_infer_img_dir:./data/msvsr_reds/test +null:null +## +trainer:norm_train +norm_train:tools/main.py -c configs/msvsr_reds.yaml --seed 123 -o dataset.train.dataset.num_clips=2 dataset.train.num_workers=0 log_config.interval=1 snapshot_config.interval=5 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +--output_dir:./output/ +load:null +norm_export:tools/export_model.py -c configs/msvsr_reds.yaml --inputs_size="1,4,3,180,320" --load +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +inference_dir:multistagevsrmodel_generator +train_model:./inference/msvsr/multistagevsrmodel_generator +infer_export:null +infer_quant:False +inference:tools/inference.py --model_type msvsr -c configs/msvsr_reds.yaml --seed 123 -o dataset.test.num_clips=2 dataset.test.number_frames=4 --output_path test_tipc/output/ +--device:gpu +null:null +null:null +null:null +null:null +null:null +--model_path: +null:null +null:null +--benchmark:True +null:null \ No newline at end of file diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index 27f7567..0f2248d 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -56,6 +56,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then rm -rf ./data/basicvsr* wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/basicvsr_lite.tar --no-check-certificate cd ./data/ && tar xf basicvsr_lite.tar && cd ../ + elif [ ${model_name} == "msvsr" ]; then + rm -rf ./data/basicvsr* + wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/basicvsr_lite.tar --no-check-certificate + cd ./data/ && tar xf basicvsr_lite.tar && cd ../ fi elif [ ${MODE} = "whole_train_whole_infer" ];then @@ -89,6 +93,10 @@ elif [ ${MODE} = "lite_train_whole_infer" ];then rm -rf ./data/REDS* wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/basicvsr_lite.tar --no-check-certificate cd ./data/ && tar xf basicvsr_lite.tar && cd ../ + elif [ ${model_name} == "msvsr" ]; then + rm -rf ./data/REDS* + wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/basicvsr_lite.tar --no-check-certificate + cd ./data/ && tar xf basicvsr_lite.tar && cd ../ fi elif [ ${MODE} = "whole_infer" ];then if [ ${model_name} = "pix2pix" ]; then @@ -125,6 +133,13 @@ elif [ ${MODE} = "whole_infer" ];then wget -nc -P ./inference https://paddlegan.bj.bcebos.com/static_model/basicvsr.tar --no-check-certificate cd ./inference && tar xf basicvsr.tar && cd ../ cd ./data/ && tar xf basicvsr_lite_test.tar && cd ../ + elif [ ${model_name} == "msvsr" ]; then + rm -rf ./data/basic* + rm -rf ./inference/msvsr* + wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/basicvsr_lite_test.tar --no-check-certificate + wget -nc -P ./inference https://paddlegan.bj.bcebos.com/static_model/msvsr.tar --no-check-certificate + cd ./inference && tar xf msvsr.tar && cd ../ + cd ./data/ && tar xf basicvsr_lite_test.tar && cd ../ fi fi diff --git a/test_tipc/results/python_msvsr_results_fp32.txt b/test_tipc/results/python_msvsr_results_fp32.txt new file mode 100644 index 0000000..0de2d2f --- /dev/null +++ b/test_tipc/results/python_msvsr_results_fp32.txt @@ -0,0 +1,2 @@ +Metric psnr: 27.3670 +Metric ssim: 0.8021 \ No newline at end of file diff --git a/tools/inference.py b/tools/inference.py index 9a6fbff..5c4883a 100644 --- a/tools/inference.py +++ b/tools/inference.py @@ -14,7 +14,8 @@ from ppgan.utils.filesystem import makedirs from ppgan.metrics import build_metric -MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", "edvr", "fom", "stylegan2", "basicvsr"] +MODEL_CLASSES = ["pix2pix", "cyclegan", "wav2lip", "esrgan", \ + "edvr", "fom", "stylegan2", "basicvsr", "msvsr"] def parse_args(): @@ -106,7 +107,6 @@ def main(): max_eval_steps = len(test_dataloader) iter_loader = IterLoader(test_dataloader) - min_max = cfg.get('min_max', None) if min_max is None: min_max = (-1., 1.) @@ -192,7 +192,7 @@ def main(): real_img = paddle.to_tensor(data['A']) for metric in metrics.values(): metric.update(prediction, real_img) - elif model_type == "basicvsr": + elif model_type in ["basicvsr", "msvsr"]: lq = data['lq'].numpy() input_handles[0].copy_from_cpu(lq) predictor.run() @@ -208,9 +208,9 @@ def main(): gt_img.append(tensor2img(gt_tensor, (0.,1.))) image_numpy = tensor2img(prediction[0], min_max) - save_image(image_numpy, os.path.join(args.output_path, "basicvsr/{}.png".format(i))) + save_image(image_numpy, os.path.join(args.output_path, model_type, "{}.png".format(i))) - metric_file = os.path.join(args.output_path, "basicvsr/metric.txt") + metric_file = os.path.join(args.output_path, model_type, "metric.txt") for metric in metrics.values(): metric.update(out_img, gt_img, is_seq=True) -- GitLab