diff --git a/README.md b/README.md index eb1e11a..d7929e0 100644 --- a/README.md +++ b/README.md @@ -1,111 +1,140 @@


## Overview -This is a fully automated framework for object detection featuring: +This is a comprehensive framework for object detection featuring: - 2D + 3D implementations of prevalent object detectors: e.g. Mask R-CNN [1], Retina Net [2], Retina U-Net [3]. - Modular and light-weight structure ensuring sharing of all processing steps (incl. backbone architecture) for comparability of models. - training with bounding box and/or pixel-wise annotations. - dynamic patching and tiling of 2D + 3D images (for training and inference). - weighted consolidation of box predictions across patch-overlaps, ensembles, and dimensions [3]. - monitoring + evaluation simultaneously on object and patient level. - 2D + 3D output visualizations. - integration of COCO mean average precision metric [5]. - integration of MIC-DKFZ batch generators for extensive data augmentation [6]. - easy modification to evaluation of instance segmentation and/or semantic segmentation.
[1] He, Kaiming, et al. "Mask R-CNN" ICCV, 2017
[2] Lin, Tsung-Yi, et al. "Focal Loss for Dense Object Detection" TPAMI, 2018.
[3] Jaeger, Paul et al. "Retina U-Net: Embarrassingly Simple Exploitation of Segmentation Supervision for Medical Object Detection" , 2018 [5] https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py
[6] https://github.com/MIC-DKFZ/batchgenerators

## Installation Setup package in virtual environment ``` git clone https://github.com/pfjaeger/medicaldetectiontoolkit.git . cd medicaldetectiontoolkit virtualenv -p python3 venv source venv/bin/activate pip3 install -e . ``` Install MIC-DKFZ batch-generators ``` cd .. git clone https://github.com/MIC-DKFZ/batchgenerators cd batchgenerators pip3 install -e . -cd mdt +cd ../medicaldetectiontoolkit +``` +We use two cuda functions: Non-Maximum Suppression (taken from [pytorch-faster-rcnn](https://github.com/ruotianluo/pytorch-faster-rcnn) and added adaption for 3D) and RoiAlign (taken from [RoiAlign](https://github.com/longcw/RoIAlign.pytorch), fixed according to [this bug report](https://hackernoon.com/how-tensorflows-tf-image-resize-stole-60-days-of-my-life-aba5eb093f35), and added adaption for 3D). In this framework, they come pre-compile for TitanX. If you have a different GPU you need to re-compile these functions: + + +| GPU | arch | +| --- | --- | +| TitanX | sm_52 | +| GTX 960M | sm_50 | +| GTX 1070 | sm_61 | +| GTX 1080 (Ti) | sm_61 | + +``` +cd cuda_functions/nms_xD/src/cuda/ +nvcc -c -o nms_kernel.cu.o nms_kernel.cu -x cu -Xcompiler -fPIC -arch=[arch] +cd ../../ +python build.py +cd ../ + +cd cuda_functions/roi_align_xD/roi_align/src/cuda/ +nvcc -c -o crop_and_resize_kernel.cu.o crop_and_resize_kernel.cu -x cu -Xcompiler -fPIC -arch=[arch] +cd ../../ +python build.py +cd ../../ ``` ## Prepare the Data This framework is meant for you to be able to train models on your own data sets. -An example data loader is provided in medicaldetectiontoolkit/experiments including thorough documentation to ensure a quick start for your own project. +Two example data loaders are provided in medicaldetectiontoolkit/experiments including thorough documentation to ensure a quick start for your own project. The way I load Data is to have a preprocessing script, which after preprocessing saves the Data of whatever data type (in the case of LIDC, those are .nrrd files obtained from [this data conversion tool](https://github.com/MIC-DKFZ/LIDC-IDRI-processing/tree/v1.0.1)) into numpy arrays (this step is just done once). During training / testing, the data loader then loads these numpy arrays dynamically. (Please note the Data Input side is meant to be customized by you according to your own needs and the provided Data loaders are merely examples: LIDC has a powerful Dataloader that handles 2D/3D inputs and is optimized for patch-based training and inference. Toy-Experiments have a lightweight Dataloader, only handling 2D without patching. The latter makes sense if you want to get familiar with the framework.). ## Execute 1. Set I/O paths, model and training specifics in the configs file: medicaldetectiontoolkit/experiments/your_experiment/configs.py 2. Train the model: ``` python exec.py --mode train --exp_source experiments/my_experiment --exp_dir path/to/experiment/directory ``` This copies snapshots of configs and model to the specified exp_dir, where all outputs will be saved. By default, the data is split into 60% training and 20% validation and 20% testing data to perform a 5-fold cross validation (can be changed to hold-out test set in configs) and all folds will be trained iteratively. In order to train a single fold, specify it using the folds arg: ``` python exec.py --folds 0 1 2 .... # specify any combination of folds [0-4] ``` 3. Run inference: ``` python exec.py --mode test --exp_dir path/to/experiment/directory ``` This runs the prediction pipeline and saves all results to exp_dir. ## Models This framework features all models explored in [3] (implemented in 2D + 3D): The proposed Retina U-Net, a simple but effective Architecture fusing state-of-the-art semantic segmentation with object detection,


also implementations of prevalent object detectors, such as Mask R-CNN, Faster R-CNN+ (Faster R-CNN w\ RoIAlign), Retina Net, U-Faster R-CNN+ (the two stage counterpart of Retina U-Net: Faster R-CNN with auxiliary semantic segmentation), DetU-Net (a U-Net like segmentation architecture with heuristics for object detection.)



## Training annotations This framework features training with pixelwise and/or bounding box annotations. To overcome the issue of box coordinates in data augmentation, we feed the annotation masks through data augmentation (create a pseudo mask, if only bounding box annotations provided) and draw the boxes afterwards.


+The framework further handles two types of pixel-wise annotations: + +1. A label map with individual ROIs identified by increasing label values, accompanied by a vector containing in each position the class target for the lesion with the corresponding label (for this mode set get_rois_from_seg_flag = False when calling ConvertSegToBoundingBoxCoordinates in your Data Loader). +2. A binary label map. There is only one foreground class and single lesions are not identified. All lesions have the same class target (foreground). In this case the Dataloader runs a Connected Component Labelling algorithm to create processable lesion - class target pairs on the fly (for this mode set get_rois_from_seg_flag = True when calling ConvertSegToBoundingBoxCoordinates in your Data Loader). + ## Prediction pipeline This framework provides an inference module, which automatically handles patching of inputs, and tiling, ensembling, and weighted consolidation of output predictions:




## Consolidation of predictions (Weighted Box Clustering) Multiple predictions of the same image (from test time augmentations, tested epochs and overlapping patches), result in a high amount of boxes (or cubes), which need to be consolidated. In semantic segmentation, the final output would typically be obtained by averaging every pixel over all predictions. As described in [3], **weighted box clustering** (WBC) does this for box predictions:





## Visualization / Monitoring By default, loss functions and performance metrics are monitored:




Histograms of matched output predictions for training/validation/testing are plotted per foreground class:



Input images + ground truth annotations + output predictions of a sampled validation abtch are plotted after each epoch (here 2D sampled slice with +-3 neighbouring context slices in channels):



Zoomed into the last two lines of the plot:


## How to cite this code Please cite the original publication [3]. ## License The code is published under the [Apache License Version 2.0](LICENSE). + diff --git a/cuda_functions/roi_align_2D/roi_align/src/crop_and_resize.c b/cuda_functions/roi_align_2D/roi_align/src/crop_and_resize.c index e1fce67..a5ff973 100644 --- a/cuda_functions/roi_align_2D/roi_align/src/crop_and_resize.c +++ b/cuda_functions/roi_align_2D/roi_align/src/crop_and_resize.c @@ -1,252 +1,269 @@ #include #include #include void CropAndResizePerBox( const float * image_data, const int batch_size, const int depth, const int image_height, const int image_width, const float * boxes_data, const int * box_index_data, const int start_box, const int limit_box, float * corps_data, const int crop_height, const int crop_width, const float extrapolation_value ) { const int image_channel_elements = image_height * image_width; const int image_elements = depth * image_channel_elements; const int channel_elements = crop_height * crop_width; const int crop_elements = depth * channel_elements; int b; #pragma omp parallel for for (b = start_box; b < limit_box; ++b) { const float * box = boxes_data + b * 4; const float y1 = box[0]; const float x1 = box[1]; const float y2 = box[2]; const float x2 = box[3]; const int b_in = box_index_data[b]; if (b_in < 0 || b_in >= batch_size) { printf("Error: batch_index %d out of range [0, %d)\n", b_in, batch_size); exit(-1); } const float height_scale = (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) : 0; const float width_scale = (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) : 0; for (int y = 0; y < crop_height; ++y) { const float in_y = (crop_height > 1) ? y1 * (image_height - 1) + y * height_scale : 0.5 * (y1 + y2) * (image_height - 1); if (in_y < 0 || in_y > image_height - 1) { for (int x = 0; x < crop_width; ++x) { for (int d = 0; d < depth; ++d) { // crops(b, y, x, d) = extrapolation_value; corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = extrapolation_value; } } continue; } const int top_y_index = floorf(in_y); const int bottom_y_index = ceilf(in_y); const float y_lerp = in_y - top_y_index; for (int x = 0; x < crop_width; ++x) { const float in_x = (crop_width > 1) ? x1 * (image_width - 1) + x * width_scale : 0.5 * (x1 + x2) * (image_width - 1); if (in_x < 0 || in_x > image_width - 1) { for (int d = 0; d < depth; ++d) { corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = extrapolation_value; } continue; } const int left_x_index = floorf(in_x); const int right_x_index = ceilf(in_x); const float x_lerp = in_x - left_x_index; for (int d = 0; d < depth; ++d) { const float *pimage = image_data + b_in * image_elements + d * image_channel_elements; const float top_left = pimage[top_y_index * image_width + left_x_index]; const float top_right = pimage[top_y_index * image_width + right_x_index]; const float bottom_left = pimage[bottom_y_index * image_width + left_x_index]; const float bottom_right = pimage[bottom_y_index * image_width + right_x_index]; const float top = top_left + (top_right - top_left) * x_lerp; const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp; corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = top + (bottom - top) * y_lerp; } } // end for x } // end for y } // end for b } void crop_and_resize_forward( THFloatTensor * image, THFloatTensor * boxes, // [y1, x1, y2, x2] THIntTensor * box_index, // range in [0, batch_size) const float extrapolation_value, const int crop_height, const int crop_width, THFloatTensor * crops ) { - const int batch_size = image->size[0]; - const int depth = image->size[1]; - const int image_height = image->size[2]; - const int image_width = image->size[3]; + //const int batch_size = image->size[0]; + //const int depth = image->size[1]; + //const int image_height = image->size[2]; + //const int image_width = image->size[3]; - const int num_boxes = boxes->size[0]; + //const int num_boxes = boxes->size[0]; + + const int batch_size = THFloatTensor_size(image, 0); + const int depth = THFloatTensor_size(image, 1); + const int image_height = THFloatTensor_size(image, 2); + const int image_width = THFloatTensor_size(image, 3); + + const int num_boxes = THFloatTensor_size(boxes, 0); // init output space THFloatTensor_resize4d(crops, num_boxes, depth, crop_height, crop_width); THFloatTensor_zero(crops); // crop_and_resize for each box CropAndResizePerBox( THFloatTensor_data(image), batch_size, depth, image_height, image_width, THFloatTensor_data(boxes), THIntTensor_data(box_index), 0, num_boxes, THFloatTensor_data(crops), crop_height, crop_width, extrapolation_value ); } void crop_and_resize_backward( THFloatTensor * grads, THFloatTensor * boxes, // [y1, x1, y2, x2] THIntTensor * box_index, // range in [0, batch_size) THFloatTensor * grads_image // resize to [bsize, c, hc, wc] ) -{ +{ // shape - const int batch_size = grads_image->size[0]; - const int depth = grads_image->size[1]; - const int image_height = grads_image->size[2]; - const int image_width = grads_image->size[3]; - - const int num_boxes = grads->size[0]; - const int crop_height = grads->size[2]; - const int crop_width = grads->size[3]; + //const int batch_size = grads_image->size[0]; + //const int depth = grads_image->size[1]; + //const int image_height = grads_image->size[2]; + //const int image_width = grads_image->size[3]; + + //const int num_boxes = grads->size[0]; + //const int crop_height = grads->size[2]; + //const int crop_width = grads->size[3]; + + const int batch_size = THFloatTensor_size(grads_image, 0); + const int depth = THFloatTensor_size(grads_image, 1); + const int image_height = THFloatTensor_size(grads_image, 2); + const int image_width = THFloatTensor_size(grads_image, 3); + + const int num_boxes = THFloatTensor_size(grads, 0); + const int crop_height = THFloatTensor_size(grads,2); + const int crop_width = THFloatTensor_size(grads,3); + // n_elements const int image_channel_elements = image_height * image_width; const int image_elements = depth * image_channel_elements; const int channel_elements = crop_height * crop_width; const int crop_elements = depth * channel_elements; // init output space THFloatTensor_zero(grads_image); // data pointer const float * grads_data = THFloatTensor_data(grads); const float * boxes_data = THFloatTensor_data(boxes); const int * box_index_data = THIntTensor_data(box_index); float * grads_image_data = THFloatTensor_data(grads_image); for (int b = 0; b < num_boxes; ++b) { const float * box = boxes_data + b * 4; const float y1 = box[0]; const float x1 = box[1]; const float y2 = box[2]; const float x2 = box[3]; const int b_in = box_index_data[b]; if (b_in < 0 || b_in >= batch_size) { printf("Error: batch_index %d out of range [0, %d)\n", b_in, batch_size); exit(-1); } const float height_scale = (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) : 0; const float width_scale = (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) : 0; for (int y = 0; y < crop_height; ++y) { const float in_y = (crop_height > 1) ? y1 * (image_height - 1) + y * height_scale : 0.5 * (y1 + y2) * (image_height - 1); if (in_y < 0 || in_y > image_height - 1) { continue; } const int top_y_index = floorf(in_y); const int bottom_y_index = ceilf(in_y); const float y_lerp = in_y - top_y_index; for (int x = 0; x < crop_width; ++x) { const float in_x = (crop_width > 1) ? x1 * (image_width - 1) + x * width_scale : 0.5 * (x1 + x2) * (image_width - 1); if (in_x < 0 || in_x > image_width - 1) { continue; } const int left_x_index = floorf(in_x); const int right_x_index = ceilf(in_x); const float x_lerp = in_x - left_x_index; for (int d = 0; d < depth; ++d) { float *pimage = grads_image_data + b_in * image_elements + d * image_channel_elements; const float grad_val = grads_data[crop_elements * b + channel_elements * d + y * crop_width + x]; const float dtop = (1 - y_lerp) * grad_val; pimage[top_y_index * image_width + left_x_index] += (1 - x_lerp) * dtop; pimage[top_y_index * image_width + right_x_index] += x_lerp * dtop; const float dbottom = y_lerp * grad_val; pimage[bottom_y_index * image_width + left_x_index] += (1 - x_lerp) * dbottom; pimage[bottom_y_index * image_width + right_x_index] += x_lerp * dbottom; } // end d } // end x } // end y } // end b } \ No newline at end of file diff --git a/cuda_functions/roi_align_3D/roi_align/src/crop_and_resize.c b/cuda_functions/roi_align_3D/roi_align/src/crop_and_resize.c index e1fce67..a5ff973 100644 --- a/cuda_functions/roi_align_3D/roi_align/src/crop_and_resize.c +++ b/cuda_functions/roi_align_3D/roi_align/src/crop_and_resize.c @@ -1,252 +1,269 @@ #include #include #include void CropAndResizePerBox( const float * image_data, const int batch_size, const int depth, const int image_height, const int image_width, const float * boxes_data, const int * box_index_data, const int start_box, const int limit_box, float * corps_data, const int crop_height, const int crop_width, const float extrapolation_value ) { const int image_channel_elements = image_height * image_width; const int image_elements = depth * image_channel_elements; const int channel_elements = crop_height * crop_width; const int crop_elements = depth * channel_elements; int b; #pragma omp parallel for for (b = start_box; b < limit_box; ++b) { const float * box = boxes_data + b * 4; const float y1 = box[0]; const float x1 = box[1]; const float y2 = box[2]; const float x2 = box[3]; const int b_in = box_index_data[b]; if (b_in < 0 || b_in >= batch_size) { printf("Error: batch_index %d out of range [0, %d)\n", b_in, batch_size); exit(-1); } const float height_scale = (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) : 0; const float width_scale = (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) : 0; for (int y = 0; y < crop_height; ++y) { const float in_y = (crop_height > 1) ? y1 * (image_height - 1) + y * height_scale : 0.5 * (y1 + y2) * (image_height - 1); if (in_y < 0 || in_y > image_height - 1) { for (int x = 0; x < crop_width; ++x) { for (int d = 0; d < depth; ++d) { // crops(b, y, x, d) = extrapolation_value; corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = extrapolation_value; } } continue; } const int top_y_index = floorf(in_y); const int bottom_y_index = ceilf(in_y); const float y_lerp = in_y - top_y_index; for (int x = 0; x < crop_width; ++x) { const float in_x = (crop_width > 1) ? x1 * (image_width - 1) + x * width_scale : 0.5 * (x1 + x2) * (image_width - 1); if (in_x < 0 || in_x > image_width - 1) { for (int d = 0; d < depth; ++d) { corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = extrapolation_value; } continue; } const int left_x_index = floorf(in_x); const int right_x_index = ceilf(in_x); const float x_lerp = in_x - left_x_index; for (int d = 0; d < depth; ++d) { const float *pimage = image_data + b_in * image_elements + d * image_channel_elements; const float top_left = pimage[top_y_index * image_width + left_x_index]; const float top_right = pimage[top_y_index * image_width + right_x_index]; const float bottom_left = pimage[bottom_y_index * image_width + left_x_index]; const float bottom_right = pimage[bottom_y_index * image_width + right_x_index]; const float top = top_left + (top_right - top_left) * x_lerp; const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp; corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = top + (bottom - top) * y_lerp; } } // end for x } // end for y } // end for b } void crop_and_resize_forward( THFloatTensor * image, THFloatTensor * boxes, // [y1, x1, y2, x2] THIntTensor * box_index, // range in [0, batch_size) const float extrapolation_value, const int crop_height, const int crop_width, THFloatTensor * crops ) { - const int batch_size = image->size[0]; - const int depth = image->size[1]; - const int image_height = image->size[2]; - const int image_width = image->size[3]; + //const int batch_size = image->size[0]; + //const int depth = image->size[1]; + //const int image_height = image->size[2]; + //const int image_width = image->size[3]; - const int num_boxes = boxes->size[0]; + //const int num_boxes = boxes->size[0]; + + const int batch_size = THFloatTensor_size(image, 0); + const int depth = THFloatTensor_size(image, 1); + const int image_height = THFloatTensor_size(image, 2); + const int image_width = THFloatTensor_size(image, 3); + + const int num_boxes = THFloatTensor_size(boxes, 0); // init output space THFloatTensor_resize4d(crops, num_boxes, depth, crop_height, crop_width); THFloatTensor_zero(crops); // crop_and_resize for each box CropAndResizePerBox( THFloatTensor_data(image), batch_size, depth, image_height, image_width, THFloatTensor_data(boxes), THIntTensor_data(box_index), 0, num_boxes, THFloatTensor_data(crops), crop_height, crop_width, extrapolation_value ); } void crop_and_resize_backward( THFloatTensor * grads, THFloatTensor * boxes, // [y1, x1, y2, x2] THIntTensor * box_index, // range in [0, batch_size) THFloatTensor * grads_image // resize to [bsize, c, hc, wc] ) -{ +{ // shape - const int batch_size = grads_image->size[0]; - const int depth = grads_image->size[1]; - const int image_height = grads_image->size[2]; - const int image_width = grads_image->size[3]; - - const int num_boxes = grads->size[0]; - const int crop_height = grads->size[2]; - const int crop_width = grads->size[3]; + //const int batch_size = grads_image->size[0]; + //const int depth = grads_image->size[1]; + //const int image_height = grads_image->size[2]; + //const int image_width = grads_image->size[3]; + + //const int num_boxes = grads->size[0]; + //const int crop_height = grads->size[2]; + //const int crop_width = grads->size[3]; + + const int batch_size = THFloatTensor_size(grads_image, 0); + const int depth = THFloatTensor_size(grads_image, 1); + const int image_height = THFloatTensor_size(grads_image, 2); + const int image_width = THFloatTensor_size(grads_image, 3); + + const int num_boxes = THFloatTensor_size(grads, 0); + const int crop_height = THFloatTensor_size(grads,2); + const int crop_width = THFloatTensor_size(grads,3); + // n_elements const int image_channel_elements = image_height * image_width; const int image_elements = depth * image_channel_elements; const int channel_elements = crop_height * crop_width; const int crop_elements = depth * channel_elements; // init output space THFloatTensor_zero(grads_image); // data pointer const float * grads_data = THFloatTensor_data(grads); const float * boxes_data = THFloatTensor_data(boxes); const int * box_index_data = THIntTensor_data(box_index); float * grads_image_data = THFloatTensor_data(grads_image); for (int b = 0; b < num_boxes; ++b) { const float * box = boxes_data + b * 4; const float y1 = box[0]; const float x1 = box[1]; const float y2 = box[2]; const float x2 = box[3]; const int b_in = box_index_data[b]; if (b_in < 0 || b_in >= batch_size) { printf("Error: batch_index %d out of range [0, %d)\n", b_in, batch_size); exit(-1); } const float height_scale = (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) : 0; const float width_scale = (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) : 0; for (int y = 0; y < crop_height; ++y) { const float in_y = (crop_height > 1) ? y1 * (image_height - 1) + y * height_scale : 0.5 * (y1 + y2) * (image_height - 1); if (in_y < 0 || in_y > image_height - 1) { continue; } const int top_y_index = floorf(in_y); const int bottom_y_index = ceilf(in_y); const float y_lerp = in_y - top_y_index; for (int x = 0; x < crop_width; ++x) { const float in_x = (crop_width > 1) ? x1 * (image_width - 1) + x * width_scale : 0.5 * (x1 + x2) * (image_width - 1); if (in_x < 0 || in_x > image_width - 1) { continue; } const int left_x_index = floorf(in_x); const int right_x_index = ceilf(in_x); const float x_lerp = in_x - left_x_index; for (int d = 0; d < depth; ++d) { float *pimage = grads_image_data + b_in * image_elements + d * image_channel_elements; const float grad_val = grads_data[crop_elements * b + channel_elements * d + y * crop_width + x]; const float dtop = (1 - y_lerp) * grad_val; pimage[top_y_index * image_width + left_x_index] += (1 - x_lerp) * dtop; pimage[top_y_index * image_width + right_x_index] += x_lerp * dtop; const float dbottom = y_lerp * grad_val; pimage[bottom_y_index * image_width + left_x_index] += (1 - x_lerp) * dbottom; pimage[bottom_y_index * image_width + right_x_index] += x_lerp * dbottom; } // end d } // end x } // end y } // end b } \ No newline at end of file diff --git a/experiments/lidc_exp/preprocessing.py b/experiments/lidc_exp/preprocessing.py index 73c6ef5..b838fa0 100644 --- a/experiments/lidc_exp/preprocessing.py +++ b/experiments/lidc_exp/preprocessing.py @@ -1,136 +1,144 @@ #!/usr/bin/env python # Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +''' +This preprocessing script loads nrrd files obtained by the data conversion tool: https://github.com/MIC-DKFZ/LIDC-IDRI-processing/tree/v1.0.1 +After applying preprocessing, images are saved as numpy arrays and the meta information for the corresponding patient is stored +as a line in the dataframe saved as info_df.pickle. +''' + import os import SimpleITK as sitk import numpy as np from multiprocessing import Pool import pandas as pd import numpy.testing as npt from skimage.transform import resize import subprocess +import pickle import configs cf = configs.configs() -# if a rater did not identify a nodule, this vote counts as 0s on the pixels. and as 0 == background (or 1?) on the mal. score. -# will this lead to many surpressed nodules. yes. they are not stored in segmentation map and the mal. labels are discarded. -# a pixel counts as foreground, if at least 2 raters drew it as foreground. def resample_array(src_imgs, src_spacing, target_spacing): src_spacing = np.round(src_spacing, 3) target_shape = [int(src_imgs.shape[ix] * src_spacing[::-1][ix] / target_spacing[::-1][ix]) for ix in range(len(src_imgs.shape))] - # print('target shape', target_shape, src_imgs.shape, src_spacing, target_spacing) for i in range(len(target_shape)): try: assert target_shape[i] > 0 except: raise AssertionError("AssertionError:", src_imgs.shape, src_spacing, target_spacing) img = src_imgs.astype(float) resampled_img = resize(img, target_shape, order=1, clip=True, mode='edge').astype('float32') return resampled_img def pp_patient(inputs): ix, path = inputs pid = path.split('/')[-1] img = sitk.ReadImage(os.path.join(path, '{}_ct_scan.nrrd'.format(pid))) img_arr = sitk.GetArrayFromImage(img) print('processing {}'.format(pid), img.GetSpacing(), img_arr.shape) img_arr = resample_array(img_arr, img.GetSpacing(), cf.target_spacing) img_arr = np.clip(img_arr, -1200, 600) #img_arr = (1200 + img_arr) / (600 + 1200) * 255 # a+x / (b-a) * (c-d) (c, d = new) img_arr = img_arr.astype(np.float32) img_arr = (img_arr - np.mean(img_arr)) / np.std(img_arr).astype(np.float16) - print('img arr shape after', img_arr.shape) - - # import matplotlib.pyplot as plt - # plt.figure() - # plt.hist(img_arr.flatten(), bins=100) - # plt.savefig(cf.root_dir + '/test.png') - # plt.close() df = pd.read_csv(os.path.join(cf.root_dir, 'characteristics.csv'), sep=';') df = df[df.PatientID == pid] final_rois = np.zeros_like(img_arr, dtype=np.uint8) mal_labels = [] roi_ids = set([ii.split('.')[0].split('_')[-1] for ii in os.listdir(path) if '.nii.gz' in ii]) rix = 1 for rid in roi_ids: roi_id_paths = [ii for ii in os.listdir(path) if '{}.nii'.format(rid) in ii] nodule_ids = [ii.split('_')[2].lstrip("0") for ii in roi_id_paths] rater_labels = [df[df.NoduleID == int(ii)].Malignancy.values[0] for ii in nodule_ids] rater_labels.extend([0] * (4-len(rater_labels))) - # print(nodule_ids, roi_id_paths, df.Malignancy.values, pid) mal_label = np.mean([ii for ii in rater_labels if ii > -1]) roi_rater_list = [] for rp in roi_id_paths: roi = sitk.ReadImage(os.path.join(cf.raw_data_dir, pid, rp)) roi_arr = sitk.GetArrayFromImage(roi).astype(np.uint8) roi_arr = resample_array(roi_arr, roi.GetSpacing(), cf.target_spacing) assert roi_arr.shape == img_arr.shape, [roi_arr.shape, img_arr.shape, pid, roi.GetSpacing()] for ix in range(len(img_arr.shape)): npt.assert_almost_equal(roi.GetSpacing()[ix], img.GetSpacing()[ix]) roi_rater_list.append(roi_arr) roi_rater_list.extend([np.zeros_like(roi_rater_list[-1])]*(4-len(roi_id_paths))) roi_raters = np.array(roi_rater_list) roi_raters = np.mean(roi_raters, axis=0) roi_raters[roi_raters < 0.5] = 0 if np.sum(roi_raters) > 0: mal_labels.append(mal_label) final_rois[roi_raters >= 0.5] = rix rix += 1 else: - print('surpressed roi!', roi_id_paths) - with open(os.path.join(cf.pp_dir, 'surpressed_rois.txt'), 'a') as handle: + # indicate rois suppressed by majority voting of raters + print('suppressed roi!', roi_id_paths) + with open(os.path.join(cf.pp_dir, 'suppressed_rois.txt'), 'a') as handle: handle.write(" ".join(roi_id_paths)) fg_slices = [ii for ii in np.unique(np.argwhere(final_rois != 0)[:, 0])] mal_labels = np.array(mal_labels) assert len(mal_labels) + 1 == len(np.unique(final_rois)), [len(mal_labels), np.unique(final_rois), pid] - out_df = pd.read_pickle(os.path.join(cf.pp_dir, 'info_df.pickle')) - out_df.loc[len(out_df)] = {'pid': pid, 'class_target': mal_labels, 'spacing': img.GetSpacing(), 'fg_slices': fg_slices} - out_df.to_pickle(os.path.join(cf.pp_dir, 'info_df.pickle')) + np.save(os.path.join(cf.pp_dir, '{}_rois.npy'.format(pid)), final_rois) np.save(os.path.join(cf.pp_dir, '{}_img.npy'.format(pid)), img_arr) + with open(os.path.join(cf.pp_dir, 'meta_info_{}.pickle'.format(pid)), 'wb') as handle: + meta_info_dict = {'pid': pid, 'class_target': mal_labels, 'spacing': img.GetSpacing(), 'fg_slices': fg_slices} + pickle.dump(meta_info_dict, handle) + + + +def aggregate_meta_info(exp_dir): + + files = [os.path.join(exp_dir, f) for f in os.listdir(exp_dir) if 'meta_info' in f] + df = pd.DataFrame(columns=['pid', 'class_target', 'spacing', 'fg_slices']) + for f in files: + with open(f, 'rb') as handle: + df.loc[len(df)] = pickle.load(handle) + + df.to_pickle(os.path.join(exp_dir, 'info_df.pickle')) + print ("aggregated meta info to df with length", len(df)) if __name__ == "__main__": paths = [os.path.join(cf.raw_data_dir, ii) for ii in os.listdir(cf.raw_data_dir)] if not os.path.exists(cf.pp_dir): os.mkdir(cf.pp_dir) - df = pd.DataFrame(columns=['pid', 'class_target', 'spacing', 'fg_slices']) - df.to_pickle(os.path.join(cf.pp_dir, 'info_df.pickle')) - pool = Pool(processes=12) p1 = pool.map(pp_patient, enumerate(paths), chunksize=1) pool.close() pool.join() # for i in enumerate(paths): # pp_patient(i) + aggregate_meta_info(cf.pp_dir) subprocess.call('cp {} {}'.format(os.path.join(cf.pp_dir, 'info_df.pickle'), os.path.join(cf.pp_dir, 'info_df_bk.pickle')), shell=True) \ No newline at end of file diff --git a/experiments/toy_exp/generate_toys.py b/experiments/toy_exp/generate_toys.py index 9f336c9..4f44768 100644 --- a/experiments/toy_exp/generate_toys.py +++ b/experiments/toy_exp/generate_toys.py @@ -1,94 +1,107 @@ #!/usr/bin/env python # Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import os import numpy as np import pandas as pd +import pickle from multiprocessing import Pool import configs as cf def multi_processing_create_image(inputs): out_dir, six, foreground_margin, class_diameters, mode = inputs print('proceesing {} {}'.format(out_dir, six)) img = np.random.rand(320, 320) seg = np.zeros((320, 320)).astype('uint8') center_x = np.random.randint(foreground_margin, img.shape[0] - foreground_margin) center_y = np.random.randint(foreground_margin, img.shape[1] - foreground_margin) class_id = np.random.randint(0, 2) for y in range(img.shape[0]): for x in range(img.shape[0]): if ((x - center_x) ** 2 + (y - center_y) ** 2 - class_diameters[class_id] ** 2) < 0: img[y][x] += 0.2 seg[y][x] = 1 if 'donuts' in mode: whole_diameter = 4 if class_id == 1: for y in range(img.shape[0]): for x in range(img.shape[0]): if ((x - center_x) ** 2 + (y - center_y) ** 2 - whole_diameter ** 2) < 0: img[y][x] -= 0.2 if mode == 'donuts_shape': seg[y][x] = 0 out = np.concatenate((img[None], seg[None])) out_path = os.path.join(out_dir, '{}.npy'.format(six)) - df = pd.read_pickle(os.path.join(out_dir, 'info_df.pickle')) - df.loc[len(df)] = [out_path, class_id, str(six)] - df.to_pickle(os.path.join(out_dir, 'info_df.pickle')) np.save(out_path, out) + with open(os.path.join(out_dir, 'meta_info_{}.pickle'.format(six)), 'wb') as handle: + pickle.dump([out_path, class_id, str(six)], handle) -def get_toy_image_info(mode, n_images, out_dir, class_diameters=(20, 20)): - if not os.path.exists(out_dir): - os.makedirs(out_dir) +def generate_experiment(exp_name, n_train_images, n_test_images, mode, class_diameters=(20, 20)): + + train_dir = os.path.join(cf.root_dir, exp_name, 'train') + test_dir = os.path.join(cf.root_dir, exp_name, 'test') + if not os.path.exists(train_dir): + os.makedirs(train_dir) + if not os.path.exists(test_dir): + os.makedirs(test_dir) # enforced distance between object center and image edge. foreground_margin = np.max(class_diameters) // 2 + info = [] + info += [[train_dir, six, foreground_margin, class_diameters, mode] for six in range(n_train_images)] + info += [[test_dir, six, foreground_margin, class_diameters, mode] for six in range(n_test_images)] + + print('starting creating {} images'.format(len(info))) + pool = Pool(processes=12) + pool.map(multi_processing_create_image, info, chunksize=1) + pool.close() + pool.join() + + aggregate_meta_info(train_dir) + aggregate_meta_info(test_dir) + + +def aggregate_meta_info(exp_dir): + + files = [os.path.join(exp_dir, f) for f in os.listdir(exp_dir) if 'meta_info' in f] df = pd.DataFrame(columns=['path', 'class_id', 'pid']) - df.to_pickle(os.path.join(out_dir, 'info_df.pickle')) - return [[out_dir, six, foreground_margin, class_diameters, mode] for six in range(n_images)] + for f in files: + with open(f, 'rb') as handle: + df.loc[len(df)] = pickle.load(handle) + + df.to_pickle(os.path.join(exp_dir, 'info_df.pickle')) + print ("aggregated meta info to df with length", len(df)) if __name__ == '__main__': cf = cf.configs() - root_dir = os.path.join(cf.root_dir, 'donuts_shape') - info = [] - info += get_toy_image_info(mode='donuts_shape', n_images=1500, out_dir=os.path.join(root_dir, 'train')) - info += get_toy_image_info(mode='donuts_shape', n_images=1000, out_dir=os.path.join(root_dir, 'test')) - - root_dir = os.path.join(cf.root_dir, 'donuts_pattern') - info += get_toy_image_info(mode='donuts_pattern', n_images=1500, out_dir=os.path.join(root_dir, 'train')) - info += get_toy_image_info(mode='donuts_pattern', n_images=1000, out_dir=os.path.join(root_dir, 'test')) + generate_experiment('donuts_shape_threads', n_train_images=1500, n_test_images=1000, mode='donuts_shape') + generate_experiment('donuts_pattern', n_train_images=1500, n_test_images=1000, mode='donuts_pattern') + generate_experiment('circles_scale', n_train_images=1500, n_test_images=1000, mode='circles_scale', class_diameters=(19, 20)) - root_dir = os.path.join(cf.root_dir, 'circles_scale') - info += get_toy_image_info(mode='circles_scale', n_images=1500, out_dir=os.path.join(root_dir, 'train'), class_diameters=(19, 20)) - info += get_toy_image_info(mode='circles_scale', n_images=1000, out_dir=os.path.join(root_dir, 'test'), class_diameters=(19, 20)) - print('starting creating {} images'.format(len(info))) - pool = Pool(processes=12) - pool.map(multi_processing_create_image, info, chunksize=1) - pool.close() - pool.join() diff --git a/models/backbone.py b/models/backbone.py index 8d249ac..418dedd 100644 --- a/models/backbone.py +++ b/models/backbone.py @@ -1,288 +1,218 @@ #!/usr/bin/env python # Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import torch.nn as nn import torch.nn.functional as F import torch class FPN(nn.Module): """ Feature Pyramid Network from https://arxiv.org/pdf/1612.03144.pdf with options for modifications. by default is constructed with Pyramid levels P2, P3, P4, P5. """ def __init__(self, cf, conv, operate_stride1=False): """ from configs: :param input_channels: number of channel dimensions in input data. :param start_filts: number of feature_maps in first layer. rest is scaled accordingly. :param out_channels: number of feature_maps for output_layers of all levels in decoder. :param conv: instance of custom conv class containing the dimension info. :param res_architecture: string deciding whether to use "resnet50" or "resnet101". :param operate_stride1: boolean flag. enables adding of Pyramid levels P1 (output stride 2) and P0 (output stride 1). :param norm: string specifying type of feature map normalization. If None, no normalization is applied. :param relu: string specifying type of nonlinearity. If None, no nonlinearity is applied. :param sixth_pooling: boolean flag. enables adding of Pyramid level P6. """ super(FPN, self).__init__() self.start_filts = cf.start_filts start_filts = self.start_filts self.n_blocks = [3, 4, {"resnet50": 6, "resnet101": 23}[cf.res_architecture], 3] self.block = ResBlock self.block_expansion = 4 self.operate_stride1 = operate_stride1 self.sixth_pooling = cf.sixth_pooling self.dim = conv.dim if operate_stride1: self.C0 = nn.Sequential(conv(cf.n_channels, start_filts, ks=3, pad=1, norm=cf.norm, relu=cf.relu), conv(start_filts, start_filts, ks=3, pad=1, norm=cf.norm, relu=cf.relu)) self.C1 = conv(start_filts, start_filts, ks=7, stride=(2, 2, 1) if conv.dim == 3 else 2, pad=3, norm=cf.norm, relu=cf.relu) else: self.C1 = conv(cf.n_channels, start_filts, ks=7, stride=(2, 2, 1) if conv.dim == 3 else 2, pad=3, norm=cf.norm, relu=cf.relu) start_filts_exp = start_filts * self.block_expansion C2_layers = [] C2_layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1) if conv.dim == 2 else nn.MaxPool3d(kernel_size=3, stride=(2, 2, 1), padding=1)) C2_layers.append(self.block(start_filts, start_filts, conv=conv, stride=1, norm=cf.norm, relu=cf.relu, downsample=(start_filts, self.block_expansion, 1))) for i in range(1, self.n_blocks[0]): C2_layers.append(self.block(start_filts_exp, start_filts, conv=conv, norm=cf.norm, relu=cf.relu)) self.C2 = nn.Sequential(*C2_layers) C3_layers = [] C3_layers.append(self.block(start_filts_exp, start_filts * 2, conv=conv, stride=2, norm=cf.norm, relu=cf.relu, downsample=(start_filts_exp, 2, 2))) for i in range(1, self.n_blocks[1]): C3_layers.append(self.block(start_filts_exp * 2, start_filts * 2, conv=conv, norm=cf.norm, relu=cf.relu)) self.C3 = nn.Sequential(*C3_layers) C4_layers = [] C4_layers.append(self.block( start_filts_exp * 2, start_filts * 4, conv=conv, stride=2, norm=cf.norm, relu=cf.relu, downsample=(start_filts_exp * 2, 2, 2))) for i in range(1, self.n_blocks[2]): C4_layers.append(self.block(start_filts_exp * 4, start_filts * 4, conv=conv, norm=cf.norm, relu=cf.relu)) self.C4 = nn.Sequential(*C4_layers) C5_layers = [] C5_layers.append(self.block( start_filts_exp * 4, start_filts * 8, conv=conv, stride=2, norm=cf.norm, relu=cf.relu, downsample=(start_filts_exp * 4, 2, 2))) for i in range(1, self.n_blocks[3]): C5_layers.append(self.block(start_filts_exp * 8, start_filts * 8, conv=conv, norm=cf.norm, relu=cf.relu)) self.C5 = nn.Sequential(*C5_layers) if self.sixth_pooling: C6_layers = [] C6_layers.append(self.block( start_filts_exp * 8, start_filts * 16, conv=conv, stride=2, norm=cf.norm, relu=cf.relu, downsample=(start_filts_exp * 8, 2, 2))) for i in range(1, self.n_blocks[3]): C6_layers.append(self.block(start_filts_exp * 16, start_filts * 16, conv=conv, norm=cf.norm, relu=cf.relu)) self.C6 = nn.Sequential(*C6_layers) if conv.dim == 2: self.P1_upsample = Interpolate(scale_factor=2, mode='bilinear') self.P2_upsample = Interpolate(scale_factor=2, mode='bilinear') else: self.P1_upsample = Interpolate(scale_factor=(2, 2, 1), mode='trilinear') self.P2_upsample = Interpolate(scale_factor=(2, 2, 1), mode='trilinear') self.out_channels = cf.end_filts self.P5_conv1 = conv(start_filts*32 + cf.n_latent_dims, self.out_channels, ks=1, stride=1, relu=None) # self.P4_conv1 = conv(start_filts*16, self.out_channels, ks=1, stride=1, relu=None) self.P3_conv1 = conv(start_filts*8, self.out_channels, ks=1, stride=1, relu=None) self.P2_conv1 = conv(start_filts*4, self.out_channels, ks=1, stride=1, relu=None) self.P1_conv1 = conv(start_filts, self.out_channels, ks=1, stride=1, relu=None) if operate_stride1: self.P0_conv1 = conv(start_filts, self.out_channels, ks=1, stride=1, relu=None) self.P0_conv2 = conv(self.out_channels, self.out_channels, ks=3, stride=1, pad=1, relu=None) self.P1_conv2 = conv(self.out_channels, self.out_channels, ks=3, stride=1, pad=1, relu=None) self.P2_conv2 = conv(self.out_channels, self.out_channels, ks=3, stride=1, pad=1, relu=None) self.P3_conv2 = conv(self.out_channels, self.out_channels, ks=3, stride=1, pad=1, relu=None) self.P4_conv2 = conv(self.out_channels, self.out_channels, ks=3, stride=1, pad=1, relu=None) self.P5_conv2 = conv(self.out_channels, self.out_channels, ks=3, stride=1, pad=1, relu=None) if self.sixth_pooling: self.P6_conv1 = conv(start_filts * 64, self.out_channels, ks=1, stride=1, relu=None) self.P6_conv2 = conv(self.out_channels, self.out_channels, ks=3, stride=1, pad=1, relu=None) def forward(self, x): """ :param x: input image of shape (b, c, y, x, (z)) :return: list of output feature maps per pyramid level, each with shape (b, c, y, x, (z)). """ if self.operate_stride1: c0_out = self.C0(x) else: c0_out = x c1_out = self.C1(c0_out) c2_out = self.C2(c1_out) c3_out = self.C3(c2_out) c4_out = self.C4(c3_out) c5_out = self.C5(c4_out) if self.sixth_pooling: c6_out = self.C6(c5_out) p6_pre_out = self.P6_conv1(c6_out) p5_pre_out = self.P5_conv1(c5_out) + F.interpolate(p6_pre_out, scale_factor=2) else: p5_pre_out = self.P5_conv1(c5_out) p4_pre_out = self.P4_conv1(c4_out) + F.interpolate(p5_pre_out, scale_factor=2) p3_pre_out = self.P3_conv1(c3_out) + F.interpolate(p4_pre_out, scale_factor=2) p2_pre_out = self.P2_conv1(c2_out) + F.interpolate(p3_pre_out, scale_factor=2) # plot feature map shapes for debugging. # for ii in [c0_out, c1_out, c2_out, c3_out, c4_out, c5_out, c6_out]: # print ("encoder shapes:", ii.shape) # # for ii in [p6_out, p5_out, p4_out, p3_out, p2_out, p1_out]: # print("decoder shapes:", ii.shape) p2_out = self.P2_conv2(p2_pre_out) p3_out = self.P3_conv2(p3_pre_out) p4_out = self.P4_conv2(p4_pre_out) p5_out = self.P5_conv2(p5_pre_out) out_list = [p2_out, p3_out, p4_out, p5_out] if self.sixth_pooling: p6_out = self.P6_conv2(p6_pre_out) out_list.append(p6_out) if self.operate_stride1: p1_pre_out = self.P1_conv1(c1_out) + self.P2_upsample(p2_pre_out) p0_pre_out = self.P0_conv1(c0_out) + self.P1_upsample(p1_pre_out) # p1_out = self.P1_conv2(p1_pre_out) # usually not needed. p0_out = self.P0_conv2(p0_pre_out) out_list = [p0_out] + out_list return out_list - def encoder_forward(self, x): - """ - :param x: input image of shape (b, c, y, x, (z)) - :return: list of output feature maps per pyramid level, each with shape (b, c, y, x, (z)). - """ - if self.operate_stride1: - c0_out = self.C0(x) - else: - c0_out = x - - c1_out = self.C1(c0_out) - c2_out = self.C2(c1_out) - c3_out = self.C3(c2_out) - c4_out = self.C4(c3_out) - c5_out = self.C5(c4_out) - out_list = [c0_out, c1_out, c2_out, c3_out, c4_out, c5_out] - if self.sixth_pooling: - c6_out = self.C6(c5_out) - out_list += [c6_out] - - return out_list - - - def decoder_forward(self, encoder_list, inject=None): - - if inject is not None: - z = inject - - if self.dim == 2: - z = z.unsqueeze(-1).unsqueeze(-1).repeat( - 1, 1, encoder_list[-1].shape[-2], encoder_list[-1].shape[-1]) - else: - z = z.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat( - 1, 1, encoder_list[-1].shape[-3], encoder_list[-1].shape[-2], encoder_list[-1].shape[-1]) - - x = torch.cat((encoder_list[-1], z), 1) - - else: - x = encoder_list[-1] - - if self.sixth_pooling: - p6_pre_out = self.P6_conv1(x) - p5_pre_out = self.P5_conv1(encoder_list[5]) + F.interpolate(p6_pre_out, scale_factor=2) - else: - p5_pre_out = self.P5_conv1(x) - - p4_pre_out = self.P4_conv1(encoder_list[4]) + F.interpolate(p5_pre_out, scale_factor=2) - p3_pre_out = self.P3_conv1(encoder_list[3]) + F.interpolate(p4_pre_out, scale_factor=2) - p2_pre_out = self.P2_conv1(encoder_list[2]) + F.interpolate(p3_pre_out, scale_factor=2) - - p2_out = self.P2_conv2(p2_pre_out) - p3_out = self.P3_conv2(p3_pre_out) - p4_out = self.P4_conv2(p4_pre_out) - p5_out = self.P5_conv2(p5_pre_out) - out_list = [p2_out, p3_out, p4_out, p5_out] - - if self.sixth_pooling: - p6_out = self.P6_conv2(p6_pre_out) - out_list.append(p6_out) - - if self.operate_stride1: - p1_pre_out = self.P1_conv1(c1_out) + self.P2_upsample(p2_pre_out) - p0_pre_out = self.P0_conv1(c0_out) + self.P1_upsample(p1_pre_out) - # p1_out = self.P1_conv2(p1_pre_out) # usually not needed. - p0_out = self.P0_conv2(p0_pre_out) - out_list = [p0_out] + out_list - - return out_list - - class ResBlock(nn.Module): def __init__(self, start_filts, planes, conv, stride=1, downsample=None, norm=None, relu='relu'): super(ResBlock, self).__init__() self.conv1 = conv(start_filts, planes, ks=1, stride=stride, norm=norm, relu=relu) self.conv2 = conv(planes, planes, ks=3, pad=1, norm=norm, relu=relu) self.conv3 = conv(planes, planes * 4, ks=1, norm=norm, relu=None) self.relu = nn.ReLU(inplace=True) if relu == 'relu' else nn.LeakyReLU(inplace=True) if downsample is not None: self.downsample = conv(downsample[0], downsample[0] * downsample[1], ks=1, stride=downsample[2], norm=norm, relu=None) else: self.downsample = None self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.conv2(out) out = self.conv3(out) if self.downsample: residual = self.downsample(x) out += residual out = self.relu(out) return out class Interpolate(nn.Module): def __init__(self, scale_factor, mode): super(Interpolate, self).__init__() self.interp = nn.functional.interpolate self.scale_factor = scale_factor self.mode = mode def forward(self, x): x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False) return x \ No newline at end of file diff --git a/plotting.py b/plotting.py index 4e47646..4e15c74 100644 --- a/plotting.py +++ b/plotting.py @@ -1,266 +1,272 @@ #!/usr/bin/env python # Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import numpy as np import os from copy import deepcopy def plot_batch_prediction(batch, results_dict, cf, outfile= None): """ plot the input images, ground truth annotations, and output predictions of a batch. If 3D batch, plots a 2D projection of one randomly sampled element (patient) in the batch. Since plotting all slices of patient volume blows up costs of time and space, only a section containing a randomly sampled ground truth annotation is plotted. :param batch: dict with keys: 'data' (input image), 'seg' (pixelwise annotations), 'pid' :param results_dict: list over batch element. Each element is a list of boxes (prediction and ground truth), where every box is a dictionary containing box_coords, box_score and box_type. """ if outfile is None: outfile = os.path.join(cf.plot_dir, 'pred_example_{}.png'.format(cf.fold)) data = batch['data'] segs = batch['seg'] pids = batch['pid'] # for 3D, repeat pid over batch elements. if len(set(pids)) == 1: pids = [pids] * data.shape[0] seg_preds = results_dict['seg_preds'] roi_results = deepcopy(results_dict['boxes']) # Randomly sampled one patient of batch and project data into 2D slices for plotting. if cf.dim == 3: patient_ix = np.random.choice(data.shape[0]) data = np.transpose(data[patient_ix], axes=(3, 0, 1, 2)) # select interesting foreground section to plot. gt_boxes = [box['box_coords'] for box in roi_results[patient_ix] if box['box_type'] == 'gt'] if len(gt_boxes) > 0: z_cuts = [np.max((int(gt_boxes[0][4]) - 5, 0)), np.min((int(gt_boxes[0][5]) + 5, data.shape[0]))] else: z_cuts = [data.shape[0]//2 - 5, int(data.shape[0]//2 + np.min([10, data.shape[0]//2]))] p_roi_results = roi_results[patient_ix] roi_results = [[] for _ in range(data.shape[0])] # iterate over cubes and spread across slices. for box in p_roi_results: b = box['box_coords'] # dismiss negative anchor slices. slices = np.round(np.unique(np.clip(np.arange(b[4], b[5] + 1), 0, data.shape[0]-1))) for s in slices: roi_results[int(s)].append(box) roi_results[int(s)][-1]['box_coords'] = b[:4] roi_results = roi_results[z_cuts[0]: z_cuts[1]] data = data[z_cuts[0]: z_cuts[1]] segs = np.transpose(segs[patient_ix], axes=(3, 0, 1, 2))[z_cuts[0]: z_cuts[1]] seg_preds = np.transpose(seg_preds[patient_ix], axes=(3, 0, 1, 2))[z_cuts[0]: z_cuts[1]] pids = [pids[patient_ix]] * data.shape[0] try: # all dimensions except for the 'channel-dimension' are required to match for i in [0, 2, 3]: assert data.shape[i] == segs.shape[i] == seg_preds.shape[i] except: raise Warning('Shapes of arrays to plot not in agreement!' 'Shapes {} vs. {} vs {}'.format(data.shape, segs.shape, seg_preds.shape)) show_arrays = np.concatenate([data, segs, seg_preds, data[:, 0][:, None]], axis=1).astype(float) approx_figshape = (4 * show_arrays.shape[0], 4 * show_arrays.shape[1]) fig = plt.figure(figsize=approx_figshape) gs = gridspec.GridSpec(show_arrays.shape[1] + 1, show_arrays.shape[0]) gs.update(wspace=0.1, hspace=0.1) for b in range(show_arrays.shape[0]): for m in range(show_arrays.shape[1]): ax = plt.subplot(gs[m, b]) ax.axis('off') if m < show_arrays.shape[1]: arr = show_arrays[b, m] if m < data.shape[1] or m == show_arrays.shape[1] - 1: cmap = 'gray' vmin = None vmax = None else: cmap = None vmin = 0 vmax = cf.num_seg_classes - 1 if m == 0: plt.title('{}'.format(pids[b][:10]), fontsize=20) plt.imshow(arr, cmap=cmap, vmin=vmin, vmax=vmax) if m >= (data.shape[1]): for box in roi_results[b]: if box['box_type'] != 'patient_tn_box': # don't plot true negative dummy boxes. coords = box['box_coords'] if box['box_type'] == 'det': # dont plot background preds or low confidence boxes. if box['box_pred_class_id'] > 0 and box['box_score'] > 0.1: plot_text = True score = np.max(box['box_score']) score_text = '{}|{:.0f}'.format(box['box_pred_class_id'], score*100) # if prob detection: plot only boxes from correct sampling instance. if 'sample_id' in box.keys() and int(box['sample_id']) != m - data.shape[1] - 2: continue # if prob detection: plot reconstructed boxes only in corresponding line. if not 'sample_id' in box.keys() and m != data.shape[1] + 1: continue score_font_size = 7 text_color = 'w' text_x = coords[1] + 10*(box['box_pred_class_id'] -1) #avoid overlap of scores in plot. text_y = coords[2] + 5 else: continue elif box['box_type'] == 'gt': plot_text = True score_text = int(box['box_label']) score_font_size = 7 text_color = 'r' text_x = coords[1] text_y = coords[0] - 1 else: plot_text = False color_var = 'extra_usage' if 'extra_usage' in list(box.keys()) else 'box_type' color = cf.box_color_palette[box[color_var]] plt.plot([coords[1], coords[3]], [coords[0], coords[0]], color=color, linewidth=1, alpha=1) # up plt.plot([coords[1], coords[3]], [coords[2], coords[2]], color=color, linewidth=1, alpha=1) # down plt.plot([coords[1], coords[1]], [coords[0], coords[2]], color=color, linewidth=1, alpha=1) # left plt.plot([coords[3], coords[3]], [coords[0], coords[2]], color=color, linewidth=1, alpha=1) # right if plot_text: plt.text(text_x, text_y, score_text, fontsize=score_font_size, color=text_color) try: plt.savefig(outfile) except: raise Warning('failed to save plot.') plt.close(fig) class TrainingPlot_2Panel(): def __init__(self, cf): self.file_name = cf.plot_dir + '/monitor_{}'.format(cf.fold) self.exp_name = cf.fold_dir + self.do_validation = cf.do_validation self.separate_values_dict = cf.assign_values_to_extra_figure self.figure_list = [] for n in range(cf.n_monitoring_figures): self.figure_list.append(plt.figure(figsize=(10, 6))) self.figure_list[-1].ax1 = plt.subplot(111) self.figure_list[-1].ax1.set_xlabel('epochs') self.figure_list[-1].ax1.set_ylabel('loss / metrics') self.figure_list[-1].ax1.set_xlim(0, cf.num_epochs) self.figure_list[-1].ax1.grid() self.figure_list[0].ax1.set_ylim(0, 1.5) self.color_palette = ['b', 'c', 'r', 'purple', 'm', 'y', 'k', 'tab:gray'] def update_and_save(self, metrics, epoch): for figure_ix in range(len(self.figure_list)): fig = self.figure_list[figure_ix] - detection_monitoring_plot(fig.ax1, metrics, self.exp_name, self.color_palette, epoch, figure_ix, self.separate_values_dict) + detection_monitoring_plot(fig.ax1, metrics, self.exp_name, self.color_palette, epoch, figure_ix, + self.separate_values_dict, + self.do_validation) fig.savefig(self.file_name + '_{}'.format(figure_ix)) -def detection_monitoring_plot(ax1, metrics, exp_name, color_palette, epoch, figure_ix, separate_values_dict): +def detection_monitoring_plot(ax1, metrics, exp_name, color_palette, epoch, figure_ix, separate_values_dict, do_validation): monitor_values_keys = metrics['train']['monitor_values'][1][0].keys() separate_values = [v for fig_ix in separate_values_dict.values() for v in fig_ix] if figure_ix == 0: plot_keys = [ii for ii in monitor_values_keys if ii not in separate_values] plot_keys += [k for k in metrics['train'].keys() if k != 'monitor_values'] else: plot_keys = separate_values_dict[figure_ix] x = np.arange(1, epoch + 1) for kix, pk in enumerate(plot_keys): if pk in metrics['train'].keys(): y_train = metrics['train'][pk][1:] - y_val = metrics['val'][pk][1:] + if do_validation: + y_val = metrics['val'][pk][1:] else: y_train = [np.mean([er[pk] for er in metrics['train']['monitor_values'][e]]) for e in x] - y_val = [np.mean([er[pk] for er in metrics['val']['monitor_values'][e]]) for e in x] + if do_validation: + y_val = [np.mean([er[pk] for er in metrics['val']['monitor_values'][e]]) for e in x] ax1.plot(x, y_train, label='train_{}'.format(pk), linestyle='--', color=color_palette[kix]) - ax1.plot(x, y_val, label='val_{}'.format(pk), linestyle='-', color=color_palette[kix]) + if do_validation: + ax1.plot(x, y_val, label='val_{}'.format(pk), linestyle='-', color=color_palette[kix]) if epoch == 1: box = ax1.get_position() ax1.set_position([box.x0, box.y0, box.width * 0.8, box.height]) ax1.legend(loc='center left', bbox_to_anchor=(1, 0.5)) ax1.set_title(exp_name) def plot_prediction_hist(label_list, pred_list, type_list, outfile): """ plot histogram of predictions for a specific class. :param label_list: list of 1s and 0s specifying whether prediction is a true positive match (1) or a false positive (0). False negatives (missed ground truth objects) are artificially added predictions with score 0 and label 1. :param pred_list: list of prediction-scores. :param type_list: list of prediction-types for stastic-info in title. """ preds = np.array(pred_list) labels = np.array(label_list) title = outfile.split('/')[-1] + ' count:{}'.format(len(label_list)) plt.figure() plt.yscale('log') if 0 in labels: plt.hist(preds[labels == 0], alpha=0.3, color='g', range=(0, 1), bins=50, label='false pos.') if 1 in labels: plt.hist(preds[labels == 1], alpha=0.3, color='b', range=(0, 1), bins=50, label='true pos. (false neg. @ score=0)') if type_list is not None: fp_count = type_list.count('det_fp') fn_count = type_list.count('det_fn') tp_count = type_list.count('det_tp') pos_count = fn_count + tp_count title += ' tp:{} fp:{} fn:{} pos:{}'. format(tp_count, fp_count, fn_count, pos_count) plt.legend() plt.title(title) plt.xlabel('confidence score') plt.ylabel('log n') plt.savefig(outfile) plt.close() def plot_stat_curves(stats, outfile): for c in ['roc', 'prc']: plt.figure() for s in stats: if s[c] is not None: plt.plot(s[c][0], s[c][1], label=s['name'] + '_' + c) plt.title(outfile.split('/')[-1] + '_' + c) plt.legend(loc=3 if c == 'prc' else 4) plt.xlabel('precision' if c == 'prc' else '1-spec.') plt.ylabel('recall') plt.savefig(outfile + '_' + c) plt.close() diff --git a/predictor.py b/predictor.py index 0c32495..dd3ae23 100644 --- a/predictor.py +++ b/predictor.py @@ -1,819 +1,816 @@ #!/usr/bin/env python # Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import os import numpy as np import torch from scipy.stats import norm from collections import OrderedDict from multiprocessing import Pool import pickle import pandas as pd class Predictor: """ Prediction pipeline: - receives a patched patient image (n_patches, c, y, x, (z)) from patient data loader. - forwards patches through model in chunks of batch_size. (method: batch_tiling_forward) - unmolds predictions (boxes and segmentations) to original patient coordinates. (method: spatial_tiling_forward) Ensembling (mode == 'test'): - for inference, forwards 4 mirrored versions of image to through model and unmolds predictions afterwards accordingly (method: data_aug_forward) - for inference, loads multiple parameter-sets of the trained model corresponding to different epochs. for each parameter-set loops over entire test set, runs prediction pipeline for each patient. (method: predict_test_set) Consolidation of predictions: - consolidates a patient's predictions (boxes, segmentations) collected over patches, data_aug- and temporal ensembling, performs clustering and weighted averaging (external function: apply_wbc_to_patient) to obtain consistent outptus. - for 2D networks, consolidates box predictions to 3D cubes via clustering (adaption of non-maximum surpression). (external function: merge_2D_to_3D_preds_per_patient) Ground truth handling: - dissmisses any ground truth boxes returned by the model (happens in validation mode, patch-based groundtruth) - if provided by data loader, adds 3D ground truth to the final predictions to be passed to the evaluator. """ def __init__(self, cf, net, logger, mode): self.cf = cf self.logger = logger # mode is 'val' for patient-based validation/monitoring and 'test' for inference. self.mode = mode # model instance. In validation mode, contains parameters of current epoch. self.net = net # rank of current epoch loaded (for temporal averaging). this info is added to each prediction, # for correct weighting during consolidation. self.rank_ix = '0' # number of ensembled models. used to calculate the number of expected predictions per position # during consolidation of predictions. Default is 1 (no ensembling, e.g. in validation). self.n_ens = 1 if self.mode == 'test': try: self.epoch_ranking = np.load(os.path.join(self.cf.fold_dir, 'epoch_ranking.npy'))[:cf.test_n_epochs] except: raise RuntimeError('no epoch ranking file in fold directory. ' 'seems like you are trying to run testing without prior training...') self.n_ens = cf.test_n_epochs if self.cf.test_aug: self.n_ens *= 4 def predict_patient(self, batch): """ predicts one patient. called either directly via loop over validation set in exec.py (mode=='val') or from self.predict_test_set (mode=='test). in val mode: adds 3D ground truth info to predictions and runs consolidation and 2Dto3D merging of predictions. in test mode: returns raw predictions (ground truth addition, consolidation, 2D to 3D merging are done in self.predict_test_set, because patient predictions across several epochs might be needed to be collected first, in case of temporal ensembling). :return. results_dict: stores the results for one patient. dictionary with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) - monitor_values (only in validation mode) """ self.logger.info('evaluating patient {} for fold {} '.format(batch['pid'], self.cf.fold)) # True if patient is provided in patches and predictions need to be tiled. self.patched_patient = True if 'patch_crop_coords' in list(batch.keys()) else False # forward batch through prediction pipeline. results_dict = self.data_aug_forward(batch) if self.mode == 'val': for b in range(batch['patient_bb_target'].shape[0]): for t in range(len(batch['patient_bb_target'][b])): results_dict['boxes'][b].append({'box_coords': batch['patient_bb_target'][b][t], 'box_label': batch['patient_roi_labels'][b][t], 'box_type': 'gt'}) if self.patched_patient: wcs_input = [results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.wcs_iou, self.n_ens] results_dict['boxes'] = apply_wbc_to_patient(wcs_input)[0] if self.cf.merge_2D_to_3D_preds: merge_dims_inputs = [results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.merge_3D_iou] results_dict['boxes'] = merge_2D_to_3D_preds_per_patient(merge_dims_inputs)[0] return results_dict def predict_test_set(self, batch_gen, return_results=True): """ wrapper around test method, which loads multiple (or one) epoch parameters (temporal ensembling), loops through the test set and collects predictions per patient. Also flattens the results per patient and epoch and adds optional ground truth boxes for evaluation. Saves out the raw result list for later analysis and optionally consolidates and returns predictions immediately. :return: (optionally) list_of_results_per_patient: list over patient results. each entry is a dict with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': not implemented yet. todo for evaluation of instance/semantic segmentation. """ dict_of_patient_results = OrderedDict() # get paths of all parameter sets to be loaded for temporal ensembling. (or just one for no temp. ensembling). weight_paths = [os.path.join(self.cf.fold_dir, '{}_best_params.pth'.format(epoch)) for epoch in self.epoch_ranking] for rank_ix, weight_path in enumerate(weight_paths): self.logger.info(('tmp ensembling over rank_ix:{} epoch:{}'.format(rank_ix, weight_path))) self.net.load_state_dict(torch.load(weight_path)) self.net.eval() self.rank_ix = str(rank_ix) # get string of current rank for unique patch ids. with torch.no_grad(): for _ in range(batch_gen['n_test']): batch = next(batch_gen['test']) # store batch info in patient entry of results dict. if rank_ix == 0: dict_of_patient_results[batch['pid']] = {} dict_of_patient_results[batch['pid']]['results_list'] = [] dict_of_patient_results[batch['pid']]['patient_bb_target'] = batch['patient_bb_target'] dict_of_patient_results[batch['pid']]['patient_roi_labels'] = batch['patient_roi_labels'] # call prediction pipeline and store results in dict. results_dict = self.predict_patient(batch) dict_of_patient_results[batch['pid']]['results_list'].append(results_dict['boxes']) self.logger.info('finished predicting test set. starting post-processing of predictions.') list_of_results_per_patient = [] # loop over patients again to flatten results across epoch predictions. # if provided, add ground truth boxes for evaluation. for pid, p_dict in dict_of_patient_results.items(): tmp_ens_list = p_dict['results_list'] results_dict = {} # collect all boxes/seg_preds of same batch_instance over temporal instances. results_dict['boxes'] = [[item for d in tmp_ens_list for item in d[batch_instance]] for batch_instance in range(len(tmp_ens_list[0]))] # TODO return for instance segmentation: # results_dict['seg_preds'] = np.mean(results_dict['seg_preds'], 1)[:, None] # results_dict['seg_preds'] = np.array([[item for d in tmp_ens_list for item in d['seg_preds'][batch_instance]] # for batch_instance in range(len(tmp_ens_list[0]['boxes']))]) # add 3D ground truth boxes for evaluation. for b in range(p_dict['patient_bb_target'].shape[0]): for t in range(len(p_dict['patient_bb_target'][b])): results_dict['boxes'][b].append({'box_coords': p_dict['patient_bb_target'][b][t], 'box_label': p_dict['patient_roi_labels'][b][t], 'box_type': 'gt'}) list_of_results_per_patient.append([results_dict['boxes'], pid]) # save out raw predictions. out_string = 'raw_pred_boxes_hold_out_list' if self.cf.hold_out_test_set else 'raw_pred_boxes_list' with open(os.path.join(self.cf.fold_dir, '{}.pickle'.format(out_string)), 'wb') as handle: pickle.dump(list_of_results_per_patient, handle) if return_results: # consolidate predictions. self.logger.info('applying wcs to test set predictions with iou = {} and n_ens = {}.'.format( self.cf.wcs_iou, self.n_ens)) pool = Pool(processes=6) mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.wcs_iou, self.n_ens] for ii in list_of_results_per_patient] list_of_results_per_patient = pool.map(apply_wbc_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() # merge 2D boxes to 3D cubes. (if model predicts 2D but evaluation is run in 3D) if self.cf.merge_2D_to_3D_preds: self.logger.info('applying 2Dto3D merging to test set predictions with iou = {}.'.format(self.cf.merge_3D_iou)) pool = Pool(processes=6) mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.merge_3D_iou] for ii in list_of_results_per_patient] list_of_results_per_patient = pool.map(merge_2D_to_3D_preds_per_patient, mp_inputs, chunksize=1) pool.close() pool.join() return list_of_results_per_patient def load_saved_predictions(self, apply_wbc=False): """ loads raw predictions saved by self.predict_test_set. consolidates and merges 2D boxes to 3D cubes for evaluation. (if model predicts 2D but evaluation is run in 3D) :return: (optionally) list_of_results_per_patient: list over patient results. each entry is a dict with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': not implemented yet. todo for evaluation of instance/semantic segmentation. """ # load predictions for a single test-set fold. if not self.cf.hold_out_test_set: with open(os.path.join(self.cf.fold_dir, 'raw_pred_boxes_list.pickle'), 'rb') as handle: list_of_results_per_patient = pickle.load(handle) da_factor = 4 if self.cf.test_aug else 1 n_ens = self.cf.test_n_epochs * da_factor self.logger.info('loaded raw test set predictions with n_patients = {} and n_ens = {}'.format( len(list_of_results_per_patient), n_ens)) # if hold out test set was perdicted, aggregate predictions of all trained models # corresponding to all CV-folds and flatten them. else: boxes_list = [] for fold in self.cf.folds: fold_dir = os.path.join(self.cf.exp_dir, 'fold_{}'.format(fold)) with open(os.path.join(fold_dir, 'raw_pred_boxes_hold_out_list.pickle'), 'rb') as handle: fold_list = pickle.load(handle) pids = [ii[1] for ii in fold_list] boxes_list.append([ii[0] for ii in fold_list]) list_of_results_per_patient = [[[[box for fold_list in boxes_list for box in fold_list[pix][0] if box['box_type'] == 'det']], pid] for pix, pid in enumerate(pids)] da_factor = 4 if self.cf.test_aug else 1 n_ens = self.cf.test_n_epochs * da_factor * len(self.cf.folds) # consolidate predictions. if apply_wbc: self.logger.info('applying wcs to test set predictions with iou = {} and n_ens = {}.'.format( self.cf.wcs_iou, n_ens)) pool = Pool(processes=6) mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.wcs_iou, n_ens] for ii in list_of_results_per_patient] list_of_results_per_patient = pool.map(apply_wbc_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() else: list_of_results_per_patient = list_of_results_per_patient # merge 2D box predictions to 3D cubes (if model predicts 2D but evaluation is run in 3D) if self.cf.merge_2D_to_3D_preds: self.logger.info( 'applying 2Dto3D merging to test set predictions with iou = {}.'.format(self.cf.merge_3D_iou)) pool = Pool(processes=6) mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.merge_3D_iou] for ii in list_of_results_per_patient] list_of_results_per_patient = pool.map(merge_2D_to_3D_preds_per_patient, mp_inputs, chunksize=1) pool.close() pool.join() return list_of_results_per_patient def data_aug_forward(self, batch): """ in val_mode: passes batch through to spatial_tiling method without data_aug. in test_mode: if cf.test_aug is set in configs, createst 4 mirrored versions of the input image, passes all of them to the next processing step (spatial_tiling method) and re-transforms returned predictions to original image version. :return. results_dict: stores the results for one patient. dictionary with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) - monitor_values (only in validation mode) """ patch_crops = batch['patch_crop_coords'] if self.patched_patient else None results_list = [self.spatial_tiling_forward(batch, patch_crops)] org_img_shape = batch['original_img_shape'] if self.mode == 'test' and self.cf.test_aug: if self.patched_patient: # apply mirror transformations to patch-crop coordinates, for correct tiling in spatial_tiling method. mirrored_patch_crops = get_mirrored_patch_crops(patch_crops, batch['original_img_shape']) else: mirrored_patch_crops = [None] * 3 img = np.copy(batch['data']) # first mirroring: y-axis. batch['data'] = np.flip(img, axis=2).copy() chunk_dict = self.spatial_tiling_forward(batch, mirrored_patch_crops[0], n_aug='1') # re-transform coordinates. for ix in range(len(chunk_dict['boxes'])): for boxix in range(len(chunk_dict['boxes'][ix])): coords = chunk_dict['boxes'][ix][boxix]['box_coords'].copy() coords[0] = org_img_shape[2] - chunk_dict['boxes'][ix][boxix]['box_coords'][2] coords[2] = org_img_shape[2] - chunk_dict['boxes'][ix][boxix]['box_coords'][0] assert coords[2] >= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()] assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()] chunk_dict['boxes'][ix][boxix]['box_coords'] = coords # re-transform segmentation predictions. chunk_dict['seg_preds'] = np.flip(chunk_dict['seg_preds'], axis=2) results_list.append(chunk_dict) # second mirroring: x-axis. batch['data'] = np.flip(img, axis=3).copy() chunk_dict = self.spatial_tiling_forward(batch, mirrored_patch_crops[1], n_aug='2') # re-transform coordinates. for ix in range(len(chunk_dict['boxes'])): for boxix in range(len(chunk_dict['boxes'][ix])): coords = chunk_dict['boxes'][ix][boxix]['box_coords'].copy() coords[1] = org_img_shape[3] - chunk_dict['boxes'][ix][boxix]['box_coords'][3] coords[3] = org_img_shape[3] - chunk_dict['boxes'][ix][boxix]['box_coords'][1] assert coords[2] >= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()] assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()] chunk_dict['boxes'][ix][boxix]['box_coords'] = coords # re-transform segmentation predictions. chunk_dict['seg_preds'] = np.flip(chunk_dict['seg_preds'], axis=3) results_list.append(chunk_dict) # third mirroring: y-axis and x-axis. batch['data'] = np.flip(np.flip(img, axis=2), axis=3).copy() chunk_dict = self.spatial_tiling_forward(batch, mirrored_patch_crops[2], n_aug='3') # re-transform coordinates. for ix in range(len(chunk_dict['boxes'])): for boxix in range(len(chunk_dict['boxes'][ix])): coords = chunk_dict['boxes'][ix][boxix]['box_coords'].copy() coords[0] = org_img_shape[2] - chunk_dict['boxes'][ix][boxix]['box_coords'][2] coords[2] = org_img_shape[2] - chunk_dict['boxes'][ix][boxix]['box_coords'][0] coords[1] = org_img_shape[3] - chunk_dict['boxes'][ix][boxix]['box_coords'][3] coords[3] = org_img_shape[3] - chunk_dict['boxes'][ix][boxix]['box_coords'][1] assert coords[2] >= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()] assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords'].copy()] chunk_dict['boxes'][ix][boxix]['box_coords'] = coords # re-transform segmentation predictions. chunk_dict['seg_preds'] = np.flip(np.flip(chunk_dict['seg_preds'], axis=2), axis=3).copy() results_list.append(chunk_dict) batch['data'] = img # aggregate all boxes/seg_preds per batch element from data_aug predictions. results_dict = {} results_dict['boxes'] = [[item for d in results_list for item in d['boxes'][batch_instance]] for batch_instance in range(org_img_shape[0])] results_dict['seg_preds'] = np.array([[item for d in results_list for item in d['seg_preds'][batch_instance]] for batch_instance in range(org_img_shape[0])]) if self.mode == 'val': results_dict['monitor_values'] = results_list[0]['monitor_values'] return results_dict def spatial_tiling_forward(self, batch, patch_crops=None, n_aug='0'): """ forwards batch to batch_tiling_forward method and receives and returns a dictionary with results. if patch-based prediction, the results received from batch_tiling_forward will be on a per-patch-basis. this method uses the provided patch_crops to re-transform all predictions to whole-image coordinates. Patch-origin information of all box-predictions will be needed for consolidation, hence it is stored as 'patch_id', which is a unique string for each patch (also takes current data aug and temporal epoch instances into account). all box predictions get additional information about the amount overlapping patches at the respective position (used for consolidation). :return. results_dict: stores the results for one patient. dictionary with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) - monitor_values (only in validation mode) """ if patch_crops is not None: patches_dict = self.batch_tiling_forward(batch) results_dict = {'boxes': [[] for _ in range(batch['original_img_shape'][0])]} # instanciate segemntation output array. Will contain averages over patch predictions. out_seg_preds = np.zeros(batch['original_img_shape'], dtype=np.float16)[:, 0][:, None] # counts patch instances per pixel-position. patch_overlap_map = np.zeros_like(out_seg_preds, dtype='uint8') #unmold segmentation outputs. loop over patches. for pix, pc in enumerate(patch_crops): if self.cf.dim == 3: - try: - out_seg_preds[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += patches_dict['seg_preds'][pix][None] - patch_overlap_map[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += 1 - except: - print('hi') + out_seg_preds[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += patches_dict['seg_preds'][pix][None] + patch_overlap_map[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += 1 else: out_seg_preds[pc[4]:pc[5], :, pc[0]:pc[1], pc[2]:pc[3], ] += patches_dict['seg_preds'][pix] patch_overlap_map[pc[4]:pc[5], :, pc[0]:pc[1], pc[2]:pc[3], ] += 1 # take mean in overlapping areas. out_seg_preds[patch_overlap_map > 0] /= patch_overlap_map[patch_overlap_map > 0] results_dict['seg_preds'] = out_seg_preds # unmold box outputs. loop over patches. for pix, pc in enumerate(patch_crops): patch_boxes = patches_dict['boxes'][pix] for box in patch_boxes: # add unique patch id for consolidation of predictions. box['patch_id'] = self.rank_ix + '_' + n_aug + '_' + str(pix) # boxes from the edges of a patch have a lower prediction quality, than the ones at patch-centers. # hence they will be downweighted for consolidation, using the 'box_patch_center_factor', which is # obtained by a normal distribution over positions in the patch and average over spatial dimensions. # Also the info 'box_n_overlaps' is stored for consolidation, which depicts the amount over # overlapping patches at the box's position. c = box['box_coords'] box_centers = np.array([(c[ii+2] - c[ii])/2 for ii in range(len(c)//2)]) box['box_patch_center_factor'] = np.mean( [norm.pdf(bc, loc=pc, scale=pc * 0.8) * np.sqrt(2 * np.pi) * pc * 0.8 for bc, pc in zip(box_centers, np.array(self.cf.patch_size) / 2)]) if self.cf.dim == 3: c += np.array([pc[0], pc[2], pc[0], pc[2], pc[4], pc[4]]) int_c = [int(np.floor(ii)) if ix%2 == 0 else int(np.ceil(ii)) for ix, ii in enumerate(c)] box['box_n_overlaps'] = np.mean(patch_overlap_map[:, :, int_c[1]:int_c[3], int_c[0]:int_c[2], int_c[4]:int_c[5]]) results_dict['boxes'][0].append(box) else: c += np.array([pc[0], pc[2], pc[0], pc[2]]) int_c = [int(np.floor(ii)) if ix % 2 == 0 else int(np.ceil(ii)) for ix, ii in enumerate(c)] box['box_n_overlaps'] = np.mean(patch_overlap_map[pc[4], :, int_c[1]:int_c[3], int_c[0]:int_c[2]]) results_dict['boxes'][pc[4]].append(box) if self.mode == 'val': results_dict['monitor_values'] = patches_dict['monitor_values'] # if predictions are not patch-based: # add patch-origin info to boxes (entire image is the same patch with overlap=1) and return results. else: results_dict = self.batch_tiling_forward(batch) for b in results_dict['boxes']: for box in b: box['box_patch_center_factor'] = 1 box['box_n_overlaps'] = 1 box['patch_id'] = self.rank_ix + '_' + n_aug return results_dict def batch_tiling_forward(self, batch): """ calls the actual network forward method. in patch-based prediction, the batch dimension might be overladed with n_patches >> batch_size, which would exceed gpu memory. In this case, batches are processed in chunks of batch_size. validation mode calls the train method to monitor losses (returned ground truth objects are discarded). test mode calls the test forward method, no ground truth required / involved. :return. results_dict: stores the results for one patient. dictionary with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) - monitor_values (only in validation mode) """ self.logger.info('forwarding (patched) patient with shape: {}'.format(batch['data'].shape)) img = batch['data'] if img.shape[0] <= self.cf.batch_size: if self.mode == 'val': # call training method to monitor losses results_dict = self.net.train_forward(batch, is_validation=True) # discard returned ground-truth boxes (also training info boxes). results_dict['boxes'] = [[box for box in b if box['box_type'] == 'det'] for b in results_dict['boxes']] else: results_dict = self.net.test_forward(batch, return_masks=self.cf.return_masks_in_test) else: split_ixs = np.split(np.arange(img.shape[0]), np.arange(img.shape[0])[::self.cf.batch_size]) chunk_dicts = [] for chunk_ixs in split_ixs[1:]: # first split is elements before 0, so empty b = {k: batch[k][chunk_ixs] for k in batch.keys() if (isinstance(batch[k], np.ndarray) and batch[k].shape[0] == img.shape[0])} if self.mode == 'val': chunk_dicts += [self.net.train_forward(b, is_validation=True)] else: chunk_dicts += [self.net.test_forward(b, return_masks=self.cf.return_masks_in_test)] results_dict = {} # flatten out batch elements from chunks ([chunk, chunk] -> [b, b, b, b, ...]) results_dict['boxes'] = [item for d in chunk_dicts for item in d['boxes']] results_dict['seg_preds'] = np.array([item for d in chunk_dicts for item in d['seg_preds']]) if self.mode == 'val': # estimate metrics by mean over batch_chunks. Most similar to training metrics. results_dict['monitor_values'] = \ {k:np.mean([d['monitor_values'][k] for d in chunk_dicts]) for k in chunk_dicts[0]['monitor_values'].keys()} # discard returned ground-truth boxes (also training info boxes). results_dict['boxes'] = [[box for box in b if box['box_type'] == 'det'] for b in results_dict['boxes']] return results_dict def apply_wbc_to_patient(inputs): """ wrapper around prediction box consolidation: weighted cluster scoring (wcs). processes a single patient. loops over batch elements in patient results (1 in 3D, slices in 2D) and foreground classes, aggregates and stores results in new list. :return. patient_results_list: list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, and a dummy batch dimension of 1 for 3D predictions. :return. pid: string. patient id. """ in_patient_results_list, pid, class_dict, wcs_iou, n_ens = inputs out_patient_results_list = [[] for _ in range(len(in_patient_results_list))] for bix, b in enumerate(in_patient_results_list): for cl in list(class_dict.keys()): boxes = [(ix, box) for ix, box in enumerate(b) if (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)] box_coords = np.array([b[1]['box_coords'] for b in boxes]) box_scores = np.array([b[1]['box_score'] for b in boxes]) box_center_factor = np.array([b[1]['box_patch_center_factor'] for b in boxes]) box_n_overlaps = np.array([b[1]['box_n_overlaps'] for b in boxes]) box_patch_id = np.array([b[1]['patch_id'] for b in boxes]) if 0 not in box_scores.shape: keep_scores, keep_coords = weighted_box_clustering( np.concatenate((box_coords, box_scores[:, None], box_center_factor[:, None], box_n_overlaps[:, None]), axis=1), box_patch_id, wcs_iou, n_ens) for boxix in range(len(keep_scores)): out_patient_results_list[bix].append({'box_type': 'det', 'box_coords': keep_coords[boxix], 'box_score': keep_scores[boxix], 'box_pred_class_id': cl}) # add gt boxes back to new output list. out_patient_results_list[bix].extend([box for box in b if box['box_type'] == 'gt']) return [out_patient_results_list, pid] def merge_2D_to_3D_preds_per_patient(inputs): """ wrapper around 2Dto3D merging operation. Processes a single patient. Takes 2D patient results (slices in batch dimension) and returns 3D patient results (dummy batch dimension of 1). Applies an adaption of Non-Maximum Surpression (Detailed methodology is described in nms_2to3D). :return. results_dict_boxes: list over batch elements (1 in 3D). each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. :return. pid: string. patient id. """ in_patient_results_list, pid, class_dict, merge_3D_iou = inputs out_patient_results_list = [] for cl in list(class_dict.keys()): boxes, slice_ids = [], [] # collect box predictions over batch dimension (slices) and store slice info as slice_ids. for bix, b in enumerate(in_patient_results_list): det_boxes = [(ix, box) for ix, box in enumerate(b) if (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)] boxes += det_boxes slice_ids += [bix] * len(det_boxes) box_coords = np.array([b[1]['box_coords'] for b in boxes]) box_scores = np.array([b[1]['box_score'] for b in boxes]) slice_ids = np.array(slice_ids) if 0 not in box_scores.shape: keep_ix, keep_z = nms_2to3D( np.concatenate((box_coords, box_scores[:, None], slice_ids[:, None]), axis=1), merge_3D_iou) else: keep_ix, keep_z = [], [] # store kept predictions in new results list and add corresponding z-dimension info to coordinates. for kix, kz in zip(keep_ix, keep_z): out_patient_results_list.append({'box_type': 'det', 'box_coords': list(box_coords[kix]) + kz, 'box_score': box_scores[kix], 'box_pred_class_id': cl}) out_patient_results_list += [box for b in in_patient_results_list for box in b if box['box_type'] == 'gt'] out_patient_results_list = [out_patient_results_list] # add dummy batch dimension 1 for 3D. return [out_patient_results_list, pid] def weighted_box_clustering(dets, box_patch_id, thresh, n_ens): """ consolidates overlapping predictions resulting from patch overlaps, test data augmentations and temporal ensembling. clusters predictions together with iou > thresh (like in NMS). Output score and coordinate for one cluster are the average weighted by individual patch center factors (how trustworthy is this candidate measured by how centered its position the patch is) and the size of the corresponding box. The number of expected predictions at a position is n_data_aug * n_temp_ens * n_overlaps_at_position (1 prediction per unique patch). Missing predictions at a cluster position are defined as the number of unique patches in the cluster, which did not contribute any predict any boxes. :param dets: (n_dets, (y1, x1, y2, x2, (z1), (z2), scores, box_pc_facts, box_n_ovs) :param thresh: threshold for iou_matching. :param n_ens: number of models, that are ensembled. (-> number of expected predicitions per position) :return: keep_scores: (n_keep) new scores of boxes to be kept. :return: keep_coords: (n_keep, (y1, x1, y2, x2, (z1), (z2)) new coordinates of boxes to be kept. """ dim = 2 if dets.shape[1] == 7 else 3 y1 = dets[:, 0] x1 = dets[:, 1] y2 = dets[:, 2] x2 = dets[:, 3] scores = dets[:, -3] box_pc_facts = dets[:, -2] box_n_ovs = dets[:, -1] areas = (y2 - y1 + 1) * (x2 - x1 + 1) if dim == 3: z1 = dets[:, 4] z2 = dets[:, 5] areas *= (z2 - z1 + 1) # order is the sorted index. maps order to index o[1] = 24 (rank1, ix 24) order = scores.argsort()[::-1] keep = [] keep_scores = [] keep_coords = [] while order.size > 0: i = order[0] # higehst scoring element xx1 = np.maximum(x1[i], x1[order]) yy1 = np.maximum(y1[i], y1[order]) xx2 = np.minimum(x2[i], x2[order]) yy2 = np.minimum(y2[i], y2[order]) w = np.maximum(0.0, xx2 - xx1 + 1) h = np.maximum(0.0, yy2 - yy1 + 1) inter = w * h if dim == 3: zz1 = np.maximum(z1[i], z1[order]) zz2 = np.minimum(z2[i], z2[order]) d = np.maximum(0.0, zz2 - zz1 + 1) inter *= d # overall between currently highest scoring box and all boxes. ovr = inter / (areas[i] + areas[order] - inter) # get all the predictions that match the current box to build one cluster. matches = np.argwhere(ovr > thresh) match_n_ovs = box_n_ovs[order[matches]] match_pc_facts = box_pc_facts[order[matches]] match_patch_id = box_patch_id[order[matches]] match_ov_facts = ovr[matches] match_areas = areas[order[matches]] match_scores = scores[order[matches]] # weight all socres in cluster by patch factors, and size. match_score_weights = match_ov_facts * match_areas * match_pc_facts match_scores *= match_score_weights # for the weigted average, scores have to be divided by the number of total expected preds at the position # of the current cluster. 1 Prediction per patch is expected. therefore, the number of ensembled models is # multiplied by the mean overlaps of patches at this position (boxes of the cluster might partly be # in areas of different overlaps). n_expected_preds = n_ens * np.mean(match_n_ovs) # the number of missing predictions is obtained as the number of patches, # which did not contribute any prediction to the current cluster. n_missing_preds = np.max((0, n_expected_preds - np.unique(match_patch_id).shape[0])) # missing preds are given the mean weighting # (expected prediction is the mean over all predictions in cluster). denom = np.sum(match_score_weights) + n_missing_preds * np.mean(match_score_weights) # compute weighted average score for the cluster avg_score = np.sum(match_scores) / denom # compute weighted average of coordinates for the cluster. now only take existing # predictions into account. avg_coords = [np.sum(y1[order[matches]] * match_scores) / np.sum(match_scores), np.sum(x1[order[matches]] * match_scores) / np.sum(match_scores), np.sum(y2[order[matches]] * match_scores) / np.sum(match_scores), np.sum(x2[order[matches]] * match_scores) / np.sum(match_scores)] if dim == 3: avg_coords.append(np.sum(z1[order[matches]] * match_scores) / np.sum(match_scores)) avg_coords.append(np.sum(z2[order[matches]] * match_scores) / np.sum(match_scores)) # some clusters might have very low scores due to high amounts of missing predictions. # filter out the with a conservative threshold, to speed up evaluation. if avg_score > 0.01: keep_scores.append(avg_score) keep_coords.append(avg_coords) # get index of all elements that were not matched and discard all others. inds = np.where(ovr <= thresh)[0] order = order[inds] return keep_scores, keep_coords def nms_2to3D(dets, thresh): """ Merges 2D boxes to 3D cubes. Therefore, boxes of all slices are projected into one slices. An adaptation of Non-maximum surpression is applied, where clusters are found (like in NMS) with an extra constrained, that surpressed boxes have to have 'connected' z-coordinates w.r.t the core slice (cluster center, highest scoring box). 'connected' z-coordinates are determined as the z-coordinates with predictions until the first coordinate, where no prediction was found. example: a cluster of predictions was found overlap > iou thresh in xy (like NMS). The z-coordinate of the highest scoring box is 50. Other predictions have 23, 46, 48, 49, 51, 52, 53, 56, 57. Only the coordinates connected with 50 are clustered to one cube: 48, 49, 51, 52, 53. (46 not because nothing was found in 47, so 47 is a 'hole', which interrupts the connection). Only the boxes corresponding to these coordinates are surpressed. All others are kept for building of further clusters. This algorithm works better with a certain min_confidence of predictions, because low confidence (e.g. noisy/cluttery) predictions can break the relatively strong assumption of defining cubes' z-boundaries at the first 'hole' in the cluster. :param dets: (n_detections, (y1, x1, y2, x2, scores, slice_id) :param thresh: iou matchin threshold (like in NMS). :return: keep: (n_keep) 1D tensor of indices to be kept. :return: keep_z: (n_keep, [z1, z2]) z-coordinates to be added to boxes, which are kept in order to form cubes. """ y1 = dets[:, 0] x1 = dets[:, 1] y2 = dets[:, 2] x2 = dets[:, 3] scores = dets[:, -2] slice_id = dets[:, -1] areas = (x2 - x1 + 1) * (y2 - y1 + 1) order = scores.argsort()[::-1] keep = [] keep_z = [] while order.size > 0: # order is the sorted index. maps order to index o[1] = 24 (rank1, ix 24) i = order[0] # pop higehst scoring element xx1 = np.maximum(x1[i], x1[order]) yy1 = np.maximum(y1[i], y1[order]) xx2 = np.minimum(x2[i], x2[order]) yy2 = np.minimum(y2[i], y2[order]) w = np.maximum(0.0, xx2 - xx1 + 1) h = np.maximum(0.0, yy2 - yy1 + 1) inter = w * h ovr = inter / (areas[i] + areas[order] - inter) matches = np.argwhere(ovr > thresh) # get all the elements that match the current box and have a lower score slice_ids = slice_id[order[matches]] core_slice = slice_id[int(i)] upper_wholes = [ii for ii in np.arange(core_slice, np.max(slice_ids)) if ii not in slice_ids] lower_wholes = [ii for ii in np.arange(np.min(slice_ids), core_slice) if ii not in slice_ids] max_valid_slice_id = np.min(upper_wholes) if len(upper_wholes) > 0 else np.max(slice_ids) min_valid_slice_id = np.max(lower_wholes) if len(lower_wholes) > 0 else np.min(slice_ids) z_matches = matches[(slice_ids <= max_valid_slice_id) & (slice_ids >= min_valid_slice_id)] z1 = np.min(slice_id[order[z_matches]]) - 1 z2 = np.max(slice_id[order[z_matches]]) + 1 keep.append(i) keep_z.append([z1, z2]) order = np.delete(order, z_matches, axis=0) return keep, keep_z def get_mirrored_patch_crops(patch_crops, org_img_shape): """ apply 3 mirrror transformations (x-axis, y-axis, x&y-axis) to given patch crop coordinates and return the transformed coordinates. Handles 2D and 3D coordinates. :param patch_crops: list of crops: each element is a list of coordinates for given crop [[y1, x1, ...], [y1, x1, ..]] :param org_img_shape: shape of patient volume used as world coordinates. :return: list of mirrored patch crops: lenght=3. each element is a list of transformed patch crops. """ mirrored_patch_crops = [] # y-axis transform. mirrored_patch_crops.append([[org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2], ii[3]] if len(ii) == 4 else [org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2], ii[3], ii[4], ii[5]] for ii in patch_crops]) # x-axis transform. mirrored_patch_crops.append([[ii[0], ii[1], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2]] if len(ii) == 4 else [ii[0], ii[1], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2], ii[4], ii[5]] for ii in patch_crops]) # y-axis and x-axis transform. mirrored_patch_crops.append([[org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2]] if len(ii) == 4 else [org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2], ii[4], ii[5]] for ii in patch_crops]) return mirrored_patch_crops