From e79655cd1dc4831b699e5685c079d17f1c125221 Mon Sep 17 00:00:00 2001 From: Lv Yan <2839719742@qq.com> Date: Tue, 23 Nov 2021 20:25:45 +0800 Subject: [PATCH] Add the files of 3D segment! --- 3DSegment/dataset/make_json.py | 45 ++++++++ 3DSegment/main.py | 18 ++++ 3DSegment/nets/models.py | 68 +++++++++++++ 3DSegment/test.py | 102 +++++++++++++++++++ 3DSegment/train.py | 166 ++++++++++++++++++++++++++++++ 3DSegment/utils/utils.py | 68 ++++++++----- Coronary_seg/.gitkeep | 0 Coronary_seg/train.py | 181 --------------------------------- Coronary_seg/transforms.py | 84 --------------- 9 files changed, 444 insertions(+), 288 deletions(-) create mode 100644 3DSegment/dataset/make_json.py create mode 100644 3DSegment/main.py create mode 100644 3DSegment/nets/models.py create mode 100644 3DSegment/test.py create mode 100644 3DSegment/train.py delete mode 100644 Coronary_seg/.gitkeep delete mode 100644 Coronary_seg/train.py delete mode 100644 Coronary_seg/transforms.py diff --git a/3DSegment/dataset/make_json.py b/3DSegment/dataset/make_json.py new file mode 100644 index 0000000..8989b42 --- /dev/null +++ b/3DSegment/dataset/make_json.py @@ -0,0 +1,45 @@ +import json +import os + +image_path = r"./imagesTr" +label_path = r"./labelsTr" +image_files = [file for file in os.listdir(image_path)] +label_files = [file for file in os.listdir(label_path)] +list_data = [] +lengths = len(image_files) +print(lengths) +for i in range(len(image_files)): + image = image_files[i] + label = label_files[i] + print(image) + print(label) + dic = { + "image": "imagesTr/" + image, + "label": "labelsTr/" + label + } + list_data.append(dic) +folds = 6 +for i in range(folds): + data = { + "name": "BrainSeg", + "modality": { + "0": "FLAIR", + "1": "T1w", + "2": "t1gd", + "3": "T2w" + }, + "labels": { + "0": "background", + "1": "edema", + "2": "non-enhancing tumor", + "3": "enhancing tumour" + }, + "numTraining": int((folds - 1) * lengths / folds), + "numVal": int(lengths / folds), + "training": list_data[:i * int(lengths / folds)] + list_data[(i + 1) * int(lengths / folds):], + "validation": list_data[i * int(lengths / folds):(i + 1) * int(lengths / folds)] + } + + file_name = f'dataset_{i}.json' # 通过扩展名指定文件存储的数据为json格式 + with open(file_name, 'w') as file_object: + json.dump(data, file_object) diff --git a/3DSegment/main.py b/3DSegment/main.py new file mode 100644 index 0000000..d3bf5ab --- /dev/null +++ b/3DSegment/main.py @@ -0,0 +1,18 @@ +# This is a sample Python script. + +# Press Shift+F10 to execute it or replace it with your code. +# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings. + + +def print_hi(name): + # Use a breakpoint in the code line below to debug your script. + print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint. + + +# Press the green button in the gutter to run the script. +if __name__ == '__main__': + print_hi('PyCharm') + for i in range(4): + metric_value_ + +# See PyCharm help at https://www.jetbrains.com/help/pycharm/ diff --git a/3DSegment/nets/models.py b/3DSegment/nets/models.py new file mode 100644 index 0000000..b3b825e --- /dev/null +++ b/3DSegment/nets/models.py @@ -0,0 +1,68 @@ +from utils.utils import get_param +from monai.networks.layers.factories import Act, Norm +from monai.networks.nets import UNet, VNet, BasicUNet, DynUNet, UNETR + + +def get_model(args): + global model + if args.mode == "UNet": + model = UNet( + dimensions=args.dim, + in_channels=1, + out_channels=args.num_classes, + channels=(32, 64, 128, 256, 512), + strides=(2, 2, 2, 2), + num_res_units=2, + act=Act.PRELU, + norm=Norm.INSTANCE, + dropout=0.0, + ) + elif args.mode == "VNet": + model = VNet( + spatial_dims=args.dim, + in_channels=1, + out_channels=args.num_classes, + act=("elu", {"inplace": True}), + dropout_prob=0.5, + dropout_dim=args.dim + ) + elif args.mode == "BasicUNet": + model = BasicUNet( + dimensions=args.dim, + in_channels=1, + out_channels=args.num_classes, + features=(32, 32, 64, 128, 256, 32), + act=("LeakyReLU", {"negative_slope": 0.1, "inplace": True}), + norm=("instance", {"affine": True}), + dropout=0.0, + upsample="deconv" + ) + elif args.mode == "DynUNet": + strides, kernels = get_param(args) + model = DynUNet( + spatial_dims=args.dim, + in_channels=1, + out_channels=args.num_classes, + kernel_size=kernels, + strides=strides, + upsample_kernel_size=strides[1:], + norm_name=("INSTANCE", {"affine": True}), + deep_supervision=False, + deep_supr_num=1, + res_block=True) + elif args.mode == "UNETR": + model = UNETR( + in_channels=1, + out_channels=args.num_classes, + img_size=args.spatial_size, + feature_size=16, + hidden_size=144, + mlp_dim=576, + num_heads=12, + pos_embed='conv', + norm_name='instance', + conv_block=True, + res_block=True, + dropout_rate=0.0 + ) + return model diff --git a/3DSegment/test.py b/3DSegment/test.py new file mode 100644 index 0000000..4cca86d --- /dev/null +++ b/3DSegment/test.py @@ -0,0 +1,102 @@ +import os + +import nibabel as nib +import numpy as np +import pandas as pd +import torch +from monai.data import load_decathlon_datalist, CacheDataset, DataLoader, decollate_batch +from monai.handlers import from_engine +from monai.inferers import sliding_window_inference + +from nets.models import get_model +from utils.utils import evaluate, get_args, post_transforms, test_transforms + + +def test(args): + data_dir = args.data + split_json = f"dataset_{args.fold}.json" + datasets = os.path.join(data_dir, split_json) + val_files = load_decathlon_datalist(datasets, True, "validation") + test_ds = CacheDataset( + data=val_files, transform=test_transforms(args), cache_rate=args.cache, num_workers=args.num_workers + ) + test_loader = DataLoader( + test_ds, batch_size=args.val_batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True + ) + + if args.gpus: + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + else: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = get_model(args).to(device) + + save_path = os.path.join(args.save_model, args.mode) + model.load_state_dict(torch.load(os.path.join(save_path, f"best_metric_model_{args.fold}.pth"))) + model.eval() + M = [] + with torch.no_grad(): + for i, test_data in enumerate(test_loader): + # file name + img_name = os.path.split(test_data["image_meta_dict"]["filename_or_obj"][0])[1] + name = img_name.split('.')[0] + + test_inputs, test_labels = test_data["image"].cuda(), test_data["label"].cuda() + + roi_size = args.spatial_size + sw_batch_size = args.sw_batch_size + test_data["pred"] = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model, overlap=0.8) + + test_data = [post_transforms(args)(i) for i in decollate_batch(test_data)] + test_outputs, test_labels = from_engine(["pred", "label"])(test_data) + + # save nii label + print(test_outputs[0].shape[0]) + for k in range(test_outputs[0].shape[0]): + array = np.array(test_outputs[0][k]) + nii_class = nib.Nifti1Image(array, np.eye(4)) + save_path = os.path.join(args.pred_path, f"fold{args.fold}") + os.makedirs(save_path, exist_ok=True) + nib.save(nii_class, os.path.join(save_path, f"{name}_pred_{k}.nii.gz")) + out_array = 0 + for o in range(1, args.num_classes): + out_array += o * test_outputs[0][o] + out_array = np.array(out_array) + # print(out_array.shape, np.max(out_array), np.min(out_array)) + nii_img = nib.Nifti1Image(out_array, np.eye(4)) + save_path = os.path.join(args.pred_path, f"fold{args.fold}") + os.makedirs(save_path, exist_ok=True) + nib.save(nii_img, os.path.join(save_path, f"{name}_pred.nii.gz")) + + # metric + metric = evaluate(test_outputs, test_labels, num_classes=args.num_classes) + M.append(metric) + print(name, f"dice={metric[0]}") + + header = ['image', 'label', 'dice',] + for h1 in range(1, args.num_classes): + header.append(f'dice_{h1}') + header.append('precision') + header.append('recall') + + for j in range(len(val_files)): + val_files[j]['dice'] = M[j][0] + for h2 in range(1, args.num_classes): + val_files[j][f'dice_{h2}'] = M[j][1][h2-1] + val_files[j]['precision'] = M[j][2] + val_files[j]['recall'] = M[j][3] + + data_dict = {header[0]: [val_files[i]['image'] for i in range(len(val_files))], + header[1]: [val_files[i]['label'] for i in range(len(val_files))], + header[2]: [val_files[i]['dice'] for i in range(len(val_files))], + header[-2]: [val_files[i]['precision'] for i in range(len(val_files))], + header[-1]: [val_files[i]['recall'] for i in range(len(val_files))]} + for h3 in range(1, args.num_classes): + data_dict[header[2 + h3]] = [val_files[i][f'dice_{h3}'] for i in range(len(val_files))] + datas = pd.DataFrame(data_dict) + datas.to_csv(os.path.join(args.pred_path, f"fold{args.fold}", f"test-{args.fold}.csv"), index=False, sep=',') + + +if __name__ == '__main__': + args = get_args() + test(args) diff --git a/3DSegment/train.py b/3DSegment/train.py new file mode 100644 index 0000000..8d8d226 --- /dev/null +++ b/3DSegment/train.py @@ -0,0 +1,166 @@ +import math +import os +import time + +import torch +from monai.inferers import sliding_window_inference +from monai.visualize.img2tensorboard import SummaryWriter, plot_2d_or_3d_image +from monai.losses import DiceCELoss +from monai.metrics import DiceMetric +from monai.data import DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch +from monai.transforms import AsDiscrete + +from utils.utils import train_transforms, val_transforms, get_args, choice_optimizer, get_logger, cosing_schedule +from nets.models import get_model + +args = get_args() +logger = get_logger(args) +writer = SummaryWriter() + + +def train(args): + data_dir = args.data + split_json = f"dataset_{args.fold}.json" + datasets = os.path.join(data_dir, split_json) + train_files = load_decathlon_datalist(datasets, True, "training") + val_files = load_decathlon_datalist(datasets, True, "validation") + train_ds = CacheDataset( + data=train_files, transform=train_transforms(args), cache_rate=args.cache, num_workers=args.num_workers, + ) + train_loader = DataLoader( + train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True + ) + val_ds = CacheDataset( + data=val_files, transform=val_transforms(args), cache_rate=args.cache, num_workers=args.num_workers + ) + val_loader = DataLoader( + val_ds, batch_size=args.val_batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True + ) + + if args.gpus: + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + else: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = get_model(args).to(device) + torch.backends.cudnn.benchmark = True + + loss_function = DiceCELoss(include_background=False, to_onehot_y=True, softmax=True) + post_label = AsDiscrete(to_onehot=True, n_classes=args.num_classes) + post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=args.num_classes) + dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False) + dice_metric_batch = DiceMetric(include_background=False, reduction="mean_batch", get_not_nans=False) + + epoch_loss_values = [] + val_interval = 1 + best_metric = -1 + best_metric_epoch = -1 + metric_values = [] + metric_values_class = {} + total_start = time.time() + for m1 in range(1, args.num_classes): + metric_values_class[f"{m1}"] = [] + + for epoch in range(args.max_epoch): + epoch_start = time.time() + print("-" * 100) + print(f"epoch {epoch + 1}/{args.max_epoch}") + logger.info(f"epoch {epoch + 1}/{args.max_epoch}") + model.train() + epoch_loss = 0 + step = 0 + for batch_data in train_loader: + step_start = time.time() + step += 1 + inputs, labels = (batch_data["image"].to(device), batch_data["label"].to(device)) + if args.scheduler: + new_lr = cosing_schedule(args.learning_rate, epoch, warmup_step=10, epoch_total=args.max_epoch) + else: + new_lr = args.learning_rate + optimizer = choice_optimizer(args, lr=new_lr, model=model) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_function(outputs, labels) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) + print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f} step time: {(time.time() - step_start):.4f}") + logger.info(f"{step}/{epoch_len}, train_loss: {loss.item():.4f} step time: {(time.time() - step_start):.4f}") + epoch_loss /= step + epoch_loss_values.append(epoch_loss) + writer.add_scalar("train_loss", epoch_loss, epoch + 1) + print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") + logger.info(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") + if (epoch + 1) % val_interval == 0: + model.eval() + with torch.no_grad(): + for val_data in val_loader: + val_inputs, val_labels = (val_data["image"].to(device), val_data["label"].to(device)) + val_outputs = sliding_window_inference(val_inputs, args.spatial_size, args.sw_batch_size, model) + val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)] + val_labels = [post_label(i) for i in decollate_batch(val_labels)] + dice_metric(y_pred=val_outputs, y=val_labels) + dice_metric_batch(y_pred=val_outputs, y=val_labels) + metric = dice_metric.aggregate().item() + metric_values.append(metric) + + metric_batch = dice_metric_batch.aggregate() + for m2 in range(args.num_classes - 1): + metric_num = metric_batch[m2].item() + metric_values_class[f"{m2 + 1}"].append(metric_num) + print(f"class {m2+1} dice: {metric_num:.4f}") + logger.info(f"class {m2+1} dice: {metric_num:.4f}") + writer.add_scalar(f"dice_{m2+1}", metric_num, epoch + 1) # tensorboard可视化 + dice_metric.reset() + dice_metric_batch.reset() + writer.add_scalar("val_mean_dice", metric, epoch + 1) # tensorboard可视化 + plot_2d_or_3d_image(val_inputs, epoch + 1, writer, index=0, tag="image") + plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") + plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") + + if metric > best_metric: + best_metric = metric + best_metric_epoch = epoch + 1 + os.makedirs(os.path.join(args.save_model, args.mode), exist_ok=True) + torch.save(model.state_dict(), + os.path.join(args.save_model, args.mode, f"best_metric_model.pth")) + print("saved new best metric model") + logger.info(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") + print( + f"current epoch: {epoch + 1} current" + f" mean dice: {metric:.4f}" + f" best mean dice: {best_metric:.4f}" + f" at epoch: {best_metric_epoch}" + ) + logger.info( + f"current epoch: {epoch + 1} current" + f" mean dice: {metric:.4f}" + f" best mean dice: {best_metric:.4f}" + f" at epoch: {best_metric_epoch}" + ) + print( + f"time consuming of epoch {epoch + 1} is:" + f" {(time.time() - epoch_start):.4f}" + ) + logger.info( + f"time consuming of epoch {epoch + 1} is:" + f" {(time.time() - epoch_start):.4f}" + ) + print( + f"train completed, best_metric: {best_metric:.4f}" + f" at epoch: {best_metric_epoch}" + f" total time: {(time.time() - total_start):.4f}" + ) + logger.info( + f"train completed, best_metric: {best_metric:.4f}" + f" at epoch: {best_metric_epoch}" + f" total time: {(time.time() - total_start):.4f}" + ) + + +if __name__ == '__main__': + train(args) + + + diff --git a/3DSegment/utils/utils.py b/3DSegment/utils/utils.py index b06d617..fbdfe24 100644 --- a/3DSegment/utils/utils.py +++ b/3DSegment/utils/utils.py @@ -1,3 +1,5 @@ +import logging +import math from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter from typing import Sequence @@ -15,7 +17,7 @@ from monai.transforms import ( ScaleIntensityRanged, Spacingd, RandRotate90d, - ToTensord, EnsureTyped, Invertd, AsDiscreted, + ToTensord, EnsureTyped, Invertd, AsDiscreted, EnsureChannelFirstd ) @@ -41,13 +43,13 @@ def get_args(strings=None): arg("--learning_rate", type=float, default=1e-4, help="") arg("--weight_decay", type=float, default=1e-5, help="") arg("--momentum", type=float, default=0.99, help="") - arg("--max_epoch", type=int, default=1200, help="") + arg("--max_epoch", type=int, default=200, help="") arg("--sw_batch_size", type=int, default=4, help="") arg("--save_model", type=str, default="./models", help="") arg("--pred_path", type=str, default="./preds", help="") - arg("--mode", type=str, default="UNet", choices=["UNet", "VNet", "BasicUNet", "DynUNet", "UNETR"]) + arg("--mode", type=str, default="DynUNet", choices=["UNet", "VNet", "BasicUNet", "DynUNet", "UNETR"]) arg("--optimizer", type=str, default="adam", choices=["adam", "adamw", "sgd"]) - arg("--scheduler", type=str, default="", choices=["cosing", "plateau"]) + arg("--scheduler", action="store_true", help="") args = parser.parse_args() return args @@ -56,7 +58,7 @@ def train_transforms(args): train_trans = Compose( [ LoadImaged(keys=["image", "label"]), - AddChanneld(keys=["image", "label"]), + EnsureChannelFirstd(keys=["image", "label"]), Spacingd( keys=["image", "label"], pixdim=args.pixdim, @@ -206,7 +208,7 @@ def get_param(args): # 计算metrics -def evaluate(y_pred, y_label): +def evaluate(y_pred, y_label, num_classes): # print(y_pred, y_pred[0].shape) dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False) dice_metric_batch = DiceMetric(include_background=False, reduction="mean_batch", get_not_nans=False) @@ -214,40 +216,60 @@ def evaluate(y_pred, y_label): dice_metric_batch(y_pred=y_pred, y=y_label) dice_metric_result = dice_metric.aggregate().item() metric_batch = dice_metric_batch.aggregate() - metric_0 = metric_batch[0].item() - metric_1 = metric_batch[1].item() - metric_2 = metric_batch[2].item() - # print(dice_metric_result) - # dice_metric.reset() - - auc_metric = ROCAUCMetric(average="weighted") - auc_metric(y_pred=y_pred[0].view(-1), y=y_label[0].view(-1)) - auc_metric_result = auc_metric.aggregate().item() - auc_metric.reset() - # auc_metric_result = 0 + metric = [] + for m in range(num_classes - 1): + metric_num = metric_batch[m].item() + metric.append(metric_num) + dice_metric.reset() + dice_metric_batch.reset() '''["sensitivity-TPR", "specificity-TNR", "precision-P", "negative predictive value", "miss rate-FNR", "fall out", "false discovery rate", "false omission rate", "prevalence threshold", "threat score", "accuracy", "balanced accuracy", "f1 score", "matthews correlation coefficient", "fowlkes mallows index", "informedness", "markedness"]''' - cfm_metric = ConfusionMatrixMetric(include_background=False, metric_name=("specificity", "precision", "recall"), + cfm_metric = ConfusionMatrixMetric(include_background=False, metric_name=("precision", "recall"), compute_sample=True) cfm_metric(y_pred=y_pred, y=y_label) cfm_metric_result = cfm_metric.aggregate() cfm_metric_result = torch.stack(cfm_metric_result).cpu().numpy()[:, 0] cfm_metric.reset() - return dice_metric_result, metric_0, metric_1, metric_2, auc_metric_result, cfm_metric_result[0], cfm_metric_result[1], cfm_metric_result[2] + return dice_metric_result, metric, cfm_metric_result[0], cfm_metric_result[1] -def choice_optimizer(args, model): +def choice_optimizer(args, lr, model): global optimizer if args.optimizer == 'adam': - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) + optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=args.weight_decay) elif args.optimizer == 'adamw': - optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) + optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=args.weight_decay) elif args.optimizer == 'sgd': - optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum) + optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum) return optimizer +def cosing_schedule(lr, epoch, warmup_step, epoch_total, min=1e-7): + if epoch < warmup_step: + mul = epoch / warmup_step + else: + mul = 0.5 * (1.0 + math.cos(math.pi * (epoch - warmup_step) / (epoch_total - warmup_step))) + return lr * mul + min + + +def get_logger(args): + logger = logging.getLogger('train') + logger.setLevel(level=logging.DEBUG) + + formatter = logging.Formatter('%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s') + + file_handler = logging.FileHandler(f'train_{args.fold}.log') + file_handler.setLevel(level=logging.INFO) + file_handler.setFormatter(formatter) + + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.DEBUG) + stream_handler.setFormatter(formatter) + + logger.addHandler(file_handler) + logger.addHandler(stream_handler) + return logger diff --git a/Coronary_seg/.gitkeep b/Coronary_seg/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/Coronary_seg/train.py b/Coronary_seg/train.py deleted file mode 100644 index 3dff05d..0000000 --- a/Coronary_seg/train.py +++ /dev/null @@ -1,181 +0,0 @@ -import glob -import math -import os -import time - -import torch - -from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch -from monai.inferers import sliding_window_inference -from monai.losses import DiceLoss -from monai.metrics import DiceMetric -from monai.networks.layers import Norm -from monai.networks.nets import UNet -from monai.optimizers import Novograd -from monai.transforms import Compose, EnsureType, AsDiscrete -from monai.utils import set_determinism - -from transforms import transformations -from torch.utils.tensorboard import SummaryWriter -from monai.visualize import plot_2d_or_3d_image -writer = SummaryWriter() - -root_dir = r'./dataset' -data_dir = os.path.join(root_dir, "train") -train_images = sorted(glob.glob(os.path.join(data_dir, "images", "*.nii.gz"))) -train_labels = sorted(glob.glob(os.path.join(data_dir, "labels", "*.nii.gz"))) -data_dicts = [ - {"image": image_name, "label": label_name} - for image_name, label_name in zip(train_images, train_labels) -] -train_files, val_files = data_dicts[:-20], data_dicts[-20:] - - -def train(): - max_epochs = 300 - learning_rate = 1e-4 - val_interval = 1 # do validation for every epoch - - train_trans, val_trans = transformations() - train_ds = CacheDataset(data=train_files, transform=train_trans, cache_rate=1.0, num_workers=0) - val_ds = CacheDataset(data=val_files, transform=val_trans, cache_rate=1.0, num_workers=0) - # don't need many workers because already cached the data - loader_workers = 0 - train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=loader_workers) - val_loader = DataLoader(val_ds, batch_size=1, num_workers=loader_workers) - - device = torch.device("cuda:0") - model = UNet( - dimensions=3, - in_channels=1, - out_channels=2, - channels=(16, 32, 64, 128, 256), - strides=(2, 2, 2, 2), - num_res_units=2, - norm=Norm.BATCH, - ).to(device) - loss_function = DiceLoss(to_onehot_y=True, softmax=True) - dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False) - - post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, n_classes=2)]) - post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, n_classes=2)]) - optimizer = Novograd(model.parameters(), learning_rate * 10) - scaler = torch.cuda.amp.GradScaler() - - best_metric = -1 - best_metric_epoch = -1 - best_metrics_epochs_and_time = [[], [], []] - epoch_loss_values = [] - metric_values = [] - epoch_times = [] - total_start = time.time() - for epoch in range(max_epochs): - epoch_start = time.time() - print("-" * 10) - print(f"epoch {epoch + 1}/{max_epochs}") - model.train() - epoch_loss = 0 - step = 0 - for batch_data in train_loader: - step_start = time.time() - step += 1 - inputs, labels = ( - batch_data["image"].to(device), - batch_data["label"].to(device), - ) - optimizer.zero_grad() - with torch.cuda.amp.autocast(): - outputs = model(inputs) - loss = loss_function(outputs, labels) - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - - epoch_loss += loss.item() - epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) - print( - f"{step}/{epoch_len}, train_loss: {loss.item():.4f}" - f" step time: {(time.time() - step_start):.4f}" - ) - epoch_loss /= step - epoch_loss_values.append(epoch_loss) - writer.add_scalar("train_loss", epoch_loss, epoch + 1) - print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") - if (epoch + 1) % val_interval == 0: - model.eval() - with torch.no_grad(): - for val_data in val_loader: - val_inputs, val_labels = ( - val_data["image"].to(device), - val_data["label"].to(device), - ) - roi_size = (160, 160, 128) - sw_batch_size = 4 - with torch.cuda.amp.autocast(): - val_outputs = sliding_window_inference( - val_inputs, roi_size, sw_batch_size, model - ) - val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)] - val_labels = [post_label(i) for i in decollate_batch(val_labels)] - dice_metric(y_pred=val_outputs, y=val_labels) - metric = dice_metric.aggregate().item() - dice_metric.reset() - metric_values.append(metric) - writer.add_scalar("val_mean_dice", metric, epoch + 1) # tensorboard可视化 - plot_2d_or_3d_image(val_inputs, epoch + 1, writer, index=0, tag="image") - plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label") - plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output") - if metric > best_metric: - best_metric = metric - best_metric_epoch = epoch + 1 - best_metrics_epochs_and_time[0].append(best_metric) - best_metrics_epochs_and_time[1].append(best_metric_epoch) - best_metrics_epochs_and_time[2].append( - time.time() - total_start - ) - torch.save(model.state_dict(), "best_metric_model.pth") - print("saved new best metric model") - print( - f"current epoch: {epoch + 1} current" - f" mean dice: {metric:.4f}" - f" best mean dice: {best_metric:.4f}" - f" at epoch: {best_metric_epoch}" - ) - print( - f"time consuming of epoch {epoch + 1} is:" - f" {(time.time() - epoch_start):.4f}" - ) - epoch_times.append(time.time() - epoch_start) - - print( - f"train completed, best_metric: {best_metric:.4f}" - f" at epoch: {best_metric_epoch}" - f" total time: {(time.time() - total_start):.4f}" - ) - return ( - max_epochs, - epoch_loss_values, - metric_values, - epoch_times, - best_metrics_epochs_and_time, - ) - - -def monai_optimized(): - set_determinism(seed=0) - monai_start = time.time() - ( - epoch_num, - m_epoch_loss_values, - m_metric_values, - m_epoch_times, - m_best, - ) = train() - m_total_time = time.time() - monai_start - print( - f"total training time of {epoch_num} epochs with MONAI: {m_total_time:.4f}" - ) - - -if __name__ == '__main__': - monai_optimized() diff --git a/Coronary_seg/transforms.py b/Coronary_seg/transforms.py deleted file mode 100644 index e685225..0000000 --- a/Coronary_seg/transforms.py +++ /dev/null @@ -1,84 +0,0 @@ -from monai.transforms import ( - AddChanneld, - Compose, - CropForegroundd, - DeleteItemsd, - FgBgToIndicesd, - LoadImaged, - Orientationd, - RandCropByPosNegLabeld, - ScaleIntensityRanged, - Spacingd, - EnsureTyped, -) - - -def transformations(): - train_transforms = Compose( - [ - LoadImaged(keys=["image", "label"]), - AddChanneld(keys=["image", "label"]), - Spacingd( - keys=["image", "label"], - pixdim=(0.5, 0.5, 0.5), - mode=("bilinear", "nearest"), - ), - Orientationd(keys=["image", "label"], axcodes="RAI"), - ScaleIntensityRanged( - keys=["image"], - a_min=150, - a_max=850, - b_min=0.0, - b_max=1.0, - clip=True, - ), - CropForegroundd(keys=["image", "label"], source_key="image"), - # pre-compute foreground and background indexes - # and cache them to accelerate training - FgBgToIndicesd( - keys="label", - fg_postfix="_fg", - bg_postfix="_bg", - image_key="image", - ), - # randomly crop out patch samples from big - # image based on pos / neg ratio - # the image centers of negative samples - # must be in valid image area - RandCropByPosNegLabeld( - keys=["image", "label"], - label_key="label", - spatial_size=(128, 128, 96), - pos=1, - neg=1, - num_samples=4, - fg_indices_key="label_fg", - bg_indices_key="label_bg", - ), - DeleteItemsd(keys=["label_fg", "label_bg"]), - EnsureTyped(keys=["image", "label"]), - ] - ) - val_transforms = Compose( - [ - LoadImaged(keys=["image", "label"]), - AddChanneld(keys=["image", "label"]), - Spacingd( - keys=["image", "label"], - pixdim=(0.5, 0.5, 0.5), - mode=("bilinear", "nearest"), - ), - Orientationd(keys=["image", "label"], axcodes="RAI"), - ScaleIntensityRanged( - keys=["image"], - a_min=150, - a_max=850, - b_min=0.0, - b_max=1.0, - clip=True, - ), - CropForegroundd(keys=["image", "label"], source_key="image"), - EnsureTyped(keys=["image", "label"]), - ] - ) - return train_transforms, val_transforms -- GitLab