CUDA out of memory, Time limit and Low DSC with nnU-Net
are three main problems when we producing algorithm, here is our solution, and we hope to be helpful for the other teams:
-
Q1 CUDA out of memory: When we test locally
, the model size does not exceed 10GB, but CUDA out of memory
after submission.
-
A1: Make sure your input file and output file are the same as official, use input_files[0] and output.mha!!!
# input path setting
input_files = glob(str(location / "*.tiff")) + glob(str(location / "*.mha"))
result = SimpleITK.ReadImage(input_files[0])
#Output path setting
SimpleITK.WriteImage(
image,
location / f"output{suffix}", # i.e., location/output.mha
useCompression=True,
)
-
Q2 Time Limit: During sliding window inference, the excessive number of windows causes an inference time of more than 5 minutes.
-
A2: The aortic structure is distributed in the center of the image, which appropriately reduces the number of sliding windows in the coronal and sagittal axis direction, for example of nnU-Net V1.
@incollection{huang2023revisiting,
title={Revisiting nnU-Net for Iterative Pseudo Labeling and Efficient Sliding Window Inference},
author={Huang, Ziyan and Wang, Haoyu and Ye, Jin and Niu, Jingqi and Tu, Can and Yang, Yuncheng and Du, Shiyi and Deng, Zhongying and Gu, Lixu and He, Junjun},
booktitle={Fast and Low-Resource Semi-supervised Abdominal Organ Segmentation: MICCAI 2022 Challenge, FLARE 2022, Held in Conjunction with MICCAI 2022, Singapore, September 22, 2022, Proceedings},
pages={178--189},
year={2023},
publisher={Springer}
}
-
Q3 Low DSC and other metric: The performance and metric are normal when testing locally, but obtaining nearly 0% DSC after submission.
-
A3: May helpful for teams using nnU-Net, the issue of Direction of test images can be seen on issue, thanks a lot for Organizers and Team shipc1220 (NexToU-cbDice).
Flipping window data before entering in the network if not using mirroring TTA, for example of nnU-Net V1:
# x: [b, c, z, y, x] loading from SimpleITK
pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 3))))
result_torch += 1 / num_results * torch.flip(pred, (4, 3))
Again, great thanks to Organizers, shipc1220 (NexToU-cbDice) and Sonwe1e (Hanglok_AortaSeg24), and more discussion can be seen on Forum.