nnUNet MONAI Bundle#

In this notebook, we will demonstrate how to create a MONAI Bundle supporting nnUNet experiment for training and inference. In this step-by step tutorial, we will describe how to create all the required python code and YAML configuration files needed to train and evaluate a nnUNet model using the MONAI Bundle format.

nnUNet Trainer#

The core component for the nnUNet MONAI Bundle is the get_nnunet_trainer function. This function is responsible for creating the nnUNet trainer object from the native nnUNetv2 implementation. From the nnUNet trainer object, we can access the training components, such as the data loaders, model, learning rate scheduler, optimizer, and loss function, and perform training and inference tasks.

[1]:
import torch
from typing import Union, Optional
import json
from pathlib import Path
import os
from torch.backends import cudnn

def get_nnunet_trainer(dataset_name_or_id: Union[str, int],
                       configuration: str, fold: Union[int, str],
                       pymaia_config_file: str = None,  # To set env variables
                       trainer_class_name: str = 'nnUNetTrainer',
                       plans_identifier: str = 'nnUNetPlans',
                       pretrained_weights: Optional[str] = None,
                       num_gpus: int = 1,
                       use_compressed_data: bool = False,
                       export_validation_probabilities: bool = False,
                       continue_training: bool = False,
                       only_run_validation: bool = False,
                       disable_checkpointing: bool = False,
                       val_with_best: bool = False,
                       device: torch.device = torch.device(
                           'cuda'),
                       pretrained_model = None
                           ):  # From nnUNet/nnunetv2/run/run_training.py#run_training

    ## Block Added

    if pymaia_config_file != None:
        with open(pymaia_config_file, "r") as f:
            pymaia_config_dict = json.load(f)

        os.environ["nnUNet_raw"] = str(Path(pymaia_config_dict["base_folder"]).joinpath("nnUNet_raw"))
        os.environ["nnUNet_preprocessed"] = pymaia_config_dict["preprocessing_folder"]
        os.environ["nnUNet_results"] = pymaia_config_dict["results_folder"]

    from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint
    ## End Block

    if isinstance(fold, str):
        if fold != 'all':
            try:
                fold = int(fold)
            except ValueError as e:
                print(
                    f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!')
                raise e

    if int(num_gpus) > 1:
        ...  # Disable for now
    else:
        nnunet_trainer = get_trainer_from_args(str(dataset_name_or_id), configuration, fold, trainer_class_name,
                                               plans_identifier, use_compressed_data, device=device)

        if disable_checkpointing:
            nnunet_trainer.disable_checkpointing = disable_checkpointing

        assert not (
                continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.'

        maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights)
        nnunet_trainer.on_train_start()  # Added to Initialize Trainer
        if torch.cuda.is_available():
            cudnn.deterministic = False
            cudnn.benchmark = True

        if pretrained_model is not None:
            state_dict = torch.load(pretrained_model)
            if 'network_weights' in state_dict:
                nnunet_trainer.network._orig_mod.load_state_dict(state_dict['network_weights'])
            #nnunet_trainer.network.load_state_dict(torch.load(pretrained_model)['model'])
        # Skip Training and Validation Phase
        return nnunet_trainer
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 1
----> 1 import torch
      2 from typing import Union, Optional
      3 import json

ModuleNotFoundError: No module named 'torch'

The function get_nnunet_trainer accepts the following parameters:

  • dataset_name_or_id: The dataset name or ID to be used for training and evaluation.

  • fold: The fold number for the cross-validation experiment.

  • config: The training configuration for the nnUNet trainer, usually 3d_fullres.

  • trainer_class_name: The nnUNet trainer class name to be used for training, e.g. nnUNetTrainer.

  • plans_identifier: The nnUNet plans identifier for the dataset, e.g. nnUNetPlans.

  • pretained_model: Optional parameter to specify the pre-trained model for transfer learning.

Additionally, the function requires the pymaia_config_file (generated after running `nnunet_prepare_data_folder <https://pymaia.readthedocs.io/en/latest/apidocs/nnunet_prepare_data_folder.html>`__) as input parameter.

[ ]:
task_id = "109"
pymaia_config_file = "/home/maia-user/Tutorials/MAIA/Experiments/Task09_Spleen/Task09_Spleen_results/Dataset109_Task09_Spleen.json"
nnunet_trainer_class_name = "nnUNetTrainer"
nnunet_plans_identifier = "nnUNetResEncUNetLPlans"

#pretrained_model = "/home/maia-user/Tutorials/nnunetmonaibundle/model/nnUNet_Bundle/models/checkpoint_epoch=10.pt"
#pretrained_model = "/home/maia-user/Tutorials/Task09_Spleen_Bundle/models/Dataset109_Spleen/nnUNetTrainer__nnUNetResEncUNetLPlans__3d_fullres/fold_0/checkpoint_final.pth"

Get nnUNet Trainer from Preprocessing Folder#

An alternative way to get the nnUNet Trainer is to use the get_nnunet_trainer_from_preprocessing_folder function. This function reads the preprocessing folder and returns the nnUNet Trainer object.

[ ]:
from batchgenerators.utilities.file_and_folder_operations import join, load_json
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
import nnunetv2

def get_nnunet_trainer_from_preprocessing_folder(
                       plans_file,
                       dataset_file,
                       configuration: str,
                       fold: Union[int, str],
                       trainer_class_name: str = 'nnUNetTrainer',
                       plans_identifier: str = 'nnUNetPlans',
                       pretrained_weights: Optional[str] = None,
                       num_gpus: int = 1,
                       use_compressed_data: bool = False,
                       export_validation_probabilities: bool = False,
                       continue_training: bool = False,
                       only_run_validation: bool = False,
                       disable_checkpointing: bool = False,
                       val_with_best: bool = False,
                       device: torch.device = torch.device(
                           'cuda'),
                        pretrained_model = None
):  # From nnUNet/nnunetv2/run/run_training.py#run_training

    ## Block Added
    os.environ["nnUNet_raw"] = str(Path(".").joinpath("nnUNet_raw"))
    os.environ["nnUNet_preprocessed"] = "."
    os.environ["nnUNet_results"] = "."
    from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint
    ##

    if isinstance(fold, str):
        if fold != 'all':
            try:
                fold = int(fold)
            except ValueError as e:
                print(
                    f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!')
                raise e

    if int(num_gpus) > 1:
        ...  # Disable for now
    else:
        nnunet_trainer = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
                                                trainer_class_name, 'nnunetv2.training.nnUNetTrainer')

        plans = load_json(plans_file)
        dataset_json = load_json(dataset_file)

        nnunet_trainer = nnunet_trainer(plans=plans, configuration=configuration, fold=fold,
                                    dataset_json=dataset_json, unpack_dataset=False, device=torch.device("cuda"))

        if disable_checkpointing:
            nnunet_trainer.disable_checkpointing = disable_checkpointing

        assert not (
                continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.'

        maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights)
        nnunet_trainer.initialize()  # To Initialize Trainer
        if torch.cuda.is_available():
            cudnn.deterministic = False
            cudnn.benchmark = True

        # Skip Training and Validation Phase
        if pretrained_model is not None:
            state_dict = torch.load(pretrained_model)
            if 'network_weights' in state_dict:
                nnunet_trainer.network._orig_mod.load_state_dict(state_dict['network_weights'])

        return nnunet_trainer
[ ]:
nnunet_trainer = get_nnunet_trainer(dataset_name_or_id = task_id,
                                    configuration = "3d_fullres",
                                    fold = "0",
                                    pymaia_config_file=pymaia_config_file,
                                    trainer_class_name = nnunet_trainer_class_name,
                                    plans_identifier = nnunet_plans_identifier,
                                    #pretrained_model=pretrained_model
                                   )
[ ]:
nnunet_trainer = get_nnunet_trainer_from_preprocessing_folder(plans_file=plans_file,
                                    dataset_file=dataset_file,
                                    configuration = "3d_fullres",
                                    fold = "0",
                                    trainer_class_name = nnunet_trainer_class_name,
                                    plans_identifier = nnunet_plans_identifier,
                                    #pretrained_model=pretrained_model
                                   )
[ ]:
from monai.data import Dataset
from monai.handlers import StatsHandler, from_engine, MeanDice, ValidationHandler, LrScheduleHandler, CheckpointSaver, CheckpointLoader, TensorBoardStatsHandler, MLFlowHandler
from monai.engines import SupervisedTrainer, SupervisedEvaluator

from monai.transforms import Compose, Lambdad, Activationsd, AsDiscreted

Train and Val Data Loaders#

[ ]:
train_dataloader = nnunet_trainer.dataloader_train
train_data = [{'case_identifier':k} for k in nnunet_trainer.dataloader_train.generator._data.dataset.keys()]
train_dataset = Dataset(data=train_data)
[ ]:
val_dataloader = nnunet_trainer.dataloader_val
val_data = [{'case_identifier':k} for k in nnunet_trainer.dataloader_val.generator._data.dataset.keys()]
val_dataset = Dataset(data=val_data)

Network, Optimizer, and Loss Function#

[ ]:
device = nnunet_trainer.device

network = nnunet_trainer.network
optimizer = nnunet_trainer.optimizer
lr_scheduler = nnunet_trainer.lr_scheduler
loss = nnunet_trainer.loss

Prepare Batch Function#

The nnUnet DataLoader returns a dictionary with the data and target keys. Since the SupervisedTrainer used in the MONAI Bundle expects the data and target to be separate tensors, we need to create a custom prepare batch function to extract the data and target tensors from the dictionary.

[ ]:
def prepare_nnunet_batch(batch, device, non_blocking):
    data = batch["data"].to(device, non_blocking=non_blocking)
    if isinstance(batch["target"], list):
        target = [i.to(device, non_blocking=non_blocking) for i in batch["target"]]
    else:
        target = batch["target"].to(device, non_blocking=non_blocking)
    return data, target
[ ]:
image, label = prepare_nnunet_batch(next(iter(train_dataloader)),device="cpu",non_blocking=True)

MONAI Supervised Trainer#

The SupervisedTrainer class from MONAI is used to train the nnUNet model. For a minimal setup, we need to provide the model, optimizer, loss function, data loaders, number of epochs and the device to run the training.

[ ]:
train_handlers = [
  StatsHandler(
      output_transform= from_engine(['loss'], first=True),
      tag_name= "train_loss"
  )
]

[ ]:
iterations = 100
epochs = 50
[ ]:
trainer = SupervisedTrainer(
    amp= True,
    device = device,
    epoch_length = iterations,
    loss_function = loss,
    max_epochs = epochs,
    network = network,
    prepare_batch = prepare_nnunet_batch,
    optimizer = optimizer,
    train_data_loader = train_dataloader,
    train_handlers= train_handlers
)
[ ]:
trainer.run()

Adding Validation and Validation Metrics#

For a complete training setup, we need to add the validation data loader and the validation metrics to the SupervisedTrainer. Using the MONAI class SupervisedEvaluator, we can evaluate the model on the validation data loader and calculate the validation metrics (Dice Score).

[ ]:
val_key_metric = MeanDice(
      output_transform = from_engine(['pred', 'label']),
      reduction = "mean",
      include_background = False

)

additional_metrics = {
      "Val_Dice_Per_Class": MeanDice(
            output_transform = from_engine(['pred', 'label']),
            reduction = "mean_batch",
            include_background = False,
      )
      }

Additionally, in order to compute the Mean Dice score over the batch, we need to apply a pos-processing transformtation to the nnUNet model output. Since MeanDice accepts y and y_preds as Batch-first tensors (BCHW[D]), we need to create a custom post-processing transform to convert the nnUNet model output to the required format.

[ ]:
num_classes = 2

postprocessing = Compose(
    transforms=[
      ## Extract only high-res predictions from Deep Supervision
      Lambdad(
        keys= ["pred","label"],
        func = lambda x: x[0]
      ),
      ## Apply Softmax to the predictions
      Activationsd(
        keys= "pred",
        softmax= True
      ),
      ## Binarize the predictions
      AsDiscreted(
        keys= "pred",
        threshold= 0.5
      ),
      ## Convert the labels to one-hot
      AsDiscreted(
        keys= "label",
        to_onehot= num_classes
      )
    ]
)
[ ]:
val_handlers = [StatsHandler(
  iteration_log = False
)]
[ ]:
val_iterations = 100
val_interval = 1
[ ]:
evaluator = SupervisedEvaluator(
    amp= True,
    device = device,
    epoch_length = val_iterations,
    network = network,
    key_val_metric={"Val_Dice": val_key_metric},
    prepare_batch= prepare_nnunet_batch,
    val_data_loader = val_dataloader,
    val_handlers= val_handlers,
    postprocessing= postprocessing,
    additional_metrics= additional_metrics,
)

And finally, we add the evaluator to the SupervisedTrainer to calculate the validation metrics during training.

[ ]:
train_handlers.append(
    ValidationHandler(
        epoch_level = True,
        interval= val_interval,
        validator = evaluator
    )
)

We can also add the MeanDice metric to the SupervisedTrainer to calculate the mean dice score over the batch during training.

[ ]:
train_key_metric = MeanDice(
      output_transform = from_engine(['pred', 'label']),
      reduction = "mean",
      include_background = False

)

additional_metrics = {
    "Train_Dice_Per_Class": MeanDice(
            output_transform = from_engine(['pred', 'label']),
            reduction = "mean_batch",
            include_background = False,
      )
}
[ ]:
trainer = SupervisedTrainer(
    amp= True,
    device = device,
    epoch_length = iterations,
    loss_function = loss,
    max_epochs = epochs,
    network = network,
    prepare_batch = prepare_nnunet_batch,
    optimizer = optimizer,
    train_data_loader = train_dataloader,
    train_handlers= train_handlers,
    key_train_metric = {"Train_Dice": train_key_metric},
    postprocessing= postprocessing,
    additional_metrics = additional_metrics
)
[ ]:
trainer.run()

Learning Rate Scheduler#

One last component to add to the SupervisedTrainer, in order to replicate the training behaviour of the native nnUNet, is the learning rate scheduler.

[ ]:
train_handlers.append(
    LrScheduleHandler(
        lr_scheduler = lr_scheduler,
        print_lr = True
    )
)
[ ]:
trainer = SupervisedTrainer(
    amp= True,
    device = device,
    epoch_length = iterations,
    loss_function = loss,
    max_epochs = epochs,
    network = network,
    prepare_batch = prepare_nnunet_batch,
    optimizer = optimizer,
    train_data_loader = train_dataloader,
    train_handlers= train_handlers,
    key_train_metric = {"Train_Dice": train_key_metric},
    postprocessing= postprocessing,
    additional_metrics = additional_metrics
)
[ ]:
trainer.run()
[ ]:
train_handlers[-1].lr_scheduler.get_last_lr()

Checkpointing#

To save the model weights during training, we can use the CheckpointSaver callback from MONAI. This callback saves the model weights after each epoch. We can later use the CheckpointLoader to load the model weights and perform inference or resume training.

[ ]:
val_handlers.append(
    CheckpointSaver(
        save_dir= "Bundle/models",
        save_dict= {"network_weights": nnunet_trainer.network._orig_mod, "optimizer_state": nnunet_trainer.optimizer, "scheduler": nnunet_trainer.lr_scheduler},
        #save_final= True,
        save_interval= 1,
        save_key_metric= True,
        #final_filename= "model_final.pt",
        #key_metric_filename= "model.pt",
        n_saved= 1
    )
)
[ ]:
ckpt_dir = "Bundle/models"
reload_checkpoint_epoch = "latest"

train_handlers.append(
    CheckpointLoader(
        load_path= 'Bundle/models/checkpoint_epoch='+str(get_checkpoint(reload_checkpoint_epoch, ckpt_dir))+'.pt'
        load_dict= {"network_weights": nnunet_trainer.network._orig_mod, "optimizer_state": nnunet_trainer.optimizer, "scheduler": nnunet_trainer.lr_scheduler},
        map_location= device
    )
)

Initial nnUNet Checkpoint#

In order to provide compatibility with the native nnUNet, we need to save the nnUNet-specific configuration, together the regular MONAI checkpoint. This is done only once, before the training starts. At the end of the training, we will have a MONAI checkpoint and a nnUNet checkpoint. To be able to convert the MONAI checkpoint to a nnUNet checkpoint at any time, we can then combine the two checkpoints.

[ ]:
checkpoint = {
  "inference_allowed_mirroring_axes": nnunet_trainer.inference_allowed_mirroring_axes,
    "init_args": nnunet_trainer.my_init_kwargs,
    "trainer_name": nnunet_trainer.__class__.__name__
}
checkpoint_filename = 'Bundle/models/nnunet_checkpoint.pth'

torch.save(checkpoint, checkpoint_filename)

MLFlow and Tensorboard Monitoring#

To monitor the training process, we can use MLFlow and Tensorboard. We can log the training metrics, hyperparameters, and model weights to MLFlow, and visualize the training metrics using Tensorboard.

[ ]:
train_handlers.append(
    TensorBoardStatsHandler(
        log_dir= "Bundle/logs",
        output_transform= from_engine(['loss'], first=True),
        tag_name =  "train_loss"
    )
)

val_handlers.append(
    TensorBoardStatsHandler(
        log_dir= "Bundle/logs",
        iteration_log = False
    )
)
[ ]:
def mlflow_transform(state_output):
    return state_output[0]['loss']

class MLFlowPyMAIAHandler(MLFlowHandler):
    def __init__(self, label_dict, **kwargs):
        super(MLFlowPyMAIAHandler, self).__init__(**kwargs)
        self.label_dict = label_dict

    def _default_epoch_log(self, engine) -> None:
        """
        Execute epoch level log operation.
        Default to track the values from Ignite `engine.state.metrics` dict and
        track the values of specified attributes of `engine.state`.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.

        """
        log_dict = engine.state.metrics
        if not log_dict:
            return

        current_epoch = self.global_epoch_transform(engine.state.epoch)

        new_log_dict = {}

        for metric in log_dict:
            if type(log_dict[metric]) == torch.Tensor:
                for i,val in enumerate(log_dict[metric]):
                    new_log_dict[metric+"_{}".format(list(self.label_dict.keys())[i+1])] = val
            else:
                new_log_dict[metric] = log_dict[metric]
        self._log_metrics(new_log_dict, step=current_epoch)

        if self.state_attributes is not None:
            attrs = {attr: getattr(engine.state, attr, None) for attr in self.state_attributes}
            self._log_metrics(attrs, step=current_epoch)
[ ]:
import re
import yaml
from monai.bundle import ConfigParser
import monai

def create_mlflow_experiment_params(params_file, custom_params=None):
    params_dict = {}
    config_values = monai.config.deviceconfig.get_config_values()
    for k in config_values:
        params_dict[re.sub("[()]"," ",str(k))] = config_values[k]

    optional_config_values = monai.config.deviceconfig.get_optional_config_values()
    for k in optional_config_values:
        params_dict[re.sub("[()]"," ",str(k))] = optional_config_values[k]

    gpu_info = monai.config.deviceconfig.get_gpu_info()
    for k in gpu_info:
        params_dict[re.sub("[()]"," ",str(k))] = str(gpu_info[k])

    yaml_config_files = [params_file]
    # %%
    monai_config = {}
    for config_file in yaml_config_files:
        with open(config_file, 'r') as file:
            monai_config.update(yaml.safe_load(file))

    monai_config["bundle_root"] = str(Path(Path(params_file).parent).parent)

    parser = ConfigParser(monai_config, globals={"os": "os",
                                                 "pathlib": "pathlib",
                                                 "json": "json",
                                                 "ignite": "ignite"
                                                 })

    parser.parse(True)

    for k in monai_config:
        params_dict[k] = parser.get_parsed_content(k,instantiate=True)

    if custom_params is not None:
        for k in custom_params:
            params_dict[k] = custom_params[k]
    return params_dict
[ ]:
%%writefile Bundle/mlflow_params.yaml

num_classes: 2
task_id: "109"
pymaia_config_file: "/home/maia-user/Tutorials/MAIA/Experiments/Task09_Spleen/Task09_Spleen_results/Dataset109_Task09_Spleen.json"
tracking_uri: "http://localhost:5000"
mlflow_experiment_name: "nnUNet_Bundle_Spleen"
mlflow_run_name: "nnUNet_Bundle_Spleen"
nnunet_trainer_class_name: "nnUNetTrainer"
nnunet_plans_identifier: "nnUNetPlans"

[ ]:
mlflow_experiment_name = "nnUNet_Bundle_Spleen"
mlflow_run_name = "nnUNet_Bundle_Spleen"
label_dict = {0: "background", 1: "Spleen"}
tracking_uri = "http://localhost:5000"
params_file = "Bundle/mlflow_params.yaml"


train_handlers.append(
    MLFlowPyMAIAHandler(
        dataset_dict = {"train": train_dataset},
        dataset_keys = "case_identifier",
        experiment_param = create_mlflow_experiment_params(params_file),
        experiment_name= mlflow_experiment_name,
        label_dict = label_dict,
        output_transform = mlflow_transform,
        run_name = mlflow_run_name,
        state_attributes = ["best_metric", "best_metric_epoch"],
        tag_name = "Train_Loss",
        tracking_uri = tracking_uri,
    )
)

val_handlers.append(
    MLFlowPyMAIAHandler(
        experiment_name= mlflow_experiment_name,
        iteration_log = False,
        label_dict = label_dict,
        output_transform = mlflow_transform,
        run_name = mlflow_run_name,
        state_attributes = ["best_metric", "best_metric_epoch"],
        tracking_uri = tracking_uri,
    )
)

To start the MLFlow server, we can run the following command in the terminal:

cd Bundle/MLFlow && mlflow server

To run Tensorboard, we can use the following command:

tensorboard --logdir Bundle/logs
[ ]:
trainer = SupervisedTrainer(
    amp= True,
    device = device,
    epoch_length = iterations,
    loss_function = loss,
    max_epochs = epochs,
    network = network,
    prepare_batch = prepare_nnunet_batch,
    optimizer = optimizer,
    train_data_loader = train_dataloader,
    train_handlers= train_handlers,
    key_train_metric = {"Train_Dice": train_key_metric},
    postprocessing= postprocessing,
    additional_metrics = additional_metrics
)
[ ]:
trainer.run()

Reload Checkpoint#

When resuming the training from a checkpoint, we also want to restart the training from the same epoch. To do this, we need to load the checkpoint and update the trainer.state.epoch and trainer.state.iteration parameter in the SupervisedTrainer.

[ ]:
from PyMAIA.utils.file_utils import subfiles

def get_checkpoint(epoch, ckpt_dir):
    if epoch == "latest":

        latest_checkpoints = subfiles(ckpt_dir, prefix="checkpoint_epoch", sort=True,
                                      join=False)
        epochs = []
        for latest_checkpoint in latest_checkpoints:
            epochs.append(int(latest_checkpoint[len("checkpoint_epoch="):-len(".pt")]))

        epochs.sort()
        latest_epoch = epochs[-1]
        return latest_epoch
    else:
        return epoch

def reload_checkpoint(trainer, epoch, num_train_batches_per_epoch, ckpt_dir):

    epoch_to_load = get_checkpoint(epoch, ckpt_dir)
    trainer.state.epoch = epoch_to_load
    trainer.state.iteration = (epoch_to_load* num_train_batches_per_epoch) +1

Create MONAI Bundle#

[ ]:
%%bash

/home/maia-user/.conda/envs/MAIA/bin/python -m monai.bundle init_bundle nnUNetBundle
# you may need to install tree with "sudo apt install tree"
mkdir -p nnUNetBundle/nnUNet
mkdir -p nnUNetBundle/src
mkdir -p nnUNetBundle/nnUNet/evaluator
which tree && tree nnUNetBundle || true
[ ]:
%%writefile nnUNetBundle/nnUNet/global.yaml

iterations: $@nnunet_trainer.num_iterations_per_epoch
device: $@nnunet_trainer.device
epochs: $@nnunet_trainer.num_epochs
pymaia_config_dict: "$json.load(open(@pymaia_config_file))"
bundle_root: .
ckpt_dir: "$@bundle_root + '/models'"
[ ]:
%%writefile nnUNetBundle/nnUNet/params.yaml

num_classes: 2
task_id: ""
pymaia_config_file: ""
tracking_uri: "mlruns"
mlflow_experiment_name: ""
mlflow_run_name: ""
nnunet_model_folder: ""
nnunet_trainer_class_name: "nnUNetTrainer"
nnunet_plans_identifier: "nnUNetPlans"
[ ]:
%%writefile nnUNetBundle/nnUNet/imports.yaml

imports:
- $import glob
- $import os
- $import ignite
- $import torch
- $import shutil
- $import json
- $import src
- $from src.utils import create_mlflow_experiment_params
- $from pathlib import Path
[ ]:
%%writefile nnUNetBundle/nnUNet/run.yaml

run:
- "src.utils.set_mlflow_token(@token)"
- "$torch.save(@checkpoint,@checkpoint_filename)"
- "$shutil.copy(Path(@nnunet_model_folder).joinpath('dataset.json'), @bundle_root+'/models/dataset.json')"
- "$shutil.copy(Path(@nnunet_model_folder).joinpath('plans.json'), @bundle_root+'/models/plans.json')"
- "$@train#pbar.attach(@train#trainer,output_transform=lambda x: {'loss': x[0]['loss']})"
- "$@validate#pbar.attach(@validate#evaluator,metric_names=['Val_Dice'])"
- $@train#trainer.run()

initialize:
- $monai.utils.set_determinism(seed=123)
[ ]:
%%writefile nnUNetBundle/nnUNet/train.yaml

nnunet_trainer:
  _target_ : src.nnUNet_Trainer.get_nnunet_trainer
  dataset_name_or_id: "@task_id"
  configuration: "3d_fullres"
  fold: "0"
  pymaia_config_file: "@pymaia_config_file"
  trainer_class_name: "@nnunet_trainer_class_name"
  plans_identifier: "@nnunet_plans_identifier"

nnunet_trainer_def:
  _target_ : src.nnUNet_Trainer.get_nnunet_trainer_from_preprocessing_folder
  plans_file: "$@bundle_root+'/models/plans.json'"
  dataset_file: "$@bundle_root+'/models/dataset.json'"
  configuration: "3d_fullres"
  fold: "0"
  trainer_class_name: "@nnunet_trainer_class_name"
  plans_identifier: "@nnunet_plans_identifier"

loss: $@nnunet_trainer.loss
lr_scheduler: $@nnunet_trainer.lr_scheduler

network_def: $@nnunet_trainer_def.network
network: $@nnunet_trainer.network

optimizer: $@nnunet_trainer.optimizer

checkpoint:
  init_args: '$@nnunet_trainer.my_init_kwargs'
  trainer_name: '$@nnunet_trainer.__class__.__name__'
  inference_allowed_mirroring_axes: '$@nnunet_trainer.inference_allowed_mirroring_axes'

checkpoint_filename: "$@bundle_root+'/models/nnunet_checkpoint.pth'"
output_dir: $@bundle_root + '/eval'

train:
  pbar:
    _target_: "ignite.contrib.handlers.tqdm_logger.ProgressBar"
  dataloader: $@nnunet_trainer.dataloader_train
  train_data: "$[{'case_identifier':k} for k in @nnunet_trainer.dataloader_train.generator._data.dataset.keys()]"
  train_dataset:
    _target_: Dataset
    data: "@train#train_data"
  handlers:
  inferer:
    _target_: SimpleInferer
  key_metric:
    Train_Dice:
      _target_: "MeanDice"
      include_background: False
      output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
      reduction: "mean"
  additional_metrics:
    Train_Dice_per_class:
      _target_: "MeanDice"
      include_background: False
      output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
      reduction: "mean_batch"
  postprocessing:
    _target_: "Compose"
    transforms:
    - _target_: Lambdad
      keys:
        - "pred"
        - "label"
      func: "$lambda x: x[0]"
    - _target_: Activationsd
      keys:
        - "pred"
      softmax: True
    - _target_: AsDiscreted
      keys:
       - "pred"
      threshold: 0.5
    - _target_: AsDiscreted
      keys:
        - "label"
      to_onehot: "@num_classes"
  postprocessing_region_based:
    _target_: "Compose"
    transforms:
    - _target_: Lambdad
      keys:
        - "pred"
        - "label"
      func: "$lambda x: x[0]"
    - _target_: Activationsd
      keys:
        - "pred"
      sigmoid: True
    - _target_: AsDiscreted
      keys:
       - "pred"
      threshold: 0.5
  trainer:
    _target_: SupervisedTrainer
    amp: true
    device: '@device'
    additional_metrics: "@train#additional_metrics"
    epoch_length: "@iterations"
    inferer: '@train#inferer'
    key_train_metric: '@train#key_metric'
    loss_function: '@loss'
    max_epochs: '@epochs'
    network: '@network'
    prepare_batch: "$src.nnUNet_Trainer.prepare_nnunet_batch"
    optimizer: '@optimizer'
    postprocessing: '@train#postprocessing'
    train_data_loader: '@train#dataloader'
    train_handlers: '@train#handlers'
[1]:
%%writefile nnUNetBundle/nnUNet/train_resume.yaml

run:
- "$src.utils.set_mlflow_token(@token)"
- '$src.utils.reload_checkpoint(@train#trainer,@reload_checkpoint_epoch,@nnunet_trainer.num_iterations_per_epoch,@bundle_root+"/models")'
- "$@train#pbar.attach(@train#trainer,output_transform=lambda x: {'loss': x[0]['loss']})"
- "$@validate#pbar.attach(@validate#evaluator,metric_names=['Val_Dice'])"
- $@train#trainer.run()

train_handlers:
  handlers:
  - _target_: "$src.nnUNet_Trainer.MLFlowPyMAIAHandler"
    label_dict: "$@pymaia_config_dict['label_dict']"
    tracking_uri: "@tracking_uri"
    experiment_name: "@mlflow_experiment_name"
    run_name: "@mlflow_run_name"
    output_transform: "$src.nnUNet_Trainer.mlflow_transform"
    dataset_dict:
        train: "@train#train_dataset"
    dataset_keys: 'case_identifier'
    state_attributes:
    - "iteration"
    - "epoch"
    tag_name: 'Train_Loss'
    experiment_param: "$src.utils.create_mlflow_experiment_params( @bundle_root + '/nnUNet/params.yaml')"
    #artifacts=None
    optimizer_param_names: 'lr'
    #close_on_complete: False
  - _target_: LrScheduleHandler
    lr_scheduler: '@lr_scheduler'
    print_lr: true
  - _target_: ValidationHandler
    epoch_level: true
    interval: '@val_interval'
    validator: '@validate#evaluator'
  #- _target_: StatsHandler
  #  output_transform: $monai.handlers.from_engine(['loss'], first=True)
  #  tag_name: train_loss
  - _target_: TensorBoardStatsHandler
    log_dir: '@output_dir'
    output_transform: $monai.handlers.from_engine(['loss'], first=True)
    tag_name: train_loss
  - _target_: CheckpointLoader
    load_dict:
      network_weights: '$@nnunet_trainer.network._orig_mod'
      optimizer_state: '$@nnunet_trainer.optimizer'
      scheduler: '$@nnunet_trainer.lr_scheduler'
    load_path: '$@bundle_root + "/models/checkpoint_epoch="+str(src.utils.get_checkpoint(@reload_checkpoint_epoch, @bundle_root+"/models"))+".pt"'
    map_location: '@device'
Writing nnUNetBundle/nnUNet/train_resume.yaml
[ ]:
%%writefile nnUNetBundle/nnUNet/train_handlers.yaml

train_handlers:
  handlers:
  - _target_: "$src.nnUNet_Trainer.MLFlowPyMAIAHandler"
    label_dict: "$@pymaia_config_dict['label_dict']"
    tracking_uri: "@tracking_uri"
    experiment_name: "@mlflow_experiment_name"
    run_name: "@mlflow_run_name"
    output_transform: "$src.nnUNet_Trainer.mlflow_transform"
    dataset_dict:
        train: "@train#train_dataset"
    dataset_keys: 'case_identifier'
    state_attributes:
    - "iteration"
    - "epoch"
    tag_name: 'Train_Loss'
    experiment_param: "$src.utils.create_mlflow_experiment_params( @bundle_root + '/nnUNet/params.yaml')"
    #artifacts=None
    optimizer_param_names: 'lr'
    #close_on_complete: False
  - _target_: LrScheduleHandler
    lr_scheduler: '@lr_scheduler'
    print_lr: true
  - _target_: ValidationHandler
    epoch_level: true
    interval: '@val_interval'
    validator: '@validate#evaluator'
  #- _target_: StatsHandler
  #  output_transform: $monai.handlers.from_engine(['loss'], first=True)
  #  tag_name: train_loss
  - _target_: TensorBoardStatsHandler
    log_dir: '@output_dir'
    output_transform: $monai.handlers.from_engine(['loss'], first=True)
    tag_name: train_loss
[ ]:
%%writefile nnUNetBundle/nnUNet/validate.yaml

val_interval: 1
validate:
  pbar:
    _target_: "ignite.contrib.handlers.tqdm_logger.ProgressBar"
  key_metric:
    Val_Dice:
      _target_: "MeanDice"
      output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
      reduction: "mean"
      include_background: False
  additional_metrics:
    Val_Dice_per_class:
      _target_: "MeanDice"
      output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
      reduction: "mean_batch"
      include_background: False
  dataloader: $@nnunet_trainer.dataloader_val
  evaluator:
    _target_: SupervisedEvaluator
    additional_metrics: '@validate#additional_metrics'
    amp: true
    epoch_length: $@nnunet_trainer.num_val_iterations_per_epoch
    device: '@device'
    inferer: '@validate#inferer'
    key_val_metric: '@validate#key_metric'
    network: '@network'
    postprocessing: '@validate#postprocessing'
    val_data_loader: '@validate#dataloader'
    val_handlers: '@validate#handlers'
    prepare_batch: "$src.nnUNet_Trainer.prepare_nnunet_batch"
  handlers:
  - _target_: StatsHandler
    iteration_log: false
  - _target_: TensorBoardStatsHandler
    iteration_log: false
    log_dir: '@output_dir'
  - _target_: "$src.nnUNet_Trainer.MLFlowPyMAIAHandler"
    label_dict: "$@pymaia_config_dict['label_dict']"
    tracking_uri: "@tracking_uri"
    experiment_name: "@mlflow_experiment_name"
    run_name: "@mlflow_run_name"
    output_transform: "$src.nnUNet_Trainer.mlflow_transform"
    iteration_log: False
    state_attributes:
    - "best_metric"
    - "best_metric_epoch"
  - _target_: "CheckpointSaver"
    save_dir: "$str(@bundle_root)+'/models'"
    save_interval: 1
    n_saved: 1
    save_key_metric: true
    save_dict:
      network_weights: '$@nnunet_trainer.network._orig_mod'
      optimizer_state: '$@nnunet_trainer.optimizer'
      scheduler: '$@nnunet_trainer.lr_scheduler'
  inferer:
    _target_: SimpleInferer
  postprocessing: '%train#postprocessing'

[ ]:
%%writefile nnUNetBundle/nnUNet/evaluator/evaluator.yaml

validate:
  pbar:
    _target_: "ignite.contrib.handlers.tqdm_logger.ProgressBar"
  key_metric:
    Val_Dice:
      _target_: "MeanDice"
      output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
      reduction: "mean"
      include_background: False
  additional_metrics:
    Val_Dice_per_class:
      _target_: "MeanDice"
      output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
      reduction: "mean_batch"
      include_background: False
  dataloader: $@nnunet_trainer.dataloader_val
  evaluator:
    _target_: SupervisedEvaluator
    additional_metrics: '@validate#additional_metrics'
    amp: true
    epoch_length: $@nnunet_trainer.num_val_iterations_per_epoch
    device: '@device'
    inferer: '@validate#inferer'
    key_val_metric: '@validate#key_metric'
    network: '@network'
    postprocessing: '@validate#postprocessing'
    val_data_loader: '@validate#dataloader'
    val_handlers: '@validate#handlers'
    prepare_batch: "$src.nnUNet_Trainer.prepare_nnunet_batch"
  handlers:
  - _target_: StatsHandler
    iteration_log: false
  - _target_: TensorBoardStatsHandler
    iteration_log: false
    log_dir: '@output_dir'
  - _target_: "$src.nnUNet_Trainer.MLFlowPyMAIAHandler"
    label_dict: "$@pymaia_config_dict['label_dict']"
    tracking_uri: "@tracking_uri"
    experiment_name: "@mlflow_experiment_name"
    run_name: "@mlflow_run_name"
    output_transform: "$src.nnUNet_Trainer.mlflow_transform"
    iteration_log: False
    state_attributes:
    - "best_metric"
    - "best_metric_epoch"
  inferer:
    _target_: SimpleInferer
  postprocessing: '%train#postprocessing'

run:
- "src.utils.set_mlflow_token(@token)"
- "$@validate#pbar.attach(@validate#evaluator,metric_names=['Val_Dice'])"
- $@validate#evaluator.run()

initialize:
- "$setattr(torch.backends.cudnn, 'benchmark', True)"
[ ]:
%%writefile nnUNetBundle/src/__init__.py


[ ]:
%%writefile nnUNetBundle/src/nnUNet_Trainer.py
import torch
from typing import Union, Optional
import json
from pathlib import Path
import os
from torch.backends import cudnn
from batchgenerators.utilities.file_and_folder_operations import join, load_json
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
import nnunetv2
from monai.handlers import MLFlowHandler


def get_nnunet_trainer(dataset_name_or_id: Union[str, int],
                       configuration: str, fold: Union[int, str],
                       pymaia_config_file: str = None,  # To set env variables
                       trainer_class_name: str = 'nnUNetTrainer',
                       plans_identifier: str = 'nnUNetPlans',
                       pretrained_weights: Optional[str] = None,
                       num_gpus: int = 1,
                       use_compressed_data: bool = False,
                       export_validation_probabilities: bool = False,
                       continue_training: bool = False,
                       only_run_validation: bool = False,
                       disable_checkpointing: bool = False,
                       val_with_best: bool = False,
                       device: torch.device = torch.device(
                           'cuda'),
                       pretrained_model = None
                           ):  # From nnUNet/nnunetv2/run/run_training.py#run_training

    ## Block Added

    if pymaia_config_file != None:
        with open(pymaia_config_file, "r") as f:
            pymaia_config_dict = json.load(f)

        os.environ["nnUNet_raw"] = str(Path(pymaia_config_dict["base_folder"]).joinpath("nnUNet_raw"))
        os.environ["nnUNet_preprocessed"] = pymaia_config_dict["preprocessing_folder"]
        os.environ["nnUNet_results"] = pymaia_config_dict["results_folder"]

    from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint
    ## End Block

    if isinstance(fold, str):
        if fold != 'all':
            try:
                fold = int(fold)
            except ValueError as e:
                print(
                    f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!')
                raise e

    if int(num_gpus) > 1:
        ...  # Disable for now
    else:
        nnunet_trainer = get_trainer_from_args(str(dataset_name_or_id), configuration, fold, trainer_class_name,
                                               plans_identifier, use_compressed_data, device=device)

        if disable_checkpointing:
            nnunet_trainer.disable_checkpointing = disable_checkpointing

        assert not (
                continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.'

        maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights)
        nnunet_trainer.on_train_start()  # Added to Initialize Trainer
        if torch.cuda.is_available():
            cudnn.deterministic = False
            cudnn.benchmark = True

        if pretrained_model is not None:
            state_dict = torch.load(pretrained_model)
            if 'network_weights' in state_dict:
                nnunet_trainer.network._orig_mod.load_state_dict(state_dict['network_weights'])
            #nnunet_trainer.network.load_state_dict(torch.load(pretrained_model)['model'])
        # Skip Training and Validation Phase
        return nnunet_trainer

def get_nnunet_trainer_from_preprocessing_folder(
                       plans_file,
                       dataset_file,
                       configuration: str,
                       fold: Union[int, str],
                       trainer_class_name: str = 'nnUNetTrainer',
                       plans_identifier: str = 'nnUNetPlans',
                       pretrained_weights: Optional[str] = None,
                       num_gpus: int = 1,
                       use_compressed_data: bool = False,
                       export_validation_probabilities: bool = False,
                       continue_training: bool = False,
                       only_run_validation: bool = False,
                       disable_checkpointing: bool = False,
                       val_with_best: bool = False,
                       device: torch.device = torch.device(
                           'cuda'),
                        pretrained_model = None
):  # From nnUNet/nnunetv2/run/run_training.py#run_training

    ## Block Added
    os.environ["nnUNet_raw"] = str(Path(".").joinpath("nnUNet_raw"))
    os.environ["nnUNet_preprocessed"] = "."
    os.environ["nnUNet_results"] = "."
    from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint
    ##

    if isinstance(fold, str):
        if fold != 'all':
            try:
                fold = int(fold)
            except ValueError as e:
                print(
                    f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!')
                raise e

    if int(num_gpus) > 1:
        ...  # Disable for now
    else:
        nnunet_trainer = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
                                                trainer_class_name, 'nnunetv2.training.nnUNetTrainer')

        plans = load_json(plans_file)
        dataset_json = load_json(dataset_file)

        nnunet_trainer = nnunet_trainer(plans=plans, configuration=configuration, fold=fold,
                                    dataset_json=dataset_json, unpack_dataset=False, device=torch.device("cuda"))

        if disable_checkpointing:
            nnunet_trainer.disable_checkpointing = disable_checkpointing

        assert not (
                continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.'

        maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights)
        nnunet_trainer.initialize()  # To Initialize Trainer
        if torch.cuda.is_available():
            cudnn.deterministic = False
            cudnn.benchmark = True

        # Skip Training and Validation Phase
        if pretrained_model is not None:
            state_dict = torch.load(pretrained_model)
            if 'network_weights' in state_dict:
                nnunet_trainer.network._orig_mod.load_state_dict(state_dict['network_weights'])
            #nnunet_trainer.network.load_state_dict(torch.load(pretrained_model)['model'])

        return nnunet_trainer

def prepare_nnunet_batch(batch, device, non_blocking):
    data = batch["data"].to(device, non_blocking=non_blocking)
    if isinstance(batch["target"], list):
        target = [i.to(device, non_blocking=non_blocking) for i in batch["target"]]
    else:
        target = batch["target"].to(device, non_blocking=non_blocking)
    return data, target

def mlflow_transform(state_output):
    return state_output[0]['loss']

class MLFlowPyMAIAHandler(MLFlowHandler):
    def __init__(self, label_dict, **kwargs):
        super(MLFlowPyMAIAHandler, self).__init__(**kwargs)
        self.label_dict = label_dict

    def _default_epoch_log(self, engine) -> None:
        """
        Execute epoch level log operation.
        Default to track the values from Ignite `engine.state.metrics` dict and
        track the values of specified attributes of `engine.state`.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.

        """
        log_dict = engine.state.metrics
        if not log_dict:
            return

        current_epoch = self.global_epoch_transform(engine.state.epoch)

        new_log_dict = {}

        for metric in log_dict:
            if type(log_dict[metric]) == torch.Tensor:
                for i,val in enumerate(log_dict[metric]):
                    new_log_dict[metric+"_{}".format(list(self.label_dict.keys())[i+1])] = val
            else:
                new_log_dict[metric] = log_dict[metric]
        self._log_metrics(new_log_dict, step=current_epoch)

        if self.state_attributes is not None:
            attrs = {attr: getattr(engine.state, attr, None) for attr in self.state_attributes}
            self._log_metrics(attrs, step=current_epoch)
[ ]:
%%writefile nnUNetBundle/src/utils.py

import re
from PyMAIA.utils.file_utils import subfiles
import yaml
from monai.bundle import ConfigParser
import monai
from pathlib import Path
import os

def set_mlflow_token(token):
    os.environ["MLFLOW_TRACKING_TOKEN"] = token

def create_mlflow_experiment_params(params_file, custom_params=None):
    params_dict = {}
    config_values = monai.config.deviceconfig.get_config_values()
    for k in config_values:
        params_dict[re.sub("[()]"," ",str(k))] = config_values[k]

    optional_config_values = monai.config.deviceconfig.get_optional_config_values()
    for k in optional_config_values:
        params_dict[re.sub("[()]"," ",str(k))] = optional_config_values[k]

    gpu_info = monai.config.deviceconfig.get_gpu_info()
    for k in gpu_info:
        params_dict[re.sub("[()]"," ",str(k))] = str(gpu_info[k])

    yaml_config_files = [params_file]
    # %%
    monai_config = {}
    for config_file in yaml_config_files:
        with open(config_file, 'r') as file:
            monai_config.update(yaml.safe_load(file))

    monai_config["bundle_root"] = str(Path(Path(params_file).parent).parent)

    parser = ConfigParser(monai_config, globals={"os": "os",
                                                 "pathlib": "pathlib",
                                                 "json": "json",
                                                 "ignite": "ignite"
                                                 })

    parser.parse(True)

    for k in monai_config:
        params_dict[k] = parser.get_parsed_content(k,instantiate=True)

    if custom_params is not None:
        for k in custom_params:
            params_dict[k] = custom_params[k]
    return params_dict

def get_checkpoint(epoch, ckpt_dir):
    if epoch == "latest":

        latest_checkpoints = subfiles(ckpt_dir, prefix="checkpoint_epoch", sort=True,
                                      join=False)
        epochs = []
        for latest_checkpoint in latest_checkpoints:
            epochs.append(int(latest_checkpoint[len("checkpoint_epoch="):-len(".pt")]))

        epochs.sort()
        latest_epoch = epochs[-1]
        return latest_epoch
    else:
        return epoch

def reload_checkpoint(trainer, epoch, num_train_batches_per_epoch, ckpt_dir):

    epoch_to_load = get_checkpoint(epoch, ckpt_dir)
    trainer.state.epoch = epoch_to_load
    trainer.state.iteration = (epoch_to_load* num_train_batches_per_epoch) +1

Inference#

After training the nnUNet model, we can then perform inference on new data. We use a nnUNetModelWrapper as a wrapper around the nnUNet model to perform inference from the MONAI Bundle. In this way, the nnUNet preprocessing, inference and postprocessing steps are handled by the nnUNetModelWrapper, with the Bundle blocks only needing to handle the input data loading and sending to the nnUnet block and the nnUNet prediction postprocessing.

The nnUNetModelWrapper receives as input the data dictionary loaded by the DataLoader, and returns the model predictions as a MetaTensor.

[ ]:
import torch
import os
from typing import Union, Optional
import torch
from monai.data.meta_tensor import MetaTensor
from torch.backends import cudnn
import setuptools
from batchgenerators.utilities.file_and_folder_operations import join, isfile, load_json
import numpy as np
import monai
from tqdm import tqdm
from pathlib import Path
import json
[ ]:


class nnUNetModelWrapper(torch.nn.Module): def __init__(self, predictor, model_folder): super().__init__() self.predictor = predictor self.predictor.initialize_from_trained_model_folder( model_folder, use_folds=(0,), checkpoint_name='checkpoint_final.pth', ) self.network_weights = self.predictor.network def forward(self, x): if type(x) is tuple: input_files = [img.meta['filename_or_obj'][0] for img in x] else: input_files = x.meta['filename_or_obj'] if type(input_files) == str: input_files = [input_files] output = self.predictor.predict_from_files( [input_files], None, save_probabilities=False, overwrite=True, num_processes_preprocessing=2, num_processes_segmentation_export=2, folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) out_tensors= [] for out in output: out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0),0))) out_tensor = torch.cat(out_tensors, 0) if type(x) is tuple: return MetaTensor(out_tensor, meta=x[0].meta) else: return MetaTensor(out_tensor, meta=x.meta)
[ ]:
def get_nnunet_predictor(model_folder):

    from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
    predictor = nnUNetPredictor(
        tile_step_size=0.5,
        use_gaussian=True,
        use_mirroring=False,
        device=torch.device('cuda', 0),
        verbose=False,
        verbose_preprocessing=False,
        allow_tqdm=True
    )
    # initializes the network architecture, loads the checkpoint
    wrapper = nnUNetModelWrapper(predictor, model_folder)
    return wrapper


