diff --git a/pdseg/check.py b/pdseg/check.py index 9eb8175e0c6377714be538f26c79e61aff51202e..5d11065a07d448cc18a62652f49529549a6fd0f4 100644 --- a/pdseg/check.py +++ b/pdseg/check.py @@ -426,6 +426,15 @@ def max_img_size_statistics(): logger.info("max width and max height of images are ({},{})".format( max_width, max_height)) +def num_classes_loss_matching_check(): + loss_type = cfg.SOLVER.LOSS + num_classes = cfg.DATASET.NUM_CLASSES + if num_classes > 2 and (("dice_loss" in loss_type) or ("bce_loss" in loss_type)): + logger.info(error_print("loss check." + " Dice loss and bce loss is only applicable to binary classfication")) + else: + logger.info(correct_print("loss check")) + def check_train_dataset(): list_file = cfg.DATASET.TRAIN_FILE_LIST @@ -474,6 +483,7 @@ def check_train_dataset(): image_type_check(img_dim) max_img_size_statistics() shape_check() + num_classes_loss_matching_check() def check_val_dataset():