check.py 23.8 KB
Newer Older
C
chenguowei01 已提交
1
# coding: utf8
W
wuyefeilin 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
W
wuzewu 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import os
import sys
import pprint
import argparse
import cv2
from tqdm import tqdm
import imghdr
C
chenguowei01 已提交
28
import logging
W
wuzewu 已提交
29 30

from utils.config import cfg
L
LutaoChu 已提交
31
from reader import pil_imread
W
wuzewu 已提交
32

L
LutaoChu 已提交
33

W
wuzewu 已提交
34 35 36 37
def init_global_variable():
    """
    初始化全局变量
    """
C
chenguowei01 已提交
38
    global png_format_right_num  # 格式正确的标注图数量
C
chenguowei01 已提交
39 40
    global png_format_wrong_num  # 格式错误的标注图数量
    global total_grt_classes  # 总的标注类别
W
wuzewu 已提交
41
    global total_num_of_each_class  # 每个类别总的像素数
C
chenguowei01 已提交
42 43 44 45 46 47 48
    global shape_unequal_image  # 图片和标注shape不一致列表
    global png_format_wrong_image  # 标注格式错误列表
    global max_width  # 图片最长宽
    global max_height  # 图片最长高
    global min_aspectratio  # 图片最小宽高比
    global max_aspectratio  # 图片最大宽高比
    global img_dim  # 图片的通道数
L
LutaoChu 已提交
49 50
    global list_wrong  # 文件名格式错误列表
    global imread_failed  # 图片读取失败列表, 二元列表
C
chenguowei01 已提交
51
    global label_wrong  # 标注图片出错列表
C
chenguowei01 已提交
52
    global label_gray_wrong  # 标注图非灰度图列表
W
wuzewu 已提交
53 54 55 56 57

    png_format_right_num = 0
    png_format_wrong_num = 0
    total_grt_classes = []
    total_num_of_each_class = []
C
chenguowei01 已提交
58 59 60 61 62 63 64 65 66 67
    shape_unequal_image = []
    png_format_wrong_image = []
    max_width = 0
    max_height = 0
    min_aspectratio = sys.float_info.max
    max_aspectratio = 0
    img_dim = []
    list_wrong = []
    imread_failed = []
    label_wrong = []
C
chenguowei01 已提交
68
    label_gray_wrong = []
W
wuzewu 已提交
69

L
LutaoChu 已提交
70

W
wuzewu 已提交
71 72 73
def parse_args():
    parser = argparse.ArgumentParser(description='PaddleSeg check')
    parser.add_argument(
L
LutaoChu 已提交
74 75 76 77 78
        '--cfg',
        dest='cfg_file',
        help='Config file for training (and optionally testing)',
        default=None,
        type=str)
W
wuzewu 已提交
79 80
    return parser.parse_args()

L
LutaoChu 已提交
81

C
chenguowei01 已提交
82 83 84
def error_print(str):
    return "".join(["\nNOT PASS ", str])

L
LutaoChu 已提交
85

C
chenguowei01 已提交
86 87
def correct_print(str):
    return "".join(["\nPASS ", str])
W
wuzewu 已提交
88

L
LutaoChu 已提交
89

W
wuzewu 已提交
90
def cv2_imread(file_path, flag=cv2.IMREAD_COLOR):
C
chenguowei01 已提交
91 92 93
    """
    解决 cv2.imread 在window平台打开中文路径的问题.
    """
W
wuzewu 已提交
94 95
    return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag)

L
LutaoChu 已提交
96

C
chenguowei01 已提交
97
def get_image_max_height_width(img):
C
chenguowei01 已提交
98
    """获取图片最大宽和高"""
C
chenguowei01 已提交
99
    global max_width, max_height
W
wuzewu 已提交
100 101 102 103 104
    img_shape = img.shape
    height, width = img_shape[0], img_shape[1]
    max_height = max(height, max_height)
    max_width = max(width, max_width)

L
LutaoChu 已提交
105

C
chenguowei01 已提交
106
def get_image_min_max_aspectratio(img):
C
chenguowei01 已提交
107
    """计算图片最大宽高比"""
C
chenguowei01 已提交
108
    global min_aspectratio, max_aspectratio
W
wuzewu 已提交
109 110
    img_shape = img.shape
    height, width = img_shape[0], img_shape[1]
L
LutaoChu 已提交
111 112
    min_aspectratio = min(width / height, min_aspectratio)
    max_aspectratio = max(width / height, max_aspectratio)
W
wuzewu 已提交
113 114
    return min_aspectratio, max_aspectratio

L
LutaoChu 已提交
115

C
chenguowei01 已提交
116
def get_image_dim(img):
C
chenguowei01 已提交
117
    """获取图像的通道数"""
W
wuzewu 已提交
118 119 120 121
    img_shape = img.shape
    if img_shape[-1] not in img_dim:
        img_dim.append(img_shape[-1])

L
LutaoChu 已提交
122

C
chenguowei01 已提交
123 124 125 126 127 128 129 130
def is_label_gray(grt):
    """判断标签是否为灰度图"""
    grt_shape = grt.shape
    if len(grt_shape) == 2:
        return True
    else:
        return False

L
LutaoChu 已提交
131

C
chenguowei01 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
def image_label_shape_check(img, grt):
    """
    验证图像和标注的大小是否匹配
    """

    flag = True
    img_height = img.shape[0]
    img_width = img.shape[1]
    grt_height = grt.shape[0]
    grt_width = grt.shape[1]

    if img_height != grt_height or img_width != grt_width:
        flag = False
    return flag

L
LutaoChu 已提交
147

C
chenguowei01 已提交
148 149 150
def ground_truth_check(grt, grt_path):
    """
    验证标注图像的格式
C
chenguowei01 已提交
151
    统计标注图类别和像素数
C
chenguowei01 已提交
152 153 154 155 156
    params:
        grt: 标注图
        grt_path: 标注图路径
    return:
        png_format: 返回是否是png格式图片
C
chenguowei01 已提交
157 158
        unique: 返回标注类别
        counts: 返回标注的像素数
C
chenguowei01 已提交
159 160 161 162 163 164 165 166 167
    """
    if imghdr.what(grt_path) == "png":
        png_format = True
    else:
        png_format = False

    unique, counts = np.unique(grt, return_counts=True)

    return png_format, unique, counts
W
wuzewu 已提交
168

L
LutaoChu 已提交
169

W
wuzewu 已提交
170 171
def sum_gt_check(png_format, grt_classes, num_of_each_class):
    """
C
chenguowei01 已提交
172
    统计所有标注图上的格式、类别和每个类别的像素数
W
wuzewu 已提交
173
    params:
C
chenguowei01 已提交
174
        png_format: 是否是png格式图片
C
chenguowei01 已提交
175
        grt_classes: 标注类别
W
wuzewu 已提交
176 177
        num_of_each_class: 各个类别的像素数目
    """
C
chenguowei01 已提交
178
    is_label_correct = True
W
wuzewu 已提交
179 180 181 182 183 184 185 186
    global png_format_right_num, png_format_wrong_num, total_grt_classes, total_num_of_each_class

    if png_format:
        png_format_right_num += 1
    else:
        png_format_wrong_num += 1

    if cfg.DATASET.IGNORE_INDEX in grt_classes:
L
LutaoChu 已提交
187 188
        grt_classes2 = np.delete(
            grt_classes, np.where(grt_classes == cfg.DATASET.IGNORE_INDEX))
C
chenguowei01 已提交
189 190
    else:
        grt_classes2 = grt_classes
W
wuzewu 已提交
191
    if min(grt_classes2) < 0 or max(grt_classes2) > cfg.DATASET.NUM_CLASSES - 1:
C
chenguowei01 已提交
192
        is_label_correct = False
W
wuzewu 已提交
193 194 195 196 197 198 199 200 201 202 203 204
    add_class = []
    add_num = []
    for i in range(len(grt_classes)):
        gi = grt_classes[i]
        if gi in total_grt_classes:
            j = total_grt_classes.index(gi)
            total_num_of_each_class[j] += num_of_each_class[i]
        else:
            add_class.append(gi)
            add_num.append(num_of_each_class[i])
    total_num_of_each_class += add_num
    total_grt_classes += add_class
C
chenguowei01 已提交
205
    return is_label_correct
W
wuzewu 已提交
206

L
LutaoChu 已提交
207

W
wuzewu 已提交
208 209
def gt_check():
    """
C
chenguowei01 已提交
210
    对标注图像进行校验,输出校验结果
W
wuzewu 已提交
211 212
    """
    if png_format_wrong_num == 0:
C
chenguowei01 已提交
213 214 215 216 217 218
        if png_format_right_num:
            logger.info(correct_print("label format check"))
        else:
            logger.info(error_print("label format check"))
            logger.info("No label image to check")
            return
W
wuzewu 已提交
219
    else:
C
chenguowei01 已提交
220
        logger.info(error_print("label format check"))
L
LutaoChu 已提交
221 222 223
    logger.info(
        "total {} label images are png format, {} label images are not png "
        "format".format(png_format_right_num, png_format_wrong_num))
C
chenguowei01 已提交
224 225 226
    if len(png_format_wrong_image) > 0:
        for i in png_format_wrong_image:
            logger.debug(i)
W
wuzewu 已提交
227

L
LutaoChu 已提交
228 229 230 231 232 233 234
    total_ratio = total_num_of_each_class / sum(total_num_of_each_class)
    total_ratio = np.around(total_ratio, decimals=4)
    total_nc = sorted(
        zip(total_grt_classes, total_num_of_each_class, total_ratio))
    logger.info(
        "\nDoing label pixel statistics:\n"
        "(label class, total pixel number, percentage) = {} ".format(total_nc))
W
wuzewu 已提交
235

C
chenguowei01 已提交
236 237
    if len(label_wrong) == 0 and not total_nc[0][0]:
        logger.info(correct_print("label class check!"))
W
wuzewu 已提交
238
    else:
C
chenguowei01 已提交
239 240 241 242
        logger.info(error_print("label class check!"))
        if total_nc[0][0]:
            logger.info("Warning: label classes should start from 0")
        if len(label_wrong) > 0:
L
LutaoChu 已提交
243 244 245
            logger.info(
                "fatal error: label class is out of range [0, {}]".format(
                    cfg.DATASET.NUM_CLASSES - 1))
C
chenguowei01 已提交
246 247
            for i in label_wrong:
                logger.debug(i)
W
wuzewu 已提交
248 249


L
LutaoChu 已提交
250 251
def eval_crop_size_check(max_height, max_width, min_aspectratio,
                         max_aspectratio):
W
wuzewu 已提交
252 253 254 255 256 257
    """
    判断eval_crop_siz与验证集及测试集的max_height, max_width的关系
    param
        max_height: 数据集的最大高
        max_width: 数据集的最大宽
    """
C
chenguowei01 已提交
258

W
wuzewu 已提交
259
    if cfg.AUG.AUG_METHOD == "stepscaling":
L
LutaoChu 已提交
260 261
        if max_width <= cfg.EVAL_CROP_SIZE[
                0] and max_height <= cfg.EVAL_CROP_SIZE[1]:
C
chenguowei01 已提交
262
            logger.info(correct_print("EVAL_CROP_SIZE check"))
L
LutaoChu 已提交
263 264 265 266
            logger.info(
                "satisfy current EVAL_CROP_SIZE: ({},{}) >= max width and max height of images: ({},{})"
                .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], max_width,
                        max_height))
C
chenguowei01 已提交
267 268 269
        else:
            logger.info(error_print("EVAL_CROP_SIZE check"))
            if max_width > cfg.EVAL_CROP_SIZE[0]:
L
LutaoChu 已提交
270 271 272
                logger.info(
                    "EVAL_CROP_SIZE[0]: {} should >= max width of images {}!".
                    format(cfg.EVAL_CROP_SIZE[0], max_width))
C
chenguowei01 已提交
273
            if max_height > cfg.EVAL_CROP_SIZE[1]:
L
LutaoChu 已提交
274 275 276
                logger.info(
                    "EVAL_CROP_SIZE[1]: {} should >= max height of images {}!".
                    format(cfg.EVAL_CROP_SIZE[1], max_height))
C
chenguowei01 已提交
277

W
wuzewu 已提交
278 279
    elif cfg.AUG.AUG_METHOD == "rangescaling":
        if min_aspectratio <= 1 and max_aspectratio >= 1:
L
LutaoChu 已提交
280 281
            if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.INF_RESIZE_VALUE \
                    and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.INF_RESIZE_VALUE:
C
chenguowei01 已提交
282
                logger.info(correct_print("EVAL_CROP_SIZE check"))
L
LutaoChu 已提交
283 284 285 286
                logger.info(
                    "satisfy current EVAL_CROP_SIZE: ({},{}) >= ({},{}) ".
                    format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
                           cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.INF_RESIZE_VALUE))
W
wuzewu 已提交
287
            else:
C
chenguowei01 已提交
288
                logger.info(error_print("EVAL_CROP_SIZE check"))
L
LutaoChu 已提交
289 290 291 292
                logger.info(
                    "EVAL_CROP_SIZE must >= img size({},{}), current EVAL_CROP_SIZE is ({},{})"
                    .format(cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.INF_RESIZE_VALUE,
                            cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1]))
W
wuzewu 已提交
293 294 295
        elif min_aspectratio > 1:
            max_height_rangscaling = cfg.AUG.INF_RESIZE_VALUE / min_aspectratio
            max_height_rangscaling = round(max_height_rangscaling)
L
LutaoChu 已提交
296 297 298
            if cfg.EVAL_CROP_SIZE[
                    0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[
                        1] >= max_height_rangscaling:
C
chenguowei01 已提交
299
                logger.info(correct_print("EVAL_CROP_SIZE check"))
L
LutaoChu 已提交
300 301 302 303
                logger.info(
                    "satisfy current EVAL_CROP_SIZE: ({},{}) >= ({},{}) ".
                    format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
                           cfg.AUG.INF_RESIZE_VALUE, max_height_rangscaling))
W
wuzewu 已提交
304
            else:
C
chenguowei01 已提交
305
                logger.info(error_print("EVAL_CROP_SIZE check"))
L
LutaoChu 已提交
306 307 308 309
                logger.info(
                    "EVAL_CROP_SIZE must >= img size({},{}), current EVAL_CROP_SIZE is ({},{})"
                    .format(cfg.AUG.INF_RESIZE_VALUE, max_height_rangscaling,
                            cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1]))
W
wuzewu 已提交
310 311 312
        elif max_aspectratio < 1:
            max_width_rangscaling = cfg.AUG.INF_RESIZE_VALUE * max_aspectratio
            max_width_rangscaling = round(max_width_rangscaling)
L
LutaoChu 已提交
313 314 315
            if cfg.EVAL_CROP_SIZE[
                    0] >= max_width_rangscaling and cfg.EVAL_CROP_SIZE[
                        1] >= cfg.AUG.INF_RESIZE_VALUE:
C
chenguowei01 已提交
316
                logger.info(correct_print("EVAL_CROP_SIZE check"))
L
LutaoChu 已提交
317 318 319 320
                logger.info(
                    "satisfy current EVAL_CROP_SIZE: ({},{}) >= ({},{}) ".
                    format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
                           max_height_rangscaling, cfg.AUG.INF_RESIZE_VALUE))
W
wuzewu 已提交
321
            else:
C
chenguowei01 已提交
322
                logger.info(error_print("EVAL_CROP_SIZE check"))
L
LutaoChu 已提交
323 324 325 326
                logger.info(
                    "EVAL_CROP_SIZE must >= img size({},{}), current EVAL_CROP_SIZE is ({},{})"
                    .format(max_width_rangscaling, cfg.AUG.INF_RESIZE_VALUE,
                            cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1]))
W
wuzewu 已提交
327
    elif cfg.AUG.AUG_METHOD == "unpadding":
328 329
        if len(cfg.AUG.FIX_RESIZE_SIZE) != 2:
            logger.info(error_print("EVAL_CROP_SIZE check"))
L
LutaoChu 已提交
330 331 332 333 334
            logger.info(
                "you set AUG.AUG_METHOD = 'unpadding', but AUG.FIX_RESIZE_SIZE is wrong. "
                "AUG.FIX_RESIZE_SIZE should be a tuple of length 2")
        elif cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.FIX_RESIZE_SIZE[0] \
                and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.FIX_RESIZE_SIZE[1]:
C
chenguowei01 已提交
335
            logger.info(correct_print("EVAL_CROP_SIZE check"))
L
LutaoChu 已提交
336 337 338 339
            logger.info(
                "satisfy current EVAL_CROP_SIZE: ({},{}) >= AUG.FIX_RESIZE_SIZE: ({},{}) "
                .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
                        cfg.AUG.FIX_RESIZE_SIZE[0], cfg.AUG.FIX_RESIZE_SIZE[1]))
W
wuzewu 已提交
340
        else:
C
chenguowei01 已提交
341
            logger.info(error_print("EVAL_CROP_SIZE check"))
L
LutaoChu 已提交
342 343 344 345
            logger.info(
                "EVAL_CROP_SIZE: ({},{}) must >= AUG.FIX_RESIZE_SIZE: ({},{})".
                format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
                       cfg.AUG.FIX_RESIZE_SIZE[0], cfg.AUG.FIX_RESIZE_SIZE[1]))
W
wuzewu 已提交
346
    else:
L
LutaoChu 已提交
347 348 349 350
        logger.info(
            "\nERROR! cfg.AUG.AUG_METHOD setting wrong, it should be one of "
            "[unpadding, stepscaling, rangescaling]")

W
wuzewu 已提交
351 352 353 354 355

def inf_resize_value_check():
    if cfg.AUG.AUG_METHOD == "rangescaling":
        if cfg.AUG.INF_RESIZE_VALUE < cfg.AUG.MIN_RESIZE_VALUE or \
                cfg.AUG.INF_RESIZE_VALUE > cfg.AUG.MIN_RESIZE_VALUE:
L
LutaoChu 已提交
356 357 358 359 360 361 362
            logger.info(
                "\nWARNING! you set AUG.AUG_METHOD = 'rangescaling'"
                "AUG.INF_RESIZE_VALUE: {} not in [AUG.MIN_RESIZE_VALUE, AUG.MAX_RESIZE_VALUE]: "
                "[{}, {}].".format(cfg.AUG.INF_RESIZE_VALUE,
                                   cfg.AUG.MIN_RESIZE_VALUE,
                                   cfg.AUG.MAX_RESIZE_VALUE))

W
wuzewu 已提交
363 364 365 366 367 368 369 370 371

def image_type_check(img_dim):
    """
    验证图像的格式与DATASET.IMAGE_TYPE是否一致
    param
        img_dim: 图像包含的通道数
    return
    """
    if (1 in img_dim or 3 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgba':
C
chenguowei01 已提交
372 373
        logger.info(error_print("DATASET.IMAGE_TYPE check"))
        logger.info("DATASET.IMAGE_TYPE is {} but the type of image has "
L
LutaoChu 已提交
374 375 376
                    "gray or rgb\n".format(cfg.DATASET.IMAGE_TYPE))
    elif (1 not in img_dim and 3 not in img_dim
          and 4 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgb':
C
chenguowei01 已提交
377
        logger.info(correct_print("DATASET.IMAGE_TYPE check"))
L
LutaoChu 已提交
378 379 380
        logger.info(
            "\nWARNING: DATASET.IMAGE_TYPE is {} but the type of all image is rgba"
            .format(cfg.DATASET.IMAGE_TYPE))
W
wuzewu 已提交
381
    else:
C
chenguowei01 已提交
382
        logger.info(correct_print("DATASET.IMAGE_TYPE check"))
W
wuzewu 已提交
383

L
LutaoChu 已提交
384

C
chenguowei01 已提交
385 386 387 388 389 390 391
def shape_check():
    """输出shape校验结果"""
    if len(shape_unequal_image) == 0:
        logger.info(correct_print("shape check"))
        logger.info("All images are the same shape as the labels")
    else:
        logger.info(error_print("shape check"))
L
LutaoChu 已提交
392 393
        logger.info(
            "Some images are not the same shape as the labels as follow: ")
C
chenguowei01 已提交
394 395
        for i in shape_unequal_image:
            logger.debug(i)
W
wuzewu 已提交
396 397


C
chenguowei01 已提交
398 399 400
def file_list_check(list_name):
    """检查分割符是否复合要求"""
    if len(list_wrong) == 0:
L
LutaoChu 已提交
401 402 403
        logger.info(
            correct_print(
                list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check"))
C
chenguowei01 已提交
404
    else:
L
LutaoChu 已提交
405 406 407 408 409
        logger.info(
            error_print(
                list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check"))
        logger.info("The following list is not separated by {}".format(
            cfg.DATASET.SEPARATOR))
C
chenguowei01 已提交
410 411 412
        for i in list_wrong:
            logger.debug(i)

L
LutaoChu 已提交
413

C
chenguowei01 已提交
414 415 416 417 418 419 420 421 422
def imread_check():
    if len(imread_failed) == 0:
        logger.info(correct_print("dataset reading check"))
        logger.info("All images can be read successfully")
    else:
        logger.info(error_print("dataset reading check"))
        logger.info("Failed to read {} images".format(len(imread_failed)))
        for i in imread_failed:
            logger.debug(i)
W
wuzewu 已提交
423

L
LutaoChu 已提交
424

C
chenguowei01 已提交
425 426 427 428 429 430
def label_gray_check():
    if len(label_gray_wrong) == 0:
        logger.info(correct_print("label gray check"))
        logger.info("All label images are gray")
    else:
        logger.info(error_print("label gray check"))
L
LutaoChu 已提交
431 432 433
        logger.info(
            "{} label images are not gray\nLabel pixel statistics may be insignificant"
            .format(len(label_gray_wrong)))
C
chenguowei01 已提交
434 435 436
        for i in label_gray_wrong:
            logger.debug(i)

W
wuzewu 已提交
437

L
LutaoChu 已提交
438 439 440 441 442
def max_img_size_statistics():
    logger.info("\nDoing max image size statistics:")
    logger.info("max width and max height of images are ({},{})".format(
        max_width, max_height))

W
wuyefeilin 已提交
443

W
wuyefeilin 已提交
444 445 446
def num_classes_loss_matching_check():
    loss_type = cfg.SOLVER.LOSS
    num_classes = cfg.DATASET.NUM_CLASSES
W
wuyefeilin 已提交
447 448 449 450 451 452 453
    if num_classes > 2 and (("dice_loss" in loss_type) or
                            ("bce_loss" in loss_type)):
        logger.info(
            error_print(
                "loss check."
                " Dice loss and bce loss is only applicable to binary classfication"
            ))
W
wuyefeilin 已提交
454 455 456
    else:
        logger.info(correct_print("loss check"))

W
wuzewu 已提交
457 458

def check_train_dataset():
C
chenguowei01 已提交
459 460 461
    list_file = cfg.DATASET.TRAIN_FILE_LIST
    logger.info("-----------------------------\n1. Check train dataset...")
    with open(list_file, 'r') as fid:
W
wuzewu 已提交
462 463
        lines = fid.readlines()
        for line in tqdm(lines):
C
chenguowei01 已提交
464 465
            line = line.strip()
            parts = line.split(cfg.DATASET.SEPARATOR)
W
wuzewu 已提交
466
            if len(parts) != 2:
C
chenguowei01 已提交
467
                list_wrong.append(line)
W
wuzewu 已提交
468 469 470 471
                continue
            img_name, grt_name = parts[0], parts[1]
            img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name)
            grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
C
chenguowei01 已提交
472 473
            try:
                img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
L
LutaoChu 已提交
474
                grt = pil_imread(grt_path)
C
chenguowei01 已提交
475 476 477
            except Exception as e:
                imread_failed.append((line, str(e)))
                continue
W
wuzewu 已提交
478

C
chenguowei01 已提交
479 480
            is_gray = is_label_gray(grt)
            if not is_gray:
C
chenguowei01 已提交
481
                label_gray_wrong.append(line)
C
chenguowei01 已提交
482
                grt = cv2.cvtColor(grt, cv2.COLOR_BGR2GRAY)
L
LutaoChu 已提交
483
            get_image_max_height_width(img)
C
chenguowei01 已提交
484
            get_image_dim(img)
W
wuzewu 已提交
485 486
            is_equal_img_grt_shape = image_label_shape_check(img, grt)
            if not is_equal_img_grt_shape:
C
chenguowei01 已提交
487
                shape_unequal_image.append(line)
W
wuzewu 已提交
488

L
LutaoChu 已提交
489 490
            png_format, grt_classes, num_of_each_class = ground_truth_check(
                grt, grt_path)
C
chenguowei01 已提交
491 492
            if not png_format:
                png_format_wrong_image.append(line)
L
LutaoChu 已提交
493 494
            is_label_correct = sum_gt_check(png_format, grt_classes,
                                            num_of_each_class)
C
chenguowei01 已提交
495 496
            if not is_label_correct:
                label_wrong.append(line)
W
wuzewu 已提交
497

C
chenguowei01 已提交
498 499
        file_list_check(list_file)
        imread_check()
C
chenguowei01 已提交
500
        label_gray_check()
W
wuzewu 已提交
501 502
        gt_check()
        image_type_check(img_dim)
L
LutaoChu 已提交
503
        max_img_size_statistics()
C
chenguowei01 已提交
504
        shape_check()
W
wuyefeilin 已提交
505
        num_classes_loss_matching_check()
C
chenguowei01 已提交
506 507


W
wuzewu 已提交
508
def check_val_dataset():
C
chenguowei01 已提交
509 510 511
    list_file = cfg.DATASET.VAL_FILE_LIST
    logger.info("\n-----------------------------\n2. Check val dataset...")
    with open(list_file) as fid:
W
wuzewu 已提交
512 513
        lines = fid.readlines()
        for line in tqdm(lines):
C
chenguowei01 已提交
514 515
            line = line.strip()
            parts = line.split(cfg.DATASET.SEPARATOR)
W
wuzewu 已提交
516
            if len(parts) != 2:
C
chenguowei01 已提交
517
                list_wrong.append(line)
W
wuzewu 已提交
518 519 520 521
                continue
            img_name, grt_name = parts[0], parts[1]
            img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name)
            grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
C
chenguowei01 已提交
522 523
            try:
                img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
L
LutaoChu 已提交
524
                grt = pil_imread(grt_path)
C
chenguowei01 已提交
525
            except Exception as e:
L
LutaoChu 已提交
526 527
                imread_failed.append((line, str(e)))
                continue
C
chenguowei01 已提交
528

C
chenguowei01 已提交
529 530
            is_gray = is_label_gray(grt)
            if not is_gray:
C
chenguowei01 已提交
531
                label_gray_wrong.append(line)
C
chenguowei01 已提交
532
                grt = cv2.cvtColor(grt, cv2.COLOR_BGR2GRAY)
C
chenguowei01 已提交
533 534 535
            get_image_max_height_width(img)
            get_image_min_max_aspectratio(img)
            get_image_dim(img)
W
wuzewu 已提交
536 537
            is_equal_img_grt_shape = image_label_shape_check(img, grt)
            if not is_equal_img_grt_shape:
C
chenguowei01 已提交
538
                shape_unequal_image.append(line)
L
LutaoChu 已提交
539 540
            png_format, grt_classes, num_of_each_class = ground_truth_check(
                grt, grt_path)
C
chenguowei01 已提交
541 542
            if not png_format:
                png_format_wrong_image.append(line)
L
LutaoChu 已提交
543 544
            is_label_correct = sum_gt_check(png_format, grt_classes,
                                            num_of_each_class)
C
chenguowei01 已提交
545 546 547 548 549
            if not is_label_correct:
                label_wrong.append(line)

        file_list_check(list_file)
        imread_check()
C
chenguowei01 已提交
550
        label_gray_check()
W
wuzewu 已提交
551 552
        gt_check()
        image_type_check(img_dim)
L
LutaoChu 已提交
553
        max_img_size_statistics()
C
chenguowei01 已提交
554
        shape_check()
L
LutaoChu 已提交
555 556 557
        eval_crop_size_check(max_height, max_width, min_aspectratio,
                             max_aspectratio)

W
wuzewu 已提交
558 559

def check_test_dataset():
C
chenguowei01 已提交
560 561 562 563
    list_file = cfg.DATASET.TEST_FILE_LIST
    has_label = False
    with open(list_file) as fid:
        logger.info("\n-----------------------------\n3. Check test dataset...")
W
wuzewu 已提交
564 565
        lines = fid.readlines()
        for line in tqdm(lines):
C
chenguowei01 已提交
566 567
            line = line.strip()
            parts = line.split(cfg.DATASET.SEPARATOR)
W
wuzewu 已提交
568 569
            if len(parts) == 1:
                img_name = parts
C
chenguowei01 已提交
570 571 572 573 574 575
                img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name[0])
                try:
                    img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
                except Exception as e:
                    imread_failed.append((line, str(e)))
                    continue
W
wuzewu 已提交
576
            elif len(parts) == 2:
C
chenguowei01 已提交
577
                has_label = True
W
wuzewu 已提交
578 579 580
                img_name, grt_name = parts[0], parts[1]
                img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name)
                grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
C
chenguowei01 已提交
581 582
                try:
                    img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
L
LutaoChu 已提交
583
                    grt = pil_imread(grt_path)
C
chenguowei01 已提交
584
                except Exception as e:
L
LutaoChu 已提交
585
                    imread_failed.append((line, str(e)))
C
chenguowei01 已提交
586
                    continue
C
chenguowei01 已提交
587

C
chenguowei01 已提交
588 589
                is_gray = is_label_gray(grt)
                if not is_gray:
C
chenguowei01 已提交
590
                    label_gray_wrong.append(line)
C
chenguowei01 已提交
591
                    grt = cv2.cvtColor(grt, cv2.COLOR_BGR2GRAY)
W
wuzewu 已提交
592 593
                is_equal_img_grt_shape = image_label_shape_check(img, grt)
                if not is_equal_img_grt_shape:
C
chenguowei01 已提交
594
                    shape_unequal_image.append(line)
L
LutaoChu 已提交
595 596
                png_format, grt_classes, num_of_each_class = ground_truth_check(
                    grt, grt_path)
C
chenguowei01 已提交
597 598
                if not png_format:
                    png_format_wrong_image.append(line)
L
LutaoChu 已提交
599 600
                is_label_correct = sum_gt_check(png_format, grt_classes,
                                                num_of_each_class)
C
chenguowei01 已提交
601 602
                if not is_label_correct:
                    label_wrong.append(line)
W
wuzewu 已提交
603
            else:
C
chenguowei01 已提交
604
                list_wrong.append(lines)
W
wuzewu 已提交
605
                continue
C
chenguowei01 已提交
606 607 608 609 610 611
            get_image_max_height_width(img)
            get_image_min_max_aspectratio(img)
            get_image_dim(img)

        file_list_check(list_file)
        imread_check()
C
chenguowei01 已提交
612 613
        if has_label:
            label_gray_check()
C
chenguowei01 已提交
614 615
        if has_label:
            gt_check()
W
wuzewu 已提交
616
        image_type_check(img_dim)
L
LutaoChu 已提交
617
        max_img_size_statistics()
C
chenguowei01 已提交
618 619
        if has_label:
            shape_check()
L
LutaoChu 已提交
620 621 622
        eval_crop_size_check(max_height, max_width, min_aspectratio,
                             max_aspectratio)

W
wuzewu 已提交
623 624 625 626

def main(args):
    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
L
LutaoChu 已提交
627
    cfg.check_and_infer()
C
chenguowei01 已提交
628
    logger.info(pprint.pformat(cfg))
W
wuzewu 已提交
629 630 631 632 633 634 635 636 637 638 639 640

    init_global_variable()
    check_train_dataset()

    init_global_variable()
    check_val_dataset()

    init_global_variable()
    check_test_dataset()

    inf_resize_value_check()

L
LutaoChu 已提交
641 642 643
    print("\nDetailed error information can be viewed in detail.log file.")


W
wuzewu 已提交
644 645
if __name__ == "__main__":
    args = parse_args()
C
chenguowei01 已提交
646 647 648 649 650 651 652 653 654 655 656
    logger = logging.getLogger()
    logger.setLevel('DEBUG')
    BASIC_FORMAT = "%(message)s"
    formatter = logging.Formatter(BASIC_FORMAT)
    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    sh.setLevel('INFO')
    th = logging.FileHandler('detail.log', 'w')
    th.setFormatter(formatter)
    logger.addHandler(sh)
    logger.addHandler(th)
W
wuzewu 已提交
657
    main(args)