[ ]:
network = get_nnunet_predictor("/home/maia-user/Tutorials/MAIA/Experiments/Task09_Spleen/Task09_Spleen_results/Dataset109_Task09_Spleen/nnUNetTrainer__nnUNetResEncUNetLPlans__3d_fullres")

Test Data Preparation#

The Bundle accepts the test dataset in the following format:

Dataset
├── Case1
│   └── Case1.nii.gz
├── Case2
│   └── Case2.nii.gz
└── Case3
    └── Case3.nii.gz
[ ]:
%%bash

mkdir -p MAIA/MONAI_Bundle/input
mkdir -p MAIA/MONAI_Bundle/output
mkdir -p MAIA/MONAI_Bundle/input/spleen_1

cp MAIA/Task09_Spleen/imagesTs/spleen_1.nii.gz MAIA/MONAI_Bundle/input/spleen_1
[ ]:
%%bash

tree MAIA/MONAI_Bundle/input
[ ]:
import pathlib

def get_subfolder_dataset(data_dir,modality_conf):
    data_list = []
    for f in os.scandir(data_dir):

        if f.is_dir():
            subject_dict = {key:str(pathlib.Path(f.path).joinpath(f.name+modality_conf[key]['suffix'])) for key in modality_conf}
            data_list.append(subject_dict)
    return data_list

