# -*- coding: utf-8 -*- from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import os import sys import pprint import argparse import cv2 from tqdm import tqdm import imghdr import logging from utils.config import cfg def init_global_variable(): """ 初始化全局变量 """ global png_format_right_num # 格式正确的标注图数量 global png_format_wrong_num # 格式错误的标注图数量 global total_grt_classes # 总的标注类别 global total_num_of_each_class # 每个类别总的像素数 global shape_unequal_image # 图片和标注shape不一致列表 global png_format_wrong_image # 标注格式错误列表 global max_width # 图片最长宽 global max_height # 图片最长高 global min_aspectratio # 图片最小宽高比 global max_aspectratio # 图片最大宽高比 global img_dim # 图片的通道数 global list_wrong #文件名格式错误列表 global imread_failed #图片读取失败列表, 二元列表 global label_wrong # 标注图片出错列表 png_format_right_num = 0 png_format_wrong_num = 0 total_grt_classes = [] total_num_of_each_class = [] shape_unequal_image = [] png_format_wrong_image = [] max_width = 0 max_height = 0 min_aspectratio = sys.float_info.max max_aspectratio = 0 img_dim = [] list_wrong = [] imread_failed = [] label_wrong = [] def parse_args(): parser = argparse.ArgumentParser(description='PaddleSeg check') parser.add_argument( '--cfg', dest='cfg_file', help='Config file for training (and optionally testing)', default=None, type=str ) return parser.parse_args() def error_print(str): return "".join(["\nNOT PASS ", str]) def correct_print(str): return "".join(["\nPASS ", str]) def cv2_imread(file_path, flag=cv2.IMREAD_COLOR): # resolve cv2.imread open Chinese file path issues on Windows Platform. return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag) def get_image_max_height_width(img): global max_width, max_height img_shape = img.shape height, width = img_shape[0], img_shape[1] max_height = max(height, max_height) max_width = max(width, max_width) def get_image_min_max_aspectratio(img): global min_aspectratio, max_aspectratio img_shape = img.shape height, width = img_shape[0], img_shape[1] min_aspectratio = min(width/height, min_aspectratio) max_aspectratio = max(width/height, max_aspectratio) return min_aspectratio, max_aspectratio def get_image_dim(img): """获取图像的维度""" img_shape = img.shape if img_shape[-1] not in img_dim: img_dim.append(img_shape[-1]) def image_label_shape_check(img, grt): """ 验证图像和标注的大小是否匹配 """ flag = True img_height = img.shape[0] img_width = img.shape[1] grt_height = grt.shape[0] grt_width = grt.shape[1] if img_height != grt_height or img_width != grt_width: flag = False return flag def ground_truth_check(grt, grt_path): """ 验证标注是否重零开始,标注值为0,1,...,num_classes-1, ingnore_idx 验证标注图像的格式 返回标注的像素数 检查图像是否都是ignore_index params: grt: 标注图 grt_path: 标注图路径 return: png_format: 返回是否是png格式图片 label_correct: 返回标注是否是正确的 label_pixel_num: 返回标注的像素数 """ if imghdr.what(grt_path) == "png": png_format = True else: png_format = False unique, counts = np.unique(grt, return_counts=True) return png_format, unique, counts def sum_gt_check(png_format, grt_classes, num_of_each_class): """ 统计所有标注图上的格式、类别和每个类别的像素数 params: png_format: 返回是否是png格式图片 grt_classes: 标注类别 num_of_each_class: 各个类别的像素数目 """ is_label_correct = True global png_format_right_num, png_format_wrong_num, total_grt_classes, total_num_of_each_class if png_format: png_format_right_num += 1 else: png_format_wrong_num += 1 if cfg.DATASET.IGNORE_INDEX in grt_classes: grt_classes2 = np.delete(grt_classes, np.where(grt_classes == cfg.DATASET.IGNORE_INDEX)) else: grt_classes2 = grt_classes if min(grt_classes2) < 0 or max(grt_classes2) > cfg.DATASET.NUM_CLASSES - 1: is_label_correct = False add_class = [] add_num = [] for i in range(len(grt_classes)): gi = grt_classes[i] if gi in total_grt_classes: j = total_grt_classes.index(gi) total_num_of_each_class[j] += num_of_each_class[i] else: add_class.append(gi) add_num.append(num_of_each_class[i]) total_num_of_each_class += add_num total_grt_classes += add_class return is_label_correct def gt_check(): """ 对标注图像进行校验,输出校验结果 """ if png_format_wrong_num == 0: if png_format_right_num: logger.info(correct_print("label format check")) else: logger.info(error_print("label format check")) logger.info("No label image to check") return else: logger.info(error_print("label format check")) logger.info("total {} label images are png format, {} label images are not png " "format".format(png_format_right_num, png_format_wrong_num)) if len(png_format_wrong_image) > 0: for i in png_format_wrong_image: logger.debug(i) total_nc = sorted(zip(total_grt_classes, total_num_of_each_class)) logger.info("\nDoing label pixel statistics...\nTotal label calsses " "and their corresponding numbers:\n{} ".format(total_nc)) if len(label_wrong) == 0 and not total_nc[0][0]: logger.info(correct_print("label class check!")) else: logger.info(error_print("label class check!")) if total_nc[0][0]: logger.info("Warning: label classes should start from 0") if len(label_wrong) > 0: logger.info("fatal error: label class is out of range [0, {}]".format(cfg.DATASET.NUM_CLASSES - 1)) for i in label_wrong: logger.debug(i) def eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio): """ 判断eval_crop_siz与验证集及测试集的max_height, max_width的关系 param max_height: 数据集的最大高 max_width: 数据集的最大宽 """ if cfg.AUG.AUG_METHOD == "stepscaling": if max_width <= cfg.EVAL_CROP_SIZE[0] or max_height <= cfg.EVAL_CROP_SIZE[1]: logger.info(correct_print("EVAL_CROP_SIZE check")) else: logger.info(error_print("EVAL_CROP_SIZE check")) if max_width > cfg.EVAL_CROP_SIZE[0]: logger.info("The EVAL_CROP_SIZE[0]: {} should larger max width of images {}!".format( cfg.EVAL_CROP_SIZE[0], max_width)) if max_height > cfg.EVAL_CROP_SIZE[1]: logger.info(error_print("The EVAL_CROP_SIZE[1]: {} should larger max height of images {}!".format( cfg.EVAL_CROP_SIZE[1], max_height))) elif cfg.AUG.AUG_METHOD == "rangescaling": if min_aspectratio <= 1 and max_aspectratio >= 1: if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.INF_RESIZE_VALUE: logger.info(correct_print("EVAL_CROP_SIZE check")) else: logger.info(error_print("EVAL_CROP_SIZE check")) logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.INF_RESIZE_VALUE)) elif min_aspectratio > 1: max_height_rangscaling = cfg.AUG.INF_RESIZE_VALUE / min_aspectratio max_height_rangscaling = round(max_height_rangscaling) if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[1] >= max_height_rangscaling: logger.info(correct_print("EVAL_CROP_SIZE check")) else: logger.info(error_print("EVAL_CROP_SIZE check")) logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], cfg.AUG.INF_RESIZE_VALUE, max_height_rangscaling)) elif max_aspectratio < 1: max_width_rangscaling = cfg.AUG.INF_RESIZE_VALUE * max_aspectratio max_width_rangscaling = round(max_width_rangscaling) if cfg.EVAL_CROP_SIZE[0] >= max_width_rangscaling and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.INF_RESIZE_VALUE: logger.info(correct_print("EVAL_CROP_SIZE check")) else: logger.info(error_print("EVAL_CROP_SIZE check")) logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], max_width_rangscaling, cfg.AUG.INF_RESIZE_VALUE)) elif cfg.AUG.AUG_METHOD == "unpadding": if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.FIX_RESIZE_SIZE[0] and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.FIX_RESIZE_SIZE[1]: logger.info(correct_print("EVAL_CROP_SIZE check")) else: logger.info(error_print("EVAL_CROP_SIZE check")) logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], cfg.AUG.FIX_RESIZE_SIZE[0], cfg.AUG.FIX_RESIZE_SIZE[1])) else: logger.info("\nERROR! cfg.AUG.AUG_METHOD setting wrong, it should be one of " "[unpadding, stepscaling, rangescaling]") def inf_resize_value_check(): if cfg.AUG.AUG_METHOD == "rangescaling": if cfg.AUG.INF_RESIZE_VALUE < cfg.AUG.MIN_RESIZE_VALUE or \ cfg.AUG.INF_RESIZE_VALUE > cfg.AUG.MIN_RESIZE_VALUE: logger.info("\nWARNING! you set AUG.AUG_METHOD = 'rangescaling'" "AUG.INF_RESIZE_VALUE: {} not in [AUG.MIN_RESIZE_VALUE, AUG.MAX_RESIZE_VALUE]: " "[{}, {}].".format(cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.MIN_RESIZE_VALUE, cfg.AUG.MAX_RESIZE_VALUE)) def image_type_check(img_dim): """ 验证图像的格式与DATASET.IMAGE_TYPE是否一致 param img_dim: 图像包含的通道数 return """ if (1 in img_dim or 3 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgba': logger.info(error_print("DATASET.IMAGE_TYPE check")) logger.info("DATASET.IMAGE_TYPE is {} but the type of image has " "gray or rgb\n".format(cfg.DATASET.IMAGE_TYPE)) elif (1 not in img_dim and 3 not in img_dim and 4 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgb': logger.info(correct_print("DATASET.IMAGE_TYPE check")) logger.info("\nWARNING: DATASET.IMAGE_TYPE is {} but the type of all image is rgba".format(cfg.DATASET.IMAGE_TYPE)) else: logger.info(correct_print("DATASET.IMAGE_TYPE check")) def shape_check(): """输出shape校验结果""" if len(shape_unequal_image) == 0: logger.info(correct_print("shape check")) logger.info("All images are the same shape as the labels") else: logger.info(error_print("shape check")) logger.info("Some images are not the same shape as the labels as follow: ") for i in shape_unequal_image: logger.debug(i) def file_list_check(list_name): """检查分割符是否复合要求""" if len(list_wrong) == 0: logger.info(correct_print(list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check")) else: logger.info(error_print(list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check")) logger.info("The following list is not separated by {}".format(cfg.DATASET.SEPARATOR)) for i in list_wrong: logger.debug(i) def imread_check(): if len(imread_failed) == 0: logger.info(correct_print("dataset reading check")) logger.info("All images can be read successfully") else: logger.info(error_print("dataset reading check")) logger.info("Failed to read {} images".format(len(imread_failed))) for i in imread_failed: logger.debug(i) def check_train_dataset(): list_file = cfg.DATASET.TRAIN_FILE_LIST logger.info("-----------------------------\n1. Check train dataset...") with open(list_file, 'r') as fid: lines = fid.readlines() for line in tqdm(lines): line = line.strip() parts = line.split(cfg.DATASET.SEPARATOR) if len(parts) != 2: list_wrong.append(line) continue img_name, grt_name = parts[0], parts[1] img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name) grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name) try: img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) except Exception as e: imread_failed.append((line, str(e))) continue get_image_dim(img) is_equal_img_grt_shape = image_label_shape_check(img, grt) if not is_equal_img_grt_shape: shape_unequal_image.append(line) png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path) if not png_format: png_format_wrong_image.append(line) is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class) if not is_label_correct: label_wrong.append(line) file_list_check(list_file) imread_check() gt_check() image_type_check(img_dim) shape_check() def check_val_dataset(): list_file = cfg.DATASET.VAL_FILE_LIST logger.info("\n-----------------------------\n2. Check val dataset...") with open(list_file) as fid: lines = fid.readlines() for line in tqdm(lines): line = line.strip() parts = line.split(cfg.DATASET.SEPARATOR) if len(parts) != 2: list_wrong.append(line) continue img_name, grt_name = parts[0], parts[1] img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name) grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name) try: img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) except Exception as e: imread_failed.append((line, e.message)) get_image_max_height_width(img) get_image_min_max_aspectratio(img) get_image_dim(img) is_equal_img_grt_shape = image_label_shape_check(img, grt) if not is_equal_img_grt_shape: shape_unequal_image.append(line) png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path) if not png_format: png_format_wrong_image.append(line) is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class) if not is_label_correct: label_wrong.append(line) file_list_check(list_file) imread_check() gt_check() image_type_check(img_dim) shape_check() eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio) def check_test_dataset(): list_file = cfg.DATASET.TEST_FILE_LIST has_label = False with open(list_file) as fid: logger.info("\n-----------------------------\n3. Check test dataset...") lines = fid.readlines() for line in tqdm(lines): line = line.strip() parts = line.split(cfg.DATASET.SEPARATOR) if len(parts) == 1: img_name = parts img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name[0]) try: img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) except Exception as e: imread_failed.append((line, str(e))) continue elif len(parts) == 2: has_label = True img_name, grt_name = parts[0], parts[1] img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name) grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name) try: img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) except Exception as e: imread_failed.append((line, e.message)) continue is_equal_img_grt_shape = image_label_shape_check(img, grt) if not is_equal_img_grt_shape: shape_unequal_image.append(line) png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path) if not png_format: png_format_wrong_image.append(line) is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class) if not is_label_correct: label_wrong.append(line) else: list_wrong.append(lines) continue get_image_max_height_width(img) get_image_min_max_aspectratio(img) get_image_dim(img) file_list_check(list_file) imread_check() if has_label: gt_check() image_type_check(img_dim) if has_label: shape_check() eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio) def main(args): if args.cfg_file is not None: cfg.update_from_file(args.cfg_file) cfg.check_and_infer(reset_dataset=True) logger.info(pprint.pformat(cfg)) init_global_variable() check_train_dataset() init_global_variable() check_val_dataset() init_global_variable() check_test_dataset() inf_resize_value_check() if __name__ == "__main__": args = parse_args() logger = logging.getLogger() logger.setLevel('DEBUG') BASIC_FORMAT = "%(message)s" formatter = logging.Formatter(BASIC_FORMAT) sh = logging.StreamHandler() sh.setFormatter(formatter) sh.setLevel('INFO') th = logging.FileHandler('detail.log', 'w') th.setFormatter(formatter) logger.addHandler(sh) logger.addHandler(th) main(args)