提交 403b4a5b 编写于 作者: W wuyefeilin 提交者: wuzewu

update check.py (#3)

* update check.py
上级 08f001b0
...@@ -12,76 +12,134 @@ import argparse ...@@ -12,76 +12,134 @@ import argparse
import cv2 import cv2
from tqdm import tqdm from tqdm import tqdm
import imghdr import imghdr
import logging
from utils.config import cfg from utils.config import cfg
def init_global_variable(): def init_global_variable():
""" """
初始化全局变量 初始化全局变量
""" """
global png_format_right_num # 格式错误的标签图数量 global png_format_right_num # 格式正确的标注图数量
global png_format_wrong_num # 格式错误的标图数量 global png_format_wrong_num # 格式错误的标图数量
global total_grt_classes # 总的标类别 global total_grt_classes # 总的标类别
global total_num_of_each_class # 每个类别总的像素数 global total_num_of_each_class # 每个类别总的像素数
global shape_unequal # 图片和标签shape不一致 global shape_unequal_image # 图片和标注shape不一致列表
global png_format_wrong # 标签格式错误 global png_format_wrong_image # 标注格式错误列表
global max_width # 图片最长宽
global max_height # 图片最长高
global min_aspectratio # 图片最小宽高比
global max_aspectratio # 图片最大宽高比
global img_dim # 图片的通道数
global list_wrong #文件名格式错误列表
global imread_failed #图片读取失败列表, 二元列表
global label_wrong # 标注图片出错列表
png_format_right_num = 0 png_format_right_num = 0
png_format_wrong_num = 0 png_format_wrong_num = 0
total_grt_classes = [] total_grt_classes = []
total_num_of_each_class = [] total_num_of_each_class = []
shape_unequal = [] shape_unequal_image = []
png_format_wrong = [] png_format_wrong_image = []
max_width = 0
max_height = 0
min_aspectratio = sys.float_info.max
max_aspectratio = 0
img_dim = []
list_wrong = []
imread_failed = []
label_wrong = []
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='PaddleSeg check') parser = argparse.ArgumentParser(description='PaddleSeg check')
parser.add_argument( parser.add_argument(
'--cfg', '--cfg',
dest='cfg_file', dest='cfg_file',
help='Config file for training (and optionally testing)', help='Config file for training (and optionally testing)',
default=None, default=None,
type=str) type=str
)
return parser.parse_args() return parser.parse_args()
def error_print(str):
return "".join(["\nNOT PASS ", str])
def correct_print(str):
return "".join(["\nPASS ", str])
def cv2_imread(file_path, flag=cv2.IMREAD_COLOR): def cv2_imread(file_path, flag=cv2.IMREAD_COLOR):
# resolve cv2.imread open Chinese file path issues on Windows Platform. # resolve cv2.imread open Chinese file path issues on Windows Platform.
return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag) return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag)
def get_image_max_height_width(img):
def get_image_max_height_width(img, max_height, max_width): global max_width, max_height
img_shape = img.shape img_shape = img.shape
height, width = img_shape[0], img_shape[1] height, width = img_shape[0], img_shape[1]
max_height = max(height, max_height) max_height = max(height, max_height)
max_width = max(width, max_width) max_width = max(width, max_width)
return max_height, max_width
def get_image_min_max_aspectratio(img):
def get_image_min_max_aspectratio(img, min_aspectratio, max_aspectratio): global min_aspectratio, max_aspectratio
img_shape = img.shape img_shape = img.shape
height, width = img_shape[0], img_shape[1] height, width = img_shape[0], img_shape[1]
min_aspectratio = min(width / height, min_aspectratio) min_aspectratio = min(width/height, min_aspectratio)
max_aspectratio = max(width / height, max_aspectratio) max_aspectratio = max(width/height, max_aspectratio)
return min_aspectratio, max_aspectratio return min_aspectratio, max_aspectratio
def get_image_dim(img):
def get_image_dim(img, img_dim):
"""获取图像的维度""" """获取图像的维度"""
img_shape = img.shape img_shape = img.shape
if img_shape[-1] not in img_dim: if img_shape[-1] not in img_dim:
img_dim.append(img_shape[-1]) img_dim.append(img_shape[-1])
def image_label_shape_check(img, grt):
"""
验证图像和标注的大小是否匹配
"""
flag = True
img_height = img.shape[0]
img_width = img.shape[1]
grt_height = grt.shape[0]
grt_width = grt.shape[1]
if img_height != grt_height or img_width != grt_width:
flag = False
return flag
def ground_truth_check(grt, grt_path):
"""
验证标注是否重零开始,标注值为0,1,...,num_classes-1, ingnore_idx
验证标注图像的格式
返回标注的像素数
检查图像是否都是ignore_index
params:
grt: 标注图
grt_path: 标注图路径
return:
png_format: 返回是否是png格式图片
label_correct: 返回标注是否是正确的
label_pixel_num: 返回标注的像素数
"""
if imghdr.what(grt_path) == "png":
png_format = True
else:
png_format = False
unique, counts = np.unique(grt, return_counts=True)
return png_format, unique, counts
def sum_gt_check(png_format, grt_classes, num_of_each_class): def sum_gt_check(png_format, grt_classes, num_of_each_class):
""" """
统计所有标图上的格式、类别和每个类别的像素数 统计所有标图上的格式、类别和每个类别的像素数
params: params:
png_format: 返回是否是png格式图片 png_format: 返回是否是png格式图片
grt_classes: 标类别 grt_classes: 标类别
num_of_each_class: 各个类别的像素数目 num_of_each_class: 各个类别的像素数目
""" """
is_label_correct = True
global png_format_right_num, png_format_wrong_num, total_grt_classes, total_num_of_each_class global png_format_right_num, png_format_wrong_num, total_grt_classes, total_num_of_each_class
if png_format: if png_format:
...@@ -90,12 +148,11 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class): ...@@ -90,12 +148,11 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class):
png_format_wrong_num += 1 png_format_wrong_num += 1
if cfg.DATASET.IGNORE_INDEX in grt_classes: if cfg.DATASET.IGNORE_INDEX in grt_classes:
grt_classes2 = np.delete( grt_classes2 = np.delete(grt_classes, np.where(grt_classes == cfg.DATASET.IGNORE_INDEX))
grt_classes, np.where(grt_classes == cfg.DATASET.IGNORE_INDEX)) else:
grt_classes2 = grt_classes
if min(grt_classes2) < 0 or max(grt_classes2) > cfg.DATASET.NUM_CLASSES - 1: if min(grt_classes2) < 0 or max(grt_classes2) > cfg.DATASET.NUM_CLASSES - 1:
print("fatal error: label class is out of range [0, {}]".format( is_label_correct = False
cfg.DATASET.NUM_CLASSES - 1))
add_class = [] add_class = []
add_num = [] add_num = []
for i in range(len(grt_classes)): for i in range(len(grt_classes)):
...@@ -108,145 +165,113 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class): ...@@ -108,145 +165,113 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class):
add_num.append(num_of_each_class[i]) add_num.append(num_of_each_class[i])
total_num_of_each_class += add_num total_num_of_each_class += add_num
total_grt_classes += add_class total_grt_classes += add_class
return is_label_correct
def gt_check(): def gt_check():
""" """
对标签进行校验,输出校验结果 对标注图像进行校验,输出校验结果
params:
png_format_wrong_num: 格式错误的标签图数量
png_format_right_num: 格式正确的标签图数量
total_grt_classes: 总的标签类别
total_num_of_each_class: 每个类别总的像素数目
return:
total_nc: 按升序排序后的总标签类别和像素数目
""" """
if png_format_wrong_num == 0: if png_format_wrong_num == 0:
print("Not pass label png format check!") if png_format_right_num:
logger.info(correct_print("label format check"))
else:
logger.info(error_print("label format check"))
logger.info("No label image to check")
return
else: else:
print("Pass label png format check!") logger.info(error_print("label format check"))
print( logger.info("total {} label images are png format, {} label images are not png "
"total {} label imgs are png format, {} label imgs are not png fromat". "format".format(png_format_right_num, png_format_wrong_num))
format(png_format_right_num, png_format_wrong_num)) if len(png_format_wrong_image) > 0:
for i in png_format_wrong_image:
logger.debug(i)
total_nc = sorted(zip(total_grt_classes, total_num_of_each_class))
print("total label calsses and their corresponding numbers:\n{} ".format(
total_nc))
if total_nc[0][0]:
print(
"Not pass label class check!\nWarning: label classes should start from 0 !!!"
)
else:
print("Pass label class check!")
total_nc = sorted(zip(total_grt_classes, total_num_of_each_class))
logger.info("\nDoing label pixel statistics...\nTotal label calsses "
"and their corresponding numbers:\n{} ".format(total_nc))
def ground_truth_check(grt, grt_path): if len(label_wrong) == 0 and not total_nc[0][0]:
""" logger.info(correct_print("label class check!"))
验证标签是否重零开始,标签值为0,1,...,num_classes-1, ingnore_idx
验证标签图像的格式
返回标签的像素数
检查图像是否都是ignore_index
params:
grt: 标签图
grt_path: 标签图路径
return:
png_format: 返回是否是png格式图片
label_correct: 返回标签是否是正确的
label_pixel_num: 返回标签的像素数
"""
if imghdr.what(grt_path) == "png":
png_format = True
else: else:
png_format = False logger.info(error_print("label class check!"))
if total_nc[0][0]:
logger.info("Warning: label classes should start from 0")
if len(label_wrong) > 0:
logger.info("fatal error: label class is out of range [0, {}]".format(cfg.DATASET.NUM_CLASSES - 1))
for i in label_wrong:
logger.debug(i)
unique, counts = np.unique(grt, return_counts=True)
return png_format, unique, counts
def eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio):
def eval_crop_size_check(max_height, max_width, min_aspectratio,
max_aspectratio):
""" """
判断eval_crop_siz与验证集及测试集的max_height, max_width的关系 判断eval_crop_siz与验证集及测试集的max_height, max_width的关系
param param
max_height: 数据集的最大高 max_height: 数据集的最大高
max_width: 数据集的最大宽 max_width: 数据集的最大宽
""" """
if cfg.AUG.AUG_METHOD == "stepscaling": if cfg.AUG.AUG_METHOD == "stepscaling":
flag = True if max_width <= cfg.EVAL_CROP_SIZE[0] or max_height <= cfg.EVAL_CROP_SIZE[1]:
if max_width > cfg.EVAL_CROP_SIZE[0]: logger.info(correct_print("EVAL_CROP_SIZE check"))
print( else:
"ERROR: The EVAL_CROP_SIZE[0]: {} should larger max width of images {}!" logger.info(error_print("EVAL_CROP_SIZE check"))
.format(cfg.EVAL_CROP_SIZE[0], max_width)) if max_width > cfg.EVAL_CROP_SIZE[0]:
flag = False logger.info("The EVAL_CROP_SIZE[0]: {} should larger max width of images {}!".format(
if max_height > cfg.EVAL_CROP_SIZE[1]: cfg.EVAL_CROP_SIZE[0], max_width))
print( if max_height > cfg.EVAL_CROP_SIZE[1]:
"ERROR: The EVAL_CROP_SIZE[1]: {} should larger max height of images {}!" logger.info(error_print("The EVAL_CROP_SIZE[1]: {} should larger max height of images {}!".format(
.format(cfg.EVAL_CROP_SIZE[1], max_height)) cfg.EVAL_CROP_SIZE[1], max_height)))
flag = False
if flag:
print("EVAL_CROP_SIZE setting correct")
elif cfg.AUG.AUG_METHOD == "rangescaling": elif cfg.AUG.AUG_METHOD == "rangescaling":
if min_aspectratio <= 1 and max_aspectratio >= 1: if min_aspectratio <= 1 and max_aspectratio >= 1:
if cfg.EVAL_CROP_SIZE[ if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.INF_RESIZE_VALUE:
0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[ logger.info(correct_print("EVAL_CROP_SIZE check"))
1] >= cfg.AUG.INF_RESIZE_VALUE:
print("EVAL_CROP_SIZE setting correct")
else: else:
print( logger.info(error_print("EVAL_CROP_SIZE check"))
"ERROR: EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.INF_RESIZE_VALUE)) cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.INF_RESIZE_VALUE))
elif min_aspectratio > 1: elif min_aspectratio > 1:
max_height_rangscaling = cfg.AUG.INF_RESIZE_VALUE / min_aspectratio max_height_rangscaling = cfg.AUG.INF_RESIZE_VALUE / min_aspectratio
max_height_rangscaling = round(max_height_rangscaling) max_height_rangscaling = round(max_height_rangscaling)
if cfg.EVAL_CROP_SIZE[ if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[1] >= max_height_rangscaling:
0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[ logger.info(correct_print("EVAL_CROP_SIZE check"))
1] >= max_height_rangscaling:
print("EVAL_CROP_SIZE setting correct")
else: else:
print( logger.info(error_print("EVAL_CROP_SIZE check"))
"ERROR: EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
cfg.AUG.INF_RESIZE_VALUE, max_height_rangscaling)) cfg.AUG.INF_RESIZE_VALUE, max_height_rangscaling))
elif max_aspectratio < 1: elif max_aspectratio < 1:
max_width_rangscaling = cfg.AUG.INF_RESIZE_VALUE * max_aspectratio max_width_rangscaling = cfg.AUG.INF_RESIZE_VALUE * max_aspectratio
max_width_rangscaling = round(max_width_rangscaling) max_width_rangscaling = round(max_width_rangscaling)
if cfg.EVAL_CROP_SIZE[ if cfg.EVAL_CROP_SIZE[0] >= max_width_rangscaling and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.INF_RESIZE_VALUE:
0] >= max_width_rangscaling and cfg.EVAL_CROP_SIZE[ logger.info(correct_print("EVAL_CROP_SIZE check"))
1] >= cfg.AUG.INF_RESIZE_VALUE:
print("EVAL_CROP_SIZE setting correct")
else: else:
print( logger.info(error_print("EVAL_CROP_SIZE check"))
"ERROR: EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
max_width_rangscaling, cfg.AUG.INF_RESIZE_VALUE)) max_width_rangscaling, cfg.AUG.INF_RESIZE_VALUE))
elif cfg.AUG.AUG_METHOD == "unpadding": elif cfg.AUG.AUG_METHOD == "unpadding":
if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.FIX_RESIZE_SIZE[ if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.FIX_RESIZE_SIZE[0] and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.FIX_RESIZE_SIZE[1]:
0] and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.FIX_RESIZE_SIZE[1]: logger.info(correct_print("EVAL_CROP_SIZE check"))
print("EVAL_CROP_SIZE setting correct")
else: else:
print( logger.info(error_print("EVAL_CROP_SIZE check"))
"ERROR: EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], .format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
cfg.AUG.FIX_RESIZE_SIZE[0], cfg.AUG.FIX_RESIZE_SIZE[1])) cfg.AUG.FIX_RESIZE_SIZE[0], cfg.AUG.FIX_RESIZE_SIZE[1]))
else: else:
print( logger.info("\nERROR! cfg.AUG.AUG_METHOD setting wrong, it should be one of "
"ERROR: cfg.AUG.AUG_METHOD setting wrong, it should be one of [unpadding, stepscaling, rangescaling]" "[unpadding, stepscaling, rangescaling]")
)
def inf_resize_value_check(): def inf_resize_value_check():
if cfg.AUG.AUG_METHOD == "rangescaling": if cfg.AUG.AUG_METHOD == "rangescaling":
if cfg.AUG.INF_RESIZE_VALUE < cfg.AUG.MIN_RESIZE_VALUE or \ if cfg.AUG.INF_RESIZE_VALUE < cfg.AUG.MIN_RESIZE_VALUE or \
cfg.AUG.INF_RESIZE_VALUE > cfg.AUG.MIN_RESIZE_VALUE: cfg.AUG.INF_RESIZE_VALUE > cfg.AUG.MIN_RESIZE_VALUE:
print( logger.info("\nWARNING! you set AUG.AUG_METHOD = 'rangescaling'"
"ERROR: you set AUG.AUG_METHOD = 'rangescaling'" "AUG.INF_RESIZE_VALUE: {} not in [AUG.MIN_RESIZE_VALUE, AUG.MAX_RESIZE_VALUE]: "
"AUG.INF_RESIZE_VALUE: {} not in [AUG.MIN_RESIZE_VALUE, AUG.MAX_RESIZE_VALUE]: " "[{}, {}].".format(cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.MIN_RESIZE_VALUE, cfg.AUG.MAX_RESIZE_VALUE))
"[{}, {}].".format(cfg.AUG.INF_RESIZE_VALUE,
cfg.AUG.MIN_RESIZE_VALUE,
cfg.AUG.MAX_RESIZE_VALUE))
def image_type_check(img_dim): def image_type_check(img_dim):
""" """
...@@ -256,166 +281,189 @@ def image_type_check(img_dim): ...@@ -256,166 +281,189 @@ def image_type_check(img_dim):
return return
""" """
if (1 in img_dim or 3 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgba': if (1 in img_dim or 3 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgba':
print( logger.info(error_print("DATASET.IMAGE_TYPE check"))
"ERROR: DATASET.IMAGE_TYPE is {} but the type of image has gray or rgb\n" logger.info("DATASET.IMAGE_TYPE is {} but the type of image has "
.format(cfg.DATASET.IMAGE_TYPE)) "gray or rgb\n".format(cfg.DATASET.IMAGE_TYPE))
# elif (1 not in img_dim and 3 not in img_dim and 4 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgb': elif (1 not in img_dim and 3 not in img_dim and 4 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgb':
# print("ERROR: DATASET.IMAGE_TYPE is {} but the type of image is rgba\n".format(cfg.DATASET.IMAGE_TYPE)) logger.info(correct_print("DATASET.IMAGE_TYPE check"))
logger.info("\nWARNING: DATASET.IMAGE_TYPE is {} but the type of all image is rgba".format(cfg.DATASET.IMAGE_TYPE))
else: else:
print("DATASET.IMAGE_TYPE setting correct") logger.info(correct_print("DATASET.IMAGE_TYPE check"))
def shape_check():
"""输出shape校验结果"""
if len(shape_unequal_image) == 0:
logger.info(correct_print("shape check"))
logger.info("All images are the same shape as the labels")
else:
logger.info(error_print("shape check"))
logger.info("Some images are not the same shape as the labels as follow: ")
for i in shape_unequal_image:
logger.debug(i)
def image_label_shape_check(img, grt):
"""
验证图像和标签的大小是否匹配
"""
flag = True def file_list_check(list_name):
img_height = img.shape[0] """检查分割符是否复合要求"""
img_width = img.shape[1] if len(list_wrong) == 0:
grt_height = grt.shape[0] logger.info(correct_print(list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check"))
grt_width = grt.shape[1] else:
logger.info(error_print(list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check"))
logger.info("The following list is not separated by {}".format(cfg.DATASET.SEPARATOR))
for i in list_wrong:
logger.debug(i)
def imread_check():
if len(imread_failed) == 0:
logger.info(correct_print("dataset reading check"))
logger.info("All images can be read successfully")
else:
logger.info(error_print("dataset reading check"))
logger.info("Failed to read {} images".format(len(imread_failed)))
for i in imread_failed:
logger.debug(i)
if img_height != grt_height or img_width != grt_width:
flag = False
return flag
def check_train_dataset(): def check_train_dataset():
train_list = cfg.DATASET.TRAIN_FILE_LIST list_file = cfg.DATASET.TRAIN_FILE_LIST
print("\ncheck train dataset...") logger.info("-----------------------------\n1. Check train dataset...")
with open(train_list, 'r') as fid: with open(list_file, 'r') as fid:
img_dim = []
lines = fid.readlines() lines = fid.readlines()
for line in tqdm(lines): for line in tqdm(lines):
parts = line.strip().split(cfg.DATASET.SEPARATOR) line = line.strip()
parts = line.split(cfg.DATASET.SEPARATOR)
if len(parts) != 2: if len(parts) != 2:
print( list_wrong.append(line)
line, "File list format incorrect! It should be"
" image_name{}label_name\\n ".format(cfg.DATASET.SEPARATOR))
continue continue
img_name, grt_name = parts[0], parts[1] img_name, grt_name = parts[0], parts[1]
img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name) img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name)
grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name) grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) try:
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
except Exception as e:
imread_failed.append((line, str(e)))
continue
get_image_dim(img, img_dim) get_image_dim(img)
is_equal_img_grt_shape = image_label_shape_check(img, grt) is_equal_img_grt_shape = image_label_shape_check(img, grt)
if not is_equal_img_grt_shape: if not is_equal_img_grt_shape:
print(line, shape_unequal_image.append(line)
"ERROR: source img and label img must has the same size")
png_format, grt_classes, num_of_each_class = ground_truth_check( png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path)
grt, grt_path) if not png_format:
sum_gt_check(png_format, grt_classes, num_of_each_class) png_format_wrong_image.append(line)
is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class)
if not is_label_correct:
label_wrong.append(line)
file_list_check(list_file)
imread_check()
gt_check() gt_check()
image_type_check(img_dim) image_type_check(img_dim)
shape_check()
def check_val_dataset(): def check_val_dataset():
val_list = cfg.DATASET.VAL_FILE_LIST list_file = cfg.DATASET.VAL_FILE_LIST
with open(val_list) as fid: logger.info("\n-----------------------------\n2. Check val dataset...")
max_height = 0 with open(list_file) as fid:
max_width = 0
min_aspectratio = sys.float_info.max
max_aspectratio = 0.0
img_dim = []
print("check val dataset...")
lines = fid.readlines() lines = fid.readlines()
for line in tqdm(lines): for line in tqdm(lines):
parts = line.strip().split(cfg.DATASET.SEPARATOR) line = line.strip()
parts = line.split(cfg.DATASET.SEPARATOR)
if len(parts) != 2: if len(parts) != 2:
print( list_wrong.append(line)
line, "File list format incorrect! It should be"
" image_name{}label_name\\n ".format(cfg.DATASET.SEPARATOR))
continue continue
img_name, grt_name = parts[0], parts[1] img_name, grt_name = parts[0], parts[1]
img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name) img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name)
grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name) grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) try:
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
max_height, max_width = get_image_max_height_width( except Exception as e:
img, max_height, max_width) imread_failed.append((line, e.message))
min_aspectratio, max_aspectratio = get_image_min_max_aspectratio( get_image_max_height_width(img)
img, min_aspectratio, max_aspectratio) get_image_min_max_aspectratio(img)
get_image_dim(img, img_dim) get_image_dim(img)
is_equal_img_grt_shape = image_label_shape_check(img, grt) is_equal_img_grt_shape = image_label_shape_check(img, grt)
if not is_equal_img_grt_shape: if not is_equal_img_grt_shape:
print(line, shape_unequal_image.append(line)
"ERROR: source img and label img must has the same size") png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path)
if not png_format:
png_format, grt_classes, num_of_each_class = ground_truth_check( png_format_wrong_image.append(line)
grt, grt_path) is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class)
sum_gt_check(png_format, grt_classes, num_of_each_class) if not is_label_correct:
label_wrong.append(line)
file_list_check(list_file)
imread_check()
gt_check() gt_check()
eval_crop_size_check(max_height, max_width, min_aspectratio,
max_aspectratio)
image_type_check(img_dim) image_type_check(img_dim)
shape_check()
eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio)
def check_test_dataset(): def check_test_dataset():
test_list = cfg.DATASET.TEST_FILE_LIST list_file = cfg.DATASET.TEST_FILE_LIST
with open(test_list) as fid: has_label = False
max_height = 0 with open(list_file) as fid:
max_width = 0 logger.info("\n-----------------------------\n3. Check test dataset...")
min_aspectratio = sys.float_info.max
max_aspectratio = 0.0
img_dim = []
print("check test dataset...")
lines = fid.readlines() lines = fid.readlines()
for line in tqdm(lines): for line in tqdm(lines):
parts = line.strip().split(cfg.DATASET.SEPARATOR) line = line.strip()
parts = line.split(cfg.DATASET.SEPARATOR)
if len(parts) == 1: if len(parts) == 1:
img_name = parts img_name = parts
img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name) img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name[0])
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) try:
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
except Exception as e:
imread_failed.append((line, str(e)))
continue
elif len(parts) == 2: elif len(parts) == 2:
has_label = True
img_name, grt_name = parts[0], parts[1] img_name, grt_name = parts[0], parts[1]
img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name) img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name)
grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name) grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) try:
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
except Exception as e:
imread_failed.append((line, e.message))
continue
is_equal_img_grt_shape = image_label_shape_check(img, grt) is_equal_img_grt_shape = image_label_shape_check(img, grt)
if not is_equal_img_grt_shape: if not is_equal_img_grt_shape:
print( shape_unequal_image.append(line)
line, png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path)
"ERROR: source img and label img must has the same size" if not png_format:
) png_format_wrong_image.append(line)
is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class)
png_format, grt_classes, num_of_each_class = ground_truth_check( if not is_label_correct:
grt, grt_path) label_wrong.append(line)
sum_gt_check(png_format, grt_classes, num_of_each_class)
else: else:
print( list_wrong.append(lines)
line, "File list format incorrect! It should be"
" image_name{}label_name\\n or image_name\n ".format(
cfg.DATASET.SEPARATOR))
continue continue
get_image_max_height_width(img)
max_height, max_width = get_image_max_height_width( get_image_min_max_aspectratio(img)
img, max_height, max_width) get_image_dim(img)
min_aspectratio, max_aspectratio = get_image_min_max_aspectratio(
img, min_aspectratio, max_aspectratio) file_list_check(list_file)
get_image_dim(img, img_dim) imread_check()
if has_label:
gt_check() gt_check()
eval_crop_size_check(max_height, max_width, min_aspectratio,
max_aspectratio)
image_type_check(img_dim) image_type_check(img_dim)
if has_label:
shape_check()
eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio)
def main(args): def main(args):
if args.cfg_file is not None: if args.cfg_file is not None:
cfg.update_from_file(args.cfg_file) cfg.update_from_file(args.cfg_file)
cfg.check_and_infer(reset_dataset=True) cfg.check_and_infer(reset_dataset=True)
print(pprint.pformat(cfg)) logger.info(pprint.pformat(cfg))
init_global_variable() init_global_variable()
check_train_dataset() check_train_dataset()
...@@ -428,8 +476,19 @@ def main(args): ...@@ -428,8 +476,19 @@ def main(args):
inf_resize_value_check() inf_resize_value_check()
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
args.cfg_file = "../configs/cityscape.yaml" logger = logging.getLogger()
logger.setLevel('DEBUG')
BASIC_FORMAT = "%(message)s"
formatter = logging.Formatter(BASIC_FORMAT)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
sh.setLevel('INFO')
th = logging.FileHandler('detail.log', 'w')
th.setFormatter(formatter)
logger.addHandler(sh)
logger.addHandler(th)
main(args) main(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册