Data Loading#

[ ]:
modalities = {
    "image": {"suffix": ".nii.gz"},
}

data = get_subfolder_dataset("MAIA/MONAI_Bundle/input",modalities)
[ ]:
from monai.transforms import LoadImaged
from monai.data import Dataset, DataLoader

preprocessing = LoadImaged(keys=["image"],ensure_channel_first=True, image_only=False)


test_dataset = Dataset(data,transform=preprocessing)

test_loader = DataLoader(test_dataset, batch_size=1)

Test nnUNetModelWrapper#

To test the nnUNetModelWrapper, we can provide a test case to the nnUNetModelWrapper and extract the model predictions returned by the wrapper.

[ ]:
batch = next(iter(test_loader))

pred = network(batch["image"])

Postprocessing and Save Predictions#

After obtaining the model predictions, we can apply postprocessing transformations to the predictions and save the results to disk.

The Transposed transform is required to unify the axis order convention between MONAI and nnUNet. The nnUNet model uses the zyx axis order, while MONAI uses the xyz axis order.

[ ]:
from monai.transforms import Compose, Transposed, SaveImaged


class PreprocessNameFormatter:
    def __init__(self, filename_key):
        self.filename_key = filename_key


    def __call__(self, metadict: dict, saver) -> dict:
        """Returns a kwargs dict for :py:meth:`FolderLayout.filename`,
        according to the input metadata and SaveImage transform."""
        subject = (
            metadict.get(monai.utils.ImageMetaKey.FILENAME_OR_OBJ, getattr(saver, "_data_index", 0))
            if metadict
            else getattr(saver, "_data_index", 0)
        )
        patch_index = metadict.get(monai.utils.ImageMetaKey.PATCH_INDEX, None) if metadict else None
        subject = subject[:-len(self.filename_key)]+".nii.gz"
        return {"subject": f"{subject}", "idx": patch_index}


