From a4138a6206d6888e5489311eee7571d866b0a2f8 Mon Sep 17 00:00:00 2001 From: lijianshe02 Date: Wed, 12 Aug 2020 10:23:56 +0000 Subject: [PATCH] support variable length input --- applications/EDVR/configs/edvr_L.yaml | 3 +-- applications/EDVR/predict.py | 10 +++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/applications/EDVR/configs/edvr_L.yaml b/applications/EDVR/configs/edvr_L.yaml index 056a152..91b05f9 100644 --- a/applications/EDVR/configs/edvr_L.yaml +++ b/applications/EDVR/configs/edvr_L.yaml @@ -19,7 +19,6 @@ INFER: number_frames: 5 batch_size: 1 file_root: "/workspace/color/input_frames" - #file_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/sharp_bicubic" - #gt_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/GT" + inference_model: "/workspace/PaddleGAN/applications/EDVR/data/inference_model" use_flip: False use_rot: False diff --git a/applications/EDVR/predict.py b/applications/EDVR/predict.py index e8bb6be..c45904a 100644 --- a/applications/EDVR/predict.py +++ b/applications/EDVR/predict.py @@ -46,6 +46,11 @@ def parse_args(): type=str, default='AttentionCluster', help='name of model to train.') + parser.add_argument( + '--inference_model', + type=str, + default='./data/inference_model', + help='path of inference_model.') parser.add_argument( '--config', type=str, @@ -111,14 +116,13 @@ def infer(args): config = parse_config(args.config) infer_config = merge_configs(config, 'infer', vars(args)) print_configs(infer_config, "Infer") - - model_path = '/workspace/PaddleGAN/applications/EDVR/data/inference_model' + inference_model = args.inference_model model_filename = 'EDVR_model.pdmodel' params_filename = 'EDVR_params.pdparams' place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) - [inference_program, feed_list, fetch_list] = fluid.io.load_inference_model(dirname=model_path, model_filename=model_filename, params_filename=params_filename, executor=exe) + [inference_program, feed_list, fetch_list] = fluid.io.load_inference_model(dirname=inference_model, model_filename=model_filename, params_filename=params_filename, executor=exe) infer_reader = get_reader(args.model_name.upper(), 'infer', infer_config) #infer_metrics = get_metrics(args.model_name.upper(), 'infer', infer_config) -- GitLab