check.py 15.6 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435
# coding: utf8

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import os
import sys
import pprint
import argparse
import cv2
from tqdm import tqdm
import imghdr

from utils.config import cfg


def init_global_variable():
    """
    初始化全局变量
    """
    global png_format_right_num  # 格式错误的标签图数量
    global png_format_wrong_num  # 格式错误的标签图数量
    global total_grt_classes  # 总的标签类别
    global total_num_of_each_class  # 每个类别总的像素数
    global shape_unequal  # 图片和标签shape不一致
    global png_format_wrong  # 标签格式错误

    png_format_right_num = 0
    png_format_wrong_num = 0
    total_grt_classes = []
    total_num_of_each_class = []
    shape_unequal = []
    png_format_wrong = []


def parse_args():
    parser = argparse.ArgumentParser(description='PaddleSeg check')
    parser.add_argument(
        '--cfg',
        dest='cfg_file',
        help='Config file for training (and optionally testing)',
        default=None,
        type=str)
    return parser.parse_args()


def cv2_imread(file_path, flag=cv2.IMREAD_COLOR):
    # resolve cv2.imread open Chinese file path issues on Windows Platform.
    return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag)


def get_image_max_height_width(img, max_height, max_width):
    img_shape = img.shape
    height, width = img_shape[0], img_shape[1]
    max_height = max(height, max_height)
    max_width = max(width, max_width)
    return max_height, max_width


def get_image_min_max_aspectratio(img, min_aspectratio, max_aspectratio):
    img_shape = img.shape
    height, width = img_shape[0], img_shape[1]
    min_aspectratio = min(width / height, min_aspectratio)
    max_aspectratio = max(width / height, max_aspectratio)
    return min_aspectratio, max_aspectratio


def get_image_dim(img, img_dim):
    """获取图像的维度"""
    img_shape = img.shape
    if img_shape[-1] not in img_dim:
        img_dim.append(img_shape[-1])


def sum_gt_check(png_format, grt_classes, num_of_each_class):
    """
    统计所有标签图上的格式、类别和每个类别的像素数
    params:
        png_format: 返回是否是png格式图片
        grt_classes: 标签类别
        num_of_each_class: 各个类别的像素数目
    """
    global png_format_right_num, png_format_wrong_num, total_grt_classes, total_num_of_each_class

    if png_format:
        png_format_right_num += 1
    else:
        png_format_wrong_num += 1

    if cfg.DATASET.IGNORE_INDEX in grt_classes:
        grt_classes2 = np.delete(
            grt_classes, np.where(grt_classes == cfg.DATASET.IGNORE_INDEX))
    if min(grt_classes2) < 0 or max(grt_classes2) > cfg.DATASET.NUM_CLASSES - 1:
        print("fatal error: label class is out of range [0, {}]".format(
            cfg.DATASET.NUM_CLASSES - 1))

    add_class = []
    add_num = []
    for i in range(len(grt_classes)):
        gi = grt_classes[i]
        if gi in total_grt_classes:
            j = total_grt_classes.index(gi)
            total_num_of_each_class[j] += num_of_each_class[i]
        else:
            add_class.append(gi)
            add_num.append(num_of_each_class[i])
    total_num_of_each_class += add_num
    total_grt_classes += add_class


def gt_check():
    """
    对标签进行校验,输出校验结果
    params:
         png_format_wrong_num: 格式错误的标签图数量
         png_format_right_num:  格式正确的标签图数量
         total_grt_classes: 总的标签类别
         total_num_of_each_class: 每个类别总的像素数目
    return:
        total_nc: 按升序排序后的总标签类别和像素数目
    """
    if png_format_wrong_num == 0:
        print("Not pass label png format check!")
    else:
        print("Pass label png format check!")
    print(
        "total {} label imgs are png format, {} label imgs are not png fromat".
        format(png_format_right_num, png_format_wrong_num))

    total_nc = sorted(zip(total_grt_classes, total_num_of_each_class))
    print("total label calsses and their corresponding numbers:\n{} ".format(
        total_nc))
    if total_nc[0][0]:
        print(
            "Not pass label class check!\nWarning: label classes should start from 0 !!!"
        )
    else:
        print("Pass label class check!")


def ground_truth_check(grt, grt_path):
    """
    验证标签是否重零开始,标签值为0,1,...,num_classes-1, ingnore_idx
    验证标签图像的格式
    返回标签的像素数
    检查图像是否都是ignore_index
    params:
        grt: 标签图
        grt_path: 标签图路径
    return:
        png_format: 返回是否是png格式图片
        label_correct: 返回标签是否是正确的
        label_pixel_num: 返回标签的像素数
    """
    if imghdr.what(grt_path) == "png":
        png_format = True
    else:
        png_format = False

    unique, counts = np.unique(grt, return_counts=True)

    return png_format, unique, counts