postprocessing = Compose([
    Transposed(keys="pred",indices=[0,3,2,1]),
    SaveImaged(keys="pred",
               output_dir="MAIA/MONAI_Bundle/output",
               output_postfix="prediction",
               meta_keys="image_meta_dict",
               output_name_formatter=PreprocessNameFormatter(modalities[list(modalities.keys())[0]]['suffix'])
               )
])

Evaluator#

Combining everything together, we can create an Evaluator that encapsulates the data loading, model inference, postprocessing, and evaluation steps. The Evaluator can be used to evaluate the model on the test dataset .

[ ]:
def prepare_nnunet_inference_batch(batch, device, non_blocking):

    return batch["image"], None
[ ]:
from monai.engines import SupervisedEvaluator

validator = SupervisedEvaluator(
    val_data_loader=test_loader,
    device = "cuda:0",
    prepare_batch=prepare_nnunet_inference_batch,
    network = network,
    postprocessing= postprocessing
)
[ ]:
validator.run()

nnUNetModelWrapper from MONAI Weights#

In alternative to the inference presented above, we can also load the MONAI weights and create a nnUNetMONAIModelWrapper from the MONAI weights. This way, we can perform inference using the MONAI weights.

[ ]:
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
from torch._dynamo import OptimizedModule
import nnunetv2
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class

