diff --git a/pdseg/check.py b/pdseg/check.py index 7937e046134548765a3b2a3c046798fd8997ee05..ab690fb1fd2d30a4b122b7c15709fd248e025cd2 100644 --- a/pdseg/check.py +++ b/pdseg/check.py @@ -34,6 +34,7 @@ def init_global_variable(): global list_wrong #文件名格式错误列表 global imread_failed #图片读取失败列表, 二元列表 global label_wrong # 标注图片出错列表 + global label_gray_wrong # 标注图非灰度图列表 png_format_right_num = 0 png_format_wrong_num = 0 @@ -49,6 +50,7 @@ def init_global_variable(): list_wrong = [] imread_failed = [] label_wrong = [] + label_gray_wrong = [] def parse_args(): parser = argparse.ArgumentParser(description='PaddleSeg check') @@ -68,10 +70,13 @@ def correct_print(str): return "".join(["\nPASS ", str]) def cv2_imread(file_path, flag=cv2.IMREAD_COLOR): - # resolve cv2.imread open Chinese file path issues on Windows Platform. + """ + 解决 cv2.imread 在window平台打开中文路径的问题. + """ return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag) def get_image_max_height_width(img): + """获取图片最大宽和高""" global max_width, max_height img_shape = img.shape height, width = img_shape[0], img_shape[1] @@ -79,6 +84,7 @@ def get_image_max_height_width(img): max_width = max(width, max_width) def get_image_min_max_aspectratio(img): + """计算图片最大宽高比""" global min_aspectratio, max_aspectratio img_shape = img.shape height, width = img_shape[0], img_shape[1] @@ -87,11 +93,19 @@ def get_image_min_max_aspectratio(img): return min_aspectratio, max_aspectratio def get_image_dim(img): - """获取图像的维度""" + """获取图像的通道数""" img_shape = img.shape if img_shape[-1] not in img_dim: img_dim.append(img_shape[-1]) +def is_label_gray(grt): + """判断标签是否为灰度图""" + grt_shape = grt.shape + if len(grt_shape) == 2: + return True + else: + return False + def image_label_shape_check(img, grt): """ 验证图像和标注的大小是否匹配 @@ -110,17 +124,15 @@ def image_label_shape_check(img, grt): def ground_truth_check(grt, grt_path): """ - 验证标注是否重零开始,标注值为0,1,...,num_classes-1, ingnore_idx 验证标注图像的格式 - 返回标注的像素数 - 检查图像是否都是ignore_index + 统计标注图类别和像素数 params: grt: 标注图 grt_path: 标注图路径 return: png_format: 返回是否是png格式图片 - label_correct: 返回标注是否是正确的 - label_pixel_num: 返回标注的像素数 + unique: 返回标注类别 + counts: 返回标注的像素数 """ if imghdr.what(grt_path) == "png": png_format = True @@ -135,7 +147,7 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class): """ 统计所有标注图上的格式、类别和每个类别的像素数 params: - png_format: 返回是否是png格式图片 + png_format: 是否是png格式图片 grt_classes: 标注类别 num_of_each_class: 各个类别的像素数目 """ @@ -188,7 +200,7 @@ def gt_check(): total_nc = sorted(zip(total_grt_classes, total_num_of_each_class)) - logger.info("\nDoing label pixel statistics...\nTotal label calsses " + logger.info("\nDoing label pixel statistics...\nTotal label classes " "and their corresponding numbers:\n{} ".format(total_nc)) if len(label_wrong) == 0 and not total_nc[0][0]: @@ -322,6 +334,17 @@ def imread_check(): for i in imread_failed: logger.debug(i) +def label_gray_check(): + if len(label_gray_wrong) == 0: + logger.info(correct_print("label gray check")) + logger.info("All label images are gray") + else: + logger.info(error_print("label gray check")) + logger.info("{} label images are not gray\nLabel pixel statistics may " + "be insignificant".format(len(label_gray_wrong))) + for i in label_gray_wrong: + logger.debug(i) + def check_train_dataset(): @@ -340,11 +363,15 @@ def check_train_dataset(): grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name) try: img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) - grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) + grt = cv2_imread(grt_path, cv2.IMREAD_UNCHANGED) except Exception as e: imread_failed.append((line, str(e))) continue + is_gray = is_label_gray(grt) + if not is_gray: + label_gray_wrong.append(line) + grt = cv2.cvtColor(grt, cv2.COLOR_BGR2GRAY) get_image_dim(img) is_equal_img_grt_shape = image_label_shape_check(img, grt) if not is_equal_img_grt_shape: @@ -359,6 +386,7 @@ def check_train_dataset(): file_list_check(list_file) imread_check() + label_gray_check() gt_check() image_type_check(img_dim) shape_check() @@ -383,9 +411,14 @@ def check_val_dataset(): grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name) try: img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) - grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) + grt = cv2_imread(grt_path, cv2.IMREAD_UNCHANGED) except Exception as e: imread_failed.append((line, e.message)) + + is_gray = is_label_gray(grt) + if not is_gray: + label_gray_wrong.append(line) + grt = cv2.cvtColor(grt, cv2.COLOR_BGR2GRAY) get_image_max_height_width(img) get_image_min_max_aspectratio(img) get_image_dim(img) @@ -401,6 +434,7 @@ def check_val_dataset(): file_list_check(list_file) imread_check() + label_gray_check() gt_check() image_type_check(img_dim) shape_check() @@ -430,10 +464,15 @@ def check_test_dataset(): grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name) try: img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) - grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) + grt = cv2_imread(grt_path, cv2.IMREAD_UNCHANGED) except Exception as e: imread_failed.append((line, e.message)) continue + + is_gray = is_label_gray(grt) + if not is_gray: + label_gray_wrong.append(line) + grt = cv2.cvtColor(grt, cv2.COLOR_BGR2GRAY) is_equal_img_grt_shape = image_label_shape_check(img, grt) if not is_equal_img_grt_shape: shape_unequal_image.append(line) @@ -452,6 +491,8 @@ def check_test_dataset(): file_list_check(list_file) imread_check() + if has_label: + label_gray_check() if has_label: gt_check() image_type_check(img_dim) diff --git a/requirements.txt b/requirements.txt index 0d438f9949547df1a477b468382cc79d55866bb6..0fdc3d573a2655ded061aa348303ae7e79c94b3d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ Pillow numpy six opencv-python +tqdm diff --git a/test/local_test_cityscapes.py b/test/local_test_cityscapes.py index 0bf3c6aeb04e551da0cba454c6e9db7575efb1bb..6618695a60aae5f07230c546337b611d7c1cc78a 100644 --- a/test/local_test_cityscapes.py +++ b/test/local_test_cityscapes.py @@ -14,6 +14,7 @@ from test_utils import download_file_and_uncompress, train, eval, vis, export_model import os +import argparse LOCAL_PATH = os.path.dirname(os.path.abspath(__file__)) DATASET_PATH = os.path.join(LOCAL_PATH, "..", "dataset") @@ -43,7 +44,16 @@ if __name__ == "__main__": vis_dir = os.path.join(LOCAL_PATH, "visual", model_name) saved_model = os.path.join(LOCAL_PATH, "saved_model", model_name) - devices = ['0'] + parser = argparse.ArgumentParser(description="PaddleSeg loacl test") + parser.add_argument("--devices", + dest="devices", + help="GPU id of running. if more than one, use spacing to separate.", + nargs="+", + default=0, + type=int) + args = parser.parse_args() + + devices = [str(x) for x in args.devices] export_model( flags=["--cfg", cfg], diff --git a/test/local_test_pet.py b/test/local_test_pet.py index 1596f4b2b7be92c32dd8cad1d8be1d79e794d142..7d0cf58cd1235575fc960769d5142865993b5763 100644 --- a/test/local_test_pet.py +++ b/test/local_test_pet.py @@ -14,6 +14,7 @@ from test_utils import download_file_and_uncompress, train, eval, vis, export_model import os +import argparse LOCAL_PATH = os.path.dirname(os.path.abspath(__file__)) DATASET_PATH = os.path.join(LOCAL_PATH, "..", "dataset") @@ -44,7 +45,16 @@ if __name__ == "__main__": vis_dir = os.path.join(LOCAL_PATH, "visual", model_name) saved_model = os.path.join(LOCAL_PATH, "saved_model", model_name) - devices = ['0'] + parser = argparse.ArgumentParser(description="PaddleSeg loacl test") + parser.add_argument("--devices", + dest="devices", + help="GPU id of running. if more than one, use spacing to separate.", + nargs="+", + default=0, + type=int) + args = parser.parse_args() + + devices = [str(x) for x in args.devices] train( flags=["--cfg", cfg, "--use_gpu", "--log_steps", "10"],