From 3e17a8c2369fc4bf445aa3c3cdc44910312fb905 Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Thu, 13 Jan 2022 09:22:08 +0800 Subject: [PATCH] fix mprnet app (#558) --- applications/tools/image_restoration.py | 4 ++-- ppgan/apps/mpr_predictor.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/applications/tools/image_restoration.py b/applications/tools/image_restoration.py index d64a376..bf1e610 100644 --- a/applications/tools/image_restoration.py +++ b/applications/tools/image_restoration.py @@ -59,9 +59,9 @@ if __name__ == "__main__": if args.cpu: paddle.set_device('cpu') - predictor = MPRPredictor(images_path=args.images_path, + predictor = MPRPredictor( output_path=args.output_path, weight_path=args.weight_path, seed=args.seed, task=args.task) - predictor.run() + predictor.run(images_path=args.images_path) diff --git a/ppgan/apps/mpr_predictor.py b/ppgan/apps/mpr_predictor.py index 6ae12df..2ec7ee8 100644 --- a/ppgan/apps/mpr_predictor.py +++ b/ppgan/apps/mpr_predictor.py @@ -89,6 +89,7 @@ class MPRPredictor(BasePredictor): def get_images(self, images_path): if os.path.isdir(images_path): return natsorted( + glob(os.path.join(images_path, '*.jpeg')) + glob(os.path.join(images_path, '*.jpg')) + glob(os.path.join(images_path, '*.JPG')) + glob(os.path.join(images_path, '*.png')) + -- GitLab