class nnUNetMONAIModelWrapper(torch.nn.Module):
    def __init__(self, predictor, model_folder, model_name="model.pt"):
        super().__init__()
        self.predictor = predictor

        model_training_output_dir = model_folder
        use_folds = '0'

        ## Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor
        dataset_json = load_json(join(model_training_output_dir, 'dataset.json'))
        plans = load_json(join(model_training_output_dir, 'plans.json'))
        plans_manager = PlansManager(plans)

        if isinstance(use_folds, str):
            use_folds = [use_folds]

        parameters = []
        for i, f in enumerate(use_folds):
            f = int(f) if f != 'all' else f
            checkpoint = torch.load(join(model_training_output_dir, 'nnunet_checkpoint.pth'),
                                    map_location=torch.device('cpu'))
            monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device('cpu'))
            if i == 0:
                trainer_name = checkpoint['trainer_name']
                configuration_name = checkpoint['init_args']['configuration']
                inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \
                    'inference_allowed_mirroring_axes' in checkpoint.keys() else None

            parameters.append(monai_checkpoint['network_weights'])

        configuration_manager = plans_manager.get_configuration(configuration_name)
        # restore network
        num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
        trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
                                                    trainer_name, 'nnunetv2.training.nnUNetTrainer')
        if trainer_class is None:
            raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. '
                               f'Please place it there (in any .py file)!')
        network = trainer_class.build_network_architecture(
            configuration_manager.network_arch_class_name,
            configuration_manager.network_arch_init_kwargs,
            configuration_manager.network_arch_init_kwargs_req_import,
            num_input_channels,
            plans_manager.get_label_manager(dataset_json).num_segmentation_heads,
            enable_deep_supervision=False
        )

        predictor.plans_manager = plans_manager
        predictor.configuration_manager = configuration_manager
        predictor.list_of_parameters = parameters
        predictor.network = network
        predictor.dataset_json = dataset_json
        predictor.trainer_name = trainer_name
        predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes
        predictor.label_manager = plans_manager.get_label_manager(dataset_json)
        if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \
                and not isinstance(predictor.network, OptimizedModule):
            print('Using torch.compile')
            predictor.network = torch.compile(self.network)
        ## End Block
        self.network_weights = self.predictor.network

    def forward(self, x):
        if type(x) is tuple:
            input_files = [img.meta['filename_or_obj'][0] for img in x]
        else:
            input_files = x.meta['filename_or_obj']
        if type(input_files) == str:
            input_files = [input_files]


        output = self.predictor.predict_from_files(
                                [input_files],
                                 None,
                                 save_probabilities=False, overwrite=True,
                                 num_processes_preprocessing=2, num_processes_segmentation_export=2,
                                 folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)

        out_tensors= []
        for out in output:
            out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0),0)))
        out_tensor = torch.cat(out_tensors, 0)

        if type(x) is tuple:
            return  MetaTensor(out_tensor, meta=x[0].meta)
        else:
            return  MetaTensor(out_tensor, meta=x.meta)

