提交 cd6e06d3 编写于 作者: L LutaoChu 提交者: wuzewu

improve data check (#49)

* update check.py

* Update check.md
上级 f10169ae
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
# YAML_FILE_PATH为yaml配置文件路径 # YAML_FILE_PATH为yaml配置文件路径
python pdseg/check.py --cfg ${YAML_FILE_PATH} python pdseg/check.py --cfg ${YAML_FILE_PATH}
``` ```
运行后,命令行将显示校验结果的概览信息,详细信息可到detail.log文件中查看。 运行后,命令行将显示校验结果的概览信息,详细的错误信息可到detail.log文件中查看。
### 1 列表分割符校验 ### 1 列表分割符校验
判断在`TRAIN_FILE_LIST``VAL_FILE_LIST``TEST_FILE_LIST`列表文件中的分隔符`DATASET.SEPARATOR`设置是否正确。 判断在`TRAIN_FILE_LIST``VAL_FILE_LIST``TEST_FILE_LIST`列表文件中的分隔符`DATASET.SEPARATOR`设置是否正确。
...@@ -31,18 +31,24 @@ python pdseg/check.py --cfg ${YAML_FILE_PATH} ...@@ -31,18 +31,24 @@ python pdseg/check.py --cfg ${YAML_FILE_PATH}
标注类别最好从0开始,否则可能影响精度。 标注类别最好从0开始,否则可能影响精度。
### 6 标注像素统计 ### 6 标注像素统计
统计每种类别像素数量,显示以供参考。 统计每种类别的像素总数和所占比例,显示以供参考。统计结果如下:
```
Doing label pixel statistics:
(label class, total pixel number, percentage) = [(0, 2048984, 0.5211), (1, 1682943, 0.428), (2, 197976, 0.0503), (3, 2257, 0.0006)]
```
### 7 图像格式校验 ### 7 图像格式校验
检查图片类型`DATASET.IMAGE_TYPE`是否设置正确。 检查图片类型`DATASET.IMAGE_TYPE`是否设置正确。
**NOTE:** 当数据集包含三通道图片时`DATASET.IMAGE_TYPE`设置为rgb; **NOTE:** 当数据集包含三通道图片时`DATASET.IMAGE_TYPE`设置为rgb;
当数据集全部为四通道图片时`DATASET.IMAGE_TYPE`设置为rgba; 当数据集全部为四通道图片时`DATASET.IMAGE_TYPE`设置为rgba;
### 8 图像与标注图尺寸一致性校验 ### 8 图像最大尺寸统计
统计数据集中图片的最大高和最大宽,显示以供参考。
### 9 图像与标注图尺寸一致性校验
验证图像尺寸和对应标注图尺寸是否一致。 验证图像尺寸和对应标注图尺寸是否一致。
### 9 模型验证参数`EVAL_CROP_SIZE`校验 ### 10 模型验证参数`EVAL_CROP_SIZE`校验
验证`EVAL_CROP_SIZE`是否设置正确,共有3种情形: 验证`EVAL_CROP_SIZE`是否设置正确,共有3种情形:
-`AUG.AUG_METHOD`为unpadding时,`EVAL_CROP_SIZE`的宽高应不小于`AUG.FIX_RESIZE_SIZE`的宽高。 -`AUG.AUG_METHOD`为unpadding时,`EVAL_CROP_SIZE`的宽高应不小于`AUG.FIX_RESIZE_SIZE`的宽高。
...@@ -51,5 +57,5 @@ python pdseg/check.py --cfg ${YAML_FILE_PATH} ...@@ -51,5 +57,5 @@ python pdseg/check.py --cfg ${YAML_FILE_PATH}
-`AUG.AUG_METHOD`为rangscaling时,`EVAL_CROP_SIZE`的宽高应不小于缩放后图像中最大的宽高。 -`AUG.AUG_METHOD`为rangscaling时,`EVAL_CROP_SIZE`的宽高应不小于缩放后图像中最大的宽高。
### 10 数据增强参数`AUG.INF_RESIZE_VALUE`校验 ### 11 数据增强参数`AUG.INF_RESIZE_VALUE`校验
验证`AUG.INF_RESIZE_VALUE`是否在[`AUG.MIN_RESIZE_VALUE`~`AUG.MAX_RESIZE_VALUE`]范围内。若在范围内,则通过校验。 验证`AUG.INF_RESIZE_VALUE`是否在[`AUG.MIN_RESIZE_VALUE`~`AUG.MAX_RESIZE_VALUE`]范围内。若在范围内,则通过校验。
...@@ -16,6 +16,7 @@ import logging ...@@ -16,6 +16,7 @@ import logging
from utils.config import cfg from utils.config import cfg
def init_global_variable(): def init_global_variable():
""" """
初始化全局变量 初始化全局变量
...@@ -31,8 +32,8 @@ def init_global_variable(): ...@@ -31,8 +32,8 @@ def init_global_variable():
global min_aspectratio # 图片最小宽高比 global min_aspectratio # 图片最小宽高比
global max_aspectratio # 图片最大宽高比 global max_aspectratio # 图片最大宽高比
global img_dim # 图片的通道数 global img_dim # 图片的通道数
global list_wrong #文件名格式错误列表 global list_wrong # 文件名格式错误列表
global imread_failed #图片读取失败列表, 二元列表 global imread_failed # 图片读取失败列表, 二元列表
global label_wrong # 标注图片出错列表 global label_wrong # 标注图片出错列表
global label_gray_wrong # 标注图非灰度图列表 global label_gray_wrong # 标注图非灰度图列表
...@@ -52,29 +53,33 @@ def init_global_variable(): ...@@ -52,29 +53,33 @@ def init_global_variable():
label_wrong = [] label_wrong = []
label_gray_wrong = [] label_gray_wrong = []
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='PaddleSeg check') parser = argparse.ArgumentParser(description='PaddleSeg check')
parser.add_argument( parser.add_argument(
'--cfg', '--cfg',
dest='cfg_file', dest='cfg_file',
help='Config file for training (and optionally testing)', help='Config file for training (and optionally testing)',
default=None, default=None,
type=str type=str)
)
return parser.parse_args() return parser.parse_args()
def error_print(str): def error_print(str):
return "".join(["\nNOT PASS ", str]) return "".join(["\nNOT PASS ", str])
def correct_print(str): def correct_print(str):
return "".join(["\nPASS ", str]) return "".join(["\nPASS ", str])
def cv2_imread(file_path, flag=cv2.IMREAD_COLOR): def cv2_imread(file_path, flag=cv2.IMREAD_COLOR):
""" """
解决 cv2.imread 在window平台打开中文路径的问题. 解决 cv2.imread 在window平台打开中文路径的问题.
""" """
return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag) return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag)
def get_image_max_height_width(img): def get_image_max_height_width(img):
"""获取图片最大宽和高""" """获取图片最大宽和高"""
global max_width, max_height global max_width, max_height
...@@ -83,21 +88,24 @@ def get_image_max_height_width(img): ...@@ -83,21 +88,24 @@ def get_image_max_height_width(img):
max_height = max(height, max_height) max_height = max(height, max_height)
max_width = max(width, max_width) max_width = max(width, max_width)
def get_image_min_max_aspectratio(img): def get_image_min_max_aspectratio(img):
"""计算图片最大宽高比""" """计算图片最大宽高比"""
global min_aspectratio, max_aspectratio global min_aspectratio, max_aspectratio
img_shape = img.shape img_shape = img.shape
height, width = img_shape[0], img_shape[1] height, width = img_shape[0], img_shape[1]
min_aspectratio = min(width/height, min_aspectratio) min_aspectratio = min(width / height, min_aspectratio)
max_aspectratio = max(width/height, max_aspectratio) max_aspectratio = max(width / height, max_aspectratio)
return min_aspectratio, max_aspectratio return min_aspectratio, max_aspectratio
def get_image_dim(img): def get_image_dim(img):
"""获取图像的通道数""" """获取图像的通道数"""
img_shape = img.shape img_shape = img.shape
if img_shape[-1] not in img_dim: if img_shape[-1] not in img_dim:
img_dim.append(img_shape[-1]) img_dim.append(img_shape[-1])
def is_label_gray(grt): def is_label_gray(grt):
"""判断标签是否为灰度图""" """判断标签是否为灰度图"""
grt_shape = grt.shape grt_shape = grt.shape
...@@ -106,6 +114,7 @@ def is_label_gray(grt): ...@@ -106,6 +114,7 @@ def is_label_gray(grt):
else: else:
return False return False
def image_label_shape_check(img, grt): def image_label_shape_check(img, grt):
""" """
验证图像和标注的大小是否匹配 验证图像和标注的大小是否匹配
...@@ -117,11 +126,11 @@ def image_label_shape_check(img, grt): ...@@ -117,11 +126,11 @@ def image_label_shape_check(img, grt):
grt_height = grt.shape[0] grt_height = grt.shape[0]
grt_width = grt.shape[1] grt_width = grt.shape[1]
if img_height != grt_height or img_width != grt_width: if img_height != grt_height or img_width != grt_width:
flag = False flag = False
return flag return flag
def ground_truth_check(grt, grt_path): def ground_truth_check(grt, grt_path):
""" """
验证标注图像的格式 验证标注图像的格式
...@@ -143,6 +152,7 @@ def ground_truth_check(grt, grt_path): ...@@ -143,6 +152,7 @@ def ground_truth_check(grt, grt_path):
return png_format, unique, counts return png_format, unique, counts
def sum_gt_check(png_format, grt_classes, num_of_each_class): def sum_gt_check(png_format, grt_classes, num_of_each_class):
""" """
统计所有标注图上的格式、类别和每个类别的像素数 统计所有标注图上的格式、类别和每个类别的像素数
...@@ -160,7 +170,8 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class): ...@@ -160,7 +170,8 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class):
png_format_wrong_num += 1 png_format_wrong_num += 1
if cfg.DATASET.IGNORE_INDEX in grt_classes: if cfg.DATASET.IGNORE_INDEX in grt_classes:
grt_classes2 = np.delete(grt_classes, np.where(grt_classes == cfg.DATASET.IGNORE_INDEX)) grt_classes2 = np.delete(
grt_classes, np.where(grt_classes == cfg.DATASET.IGNORE_INDEX))
else: else:
grt_classes2 = grt_classes grt_classes2 = grt_classes
if min(grt_classes2) < 0 or max(grt_classes2) > cfg.DATASET.NUM_CLASSES - 1: if min(grt_classes2) < 0 or max(grt_classes2) > cfg.DATASET.NUM_CLASSES - 1:
...@@ -179,6 +190,7 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class): ...@@ -179,6 +190,7 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class):
total_grt_classes += add_class total_grt_classes += add_class
return is_label_correct return is_label_correct
def gt_check(): def gt_check():
""" """
对标注图像进行校验,输出校验结果 对标注图像进行校验,输出校验结果
...@@ -192,16 +204,20 @@ def gt_check(): ...@@ -192,16 +204,20 @@ def gt_check():
return return
else: else:
logger.info(error_print("label format check")) logger.info(error_print("label format check"))
logger.info("total {} label images are png format, {} label images are not png " logger.info(
"format".format(png_format_right_num, png_format_wrong_num)) "total {} label images are png format, {} label images are not png "
"format".format(png_format_right_num, png_format_wrong_num))
if len(png_format_wrong_image) > 0: if len(png_format_wrong_image) > 0:
for i in png_format_wrong_image: for i in png_format_wrong_image:
logger.debug(i) logger.debug(i)
total_ratio = total_num_of_each_class / sum(total_num_of_each_class)
total_nc = sorted(zip(total_grt_classes, total_num_of_each_class)) total_ratio = np.around(total_ratio, decimals=4)
logger.info("\nDoing label pixel statistics...\nTotal label classes " total_nc = sorted(
"and their corresponding numbers:\n{} ".format(total_nc)) zip(total_grt_classes, total_num_of_each_class, total_ratio))
logger.info(
"\nDoing label pixel statistics:\n"
"(label class, total pixel number, percentage) = {} ".format(total_nc))
if len(label_wrong) == 0 and not total_nc[0][0]: if len(label_wrong) == 0 and not total_nc[0][0]:
logger.info(correct_print("label class check!")) logger.info(correct_print("label class check!"))
...@@ -210,13 +226,15 @@ def gt_check(): ...@@ -210,13 +226,15 @@ def gt_check():
if total_nc[0][0]: if total_nc[0][0]:
logger.info("Warning: label classes should start from 0") logger.info("Warning: label classes should start from 0")
if len(label_wrong) > 0: if len(label_wrong) > 0:
logger.info("fatal error: label class is out of range [0, {}]".format(cfg.DATASET.NUM_CLASSES - 1)) logger.info(
"fatal error: label class is out of range [0, {}]".format(
cfg.DATASET.NUM_CLASSES - 1))
for i in label_wrong: for i in label_wrong:
logger.debug(i) logger.debug(i)
def eval_crop_size_check(max_height, max_width, min_aspectratio,
def eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio): max_aspectratio):
""" """
判断eval_crop_siz与验证集及测试集的max_height, max_width的关系 判断eval_crop_siz与验证集及测试集的max_height, max_width的关系
param param
...@@ -225,69 +243,109 @@ def eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio ...@@ -225,69 +243,109 @@ def eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio
""" """
if cfg.AUG.AUG_METHOD == "stepscaling": if cfg.AUG.AUG_METHOD == "stepscaling":
if max_width <= cfg.EVAL_CROP_SIZE[0] and max_height <= cfg.EVAL_CROP_SIZE[1]: if max_width <= cfg.EVAL_CROP_SIZE[
0] and max_height <= cfg.EVAL_CROP_SIZE[1]:
logger.info(correct_print("EVAL_CROP_SIZE check")) logger.info(correct_print("EVAL_CROP_SIZE check"))
logger.info(
"satisfy current EVAL_CROP_SIZE: ({},{}) >= max width and max height of images: ({},{})"
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], max_width,
max_height))
else: else:
logger.info(error_print("EVAL_CROP_SIZE check")) logger.info(error_print("EVAL_CROP_SIZE check"))
if max_width > cfg.EVAL_CROP_SIZE[0]: if max_width > cfg.EVAL_CROP_SIZE[0]:
logger.info("The EVAL_CROP_SIZE[0]: {} should larger max width of images {}!".format( logger.info(
cfg.EVAL_CROP_SIZE[0], max_width)) "EVAL_CROP_SIZE[0]: {} should >= max width of images {}!".
format(cfg.EVAL_CROP_SIZE[0], max_width))
if max_height > cfg.EVAL_CROP_SIZE[1]: if max_height > cfg.EVAL_CROP_SIZE[1]:
logger.info(error_print("The EVAL_CROP_SIZE[1]: {} should larger max height of images {}!".format( logger.info(
cfg.EVAL_CROP_SIZE[1], max_height))) "EVAL_CROP_SIZE[1]: {} should >= max height of images {}!".
format(cfg.EVAL_CROP_SIZE[1], max_height))
elif cfg.AUG.AUG_METHOD == "rangescaling": elif cfg.AUG.AUG_METHOD == "rangescaling":
if min_aspectratio <= 1 and max_aspectratio >= 1: if min_aspectratio <= 1 and max_aspectratio >= 1:
if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.INF_RESIZE_VALUE: if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.INF_RESIZE_VALUE \
and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.INF_RESIZE_VALUE:
logger.info(correct_print("EVAL_CROP_SIZE check")) logger.info(correct_print("EVAL_CROP_SIZE check"))
logger.info(
"satisfy current EVAL_CROP_SIZE: ({},{}) >= ({},{}) ".
format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.INF_RESIZE_VALUE))
else: else:
logger.info(error_print("EVAL_CROP_SIZE check")) logger.info(error_print("EVAL_CROP_SIZE check"))
logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" logger.info(
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], "EVAL_CROP_SIZE must >= img size({},{}), current EVAL_CROP_SIZE is ({},{})"
cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.INF_RESIZE_VALUE)) .format(cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.INF_RESIZE_VALUE,
cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1]))
elif min_aspectratio > 1: elif min_aspectratio > 1:
max_height_rangscaling = cfg.AUG.INF_RESIZE_VALUE / min_aspectratio max_height_rangscaling = cfg.AUG.INF_RESIZE_VALUE / min_aspectratio
max_height_rangscaling = round(max_height_rangscaling) max_height_rangscaling = round(max_height_rangscaling)
if cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[1] >= max_height_rangscaling: if cfg.EVAL_CROP_SIZE[
0] >= cfg.AUG.INF_RESIZE_VALUE and cfg.EVAL_CROP_SIZE[
1] >= max_height_rangscaling:
logger.info(correct_print("EVAL_CROP_SIZE check")) logger.info(correct_print("EVAL_CROP_SIZE check"))
logger.info(
"satisfy current EVAL_CROP_SIZE: ({},{}) >= ({},{}) ".
format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
cfg.AUG.INF_RESIZE_VALUE, max_height_rangscaling))
else: else:
logger.info(error_print("EVAL_CROP_SIZE check")) logger.info(error_print("EVAL_CROP_SIZE check"))
logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" logger.info(
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], "EVAL_CROP_SIZE must >= img size({},{}), current EVAL_CROP_SIZE is ({},{})"
cfg.AUG.INF_RESIZE_VALUE, max_height_rangscaling)) .format(cfg.AUG.INF_RESIZE_VALUE, max_height_rangscaling,
cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1]))
elif max_aspectratio < 1: elif max_aspectratio < 1:
max_width_rangscaling = cfg.AUG.INF_RESIZE_VALUE * max_aspectratio max_width_rangscaling = cfg.AUG.INF_RESIZE_VALUE * max_aspectratio
max_width_rangscaling = round(max_width_rangscaling) max_width_rangscaling = round(max_width_rangscaling)
if cfg.EVAL_CROP_SIZE[0] >= max_width_rangscaling and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.INF_RESIZE_VALUE: if cfg.EVAL_CROP_SIZE[
0] >= max_width_rangscaling and cfg.EVAL_CROP_SIZE[
1] >= cfg.AUG.INF_RESIZE_VALUE:
logger.info(correct_print("EVAL_CROP_SIZE check")) logger.info(correct_print("EVAL_CROP_SIZE check"))
logger.info(
"satisfy current EVAL_CROP_SIZE: ({},{}) >= ({},{}) ".
format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
max_height_rangscaling, cfg.AUG.INF_RESIZE_VALUE))
else: else:
logger.info(error_print("EVAL_CROP_SIZE check")) logger.info(error_print("EVAL_CROP_SIZE check"))
logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" logger.info(
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], "EVAL_CROP_SIZE must >= img size({},{}), current EVAL_CROP_SIZE is ({},{})"
max_width_rangscaling, cfg.AUG.INF_RESIZE_VALUE)) .format(max_width_rangscaling, cfg.AUG.INF_RESIZE_VALUE,
cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1]))
elif cfg.AUG.AUG_METHOD == "unpadding": elif cfg.AUG.AUG_METHOD == "unpadding":
if len(cfg.AUG.FIX_RESIZE_SIZE) != 2: if len(cfg.AUG.FIX_RESIZE_SIZE) != 2:
logger.info(error_print("EVAL_CROP_SIZE check")) logger.info(error_print("EVAL_CROP_SIZE check"))
logger.info("you set AUG.AUG_METHOD = 'unpadding', but AUG.FIX_RESIZE_SIZE is wrong. " logger.info(
"AUG.FIX_RESIZE_SIZE should be a tuple of length 2") "you set AUG.AUG_METHOD = 'unpadding', but AUG.FIX_RESIZE_SIZE is wrong. "
elif cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.FIX_RESIZE_SIZE[0] and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.FIX_RESIZE_SIZE[1]: "AUG.FIX_RESIZE_SIZE should be a tuple of length 2")
elif cfg.EVAL_CROP_SIZE[0] >= cfg.AUG.FIX_RESIZE_SIZE[0] \
and cfg.EVAL_CROP_SIZE[1] >= cfg.AUG.FIX_RESIZE_SIZE[1]:
logger.info(correct_print("EVAL_CROP_SIZE check")) logger.info(correct_print("EVAL_CROP_SIZE check"))
logger.info(
"satisfy current EVAL_CROP_SIZE: ({},{}) >= AUG.FIX_RESIZE_SIZE: ({},{}) "
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
cfg.AUG.FIX_RESIZE_SIZE[0], cfg.AUG.FIX_RESIZE_SIZE[1]))
else: else:
logger.info(error_print("EVAL_CROP_SIZE check")) logger.info(error_print("EVAL_CROP_SIZE check"))
logger.info("EVAL_CROP_SIZE: ({},{}) must large than img size({},{})" logger.info(
.format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1], "EVAL_CROP_SIZE: ({},{}) must >= AUG.FIX_RESIZE_SIZE: ({},{})".
cfg.AUG.FIX_RESIZE_SIZE[0], cfg.AUG.FIX_RESIZE_SIZE[1])) format(cfg.EVAL_CROP_SIZE[0], cfg.EVAL_CROP_SIZE[1],
cfg.AUG.FIX_RESIZE_SIZE[0], cfg.AUG.FIX_RESIZE_SIZE[1]))
else: else:
logger.info("\nERROR! cfg.AUG.AUG_METHOD setting wrong, it should be one of " logger.info(
"[unpadding, stepscaling, rangescaling]") "\nERROR! cfg.AUG.AUG_METHOD setting wrong, it should be one of "
"[unpadding, stepscaling, rangescaling]")
def inf_resize_value_check(): def inf_resize_value_check():
if cfg.AUG.AUG_METHOD == "rangescaling": if cfg.AUG.AUG_METHOD == "rangescaling":
if cfg.AUG.INF_RESIZE_VALUE < cfg.AUG.MIN_RESIZE_VALUE or \ if cfg.AUG.INF_RESIZE_VALUE < cfg.AUG.MIN_RESIZE_VALUE or \
cfg.AUG.INF_RESIZE_VALUE > cfg.AUG.MIN_RESIZE_VALUE: cfg.AUG.INF_RESIZE_VALUE > cfg.AUG.MIN_RESIZE_VALUE:
logger.info("\nWARNING! you set AUG.AUG_METHOD = 'rangescaling'" logger.info(
"AUG.INF_RESIZE_VALUE: {} not in [AUG.MIN_RESIZE_VALUE, AUG.MAX_RESIZE_VALUE]: " "\nWARNING! you set AUG.AUG_METHOD = 'rangescaling'"
"[{}, {}].".format(cfg.AUG.INF_RESIZE_VALUE, cfg.AUG.MIN_RESIZE_VALUE, cfg.AUG.MAX_RESIZE_VALUE)) "AUG.INF_RESIZE_VALUE: {} not in [AUG.MIN_RESIZE_VALUE, AUG.MAX_RESIZE_VALUE]: "
"[{}, {}].".format(cfg.AUG.INF_RESIZE_VALUE,
cfg.AUG.MIN_RESIZE_VALUE,
cfg.AUG.MAX_RESIZE_VALUE))
def image_type_check(img_dim): def image_type_check(img_dim):
""" """
...@@ -299,13 +357,17 @@ def image_type_check(img_dim): ...@@ -299,13 +357,17 @@ def image_type_check(img_dim):
if (1 in img_dim or 3 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgba': if (1 in img_dim or 3 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgba':
logger.info(error_print("DATASET.IMAGE_TYPE check")) logger.info(error_print("DATASET.IMAGE_TYPE check"))
logger.info("DATASET.IMAGE_TYPE is {} but the type of image has " logger.info("DATASET.IMAGE_TYPE is {} but the type of image has "
"gray or rgb\n".format(cfg.DATASET.IMAGE_TYPE)) "gray or rgb\n".format(cfg.DATASET.IMAGE_TYPE))
elif (1 not in img_dim and 3 not in img_dim and 4 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgb': elif (1 not in img_dim and 3 not in img_dim
and 4 in img_dim) and cfg.DATASET.IMAGE_TYPE == 'rgb':
logger.info(correct_print("DATASET.IMAGE_TYPE check")) logger.info(correct_print("DATASET.IMAGE_TYPE check"))
logger.info("\nWARNING: DATASET.IMAGE_TYPE is {} but the type of all image is rgba".format(cfg.DATASET.IMAGE_TYPE)) logger.info(
"\nWARNING: DATASET.IMAGE_TYPE is {} but the type of all image is rgba"
.format(cfg.DATASET.IMAGE_TYPE))
else: else:
logger.info(correct_print("DATASET.IMAGE_TYPE check")) logger.info(correct_print("DATASET.IMAGE_TYPE check"))
def shape_check(): def shape_check():
"""输出shape校验结果""" """输出shape校验结果"""
if len(shape_unequal_image) == 0: if len(shape_unequal_image) == 0:
...@@ -313,7 +375,8 @@ def shape_check(): ...@@ -313,7 +375,8 @@ def shape_check():
logger.info("All images are the same shape as the labels") logger.info("All images are the same shape as the labels")
else: else:
logger.info(error_print("shape check")) logger.info(error_print("shape check"))
logger.info("Some images are not the same shape as the labels as follow: ") logger.info(
"Some images are not the same shape as the labels as follow: ")
for i in shape_unequal_image: for i in shape_unequal_image:
logger.debug(i) logger.debug(i)
...@@ -321,13 +384,19 @@ def shape_check(): ...@@ -321,13 +384,19 @@ def shape_check():
def file_list_check(list_name): def file_list_check(list_name):
"""检查分割符是否复合要求""" """检查分割符是否复合要求"""
if len(list_wrong) == 0: if len(list_wrong) == 0:
logger.info(correct_print(list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check")) logger.info(
correct_print(
list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check"))
else: else:
logger.info(error_print(list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check")) logger.info(
logger.info("The following list is not separated by {}".format(cfg.DATASET.SEPARATOR)) error_print(
list_name.split(os.sep)[-1] + " DATASET.SEPARATOR check"))
logger.info("The following list is not separated by {}".format(
cfg.DATASET.SEPARATOR))
for i in list_wrong: for i in list_wrong:
logger.debug(i) logger.debug(i)
def imread_check(): def imread_check():
if len(imread_failed) == 0: if len(imread_failed) == 0:
logger.info(correct_print("dataset reading check")) logger.info(correct_print("dataset reading check"))
...@@ -338,18 +407,25 @@ def imread_check(): ...@@ -338,18 +407,25 @@ def imread_check():
for i in imread_failed: for i in imread_failed:
logger.debug(i) logger.debug(i)
def label_gray_check(): def label_gray_check():
if len(label_gray_wrong) == 0: if len(label_gray_wrong) == 0:
logger.info(correct_print("label gray check")) logger.info(correct_print("label gray check"))
logger.info("All label images are gray") logger.info("All label images are gray")
else: else:
logger.info(error_print("label gray check")) logger.info(error_print("label gray check"))
logger.info("{} label images are not gray\nLabel pixel statistics may " logger.info(
"be insignificant".format(len(label_gray_wrong))) "{} label images are not gray\nLabel pixel statistics may be insignificant"
.format(len(label_gray_wrong)))
for i in label_gray_wrong: for i in label_gray_wrong:
logger.debug(i) logger.debug(i)
def max_img_size_statistics():
logger.info("\nDoing max image size statistics:")
logger.info("max width and max height of images are ({},{})".format(
max_width, max_height))
def check_train_dataset(): def check_train_dataset():
list_file = cfg.DATASET.TRAIN_FILE_LIST list_file = cfg.DATASET.TRAIN_FILE_LIST
...@@ -376,15 +452,18 @@ def check_train_dataset(): ...@@ -376,15 +452,18 @@ def check_train_dataset():
if not is_gray: if not is_gray:
label_gray_wrong.append(line) label_gray_wrong.append(line)
grt = cv2.cvtColor(grt, cv2.COLOR_BGR2GRAY) grt = cv2.cvtColor(grt, cv2.COLOR_BGR2GRAY)
get_image_max_height_width(img)
get_image_dim(img) get_image_dim(img)
is_equal_img_grt_shape = image_label_shape_check(img, grt) is_equal_img_grt_shape = image_label_shape_check(img, grt)
if not is_equal_img_grt_shape: if not is_equal_img_grt_shape:
shape_unequal_image.append(line) shape_unequal_image.append(line)
png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path) png_format, grt_classes, num_of_each_class = ground_truth_check(
grt, grt_path)
if not png_format: if not png_format:
png_format_wrong_image.append(line) png_format_wrong_image.append(line)
is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class) is_label_correct = sum_gt_check(png_format, grt_classes,
num_of_each_class)
if not is_label_correct: if not is_label_correct:
label_wrong.append(line) label_wrong.append(line)
...@@ -393,12 +472,10 @@ def check_train_dataset(): ...@@ -393,12 +472,10 @@ def check_train_dataset():
label_gray_check() label_gray_check()
gt_check() gt_check()
image_type_check(img_dim) image_type_check(img_dim)
max_img_size_statistics()
shape_check() shape_check()
def check_val_dataset(): def check_val_dataset():
list_file = cfg.DATASET.VAL_FILE_LIST list_file = cfg.DATASET.VAL_FILE_LIST
logger.info("\n-----------------------------\n2. Check val dataset...") logger.info("\n-----------------------------\n2. Check val dataset...")
...@@ -430,10 +507,12 @@ def check_val_dataset(): ...@@ -430,10 +507,12 @@ def check_val_dataset():
is_equal_img_grt_shape = image_label_shape_check(img, grt) is_equal_img_grt_shape = image_label_shape_check(img, grt)
if not is_equal_img_grt_shape: if not is_equal_img_grt_shape:
shape_unequal_image.append(line) shape_unequal_image.append(line)
png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path) png_format, grt_classes, num_of_each_class = ground_truth_check(
grt, grt_path)
if not png_format: if not png_format:
png_format_wrong_image.append(line) png_format_wrong_image.append(line)
is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class) is_label_correct = sum_gt_check(png_format, grt_classes,
num_of_each_class)
if not is_label_correct: if not is_label_correct:
label_wrong.append(line) label_wrong.append(line)
...@@ -442,8 +521,11 @@ def check_val_dataset(): ...@@ -442,8 +521,11 @@ def check_val_dataset():
label_gray_check() label_gray_check()
gt_check() gt_check()
image_type_check(img_dim) image_type_check(img_dim)
max_img_size_statistics()
shape_check() shape_check()
eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio) eval_crop_size_check(max_height, max_width, min_aspectratio,
max_aspectratio)
def check_test_dataset(): def check_test_dataset():
list_file = cfg.DATASET.TEST_FILE_LIST list_file = cfg.DATASET.TEST_FILE_LIST
...@@ -481,10 +563,12 @@ def check_test_dataset(): ...@@ -481,10 +563,12 @@ def check_test_dataset():
is_equal_img_grt_shape = image_label_shape_check(img, grt) is_equal_img_grt_shape = image_label_shape_check(img, grt)
if not is_equal_img_grt_shape: if not is_equal_img_grt_shape:
shape_unequal_image.append(line) shape_unequal_image.append(line)
png_format, grt_classes, num_of_each_class = ground_truth_check(grt, grt_path) png_format, grt_classes, num_of_each_class = ground_truth_check(
grt, grt_path)
if not png_format: if not png_format:
png_format_wrong_image.append(line) png_format_wrong_image.append(line)
is_label_correct = sum_gt_check(png_format, grt_classes, num_of_each_class) is_label_correct = sum_gt_check(png_format, grt_classes,
num_of_each_class)
if not is_label_correct: if not is_label_correct:
label_wrong.append(line) label_wrong.append(line)
else: else:
...@@ -501,9 +585,12 @@ def check_test_dataset(): ...@@ -501,9 +585,12 @@ def check_test_dataset():
if has_label: if has_label:
gt_check() gt_check()
image_type_check(img_dim) image_type_check(img_dim)
max_img_size_statistics()
if has_label: if has_label:
shape_check() shape_check()
eval_crop_size_check(max_height, max_width, min_aspectratio, max_aspectratio) eval_crop_size_check(max_height, max_width, min_aspectratio,
max_aspectratio)
def main(args): def main(args):
if args.cfg_file is not None: if args.cfg_file is not None:
...@@ -522,6 +609,9 @@ def main(args): ...@@ -522,6 +609,9 @@ def main(args):
inf_resize_value_check() inf_resize_value_check()
print("\nDetailed error information can be viewed in detail.log file.")
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
logger = logging.getLogger() logger = logging.getLogger()
...@@ -536,5 +626,3 @@ if __name__ == "__main__": ...@@ -536,5 +626,3 @@ if __name__ == "__main__":
logger.addHandler(sh) logger.addHandler(sh)
logger.addHandler(th) logger.addHandler(th)
main(args) main(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册