提交 26fe6a0f 编写于 作者: W wuyefeilin 提交者: wuzewu

add loss check in check.py (#72)

* add loss check
上级 6d0d104e
......@@ -426,6 +426,15 @@ def max_img_size_statistics():
logger.info("max width and max height of images are ({},{})".format(
max_width, max_height))
def num_classes_loss_matching_check():
loss_type = cfg.SOLVER.LOSS
num_classes = cfg.DATASET.NUM_CLASSES
if num_classes > 2 and (("dice_loss" in loss_type) or ("bce_loss" in loss_type)):
logger.info(error_print("loss check."
" Dice loss and bce loss is only applicable to binary classfication"))
else:
logger.info(correct_print("loss check"))
def check_train_dataset():
list_file = cfg.DATASET.TRAIN_FILE_LIST
......@@ -474,6 +483,7 @@ def check_train_dataset():
image_type_check(img_dim)
max_img_size_statistics()
shape_check()
num_classes_loss_matching_check()
def check_val_dataset():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册