def get_nnunet_monai_predictor(model_folder, model_name="model.pt"):

    from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
    predictor = nnUNetPredictor(
        tile_step_size=0.5,
        use_gaussian=True,
        use_mirroring=False,
        device=torch.device('cuda', 0),
        verbose=False,
        verbose_preprocessing=False,
        allow_tqdm=True
    )
    # initializes the network architecture, loads the checkpoint
    wrapper = nnUNetMONAIModelWrapper(predictor, model_folder, model_name)
    return wrapper
[ ]:
from monai.handlers import CheckpointLoader

network = get_nnunet_monai_predictor("/home/maia-user/Tutorials/nnUNetBundle/models")

# Optional: Load the best model, not needed since the checkpoint is already loaded in the wrapper
val_handlers = [
    CheckpointLoader(
        load_dict={
            'network_weights': network.network_weights,
        },
        strict=True,
        load_path="/home/maia-user/Tutorials/nnUNetBundle/models/best_model.pt",

    )
]

validator = SupervisedEvaluator(
    val_data_loader=test_loader,
    device = "cuda:0",
    prepare_batch=prepare_nnunet_inference_batch,
    network = network,
    postprocessing= postprocessing,
    val_handlers= val_handlers
)

validator.run()
[ ]:
%%writefile nnUNetBundle/configs/inference.yaml

