diff --git a/applications/EDVR/configs/edvr_L.yaml b/applications/EDVR/configs/edvr_L.yaml index 056a1520700f36ca064ebc877365dfa0380a6779..91b05f945e5e8fcc751ce878fb67c513c752e5f4 100644 --- a/applications/EDVR/configs/edvr_L.yaml +++ b/applications/EDVR/configs/edvr_L.yaml @@ -19,7 +19,6 @@ INFER: number_frames: 5 batch_size: 1 file_root: "/workspace/color/input_frames" - #file_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/sharp_bicubic" - #gt_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/GT" + inference_model: "/workspace/PaddleGAN/applications/EDVR/data/inference_model" use_flip: False use_rot: False diff --git a/applications/EDVR/predict.py b/applications/EDVR/predict.py index e8bb6be336dcbf5d2435de61eea5844c6697632b..c45904a2698d2e38702df2f91dc96b7ed1cf0ae8 100644 --- a/applications/EDVR/predict.py +++ b/applications/EDVR/predict.py @@ -46,6 +46,11 @@ def parse_args(): type=str, default='AttentionCluster', help='name of model to train.') + parser.add_argument( + '--inference_model', + type=str, + default='./data/inference_model', + help='path of inference_model.') parser.add_argument( '--config', type=str, @@ -111,14 +116,13 @@ def infer(args): config = parse_config(args.config) infer_config = merge_configs(config, 'infer', vars(args)) print_configs(infer_config, "Infer") - - model_path = '/workspace/PaddleGAN/applications/EDVR/data/inference_model' + inference_model = args.inference_model model_filename = 'EDVR_model.pdmodel' params_filename = 'EDVR_params.pdparams' place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) - [inference_program, feed_list, fetch_list] = fluid.io.load_inference_model(dirname=model_path, model_filename=model_filename, params_filename=params_filename, executor=exe) + [inference_program, feed_list, fetch_list] = fluid.io.load_inference_model(dirname=inference_model, model_filename=model_filename, params_filename=params_filename, executor=exe) infer_reader = get_reader(args.model_name.upper(), 'infer', infer_config) #infer_metrics = get_metrics(args.model_name.upper(), 'infer', infer_config)