Brain Tumor Segmentation Challenge#

The BRATSChallenge is a yearly competition for brain tumor segmentation organised by the Medical Image Computing Society of North America (MICCAI).

The dataset consists of brain MRI scans of patients with gliomas, which are the most common type of primary brain tumor. The dataset is quite large so it will take a while to download. We will download the training data in the code below for you and load it in utils/brats_dataset.py. We have simplified the challenge to only consider 2D slices from the MRI scans and only segment the tumor and not classify into different types of tumors. You may of course try to classify into different types of tumors if you wish, after you have tuned your model for the binary segmentation task.

The dataset is split into training and testing sets, with the testing set being used for the final evaluation of the submitted segmentation algorithms. We will simplify the challenge and only consider 2D slices from the MRI scans.

The goal of this tutorial is to train and tune a model for the BRATSChallenge. We will provide you with an example training script, which you will need to modify and tune for the challenge or implement your own. It is often not necessary to run 10 epochs to know if a model is working or not and don’t wait until the training is completely done to check if it is working. This may waste a lot of time. Debug by training on a single batch, tune hyperparameters on the validation set with shorter training time and then train for longer. Also check that each part of your pipeline is working, if you add augmentations, check that they are working by plotting the images and masks.

Tasks#

  1. Explore the dataset and the example training script.

  2. Consider how you may improve the model’s performance.

  3. Explore: data augmentation, image size, model architecture, loss function, optimizer, learning rate, etc.

  4. Evaluate the performance of your model with metrics appropriate for the task. Look at MONAI documentation for metrics.

  5. Submit your code and a short report/description of your approach.

We encourage you to use a good training practice for this lab. This means that you should:

  1. Use a validation set to tune your model.

  2. Use a loss function and optimizer that are appropriate for the task.

  3. Use a learning rate that is appropriate for the task.

  4. Use a batch size that is appropriate for the task and the size of your GPU.

  5. Use data augmentation techniques that are appropriate for the task.

  6. Implement regularization if necessary to prevent overfitting if needed.

  7. Not load the entire dataset into memory.

  8. Save your model and checkpoints.

  9. Optional: Monitor and log training progress using tools like TensorBoard.

  10. Evaluate the performance of your model with metrics appropriate for the task at least: Dice coefficient, precision, recall, AUC

We have provided you with a basic model implementation and structure on how to load the data and train the model. You will need to modify the code to improve the model’s performance.

[ ]:
# The lab example has a requirement for tensorboard, which you can install by running the command below in the terminal
#conda activate DL_labs_GPU && pip install -U tensorboard

# In your own code you choose whether to use Tensorboard or not.
[ ]:
import torch
from utils.brats_dataset import get_brats_dataloader
from utils.model import get_unet_model
from utils.train import train
from utils.evaluation import evaluate

[ ]:
root_dir = "datasets"  # Update this to your dataset directory
batch_size = 8
num_workers = 4

epochs = 3
save_dir = "utils/checkpoints"
log_path = "utils/logs"
load_path = None  # Set to a checkpoint path if you want to resume training

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load data
print("Loading training data...")
train_dataloader = get_brats_dataloader(root_dir, "training", batch_size, num_workers, shuffle=True)
print("Loading validation data...")
val_dataloader = get_brats_dataloader(root_dir, "validation", batch_size, num_workers, shuffle=False)

# Initialize model, loss function, and optimizer
model = get_unet_model(in_channels=4, out_channels=1)
model.to(device)
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Train the model
print("Starting training...")
trained_model = train(
    model=model,
    train_loader=train_dataloader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    device=device,
    epochs=epochs,
    save_dir=save_dir,
    save_interval=2,
    log_path=log_path,
    load_path=load_path,
    overfit_batch=False
)

# Evaluate the model
print("Starting evaluation on validation set...")
evaluate(trained_model, val_dataloader, device)
print("Starting evaluation on training set...")
evaluate(trained_model, train_dataloader, device)