diff --git a/PaddleCV/object_detection/tools/infer.py b/PaddleCV/object_detection/tools/infer.py index be43ea41ee45bdf3045c7ee08977b9c35231fb3f..c4ee32d3e62b1a5c37061d1795126b1913b4870a 100644 --- a/PaddleCV/object_detection/tools/infer.py +++ b/PaddleCV/object_detection/tools/infer.py @@ -56,6 +56,10 @@ def get_test_images(infer_dir, infer_img): """ assert infer_img is not None or infer_dir is not None, \ "--infer_img or --infer_dir should be set" + assert infer_img is None or os.path.isfile(infer_img), \ + "{} is not a file".format(infer_img) + assert infer_dir is None or os.path.isdir(infer_dir), \ + "{} is not a directory".format(infer_dir) images = [] # infer_img has a higher priority @@ -162,8 +166,8 @@ def main(): for im_id in im_ids: image_path = imid2path[int(im_id)] image = Image.open(image_path).convert('RGB') - visualize_results(image, int(im_id), catid2name, 0.5, bbox_results, - mask_results, is_bbox_normalized) + image = visualize_results(image, int(im_id), catid2name, 0.5, + bbox_results, mask_results, is_bbox_normalized) save_name = get_save_image_name(FLAGS.output_dir, image_path) logger.info("Detection bbox results save in {}".format(save_name)) image.save(save_name)