diff --git a/tools/infer.py b/tools/infer.py index 89591080d21137197589e2a5ab42b64f58c0dee9..3835264b2099562e4dfcb3accf07ebf02119b886 100644 --- a/tools/infer.py +++ b/tools/infer.py @@ -74,20 +74,20 @@ def get_test_images(infer_dir, infer_img): "{} is not a file".format(infer_img) assert infer_dir is None or os.path.isdir(infer_dir), \ "{} is not a directory".format(infer_dir) - images = [] # infer_img has a higher priority if infer_img and os.path.isfile(infer_img): - images.append(infer_img) - return images + return [infer_img] + images = set() infer_dir = os.path.abspath(infer_dir) assert os.path.isdir(infer_dir), \ "infer_dir {} is not a directory".format(infer_dir) exts = ['jpg', 'jpeg', 'png', 'bmp'] exts += [ext.upper() for ext in exts] for ext in exts: - images.extend(glob.glob('{}/*.{}'.format(infer_dir, ext))) + images.update(glob.glob('{}/*.{}'.format(infer_dir, ext))) + images = list(images) assert len(images) > 0, "no image found in {}".format(infer_dir) logger.info("Found {} inference images in total.".format(len(images)))