{ "cells": [ { "cell_type": "markdown", "id": "bec25bff", "metadata": {}, "source": [ "# nnUNet MONAI Bundle\n", "\n", "In this notebook, we will demonstrate how to create a MONAI Bundle supporting nnUNet experiment for training and inference. In this step-by step tutorial, we will describe how to create all the required python code and YAML configuration files needed to train and evaluate a nnUNet model using the MONAI Bundle format." ] }, { "cell_type": "markdown", "id": "c6757489", "metadata": {}, "source": [ "## nnUNet Trainer\n", "\n", "The core component for the nnUNet MONAI Bundle is the `get_nnunet_trainer` function. This function is responsible for creating the nnUNet trainer object from the native nnUNetv2 implementation. From the nnUNet trainer object, we can access the training components, such as the data loaders, model, learning rate scheduler, optimizer, and loss function, and perform training and inference tasks." ] }, { "cell_type": "code", "execution_count": 1, "id": "b4090170-b7c4-402b-9d70-9b59c463354b", "metadata": {}, "outputs": [ { "ename": "ModuleNotFoundError", "evalue": "No module named 'torch'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Union, Optional\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjson\u001b[39;00m\n", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch'" ] } ], "source": [ "import torch\n", "from typing import Union, Optional\n", "import json\n", "from pathlib import Path\n", "import os\n", "from torch.backends import cudnn\n", "\n", "def get_nnunet_trainer(dataset_name_or_id: Union[str, int],\n", " configuration: str, fold: Union[int, str],\n", " pymaia_config_file: str = None, # To set env variables\n", " trainer_class_name: str = 'nnUNetTrainer',\n", " plans_identifier: str = 'nnUNetPlans',\n", " pretrained_weights: Optional[str] = None,\n", " num_gpus: int = 1,\n", " use_compressed_data: bool = False,\n", " export_validation_probabilities: bool = False,\n", " continue_training: bool = False,\n", " only_run_validation: bool = False,\n", " disable_checkpointing: bool = False,\n", " val_with_best: bool = False,\n", " device: torch.device = torch.device(\n", " 'cuda'),\n", " pretrained_model = None\n", " ): # From nnUNet/nnunetv2/run/run_training.py#run_training\n", "\n", " ## Block Added\n", " \n", " if pymaia_config_file != None:\n", " with open(pymaia_config_file, \"r\") as f:\n", " pymaia_config_dict = json.load(f)\n", "\n", " os.environ[\"nnUNet_raw\"] = str(Path(pymaia_config_dict[\"base_folder\"]).joinpath(\"nnUNet_raw\"))\n", " os.environ[\"nnUNet_preprocessed\"] = pymaia_config_dict[\"preprocessing_folder\"]\n", " os.environ[\"nnUNet_results\"] = pymaia_config_dict[\"results_folder\"]\n", "\n", " from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint\n", " ## End Block\n", "\n", " if isinstance(fold, str):\n", " if fold != 'all':\n", " try:\n", " fold = int(fold)\n", " except ValueError as e:\n", " print(\n", " f'Unable to convert given value for fold to int: {fold}. fold must bei either \"all\" or an integer!')\n", " raise e\n", "\n", " if int(num_gpus) > 1:\n", " ... # Disable for now\n", " else:\n", " nnunet_trainer = get_trainer_from_args(str(dataset_name_or_id), configuration, fold, trainer_class_name,\n", " plans_identifier, use_compressed_data, device=device)\n", " \n", " if disable_checkpointing:\n", " nnunet_trainer.disable_checkpointing = disable_checkpointing\n", "\n", " assert not (\n", " continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.'\n", "\n", " maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights)\n", " nnunet_trainer.on_train_start() # Added to Initialize Trainer\n", " if torch.cuda.is_available():\n", " cudnn.deterministic = False\n", " cudnn.benchmark = True\n", "\n", " if pretrained_model is not None:\n", " state_dict = torch.load(pretrained_model)\n", " if 'network_weights' in state_dict:\n", " nnunet_trainer.network._orig_mod.load_state_dict(state_dict['network_weights'])\n", " #nnunet_trainer.network.load_state_dict(torch.load(pretrained_model)['model'])\n", " # Skip Training and Validation Phase \n", " return nnunet_trainer" ] }, { "cell_type": "markdown", "id": "f355baa4", "metadata": {}, "source": [ "The function `get_nnunet_trainer` accepts the following parameters:\n", "\n", "- `dataset_name_or_id`: The dataset name or ID to be used for training and evaluation.\n", "- `fold`: The fold number for the cross-validation experiment.\n", "- `config`: The training configuration for the nnUNet trainer, usually `3d_fullres`.\n", "- `trainer_class_name`: The nnUNet trainer class name to be used for training, e.g. `nnUNetTrainer`.\n", "- `plans_identifier`: The nnUNet plans identifier for the dataset, e.g. `nnUNetPlans`.\n", "- `pretained_model`: Optional parameter to specify the pre-trained model for transfer learning.\n", "\n", "Additionally, the function requires the `pymaia_config_file` (generated after running [`nnunet_prepare_data_folder`](https://pymaia.readthedocs.io/en/latest/apidocs/nnunet_prepare_data_folder.html)) as input parameter." ] }, { "cell_type": "code", "execution_count": null, "id": "e90eb700-6e8b-4f55-b35c-0ff1e8d1019c", "metadata": {}, "outputs": [], "source": [ "task_id = \"109\"\n", "pymaia_config_file = \"/home/maia-user/Tutorials/MAIA/Experiments/Task09_Spleen/Task09_Spleen_results/Dataset109_Task09_Spleen.json\"\n", "nnunet_trainer_class_name = \"nnUNetTrainer\"\n", "nnunet_plans_identifier = \"nnUNetResEncUNetLPlans\"\n", "\n", "#pretrained_model = \"/home/maia-user/Tutorials/nnunetmonaibundle/model/nnUNet_Bundle/models/checkpoint_epoch=10.pt\"\n", "#pretrained_model = \"/home/maia-user/Tutorials/Task09_Spleen_Bundle/models/Dataset109_Spleen/nnUNetTrainer__nnUNetResEncUNetLPlans__3d_fullres/fold_0/checkpoint_final.pth\"" ] }, { "cell_type": "markdown", "id": "8e161502", "metadata": {}, "source": [ "### Get nnUNet Trainer from Preprocessing Folder\n", "\n", "An alternative way to get the nnUNet Trainer is to use the `get_nnunet_trainer_from_preprocessing_folder` function. This function reads the preprocessing folder and returns the nnUNet Trainer object." ] }, { "cell_type": "code", "execution_count": null, "id": "5a10aabb", "metadata": {}, "outputs": [], "source": [ "from batchgenerators.utilities.file_and_folder_operations import join, load_json\n", "from nnunetv2.utilities.find_class_by_name import recursive_find_python_class\n", "import nnunetv2\n", "\n", "def get_nnunet_trainer_from_preprocessing_folder(\n", " plans_file,\n", " dataset_file,\n", " configuration: str,\n", " fold: Union[int, str],\n", " trainer_class_name: str = 'nnUNetTrainer',\n", " plans_identifier: str = 'nnUNetPlans',\n", " pretrained_weights: Optional[str] = None,\n", " num_gpus: int = 1,\n", " use_compressed_data: bool = False,\n", " export_validation_probabilities: bool = False,\n", " continue_training: bool = False,\n", " only_run_validation: bool = False,\n", " disable_checkpointing: bool = False,\n", " val_with_best: bool = False,\n", " device: torch.device = torch.device(\n", " 'cuda'),\n", " pretrained_model = None\n", "): # From nnUNet/nnunetv2/run/run_training.py#run_training\n", "\n", " ## Block Added\n", " os.environ[\"nnUNet_raw\"] = str(Path(\".\").joinpath(\"nnUNet_raw\"))\n", " os.environ[\"nnUNet_preprocessed\"] = \".\"\n", " os.environ[\"nnUNet_results\"] = \".\"\n", " from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint\n", " ##\n", "\n", " if isinstance(fold, str):\n", " if fold != 'all':\n", " try:\n", " fold = int(fold)\n", " except ValueError as e:\n", " print(\n", " f'Unable to convert given value for fold to int: {fold}. fold must bei either \"all\" or an integer!')\n", " raise e\n", "\n", " if int(num_gpus) > 1:\n", " ... # Disable for now\n", " else:\n", " nnunet_trainer = recursive_find_python_class(join(nnunetv2.__path__[0], \"training\", \"nnUNetTrainer\"),\n", " trainer_class_name, 'nnunetv2.training.nnUNetTrainer')\n", "\n", " plans = load_json(plans_file)\n", " dataset_json = load_json(dataset_file)\n", "\n", " nnunet_trainer = nnunet_trainer(plans=plans, configuration=configuration, fold=fold,\n", " dataset_json=dataset_json, unpack_dataset=False, device=torch.device(\"cuda\"))\n", "\n", " if disable_checkpointing:\n", " nnunet_trainer.disable_checkpointing = disable_checkpointing\n", "\n", " assert not (\n", " continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.'\n", "\n", " maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights)\n", " nnunet_trainer.initialize() # To Initialize Trainer\n", " if torch.cuda.is_available():\n", " cudnn.deterministic = False\n", " cudnn.benchmark = True\n", "\n", " # Skip Training and Validation Phase\n", " if pretrained_model is not None:\n", " state_dict = torch.load(pretrained_model)\n", " if 'network_weights' in state_dict:\n", " nnunet_trainer.network._orig_mod.load_state_dict(state_dict['network_weights'])\n", "\n", " return nnunet_trainer" ] }, { "cell_type": "code", "execution_count": null, "id": "fc612107-0be8-44bf-8e0d-515cf3014a2b", "metadata": { "scrolled": true }, "outputs": [], "source": [ "nnunet_trainer = get_nnunet_trainer(dataset_name_or_id = task_id,\n", " configuration = \"3d_fullres\",\n", " fold = \"0\",\n", " pymaia_config_file=pymaia_config_file,\n", " trainer_class_name = nnunet_trainer_class_name,\n", " plans_identifier = nnunet_plans_identifier,\n", " #pretrained_model=pretrained_model\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "a42b066f", "metadata": {}, "outputs": [], "source": [ "nnunet_trainer = get_nnunet_trainer_from_preprocessing_folder(plans_file=plans_file,\n", " dataset_file=dataset_file,\n", " configuration = \"3d_fullres\",\n", " fold = \"0\",\n", " trainer_class_name = nnunet_trainer_class_name,\n", " plans_identifier = nnunet_plans_identifier,\n", " #pretrained_model=pretrained_model\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "63d65cf0-56ea-4c45-88bc-e271a5e2195c", "metadata": {}, "outputs": [], "source": [ "from monai.data import Dataset\n", "from monai.handlers import StatsHandler, from_engine, MeanDice, ValidationHandler, LrScheduleHandler, CheckpointSaver, CheckpointLoader, TensorBoardStatsHandler, MLFlowHandler\n", "from monai.engines import SupervisedTrainer, SupervisedEvaluator\n", "\n", "from monai.transforms import Compose, Lambdad, Activationsd, AsDiscreted" ] }, { "cell_type": "markdown", "id": "765619ea", "metadata": {}, "source": [ "## Train and Val Data Loaders" ] }, { "cell_type": "code", "execution_count": null, "id": "f8e60cdb", "metadata": {}, "outputs": [], "source": [ "train_dataloader = nnunet_trainer.dataloader_train\n", "train_data = [{'case_identifier':k} for k in nnunet_trainer.dataloader_train.generator._data.dataset.keys()]\n", "train_dataset = Dataset(data=train_data)" ] }, { "cell_type": "code", "execution_count": null, "id": "9b80cc9a", "metadata": {}, "outputs": [], "source": [ "val_dataloader = nnunet_trainer.dataloader_val\n", "val_data = [{'case_identifier':k} for k in nnunet_trainer.dataloader_val.generator._data.dataset.keys()]\n", "val_dataset = Dataset(data=val_data)" ] }, { "cell_type": "markdown", "id": "3c7756d7", "metadata": {}, "source": [ "## Network, Optimizer, and Loss Function" ] }, { "cell_type": "code", "execution_count": null, "id": "26b88a8e", "metadata": {}, "outputs": [], "source": [ "device = nnunet_trainer.device\n", "\n", "network = nnunet_trainer.network\n", "optimizer = nnunet_trainer.optimizer\n", "lr_scheduler = nnunet_trainer.lr_scheduler\n", "loss = nnunet_trainer.loss" ] }, { "cell_type": "markdown", "id": "5d7d6023", "metadata": {}, "source": [ "## Prepare Batch Function\n", "\n", "The `nnUnet DataLoader` returns a dictionary with the `data` and `target` keys. Since the `SupervisedTrainer` used in the MONAI Bundle expects the data and target to be separate tensors, we need to create a custom prepare batch function to extract the data and target tensors from the dictionary." ] }, { "cell_type": "code", "execution_count": null, "id": "155e9460-9f69-4b15-bfb9-eae032afbc92", "metadata": {}, "outputs": [], "source": [ "def prepare_nnunet_batch(batch, device, non_blocking):\n", " data = batch[\"data\"].to(device, non_blocking=non_blocking)\n", " if isinstance(batch[\"target\"], list):\n", " target = [i.to(device, non_blocking=non_blocking) for i in batch[\"target\"]]\n", " else:\n", " target = batch[\"target\"].to(device, non_blocking=non_blocking)\n", " return data, target" ] }, { "cell_type": "code", "execution_count": null, "id": "70774a2a-84ff-4297-98e5-6452567f13a1", "metadata": {}, "outputs": [], "source": [ "image, label = prepare_nnunet_batch(next(iter(train_dataloader)),device=\"cpu\",non_blocking=True)" ] }, { "cell_type": "markdown", "id": "54b5c684", "metadata": {}, "source": [ "## MONAI Supervised Trainer\n", "\n", "The `SupervisedTrainer` class from MONAI is used to train the nnUNet model. For a minimal setup, we need to provide the model, optimizer, loss function, data loaders, number of epochs and the device to run the training." ] }, { "cell_type": "code", "execution_count": null, "id": "d480fbe6", "metadata": {}, "outputs": [], "source": [ "train_handlers = [\n", " StatsHandler(\n", " output_transform= from_engine(['loss'], first=True),\n", " tag_name= \"train_loss\"\n", " )\n", "]\n" ] }, { "cell_type": "code", "execution_count": null, "id": "844bb28a", "metadata": {}, "outputs": [], "source": [ "iterations = 100\n", "epochs = 50" ] }, { "cell_type": "code", "execution_count": null, "id": "415bfc68", "metadata": {}, "outputs": [], "source": [ "trainer = SupervisedTrainer(\n", " amp= True,\n", " device = device,\n", " epoch_length = iterations,\n", " loss_function = loss,\n", " max_epochs = epochs,\n", " network = network,\n", " prepare_batch = prepare_nnunet_batch,\n", " optimizer = optimizer,\n", " train_data_loader = train_dataloader,\n", " train_handlers= train_handlers\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "ba2ce831", "metadata": {}, "outputs": [], "source": [ "trainer.run()" ] }, { "cell_type": "markdown", "id": "c41fcf2a", "metadata": {}, "source": [ "## Adding Validation and Validation Metrics\n", "\n", "For a complete training setup, we need to add the validation data loader and the validation metrics to the `SupervisedTrainer`. Using the MONAI class `SupervisedEvaluator`, we can evaluate the model on the validation data loader and calculate the validation metrics (`Dice Score`)." ] }, { "cell_type": "code", "execution_count": null, "id": "e5f713a2", "metadata": {}, "outputs": [], "source": [ "val_key_metric = MeanDice(\n", " output_transform = from_engine(['pred', 'label']),\n", " reduction = \"mean\",\n", " include_background = False\n", "\n", ")\n", "\n", "additional_metrics = {\n", " \"Val_Dice_Per_Class\": MeanDice(\n", " output_transform = from_engine(['pred', 'label']),\n", " reduction = \"mean_batch\",\n", " include_background = False,\n", " )\n", " }" ] }, { "cell_type": "markdown", "id": "fa8287fb", "metadata": {}, "source": [ "Additionally, in order to compute the Mean Dice score over the batch, we need to apply a pos-processing transformtation to the nnUNet model output. Since `MeanDice` accepts `y` and `y_preds` as Batch-first tensors (BCHW[D]), we need to create a custom post-processing transform to convert the nnUNet model output to the required format." ] }, { "cell_type": "code", "execution_count": null, "id": "9e28ec37", "metadata": {}, "outputs": [], "source": [ "num_classes = 2\n", "\n", "postprocessing = Compose(\n", " transforms=[\n", " ## Extract only high-res predictions from Deep Supervision\n", " Lambdad( \n", " keys= [\"pred\",\"label\"],\n", " func = lambda x: x[0]\n", " ),\n", " ## Apply Softmax to the predictions\n", " Activationsd(\n", " keys= \"pred\",\n", " softmax= True\n", " ),\n", " ## Binarize the predictions\n", " AsDiscreted(\n", " keys= \"pred\",\n", " threshold= 0.5\n", " ),\n", " ## Convert the labels to one-hot\n", " AsDiscreted(\n", " keys= \"label\",\n", " to_onehot= num_classes\n", " )\n", " ]\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "7a90e728", "metadata": {}, "outputs": [], "source": [ "val_handlers = [StatsHandler(\n", " iteration_log = False\n", ")]" ] }, { "cell_type": "code", "execution_count": null, "id": "e9586476", "metadata": {}, "outputs": [], "source": [ "val_iterations = 100\n", "val_interval = 1" ] }, { "cell_type": "code", "execution_count": null, "id": "e8081ce4", "metadata": {}, "outputs": [], "source": [ "evaluator = SupervisedEvaluator(\n", " amp= True,\n", " device = device,\n", " epoch_length = val_iterations,\n", " network = network,\n", " key_val_metric={\"Val_Dice\": val_key_metric},\n", " prepare_batch= prepare_nnunet_batch,\n", " val_data_loader = val_dataloader,\n", " val_handlers= val_handlers,\n", " postprocessing= postprocessing,\n", " additional_metrics= additional_metrics,\n", ")" ] }, { "cell_type": "markdown", "id": "aadfd315", "metadata": {}, "source": [ "And finally, we add the evaluator to the `SupervisedTrainer` to calculate the validation metrics during training." ] }, { "cell_type": "code", "execution_count": null, "id": "e2bc29ad", "metadata": {}, "outputs": [], "source": [ "train_handlers.append(\n", " ValidationHandler(\n", " epoch_level = True,\n", " interval= val_interval,\n", " validator = evaluator\n", " )\n", ")" ] }, { "cell_type": "markdown", "id": "3c904b0a", "metadata": {}, "source": [ "We can also add the `MeanDice` metric to the `SupervisedTrainer` to calculate the mean dice score over the batch during training." ] }, { "cell_type": "code", "execution_count": null, "id": "a9d44de3", "metadata": {}, "outputs": [], "source": [ "train_key_metric = MeanDice(\n", " output_transform = from_engine(['pred', 'label']),\n", " reduction = \"mean\",\n", " include_background = False\n", "\n", ")\n", "\n", "additional_metrics = {\n", " \"Train_Dice_Per_Class\": MeanDice(\n", " output_transform = from_engine(['pred', 'label']),\n", " reduction = \"mean_batch\",\n", " include_background = False,\n", " )\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "a339901b", "metadata": {}, "outputs": [], "source": [ "trainer = SupervisedTrainer(\n", " amp= True,\n", " device = device,\n", " epoch_length = iterations,\n", " loss_function = loss,\n", " max_epochs = epochs,\n", " network = network,\n", " prepare_batch = prepare_nnunet_batch,\n", " optimizer = optimizer,\n", " train_data_loader = train_dataloader,\n", " train_handlers= train_handlers,\n", " key_train_metric = {\"Train_Dice\": train_key_metric},\n", " postprocessing= postprocessing,\n", " additional_metrics = additional_metrics\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "8f1869f5", "metadata": {}, "outputs": [], "source": [ "trainer.run()" ] }, { "cell_type": "markdown", "id": "fbfb0762", "metadata": {}, "source": [ "## Learning Rate Scheduler\n", "\n", "One last component to add to the `SupervisedTrainer`, in order to replicate the training behaviour of the native nnUNet, is the learning rate scheduler." ] }, { "cell_type": "code", "execution_count": null, "id": "9b92598c", "metadata": {}, "outputs": [], "source": [ "train_handlers.append(\n", " LrScheduleHandler(\n", " lr_scheduler = lr_scheduler,\n", " print_lr = True\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "54efe274", "metadata": {}, "outputs": [], "source": [ "trainer = SupervisedTrainer(\n", " amp= True,\n", " device = device,\n", " epoch_length = iterations,\n", " loss_function = loss,\n", " max_epochs = epochs,\n", " network = network,\n", " prepare_batch = prepare_nnunet_batch,\n", " optimizer = optimizer,\n", " train_data_loader = train_dataloader,\n", " train_handlers= train_handlers,\n", " key_train_metric = {\"Train_Dice\": train_key_metric},\n", " postprocessing= postprocessing,\n", " additional_metrics = additional_metrics\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "1f687bc6", "metadata": {}, "outputs": [], "source": [ "trainer.run()" ] }, { "cell_type": "code", "execution_count": null, "id": "fac3c8c5", "metadata": {}, "outputs": [], "source": [ "train_handlers[-1].lr_scheduler.get_last_lr()" ] }, { "cell_type": "markdown", "id": "52a36367", "metadata": {}, "source": [ "## Checkpointing\n", "\n", "To save the model weights during training, we can use the `CheckpointSaver` callback from MONAI. This callback saves the model weights after each epoch.\n", "We can later use the `CheckpointLoader` to load the model weights and perform inference or resume training." ] }, { "cell_type": "code", "execution_count": null, "id": "54c4fb21", "metadata": {}, "outputs": [], "source": [ "val_handlers.append(\n", " CheckpointSaver(\n", " save_dir= \"Bundle/models\",\n", " save_dict= {\"network_weights\": nnunet_trainer.network._orig_mod, \"optimizer_state\": nnunet_trainer.optimizer, \"scheduler\": nnunet_trainer.lr_scheduler},\n", " #save_final= True,\n", " save_interval= 1,\n", " save_key_metric= True,\n", " #final_filename= \"model_final.pt\",\n", " #key_metric_filename= \"model.pt\",\n", " n_saved= 1\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "ece9e988", "metadata": {}, "outputs": [], "source": [ "ckpt_dir = \"Bundle/models\"\n", "reload_checkpoint_epoch = \"latest\"\n", "\n", "train_handlers.append(\n", " CheckpointLoader(\n", " load_path= 'Bundle/models/checkpoint_epoch='+str(get_checkpoint(reload_checkpoint_epoch, ckpt_dir))+'.pt'\n", " load_dict= {\"network_weights\": nnunet_trainer.network._orig_mod, \"optimizer_state\": nnunet_trainer.optimizer, \"scheduler\": nnunet_trainer.lr_scheduler},\n", " map_location= device\n", " )\n", ")" ] }, { "cell_type": "markdown", "id": "78db7528", "metadata": {}, "source": [ "## Initial nnUNet Checkpoint\n", "\n", "In order to provide compatibility with the native nnUNet, we need to save the nnUNet-specific configuration, together the regular MONAI checkpoint. This is done only once, before the training starts. At the end of the training, we will have a MONAI checkpoint and a nnUNet checkpoint. To be able to convert the MONAI checkpoint to a nnUNet checkpoint at any time, we can then combine the two checkpoints." ] }, { "cell_type": "code", "execution_count": null, "id": "8131d003", "metadata": {}, "outputs": [], "source": [ "checkpoint = {\n", " \"inference_allowed_mirroring_axes\": nnunet_trainer.inference_allowed_mirroring_axes,\n", " \"init_args\": nnunet_trainer.my_init_kwargs,\n", " \"trainer_name\": nnunet_trainer.__class__.__name__\n", "}\n", "checkpoint_filename = 'Bundle/models/nnunet_checkpoint.pth'\n", "\n", "torch.save(checkpoint, checkpoint_filename)" ] }, { "cell_type": "markdown", "id": "c0e26cdb", "metadata": {}, "source": [ "## MLFlow and Tensorboard Monitoring\n", "\n", "To monitor the training process, we can use MLFlow and Tensorboard. We can log the training metrics, hyperparameters, and model weights to MLFlow, and visualize the training metrics using Tensorboard." ] }, { "cell_type": "code", "execution_count": null, "id": "b2976402", "metadata": {}, "outputs": [], "source": [ "train_handlers.append(\n", " TensorBoardStatsHandler(\n", " log_dir= \"Bundle/logs\",\n", " output_transform= from_engine(['loss'], first=True),\n", " tag_name = \"train_loss\"\n", " )\n", ")\n", "\n", "val_handlers.append(\n", " TensorBoardStatsHandler(\n", " log_dir= \"Bundle/logs\",\n", " iteration_log = False\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "86d54487", "metadata": {}, "outputs": [], "source": [ "def mlflow_transform(state_output):\n", " return state_output[0]['loss']\n", "\n", "class MLFlowPyMAIAHandler(MLFlowHandler):\n", " def __init__(self, label_dict, **kwargs):\n", " super(MLFlowPyMAIAHandler, self).__init__(**kwargs)\n", " self.label_dict = label_dict\n", " \n", " def _default_epoch_log(self, engine) -> None:\n", " \"\"\"\n", " Execute epoch level log operation.\n", " Default to track the values from Ignite `engine.state.metrics` dict and\n", " track the values of specified attributes of `engine.state`.\n", "\n", " Args:\n", " engine: Ignite Engine, it can be a trainer, validator or evaluator.\n", "\n", " \"\"\"\n", " log_dict = engine.state.metrics\n", " if not log_dict:\n", " return\n", "\n", " current_epoch = self.global_epoch_transform(engine.state.epoch)\n", "\n", " new_log_dict = {}\n", "\n", " for metric in log_dict:\n", " if type(log_dict[metric]) == torch.Tensor:\n", " for i,val in enumerate(log_dict[metric]):\n", " new_log_dict[metric+\"_{}\".format(list(self.label_dict.keys())[i+1])] = val\n", " else:\n", " new_log_dict[metric] = log_dict[metric]\n", " self._log_metrics(new_log_dict, step=current_epoch)\n", "\n", " if self.state_attributes is not None:\n", " attrs = {attr: getattr(engine.state, attr, None) for attr in self.state_attributes}\n", " self._log_metrics(attrs, step=current_epoch)" ] }, { "cell_type": "code", "execution_count": null, "id": "6e8fdc2e", "metadata": {}, "outputs": [], "source": [ "import re\n", "import yaml\n", "from monai.bundle import ConfigParser\n", "import monai\n", "\n", "def create_mlflow_experiment_params(params_file, custom_params=None):\n", " params_dict = {}\n", " config_values = monai.config.deviceconfig.get_config_values()\n", " for k in config_values:\n", " params_dict[re.sub(\"[()]\",\" \",str(k))] = config_values[k]\n", "\n", " optional_config_values = monai.config.deviceconfig.get_optional_config_values()\n", " for k in optional_config_values:\n", " params_dict[re.sub(\"[()]\",\" \",str(k))] = optional_config_values[k]\n", "\n", " gpu_info = monai.config.deviceconfig.get_gpu_info()\n", " for k in gpu_info:\n", " params_dict[re.sub(\"[()]\",\" \",str(k))] = str(gpu_info[k])\n", "\n", " yaml_config_files = [params_file]\n", " # %%\n", " monai_config = {}\n", " for config_file in yaml_config_files:\n", " with open(config_file, 'r') as file:\n", " monai_config.update(yaml.safe_load(file))\n", "\n", " monai_config[\"bundle_root\"] = str(Path(Path(params_file).parent).parent)\n", "\n", " parser = ConfigParser(monai_config, globals={\"os\": \"os\",\n", " \"pathlib\": \"pathlib\",\n", " \"json\": \"json\",\n", " \"ignite\": \"ignite\"\n", " })\n", "\n", " parser.parse(True)\n", "\n", " for k in monai_config:\n", " params_dict[k] = parser.get_parsed_content(k,instantiate=True)\n", "\n", " if custom_params is not None:\n", " for k in custom_params:\n", " params_dict[k] = custom_params[k]\n", " return params_dict" ] }, { "cell_type": "code", "execution_count": null, "id": "8fb3858f", "metadata": {}, "outputs": [], "source": [ "%%writefile Bundle/mlflow_params.yaml\n", "\n", "num_classes: 2\n", "task_id: \"109\"\n", "pymaia_config_file: \"/home/maia-user/Tutorials/MAIA/Experiments/Task09_Spleen/Task09_Spleen_results/Dataset109_Task09_Spleen.json\"\n", "tracking_uri: \"http://localhost:5000\"\n", "mlflow_experiment_name: \"nnUNet_Bundle_Spleen\"\n", "mlflow_run_name: \"nnUNet_Bundle_Spleen\"\n", "nnunet_trainer_class_name: \"nnUNetTrainer\"\n", "nnunet_plans_identifier: \"nnUNetPlans\"\n" ] }, { "cell_type": "code", "execution_count": null, "id": "174505fe", "metadata": {}, "outputs": [], "source": [ "mlflow_experiment_name = \"nnUNet_Bundle_Spleen\"\n", "mlflow_run_name = \"nnUNet_Bundle_Spleen\"\n", "label_dict = {0: \"background\", 1: \"Spleen\"}\n", "tracking_uri = \"http://localhost:5000\"\n", "params_file = \"Bundle/mlflow_params.yaml\"\n", "\n", "\n", "train_handlers.append(\n", " MLFlowPyMAIAHandler(\n", " dataset_dict = {\"train\": train_dataset},\n", " dataset_keys = \"case_identifier\",\n", " experiment_param = create_mlflow_experiment_params(params_file),\n", " experiment_name= mlflow_experiment_name,\n", " label_dict = label_dict,\n", " output_transform = mlflow_transform,\n", " run_name = mlflow_run_name,\n", " state_attributes = [\"best_metric\", \"best_metric_epoch\"],\n", " tag_name = \"Train_Loss\",\n", " tracking_uri = tracking_uri,\n", " )\n", ")\n", "\n", "val_handlers.append(\n", " MLFlowPyMAIAHandler(\n", " experiment_name= mlflow_experiment_name,\n", " iteration_log = False,\n", " label_dict = label_dict,\n", " output_transform = mlflow_transform,\n", " run_name = mlflow_run_name,\n", " state_attributes = [\"best_metric\", \"best_metric_epoch\"],\n", " tracking_uri = tracking_uri,\n", " )\n", ")" ] }, { "cell_type": "markdown", "id": "20722504", "metadata": {}, "source": [ "To start the MLFlow server, we can run the following command in the terminal:\n", "\n", "```bash\n", "cd Bundle/MLFlow && mlflow server\n", "```\n", "To run Tensorboard, we can use the following command:\n", "\n", "```bash\n", "tensorboard --logdir Bundle/logs\n", "```" ] }, { "cell_type": "code", "execution_count": null, "id": "0d11d8c8", "metadata": {}, "outputs": [], "source": [ "trainer = SupervisedTrainer(\n", " amp= True,\n", " device = device,\n", " epoch_length = iterations,\n", " loss_function = loss,\n", " max_epochs = epochs,\n", " network = network,\n", " prepare_batch = prepare_nnunet_batch,\n", " optimizer = optimizer,\n", " train_data_loader = train_dataloader,\n", " train_handlers= train_handlers,\n", " key_train_metric = {\"Train_Dice\": train_key_metric},\n", " postprocessing= postprocessing,\n", " additional_metrics = additional_metrics\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "fcc921bf", "metadata": {}, "outputs": [], "source": [ "trainer.run()" ] }, { "cell_type": "markdown", "id": "b26642eb", "metadata": {}, "source": [ "## Reload Checkpoint\n", "\n", "When resuming the training from a checkpoint, we also want to restart the training from the same epoch. To do this, we need to load the checkpoint and update the `trainer.state.epoch` and `trainer.state.iteration` parameter in the `SupervisedTrainer`." ] }, { "cell_type": "code", "execution_count": null, "id": "93a57835", "metadata": {}, "outputs": [], "source": [ "from PyMAIA.utils.file_utils import subfiles\n", "\n", "def get_checkpoint(epoch, ckpt_dir):\n", " if epoch == \"latest\":\n", "\n", " latest_checkpoints = subfiles(ckpt_dir, prefix=\"checkpoint_epoch\", sort=True,\n", " join=False)\n", " epochs = []\n", " for latest_checkpoint in latest_checkpoints:\n", " epochs.append(int(latest_checkpoint[len(\"checkpoint_epoch=\"):-len(\".pt\")]))\n", "\n", " epochs.sort()\n", " latest_epoch = epochs[-1]\n", " return latest_epoch\n", " else:\n", " return epoch\n", "\n", "def reload_checkpoint(trainer, epoch, num_train_batches_per_epoch, ckpt_dir):\n", "\n", " epoch_to_load = get_checkpoint(epoch, ckpt_dir)\n", " trainer.state.epoch = epoch_to_load\n", " trainer.state.iteration = (epoch_to_load* num_train_batches_per_epoch) +1" ] }, { "cell_type": "markdown", "id": "b353be42", "metadata": {}, "source": [ "## Create MONAI Bundle" ] }, { "cell_type": "code", "execution_count": null, "id": "564700ff", "metadata": {}, "outputs": [], "source": [ "%%bash \n", "\n", "/home/maia-user/.conda/envs/MAIA/bin/python -m monai.bundle init_bundle nnUNetBundle\n", "# you may need to install tree with \"sudo apt install tree\"\n", "mkdir -p nnUNetBundle/nnUNet\n", "mkdir -p nnUNetBundle/src\n", "mkdir -p nnUNetBundle/nnUNet/evaluator\n", "which tree && tree nnUNetBundle || true" ] }, { "cell_type": "code", "execution_count": null, "id": "cb7aa3da", "metadata": {}, "outputs": [], "source": [ "%%writefile nnUNetBundle/nnUNet/global.yaml\n", "\n", "iterations: $@nnunet_trainer.num_iterations_per_epoch\n", "device: $@nnunet_trainer.device\n", "epochs: $@nnunet_trainer.num_epochs\n", "pymaia_config_dict: \"$json.load(open(@pymaia_config_file))\"\n", "bundle_root: .\n", "ckpt_dir: \"$@bundle_root + '/models'\"" ] }, { "cell_type": "code", "execution_count": null, "id": "33f32c17", "metadata": {}, "outputs": [], "source": [ "%%writefile nnUNetBundle/nnUNet/params.yaml\n", "\n", "num_classes: 2\n", "task_id: \"\"\n", "pymaia_config_file: \"\"\n", "tracking_uri: \"mlruns\"\n", "mlflow_experiment_name: \"\"\n", "mlflow_run_name: \"\"\n", "nnunet_model_folder: \"\"\n", "nnunet_trainer_class_name: \"nnUNetTrainer\"\n", "nnunet_plans_identifier: \"nnUNetPlans\"" ] }, { "cell_type": "code", "execution_count": null, "id": "e31c3314", "metadata": {}, "outputs": [], "source": [ "%%writefile nnUNetBundle/nnUNet/imports.yaml\n", "\n", "imports:\n", "- $import glob\n", "- $import os\n", "- $import ignite\n", "- $import torch\n", "- $import shutil\n", "- $import json\n", "- $import src\n", "- $from src.utils import create_mlflow_experiment_params\n", "- $from pathlib import Path" ] }, { "cell_type": "code", "execution_count": null, "id": "dc7fc76e", "metadata": {}, "outputs": [], "source": [ "%%writefile nnUNetBundle/nnUNet/run.yaml\n", "\n", "run:\n", "- \"src.utils.set_mlflow_token(@token)\"\n", "- \"$torch.save(@checkpoint,@checkpoint_filename)\"\n", "- \"$shutil.copy(Path(@nnunet_model_folder).joinpath('dataset.json'), @bundle_root+'/models/dataset.json')\"\n", "- \"$shutil.copy(Path(@nnunet_model_folder).joinpath('plans.json'), @bundle_root+'/models/plans.json')\"\n", "- \"$@train#pbar.attach(@train#trainer,output_transform=lambda x: {'loss': x[0]['loss']})\"\n", "- \"$@validate#pbar.attach(@validate#evaluator,metric_names=['Val_Dice'])\"\n", "- $@train#trainer.run()\n", "\n", "initialize:\n", "- $monai.utils.set_determinism(seed=123)" ] }, { "cell_type": "code", "execution_count": null, "id": "7268a30a", "metadata": {}, "outputs": [], "source": [ "%%writefile nnUNetBundle/nnUNet/train.yaml\n", "\n", "nnunet_trainer:\n", " _target_ : src.nnUNet_Trainer.get_nnunet_trainer\n", " dataset_name_or_id: \"@task_id\"\n", " configuration: \"3d_fullres\"\n", " fold: \"0\"\n", " pymaia_config_file: \"@pymaia_config_file\"\n", " trainer_class_name: \"@nnunet_trainer_class_name\"\n", " plans_identifier: \"@nnunet_plans_identifier\"\n", "\n", "nnunet_trainer_def:\n", " _target_ : src.nnUNet_Trainer.get_nnunet_trainer_from_preprocessing_folder\n", " plans_file: \"$@bundle_root+'/models/plans.json'\"\n", " dataset_file: \"$@bundle_root+'/models/dataset.json'\"\n", " configuration: \"3d_fullres\"\n", " fold: \"0\"\n", " trainer_class_name: \"@nnunet_trainer_class_name\"\n", " plans_identifier: \"@nnunet_plans_identifier\"\n", "\n", "loss: $@nnunet_trainer.loss\n", "lr_scheduler: $@nnunet_trainer.lr_scheduler\n", "\n", "network_def: $@nnunet_trainer_def.network\n", "network: $@nnunet_trainer.network\n", "\n", "optimizer: $@nnunet_trainer.optimizer\n", "\n", "checkpoint:\n", " init_args: '$@nnunet_trainer.my_init_kwargs'\n", " trainer_name: '$@nnunet_trainer.__class__.__name__'\n", " inference_allowed_mirroring_axes: '$@nnunet_trainer.inference_allowed_mirroring_axes'\n", "\n", "checkpoint_filename: \"$@bundle_root+'/models/nnunet_checkpoint.pth'\"\n", "output_dir: $@bundle_root + '/eval'\n", "\n", "train:\n", " pbar:\n", " _target_: \"ignite.contrib.handlers.tqdm_logger.ProgressBar\"\n", " dataloader: $@nnunet_trainer.dataloader_train\n", " train_data: \"$[{'case_identifier':k} for k in @nnunet_trainer.dataloader_train.generator._data.dataset.keys()]\"\n", " train_dataset:\n", " _target_: Dataset\n", " data: \"@train#train_data\"\n", " handlers:\n", " inferer:\n", " _target_: SimpleInferer\n", " key_metric:\n", " Train_Dice:\n", " _target_: \"MeanDice\"\n", " include_background: False\n", " output_transform: \"$monai.handlers.from_engine(['pred', 'label'])\"\n", " reduction: \"mean\"\n", " additional_metrics:\n", " Train_Dice_per_class:\n", " _target_: \"MeanDice\"\n", " include_background: False\n", " output_transform: \"$monai.handlers.from_engine(['pred', 'label'])\"\n", " reduction: \"mean_batch\"\n", " postprocessing:\n", " _target_: \"Compose\"\n", " transforms:\n", " - _target_: Lambdad\n", " keys:\n", " - \"pred\"\n", " - \"label\"\n", " func: \"$lambda x: x[0]\"\n", " - _target_: Activationsd\n", " keys:\n", " - \"pred\"\n", " softmax: True\n", " - _target_: AsDiscreted\n", " keys:\n", " - \"pred\"\n", " threshold: 0.5\n", " - _target_: AsDiscreted\n", " keys:\n", " - \"label\"\n", " to_onehot: \"@num_classes\"\n", " postprocessing_region_based:\n", " _target_: \"Compose\"\n", " transforms:\n", " - _target_: Lambdad\n", " keys:\n", " - \"pred\"\n", " - \"label\"\n", " func: \"$lambda x: x[0]\"\n", " - _target_: Activationsd\n", " keys:\n", " - \"pred\"\n", " sigmoid: True\n", " - _target_: AsDiscreted\n", " keys:\n", " - \"pred\"\n", " threshold: 0.5\n", " trainer:\n", " _target_: SupervisedTrainer\n", " amp: true\n", " device: '@device'\n", " additional_metrics: \"@train#additional_metrics\"\n", " epoch_length: \"@iterations\"\n", " inferer: '@train#inferer'\n", " key_train_metric: '@train#key_metric'\n", " loss_function: '@loss'\n", " max_epochs: '@epochs'\n", " network: '@network'\n", " prepare_batch: \"$src.nnUNet_Trainer.prepare_nnunet_batch\"\n", " optimizer: '@optimizer'\n", " postprocessing: '@train#postprocessing'\n", " train_data_loader: '@train#dataloader'\n", " train_handlers: '@train#handlers'" ] }, { "cell_type": "code", "execution_count": 1, "id": "4933773b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Writing nnUNetBundle/nnUNet/train_resume.yaml\n" ] } ], "source": [ "%%writefile nnUNetBundle/nnUNet/train_resume.yaml\n", "\n", "run:\n", "- \"$src.utils.set_mlflow_token(@token)\"\n", "- '$src.utils.reload_checkpoint(@train#trainer,@reload_checkpoint_epoch,@nnunet_trainer.num_iterations_per_epoch,@bundle_root+\"/models\")'\n", "- \"$@train#pbar.attach(@train#trainer,output_transform=lambda x: {'loss': x[0]['loss']})\"\n", "- \"$@validate#pbar.attach(@validate#evaluator,metric_names=['Val_Dice'])\"\n", "- $@train#trainer.run()\n", "\n", "train_handlers:\n", " handlers:\n", " - _target_: \"$src.nnUNet_Trainer.MLFlowPyMAIAHandler\"\n", " label_dict: \"$@pymaia_config_dict['label_dict']\"\n", " tracking_uri: \"@tracking_uri\"\n", " experiment_name: \"@mlflow_experiment_name\"\n", " run_name: \"@mlflow_run_name\"\n", " output_transform: \"$src.nnUNet_Trainer.mlflow_transform\"\n", " dataset_dict:\n", " train: \"@train#train_dataset\"\n", " dataset_keys: 'case_identifier'\n", " state_attributes:\n", " - \"iteration\"\n", " - \"epoch\"\n", " tag_name: 'Train_Loss'\n", " experiment_param: \"$src.utils.create_mlflow_experiment_params( @bundle_root + '/nnUNet/params.yaml')\"\n", " #artifacts=None\n", " optimizer_param_names: 'lr'\n", " #close_on_complete: False\n", " - _target_: LrScheduleHandler\n", " lr_scheduler: '@lr_scheduler'\n", " print_lr: true\n", " - _target_: ValidationHandler\n", " epoch_level: true\n", " interval: '@val_interval'\n", " validator: '@validate#evaluator'\n", " #- _target_: StatsHandler\n", " # output_transform: $monai.handlers.from_engine(['loss'], first=True)\n", " # tag_name: train_loss\n", " - _target_: TensorBoardStatsHandler\n", " log_dir: '@output_dir'\n", " output_transform: $monai.handlers.from_engine(['loss'], first=True)\n", " tag_name: train_loss\n", " - _target_: CheckpointLoader\n", " load_dict:\n", " network_weights: '$@nnunet_trainer.network._orig_mod'\n", " optimizer_state: '$@nnunet_trainer.optimizer'\n", " scheduler: '$@nnunet_trainer.lr_scheduler'\n", " load_path: '$@bundle_root + \"/models/checkpoint_epoch=\"+str(src.utils.get_checkpoint(@reload_checkpoint_epoch, @bundle_root+\"/models\"))+\".pt\"'\n", " map_location: '@device'" ] }, { "cell_type": "code", "execution_count": null, "id": "944f75b6", "metadata": {}, "outputs": [], "source": [ "%%writefile nnUNetBundle/nnUNet/train_handlers.yaml\n", "\n", "train_handlers:\n", " handlers:\n", " - _target_: \"$src.nnUNet_Trainer.MLFlowPyMAIAHandler\"\n", " label_dict: \"$@pymaia_config_dict['label_dict']\"\n", " tracking_uri: \"@tracking_uri\"\n", " experiment_name: \"@mlflow_experiment_name\"\n", " run_name: \"@mlflow_run_name\"\n", " output_transform: \"$src.nnUNet_Trainer.mlflow_transform\"\n", " dataset_dict:\n", " train: \"@train#train_dataset\"\n", " dataset_keys: 'case_identifier'\n", " state_attributes:\n", " - \"iteration\"\n", " - \"epoch\"\n", " tag_name: 'Train_Loss'\n", " experiment_param: \"$src.utils.create_mlflow_experiment_params( @bundle_root + '/nnUNet/params.yaml')\"\n", " #artifacts=None\n", " optimizer_param_names: 'lr'\n", " #close_on_complete: False\n", " - _target_: LrScheduleHandler\n", " lr_scheduler: '@lr_scheduler'\n", " print_lr: true\n", " - _target_: ValidationHandler\n", " epoch_level: true\n", " interval: '@val_interval'\n", " validator: '@validate#evaluator'\n", " #- _target_: StatsHandler\n", " # output_transform: $monai.handlers.from_engine(['loss'], first=True)\n", " # tag_name: train_loss\n", " - _target_: TensorBoardStatsHandler\n", " log_dir: '@output_dir'\n", " output_transform: $monai.handlers.from_engine(['loss'], first=True)\n", " tag_name: train_loss" ] }, { "cell_type": "code", "execution_count": null, "id": "4d3b2a5f", "metadata": {}, "outputs": [], "source": [ "%%writefile nnUNetBundle/nnUNet/validate.yaml\n", "\n", "val_interval: 1\n", "validate:\n", " pbar:\n", " _target_: \"ignite.contrib.handlers.tqdm_logger.ProgressBar\"\n", " key_metric:\n", " Val_Dice:\n", " _target_: \"MeanDice\"\n", " output_transform: \"$monai.handlers.from_engine(['pred', 'label'])\"\n", " reduction: \"mean\"\n", " include_background: False\n", " additional_metrics:\n", " Val_Dice_per_class:\n", " _target_: \"MeanDice\"\n", " output_transform: \"$monai.handlers.from_engine(['pred', 'label'])\"\n", " reduction: \"mean_batch\"\n", " include_background: False\n", " dataloader: $@nnunet_trainer.dataloader_val\n", " evaluator:\n", " _target_: SupervisedEvaluator\n", " additional_metrics: '@validate#additional_metrics'\n", " amp: true\n", " epoch_length: $@nnunet_trainer.num_val_iterations_per_epoch\n", " device: '@device'\n", " inferer: '@validate#inferer'\n", " key_val_metric: '@validate#key_metric'\n", " network: '@network'\n", " postprocessing: '@validate#postprocessing'\n", " val_data_loader: '@validate#dataloader'\n", " val_handlers: '@validate#handlers'\n", " prepare_batch: \"$src.nnUNet_Trainer.prepare_nnunet_batch\"\n", " handlers:\n", " - _target_: StatsHandler\n", " iteration_log: false\n", " - _target_: TensorBoardStatsHandler\n", " iteration_log: false\n", " log_dir: '@output_dir'\n", " - _target_: \"$src.nnUNet_Trainer.MLFlowPyMAIAHandler\"\n", " label_dict: \"$@pymaia_config_dict['label_dict']\"\n", " tracking_uri: \"@tracking_uri\"\n", " experiment_name: \"@mlflow_experiment_name\"\n", " run_name: \"@mlflow_run_name\"\n", " output_transform: \"$src.nnUNet_Trainer.mlflow_transform\"\n", " iteration_log: False\n", " state_attributes:\n", " - \"best_metric\"\n", " - \"best_metric_epoch\"\n", " - _target_: \"CheckpointSaver\"\n", " save_dir: \"$str(@bundle_root)+'/models'\"\n", " save_interval: 1\n", " n_saved: 1\n", " save_key_metric: true\n", " save_dict:\n", " network_weights: '$@nnunet_trainer.network._orig_mod'\n", " optimizer_state: '$@nnunet_trainer.optimizer'\n", " scheduler: '$@nnunet_trainer.lr_scheduler'\n", " inferer:\n", " _target_: SimpleInferer\n", " postprocessing: '%train#postprocessing'\n" ] }, { "cell_type": "code", "execution_count": null, "id": "51fae1b1", "metadata": {}, "outputs": [], "source": [ "%%writefile nnUNetBundle/nnUNet/evaluator/evaluator.yaml\n", "\n", "validate:\n", " pbar:\n", " _target_: \"ignite.contrib.handlers.tqdm_logger.ProgressBar\"\n", " key_metric:\n", " Val_Dice:\n", " _target_: \"MeanDice\"\n", " output_transform: \"$monai.handlers.from_engine(['pred', 'label'])\"\n", " reduction: \"mean\"\n", " include_background: False\n", " additional_metrics:\n", " Val_Dice_per_class:\n", " _target_: \"MeanDice\"\n", " output_transform: \"$monai.handlers.from_engine(['pred', 'label'])\"\n", " reduction: \"mean_batch\"\n", " include_background: False\n", " dataloader: $@nnunet_trainer.dataloader_val\n", " evaluator:\n", " _target_: SupervisedEvaluator\n", " additional_metrics: '@validate#additional_metrics'\n", " amp: true\n", " epoch_length: $@nnunet_trainer.num_val_iterations_per_epoch\n", " device: '@device'\n", " inferer: '@validate#inferer'\n", " key_val_metric: '@validate#key_metric'\n", " network: '@network'\n", " postprocessing: '@validate#postprocessing'\n", " val_data_loader: '@validate#dataloader'\n", " val_handlers: '@validate#handlers'\n", " prepare_batch: \"$src.nnUNet_Trainer.prepare_nnunet_batch\"\n", " handlers:\n", " - _target_: StatsHandler\n", " iteration_log: false\n", " - _target_: TensorBoardStatsHandler\n", " iteration_log: false\n", " log_dir: '@output_dir'\n", " - _target_: \"$src.nnUNet_Trainer.MLFlowPyMAIAHandler\"\n", " label_dict: \"$@pymaia_config_dict['label_dict']\"\n", " tracking_uri: \"@tracking_uri\"\n", " experiment_name: \"@mlflow_experiment_name\"\n", " run_name: \"@mlflow_run_name\"\n", " output_transform: \"$src.nnUNet_Trainer.mlflow_transform\"\n", " iteration_log: False\n", " state_attributes:\n", " - \"best_metric\"\n", " - \"best_metric_epoch\"\n", " inferer:\n", " _target_: SimpleInferer\n", " postprocessing: '%train#postprocessing'\n", "\n", "run:\n", "- \"src.utils.set_mlflow_token(@token)\"\n", "- \"$@validate#pbar.attach(@validate#evaluator,metric_names=['Val_Dice'])\"\n", "- $@validate#evaluator.run()\n", "\n", "initialize:\n", "- \"$setattr(torch.backends.cudnn, 'benchmark', True)\"" ] }, { "cell_type": "code", "execution_count": null, "id": "0f4fb6da", "metadata": {}, "outputs": [], "source": [ "%%writefile nnUNetBundle/src/__init__.py\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "194a46ba", "metadata": {}, "outputs": [], "source": [ "%%writefile nnUNetBundle/src/nnUNet_Trainer.py\n", "import torch\n", "from typing import Union, Optional\n", "import json\n", "from pathlib import Path\n", "import os\n", "from torch.backends import cudnn\n", "from batchgenerators.utilities.file_and_folder_operations import join, load_json\n", "from nnunetv2.utilities.find_class_by_name import recursive_find_python_class\n", "import nnunetv2\n", "from monai.handlers import MLFlowHandler\n", "\n", "\n", "def get_nnunet_trainer(dataset_name_or_id: Union[str, int],\n", " configuration: str, fold: Union[int, str],\n", " pymaia_config_file: str = None, # To set env variables\n", " trainer_class_name: str = 'nnUNetTrainer',\n", " plans_identifier: str = 'nnUNetPlans',\n", " pretrained_weights: Optional[str] = None,\n", " num_gpus: int = 1,\n", " use_compressed_data: bool = False,\n", " export_validation_probabilities: bool = False,\n", " continue_training: bool = False,\n", " only_run_validation: bool = False,\n", " disable_checkpointing: bool = False,\n", " val_with_best: bool = False,\n", " device: torch.device = torch.device(\n", " 'cuda'),\n", " pretrained_model = None\n", " ): # From nnUNet/nnunetv2/run/run_training.py#run_training\n", "\n", " ## Block Added\n", " \n", " if pymaia_config_file != None:\n", " with open(pymaia_config_file, \"r\") as f:\n", " pymaia_config_dict = json.load(f)\n", "\n", " os.environ[\"nnUNet_raw\"] = str(Path(pymaia_config_dict[\"base_folder\"]).joinpath(\"nnUNet_raw\"))\n", " os.environ[\"nnUNet_preprocessed\"] = pymaia_config_dict[\"preprocessing_folder\"]\n", " os.environ[\"nnUNet_results\"] = pymaia_config_dict[\"results_folder\"]\n", "\n", " from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint\n", " ## End Block\n", "\n", " if isinstance(fold, str):\n", " if fold != 'all':\n", " try:\n", " fold = int(fold)\n", " except ValueError as e:\n", " print(\n", " f'Unable to convert given value for fold to int: {fold}. fold must bei either \"all\" or an integer!')\n", " raise e\n", "\n", " if int(num_gpus) > 1:\n", " ... # Disable for now\n", " else:\n", " nnunet_trainer = get_trainer_from_args(str(dataset_name_or_id), configuration, fold, trainer_class_name,\n", " plans_identifier, use_compressed_data, device=device)\n", " \n", " if disable_checkpointing:\n", " nnunet_trainer.disable_checkpointing = disable_checkpointing\n", "\n", " assert not (\n", " continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.'\n", "\n", " maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights)\n", " nnunet_trainer.on_train_start() # Added to Initialize Trainer\n", " if torch.cuda.is_available():\n", " cudnn.deterministic = False\n", " cudnn.benchmark = True\n", "\n", " if pretrained_model is not None:\n", " state_dict = torch.load(pretrained_model)\n", " if 'network_weights' in state_dict:\n", " nnunet_trainer.network._orig_mod.load_state_dict(state_dict['network_weights'])\n", " #nnunet_trainer.network.load_state_dict(torch.load(pretrained_model)['model'])\n", " # Skip Training and Validation Phase \n", " return nnunet_trainer\n", "\n", "def get_nnunet_trainer_from_preprocessing_folder(\n", " plans_file,\n", " dataset_file,\n", " configuration: str,\n", " fold: Union[int, str],\n", " trainer_class_name: str = 'nnUNetTrainer',\n", " plans_identifier: str = 'nnUNetPlans',\n", " pretrained_weights: Optional[str] = None,\n", " num_gpus: int = 1,\n", " use_compressed_data: bool = False,\n", " export_validation_probabilities: bool = False,\n", " continue_training: bool = False,\n", " only_run_validation: bool = False,\n", " disable_checkpointing: bool = False,\n", " val_with_best: bool = False,\n", " device: torch.device = torch.device(\n", " 'cuda'),\n", " pretrained_model = None\n", "): # From nnUNet/nnunetv2/run/run_training.py#run_training\n", "\n", " ## Block Added\n", " os.environ[\"nnUNet_raw\"] = str(Path(\".\").joinpath(\"nnUNet_raw\"))\n", " os.environ[\"nnUNet_preprocessed\"] = \".\"\n", " os.environ[\"nnUNet_results\"] = \".\"\n", " from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint\n", " ##\n", "\n", " if isinstance(fold, str):\n", " if fold != 'all':\n", " try:\n", " fold = int(fold)\n", " except ValueError as e:\n", " print(\n", " f'Unable to convert given value for fold to int: {fold}. fold must bei either \"all\" or an integer!')\n", " raise e\n", "\n", " if int(num_gpus) > 1:\n", " ... # Disable for now\n", " else:\n", " nnunet_trainer = recursive_find_python_class(join(nnunetv2.__path__[0], \"training\", \"nnUNetTrainer\"),\n", " trainer_class_name, 'nnunetv2.training.nnUNetTrainer')\n", "\n", " plans = load_json(plans_file)\n", " dataset_json = load_json(dataset_file)\n", "\n", " nnunet_trainer = nnunet_trainer(plans=plans, configuration=configuration, fold=fold,\n", " dataset_json=dataset_json, unpack_dataset=False, device=torch.device(\"cuda\"))\n", "\n", " if disable_checkpointing:\n", " nnunet_trainer.disable_checkpointing = disable_checkpointing\n", "\n", " assert not (\n", " continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.'\n", "\n", " maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights)\n", " nnunet_trainer.initialize() # To Initialize Trainer\n", " if torch.cuda.is_available():\n", " cudnn.deterministic = False\n", " cudnn.benchmark = True\n", "\n", " # Skip Training and Validation Phase\n", " if pretrained_model is not None:\n", " state_dict = torch.load(pretrained_model)\n", " if 'network_weights' in state_dict:\n", " nnunet_trainer.network._orig_mod.load_state_dict(state_dict['network_weights'])\n", " #nnunet_trainer.network.load_state_dict(torch.load(pretrained_model)['model'])\n", "\n", " return nnunet_trainer\n", "\n", "def prepare_nnunet_batch(batch, device, non_blocking):\n", " data = batch[\"data\"].to(device, non_blocking=non_blocking)\n", " if isinstance(batch[\"target\"], list):\n", " target = [i.to(device, non_blocking=non_blocking) for i in batch[\"target\"]]\n", " else:\n", " target = batch[\"target\"].to(device, non_blocking=non_blocking)\n", " return data, target\n", "\n", "def mlflow_transform(state_output):\n", " return state_output[0]['loss']\n", "\n", "class MLFlowPyMAIAHandler(MLFlowHandler):\n", " def __init__(self, label_dict, **kwargs):\n", " super(MLFlowPyMAIAHandler, self).__init__(**kwargs)\n", " self.label_dict = label_dict\n", " \n", " def _default_epoch_log(self, engine) -> None:\n", " \"\"\"\n", " Execute epoch level log operation.\n", " Default to track the values from Ignite `engine.state.metrics` dict and\n", " track the values of specified attributes of `engine.state`.\n", "\n", " Args:\n", " engine: Ignite Engine, it can be a trainer, validator or evaluator.\n", "\n", " \"\"\"\n", " log_dict = engine.state.metrics\n", " if not log_dict:\n", " return\n", "\n", " current_epoch = self.global_epoch_transform(engine.state.epoch)\n", "\n", " new_log_dict = {}\n", "\n", " for metric in log_dict:\n", " if type(log_dict[metric]) == torch.Tensor:\n", " for i,val in enumerate(log_dict[metric]):\n", " new_log_dict[metric+\"_{}\".format(list(self.label_dict.keys())[i+1])] = val\n", " else:\n", " new_log_dict[metric] = log_dict[metric]\n", " self._log_metrics(new_log_dict, step=current_epoch)\n", "\n", " if self.state_attributes is not None:\n", " attrs = {attr: getattr(engine.state, attr, None) for attr in self.state_attributes}\n", " self._log_metrics(attrs, step=current_epoch)" ] }, { "cell_type": "code", "execution_count": null, "id": "d679a387", "metadata": {}, "outputs": [], "source": [ "%%writefile nnUNetBundle/src/utils.py\n", "\n", "import re\n", "from PyMAIA.utils.file_utils import subfiles\n", "import yaml\n", "from monai.bundle import ConfigParser\n", "import monai\n", "from pathlib import Path\n", "import os\n", "\n", "def set_mlflow_token(token):\n", " os.environ[\"MLFLOW_TRACKING_TOKEN\"] = token\n", "\n", "def create_mlflow_experiment_params(params_file, custom_params=None):\n", " params_dict = {}\n", " config_values = monai.config.deviceconfig.get_config_values()\n", " for k in config_values:\n", " params_dict[re.sub(\"[()]\",\" \",str(k))] = config_values[k]\n", "\n", " optional_config_values = monai.config.deviceconfig.get_optional_config_values()\n", " for k in optional_config_values:\n", " params_dict[re.sub(\"[()]\",\" \",str(k))] = optional_config_values[k]\n", "\n", " gpu_info = monai.config.deviceconfig.get_gpu_info()\n", " for k in gpu_info:\n", " params_dict[re.sub(\"[()]\",\" \",str(k))] = str(gpu_info[k])\n", "\n", " yaml_config_files = [params_file]\n", " # %%\n", " monai_config = {}\n", " for config_file in yaml_config_files:\n", " with open(config_file, 'r') as file:\n", " monai_config.update(yaml.safe_load(file))\n", "\n", " monai_config[\"bundle_root\"] = str(Path(Path(params_file).parent).parent)\n", "\n", " parser = ConfigParser(monai_config, globals={\"os\": \"os\",\n", " \"pathlib\": \"pathlib\",\n", " \"json\": \"json\",\n", " \"ignite\": \"ignite\"\n", " })\n", "\n", " parser.parse(True)\n", "\n", " for k in monai_config:\n", " params_dict[k] = parser.get_parsed_content(k,instantiate=True)\n", "\n", " if custom_params is not None:\n", " for k in custom_params:\n", " params_dict[k] = custom_params[k]\n", " return params_dict\n", "\n", "def get_checkpoint(epoch, ckpt_dir):\n", " if epoch == \"latest\":\n", "\n", " latest_checkpoints = subfiles(ckpt_dir, prefix=\"checkpoint_epoch\", sort=True,\n", " join=False)\n", " epochs = []\n", " for latest_checkpoint in latest_checkpoints:\n", " epochs.append(int(latest_checkpoint[len(\"checkpoint_epoch=\"):-len(\".pt\")]))\n", "\n", " epochs.sort()\n", " latest_epoch = epochs[-1]\n", " return latest_epoch\n", " else:\n", " return epoch\n", "\n", "def reload_checkpoint(trainer, epoch, num_train_batches_per_epoch, ckpt_dir):\n", "\n", " epoch_to_load = get_checkpoint(epoch, ckpt_dir)\n", " trainer.state.epoch = epoch_to_load\n", " trainer.state.iteration = (epoch_to_load* num_train_batches_per_epoch) +1" ] }, { "cell_type": "markdown", "id": "3f221410", "metadata": {}, "source": [ "## Inference\n", "\n", "After training the nnUNet model, we can then perform inference on new data. We use a `nnUNetModelWrapper` as a wrapper around the nnUNet model to perform inference from the MONAI Bundle. In this way, the nnUNet preprocessing, inference and postprocessing steps are handled by the `nnUNetModelWrapper`, with the Bundle blocks only needing to handle the input data loading and sending to the nnUnet block and the nnUNet prediction postprocessing.\n", "\n", "The `nnUNetModelWrapper` receives as input the data dictionary loaded by the DataLoader, and returns the model predictions as a MetaTensor." ] }, { "cell_type": "code", "execution_count": null, "id": "9112d455", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import os\n", "from typing import Union, Optional\n", "import torch\n", "from monai.data.meta_tensor import MetaTensor\n", "from torch.backends import cudnn\n", "import setuptools\n", "from batchgenerators.utilities.file_and_folder_operations import join, isfile, load_json\n", "import numpy as np\n", "import monai\n", "from tqdm import tqdm\n", "from pathlib import Path\n", "import json" ] }, { "cell_type": "code", "execution_count": null, "id": "bd0f33e1", "metadata": {}, "outputs": [], "source": [ "\n", "\n", "class nnUNetModelWrapper(torch.nn.Module):\n", " def __init__(self, predictor, model_folder):\n", " super().__init__()\n", " self.predictor = predictor\n", " self.predictor.initialize_from_trained_model_folder(\n", " model_folder,\n", " use_folds=(0,),\n", " checkpoint_name='checkpoint_final.pth',\n", " )\n", " self.network_weights = self.predictor.network\n", "\n", " def forward(self, x):\n", " if type(x) is tuple:\n", " input_files = [img.meta['filename_or_obj'][0] for img in x]\n", " else:\n", " input_files = x.meta['filename_or_obj']\n", " if type(input_files) == str:\n", " input_files = [input_files]\n", "\n", " \n", " output = self.predictor.predict_from_files(\n", " [input_files],\n", " None,\n", " save_probabilities=False, overwrite=True,\n", " num_processes_preprocessing=2, num_processes_segmentation_export=2,\n", " folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)\n", "\n", " out_tensors= []\n", " for out in output:\n", " out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0),0)))\n", " out_tensor = torch.cat(out_tensors, 0)\n", "\n", " if type(x) is tuple:\n", " return MetaTensor(out_tensor, meta=x[0].meta)\n", " else:\n", " return MetaTensor(out_tensor, meta=x.meta)" ] }, { "cell_type": "code", "execution_count": null, "id": "8a192cc4", "metadata": {}, "outputs": [], "source": [ "def get_nnunet_predictor(model_folder):\n", " \n", " from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor\n", " predictor = nnUNetPredictor(\n", " tile_step_size=0.5,\n", " use_gaussian=True,\n", " use_mirroring=False,\n", " device=torch.device('cuda', 0),\n", " verbose=False,\n", " verbose_preprocessing=False,\n", " allow_tqdm=True\n", " )\n", " # initializes the network architecture, loads the checkpoint\n", " wrapper = nnUNetModelWrapper(predictor, model_folder)\n", " return wrapper\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6adfa0a0", "metadata": {}, "outputs": [], "source": [ "network = get_nnunet_predictor(\"/home/maia-user/Tutorials/MAIA/Experiments/Task09_Spleen/Task09_Spleen_results/Dataset109_Task09_Spleen/nnUNetTrainer__nnUNetResEncUNetLPlans__3d_fullres\")" ] }, { "cell_type": "markdown", "id": "d9258c3b", "metadata": {}, "source": [ "## Test Data Preparation\n", "\n", "The Bundle accepts the test dataset in the following format:\n", "\n", "```bash\n", "Dataset\n", "├── Case1\n", "│ └── Case1.nii.gz\n", "├── Case2\n", "│ └── Case2.nii.gz\n", "└── Case3\n", " └── Case3.nii.gz\n", "```" ] }, { "cell_type": "code", "execution_count": null, "id": "07229e86", "metadata": {}, "outputs": [], "source": [ "%%bash\n", "\n", "mkdir -p MAIA/MONAI_Bundle/input\n", "mkdir -p MAIA/MONAI_Bundle/output\n", "mkdir -p MAIA/MONAI_Bundle/input/spleen_1\n", "\n", "cp MAIA/Task09_Spleen/imagesTs/spleen_1.nii.gz MAIA/MONAI_Bundle/input/spleen_1" ] }, { "cell_type": "code", "execution_count": null, "id": "308abf67", "metadata": {}, "outputs": [], "source": [ "%%bash\n", "\n", "tree MAIA/MONAI_Bundle/input" ] }, { "cell_type": "code", "execution_count": null, "id": "f10d59d7", "metadata": {}, "outputs": [], "source": [ "import pathlib\n", "\n", "def get_subfolder_dataset(data_dir,modality_conf):\n", " data_list = []\n", " for f in os.scandir(data_dir):\n", "\n", " if f.is_dir():\n", " subject_dict = {key:str(pathlib.Path(f.path).joinpath(f.name+modality_conf[key]['suffix'])) for key in modality_conf}\n", " data_list.append(subject_dict)\n", " return data_list" ] }, { "cell_type": "markdown", "id": "6d7e5b73", "metadata": {}, "source": [ "### Data Loading" ] }, { "cell_type": "code", "execution_count": null, "id": "5f3972b9", "metadata": {}, "outputs": [], "source": [ "modalities = {\n", " \"image\": {\"suffix\": \".nii.gz\"},\n", "}\n", "\n", "data = get_subfolder_dataset(\"MAIA/MONAI_Bundle/input\",modalities)" ] }, { "cell_type": "code", "execution_count": null, "id": "9e24a629", "metadata": {}, "outputs": [], "source": [ "from monai.transforms import LoadImaged\n", "from monai.data import Dataset, DataLoader\n", "\n", "preprocessing = LoadImaged(keys=[\"image\"],ensure_channel_first=True, image_only=False)\n", "\n", "\n", "test_dataset = Dataset(data,transform=preprocessing)\n", "\n", "test_loader = DataLoader(test_dataset, batch_size=1)" ] }, { "cell_type": "markdown", "id": "8c637fba", "metadata": {}, "source": [ "### Test nnUNetModelWrapper\n", "\n", "To test the `nnUNetModelWrapper`, we can provide a test case to the `nnUNetModelWrapper` and extract the model predictions returned by the wrapper." ] }, { "cell_type": "code", "execution_count": null, "id": "2386fc9c", "metadata": {}, "outputs": [], "source": [ "batch = next(iter(test_loader))\n", "\n", "pred = network(batch[\"image\"])" ] }, { "cell_type": "markdown", "id": "3e010b7e", "metadata": {}, "source": [ "### Postprocessing and Save Predictions\n", "\n", "After obtaining the model predictions, we can apply postprocessing transformations to the predictions and save the results to disk.\n", "\n", "The `Transposed` transform is required to unify the axis order convention between MONAI and nnUNet. The nnUNet model uses the `zyx` axis order, while MONAI uses the `xyz` axis order." ] }, { "cell_type": "code", "execution_count": null, "id": "ccd5a438", "metadata": {}, "outputs": [], "source": [ "from monai.transforms import Compose, Transposed, SaveImaged\n", "\n", "\n", "class PreprocessNameFormatter:\n", " def __init__(self, filename_key):\n", " self.filename_key = filename_key\n", "\n", "\n", " def __call__(self, metadict: dict, saver) -> dict:\n", " \"\"\"Returns a kwargs dict for :py:meth:`FolderLayout.filename`,\n", " according to the input metadata and SaveImage transform.\"\"\"\n", " subject = (\n", " metadict.get(monai.utils.ImageMetaKey.FILENAME_OR_OBJ, getattr(saver, \"_data_index\", 0))\n", " if metadict\n", " else getattr(saver, \"_data_index\", 0)\n", " )\n", " patch_index = metadict.get(monai.utils.ImageMetaKey.PATCH_INDEX, None) if metadict else None\n", " subject = subject[:-len(self.filename_key)]+\".nii.gz\"\n", " return {\"subject\": f\"{subject}\", \"idx\": patch_index}\n", "\n", "\n", "postprocessing = Compose([\n", " Transposed(keys=\"pred\",indices=[0,3,2,1]),\n", " SaveImaged(keys=\"pred\",\n", " output_dir=\"MAIA/MONAI_Bundle/output\",\n", " output_postfix=\"prediction\",\n", " meta_keys=\"image_meta_dict\",\n", " output_name_formatter=PreprocessNameFormatter(modalities[list(modalities.keys())[0]]['suffix'])\n", " )\n", "])" ] }, { "cell_type": "markdown", "id": "9c85dd88", "metadata": {}, "source": [ "## Evaluator\n", "\n", "Combining everything together, we can create an `Evaluator` that encapsulates the data loading, model inference, postprocessing, and evaluation steps. The `Evaluator` can be used to evaluate the model on the test dataset ." ] }, { "cell_type": "code", "execution_count": null, "id": "395e5615", "metadata": {}, "outputs": [], "source": [ "def prepare_nnunet_inference_batch(batch, device, non_blocking):\n", " \n", " return batch[\"image\"], None" ] }, { "cell_type": "code", "execution_count": null, "id": "bbf1fec9", "metadata": {}, "outputs": [], "source": [ "from monai.engines import SupervisedEvaluator\n", "\n", "validator = SupervisedEvaluator(\n", " val_data_loader=test_loader,\n", " device = \"cuda:0\",\n", " prepare_batch=prepare_nnunet_inference_batch,\n", " network = network,\n", " postprocessing= postprocessing\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "67970bc2", "metadata": {}, "outputs": [], "source": [ "validator.run()" ] }, { "cell_type": "markdown", "id": "a5950082", "metadata": {}, "source": [ "## nnUNetModelWrapper from MONAI Weights\n", "\n", "In alternative to the inference presented above, we can also load the MONAI weights and create a `nnUNetMONAIModelWrapper` from the MONAI weights. This way, we can perform inference using the MONAI weights." ] }, { "cell_type": "code", "execution_count": null, "id": "2692f561", "metadata": {}, "outputs": [], "source": [ "from nnunetv2.utilities.plans_handling.plans_handler import PlansManager\n", "from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels\n", "from torch._dynamo import OptimizedModule\n", "import nnunetv2\n", "from nnunetv2.utilities.find_class_by_name import recursive_find_python_class\n", "\n", "class nnUNetMONAIModelWrapper(torch.nn.Module):\n", " def __init__(self, predictor, model_folder, model_name=\"model.pt\"):\n", " super().__init__()\n", " self.predictor = predictor\n", "\n", " model_training_output_dir = model_folder\n", " use_folds = '0'\n", "\n", " ## Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor\n", " dataset_json = load_json(join(model_training_output_dir, 'dataset.json'))\n", " plans = load_json(join(model_training_output_dir, 'plans.json'))\n", " plans_manager = PlansManager(plans)\n", "\n", " if isinstance(use_folds, str):\n", " use_folds = [use_folds]\n", "\n", " parameters = []\n", " for i, f in enumerate(use_folds):\n", " f = int(f) if f != 'all' else f\n", " checkpoint = torch.load(join(model_training_output_dir, 'nnunet_checkpoint.pth'),\n", " map_location=torch.device('cpu'))\n", " monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device('cpu'))\n", " if i == 0:\n", " trainer_name = checkpoint['trainer_name']\n", " configuration_name = checkpoint['init_args']['configuration']\n", " inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \\\n", " 'inference_allowed_mirroring_axes' in checkpoint.keys() else None\n", "\n", " parameters.append(monai_checkpoint['network_weights'])\n", "\n", " configuration_manager = plans_manager.get_configuration(configuration_name)\n", " # restore network\n", " num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)\n", " trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], \"training\", \"nnUNetTrainer\"),\n", " trainer_name, 'nnunetv2.training.nnUNetTrainer')\n", " if trainer_class is None:\n", " raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. '\n", " f'Please place it there (in any .py file)!')\n", " network = trainer_class.build_network_architecture(\n", " configuration_manager.network_arch_class_name,\n", " configuration_manager.network_arch_init_kwargs,\n", " configuration_manager.network_arch_init_kwargs_req_import,\n", " num_input_channels,\n", " plans_manager.get_label_manager(dataset_json).num_segmentation_heads,\n", " enable_deep_supervision=False\n", " )\n", "\n", " predictor.plans_manager = plans_manager\n", " predictor.configuration_manager = configuration_manager\n", " predictor.list_of_parameters = parameters\n", " predictor.network = network\n", " predictor.dataset_json = dataset_json\n", " predictor.trainer_name = trainer_name\n", " predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes\n", " predictor.label_manager = plans_manager.get_label_manager(dataset_json)\n", " if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \\\n", " and not isinstance(predictor.network, OptimizedModule):\n", " print('Using torch.compile')\n", " predictor.network = torch.compile(self.network)\n", " ## End Block\n", " self.network_weights = self.predictor.network\n", "\n", " def forward(self, x):\n", " if type(x) is tuple:\n", " input_files = [img.meta['filename_or_obj'][0] for img in x]\n", " else:\n", " input_files = x.meta['filename_or_obj']\n", " if type(input_files) == str:\n", " input_files = [input_files]\n", "\n", " \n", " output = self.predictor.predict_from_files(\n", " [input_files],\n", " None,\n", " save_probabilities=False, overwrite=True,\n", " num_processes_preprocessing=2, num_processes_segmentation_export=2,\n", " folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)\n", "\n", " out_tensors= []\n", " for out in output:\n", " out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0),0)))\n", " out_tensor = torch.cat(out_tensors, 0)\n", "\n", " if type(x) is tuple:\n", " return MetaTensor(out_tensor, meta=x[0].meta)\n", " else:\n", " return MetaTensor(out_tensor, meta=x.meta)\n", " \n", "def get_nnunet_monai_predictor(model_folder, model_name=\"model.pt\"):\n", " \n", " from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor\n", " predictor = nnUNetPredictor(\n", " tile_step_size=0.5,\n", " use_gaussian=True,\n", " use_mirroring=False,\n", " device=torch.device('cuda', 0),\n", " verbose=False,\n", " verbose_preprocessing=False,\n", " allow_tqdm=True\n", " )\n", " # initializes the network architecture, loads the checkpoint\n", " wrapper = nnUNetMONAIModelWrapper(predictor, model_folder, model_name)\n", " return wrapper" ] }, { "cell_type": "code", "execution_count": null, "id": "2a4aa3c2", "metadata": {}, "outputs": [], "source": [ "from monai.handlers import CheckpointLoader\n", "\n", "network = get_nnunet_monai_predictor(\"/home/maia-user/Tutorials/nnUNetBundle/models\")\n", "\n", "# Optional: Load the best model, not needed since the checkpoint is already loaded in the wrapper\n", "val_handlers = [\n", " CheckpointLoader(\n", " load_dict={\n", " 'network_weights': network.network_weights,\n", " },\n", " strict=True,\n", " load_path=\"/home/maia-user/Tutorials/nnUNetBundle/models/best_model.pt\",\n", "\n", " )\n", "]\n", "\n", "validator = SupervisedEvaluator(\n", " val_data_loader=test_loader,\n", " device = \"cuda:0\",\n", " prepare_batch=prepare_nnunet_inference_batch,\n", " network = network,\n", " postprocessing= postprocessing,\n", " val_handlers= val_handlers\n", ")\n", "\n", "validator.run()" ] }, { "cell_type": "code", "execution_count": null, "id": "f63d8dd8", "metadata": {}, "outputs": [], "source": [ "%%writefile nnUNetBundle/configs/inference.yaml\n", "\n", "imports: \n", " - $import json\n", " - $import src\n", " - $import src.inferer\n", " - $import src.dataset\n", " - $from pathlib import Path\n", " - $import os\n", " - $from ignite.contrib.handlers.tqdm_logger import ProgressBar\n", " - $import shutil\n", "\n", "output_dir: \".\"\n", "bundle_root: \".\"\n", "data_dir: \".\"\n", "model_folder: \".\"\n", "prediction_suffix: \"prediction\"\n", "modality_conf:\n", " image:\n", " suffix: \".nii.gz\"\n", "\n", "test_data_list: \"$src.dataset.get_subfolder_dataset(@data_dir,@modality_conf)\"\n", "image_modality_keys: \"$list(@modality_conf.keys())\"\n", "image_key: \"image\"\n", "image_suffix: \"@image_key\"\n", "\n", "preprocessing:\n", " _target_: Compose\n", " transforms:\n", " - _target_: LoadImaged\n", " keys: \"@image_modality_keys\"\n", " ensure_channel_first: True\n", " image_only: False\n", "\n", "test_dataset:\n", " _target_: Dataset\n", " data: \"$@test_data_list\"\n", " transform: \"@preprocessing\"\n", "\n", "test_loader:\n", " _target_: DataLoader\n", " dataset: \"@test_dataset\"\n", " batch_size: 1\n", " #collate_fn: \"$monai.data.utils.no_collation\"\n", "\n", "\n", "device: \"$torch.device('cuda')\"\n", "\n", "nnunet_config:\n", " model_folder: \"$@bundle_root + '/models'\"\n", " #model_folder: \"@model_folder\"\n", "\n", "#network_def: \"$src.inferer.get_nnunet_predictor(**@nnunet_config)\"\n", "network_def: \"$src.inferer.get_nnunet_monai_predictor(**@nnunet_config)\"\n", "\n", "postprocessing:\n", " _target_: \"Compose\"\n", " transforms:\n", " - _target_: Transposed\n", " keys: \"pred\"\n", " indices:\n", " - 0\n", " - 3\n", " - 2\n", " - 1\n", " - _target_: SaveImaged\n", " keys: \"pred\"\n", " resample: False\n", " output_postfix: \"@prediction_suffix\"\n", " output_dir: \"@output_dir\"\n", " meta_keys: \"image_meta_dict\"\n", " output_name_formatter: \"$src.inferer.PreprocessNameFormatter(@modality_conf[list(@modality_conf.keys())[0]]['suffix'])\"\n", "\n", "\n", "testing:\n", " dataloader: \"$@test_loader\"\n", " pbar:\n", " _target_: \"ignite.contrib.handlers.tqdm_logger.ProgressBar\"\n", " test_inferer: \"$@inferer\"\n", "\n", "inferer: \n", " _target_: \"SimpleInferer\"\n", "\n", "validator:\n", " _target_: \"SupervisedEvaluator\"\n", " postprocessing: \"$@postprocessing\"\n", " device: \"$@device\"\n", " inferer: \"$@testing#test_inferer\"\n", " val_data_loader: \"$@testing#dataloader\"\n", " network: \"@network_def\"\n", " prepare_batch: \"$src.inferer.prepare_nnunet_inference_batch\"\n", " val_handlers:\n", " - _target_: \"CheckpointLoader\"\n", " load_path: \"$@bundle_root+'/models/model.pt'\"\n", " load_dict:\n", " network_weights: '$@network_def.network_weights'\n", "run:\n", " - \"$@testing#pbar.attach(@validator)\"\n", " - \"$@validator.run()\"" ] }, { "cell_type": "code", "execution_count": null, "id": "0ae05956", "metadata": {}, "outputs": [], "source": [ "%%writefile nnUNetBundle/src/inferer.py\n", "\n", "import torch\n", "import os\n", "from typing import Union, Optional\n", "import torch\n", "from monai.data.meta_tensor import MetaTensor\n", "from torch.backends import cudnn\n", "import setuptools\n", "from batchgenerators.utilities.file_and_folder_operations import join, isfile, load_json\n", "import numpy as np\n", "import monai\n", "from tqdm import tqdm\n", "from pathlib import Path\n", "import json\n", "from nnunetv2.utilities.plans_handling.plans_handler import PlansManager\n", "from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels\n", "from torch._dynamo import OptimizedModule\n", "import nnunetv2\n", "from nnunetv2.utilities.find_class_by_name import recursive_find_python_class\n", "\n", "class PreprocessNameFormatter:\n", " def __init__(self, filename_key):\n", " self.filename_key = filename_key\n", "\n", "\n", " def __call__(self, metadict: dict, saver) -> dict:\n", " \"\"\"Returns a kwargs dict for :py:meth:`FolderLayout.filename`,\n", " according to the input metadata and SaveImage transform.\"\"\"\n", " subject = (\n", " metadict.get(monai.utils.ImageMetaKey.FILENAME_OR_OBJ, getattr(saver, \"_data_index\", 0))\n", " if metadict\n", " else getattr(saver, \"_data_index\", 0)\n", " )\n", " patch_index = metadict.get(monai.utils.ImageMetaKey.PATCH_INDEX, None) if metadict else None\n", " subject = subject[:-len(self.filename_key)]+\".nii.gz\"\n", " return {\"subject\": f\"{subject}\", \"idx\": patch_index}\n", "\n", "class nnUNetModelWrapper(torch.nn.Module):\n", " def __init__(self, predictor, model_folder):\n", " super().__init__()\n", " self.predictor = predictor\n", " self.predictor.initialize_from_trained_model_folder(\n", " model_folder,\n", " use_folds=(0,),\n", " checkpoint_name='checkpoint_final.pth',\n", " )\n", " self.network_weights = self.predictor.network\n", "\n", " def forward(self, x):\n", " if type(x) is tuple:\n", " input_files = [img.meta['filename_or_obj'][0] for img in x]\n", " else:\n", " input_files = x.meta['filename_or_obj']\n", " if type(input_files) == str:\n", " input_files = [input_files]\n", "\n", " \n", " output = self.predictor.predict_from_files(\n", " [input_files],\n", " None,\n", " save_probabilities=False, overwrite=True,\n", " num_processes_preprocessing=2, num_processes_segmentation_export=2,\n", " folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)\n", "\n", " out_tensors= []\n", " for out in output:\n", " out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0),0)))\n", " out_tensor = torch.cat(out_tensors, 0)\n", "\n", " if type(x) is tuple:\n", " return MetaTensor(out_tensor, meta=x[0].meta)\n", " else:\n", " return MetaTensor(out_tensor, meta=x.meta)\n", "\n", "def get_nnunet_predictor(model_folder):\n", " \n", " from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor\n", " predictor = nnUNetPredictor(\n", " tile_step_size=0.5,\n", " use_gaussian=True,\n", " use_mirroring=False,\n", " #perform_everything_on_device=True,\n", " device=torch.device('cuda', 0),\n", " verbose=False,\n", " verbose_preprocessing=False,\n", " allow_tqdm=True\n", " )\n", " # initializes the network architecture, loads the checkpoint\n", " wrapper = nnUNetModelWrapper(predictor, model_folder)\n", " return wrapper\n", "\n", "def prepare_nnunet_inference_batch(batch, device, non_blocking):\n", " \n", " return batch[\"image\"], None\n", "\n", "class nnUNetMONAIModelWrapper(torch.nn.Module):\n", " def __init__(self, predictor, model_folder, model_name=\"model.pt\"):\n", " super().__init__()\n", " self.predictor = predictor\n", "\n", " model_training_output_dir = model_folder\n", " use_folds = '0'\n", "\n", " ## Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor\n", " dataset_json = load_json(join(model_training_output_dir, 'dataset.json'))\n", " plans = load_json(join(model_training_output_dir, 'plans.json'))\n", " plans_manager = PlansManager(plans)\n", "\n", " if isinstance(use_folds, str):\n", " use_folds = [use_folds]\n", "\n", " parameters = []\n", " for i, f in enumerate(use_folds):\n", " f = int(f) if f != 'all' else f\n", " checkpoint = torch.load(join(model_training_output_dir, 'nnunet_checkpoint.pth'),\n", " map_location=torch.device('cpu'))\n", " monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device('cpu'))\n", " if i == 0:\n", " trainer_name = checkpoint['trainer_name']\n", " configuration_name = checkpoint['init_args']['configuration']\n", " inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \\\n", " 'inference_allowed_mirroring_axes' in checkpoint.keys() else None\n", "\n", " parameters.append(monai_checkpoint['network_weights'])\n", "\n", " configuration_manager = plans_manager.get_configuration(configuration_name)\n", " # restore network\n", " num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)\n", " trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], \"training\", \"nnUNetTrainer\"),\n", " trainer_name, 'nnunetv2.training.nnUNetTrainer')\n", " if trainer_class is None:\n", " raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. '\n", " f'Please place it there (in any .py file)!')\n", " network = trainer_class.build_network_architecture(\n", " configuration_manager.network_arch_class_name,\n", " configuration_manager.network_arch_init_kwargs,\n", " configuration_manager.network_arch_init_kwargs_req_import,\n", " num_input_channels,\n", " plans_manager.get_label_manager(dataset_json).num_segmentation_heads,\n", " enable_deep_supervision=False\n", " )\n", "\n", " predictor.plans_manager = plans_manager\n", " predictor.configuration_manager = configuration_manager\n", " predictor.list_of_parameters = parameters\n", " predictor.network = network\n", " predictor.dataset_json = dataset_json\n", " predictor.trainer_name = trainer_name\n", " predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes\n", " predictor.label_manager = plans_manager.get_label_manager(dataset_json)\n", " if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \\\n", " and not isinstance(predictor.network, OptimizedModule):\n", " print('Using torch.compile')\n", " predictor.network = torch.compile(self.network)\n", " ## End Block\n", " self.network_weights = self.predictor.network\n", "\n", " def forward(self, x):\n", " if type(x) is tuple:\n", " input_files = [img.meta['filename_or_obj'][0] for img in x]\n", " else:\n", " input_files = x.meta['filename_or_obj']\n", " if type(input_files) == str:\n", " input_files = [input_files]\n", "\n", " \n", " output = self.predictor.predict_from_files(\n", " [input_files],\n", " None,\n", " save_probabilities=False, overwrite=True,\n", " num_processes_preprocessing=2, num_processes_segmentation_export=2,\n", " folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)\n", "\n", " out_tensors= []\n", " for out in output:\n", " out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0),0)))\n", " out_tensor = torch.cat(out_tensors, 0)\n", "\n", " if type(x) is tuple:\n", " return MetaTensor(out_tensor, meta=x[0].meta)\n", " else:\n", " return MetaTensor(out_tensor, meta=x.meta)\n", " \n", "def get_nnunet_monai_predictor(model_folder, model_name=\"model.pt\"):\n", " \n", " from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor\n", " predictor = nnUNetPredictor(\n", " tile_step_size=0.5,\n", " use_gaussian=True,\n", " use_mirroring=False,\n", " device=torch.device('cuda', 0),\n", " verbose=False,\n", " verbose_preprocessing=False,\n", " allow_tqdm=True\n", " )\n", " # initializes the network architecture, loads the checkpoint\n", " wrapper = nnUNetMONAIModelWrapper(predictor, model_folder, model_name)\n", " return wrapper\n" ] }, { "cell_type": "code", "execution_count": null, "id": "62668710", "metadata": {}, "outputs": [], "source": [ "%%writefile nnUNetBundle/src/dataset.py\n", "\n", "import pathlib\n", "import os\n", "\n", "def get_subfolder_dataset(data_dir,modality_conf):\n", " data_list = []\n", " for f in os.scandir(data_dir):\n", "\n", " if f.is_dir():\n", " subject_dict = {key:str(pathlib.Path(f.path).joinpath(f.name+modality_conf[key]['suffix'])) for key in modality_conf}\n", " data_list.append(subject_dict)\n", " return data_list" ] }, { "cell_type": "markdown", "id": "e7d75b56-80b8-412c-97d8-b79e1111caa0", "metadata": {}, "source": [ "## MONAI Bundle to nnUNet Conversion" ] }, { "cell_type": "markdown", "id": "d65a50b3-063d-412b-a8e0-5ec45c003925", "metadata": {}, "source": [ "To convert a MONAI Bundle to a nnUNet Bundle, we need to combine the MONAI checkpoint with the nnUNet checkpoint. This is done by loading the MONAI checkpoint and the nnUNet checkpoint, and updating the nnUNet model weights with the MONAI model weights." ] }, { "cell_type": "code", "execution_count": null, "id": "62793601", "metadata": {}, "outputs": [], "source": [ "from PyMAIA.utils.file_utils import subfiles\n", "from nnunetv2.training.logging.nnunet_logger import nnUNetLogger\n", "from pathlib import Path\n", "import torch\n", "from odict import odict\n", "import os\n", "import shutil" ] }, { "cell_type": "code", "execution_count": null, "id": "ba864564", "metadata": {}, "outputs": [], "source": [ "def convert_MONAI_to_nnUNet(nnunet_root_folder, nnunet_config, bundle_config):\n", " os.environ[\"ROOT_FOLDER\"] = nnunet_root_folder\n", "\n", " os.environ[\"RESULTS_FOLDER\"] = str(\n", " Path(os.environ[\"ROOT_FOLDER\"]).joinpath(\n", " nnunet_config[\"Experiment Name\"], nnunet_config[\"Experiment Name\"] + \"_results\"\n", " )\n", " )\n", "\n", " nnunet_trainer = \"nnUNetTrainer\"\n", " nnunet_plans = \"nnUNetPlans\"\n", "\n", " if \"nnunet_trainer\" in nnunet_config:\n", " nnunet_trainer = nnunet_config[\"nnunet_trainer\"]\n", "\n", " if \"nnunet_plans\" in nnunet_config:\n", " nnunet_plans = nnunet_config[\"nnunet_plans\"]\n", "\n", " nnunet_model_folder = Path(os.environ[\"RESULTS_FOLDER\"]).joinpath(\n", " \"Dataset\" + nnunet_config[\"task_ID\"] + \"_\" + nnunet_config[\n", " \"Experiment Name\"],\n", " f\"{nnunet_trainer}__{nnunet_plans}__3d_fullres\")\n", "\n", " bundle_name = bundle_config[\"Bundle_Name\"]\n", "\n", " nnunet_checkpoint = torch.load(f\"{bundle_name}/models/nnunet_checkpoint.pth\")\n", " latest_checkpoints = subfiles(Path(bundle_name).joinpath(\"models\"),prefix=\"checkpoint_epoch\",sort=True,join=False)\n", " epochs = []\n", " for latest_checkpoint in latest_checkpoints:\n", " epochs.append(int(latest_checkpoint[len(\"checkpoint_epoch=\"):-len(\".pt\")]))\n", "\n", " epochs.sort()\n", " final_epoch = epochs[-1]\n", " monai_last_checkpoint = torch.load(f\"{bundle_name}/models/checkpoint_epoch={final_epoch}.pt\")\n", "\n", " best_checkpoints = subfiles(Path(bundle_name).joinpath(\"models\"), prefix=\"checkpoint_key_metric\", sort=True,\n", " join=False)\n", " key_metrics = []\n", " for best_checkpoint in best_checkpoints:\n", " key_metrics.append(str(best_checkpoint[len(\"checkpoint_key_metric=\"):-len(\".pt\")]))\n", "\n", " key_metrics.sort()\n", " best_key_metric = key_metrics[-1]\n", " monai_best_checkpoint = torch.load(f\"{bundle_name}/models/checkpoint_key_metric={best_key_metric}.pt\")\n", "\n", " nnunet_checkpoint['optimizer_state'] = monai_last_checkpoint['optimizer_state']\n", "\n", "\n", "\n", " nnunet_checkpoint['network_weights'] = odict()\n", "\n", " for key in monai_last_checkpoint['network_weights']:\n", " nnunet_checkpoint['network_weights'][key] = monai_last_checkpoint['network_weights'][key]\n", "\n", " nnunet_checkpoint['current_epoch'] = final_epoch\n", " nnunet_checkpoint['logging'] = nnUNetLogger().get_checkpoint()\n", " nnunet_checkpoint['_best_ema'] = 0\n", " nnunet_checkpoint['grad_scaler_state'] = None\n", "\n", "\n", "\n", " torch.save(nnunet_checkpoint, Path(nnunet_model_folder).joinpath(\"fold_0\",\"checkpoint_final.pth\"))\n", "\n", " nnunet_checkpoint['network_weights'] = odict()\n", "\n", " nnunet_checkpoint['optimizer_state'] = monai_best_checkpoint['optimizer_state']\n", "\n", " for key in monai_best_checkpoint['network_weights']:\n", " nnunet_checkpoint['network_weights'][key] = \\\n", " monai_best_checkpoint['network_weights'][key]\n", "\n", " torch.save(nnunet_checkpoint, Path(nnunet_model_folder).joinpath(\"fold_0\", \"checkpoint_best.pth\"))\n", "\n", " shutil.move(f\"{bundle_name}/models/checkpoint_epoch={final_epoch}.pt\",f\"{bundle_name}/models/model.pt\")\n", " shutil.move(f\"{bundle_name}/models/checkpoint_key_metric={best_key_metric}.pt\",f\"{bundle_name}/models/best_model.pt\")" ] }, { "cell_type": "code", "execution_count": null, "id": "9c2ce85f", "metadata": {}, "outputs": [], "source": [ "nnunet_root_folder = \"MAIA/Experiments\"\n", "\n", "\n", "\n", "nnunet_config = {\n", " \"Experiment Name\": \"Task09_Spleen\",\n", " \"task_ID\": \"109\",\n", " \"nnunet_plans\":\"nnUNetResEncUNetLPlans\"\n", "}\n", "\n", "bundle_config = {\n", " \"Bundle_Name\": \"nnUNetBundle\"\n", "}" ] }, { "cell_type": "markdown", "id": "8190d533", "metadata": {}, "source": [ "## nnUNet to MONAI Bundle Conversion\n", "\n", "To convert a nnUNet Bundle to a MONAI Bundle, we need to separate the MONAI checkpoint from the nnUNet checkpoint. This is done by loading the nnUNet checkpoint and the MONAI checkpoint, and updating the MONAI model weights with the nnUNet model weights." ] }, { "cell_type": "code", "execution_count": null, "id": "aa5244a1", "metadata": {}, "outputs": [], "source": [ "def convert_nnunet_to_monai_bundle(nnunet_root_folder, nnunet_config, bundle_root_folder):\n", " os.environ[\"ROOT_FOLDER\"] = nnunet_root_folder\n", "\n", " os.environ[\"RESULTS_FOLDER\"] = str(\n", " Path(os.environ[\"ROOT_FOLDER\"]).joinpath(\n", " nnunet_config[\"Experiment Name\"], nnunet_config[\"Experiment Name\"] + \"_results\"\n", " )\n", " )\n", "\n", " nnunet_trainer = \"nnUNetTrainer\"\n", " nnunet_plans = \"nnUNetPlans\"\n", "\n", " if \"nnunet_trainer\" in nnunet_config:\n", " nnunet_trainer = nnunet_config[\"nnunet_trainer\"]\n", "\n", " if \"nnunet_plans\" in nnunet_config:\n", " nnunet_plans = nnunet_config[\"nnunet_plans\"]\n", "\n", " nnunet_model_folder = Path(os.environ[\"RESULTS_FOLDER\"]).joinpath(\n", " \"Dataset\" + nnunet_config[\"task_ID\"] + \"_\" + nnunet_config[\n", " \"Experiment Name\"],\n", " f\"{nnunet_trainer}__{nnunet_plans}__3d_fullres\")\n", " \n", " nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(\"fold_0\",\"checkpoint_final.pth\"))\n", " nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(\"fold_0\",\"checkpoint_best.pth\"))\n", "\n", " nnunet_checkpoint = {}\n", " nnunet_checkpoint['inference_allowed_mirroring_axes'] = nnunet_checkpoint_final['inference_allowed_mirroring_axes']\n", " nnunet_checkpoint['init_args'] = nnunet_checkpoint_final['init_args']\n", " nnunet_checkpoint['trainer_name'] = nnunet_checkpoint_final['trainer_name']\n", "\n", " torch.save(nnunet_checkpoint, Path(bundle_root_folder).joinpath(\"models\",\"nnunet_checkpoint.pth\"))\n", "\n", " monai_last_checkpoint = {}\n", " monai_last_checkpoint['network_weights'] = nnunet_checkpoint_final['network_weights']\n", " torch.save(monai_last_checkpoint, Path(bundle_root_folder).joinpath(\"models\",\"model.pt\"))\n", "\n", " monai_best_checkpoint = {}\n", " monai_best_checkpoint['network_weights'] = nnunet_checkpoint_best['network_weights']\n", " torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath(\"models\",\"best_model.pt\"))\n", "\n", " shutil.copy(Path(nnunet_model_folder).joinpath(\"plans.json\"),Path(bundle_root_folder).joinpath(\"models\",\"plans.json\"))\n", " shutil.copy(Path(nnunet_model_folder).joinpath(\"dataset.json\"),Path(bundle_root_folder).joinpath(\"models\",\"dataset.json\"))\n", " \n" ] }, { "cell_type": "code", "execution_count": null, "id": "773c8160", "metadata": {}, "outputs": [], "source": [ "nnunet_root_folder = \"MAIA/Experiments\"\n", "\n", "\n", "\n", "nnunet_config = {\n", " \"Experiment Name\": \"Task09_Spleen\",\n", " \"task_ID\": \"109\",\n", " \"nnunet_plans\":\"nnUNetResEncUNetLPlans\"\n", "}\n", "\n", "bundle_root_folder = \"nnUNetBundle_Test\"\n", "\n", "Path(bundle_root_folder).joinpath(\"models\").mkdir(parents=True, exist_ok=True)\n", "\n", "convert_nnunet_to_monai_bundle(nnunet_root_folder, nnunet_config, bundle_root_folder)" ] }, { "cell_type": "code", "execution_count": null, "id": "1fb17e4f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }