check.py 18.7 KB
Newer Older
C
chenguowei01 已提交
1
# -*- coding: utf-8 -*-
W
wuzewu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import os
import sys
import pprint
import argparse
import cv2
from tqdm import tqdm
import imghdr
C
chenguowei01 已提交
15
import logging
W
wuzewu 已提交
16 17 18 19 20 21 22

from utils.config import cfg

def init_global_variable():
    """
    初始化全局变量
    """
C
chenguowei01 已提交
23
    global png_format_right_num  # 格式正确的标注图数量
C
chenguowei01 已提交
24 25
    global png_format_wrong_num  # 格式错误的标注图数量
    global total_grt_classes  # 总的标注类别
W
wuzewu 已提交
26
    global total_num_of_each_class  # 每个类别总的像素数
C
chenguowei01 已提交
27 28 29 30 31 32 33 34 35 36
    global shape_unequal_image  # 图片和标注shape不一致列表
    global png_format_wrong_image  # 标注格式错误列表
    global max_width  # 图片最长宽
    global max_height  # 图片最长高
    global min_aspectratio  # 图片最小宽高比
    global max_aspectratio  # 图片最大宽高比
    global img_dim  # 图片的通道数
    global list_wrong  #文件名格式错误列表
    global imread_failed  #图片读取失败列表, 二元列表
    global label_wrong  # 标注图片出错列表
W
wuzewu 已提交
37 38 39 40 41

    png_format_right_num = 0
    png_format_wrong_num = 0
    total_grt_classes = []
    total_num_of_each_class = []
C
chenguowei01 已提交
42 43 44 45 46 47 48 49 50 51
    shape_unequal_image = []
    png_format_wrong_image = []
    max_width = 0
    max_height = 0
    min_aspectratio = sys.float_info.max
    max_aspectratio = 0
    img_dim = []
    list_wrong = []
    imread_failed = []
    label_wrong = []
W
wuzewu 已提交
52 53 54 55

def parse_args():
    parser = argparse.ArgumentParser(description='PaddleSeg check')
    parser.add_argument(
C
chenguowei01 已提交
56 57 58 59 60 61
            '--cfg',
            dest='cfg_file',
            help='Config file for training (and optionally testing)',
            default=None,
            type=str
            )
W
wuzewu 已提交
62 63
    return parser.parse_args()

C
chenguowei01 已提交
64 65 66 67 68
def error_print(str):
    return "".join(["\nNOT PASS ", str])

def correct_print(str):
    return "".join(["\nPASS ", str])
W
wuzewu 已提交
69 70 71 72 73

def cv2_imread(file_path, flag=cv2.IMREAD_COLOR):
    # resolve cv2.imread open Chinese file path issues on Windows Platform.
    return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag)

C
chenguowei01 已提交
74 75
def get_image_max_height_width(img):
    global max_width, max_height
W
wuzewu 已提交
76 77 78 79 80
    img_shape = img.shape
    height, width = img_shape[0], img_shape[1]
    max_height = max(height, max_height)
    max_width = max(width, max_width)

C
chenguowei01 已提交
81 82
def get_image_min_max_aspectratio(img):
    global min_aspectratio, max_aspectratio
W
wuzewu 已提交
83 84
    img_shape = img.shape
    height, width = img_shape[0], img_shape[1]
C
chenguowei01 已提交
85 86
    min_aspectratio = min(width/height, min_aspectratio)
    max_aspectratio = max(width/height, max_aspectratio)
W
wuzewu 已提交
87 88
    return min_aspectratio, max_aspectratio

C
chenguowei01 已提交
89
def get_image_dim(img):
W
wuzewu 已提交
90 91 92 93 94
    """获取图像的维度"""
    img_shape = img.shape
    if img_shape[-1] not in img_dim:
        img_dim.append(img_shape[-1])

C
chenguowei01 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
def image_label_shape_check(img, grt):
    """
    验证图像和标注的大小是否匹配
    """

    flag = True
    img_height = img.shape[0]
    img_width = img.shape[1]
    grt_height = grt.shape[0]
    grt_width = grt.shape[1]


    if img_height != grt_height or img_width != grt_width:
        flag = False
    return flag

def ground_truth_check(grt, grt_path):
    """
    验证标注是否重零开始,标注值为0,1,...,num_classes-1, ingnore_idx
    验证标注图像的格式
    返回标注的像素数
    检查图像是否都是ignore_index
    params:
        grt: 标注图
        grt_path: 标注图路径
    return:
        png_format: 返回是否是png格式图片
        label_correct: 返回标注是否是正确的
        label_pixel_num: 返回标注的像素数
    """
    if imghdr.what(grt_path) == "png":
        png_format = True
    else:
        png_format = False

    unique, counts = np.unique(grt, return_counts=True)

    return png_format, unique, counts
W
wuzewu 已提交
133 134 135

def sum_gt_check(png_format, grt_classes, num_of_each_class):
    """
C
chenguowei01 已提交
136
    统计所有标注图上的格式、类别和每个类别的像素数
W
wuzewu 已提交
137 138
    params:
        png_format: 返回是否是png格式图片
C
chenguowei01 已提交
139
        grt_classes: 标注类别
W
wuzewu 已提交
140 141
        num_of_each_class: 各个类别的像素数目
    """
C
chenguowei01 已提交
142
    is_label_correct = True
W
wuzewu 已提交
143 144 145 146 147 148 149 150
    global png_format_right_num, png_format_wrong_num, total_grt_classes, total_num_of_each_class

    if png_format:
        png_format_right_num += 1
    else:
        png_format_wrong_num += 1

    if cfg.DATASET.IGNORE_INDEX in grt_classes:
C
chenguowei01 已提交
151 152 153
        grt_classes2 = np.delete(grt_classes, np.where(grt_classes == cfg.DATASET.IGNORE_INDEX))
    else:
        grt_classes2 = grt_classes
W
wuzewu 已提交
154
    if min(grt_classes2) < 0 or max(grt_classes2) > cfg.DATASET.NUM_CLASSES - 1:
C
chenguowei01 已提交
155
        is_label_correct = False
W
wuzewu 已提交
156 157 158 159 160 161 162 163 164 165 166 167
    add_class = []
    add_num = []
    for i in range(len(grt_classes)):
        gi = grt_classes[i]
        if gi in total_grt_classes:
            j = total_grt_classes.index(gi)
            total_num_of_each_class[j] += num_of_each_class[i]
        else:
            add_class.append(gi)
            add_num.append(num_of_each_class[i])
    total_num_of_each_class += add_num
    total_grt_classes += add_class
C
chenguowei01 已提交
168
    return is_label_correct
W
wuzewu 已提交
169 170 171

def gt_check():
    """
C
chenguowei01 已提交
172
    对标注图像进行校验,输出校验结果
W
wuzewu 已提交
173 174
    """
    if png_format_wrong_num == 0:
C
chenguowei01 已提交
175 176 177 178 179 180
        if png_format_right_num:
            logger.info(correct_print("label format check"))
        else:
            logger.info(error_print("label format check"))
            logger.info("No label image to check")
            return
W
wuzewu 已提交
181
    else:
C
chenguowei01 已提交
182 183 184 185 186 187
        logger.info(error_print("label format check"))
    logger.info("total {} label images are png format, {} label images are not png "
               "format".format(png_format_right_num, png_format_wrong_num))
    if len(png_format_wrong_image) > 0:
        for i in png_format_wrong_image:
            logger.debug(i)
W
wuzewu 已提交
188 189


C
chenguowei01 已提交
190 191 192
    total_nc = sorted(zip(total_grt_classes, total_num_of_each_class))
    logger.info("\nDoing label pixel statistics...\nTotal label calsses "
                "and their corresponding numbers:\n{} ".format(total_nc))
W
wuzewu 已提交
193

C
chenguowei01 已提交
194 195
    if len(label_wrong) == 0 and not total_nc[0][0]:
        logger.info(correct_print("label class check!"))
W
wuzewu 已提交
196
    else:
C
chenguowei01 已提交
197 198 199 200 201 202 203
        logger.info(error_print("label class check!"))
        if total_nc[0][0]:
            logger.info("Warning: label classes should start from 0")
        if len(label_wrong) > 0:
            logger.info("fatal error: label class is out of range [0, {}]".format(cfg.DATASET.NUM_CLASSES - 1))
            for i in label_wrong:
                logger.debug(i)
W
wuzewu 已提交
204 205 206



C
chenguowei01 已提交
207
def eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio):
W
wuzewu 已提交
208 209 210 211 212 213
    """
    判断eval_crop_siz与验证集及测试集的max_height, max_width的关系
    param
        max_height: 数据集的最大高
        max_width: 数据集的最大宽
    """
C
chenguowei01 已提交
214

W
wuzewu 已提交
215
    if cfg.AUG.AUG_METHOD == "stepscaling":
C
chenguowei01 已提交
216 217 218 219 220 221 222 223 224 225 226
        if max_width <= cfg.EVAL_CROP_SIZE[0] or max_height <= cfg.EVAL_CROP_SIZE[1]:
            logger.info(correct_print("EVAL_CROP_SIZE check"))
        else:
            logger.info(error_print("EVAL_CROP_SIZE check"))
            if max_width > cfg.EVAL_CROP_SIZE[0]:
                logger.info("The EVAL_CROP_SIZE[0]: {} should larger max width of images {}!".format(
                cfg.EVAL_CROP_SIZE[0], max_width))
            if max_height > cfg.EVAL_CROP_SIZE[1]:
                logger.info(error_print("The EVAL_CROP_SIZE[1]: {} should larger max height of images {}!".format(
                    cfg.EVAL_CROP_SIZE[1], max_height)))

W
wuzewu 已提交
227 228
    elif cfg.AUG.AUG_METHOD == "rangescaling":
        if min_aspectratio <= 1 and max_aspectratio >= 1:
C
chenguowei01 已提交
229 230
            if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.INF_RESIZE_VALUE:
                logger.info(correct_print("EVAL_CROP_SIZE check"))
W
wuzewu 已提交
231
            else:
C
chenguowei01 已提交
232 233
                logger.info(error_print("EVAL_CROP_SIZE check"))
                logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
W
wuzewu 已提交
234 235 236 237 238
                    .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
                            cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.INF_RESIZE_VALUE))
        elif min_aspectratio > 1:
            max_height_rangscaling = cfg.AUG.INF_RESIZE_VALUE / min_aspectratio
            max_height_rangscaling = round(max_height_rangscaling)
C
chenguowei01 已提交
239 240
            if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[1] >= max_height_rangscaling:
                logger.info(correct_print("EVAL_CROP_SIZE check"))
W
wuzewu 已提交
241
            else:
C
chenguowei01 已提交
242 243 244 245
                logger.info(error_print("EVAL_CROP_SIZE check"))
                logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
                      .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
                              cfg.AUG.INF_RESIZE_VALUE, max_height_rangscaling))
W
wuzewu 已提交
246 247 248
        elif max_aspectratio < 1:
            max_width_rangscaling = cfg.AUG.INF_RESIZE_VALUE * max_aspectratio
            max_width_rangscaling = round(max_width_rangscaling)
C
chenguowei01 已提交
249 250
            if cfg.EVAL_CROP_SIZE[0] >= max_width_rangscaling and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.INF_RESIZE_VALUE:
                logger.info(correct_print("EVAL_CROP_SIZE check"))
W
wuzewu 已提交
251
            else:
C
chenguowei01 已提交
252 253
                logger.info(error_print("EVAL_CROP_SIZE check"))
                logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
W
wuzewu 已提交
254 255 256
                    .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
                            max_width_rangscaling, cfg.AUG.INF_RESIZE_VALUE))
    elif cfg.AUG.AUG_METHOD == "unpadding":
C
chenguowei01 已提交
257 258
        if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.FIX_RESIZE_SIZE[0] and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.FIX_RESIZE_SIZE[1]:
            logger.info(correct_print("EVAL_CROP_SIZE check"))
W
wuzewu 已提交
259
        else:
C
chenguowei01 已提交
260 261 262 263
            logger.info(error_print("EVAL_CROP_SIZE check"))
            logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
                  .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
                          cfg.AUG.FIX_RESIZE_SIZE[0], cfg.AUG.FIX_RESIZE_SIZE[1]))
W
wuzewu 已提交
264
    else:
C
chenguowei01 已提交
265 266
        logger.info("\nERROR! cfg.AUG.AUG_METHOD setting wrong, it should be one of "
                          "[unpadding, stepscaling, rangescaling]")
W
wuzewu 已提交
267 268 269 270 271

def inf_resize_value_check():
    if cfg.AUG.AUG_METHOD == "rangescaling":
        if cfg.AUG.INF_RESIZE_VALUE < cfg.AUG.MIN_RESIZE_VALUE or \
                cfg.AUG.INF_RESIZE_VALUE > cfg.AUG.MIN_RESIZE_VALUE:
C
chenguowei01 已提交
272 273 274
            logger.info("\nWARNING! you set AUG.AUG_METHOD = 'rangescaling'"
                  "AUG.INF_RESIZE_VALUE: {} not in [AUG.MIN_RESIZE_VALUE, AUG.MAX_RESIZE_VALUE]: "
                  "[{}, {}].".format(cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.MIN_RESIZE_VALUE, cfg.AUG.MAX_RESIZE_VALUE))
W
wuzewu 已提交
275 276 277 278 279 280 281 282 283

def image_type_check(img_dim):
    """
    验证图像的格式与DATASET.IMAGE_TYPE是否一致
    param
        img_dim: 图像包含的通道数
    return
    """
    if (1 in img_dim or 3 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgba':
C
chenguowei01 已提交
284 285 286 287 288 289
        logger.info(error_print("DATASET.IMAGE_TYPE check"))
        logger.info("DATASET.IMAGE_TYPE is {} but the type of image has "
                            "gray or rgb\n".format(cfg.DATASET.IMAGE_TYPE))
    elif (1 not in img_dim and 3 not in img_dim and 4 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgb':
        logger.info(correct_print("DATASET.IMAGE_TYPE check"))
        logger.info("\nWARNING: DATASET.IMAGE_TYPE is {} but the type of all image is rgba".format(cfg.DATASET.IMAGE_TYPE))
W
wuzewu 已提交
290
    else:
C
chenguowei01 已提交
291
        logger.info(correct_print("DATASET.IMAGE_TYPE check"))
W
wuzewu 已提交
292

C
chenguowei01 已提交
293 294 295 296 297 298 299 300 301 302
def shape_check():
    """输出shape校验结果"""
    if len(shape_unequal_image) == 0:
        logger.info(correct_print("shape check"))
        logger.info("All images are the same shape as the labels")
    else:
        logger.info(error_print("shape check"))
        logger.info("Some images are not the same shape as the labels as follow: ")
        for i in shape_unequal_image:
            logger.debug(i)
W
wuzewu 已提交
303 304


C
chenguowei01 已提交
305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
def file_list_check(list_name):
    """检查分割符是否复合要求"""
    if len(list_wrong) == 0:
        logger.info(correct_print(list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check"))
    else:
        logger.info(error_print(list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check"))
        logger.info("The following list is not separated by {}".format(cfg.DATASET.SEPARATOR))
        for i in list_wrong:
            logger.debug(i)

def imread_check():
    if len(imread_failed) == 0:
        logger.info(correct_print("dataset reading check"))
        logger.info("All images can be read successfully")
    else:
        logger.info(error_print("dataset reading check"))
        logger.info("Failed to read {} images".format(len(imread_failed)))
        for i in imread_failed:
            logger.debug(i)
W
wuzewu 已提交
324 325 326 327



def check_train_dataset():
C
chenguowei01 已提交
328 329 330
    list_file = cfg.DATASET.TRAIN_FILE_LIST
    logger.info("-----------------------------\n1. Check train dataset...")
    with open(list_file, 'r') as fid:
W
wuzewu 已提交
331 332
        lines = fid.readlines()
        for line in tqdm(lines):
C
chenguowei01 已提交
333 334
            line = line.strip()
            parts = line.split(cfg.DATASET.SEPARATOR)
W
wuzewu 已提交
335
            if len(parts) != 2:
C
chenguowei01 已提交
336
                list_wrong.append(line)
W
wuzewu 已提交
337 338 339 340
                continue
            img_name, grt_name = parts[0], parts[1]
            img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name)
            grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
C
chenguowei01 已提交
341 342 343 344 345 346
            try:
                img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
                grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
            except Exception as e:
                imread_failed.append((line, str(e)))
                continue
W
wuzewu 已提交
347

C
chenguowei01 已提交
348
            get_image_dim(img)
W
wuzewu 已提交
349 350
            is_equal_img_grt_shape = image_label_shape_check(img, grt)
            if not is_equal_img_grt_shape:
C
chenguowei01 已提交
351
                shape_unequal_image.append(line)
W
wuzewu 已提交
352

C
chenguowei01 已提交
353 354 355 356 357 358
            png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path)
            if not png_format:
                png_format_wrong_image.append(line)
            is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class)
            if not is_label_correct:
                label_wrong.append(line)
W
wuzewu 已提交
359

C
chenguowei01 已提交
360 361
        file_list_check(list_file)
        imread_check()
W
wuzewu 已提交
362 363
        gt_check()
        image_type_check(img_dim)
C
chenguowei01 已提交
364 365 366 367
        shape_check()



W
wuzewu 已提交
368 369 370


def check_val_dataset():
C
chenguowei01 已提交
371 372 373
    list_file = cfg.DATASET.VAL_FILE_LIST
    logger.info("\n-----------------------------\n2. Check val dataset...")
    with open(list_file) as fid:
W
wuzewu 已提交
374 375
        lines = fid.readlines()
        for line in tqdm(lines):
C
chenguowei01 已提交
376 377
            line = line.strip()
            parts = line.split(cfg.DATASET.SEPARATOR)
W
wuzewu 已提交
378
            if len(parts) != 2:
C
chenguowei01 已提交
379
                list_wrong.append(line)
W
wuzewu 已提交
380 381 382 383
                continue
            img_name, grt_name = parts[0], parts[1]
            img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name)
            grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
C
chenguowei01 已提交
384 385 386 387 388 389 390 391
            try:
                img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
                grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
            except Exception as e:
                imread_failed.append((line, e.message))
            get_image_max_height_width(img)
            get_image_min_max_aspectratio(img)
            get_image_dim(img)
W
wuzewu 已提交
392 393
            is_equal_img_grt_shape = image_label_shape_check(img, grt)
            if not is_equal_img_grt_shape:
C
chenguowei01 已提交
394 395 396 397 398 399 400 401 402 403
                shape_unequal_image.append(line)
            png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path)
            if not png_format:
                png_format_wrong_image.append(line)
            is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class)
            if not is_label_correct:
                label_wrong.append(line)

        file_list_check(list_file)
        imread_check()
W
wuzewu 已提交
404 405
        gt_check()
        image_type_check(img_dim)
C
chenguowei01 已提交
406 407
        shape_check()
        eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio)
W
wuzewu 已提交
408 409

def check_test_dataset():
C
chenguowei01 已提交
410 411 412 413
    list_file = cfg.DATASET.TEST_FILE_LIST
    has_label = False
    with open(list_file) as fid:
        logger.info("\n-----------------------------\n3. Check test dataset...")
W
wuzewu 已提交
414 415
        lines = fid.readlines()
        for line in tqdm(lines):
C
chenguowei01 已提交
416 417
            line = line.strip()
            parts = line.split(cfg.DATASET.SEPARATOR)
W
wuzewu 已提交
418 419
            if len(parts) == 1:
                img_name = parts
C
chenguowei01 已提交
420 421 422 423 424 425
                img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name[0])
                try:
                    img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
                except Exception as e:
                    imread_failed.append((line, str(e)))
                    continue
W
wuzewu 已提交
426
            elif len(parts) == 2:
C
chenguowei01 已提交
427
                has_label = True
W
wuzewu 已提交
428 429 430
                img_name, grt_name = parts[0], parts[1]
                img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name)
                grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
C
chenguowei01 已提交
431 432 433 434 435 436
                try:
                    img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
                    grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
                except Exception as e:
                    imread_failed.append((line, e.message))
                    continue
W
wuzewu 已提交
437 438
                is_equal_img_grt_shape = image_label_shape_check(img, grt)
                if not is_equal_img_grt_shape:
C
chenguowei01 已提交
439 440 441 442 443 444 445
                    shape_unequal_image.append(line)
                png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path)
                if not png_format:
                    png_format_wrong_image.append(line)
                is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class)
                if not is_label_correct:
                    label_wrong.append(line)
W
wuzewu 已提交
446
            else:
C
chenguowei01 已提交
447
                list_wrong.append(lines)
W
wuzewu 已提交
448
                continue
C
chenguowei01 已提交
449 450 451 452 453 454 455 456
            get_image_max_height_width(img)
            get_image_min_max_aspectratio(img)
            get_image_dim(img)

        file_list_check(list_file)
        imread_check()
        if has_label:
            gt_check()
W
wuzewu 已提交
457
        image_type_check(img_dim)
C
chenguowei01 已提交
458 459 460
        if has_label:
            shape_check()
        eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio)
W
wuzewu 已提交
461 462 463 464 465

def main(args):
    if args.cfg_file is not None:
        cfg.update_from_file(args.cfg_file)
    cfg.check_and_infer(reset_dataset=True)
C
chenguowei01 已提交
466
    logger.info(pprint.pformat(cfg))
W
wuzewu 已提交
467 468 469 470 471 472 473 474 475 476 477 478 479 480

    init_global_variable()
    check_train_dataset()

    init_global_variable()
    check_val_dataset()

    init_global_variable()
    check_test_dataset()

    inf_resize_value_check()

if __name__ == "__main__":
    args = parse_args()
C
chenguowei01 已提交
481 482 483 484 485 486 487 488 489 490 491
    logger = logging.getLogger()
    logger.setLevel('DEBUG')
    BASIC_FORMAT = "%(message)s"
    formatter = logging.Formatter(BASIC_FORMAT)
    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    sh.setLevel('INFO')
    th = logging.FileHandler('detail.log', 'w')
    th.setFormatter(formatter)
    logger.addHandler(sh)
    logger.addHandler(th)
W
wuzewu 已提交
492
    main(args)
C
chenguowei01 已提交
493 494