imports:
  - $import json
  - $import src
  - $import src.inferer
  - $import src.dataset
  - $from pathlib import Path
  - $import os
  - $from ignite.contrib.handlers.tqdm_logger import ProgressBar
  - $import shutil

output_dir: "."
bundle_root: "."
data_dir: "."
model_folder: "."
prediction_suffix: "prediction"
modality_conf:
  image:
    suffix: ".nii.gz"

test_data_list: "$src.dataset.get_subfolder_dataset(@data_dir,@modality_conf)"
image_modality_keys: "$list(@modality_conf.keys())"
image_key: "image"
image_suffix: "@image_key"

preprocessing:
  _target_: Compose
  transforms:
  - _target_: LoadImaged
    keys: "@image_modality_keys"
    ensure_channel_first: True
    image_only: False

test_dataset:
  _target_: Dataset
  data: "$@test_data_list"
  transform: "@preprocessing"

test_loader:
  _target_: DataLoader
  dataset: "@test_dataset"
  batch_size: 1
  #collate_fn: "$monai.data.utils.no_collation"


device: "$torch.device('cuda')"

nnunet_config:
  model_folder: "$@bundle_root + '/models'"
  #model_folder: "@model_folder"

#network_def: "$src.inferer.get_nnunet_predictor(**@nnunet_config)"
network_def: "$src.inferer.get_nnunet_monai_predictor(**@nnunet_config)"

postprocessing:
  _target_: "Compose"
  transforms:
    - _target_: Transposed
      keys: "pred"
      indices:
      - 0
      - 3
      - 2
      - 1
    - _target_: SaveImaged
      keys: "pred"
      resample: False
      output_postfix: "@prediction_suffix"
      output_dir: "@output_dir"
      meta_keys: "image_meta_dict"
      output_name_formatter: "$src.inferer.PreprocessNameFormatter(@modality_conf[list(@modality_conf.keys())[0]]['suffix'])"


testing:
  dataloader: "$@test_loader"
  pbar:
    _target_: "ignite.contrib.handlers.tqdm_logger.ProgressBar"
  test_inferer: "$@inferer"

inferer:
  _target_: "SimpleInferer"

validator:
  _target_: "SupervisedEvaluator"
  postprocessing: "$@postprocessing"
  device: "$@device"
  inferer: "$@testing#test_inferer"
  val_data_loader: "$@testing#dataloader"
  network: "@network_def"
  prepare_batch: "$src.inferer.prepare_nnunet_inference_batch"
  val_handlers:
  - _target_: "CheckpointLoader"
    load_path: "$@bundle_root+'/models/model.pt'"
    load_dict:
      network_weights: '$@network_def.network_weights'
run:
  - "$@testing#pbar.attach(@validator)"
  - "$@validator.run()"
[ ]:
%%writefile nnUNetBundle/src/inferer.py

import torch
import os
from typing import Union, Optional
import torch
from monai.data.meta_tensor import MetaTensor
from torch.backends import cudnn
import setuptools
from batchgenerators.utilities.file_and_folder_operations import join, isfile, load_json
import numpy as np
import monai
from tqdm import tqdm
from pathlib import Path
import json
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
from torch._dynamo import OptimizedModule
import nnunetv2
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class

class PreprocessNameFormatter:
    def __init__(self, filename_key):
        self.filename_key = filename_key


    def __call__(self, metadict: dict, saver) -> dict:
        """Returns a kwargs dict for :py:meth:`FolderLayout.filename`,
        according to the input metadata and SaveImage transform."""
        subject = (
            metadict.get(monai.utils.ImageMetaKey.FILENAME_OR_OBJ, getattr(saver, "_data_index", 0))
            if metadict
            else getattr(saver, "_data_index", 0)
        )
        patch_index = metadict.get(monai.utils.ImageMetaKey.PATCH_INDEX, None) if metadict else None
        subject = subject[:-len(self.filename_key)]+".nii.gz"
        return {"subject": f"{subject}", "idx": patch_index}

class nnUNetModelWrapper(torch.nn.Module):
    def __init__(self, predictor, model_folder):
        super().__init__()
        self.predictor = predictor
        self.predictor.initialize_from_trained_model_folder(
        model_folder,
        use_folds=(0,),
        checkpoint_name='checkpoint_final.pth',
        )
        self.network_weights = self.predictor.network

    def forward(self, x):
        if type(x) is tuple:
            input_files = [img.meta['filename_or_obj'][0] for img in x]
        else:
            input_files = x.meta['filename_or_obj']
        if type(input_files) == str:
            input_files = [input_files]


        output = self.predictor.predict_from_files(
                                [input_files],
                                 None,
                                 save_probabilities=False, overwrite=True,
                                 num_processes_preprocessing=2, num_processes_segmentation_export=2,
                                 folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)

        out_tensors= []
        for out in output:
            out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0),0)))
        out_tensor = torch.cat(out_tensors, 0)

        if type(x) is tuple:
            return  MetaTensor(out_tensor, meta=x[0].meta)
        else:
            return  MetaTensor(out_tensor, meta=x.meta)

def get_nnunet_predictor(model_folder):

    from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
    predictor = nnUNetPredictor(
        tile_step_size=0.5,
        use_gaussian=True,
        use_mirroring=False,
        #perform_everything_on_device=True,
        device=torch.device('cuda', 0),
        verbose=False,
        verbose_preprocessing=False,
        allow_tqdm=True
    )
    # initializes the network architecture, loads the checkpoint
    wrapper = nnUNetModelWrapper(predictor, model_folder)
    return wrapper

def prepare_nnunet_inference_batch(batch, device, non_blocking):

    return batch["image"], None

class nnUNetMONAIModelWrapper(torch.nn.Module):
    def __init__(self, predictor, model_folder, model_name="model.pt"):
        super().__init__()
        self.predictor = predictor

        model_training_output_dir = model_folder
        use_folds = '0'

        ## Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor
        dataset_json = load_json(join(model_training_output_dir, 'dataset.json'))
        plans = load_json(join(model_training_output_dir, 'plans.json'))
        plans_manager = PlansManager(plans)

        if isinstance(use_folds, str):
            use_folds = [use_folds]

        parameters = []
        for i, f in enumerate(use_folds):
            f = int(f) if f != 'all' else f
            checkpoint = torch.load(join(model_training_output_dir, 'nnunet_checkpoint.pth'),
                                    map_location=torch.device('cpu'))
            monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device('cpu'))
            if i == 0:
                trainer_name = checkpoint['trainer_name']
                configuration_name = checkpoint['init_args']['configuration']
                inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \
                    'inference_allowed_mirroring_axes' in checkpoint.keys() else None

            parameters.append(monai_checkpoint['network_weights'])

        configuration_manager = plans_manager.get_configuration(configuration_name)
        # restore network
        num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
        trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
                                                    trainer_name, 'nnunetv2.training.nnUNetTrainer')
        if trainer_class is None:
            raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. '
                               f'Please place it there (in any .py file)!')
        network = trainer_class.build_network_architecture(
            configuration_manager.network_arch_class_name,
            configuration_manager.network_arch_init_kwargs,
            configuration_manager.network_arch_init_kwargs_req_import,
            num_input_channels,
            plans_manager.get_label_manager(dataset_json).num_segmentation_heads,
            enable_deep_supervision=False
        )

        predictor.plans_manager = plans_manager
        predictor.configuration_manager = configuration_manager
        predictor.list_of_parameters = parameters
        predictor.network = network
        predictor.dataset_json = dataset_json
        predictor.trainer_name = trainer_name
        predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes
        predictor.label_manager = plans_manager.get_label_manager(dataset_json)
        if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \
                and not isinstance(predictor.network, OptimizedModule):
            print('Using torch.compile')
            predictor.network = torch.compile(self.network)
        ## End Block
        self.network_weights = self.predictor.network

    def forward(self, x):
        if type(x) is tuple:
            input_files = [img.meta['filename_or_obj'][0] for img in x]
        else:
            input_files = x.meta['filename_or_obj']
        if type(input_files) == str:
            input_files = [input_files]


        output = self.predictor.predict_from_files(
                                [input_files],
                                 None,
                                 save_probabilities=False, overwrite=True,
                                 num_processes_preprocessing=2, num_processes_segmentation_export=2,
                                 folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)

        out_tensors= []
        for out in output:
            out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0),0)))
        out_tensor = torch.cat(out_tensors, 0)

        if type(x) is tuple:
            return  MetaTensor(out_tensor, meta=x[0].meta)
        else:
            return  MetaTensor(out_tensor, meta=x.meta)

