未验证 提交 38d5b571 编写于 作者: Z Zeyu Chen 提交者: GitHub

Merge pull request #10 from wuyefeilin/master

add tqdm in requirements.txt, add a parser params to local_test_*.py
...@@ -34,6 +34,7 @@ def init_global_variable(): ...@@ -34,6 +34,7 @@ def init_global_variable():
global list_wrong #文件名格式错误列表 global list_wrong #文件名格式错误列表
global imread_failed #图片读取失败列表, 二元列表 global imread_failed #图片读取失败列表, 二元列表
global label_wrong # 标注图片出错列表 global label_wrong # 标注图片出错列表
global label_gray_wrong # 标注图非灰度图列表
png_format_right_num = 0 png_format_right_num = 0
png_format_wrong_num = 0 png_format_wrong_num = 0
...@@ -49,6 +50,7 @@ def init_global_variable(): ...@@ -49,6 +50,7 @@ def init_global_variable():
list_wrong = [] list_wrong = []
imread_failed = [] imread_failed = []
label_wrong = [] label_wrong = []
label_gray_wrong = []
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='PaddleSeg check') parser = argparse.ArgumentParser(description='PaddleSeg check')
...@@ -68,10 +70,13 @@ def correct_print(str): ...@@ -68,10 +70,13 @@ def correct_print(str):
return "".join(["\nPASS ", str]) return "".join(["\nPASS ", str])
def cv2_imread(file_path, flag=cv2.IMREAD_COLOR): def cv2_imread(file_path, flag=cv2.IMREAD_COLOR):
# resolve cv2.imread open Chinese file path issues on Windows Platform. """
解决 cv2.imread 在window平台打开中文路径的问题.
"""
return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag) return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag)
def get_image_max_height_width(img): def get_image_max_height_width(img):
"""获取图片最大宽和高"""
global max_width, max_height global max_width, max_height
img_shape = img.shape img_shape = img.shape
height, width = img_shape[0], img_shape[1] height, width = img_shape[0], img_shape[1]
...@@ -79,6 +84,7 @@ def get_image_max_height_width(img): ...@@ -79,6 +84,7 @@ def get_image_max_height_width(img):
max_width = max(width, max_width) max_width = max(width, max_width)
def get_image_min_max_aspectratio(img): def get_image_min_max_aspectratio(img):
"""计算图片最大宽高比"""
global min_aspectratio, max_aspectratio global min_aspectratio, max_aspectratio
img_shape = img.shape img_shape = img.shape
height, width = img_shape[0], img_shape[1] height, width = img_shape[0], img_shape[1]
...@@ -87,11 +93,19 @@ def get_image_min_max_aspectratio(img): ...@@ -87,11 +93,19 @@ def get_image_min_max_aspectratio(img):
return min_aspectratio, max_aspectratio return min_aspectratio, max_aspectratio
def get_image_dim(img): def get_image_dim(img):
"""获取图像的维度""" """获取图像的通道数"""
img_shape = img.shape img_shape = img.shape
if img_shape[-1] not in img_dim: if img_shape[-1] not in img_dim:
img_dim.append(img_shape[-1]) img_dim.append(img_shape[-1])
def is_label_gray(grt):
"""判断标签是否为灰度图"""
grt_shape = grt.shape
if len(grt_shape) == 2:
return True
else:
return False
def image_label_shape_check(img, grt): def image_label_shape_check(img, grt):
""" """
验证图像和标注的大小是否匹配 验证图像和标注的大小是否匹配
...@@ -110,17 +124,15 @@ def image_label_shape_check(img, grt): ...@@ -110,17 +124,15 @@ def image_label_shape_check(img, grt):
def ground_truth_check(grt, grt_path): def ground_truth_check(grt, grt_path):
""" """
验证标注是否重零开始,标注值为0,1,...,num_classes-1, ingnore_idx
验证标注图像的格式 验证标注图像的格式
返回标注的像素数 统计标注图类别和像素数
检查图像是否都是ignore_index
params: params:
grt: 标注图 grt: 标注图
grt_path: 标注图路径 grt_path: 标注图路径
return: return:
png_format: 返回是否是png格式图片 png_format: 返回是否是png格式图片
label_correct: 返回标注是否是正确的 unique: 返回标注类别
label_pixel_num: 返回标注的像素数 counts: 返回标注的像素数
""" """
if imghdr.what(grt_path) == "png": if imghdr.what(grt_path) == "png":
png_format = True png_format = True
...@@ -135,7 +147,7 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class): ...@@ -135,7 +147,7 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class):
""" """
统计所有标注图上的格式、类别和每个类别的像素数 统计所有标注图上的格式、类别和每个类别的像素数
params: params:
png_format: 返回是否是png格式图片 png_format: 是否是png格式图片
grt_classes: 标注类别 grt_classes: 标注类别
num_of_each_class: 各个类别的像素数目 num_of_each_class: 各个类别的像素数目
""" """
...@@ -188,7 +200,7 @@ def gt_check(): ...@@ -188,7 +200,7 @@ def gt_check():
total_nc = sorted(zip(total_grt_classes, total_num_of_each_class)) total_nc = sorted(zip(total_grt_classes, total_num_of_each_class))
logger.info("\nDoing label pixel statistics...\nTotal label calsses " logger.info("\nDoing label pixel statistics...\nTotal label classes "
"and their corresponding numbers:\n{} ".format(total_nc)) "and their corresponding numbers:\n{} ".format(total_nc))
if len(label_wrong) == 0 and not total_nc[0][0]: if len(label_wrong) == 0 and not total_nc[0][0]:
...@@ -322,6 +334,17 @@ def imread_check(): ...@@ -322,6 +334,17 @@ def imread_check():
for i in imread_failed: for i in imread_failed:
logger.debug(i) logger.debug(i)
def label_gray_check():
if len(label_gray_wrong) == 0:
logger.info(correct_print("label gray check"))
logger.info("All label images are gray")
else:
logger.info(error_print("label gray check"))
logger.info("{} label images are not gray\nLabel pixel statistics may "
"be insignificant".format(len(label_gray_wrong)))
for i in label_gray_wrong:
logger.debug(i)
def check_train_dataset(): def check_train_dataset():
...@@ -340,11 +363,15 @@ def check_train_dataset(): ...@@ -340,11 +363,15 @@ def check_train_dataset():
grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name) grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
try: try:
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) grt = cv2_imread(grt_path, cv2.IMREAD_UNCHANGED)
except Exception as e: except Exception as e:
imread_failed.append((line, str(e))) imread_failed.append((line, str(e)))
continue continue
is_gray = is_label_gray(grt)
if not is_gray:
label_gray_wrong.append(line)
grt = cv2.cvtColor(grt, cv2.COLOR_BGR2GRAY)
get_image_dim(img) get_image_dim(img)
is_equal_img_grt_shape = image_label_shape_check(img, grt) is_equal_img_grt_shape = image_label_shape_check(img, grt)
if not is_equal_img_grt_shape: if not is_equal_img_grt_shape:
...@@ -359,6 +386,7 @@ def check_train_dataset(): ...@@ -359,6 +386,7 @@ def check_train_dataset():
file_list_check(list_file) file_list_check(list_file)
imread_check() imread_check()
label_gray_check()
gt_check() gt_check()
image_type_check(img_dim) image_type_check(img_dim)
shape_check() shape_check()
...@@ -383,9 +411,14 @@ def check_val_dataset(): ...@@ -383,9 +411,14 @@ def check_val_dataset():
grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name) grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
try: try:
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) grt = cv2_imread(grt_path, cv2.IMREAD_UNCHANGED)
except Exception as e: except Exception as e:
imread_failed.append((line, e.message)) imread_failed.append((line, e.message))
is_gray = is_label_gray(grt)
if not is_gray:
label_gray_wrong.append(line)
grt = cv2.cvtColor(grt, cv2.COLOR_BGR2GRAY)
get_image_max_height_width(img) get_image_max_height_width(img)
get_image_min_max_aspectratio(img) get_image_min_max_aspectratio(img)
get_image_dim(img) get_image_dim(img)
...@@ -401,6 +434,7 @@ def check_val_dataset(): ...@@ -401,6 +434,7 @@ def check_val_dataset():
file_list_check(list_file) file_list_check(list_file)
imread_check() imread_check()
label_gray_check()
gt_check() gt_check()
image_type_check(img_dim) image_type_check(img_dim)
shape_check() shape_check()
...@@ -430,10 +464,15 @@ def check_test_dataset(): ...@@ -430,10 +464,15 @@ def check_test_dataset():
grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name) grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
try: try:
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED) img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) grt = cv2_imread(grt_path, cv2.IMREAD_UNCHANGED)
except Exception as e: except Exception as e:
imread_failed.append((line, e.message)) imread_failed.append((line, e.message))
continue continue
is_gray = is_label_gray(grt)
if not is_gray:
label_gray_wrong.append(line)
grt = cv2.cvtColor(grt, cv2.COLOR_BGR2GRAY)
is_equal_img_grt_shape = image_label_shape_check(img, grt) is_equal_img_grt_shape = image_label_shape_check(img, grt)
if not is_equal_img_grt_shape: if not is_equal_img_grt_shape:
shape_unequal_image.append(line) shape_unequal_image.append(line)
...@@ -452,6 +491,8 @@ def check_test_dataset(): ...@@ -452,6 +491,8 @@ def check_test_dataset():
file_list_check(list_file) file_list_check(list_file)
imread_check() imread_check()
if has_label:
label_gray_check()
if has_label: if has_label:
gt_check() gt_check()
image_type_check(img_dim) image_type_check(img_dim)
......
...@@ -8,3 +8,4 @@ Pillow ...@@ -8,3 +8,4 @@ Pillow
numpy numpy
six six
opencv-python opencv-python
tqdm
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from test_utils import download_file_and_uncompress, train, eval, vis, export_model from test_utils import download_file_and_uncompress, train, eval, vis, export_model
import os import os
import argparse
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__)) LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
DATASET_PATH = os.path.join(LOCAL_PATH, "..", "dataset") DATASET_PATH = os.path.join(LOCAL_PATH, "..", "dataset")
...@@ -43,7 +44,16 @@ if __name__ == "__main__": ...@@ -43,7 +44,16 @@ if __name__ == "__main__":
vis_dir = os.path.join(LOCAL_PATH, "visual", model_name) vis_dir = os.path.join(LOCAL_PATH, "visual", model_name)
saved_model = os.path.join(LOCAL_PATH, "saved_model", model_name) saved_model = os.path.join(LOCAL_PATH, "saved_model", model_name)
devices = ['0'] parser = argparse.ArgumentParser(description="PaddleSeg loacl test")
parser.add_argument("--devices",
dest="devices",
help="GPU id of running. if more than one, use spacing to separate.",
nargs="+",
default=0,
type=int)
args = parser.parse_args()
devices = [str(x) for x in args.devices]
export_model( export_model(
flags=["--cfg", cfg], flags=["--cfg", cfg],
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from test_utils import download_file_and_uncompress, train, eval, vis, export_model from test_utils import download_file_and_uncompress, train, eval, vis, export_model
import os import os
import argparse
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__)) LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
DATASET_PATH = os.path.join(LOCAL_PATH, "..", "dataset") DATASET_PATH = os.path.join(LOCAL_PATH, "..", "dataset")
...@@ -44,7 +45,16 @@ if __name__ == "__main__": ...@@ -44,7 +45,16 @@ if __name__ == "__main__":
vis_dir = os.path.join(LOCAL_PATH, "visual", model_name) vis_dir = os.path.join(LOCAL_PATH, "visual", model_name)
saved_model = os.path.join(LOCAL_PATH, "saved_model", model_name) saved_model = os.path.join(LOCAL_PATH, "saved_model", model_name)
devices = ['0'] parser = argparse.ArgumentParser(description="PaddleSeg loacl test")
parser.add_argument("--devices",
dest="devices",
help="GPU id of running. if more than one, use spacing to separate.",
nargs="+",
default=0,
type=int)
args = parser.parse_args()
devices = [str(x) for x in args.devices]
train( train(
flags=["--cfg", cfg, "--use_gpu", "--log_steps", "10"], flags=["--cfg", cfg, "--use_gpu", "--log_steps", "10"],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册