diff --git a/applications/tools/image_restoration.py b/applications/tools/image_restoration.py index d64a376d0604e3c5aa7584021b4663511e110abf..bf1e610c257e006f57661e393fcd0a650cc08f4b 100644 --- a/applications/tools/image_restoration.py +++ b/applications/tools/image_restoration.py @@ -59,9 +59,9 @@ if __name__ == "__main__": if args.cpu: paddle.set_device('cpu') - predictor = MPRPredictor(images_path=args.images_path, + predictor = MPRPredictor( output_path=args.output_path, weight_path=args.weight_path, seed=args.seed, task=args.task) - predictor.run() + predictor.run(images_path=args.images_path) diff --git a/ppgan/apps/mpr_predictor.py b/ppgan/apps/mpr_predictor.py index 6ae12df54921f8a7186cfcb76edccef1f799f8bf..2ec7ee868fc9e081d8d18f4fd350713bd826800a 100644 --- a/ppgan/apps/mpr_predictor.py +++ b/ppgan/apps/mpr_predictor.py @@ -89,6 +89,7 @@ class MPRPredictor(BasePredictor): def get_images(self, images_path): if os.path.isdir(images_path): return natsorted( + glob(os.path.join(images_path, '*.jpeg')) + glob(os.path.join(images_path, '*.jpg')) + glob(os.path.join(images_path, '*.JPG')) + glob(os.path.join(images_path, '*.png')) +