def eval_crop_size_check(max_height, max_width, min_aspectratio,
                         max_aspectratio):
    """
    判断eval_crop_siz与验证集及测试集的max_height, max_width的关系
    param
        max_height: 数据集的最大高
        max_width: 数据集的最大宽
    """
    if cfg.AUG.AUG_METHOD == "stepscaling":
        flag = True
        if max_width > cfg.EVAL_CROP_SIZE[0]:
            print(
                "ERROR: The EVAL_CROP_SIZE[0]: {} should larger max width of images {}!"
                .format(cfg.EVAL_CROP_SIZE[0], max_width))
            flag = False
        if max_height > cfg.EVAL_CROP_SIZE[1]:
            print(
                "ERROR: The EVAL_CROP_SIZE[1]: {} should larger max height of images {}!"
                .format(cfg.EVAL_CROP_SIZE[1], max_height))
            flag = False
        if flag:
            print("EVAL_CROP_SIZE setting correct")
    elif cfg.AUG.AUG_METHOD == "rangescaling":
        if min_aspectratio <= 1 and max_aspectratio >= 1:
            if cfg.EVAL_CROP_SIZE[
                    0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[
                        1] >= cfg.AUG.INF_RESIZE_VALUE:
                print("EVAL_CROP_SIZE setting correct")
            else:
                print(
                    "ERROR: EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
                    .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
                            cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.INF_RESIZE_VALUE))
        elif min_aspectratio > 1:
            max_height_rangscaling = cfg.AUG.INF_RESIZE_VALUE / min_aspectratio
            max_height_rangscaling = round(max_height_rangscaling)
            if cfg.EVAL_CROP_SIZE[
                    0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[
                        1] >= max_height_rangscaling:
                print("EVAL_CROP_SIZE setting correct")
            else:
                print(
                    "ERROR: EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
                    .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
                            cfg.AUG.INF_RESIZE_VALUE, max_height_rangscaling))
        elif max_aspectratio < 1:
            max_width_rangscaling = cfg.AUG.INF_RESIZE_VALUE * max_aspectratio
            max_width_rangscaling = round(max_width_rangscaling)
            if cfg.EVAL_CROP_SIZE[
                    0] >= max_width_rangscaling and cfg.EVAL_CROP_SIZE[
                        1] >= cfg.AUG.INF_RESIZE_VALUE:
                print("EVAL_CROP_SIZE setting correct")
            else:
                print(
                    "ERROR: EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
                    .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
                            max_width_rangscaling, cfg.AUG.INF_RESIZE_VALUE))
    elif cfg.AUG.AUG_METHOD == "unpadding":
        if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.FIX_RESIZE_SIZE[
                0] and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.FIX_RESIZE_SIZE[1]:
            print("EVAL_CROP_SIZE setting correct")
        else:
            print(
                "ERROR: EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
                .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
                        cfg.AUG.FIX_RESIZE_SIZE[0], cfg.AUG.FIX_RESIZE_SIZE[1]))
    else:
        print(
            "ERROR: cfg.AUG.AUG_METHOD setting wrong, it should be one of [unpadding, stepscaling, rangescaling]"
        )


def inf_resize_value_check():
    if cfg.AUG.AUG_METHOD == "rangescaling":
        if cfg.AUG.INF_RESIZE_VALUE < cfg.AUG.MIN_RESIZE_VALUE or \
                cfg.AUG.INF_RESIZE_VALUE > cfg.AUG.MIN_RESIZE_VALUE:
            print(
                "ERROR: you set AUG.AUG_METHOD = 'rangescaling'"
                "AUG.INF_RESIZE_VALUE: {} not in [AUG.MIN_RESIZE_VALUE, AUG.MAX_RESIZE_VALUE]: "
                "[{}, {}].".format(cfg.AUG.INF_RESIZE_VALUE,
                                   cfg.AUG.MIN_RESIZE_VALUE,
                                   cfg.AUG.MAX_RESIZE_VALUE))


