提交 403b4a5b 编写于 作者: W wuyefeilin 提交者: wuzewu

update check.py (#3)

* update check.py
上级 08f001b0
......@@ -12,76 +12,134 @@ import argparse
import cv2
from tqdm import tqdm
import imghdr
import logging
from utils.config import cfg
def init_global_variable():
"""
初始化全局变量
"""
global png_format_right_num # 格式错误的标签图数量
global png_format_wrong_num # 格式错误的标图数量
global total_grt_classes # 总的标类别
global png_format_right_num # 格式正确的标注图数量
global png_format_wrong_num # 格式错误的标图数量
global total_grt_classes # 总的标类别
global total_num_of_each_class # 每个类别总的像素数
global shape_unequal # 图片和标签shape不一致
global png_format_wrong # 标签格式错误
global shape_unequal_image # 图片和标注shape不一致列表
global png_format_wrong_image # 标注格式错误列表
global max_width # 图片最长宽
global max_height # 图片最长高
global min_aspectratio # 图片最小宽高比
global max_aspectratio # 图片最大宽高比
global img_dim # 图片的通道数
global list_wrong #文件名格式错误列表
global imread_failed #图片读取失败列表, 二元列表
global label_wrong # 标注图片出错列表
png_format_right_num = 0
png_format_wrong_num = 0
total_grt_classes = []
total_num_of_each_class = []
shape_unequal = []
png_format_wrong = []
shape_unequal_image = []
png_format_wrong_image = []
max_width = 0
max_height = 0
min_aspectratio = sys.float_info.max
max_aspectratio = 0
img_dim = []
list_wrong = []
imread_failed = []
label_wrong = []
def parse_args():
parser = argparse.ArgumentParser(description='PaddleSeg check')
parser.add_argument(
'--cfg',
dest='cfg_file',
help='Config file for training (and optionally testing)',
default=None,
type=str)
'--cfg',
dest='cfg_file',
help='Config file for training (and optionally testing)',
default=None,
type=str
)
return parser.parse_args()
def error_print(str):
return "".join(["\nNOT PASS ", str])
def correct_print(str):
return "".join(["\nPASS ", str])
def cv2_imread(file_path, flag=cv2.IMREAD_COLOR):
# resolve cv2.imread open Chinese file path issues on Windows Platform.
return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag)
def get_image_max_height_width(img, max_height, max_width):
def get_image_max_height_width(img):
global max_width, max_height
img_shape = img.shape
height, width = img_shape[0], img_shape[1]
max_height = max(height, max_height)
max_width = max(width, max_width)
return max_height, max_width
def get_image_min_max_aspectratio(img, min_aspectratio, max_aspectratio):
def get_image_min_max_aspectratio(img):
global min_aspectratio, max_aspectratio
img_shape = img.shape
height, width = img_shape[0], img_shape[1]
min_aspectratio = min(width / height, min_aspectratio)
max_aspectratio = max(width / height, max_aspectratio)
min_aspectratio = min(width/height, min_aspectratio)
max_aspectratio = max(width/height, max_aspectratio)
return min_aspectratio, max_aspectratio
def get_image_dim(img, img_dim):
def get_image_dim(img):
"""获取图像的维度"""
img_shape = img.shape
if img_shape[-1] not in img_dim:
img_dim.append(img_shape[-1])
def image_label_shape_check(img, grt):
"""
验证图像和标注的大小是否匹配
"""
flag = True
img_height = img.shape[0]
img_width = img.shape[1]
grt_height = grt.shape[0]
grt_width = grt.shape[1]
if img_height != grt_height or img_width != grt_width:
flag = False
return flag
def ground_truth_check(grt, grt_path):
"""
验证标注是否重零开始,标注值为0,1,...,num_classes-1, ingnore_idx
验证标注图像的格式
返回标注的像素数
检查图像是否都是ignore_index
params:
grt: 标注图
grt_path: 标注图路径
return:
png_format: 返回是否是png格式图片
label_correct: 返回标注是否是正确的
label_pixel_num: 返回标注的像素数
"""
if imghdr.what(grt_path) == "png":
png_format = True
else:
png_format = False
unique, counts = np.unique(grt, return_counts=True)
return png_format, unique, counts
def sum_gt_check(png_format, grt_classes, num_of_each_class):
"""
统计所有标图上的格式、类别和每个类别的像素数
统计所有标图上的格式、类别和每个类别的像素数
params:
png_format: 返回是否是png格式图片
grt_classes: 标类别
grt_classes: 标类别
num_of_each_class: 各个类别的像素数目
"""
is_label_correct = True
global png_format_right_num, png_format_wrong_num, total_grt_classes, total_num_of_each_class
if png_format:
......@@ -90,12 +148,11 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class):
png_format_wrong_num += 1
if cfg.DATASET.IGNORE_INDEX in grt_classes:
grt_classes2 = np.delete(
grt_classes, np.where(grt_classes == cfg.DATASET.IGNORE_INDEX))
grt_classes2 = np.delete(grt_classes, np.where(grt_classes == cfg.DATASET.IGNORE_INDEX))
else:
grt_classes2 = grt_classes
if min(grt_classes2) < 0 or max(grt_classes2) > cfg.DATASET.NUM_CLASSES - 1:
print("fatal error: label class is out of range [0, {}]".format(
cfg.DATASET.NUM_CLASSES - 1))
is_label_correct = False
add_class = []
add_num = []
for i in range(len(grt_classes)):
......@@ -108,145 +165,113 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class):
add_num.append(num_of_each_class[i])
total_num_of_each_class += add_num
total_grt_classes += add_class
return is_label_correct
def gt_check():
"""
对标签进行校验,输出校验结果
params:
png_format_wrong_num: 格式错误的标签图数量
png_format_right_num: 格式正确的标签图数量
total_grt_classes: 总的标签类别
total_num_of_each_class: 每个类别总的像素数目
return:
total_nc: 按升序排序后的总标签类别和像素数目
对标注图像进行校验,输出校验结果
"""
if png_format_wrong_num == 0:
print("Not pass label png format check!")
if png_format_right_num:
logger.info(correct_print("label format check"))
else:
logger.info(error_print("label format check"))
logger.info("No label image to check")
return
else:
print("Pass label png format check!")
print(
"total {} label imgs are png format, {} label imgs are not png fromat".
format(png_format_right_num, png_format_wrong_num))
logger.info(error_print("label format check"))
logger.info("total {} label images are png format, {} label images are not png "
"format".format(png_format_right_num, png_format_wrong_num))
if len(png_format_wrong_image) > 0:
for i in png_format_wrong_image:
logger.debug(i)
total_nc = sorted(zip(total_grt_classes, total_num_of_each_class))
print("total label calsses and their corresponding numbers:\n{} ".format(
total_nc))
if total_nc[0][0]:
print(
"Not pass label class check!\nWarning: label classes should start from 0 !!!"
)
else:
print("Pass label class check!")
total_nc = sorted(zip(total_grt_classes, total_num_of_each_class))
logger.info("\nDoing label pixel statistics...\nTotal label calsses "
"and their corresponding numbers:\n{} ".format(total_nc))
def ground_truth_check(grt, grt_path):
"""
验证标签是否重零开始,标签值为0,1,...,num_classes-1, ingnore_idx
验证标签图像的格式
返回标签的像素数
检查图像是否都是ignore_index
params:
grt: 标签图
grt_path: 标签图路径
return:
png_format: 返回是否是png格式图片
label_correct: 返回标签是否是正确的
label_pixel_num: 返回标签的像素数
"""
if imghdr.what(grt_path) == "png":
png_format = True
if len(label_wrong) == 0 and not total_nc[0][0]:
logger.info(correct_print("label class check!"))
else:
png_format = False
logger.info(error_print("label class check!"))
if total_nc[0][0]:
logger.info("Warning: label classes should start from 0")
if len(label_wrong) > 0:
logger.info("fatal error: label class is out of range [0, {}]".format(cfg.DATASET.NUM_CLASSES - 1))
for i in label_wrong:
logger.debug(i)
unique, counts = np.unique(grt, return_counts=True)
return png_format, unique, counts
def eval_crop_size_check(max_height, max_width, min_aspectratio,
max_aspectratio):
def eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio):
"""
判断eval_crop_siz与验证集及测试集的max_height, max_width的关系
param
max_height: 数据集的最大高
max_width: 数据集的最大宽
"""
if cfg.AUG.AUG_METHOD == "stepscaling":
flag = True
if max_width > cfg.EVAL_CROP_SIZE[0]:
print(
"ERROR: The EVAL_CROP_SIZE[0]: {} should larger max width of images {}!"
.format(cfg.EVAL_CROP_SIZE[0], max_width))
flag = False
if max_height > cfg.EVAL_CROP_SIZE[1]:
print(
"ERROR: The EVAL_CROP_SIZE[1]: {} should larger max height of images {}!"
.format(cfg.EVAL_CROP_SIZE[1], max_height))
flag = False
if flag:
print("EVAL_CROP_SIZE setting correct")
if max_width <= cfg.EVAL_CROP_SIZE[0] or max_height <= cfg.EVAL_CROP_SIZE[1]:
logger.info(correct_print("EVAL_CROP_SIZE check"))
else:
logger.info(error_print("EVAL_CROP_SIZE check"))
if max_width > cfg.EVAL_CROP_SIZE[0]:
logger.info("The EVAL_CROP_SIZE[0]: {} should larger max width of images {}!".format(
cfg.EVAL_CROP_SIZE[0], max_width))
if max_height > cfg.EVAL_CROP_SIZE[1]:
logger.info(error_print("The EVAL_CROP_SIZE[1]: {} should larger max height of images {}!".format(
cfg.EVAL_CROP_SIZE[1], max_height)))
elif cfg.AUG.AUG_METHOD == "rangescaling":
if min_aspectratio <= 1 and max_aspectratio >= 1:
if cfg.EVAL_CROP_SIZE[
0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[
1] >= cfg.AUG.INF_RESIZE_VALUE:
print("EVAL_CROP_SIZE setting correct")
if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.INF_RESIZE_VALUE:
logger.info(correct_print("EVAL_CROP_SIZE check"))
else:
print(
"ERROR: EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
logger.info(error_print("EVAL_CROP_SIZE check"))
logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.INF_RESIZE_VALUE))
elif min_aspectratio > 1:
max_height_rangscaling = cfg.AUG.INF_RESIZE_VALUE / min_aspectratio
max_height_rangscaling = round(max_height_rangscaling)
if cfg.EVAL_CROP_SIZE[
0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[
1] >= max_height_rangscaling:
print("EVAL_CROP_SIZE setting correct")
if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[1] >= max_height_rangscaling:
logger.info(correct_print("EVAL_CROP_SIZE check"))
else:
print(
"ERROR: EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
cfg.AUG.INF_RESIZE_VALUE, max_height_rangscaling))
logger.info(error_print("EVAL_CROP_SIZE check"))
logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
cfg.AUG.INF_RESIZE_VALUE, max_height_rangscaling))
elif max_aspectratio < 1:
max_width_rangscaling = cfg.AUG.INF_RESIZE_VALUE * max_aspectratio
max_width_rangscaling = round(max_width_rangscaling)
if cfg.EVAL_CROP_SIZE[
0] >= max_width_rangscaling and cfg.EVAL_CROP_SIZE[
1] >= cfg.AUG.INF_RESIZE_VALUE:
print("EVAL_CROP_SIZE setting correct")
if cfg.EVAL_CROP_SIZE[0] >= max_width_rangscaling and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.INF_RESIZE_VALUE:
logger.info(correct_print("EVAL_CROP_SIZE check"))
else:
print(
"ERROR: EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
logger.info(error_print("EVAL_CROP_SIZE check"))
logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
max_width_rangscaling, cfg.AUG.INF_RESIZE_VALUE))
elif cfg.AUG.AUG_METHOD == "unpadding":
if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.FIX_RESIZE_SIZE[
0] and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.FIX_RESIZE_SIZE[1]:
print("EVAL_CROP_SIZE setting correct")
if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.FIX_RESIZE_SIZE[0] and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.FIX_RESIZE_SIZE[1]:
logger.info(correct_print("EVAL_CROP_SIZE check"))
else:
print(
"ERROR: EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
cfg.AUG.FIX_RESIZE_SIZE[0], cfg.AUG.FIX_RESIZE_SIZE[1]))
logger.info(error_print("EVAL_CROP_SIZE check"))
logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})"
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
cfg.AUG.FIX_RESIZE_SIZE[0], cfg.AUG.FIX_RESIZE_SIZE[1]))
else:
print(
"ERROR: cfg.AUG.AUG_METHOD setting wrong, it should be one of [unpadding, stepscaling, rangescaling]"
)
logger.info("\nERROR! cfg.AUG.AUG_METHOD setting wrong, it should be one of "
"[unpadding, stepscaling, rangescaling]")
def inf_resize_value_check():
if cfg.AUG.AUG_METHOD == "rangescaling":
if cfg.AUG.INF_RESIZE_VALUE < cfg.AUG.MIN_RESIZE_VALUE or \
cfg.AUG.INF_RESIZE_VALUE > cfg.AUG.MIN_RESIZE_VALUE:
print(
"ERROR: you set AUG.AUG_METHOD = 'rangescaling'"
"AUG.INF_RESIZE_VALUE: {} not in [AUG.MIN_RESIZE_VALUE, AUG.MAX_RESIZE_VALUE]: "
"[{}, {}].".format(cfg.AUG.INF_RESIZE_VALUE,
cfg.AUG.MIN_RESIZE_VALUE,
cfg.AUG.MAX_RESIZE_VALUE))
logger.info("\nWARNING! you set AUG.AUG_METHOD = 'rangescaling'"
"AUG.INF_RESIZE_VALUE: {} not in [AUG.MIN_RESIZE_VALUE, AUG.MAX_RESIZE_VALUE]: "
"[{}, {}].".format(cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.MIN_RESIZE_VALUE, cfg.AUG.MAX_RESIZE_VALUE))
def image_type_check(img_dim):
"""
......@@ -256,166 +281,189 @@ def image_type_check(img_dim):
return
"""
if (1 in img_dim or 3 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgba':
print(
"ERROR: DATASET.IMAGE_TYPE is {} but the type of image has gray or rgb\n"
.format(cfg.DATASET.IMAGE_TYPE))
# elif (1 not in img_dim and 3 not in img_dim and 4 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgb':
# print("ERROR: DATASET.IMAGE_TYPE is {} but the type of image is rgba\n".format(cfg.DATASET.IMAGE_TYPE))
logger.info(error_print("DATASET.IMAGE_TYPE check"))
logger.info("DATASET.IMAGE_TYPE is {} but the type of image has "
"gray or rgb\n".format(cfg.DATASET.IMAGE_TYPE))
elif (1 not in img_dim and 3 not in img_dim and 4 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgb':
logger.info(correct_print("DATASET.IMAGE_TYPE check"))
logger.info("\nWARNING: DATASET.IMAGE_TYPE is {} but the type of all image is rgba".format(cfg.DATASET.IMAGE_TYPE))
else:
print("DATASET.IMAGE_TYPE setting correct")
logger.info(correct_print("DATASET.IMAGE_TYPE check"))
def shape_check():
"""输出shape校验结果"""
if len(shape_unequal_image) == 0:
logger.info(correct_print("shape check"))
logger.info("All images are the same shape as the labels")
else:
logger.info(error_print("shape check"))
logger.info("Some images are not the same shape as the labels as follow: ")
for i in shape_unequal_image:
logger.debug(i)
def image_label_shape_check(img, grt):
"""
验证图像和标签的大小是否匹配
"""
flag = True
img_height = img.shape[0]
img_width = img.shape[1]
grt_height = grt.shape[0]
grt_width = grt.shape[1]
def file_list_check(list_name):
"""检查分割符是否复合要求"""
if len(list_wrong) == 0:
logger.info(correct_print(list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check"))
else:
logger.info(error_print(list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check"))
logger.info("The following list is not separated by {}".format(cfg.DATASET.SEPARATOR))
for i in list_wrong:
logger.debug(i)
def imread_check():
if len(imread_failed) == 0:
logger.info(correct_print("dataset reading check"))
logger.info("All images can be read successfully")
else:
logger.info(error_print("dataset reading check"))
logger.info("Failed to read {} images".format(len(imread_failed)))
for i in imread_failed:
logger.debug(i)
if img_height != grt_height or img_width != grt_width:
flag = False
return flag
def check_train_dataset():
train_list = cfg.DATASET.TRAIN_FILE_LIST
print("\ncheck train dataset...")
with open(train_list, 'r') as fid:
img_dim = []
list_file = cfg.DATASET.TRAIN_FILE_LIST
logger.info("-----------------------------\n1. Check train dataset...")
with open(list_file, 'r') as fid:
lines = fid.readlines()
for line in tqdm(lines):
parts = line.strip().split(cfg.DATASET.SEPARATOR)
line = line.strip()
parts = line.split(cfg.DATASET.SEPARATOR)
if len(parts) != 2:
print(
line, "File list format incorrect! It should be"
" image_name{}label_name\\n ".format(cfg.DATASET.SEPARATOR))
list_wrong.append(line)
continue
img_name, grt_name = parts[0], parts[1]
img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name)
grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
try:
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
except Exception as e:
imread_failed.append((line, str(e)))
continue
get_image_dim(img, img_dim)
get_image_dim(img)
is_equal_img_grt_shape = image_label_shape_check(img, grt)
if not is_equal_img_grt_shape:
print(line,
"ERROR: source img and label img must has the same size")
shape_unequal_image.append(line)
png_format, grt_classes, num_of_each_class = ground_truth_check(
grt, grt_path)
sum_gt_check(png_format, grt_classes, num_of_each_class)
png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path)
if not png_format:
png_format_wrong_image.append(line)
is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class)
if not is_label_correct:
label_wrong.append(line)
file_list_check(list_file)
imread_check()
gt_check()
image_type_check(img_dim)
shape_check()
def check_val_dataset():
val_list = cfg.DATASET.VAL_FILE_LIST
with open(val_list) as fid:
max_height = 0
max_width = 0
min_aspectratio = sys.float_info.max
max_aspectratio = 0.0
img_dim = []
print("check val dataset...")
list_file = cfg.DATASET.VAL_FILE_LIST
logger.info("\n-----------------------------\n2. Check val dataset...")
with open(list_file) as fid:
lines = fid.readlines()
for line in tqdm(lines):
parts = line.strip().split(cfg.DATASET.SEPARATOR)
line = line.strip()
parts = line.split(cfg.DATASET.SEPARATOR)
if len(parts) != 2:
print(
line, "File list format incorrect! It should be"
" image_name{}label_name\\n ".format(cfg.DATASET.SEPARATOR))
list_wrong.append(line)
continue
img_name, grt_name = parts[0], parts[1]
img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name)
grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
max_height, max_width = get_image_max_height_width(
img, max_height, max_width)
min_aspectratio, max_aspectratio = get_image_min_max_aspectratio(
img, min_aspectratio, max_aspectratio)
get_image_dim(img, img_dim)
try:
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
except Exception as e:
imread_failed.append((line, e.message))
get_image_max_height_width(img)
get_image_min_max_aspectratio(img)
get_image_dim(img)
is_equal_img_grt_shape = image_label_shape_check(img, grt)
if not is_equal_img_grt_shape:
print(line,
"ERROR: source img and label img must has the same size")
png_format, grt_classes, num_of_each_class = ground_truth_check(
grt, grt_path)
sum_gt_check(png_format, grt_classes, num_of_each_class)
shape_unequal_image.append(line)
png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path)
if not png_format:
png_format_wrong_image.append(line)
is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class)
if not is_label_correct:
label_wrong.append(line)
file_list_check(list_file)
imread_check()
gt_check()
eval_crop_size_check(max_height, max_width, min_aspectratio,
max_aspectratio)
image_type_check(img_dim)
shape_check()
eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio)
def check_test_dataset():
test_list = cfg.DATASET.TEST_FILE_LIST
with open(test_list) as fid:
max_height = 0
max_width = 0
min_aspectratio = sys.float_info.max
max_aspectratio = 0.0
img_dim = []
print("check test dataset...")
list_file = cfg.DATASET.TEST_FILE_LIST
has_label = False
with open(list_file) as fid:
logger.info("\n-----------------------------\n3. Check test dataset...")
lines = fid.readlines()
for line in tqdm(lines):
parts = line.strip().split(cfg.DATASET.SEPARATOR)
line = line.strip()
parts = line.split(cfg.DATASET.SEPARATOR)
if len(parts) == 1:
img_name = parts
img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name)
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name[0])
try:
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
except Exception as e:
imread_failed.append((line, str(e)))
continue
elif len(parts) == 2:
has_label = True
img_name, grt_name = parts[0], parts[1]
img_path = os.path.join(cfg.DATASET.DATA_DIR, img_name)
grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
try:
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
except Exception as e:
imread_failed.append((line, e.message))
continue
is_equal_img_grt_shape = image_label_shape_check(img, grt)
if not is_equal_img_grt_shape:
print(
line,
"ERROR: source img and label img must has the same size"
)
png_format, grt_classes, num_of_each_class = ground_truth_check(
grt, grt_path)
sum_gt_check(png_format, grt_classes, num_of_each_class)
shape_unequal_image.append(line)
png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path)
if not png_format:
png_format_wrong_image.append(line)
is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class)
if not is_label_correct:
label_wrong.append(line)
else:
print(
line, "File list format incorrect! It should be"
" image_name{}label_name\\n or image_name\n ".format(
cfg.DATASET.SEPARATOR))
list_wrong.append(lines)
continue
max_height, max_width = get_image_max_height_width(
img, max_height, max_width)
min_aspectratio, max_aspectratio = get_image_min_max_aspectratio(
img, min_aspectratio, max_aspectratio)
get_image_dim(img, img_dim)
gt_check()
eval_crop_size_check(max_height, max_width, min_aspectratio,
max_aspectratio)
get_image_max_height_width(img)
get_image_min_max_aspectratio(img)
get_image_dim(img)
file_list_check(list_file)
imread_check()
if has_label:
gt_check()
image_type_check(img_dim)
if has_label:
shape_check()
eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio)
def main(args):
if args.cfg_file is not None:
cfg.update_from_file(args.cfg_file)
cfg.check_and_infer(reset_dataset=True)
print(pprint.pformat(cfg))
logger.info(pprint.pformat(cfg))
init_global_variable()
check_train_dataset()
......@@ -428,8 +476,19 @@ def main(args):
inf_resize_value_check()
if __name__ == "__main__":
args = parse_args()
args.cfg_file = "../configs/cityscape.yaml"
logger = logging.getLogger()
logger.setLevel('DEBUG')
BASIC_FORMAT = "%(message)s"
formatter = logging.Formatter(BASIC_FORMAT)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
sh.setLevel('INFO')
th = logging.FileHandler('detail.log', 'w')
th.setFormatter(formatter)
logger.addHandler(sh)
logger.addHandler(th)
main(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册