提交 e79655cd 编写于 作者: 小吕同学吖's avatar 小吕同学吖 😲

Add the files of 3D segment!

上级 15f22f83
import json
import os
image_path = r"./imagesTr"
label_path = r"./labelsTr"
image_files = [file for file in os.listdir(image_path)]
label_files = [file for file in os.listdir(label_path)]
list_data = []
lengths = len(image_files)
print(lengths)
for i in range(len(image_files)):
image = image_files[i]
label = label_files[i]
print(image)
print(label)
dic = {
"image": "imagesTr/" + image,
"label": "labelsTr/" + label
}
list_data.append(dic)
folds = 6
for i in range(folds):
data = {
"name": "BrainSeg",
"modality": {
"0": "FLAIR",
"1": "T1w",
"2": "t1gd",
"3": "T2w"
},
"labels": {
"0": "background",
"1": "edema",
"2": "non-enhancing tumor",
"3": "enhancing tumour"
},
"numTraining": int((folds - 1) * lengths / folds),
"numVal": int(lengths / folds),
"training": list_data[:i * int(lengths / folds)] + list_data[(i + 1) * int(lengths / folds):],
"validation": list_data[i * int(lengths / folds):(i + 1) * int(lengths / folds)]
}
file_name = f'dataset_{i}.json' # 通过扩展名指定文件存储的数据为json格式
with open(file_name, 'w') as file_object:
json.dump(data, file_object)
# This is a sample Python script.
# Press Shift+F10 to execute it or replace it with your code.
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
def print_hi(name):
# Use a breakpoint in the code line below to debug your script.
print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint.
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
print_hi('PyCharm')
for i in range(4):
metric_value_
# See PyCharm help at https://www.jetbrains.com/help/pycharm/
from utils.utils import get_param
from monai.networks.layers.factories import Act, Norm
from monai.networks.nets import UNet, VNet, BasicUNet, DynUNet, UNETR
def get_model(args):
global model
if args.mode == "UNet":
model = UNet(
dimensions=args.dim,
in_channels=1,
out_channels=args.num_classes,
channels=(32, 64, 128, 256, 512),
strides=(2, 2, 2, 2),
num_res_units=2,
act=Act.PRELU,
norm=Norm.INSTANCE,
dropout=0.0,
)
elif args.mode == "VNet":
model = VNet(
spatial_dims=args.dim,
in_channels=1,
out_channels=args.num_classes,
act=("elu", {"inplace": True}),
dropout_prob=0.5,
dropout_dim=args.dim
)
elif args.mode == "BasicUNet":
model = BasicUNet(
dimensions=args.dim,
in_channels=1,
out_channels=args.num_classes,
features=(32, 32, 64, 128, 256, 32),
act=("LeakyReLU", {"negative_slope": 0.1, "inplace": True}),
norm=("instance", {"affine": True}),
dropout=0.0,
upsample="deconv"
)
elif args.mode == "DynUNet":
strides, kernels = get_param(args)
model = DynUNet(
spatial_dims=args.dim,
in_channels=1,
out_channels=args.num_classes,
kernel_size=kernels,
strides=strides,
upsample_kernel_size=strides[1:],
norm_name=("INSTANCE", {"affine": True}),
deep_supervision=False,
deep_supr_num=1,
res_block=True)
elif args.mode == "UNETR":
model = UNETR(
in_channels=1,
out_channels=args.num_classes,
img_size=args.spatial_size,
feature_size=16,
hidden_size=144,
mlp_dim=576,
num_heads=12,
pos_embed='conv',
norm_name='instance',
conv_block=True,
res_block=True,
dropout_rate=0.0
)
return model
import os
import nibabel as nib
import numpy as np
import pandas as pd
import torch
from monai.data import load_decathlon_datalist, CacheDataset, DataLoader, decollate_batch
from monai.handlers import from_engine
from monai.inferers import sliding_window_inference
from nets.models import get_model
from utils.utils import evaluate, get_args, post_transforms, test_transforms
def test(args):
data_dir = args.data
split_json = f"dataset_{args.fold}.json"
datasets = os.path.join(data_dir, split_json)
val_files = load_decathlon_datalist(datasets, True, "validation")
test_ds = CacheDataset(
data=val_files, transform=test_transforms(args), cache_rate=args.cache, num_workers=args.num_workers
)
test_loader = DataLoader(
test_ds, batch_size=args.val_batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True
)
if args.gpus:
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
else:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = get_model(args).to(device)
save_path = os.path.join(args.save_model, args.mode)
model.load_state_dict(torch.load(os.path.join(save_path, f"best_metric_model_{args.fold}.pth")))
model.eval()
M = []
with torch.no_grad():
for i, test_data in enumerate(test_loader):
# file name
img_name = os.path.split(test_data["image_meta_dict"]["filename_or_obj"][0])[1]
name = img_name.split('.')[0]
test_inputs, test_labels = test_data["image"].cuda(), test_data["label"].cuda()
roi_size = args.spatial_size
sw_batch_size = args.sw_batch_size
test_data["pred"] = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model, overlap=0.8)
test_data = [post_transforms(args)(i) for i in decollate_batch(test_data)]
test_outputs, test_labels = from_engine(["pred", "label"])(test_data)
# save nii label
print(test_outputs[0].shape[0])
for k in range(test_outputs[0].shape[0]):
array = np.array(test_outputs[0][k])
nii_class = nib.Nifti1Image(array, np.eye(4))
save_path = os.path.join(args.pred_path, f"fold{args.fold}")
os.makedirs(save_path, exist_ok=True)
nib.save(nii_class, os.path.join(save_path, f"{name}_pred_{k}.nii.gz"))
out_array = 0
for o in range(1, args.num_classes):
out_array += o * test_outputs[0][o]
out_array = np.array(out_array)
# print(out_array.shape, np.max(out_array), np.min(out_array))
nii_img = nib.Nifti1Image(out_array, np.eye(4))
save_path = os.path.join(args.pred_path, f"fold{args.fold}")
os.makedirs(save_path, exist_ok=True)
nib.save(nii_img, os.path.join(save_path, f"{name}_pred.nii.gz"))
# metric
metric = evaluate(test_outputs, test_labels, num_classes=args.num_classes)
M.append(metric)
print(name, f"dice={metric[0]}")
header = ['image', 'label', 'dice',]
for h1 in range(1, args.num_classes):
header.append(f'dice_{h1}')
header.append('precision')
header.append('recall')
for j in range(len(val_files)):
val_files[j]['dice'] = M[j][0]
for h2 in range(1, args.num_classes):
val_files[j][f'dice_{h2}'] = M[j][1][h2-1]
val_files[j]['precision'] = M[j][2]
val_files[j]['recall'] = M[j][3]
data_dict = {header[0]: [val_files[i]['image'] for i in range(len(val_files))],
header[1]: [val_files[i]['label'] for i in range(len(val_files))],
header[2]: [val_files[i]['dice'] for i in range(len(val_files))],
header[-2]: [val_files[i]['precision'] for i in range(len(val_files))],
header[-1]: [val_files[i]['recall'] for i in range(len(val_files))]}
for h3 in range(1, args.num_classes):
data_dict[header[2 + h3]] = [val_files[i][f'dice_{h3}'] for i in range(len(val_files))]
datas = pd.DataFrame(data_dict)
datas.to_csv(os.path.join(args.pred_path, f"fold{args.fold}", f"test-{args.fold}.csv"), index=False, sep=',')
if __name__ == '__main__':
args = get_args()
test(args)
import glob
import math
import os
import time
import torch
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.inferers import sliding_window_inference
from monai.losses import DiceLoss
from monai.visualize.img2tensorboard import SummaryWriter, plot_2d_or_3d_image
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.networks.layers import Norm
from monai.networks.nets import UNet
from monai.optimizers import Novograd
from monai.transforms import Compose, EnsureType, AsDiscrete
from monai.utils import set_determinism
from transforms import transformations
from torch.utils.tensorboard import SummaryWriter
from monai.visualize import plot_2d_or_3d_image
from monai.data import DataLoader, CacheDataset, load_decathlon_datalist, decollate_batch
from monai.transforms import AsDiscrete
from utils.utils import train_transforms, val_transforms, get_args, choice_optimizer, get_logger, cosing_schedule
from nets.models import get_model
args = get_args()
logger = get_logger(args)
writer = SummaryWriter()
root_dir = r'./dataset'
data_dir = os.path.join(root_dir, "train")
train_images = sorted(glob.glob(os.path.join(data_dir, "images", "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join(data_dir, "labels", "*.nii.gz")))
data_dicts = [
{"image": image_name, "label": label_name}
for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:-20], data_dicts[-20:]
def train():
max_epochs = 300
learning_rate = 1e-4
val_interval = 1 # do validation for every epoch
train_trans, val_trans = transformations()
train_ds = CacheDataset(data=train_files, transform=train_trans, cache_rate=1.0, num_workers=0)
val_ds = CacheDataset(data=val_files, transform=val_trans, cache_rate=1.0, num_workers=0)
# don't need many workers because already cached the data
loader_workers = 0
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=loader_workers)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=loader_workers)
device = torch.device("cuda:0")
model = UNet(
dimensions=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, n_classes=2)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, n_classes=2)])
optimizer = Novograd(model.parameters(), learning_rate * 10)
scaler = torch.cuda.amp.GradScaler()
def train(args):
data_dir = args.data
split_json = f"dataset_{args.fold}.json"
datasets = os.path.join(data_dir, split_json)
train_files = load_decathlon_datalist(datasets, True, "training")
val_files = load_decathlon_datalist(datasets, True, "validation")
train_ds = CacheDataset(
data=train_files, transform=train_transforms(args), cache_rate=args.cache, num_workers=args.num_workers,
)
train_loader = DataLoader(
train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True
)
val_ds = CacheDataset(
data=val_files, transform=val_transforms(args), cache_rate=args.cache, num_workers=args.num_workers
)
val_loader = DataLoader(
val_ds, batch_size=args.val_batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True
)
if args.gpus:
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
else:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = get_model(args).to(device)
torch.backends.cudnn.benchmark = True
loss_function = DiceCELoss(include_background=False, to_onehot_y=True, softmax=True)
post_label = AsDiscrete(to_onehot=True, n_classes=args.num_classes)
post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=args.num_classes)
dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
dice_metric_batch = DiceMetric(include_background=False, reduction="mean_batch", get_not_nans=False)
epoch_loss_values = []
val_interval = 1
best_metric = -1
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
epoch_loss_values = []
metric_values = []
epoch_times = []
metric_values_class = {}
total_start = time.time()
for epoch in range(max_epochs):
for m1 in range(1, args.num_classes):
metric_values_class[f"{m1}"] = []
for epoch in range(args.max_epoch):
epoch_start = time.time()
print("-" * 10)
print(f"epoch {epoch + 1}/{max_epochs}")
print("-" * 100)
print(f"epoch {epoch + 1}/{args.max_epoch}")
logger.info(f"epoch {epoch + 1}/{args.max_epoch}")
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step_start = time.time()
step += 1
inputs, labels = (
batch_data["image"].to(device),
batch_data["label"].to(device),
)
inputs, labels = (batch_data["image"].to(device), batch_data["label"].to(device))
if args.scheduler:
new_lr = cosing_schedule(args.learning_rate, epoch, warmup_step=10, epoch_total=args.max_epoch)
else:
new_lr = args.learning_rate
optimizer = choice_optimizer(args, lr=new_lr, model=model)
optimizer.zero_grad()
with torch.cuda.amp.autocast():
outputs = model(inputs)
loss = loss_function(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)
print(
f"{step}/{epoch_len}, train_loss: {loss.item():.4f}"
f" step time: {(time.time() - step_start):.4f}"
)
print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f} step time: {(time.time() - step_start):.4f}")
logger.info(f"{step}/{epoch_len}, train_loss: {loss.item():.4f} step time: {(time.time() - step_start):.4f}")
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
writer.add_scalar("train_loss", epoch_loss, epoch + 1)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
logger.info(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
for val_data in val_loader:
val_inputs, val_labels = (
val_data["image"].to(device),
val_data["label"].to(device),
)
roi_size = (160, 160, 128)
sw_batch_size = 4
with torch.cuda.amp.autocast():
val_outputs = sliding_window_inference(
val_inputs, roi_size, sw_batch_size, model
)
val_inputs, val_labels = (val_data["image"].to(device), val_data["label"].to(device))
val_outputs = sliding_window_inference(val_inputs, args.spatial_size, args.sw_batch_size, model)
val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
val_labels = [post_label(i) for i in decollate_batch(val_labels)]
dice_metric(y_pred=val_outputs, y=val_labels)
dice_metric_batch(y_pred=val_outputs, y=val_labels)
metric = dice_metric.aggregate().item()
dice_metric.reset()
metric_values.append(metric)
metric_batch = dice_metric_batch.aggregate()
for m2 in range(args.num_classes - 1):
metric_num = metric_batch[m2].item()
metric_values_class[f"{m2 + 1}"].append(metric_num)
print(f"class {m2+1} dice: {metric_num:.4f}")
logger.info(f"class {m2+1} dice: {metric_num:.4f}")
writer.add_scalar(f"dice_{m2+1}", metric_num, epoch + 1) # tensorboard可视化
dice_metric.reset()
dice_metric_batch.reset()
writer.add_scalar("val_mean_dice", metric, epoch + 1) # tensorboard可视化
plot_2d_or_3d_image(val_inputs, epoch + 1, writer, index=0, tag="image")
plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
best_metrics_epochs_and_time[0].append(best_metric)
best_metrics_epochs_and_time[1].append(best_metric_epoch)
best_metrics_epochs_and_time[2].append(
time.time() - total_start
)
torch.save(model.state_dict(), "best_metric_model.pth")
os.makedirs(os.path.join(args.save_model, args.mode), exist_ok=True)
torch.save(model.state_dict(),
os.path.join(args.save_model, args.mode, f"best_metric_model.pth"))
print("saved new best metric model")
logger.info(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
print(
f"current epoch: {epoch + 1} current"
f" mean dice: {metric:.4f}"
f" best mean dice: {best_metric:.4f}"
f" at epoch: {best_metric_epoch}"
)
logger.info(
f"current epoch: {epoch + 1} current"
f" mean dice: {metric:.4f}"
f" best mean dice: {best_metric:.4f}"
f" at epoch: {best_metric_epoch}"
)
print(
f"time consuming of epoch {epoch + 1} is:"
f" {(time.time() - epoch_start):.4f}"
)
epoch_times.append(time.time() - epoch_start)
logger.info(
f"time consuming of epoch {epoch + 1} is:"
f" {(time.time() - epoch_start):.4f}"
)
print(
f"train completed, best_metric: {best_metric:.4f}"
f" at epoch: {best_metric_epoch}"
f" total time: {(time.time() - total_start):.4f}"
)
return (
max_epochs,
epoch_loss_values,
metric_values,
epoch_times,
best_metrics_epochs_and_time,
logger.info(
f"train completed, best_metric: {best_metric:.4f}"
f" at epoch: {best_metric_epoch}"
f" total time: {(time.time() - total_start):.4f}"
)
def monai_optimized():
set_determinism(seed=0)
monai_start = time.time()
(
epoch_num,
m_epoch_loss_values,
m_metric_values,
m_epoch_times,
m_best,
) = train()
m_total_time = time.time() - monai_start
print(
f"total training time of {epoch_num} epochs with MONAI: {m_total_time:.4f}"
)
if __name__ == '__main__':
train(args)
if __name__ == '__main__':
monai_optimized()
import logging
import math
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from typing import Sequence
......@@ -15,7 +17,7 @@ from monai.transforms import (
ScaleIntensityRanged,
Spacingd,
RandRotate90d,
ToTensord, EnsureTyped, Invertd, AsDiscreted,
ToTensord, EnsureTyped, Invertd, AsDiscreted, EnsureChannelFirstd
)
......@@ -41,13 +43,13 @@ def get_args(strings=None):
arg("--learning_rate", type=float, default=1e-4, help="")
arg("--weight_decay", type=float, default=1e-5, help="")
arg("--momentum", type=float, default=0.99, help="")
arg("--max_epoch", type=int, default=1200, help="")
arg("--max_epoch", type=int, default=200, help="")
arg("--sw_batch_size", type=int, default=4, help="")
arg("--save_model", type=str, default="./models", help="")
arg("--pred_path", type=str, default="./preds", help="")
arg("--mode", type=str, default="UNet", choices=["UNet", "VNet", "BasicUNet", "DynUNet", "UNETR"])
arg("--mode", type=str, default="DynUNet", choices=["UNet", "VNet", "BasicUNet", "DynUNet", "UNETR"])
arg("--optimizer", type=str, default="adam", choices=["adam", "adamw", "sgd"])
arg("--scheduler", type=str, default="", choices=["cosing", "plateau"])
arg("--scheduler", action="store_true", help="")
args = parser.parse_args()
return args
......@@ -56,7 +58,7 @@ def train_transforms(args):
train_trans = Compose(
[
LoadImaged(keys=["image", "label"]),
AddChanneld(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Spacingd(
keys=["image", "label"],
pixdim=args.pixdim,
......@@ -206,7 +208,7 @@ def get_param(args):
# 计算metrics
def evaluate(y_pred, y_label):
def evaluate(y_pred, y_label, num_classes):
# print(y_pred, y_pred[0].shape)
dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
dice_metric_batch = DiceMetric(include_background=False, reduction="mean_batch", get_not_nans=False)
......@@ -214,40 +216,60 @@ def evaluate(y_pred, y_label):
dice_metric_batch(y_pred=y_pred, y=y_label)
dice_metric_result = dice_metric.aggregate().item()
metric_batch = dice_metric_batch.aggregate()
metric_0 = metric_batch[0].item()
metric_1 = metric_batch[1].item()
metric_2 = metric_batch[2].item()
# print(dice_metric_result)
# dice_metric.reset()
auc_metric = ROCAUCMetric(average="weighted")
auc_metric(y_pred=y_pred[0].view(-1), y=y_label[0].view(-1))
auc_metric_result = auc_metric.aggregate().item()
auc_metric.reset()
# auc_metric_result = 0
metric = []
for m in range(num_classes - 1):
metric_num = metric_batch[m].item()
metric.append(metric_num)
dice_metric.reset()
dice_metric_batch.reset()
'''["sensitivity-TPR", "specificity-TNR", "precision-P", "negative predictive value", "miss rate-FNR", "fall out",
"false discovery rate", "false omission rate", "prevalence threshold", "threat score", "accuracy",
"balanced accuracy", "f1 score", "matthews correlation coefficient", "fowlkes mallows index", "informedness",
"markedness"]'''
cfm_metric = ConfusionMatrixMetric(include_background=False, metric_name=("specificity", "precision", "recall"),
cfm_metric = ConfusionMatrixMetric(include_background=False, metric_name=("precision", "recall"),
compute_sample=True)
cfm_metric(y_pred=y_pred, y=y_label)
cfm_metric_result = cfm_metric.aggregate()
cfm_metric_result = torch.stack(cfm_metric_result).cpu().numpy()[:, 0]
cfm_metric.reset()
return dice_metric_result, metric_0, metric_1, metric_2, auc_metric_result, cfm_metric_result[0], cfm_metric_result[1], cfm_metric_result[2]
return dice_metric_result, metric, cfm_metric_result[0], cfm_metric_result[1]
def choice_optimizer(args, model):
def choice_optimizer(args, lr, model):
global optimizer
if args.optimizer == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=args.weight_decay)
elif args.optimizer == 'adamw':
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=args.weight_decay)
elif args.optimizer == 'sgd':
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum)
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum)
return optimizer
def cosing_schedule(lr, epoch, warmup_step, epoch_total, min=1e-7):
if epoch < warmup_step:
mul = epoch / warmup_step
else:
mul = 0.5 * (1.0 + math.cos(math.pi * (epoch - warmup_step) / (epoch_total - warmup_step)))
return lr * mul + min
def get_logger(args):
logger = logging.getLogger('train')
logger.setLevel(level=logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
file_handler = logging.FileHandler(f'train_{args.fold}.log')
file_handler.setLevel(level=logging.INFO)
file_handler.setFormatter(formatter)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.DEBUG)
stream_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
return logger
from monai.transforms import (
AddChanneld,
Compose,
CropForegroundd,
DeleteItemsd,
FgBgToIndicesd,
LoadImaged,
Orientationd,
RandCropByPosNegLabeld,
ScaleIntensityRanged,
Spacingd,
EnsureTyped,
)
def transformations():
train_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
AddChanneld(keys=["image", "label"]),
Spacingd(
keys=["image", "label"],
pixdim=(0.5, 0.5, 0.5),
mode=("bilinear", "nearest"),
),
Orientationd(keys=["image", "label"], axcodes="RAI"),
ScaleIntensityRanged(
keys=["image"],
a_min=150,
a_max=850,
b_min=0.0,
b_max=1.0,
clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
# pre-compute foreground and background indexes
# and cache them to accelerate training
FgBgToIndicesd(
keys="label",
fg_postfix="_fg",
bg_postfix="_bg",
image_key="image",
),
# randomly crop out patch samples from big
# image based on pos / neg ratio
# the image centers of negative samples
# must be in valid image area
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(128, 128, 96),
pos=1,
neg=1,
num_samples=4,
fg_indices_key="label_fg",
bg_indices_key="label_bg",
),
DeleteItemsd(keys=["label_fg", "label_bg"]),
EnsureTyped(keys=["image", "label"]),
]
)
val_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
AddChanneld(keys=["image", "label"]),
Spacingd(
keys=["image", "label"],
pixdim=(0.5, 0.5, 0.5),
mode=("bilinear", "nearest"),
),
Orientationd(keys=["image", "label"], axcodes="RAI"),
ScaleIntensityRanged(
keys=["image"],
a_min=150,
a_max=850,
b_min=0.0,
b_max=1.0,
clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
EnsureTyped(keys=["image", "label"]),
]
)
return train_transforms, val_transforms
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册