def image_type_check(img_dim):
    """
    验证图像的格式与DATASET.IMAGE_TYPE是否一致
    param
        img_dim: 图像包含的通道数
    return
    """
    if (1 in img_dim or 3 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgba':
        print(
            "ERROR: DATASET.IMAGE_TYPE is {} but the type of image has gray or rgb\n"
            .format(cfg.DATASET.IMAGE_TYPE))
    # elif (1 not in img_dim and 3 not in img_dim and 4 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgb':
    #     print("ERROR: DATASET.IMAGE_TYPE is {} but the type of image is rgba\n".format(cfg.DATASET.IMAGE_TYPE))
    else:
        print("DATASET.IMAGE_TYPE setting correct")


def image_label_shape_check(img, grt):
    """
    验证图像和标签的大小是否匹配
    """

    flag = True
    img_height = img.shape[0]
    img_width = img.shape[1]
    grt_height = grt.shape[0]
    grt_width = grt.shape[1]

    if img_height != grt_height or img_width != grt_width:
        flag = False
    return flag


def check_train_dataset():
    train_list = cfg.DATASET.TRAIN_FILE_LIST
    print("\ncheck train dataset...")
    with open(train_list, 'r') as fid:
        img_dim = []
        lines = fid.readlines()
        for line in tqdm(lines):
            parts = line.strip().split(cfg.DATASET.SEPARATOR)
            if len(parts) != 2:
                print(
                    line, "File list format incorrect! It should be"
                    " image_name{}label_name\\n ".format(cfg.DATASET.SEPARATOR))
                continue
            img_name, grt_name = parts[0], parts[1]
            img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name)
            grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
            img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
            grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)

            get_image_dim(img, img_dim)
            is_equal_img_grt_shape = image_label_shape_check(img, grt)
            if not is_equal_img_grt_shape:
                print(line,
                      "ERROR: source img and label img must has the same size")

            png_format, grt_classes, num_of_each_class = ground_truth_check(
                grt, grt_path)
            sum_gt_check(png_format, grt_classes, num_of_each_class)

        gt_check()

        image_type_check(img_dim)


def check_val_dataset():
    val_list = cfg.DATASET.VAL_FILE_LIST
    with open(val_list) as fid:
        max_height = 0
        max_width = 0
        min_aspectratio = sys.float_info.max
        max_aspectratio = 0.0
        img_dim = []
        print("check val dataset...")
        lines = fid.readlines()
        for line in tqdm(lines):
            parts = line.strip().split(cfg.DATASET.SEPARATOR)
            if len(parts) != 2:
                print(
                    line, "File list format incorrect! It should be"
                    " image_name{}label_name\\n ".format(cfg.DATASET.SEPARATOR))
                continue
            img_name, grt_name = parts[0], parts[1]
            img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name)
            grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
            img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
            grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)

            max_height, max_width = get_image_max_height_width(
                img, max_height, max_width)
            min_aspectratio, max_aspectratio = get_image_min_max_aspectratio(
                img, min_aspectratio, max_aspectratio)
            get_image_dim(img, img_dim)
            is_equal_img_grt_shape = image_label_shape_check(img, grt)
            if not is_equal_img_grt_shape:
                print(line,
                      "ERROR: source img and label img must has the same size")

            png_format, grt_classes, num_of_each_class = ground_truth_check(
                grt, grt_path)
            sum_gt_check(png_format, grt_classes, num_of_each_class)
        gt_check()

        eval_crop_size_check(max_height, max_width, min_aspectratio,
                             max_aspectratio)
        image_type_check(img_dim)


def check_test_dataset():
    test_list = cfg.DATASET.TEST_FILE_LIST
    with open(test_list) as fid:
        max_height = 0
        max_width = 0
        min_aspectratio = sys.float_info.max
        max_aspectratio = 0.0
        img_dim = []
        print("check test dataset...")
        lines = fid.readlines()
        for line in tqdm(lines):
            parts = line.strip().split(cfg.DATASET.SEPARATOR)
            if len(parts) == 1:
                img_name = parts
                img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name)
                img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)

            elif len(parts) == 2:
                img_name, grt_name = parts[0], parts[1]
                img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name)
                grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
                img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
                grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
                is_equal_img_grt_shape = image_label_shape_check(img, grt)
                if not is_equal_img_grt_shape:
                    print(
                        line,
                        "ERROR: source img and label img must has the same size"
                    )

                png_format, grt_classes, num_of_each_class = ground_truth_check(
                    grt, grt_path)
                sum_gt_check(png_format, grt_classes, num_of_each_class)

            else:
                print(
                    line, "File list format incorrect! It should be"
                    " image_name{}label_name\\n or image_name\n ".format(
                        cfg.DATASET.SEPARATOR))
                continue

            max_height, max_width = get_image_max_height_width(
                img, max_height, max_width)
            min_aspectratio, max_aspectratio = get_image_min_max_aspectratio(
                img, min_aspectratio, max_aspectratio)
            get_image_dim(img, img_dim)

        gt_check()
        eval_crop_size_check(max_height, max_width, min_aspectratio,
                             max_aspectratio)
        image_type_check(img_dim)


def main(args):
    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
    cfg.check_and_infer(reset_dataset=True)
    print(pprint.pformat(cfg))

    init_global_variable()
    check_train_dataset()

    init_global_variable()
    check_val_dataset()

    init_global_variable()
    check_test_dataset()

    inf_resize_value_check()


if __name__ == "__main__":
    args = parse_args()
    args.cfg_file = "../configs/cityscape.yaml"
    main(args)