From 2269b08f6975c07114110d4ab7734075ad59e3e8 Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Wed, 28 Aug 2019 15:12:34 +0800 Subject: [PATCH] update check.py --- pdseg/check.py | 525 +++++++++++++++++++++++++++---------------------- 1 file changed, 292 insertions(+), 233 deletions(-) diff --git a/pdseg/check.py b/pdseg/check.py index 92498c66..bb0e5c1d 100644 --- a/pdseg/check.py +++ b/pdseg/check.py @@ -1,4 +1,4 @@ -# coding: utf8 +# -*- coding: utf-8 -*- from __future__ import absolute_import from __future__ import division @@ -12,76 +12,134 @@ import argparse import cv2 from tqdm import tqdm import imghdr +import logging from utils.config import cfg - def init_global_variable(): """ 初始化全局变量 """ - global png_format_right_num # 格式错误的标签图数量 - global png_format_wrong_num # 格式错误的标签图数量 - global total_grt_classes # 总的标签类别 + global png_format_right_num # 格式错误的标注图数量 + global png_format_wrong_num # 格式错误的标注图数量 + global total_grt_classes # 总的标注类别 global total_num_of_each_class # 每个类别总的像素数 - global shape_unequal # 图片和标签shape不一致 - global png_format_wrong # 标签格式错误 + global shape_unequal_image # 图片和标注shape不一致列表 + global png_format_wrong_image # 标注格式错误列表 + global max_width # 图片最长宽 + global max_height # 图片最长高 + global min_aspectratio # 图片最小宽高比 + global max_aspectratio # 图片最大宽高比 + global img_dim # 图片的通道数 + global list_wrong #文件名格式错误列表 + global imread_failed #图片读取失败列表, 二元列表 + global label_wrong # 标注图片出错列表 png_format_right_num = 0 png_format_wrong_num = 0 total_grt_classes = [] total_num_of_each_class = [] - shape_unequal = [] - png_format_wrong = [] - + shape_unequal_image = [] + png_format_wrong_image = [] + max_width = 0 + max_height = 0 + min_aspectratio = sys.float_info.max + max_aspectratio = 0 + img_dim = [] + list_wrong = [] + imread_failed = [] + label_wrong = [] def parse_args(): parser = argparse.ArgumentParser(description='PaddleSeg check') parser.add_argument( - '--cfg', - dest='cfg_file', - help='Config file for training (and optionally testing)', - default=None, - type=str) + '--cfg', + dest='cfg_file', + help='Config file for training (and optionally testing)', + default=None, + type=str + ) return parser.parse_args() +def error_print(str): + return "".join(["\nNOT PASS ", str]) + +def correct_print(str): + return "".join(["\nPASS ", str]) def cv2_imread(file_path, flag=cv2.IMREAD_COLOR): # resolve cv2.imread open Chinese file path issues on Windows Platform. return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag) - -def get_image_max_height_width(img, max_height, max_width): +def get_image_max_height_width(img): + global max_width, max_height img_shape = img.shape height, width = img_shape[0], img_shape[1] max_height = max(height, max_height) max_width = max(width, max_width) - return max_height, max_width - -def get_image_min_max_aspectratio(img, min_aspectratio, max_aspectratio): +def get_image_min_max_aspectratio(img): + global min_aspectratio, max_aspectratio img_shape = img.shape height, width = img_shape[0], img_shape[1] - min_aspectratio = min(width / height, min_aspectratio) - max_aspectratio = max(width / height, max_aspectratio) + min_aspectratio = min(width/height, min_aspectratio) + max_aspectratio = max(width/height, max_aspectratio) return min_aspectratio, max_aspectratio - -def get_image_dim(img, img_dim): +def get_image_dim(img): """获取图像的维度""" img_shape = img.shape if img_shape[-1] not in img_dim: img_dim.append(img_shape[-1]) +def image_label_shape_check(img, grt): + """ + 验证图像和标注的大小是否匹配 + """ + + flag = True + img_height = img.shape[0] + img_width = img.shape[1] + grt_height = grt.shape[0] + grt_width = grt.shape[1] + + + if img_height != grt_height or img_width != grt_width: + flag = False + return flag + +def ground_truth_check(grt, grt_path): + """ + 验证标注是否重零开始,标注值为0,1,...,num_classes-1, ingnore_idx + 验证标注图像的格式 + 返回标注的像素数 + 检查图像是否都是ignore_index + params: + grt: 标注图 + grt_path: 标注图路径 + return: + png_format: 返回是否是png格式图片 + label_correct: 返回标注是否是正确的 + label_pixel_num: 返回标注的像素数 + """ + if imghdr.what(grt_path) == "png": + png_format = True + else: + png_format = False + + unique, counts = np.unique(grt, return_counts=True) + + return png_format, unique, counts def sum_gt_check(png_format, grt_classes, num_of_each_class): """ - 统计所有标签图上的格式、类别和每个类别的像素数 + 统计所有标注图上的格式、类别和每个类别的像素数 params: png_format: 返回是否是png格式图片 - grt_classes: 标签类别 + grt_classes: 标注类别 num_of_each_class: 各个类别的像素数目 """ + is_label_correct = True global png_format_right_num, png_format_wrong_num, total_grt_classes, total_num_of_each_class if png_format: @@ -90,12 +148,11 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class): png_format_wrong_num += 1 if cfg.DATASET.IGNORE_INDEX in grt_classes: - grt_classes2 = np.delete( - grt_classes, np.where(grt_classes == cfg.DATASET.IGNORE_INDEX)) + grt_classes2 = np.delete(grt_classes, np.where(grt_classes == cfg.DATASET.IGNORE_INDEX)) + else: + grt_classes2 = grt_classes if min(grt_classes2) < 0 or max(grt_classes2) > cfg.DATASET.NUM_CLASSES - 1: - print("fatal error: label class is out of range [0, {}]".format( - cfg.DATASET.NUM_CLASSES - 1)) - + is_label_correct = False add_class = [] add_num = [] for i in range(len(grt_classes)): @@ -108,145 +165,113 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class): add_num.append(num_of_each_class[i]) total_num_of_each_class += add_num total_grt_classes += add_class - + return is_label_correct def gt_check(): """ - 对标签进行校验,输出校验结果 - params: - png_format_wrong_num: 格式错误的标签图数量 - png_format_right_num: 格式正确的标签图数量 - total_grt_classes: 总的标签类别 - total_num_of_each_class: 每个类别总的像素数目 - return: - total_nc: 按升序排序后的总标签类别和像素数目 + 对标注图像进行校验,输出校验结果 """ if png_format_wrong_num == 0: - print("Not pass label png format check!") + if png_format_right_num: + logger.info(correct_print("label format check")) + else: + logger.info(error_print("label format check")) + logger.info("No label image to check") + return else: - print("Pass label png format check!") - print( - "total {} label imgs are png format, {} label imgs are not png fromat". - format(png_format_right_num, png_format_wrong_num)) + logger.info(error_print("label format check")) + logger.info("total {} label images are png format, {} label images are not png " + "format".format(png_format_right_num, png_format_wrong_num)) + if len(png_format_wrong_image) > 0: + for i in png_format_wrong_image: + logger.debug(i) - total_nc = sorted(zip(total_grt_classes, total_num_of_each_class)) - print("total label calsses and their corresponding numbers:\n{} ".format( - total_nc)) - if total_nc[0][0]: - print( - "Not pass label class check!\nWarning: label classes should start from 0 !!!" - ) - else: - print("Pass label class check!") + total_nc = sorted(zip(total_grt_classes, total_num_of_each_class)) + logger.info("\nDoing label pixel statistics...\nTotal label calsses " + "and their corresponding numbers:\n{} ".format(total_nc)) -def ground_truth_check(grt, grt_path): - """ - 验证标签是否重零开始,标签值为0,1,...,num_classes-1, ingnore_idx - 验证标签图像的格式 - 返回标签的像素数 - 检查图像是否都是ignore_index - params: - grt: 标签图 - grt_path: 标签图路径 - return: - png_format: 返回是否是png格式图片 - label_correct: 返回标签是否是正确的 - label_pixel_num: 返回标签的像素数 - """ - if imghdr.what(grt_path) == "png": - png_format = True + if len(label_wrong) == 0 and not total_nc[0][0]: + logger.info(correct_print("label class check!")) else: - png_format = False + logger.info(error_print("label class check!")) + if total_nc[0][0]: + logger.info("Warning: label classes should start from 0") + if len(label_wrong) > 0: + logger.info("fatal error: label class is out of range [0, {}]".format(cfg.DATASET.NUM_CLASSES - 1)) + for i in label_wrong: + logger.debug(i) - unique, counts = np.unique(grt, return_counts=True) - return png_format, unique, counts - -def eval_crop_size_check(max_height, max_width, min_aspectratio, - max_aspectratio): +def eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio): """ 判断eval_crop_siz与验证集及测试集的max_height, max_width的关系 param max_height: 数据集的最大高 max_width: 数据集的最大宽 """ + if cfg.AUG.AUG_METHOD == "stepscaling": - flag = True - if max_width > cfg.EVAL_CROP_SIZE[0]: - print( - "ERROR: The EVAL_CROP_SIZE[0]: {} should larger max width of images {}!" - .format(cfg.EVAL_CROP_SIZE[0], max_width)) - flag = False - if max_height > cfg.EVAL_CROP_SIZE[1]: - print( - "ERROR: The EVAL_CROP_SIZE[1]: {} should larger max height of images {}!" - .format(cfg.EVAL_CROP_SIZE[1], max_height)) - flag = False - if flag: - print("EVAL_CROP_SIZE setting correct") + if max_width <= cfg.EVAL_CROP_SIZE[0] or max_height <= cfg.EVAL_CROP_SIZE[1]: + logger.info(correct_print("EVAL_CROP_SIZE check")) + else: + logger.info(error_print("EVAL_CROP_SIZE check")) + if max_width > cfg.EVAL_CROP_SIZE[0]: + logger.info("The EVAL_CROP_SIZE[0]: {} should larger max width of images {}!".format( + cfg.EVAL_CROP_SIZE[0], max_width)) + if max_height > cfg.EVAL_CROP_SIZE[1]: + logger.info(error_print("The EVAL_CROP_SIZE[1]: {} should larger max height of images {}!".format( + cfg.EVAL_CROP_SIZE[1], max_height))) + elif cfg.AUG.AUG_METHOD == "rangescaling": if min_aspectratio <= 1 and max_aspectratio >= 1: - if cfg.EVAL_CROP_SIZE[ - 0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[ - 1] >= cfg.AUG.INF_RESIZE_VALUE: - print("EVAL_CROP_SIZE setting correct") + if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.INF_RESIZE_VALUE: + logger.info(correct_print("EVAL_CROP_SIZE check")) else: - print( - "ERROR: EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" + logger.info(error_print("EVAL_CROP_SIZE check")) + logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.INF_RESIZE_VALUE)) elif min_aspectratio > 1: max_height_rangscaling = cfg.AUG.INF_RESIZE_VALUE / min_aspectratio max_height_rangscaling = round(max_height_rangscaling) - if cfg.EVAL_CROP_SIZE[ - 0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[ - 1] >= max_height_rangscaling: - print("EVAL_CROP_SIZE setting correct") + if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[1] >= max_height_rangscaling: + logger.info(correct_print("EVAL_CROP_SIZE check")) else: - print( - "ERROR: EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" - .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], - cfg.AUG.INF_RESIZE_VALUE, max_height_rangscaling)) + logger.info(error_print("EVAL_CROP_SIZE check")) + logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" + .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], + cfg.AUG.INF_RESIZE_VALUE, max_height_rangscaling)) elif max_aspectratio < 1: max_width_rangscaling = cfg.AUG.INF_RESIZE_VALUE * max_aspectratio max_width_rangscaling = round(max_width_rangscaling) - if cfg.EVAL_CROP_SIZE[ - 0] >= max_width_rangscaling and cfg.EVAL_CROP_SIZE[ - 1] >= cfg.AUG.INF_RESIZE_VALUE: - print("EVAL_CROP_SIZE setting correct") + if cfg.EVAL_CROP_SIZE[0] >= max_width_rangscaling and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.INF_RESIZE_VALUE: + logger.info(correct_print("EVAL_CROP_SIZE check")) else: - print( - "ERROR: EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" + logger.info(error_print("EVAL_CROP_SIZE check")) + logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], max_width_rangscaling, cfg.AUG.INF_RESIZE_VALUE)) elif cfg.AUG.AUG_METHOD == "unpadding": - if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.FIX_RESIZE_SIZE[ - 0] and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.FIX_RESIZE_SIZE[1]: - print("EVAL_CROP_SIZE setting correct") + if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.FIX_RESIZE_SIZE[0] and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.FIX_RESIZE_SIZE[1]: + logger.info(correct_print("EVAL_CROP_SIZE check")) else: - print( - "ERROR: EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" - .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], - cfg.AUG.FIX_RESIZE_SIZE[0], cfg.AUG.FIX_RESIZE_SIZE[1])) + logger.info(error_print("EVAL_CROP_SIZE check")) + logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" + .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], + cfg.AUG.FIX_RESIZE_SIZE[0], cfg.AUG.FIX_RESIZE_SIZE[1])) else: - print( - "ERROR: cfg.AUG.AUG_METHOD setting wrong, it should be one of [unpadding, stepscaling, rangescaling]" - ) - + logger.info("\nERROR! cfg.AUG.AUG_METHOD setting wrong, it should be one of " + "[unpadding, stepscaling, rangescaling]") def inf_resize_value_check(): if cfg.AUG.AUG_METHOD == "rangescaling": if cfg.AUG.INF_RESIZE_VALUE < cfg.AUG.MIN_RESIZE_VALUE or \ cfg.AUG.INF_RESIZE_VALUE > cfg.AUG.MIN_RESIZE_VALUE: - print( - "ERROR: you set AUG.AUG_METHOD = 'rangescaling'" - "AUG.INF_RESIZE_VALUE: {} not in [AUG.MIN_RESIZE_VALUE, AUG.MAX_RESIZE_VALUE]: " - "[{}, {}].".format(cfg.AUG.INF_RESIZE_VALUE, - cfg.AUG.MIN_RESIZE_VALUE, - cfg.AUG.MAX_RESIZE_VALUE)) - + logger.info("\nWARNING! you set AUG.AUG_METHOD = 'rangescaling'" + "AUG.INF_RESIZE_VALUE: {} not in [AUG.MIN_RESIZE_VALUE, AUG.MAX_RESIZE_VALUE]: " + "[{}, {}].".format(cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.MIN_RESIZE_VALUE, cfg.AUG.MAX_RESIZE_VALUE)) def image_type_check(img_dim): """ @@ -256,166 +281,189 @@ def image_type_check(img_dim): return """ if (1 in img_dim or 3 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgba': - print( - "ERROR: DATASET.IMAGE_TYPE is {} but the type of image has gray or rgb\n" - .format(cfg.DATASET.IMAGE_TYPE)) - # elif (1 not in img_dim and 3 not in img_dim and 4 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgb': - # print("ERROR: DATASET.IMAGE_TYPE is {} but the type of image is rgba\n".format(cfg.DATASET.IMAGE_TYPE)) + logger.info(error_print("DATASET.IMAGE_TYPE check")) + logger.info("DATASET.IMAGE_TYPE is {} but the type of image has " + "gray or rgb\n".format(cfg.DATASET.IMAGE_TYPE)) + elif (1 not in img_dim and 3 not in img_dim and 4 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgb': + logger.info(correct_print("DATASET.IMAGE_TYPE check")) + logger.info("\nWARNING: DATASET.IMAGE_TYPE is {} but the type of all image is rgba".format(cfg.DATASET.IMAGE_TYPE)) else: - print("DATASET.IMAGE_TYPE setting correct") + logger.info(correct_print("DATASET.IMAGE_TYPE check")) +def shape_check(): + """输出shape校验结果""" + if len(shape_unequal_image) == 0: + logger.info(correct_print("shape check")) + logger.info("All images are the same shape as the labels") + else: + logger.info(error_print("shape check")) + logger.info("Some images are not the same shape as the labels as follow: ") + for i in shape_unequal_image: + logger.debug(i) -def image_label_shape_check(img, grt): - """ - 验证图像和标签的大小是否匹配 - """ - flag = True - img_height = img.shape[0] - img_width = img.shape[1] - grt_height = grt.shape[0] - grt_width = grt.shape[1] +def file_list_check(list_name): + """检查分割符是否复合要求""" + if len(list_wrong) == 0: + logger.info(correct_print(list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check")) + else: + logger.info(error_print(list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check")) + logger.info("The following list is not separated by {}".format(cfg.DATASET.SEPARATOR)) + for i in list_wrong: + logger.debug(i) + +def imread_check(): + if len(imread_failed) == 0: + logger.info(correct_print("dataset reading check")) + logger.info("All images can be read successfully") + else: + logger.info(error_print("dataset reading check")) + logger.info("Failed to read {} images".format(len(imread_failed))) + for i in imread_failed: + logger.debug(i) - if img_height != grt_height or img_width != grt_width: - flag = False - return flag def check_train_dataset(): - train_list = cfg.DATASET.TRAIN_FILE_LIST - print("\ncheck train dataset...") - with open(train_list, 'r') as fid: - img_dim = [] + list_file = cfg.DATASET.TRAIN_FILE_LIST + logger.info("-----------------------------\n1. Check train dataset...") + with open(list_file, 'r') as fid: lines = fid.readlines() for line in tqdm(lines): - parts = line.strip().split(cfg.DATASET.SEPARATOR) + line = line.strip() + parts = line.split(cfg.DATASET.SEPARATOR) if len(parts) != 2: - print( - line, "File list format incorrect! It should be" - " image_name{}label_name\\n ".format(cfg.DATASET.SEPARATOR)) + list_wrong.append(line) continue img_name, grt_name = parts[0], parts[1] img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name) grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name) - img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) - grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) + try: + img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) + grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) + except Exception as e: + imread_failed.append((line, str(e))) + continue - get_image_dim(img, img_dim) + get_image_dim(img) is_equal_img_grt_shape = image_label_shape_check(img, grt) if not is_equal_img_grt_shape: - print(line, - "ERROR: source img and label img must has the same size") + shape_unequal_image.append(line) - png_format, grt_classes, num_of_each_class = ground_truth_check( - grt, grt_path) - sum_gt_check(png_format, grt_classes, num_of_each_class) + png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path) + if not png_format: + png_format_wrong_image.append(line) + is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class) + if not is_label_correct: + label_wrong.append(line) + file_list_check(list_file) + imread_check() gt_check() - image_type_check(img_dim) + shape_check() + + + def check_val_dataset(): - val_list = cfg.DATASET.VAL_FILE_LIST - with open(val_list) as fid: - max_height = 0 - max_width = 0 - min_aspectratio = sys.float_info.max - max_aspectratio = 0.0 - img_dim = [] - print("check val dataset...") + list_file = cfg.DATASET.VAL_FILE_LIST + logger.info("\n-----------------------------\n2. Check val dataset...") + with open(list_file) as fid: lines = fid.readlines() for line in tqdm(lines): - parts = line.strip().split(cfg.DATASET.SEPARATOR) + line = line.strip() + parts = line.split(cfg.DATASET.SEPARATOR) if len(parts) != 2: - print( - line, "File list format incorrect! It should be" - " image_name{}label_name\\n ".format(cfg.DATASET.SEPARATOR)) + list_wrong.append(line) continue img_name, grt_name = parts[0], parts[1] img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name) grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name) - img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) - grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) - - max_height, max_width = get_image_max_height_width( - img, max_height, max_width) - min_aspectratio, max_aspectratio = get_image_min_max_aspectratio( - img, min_aspectratio, max_aspectratio) - get_image_dim(img, img_dim) + try: + img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) + grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) + except Exception as e: + imread_failed.append((line, e.message)) + get_image_max_height_width(img) + get_image_min_max_aspectratio(img) + get_image_dim(img) is_equal_img_grt_shape = image_label_shape_check(img, grt) if not is_equal_img_grt_shape: - print(line, - "ERROR: source img and label img must has the same size") - - png_format, grt_classes, num_of_each_class = ground_truth_check( - grt, grt_path) - sum_gt_check(png_format, grt_classes, num_of_each_class) + shape_unequal_image.append(line) + png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path) + if not png_format: + png_format_wrong_image.append(line) + is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class) + if not is_label_correct: + label_wrong.append(line) + + file_list_check(list_file) + imread_check() gt_check() - - eval_crop_size_check(max_height, max_width, min_aspectratio, - max_aspectratio) image_type_check(img_dim) - + shape_check() + eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio) def check_test_dataset(): - test_list = cfg.DATASET.TEST_FILE_LIST - with open(test_list) as fid: - max_height = 0 - max_width = 0 - min_aspectratio = sys.float_info.max - max_aspectratio = 0.0 - img_dim = [] - print("check test dataset...") + list_file = cfg.DATASET.TEST_FILE_LIST + has_label = False + with open(list_file) as fid: + logger.info("\n-----------------------------\n3. Check test dataset...") lines = fid.readlines() for line in tqdm(lines): - parts = line.strip().split(cfg.DATASET.SEPARATOR) + line = line.strip() + parts = line.split(cfg.DATASET.SEPARATOR) if len(parts) == 1: img_name = parts - img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name) - img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) - + img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name[0]) + try: + img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) + except Exception as e: + imread_failed.append((line, str(e))) + continue elif len(parts) == 2: + has_label = True img_name, grt_name = parts[0], parts[1] img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name) grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name) - img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) - grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) + try: + img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) + grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) + except Exception as e: + imread_failed.append((line, e.message)) + continue is_equal_img_grt_shape = image_label_shape_check(img, grt) if not is_equal_img_grt_shape: - print( - line, - "ERROR: source img and label img must has the same size" - ) - - png_format, grt_classes, num_of_each_class = ground_truth_check( - grt, grt_path) - sum_gt_check(png_format, grt_classes, num_of_each_class) - + shape_unequal_image.append(line) + png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path) + if not png_format: + png_format_wrong_image.append(line) + is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class) + if not is_label_correct: + label_wrong.append(line) else: - print( - line, "File list format incorrect! It should be" - " image_name{}label_name\\n or image_name\n ".format( - cfg.DATASET.SEPARATOR)) + list_wrong.append(lines) continue - - max_height, max_width = get_image_max_height_width( - img, max_height, max_width) - min_aspectratio, max_aspectratio = get_image_min_max_aspectratio( - img, min_aspectratio, max_aspectratio) - get_image_dim(img, img_dim) - - gt_check() - eval_crop_size_check(max_height, max_width, min_aspectratio, - max_aspectratio) + get_image_max_height_width(img) + get_image_min_max_aspectratio(img) + get_image_dim(img) + + file_list_check(list_file) + imread_check() + if has_label: + gt_check() image_type_check(img_dim) - + if has_label: + shape_check() + eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio) def main(args): if args.cfg_file is not None: cfg.update_from_file(args.cfg_file) cfg.check_and_infer(reset_dataset=True) - print(pprint.pformat(cfg)) + logger.info(pprint.pformat(cfg)) init_global_variable() check_train_dataset() @@ -428,8 +476,19 @@ def main(args): inf_resize_value_check() - if __name__ == "__main__": args = parse_args() - args.cfg_file = "../configs/cityscape.yaml" + logger = logging.getLogger() + logger.setLevel('DEBUG') + BASIC_FORMAT = "%(message)s" + formatter = logging.Formatter(BASIC_FORMAT) + sh = logging.StreamHandler() + sh.setFormatter(formatter) + sh.setLevel('INFO') + th = logging.FileHandler('detail.log', 'w') + th.setFormatter(formatter) + logger.addHandler(sh) + logger.addHandler(th) main(args) + + -- GitLab