提交 8b1504aa 编写于 作者: C chenguowei01

add lael gray check

上级 ff6c36f9
...@@ -34,6 +34,7 @@ def init_global_variable(): ...@@ -34,6 +34,7 @@ def init_global_variable():
global list_wrong #文件名格式错误列表 global list_wrong #文件名格式错误列表
global imread_failed #图片读取失败列表, 二元列表 global imread_failed #图片读取失败列表, 二元列表
global label_wrong # 标注图片出错列表 global label_wrong # 标注图片出错列表
global label_gray_wrong # 标注图非灰度图列表
png_format_right_num = 0 png_format_right_num = 0
png_format_wrong_num = 0 png_format_wrong_num = 0
...@@ -49,6 +50,7 @@ def init_global_variable(): ...@@ -49,6 +50,7 @@ def init_global_variable():
list_wrong = [] list_wrong = []
imread_failed = [] imread_failed = []
label_wrong = [] label_wrong = []
label_gray_wrong = []
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='PaddleSeg check') parser = argparse.ArgumentParser(description='PaddleSeg check')
...@@ -68,10 +70,13 @@ def correct_print(str): ...@@ -68,10 +70,13 @@ def correct_print(str):
return "".join(["\nPASS ", str]) return "".join(["\nPASS ", str])
def cv2_imread(file_path, flag=cv2.IMREAD_COLOR): def cv2_imread(file_path, flag=cv2.IMREAD_COLOR):
# resolve cv2.imread open Chinese file path issues on Windows Platform. """
解决 cv2.imread 在window平台打开中文路径的问题.
"""
return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag) return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag)
def get_image_max_height_width(img): def get_image_max_height_width(img):
"""获取图片最大宽和高"""
global max_width, max_height global max_width, max_height
img_shape = img.shape img_shape = img.shape
height, width = img_shape[0], img_shape[1] height, width = img_shape[0], img_shape[1]
...@@ -79,6 +84,7 @@ def get_image_max_height_width(img): ...@@ -79,6 +84,7 @@ def get_image_max_height_width(img):
max_width = max(width, max_width) max_width = max(width, max_width)
def get_image_min_max_aspectratio(img): def get_image_min_max_aspectratio(img):
"""计算图片最大宽高比"""
global min_aspectratio, max_aspectratio global min_aspectratio, max_aspectratio
img_shape = img.shape img_shape = img.shape
height, width = img_shape[0], img_shape[1] height, width = img_shape[0], img_shape[1]
...@@ -87,11 +93,19 @@ def get_image_min_max_aspectratio(img): ...@@ -87,11 +93,19 @@ def get_image_min_max_aspectratio(img):
return min_aspectratio, max_aspectratio return min_aspectratio, max_aspectratio
def get_image_dim(img): def get_image_dim(img):
"""获取图像的维度""" """获取图像的通道数"""
img_shape = img.shape img_shape = img.shape
if img_shape[-1] not in img_dim: if img_shape[-1] not in img_dim:
img_dim.append(img_shape[-1]) img_dim.append(img_shape[-1])
def is_label_gray(grt):
"""判断标签是否为灰度图"""
grt_shape = grt.shape
if len(grt_shape) == 2:
return True
else:
return False
def image_label_shape_check(img, grt): def image_label_shape_check(img, grt):
""" """
验证图像和标注的大小是否匹配 验证图像和标注的大小是否匹配
...@@ -110,17 +124,15 @@ def image_label_shape_check(img, grt): ...@@ -110,17 +124,15 @@ def image_label_shape_check(img, grt):
def ground_truth_check(grt, grt_path): def ground_truth_check(grt, grt_path):
""" """
验证标注是否重零开始,标注值为0,1,...,num_classes-1, ingnore_idx
验证标注图像的格式 验证标注图像的格式
返回标注的像素数 统计标注图类别和像素数
检查图像是否都是ignore_index
params: params:
grt: 标注图 grt: 标注图
grt_path: 标注图路径 grt_path: 标注图路径
return: return:
png_format: 返回是否是png格式图片 png_format: 返回是否是png格式图片
label_correct: 返回标注是否是正确的 unique: 返回标注类别
label_pixel_num: 返回标注的像素数 counts: 返回标注的像素数
""" """
if imghdr.what(grt_path) == "png": if imghdr.what(grt_path) == "png":
png_format = True png_format = True
...@@ -135,7 +147,7 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class): ...@@ -135,7 +147,7 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class):
""" """
统计所有标注图上的格式、类别和每个类别的像素数 统计所有标注图上的格式、类别和每个类别的像素数
params: params:
png_format: 返回是否是png格式图片 png_format: 是否是png格式图片
grt_classes: 标注类别 grt_classes: 标注类别
num_of_each_class: 各个类别的像素数目 num_of_each_class: 各个类别的像素数目
""" """
...@@ -322,6 +334,16 @@ def imread_check(): ...@@ -322,6 +334,16 @@ def imread_check():
for i in imread_failed: for i in imread_failed:
logger.debug(i) logger.debug(i)
def label_gray_check():
if len(label_gray_wrong) == 0:
logger.info(correct_print("label gray check"))
logger.info("All label images are gray")
else:
logger.info(error_print("label gray check"))
logger.info("{} label images are not gray".format(len(label_gray_wrong)))
for i in label_gray_wrong:
logger.debug(i)
def check_train_dataset(): def check_train_dataset():
...@@ -340,11 +362,15 @@ def check_train_dataset(): ...@@ -340,11 +362,15 @@ def check_train_dataset():
grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name) grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
try: try:
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) grt = cv2_imread(grt_path, cv2.IMREAD_UNCHANGED)
except Exception as e: except Exception as e:
imread_failed.append((line, str(e))) imread_failed.append((line, str(e)))
continue continue
is_label_gray = label_gray_check(grt)
if not is_label_gray:
label_gray_wrong.append(line)
grt = cv2.cvtColor(grt, cv2.COLOR_BGR2GRAY)
get_image_dim(img) get_image_dim(img)
is_equal_img_grt_shape = image_label_shape_check(img, grt) is_equal_img_grt_shape = image_label_shape_check(img, grt)
if not is_equal_img_grt_shape: if not is_equal_img_grt_shape:
...@@ -383,9 +409,14 @@ def check_val_dataset(): ...@@ -383,9 +409,14 @@ def check_val_dataset():
grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name) grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
try: try:
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) grt = cv2_imread(grt_path, cv2.IMREAD_UNCHANGED)
except Exception as e: except Exception as e:
imread_failed.append((line, e.message)) imread_failed.append((line, e.message))
is_label_gray = label_gray_check(grt)
if not is_label_gray:
label_gray_wrong.append(line)
grt = cv2.cvtColor(grt, cv2.COLOR_BGR2GRAY)
get_image_max_height_width(img) get_image_max_height_width(img)
get_image_min_max_aspectratio(img) get_image_min_max_aspectratio(img)
get_image_dim(img) get_image_dim(img)
...@@ -430,10 +461,15 @@ def check_test_dataset(): ...@@ -430,10 +461,15 @@ def check_test_dataset():
grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name) grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
try: try:
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) grt = cv2_imread(grt_path, cv2.IMREAD_UNCHANGED)
except Exception as e: except Exception as e:
imread_failed.append((line, e.message)) imread_failed.append((line, e.message))
continue continue
is_label_gray = label_gray_check(grt)
if not is_label_gray:
label_gray_wrong.append(line)
grt = cv2.cvtColor(grt, cv2.COLOR_BGR2GRAY)
is_equal_img_grt_shape = image_label_shape_check(img, grt) is_equal_img_grt_shape = image_label_shape_check(img, grt)
if not is_equal_img_grt_shape: if not is_equal_img_grt_shape:
shape_unequal_image.append(line) shape_unequal_image.append(line)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册