未验证 提交 38d5b571 编写于 作者: Z Zeyu Chen 提交者: GitHub

Merge pull request #10 from wuyefeilin/master

add tqdm in requirements.txt, add a parser params to local_test_*.py
......@@ -34,6 +34,7 @@ def init_global_variable():
global list_wrong #文件名格式错误列表
global imread_failed #图片读取失败列表, 二元列表
global label_wrong # 标注图片出错列表
global label_gray_wrong # 标注图非灰度图列表
png_format_right_num = 0
png_format_wrong_num = 0
......@@ -49,6 +50,7 @@ def init_global_variable():
list_wrong = []
imread_failed = []
label_wrong = []
label_gray_wrong = []
def parse_args():
parser = argparse.ArgumentParser(description='PaddleSeg check')
......@@ -68,10 +70,13 @@ def correct_print(str):
return "".join(["\nPASS ", str])
def cv2_imread(file_path, flag=cv2.IMREAD_COLOR):
# resolve cv2.imread open Chinese file path issues on Windows Platform.
"""
解决 cv2.imread 在window平台打开中文路径的问题.
"""
return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag)
def get_image_max_height_width(img):
"""获取图片最大宽和高"""
global max_width, max_height
img_shape = img.shape
height, width = img_shape[0], img_shape[1]
......@@ -79,6 +84,7 @@ def get_image_max_height_width(img):
max_width = max(width, max_width)
def get_image_min_max_aspectratio(img):
"""计算图片最大宽高比"""
global min_aspectratio, max_aspectratio
img_shape = img.shape
height, width = img_shape[0], img_shape[1]
......@@ -87,11 +93,19 @@ def get_image_min_max_aspectratio(img):
return min_aspectratio, max_aspectratio
def get_image_dim(img):
"""获取图像的维度"""
"""获取图像的通道数"""
img_shape = img.shape
if img_shape[-1] not in img_dim:
img_dim.append(img_shape[-1])
def is_label_gray(grt):
"""判断标签是否为灰度图"""
grt_shape = grt.shape
if len(grt_shape) == 2:
return True
else:
return False
def image_label_shape_check(img, grt):
"""
验证图像和标注的大小是否匹配
......@@ -110,17 +124,15 @@ def image_label_shape_check(img, grt):
def ground_truth_check(grt, grt_path):
"""
验证标注是否重零开始,标注值为0,1,...,num_classes-1, ingnore_idx
验证标注图像的格式
返回标注的像素数
检查图像是否都是ignore_index
统计标注图类别和像素数
params:
grt: 标注图
grt_path: 标注图路径
return:
png_format: 返回是否是png格式图片
label_correct: 返回标注是否是正确的
label_pixel_num: 返回标注的像素数
unique: 返回标注类别
counts: 返回标注的像素数
"""
if imghdr.what(grt_path) == "png":
png_format = True
......@@ -135,7 +147,7 @@ def sum_gt_check(png_format, grt_classes, num_of_each_class):
"""
统计所有标注图上的格式、类别和每个类别的像素数
params:
png_format: 返回是否是png格式图片
png_format: 是否是png格式图片
grt_classes: 标注类别
num_of_each_class: 各个类别的像素数目
"""
......@@ -188,7 +200,7 @@ def gt_check():
total_nc = sorted(zip(total_grt_classes, total_num_of_each_class))
logger.info("\nDoing label pixel statistics...\nTotal label calsses "
logger.info("\nDoing label pixel statistics...\nTotal label classes "
"and their corresponding numbers:\n{} ".format(total_nc))
if len(label_wrong) == 0 and not total_nc[0][0]:
......@@ -322,6 +334,17 @@ def imread_check():
for i in imread_failed:
logger.debug(i)
def label_gray_check():
if len(label_gray_wrong) == 0:
logger.info(correct_print("label gray check"))
logger.info("All label images are gray")
else:
logger.info(error_print("label gray check"))
logger.info("{} label images are not gray\nLabel pixel statistics may "
"be insignificant".format(len(label_gray_wrong)))
for i in label_gray_wrong:
logger.debug(i)
def check_train_dataset():
......@@ -340,11 +363,15 @@ def check_train_dataset():
grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
try:
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
grt = cv2_imread(grt_path, cv2.IMREAD_UNCHANGED)
except Exception as e:
imread_failed.append((line, str(e)))
continue
is_gray = is_label_gray(grt)
if not is_gray:
label_gray_wrong.append(line)
grt = cv2.cvtColor(grt, cv2.COLOR_BGR2GRAY)
get_image_dim(img)
is_equal_img_grt_shape = image_label_shape_check(img, grt)
if not is_equal_img_grt_shape:
......@@ -359,6 +386,7 @@ def check_train_dataset():
file_list_check(list_file)
imread_check()
label_gray_check()
gt_check()
image_type_check(img_dim)
shape_check()
......@@ -383,9 +411,14 @@ def check_val_dataset():
grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
try:
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
grt = cv2_imread(grt_path, cv2.IMREAD_UNCHANGED)
except Exception as e:
imread_failed.append((line, e.message))
is_gray = is_label_gray(grt)
if not is_gray:
label_gray_wrong.append(line)
grt = cv2.cvtColor(grt, cv2.COLOR_BGR2GRAY)
get_image_max_height_width(img)
get_image_min_max_aspectratio(img)
get_image_dim(img)
......@@ -401,6 +434,7 @@ def check_val_dataset():
file_list_check(list_file)
imread_check()
label_gray_check()
gt_check()
image_type_check(img_dim)
shape_check()
......@@ -430,10 +464,15 @@ def check_test_dataset():
grt_path = os.path.join(cfg.DATASET.DATA_DIR, grt_name)
try:
img = cv2_imread(img_path, cv2.IMREAD_UNCHANGED)
grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
grt = cv2_imread(grt_path, cv2.IMREAD_UNCHANGED)
except Exception as e:
imread_failed.append((line, e.message))
continue
is_gray = is_label_gray(grt)
if not is_gray:
label_gray_wrong.append(line)
grt = cv2.cvtColor(grt, cv2.COLOR_BGR2GRAY)
is_equal_img_grt_shape = image_label_shape_check(img, grt)
if not is_equal_img_grt_shape:
shape_unequal_image.append(line)
......@@ -452,6 +491,8 @@ def check_test_dataset():
file_list_check(list_file)
imread_check()
if has_label:
label_gray_check()
if has_label:
gt_check()
image_type_check(img_dim)
......
......@@ -14,6 +14,7 @@
from test_utils import download_file_and_uncompress, train, eval, vis, export_model
import os
import argparse
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
DATASET_PATH = os.path.join(LOCAL_PATH, "..", "dataset")
......@@ -43,7 +44,16 @@ if __name__ == "__main__":
vis_dir = os.path.join(LOCAL_PATH, "visual", model_name)
saved_model = os.path.join(LOCAL_PATH, "saved_model", model_name)
devices = ['0']
parser = argparse.ArgumentParser(description="PaddleSeg loacl test")
parser.add_argument("--devices",
dest="devices",
help="GPU id of running. if more than one, use spacing to separate.",
nargs="+",
default=0,
type=int)
args = parser.parse_args()
devices = [str(x) for x in args.devices]
export_model(
flags=["--cfg", cfg],
......
......@@ -14,6 +14,7 @@
from test_utils import download_file_and_uncompress, train, eval, vis, export_model
import os
import argparse
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
DATASET_PATH = os.path.join(LOCAL_PATH, "..", "dataset")
......@@ -44,7 +45,16 @@ if __name__ == "__main__":
vis_dir = os.path.join(LOCAL_PATH, "visual", model_name)
saved_model = os.path.join(LOCAL_PATH, "saved_model", model_name)
devices = ['0']
parser = argparse.ArgumentParser(description="PaddleSeg loacl test")
parser.add_argument("--devices",
dest="devices",
help="GPU id of running. if more than one, use spacing to separate.",
nargs="+",
default=0,
type=int)
args = parser.parse_args()
devices = [str(x) for x in args.devices]
train(
flags=["--cfg", cfg, "--use_gpu", "--log_steps", "10"],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册