提交 a4138a62 编写于 作者: L lijianshe02

support variable length input

上级 542c8839
...@@ -19,7 +19,6 @@ INFER: ...@@ -19,7 +19,6 @@ INFER:
number_frames: 5 number_frames: 5
batch_size: 1 batch_size: 1
file_root: "/workspace/color/input_frames" file_root: "/workspace/color/input_frames"
#file_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/sharp_bicubic" inference_model: "/workspace/PaddleGAN/applications/EDVR/data/inference_model"
#gt_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/GT"
use_flip: False use_flip: False
use_rot: False use_rot: False
...@@ -46,6 +46,11 @@ def parse_args(): ...@@ -46,6 +46,11 @@ def parse_args():
type=str, type=str,
default='AttentionCluster', default='AttentionCluster',
help='name of model to train.') help='name of model to train.')
parser.add_argument(
'--inference_model',
type=str,
default='./data/inference_model',
help='path of inference_model.')
parser.add_argument( parser.add_argument(
'--config', '--config',
type=str, type=str,
...@@ -111,14 +116,13 @@ def infer(args): ...@@ -111,14 +116,13 @@ def infer(args):
config = parse_config(args.config) config = parse_config(args.config)
infer_config = merge_configs(config, 'infer', vars(args)) infer_config = merge_configs(config, 'infer', vars(args))
print_configs(infer_config, "Infer") print_configs(infer_config, "Infer")
inference_model = args.inference_model
model_path = '/workspace/PaddleGAN/applications/EDVR/data/inference_model'
model_filename = 'EDVR_model.pdmodel' model_filename = 'EDVR_model.pdmodel'
params_filename = 'EDVR_params.pdparams' params_filename = 'EDVR_params.pdparams'
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
[inference_program, feed_list, fetch_list] = fluid.io.load_inference_model(dirname=model_path, model_filename=model_filename, params_filename=params_filename, executor=exe) [inference_program, feed_list, fetch_list] = fluid.io.load_inference_model(dirname=inference_model, model_filename=model_filename, params_filename=params_filename, executor=exe)
infer_reader = get_reader(args.model_name.upper(), 'infer', infer_config) infer_reader = get_reader(args.model_name.upper(), 'infer', infer_config)
#infer_metrics = get_metrics(args.model_name.upper(), 'infer', infer_config) #infer_metrics = get_metrics(args.model_name.upper(), 'infer', infer_config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册