Quick start
Train like a Pro in a few lines
Download a dataset, create a dataset wrapper, and hand it to a trainer. Below is an example using
the ACDC dataset. The example code to the right replicates most of nnU-Net's features but without
augmentations.
Example: train on the ACDC dataset
from typing import override
from os.path import exists
from monai.networks.nets import DynUNet
from torch import nn
from torch.utils.data import DataLoader
from mipcandy import SegmentationTrainer, AmbiguousShape, auto_device, download_dataset, NNUNetDataset, inspect, \
load_inspection_annotations, RandomROIDataset
class UNetTrainer(SegmentationTrainer):
@override
def build_network(self, example_shape: AmbiguousShape) -> nn.Module:
kernel_size = [[3, 3, 3]] * 5
strides = [[1, 1, 1]] + [[2, 2, 2]] * 4
return DynUNet(spatial_dims=3, in_channels=example_shape[0], out_channels=self.num_classes,
kernel_size=kernel_size, strides=strides, upsample_kernel_size=strides,
deep_supervision=self.deep_supervision, deep_supr_num=2, res_block=True)
if __name__ == "__main__":
device = auto_device()
download_dataset("nnunet_datasets/ACDC", "tutorial/datasets/ACDC")
dataset = NNUNetDataset("tutorial/datasets/ACDC", align_spacing=True)
if exists("tutorial/datasets/ACDC/annotations.json"):
annotations = load_inspection_annotations("tutorial/datasets/ACDC/annotations.json", dataset)
else:
dataset.device(device=device)
annotations = inspect(dataset)
dataset.device(device="cpu")
annotations.save("tutorial/datasets/ACDC/annotations.json")
dataset = RandomROIDataset(annotations, 2)
train, val = dataset.fold()
train_loader = DataLoader(train, 2, True, num_workers=2, prefetch_factor=2, persistent_workers=True)
val_loader = DataLoader(val, 1, False)
trainer = UNetTrainer("tutorial", train_loader, val_loader, device=device)
trainer.train(1000, note="example with the ACDC dataset")