def get_nnunet_monai_predictor(model_folder, model_name="model.pt"):

    from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
    predictor = nnUNetPredictor(
        tile_step_size=0.5,
        use_gaussian=True,
        use_mirroring=False,
        device=torch.device('cuda', 0),
        verbose=False,
        verbose_preprocessing=False,
        allow_tqdm=True
    )
    # initializes the network architecture, loads the checkpoint
    wrapper = nnUNetMONAIModelWrapper(predictor, model_folder, model_name)
    return wrapper

[ ]:
%%writefile nnUNetBundle/src/dataset.py

import pathlib
import os

def get_subfolder_dataset(data_dir,modality_conf):
    data_list = []
    for f in os.scandir(data_dir):

        if f.is_dir():
            subject_dict = {key:str(pathlib.Path(f.path).joinpath(f.name+modality_conf[key]['suffix'])) for key in modality_conf}
            data_list.append(subject_dict)
    return data_list

MONAI Bundle to nnUNet Conversion#

To convert a MONAI Bundle to a nnUNet Bundle, we need to combine the MONAI checkpoint with the nnUNet checkpoint. This is done by loading the MONAI checkpoint and the nnUNet checkpoint, and updating the nnUNet model weights with the MONAI model weights.

[ ]:
from PyMAIA.utils.file_utils import subfiles
from nnunetv2.training.logging.nnunet_logger import nnUNetLogger
from pathlib import Path
import torch
from odict import odict
import os
import shutil
[ ]:
def convert_MONAI_to_nnUNet(nnunet_root_folder, nnunet_config, bundle_config):
    os.environ["ROOT_FOLDER"] = nnunet_root_folder

    os.environ["RESULTS_FOLDER"] = str(
        Path(os.environ["ROOT_FOLDER"]).joinpath(
            nnunet_config["Experiment Name"], nnunet_config["Experiment Name"] + "_results"
        )
    )

    nnunet_trainer = "nnUNetTrainer"
    nnunet_plans = "nnUNetPlans"

    if "nnunet_trainer" in nnunet_config:
        nnunet_trainer = nnunet_config["nnunet_trainer"]

    if "nnunet_plans" in nnunet_config:
        nnunet_plans = nnunet_config["nnunet_plans"]

    nnunet_model_folder = Path(os.environ["RESULTS_FOLDER"]).joinpath(
        "Dataset" + nnunet_config["task_ID"] + "_" + nnunet_config[
            "Experiment Name"],
        f"{nnunet_trainer}__{nnunet_plans}__3d_fullres")

    bundle_name = bundle_config["Bundle_Name"]

    nnunet_checkpoint = torch.load(f"{bundle_name}/models/nnunet_checkpoint.pth")
    latest_checkpoints = subfiles(Path(bundle_name).joinpath("models"),prefix="checkpoint_epoch",sort=True,join=False)
    epochs = []
    for latest_checkpoint in latest_checkpoints:
        epochs.append(int(latest_checkpoint[len("checkpoint_epoch="):-len(".pt")]))

    epochs.sort()
    final_epoch = epochs[-1]
    monai_last_checkpoint = torch.load(f"{bundle_name}/models/checkpoint_epoch={final_epoch}.pt")

    best_checkpoints = subfiles(Path(bundle_name).joinpath("models"), prefix="checkpoint_key_metric", sort=True,
                                    join=False)
    key_metrics = []
    for best_checkpoint in best_checkpoints:
        key_metrics.append(str(best_checkpoint[len("checkpoint_key_metric="):-len(".pt")]))

    key_metrics.sort()
    best_key_metric = key_metrics[-1]
    monai_best_checkpoint = torch.load(f"{bundle_name}/models/checkpoint_key_metric={best_key_metric}.pt")

    nnunet_checkpoint['optimizer_state'] = monai_last_checkpoint['optimizer_state']



    nnunet_checkpoint['network_weights'] = odict()

    for key in monai_last_checkpoint['network_weights']:
        nnunet_checkpoint['network_weights'][key] = monai_last_checkpoint['network_weights'][key]

    nnunet_checkpoint['current_epoch'] = final_epoch
    nnunet_checkpoint['logging'] = nnUNetLogger().get_checkpoint()
    nnunet_checkpoint['_best_ema'] = 0
    nnunet_checkpoint['grad_scaler_state'] = None



    torch.save(nnunet_checkpoint, Path(nnunet_model_folder).joinpath("fold_0","checkpoint_final.pth"))

    nnunet_checkpoint['network_weights'] = odict()

    nnunet_checkpoint['optimizer_state'] = monai_best_checkpoint['optimizer_state']

    for key in monai_best_checkpoint['network_weights']:
        nnunet_checkpoint['network_weights'][key] = \
        monai_best_checkpoint['network_weights'][key]

    torch.save(nnunet_checkpoint, Path(nnunet_model_folder).joinpath("fold_0", "checkpoint_best.pth"))

    shutil.move(f"{bundle_name}/models/checkpoint_epoch={final_epoch}.pt",f"{bundle_name}/models/model.pt")
    shutil.move(f"{bundle_name}/models/checkpoint_key_metric={best_key_metric}.pt",f"{bundle_name}/models/best_model.pt")
[ ]:
nnunet_root_folder = "MAIA/Experiments"



nnunet_config = {
    "Experiment Name": "Task09_Spleen",
    "task_ID": "109",
    "nnunet_plans":"nnUNetResEncUNetLPlans"
}

bundle_config = {
    "Bundle_Name": "nnUNetBundle"
}

nnUNet to MONAI Bundle Conversion#

To convert a nnUNet Bundle to a MONAI Bundle, we need to separate the MONAI checkpoint from the nnUNet checkpoint. This is done by loading the nnUNet checkpoint and the MONAI checkpoint, and updating the MONAI model weights with the nnUNet model weights.

[ ]:
def convert_nnunet_to_monai_bundle(nnunet_root_folder, nnunet_config, bundle_root_folder):
    os.environ["ROOT_FOLDER"] = nnunet_root_folder

    os.environ["RESULTS_FOLDER"] = str(
        Path(os.environ["ROOT_FOLDER"]).joinpath(
            nnunet_config["Experiment Name"], nnunet_config["Experiment Name"] + "_results"
        )
    )

    nnunet_trainer = "nnUNetTrainer"
    nnunet_plans = "nnUNetPlans"

    if "nnunet_trainer" in nnunet_config:
        nnunet_trainer = nnunet_config["nnunet_trainer"]

    if "nnunet_plans" in nnunet_config:
        nnunet_plans = nnunet_config["nnunet_plans"]

    nnunet_model_folder = Path(os.environ["RESULTS_FOLDER"]).joinpath(
        "Dataset" + nnunet_config["task_ID"] + "_" + nnunet_config[
            "Experiment Name"],
        f"{nnunet_trainer}__{nnunet_plans}__3d_fullres")

    nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath("fold_0","checkpoint_final.pth"))
    nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath("fold_0","checkpoint_best.pth"))

    nnunet_checkpoint = {}
    nnunet_checkpoint['inference_allowed_mirroring_axes'] = nnunet_checkpoint_final['inference_allowed_mirroring_axes']
    nnunet_checkpoint['init_args'] = nnunet_checkpoint_final['init_args']
    nnunet_checkpoint['trainer_name'] = nnunet_checkpoint_final['trainer_name']

    torch.save(nnunet_checkpoint, Path(bundle_root_folder).joinpath("models","nnunet_checkpoint.pth"))

    monai_last_checkpoint = {}
    monai_last_checkpoint['network_weights'] = nnunet_checkpoint_final['network_weights']
    torch.save(monai_last_checkpoint, Path(bundle_root_folder).joinpath("models","model.pt"))

    monai_best_checkpoint = {}
    monai_best_checkpoint['network_weights'] = nnunet_checkpoint_best['network_weights']
    torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath("models","best_model.pt"))

    shutil.copy(Path(nnunet_model_folder).joinpath("plans.json"),Path(bundle_root_folder).joinpath("models","plans.json"))
    shutil.copy(Path(nnunet_model_folder).joinpath("dataset.json"),Path(bundle_root_folder).joinpath("models","dataset.json"))

[ ]:
nnunet_root_folder = "MAIA/Experiments"



nnunet_config = {
    "Experiment Name": "Task09_Spleen",
    "task_ID": "109",
    "nnunet_plans":"nnUNetResEncUNetLPlans"
}

bundle_root_folder = "nnUNetBundle_Test"

Path(bundle_root_folder).joinpath("models").mkdir(parents=True, exist_ok=True)

convert_nnunet_to_monai_bundle(nnunet_root_folder, nnunet_config, bundle_root_folder)
[ ]: