diff --git a/CodeDoc.odt b/CodeDoc.odt new file mode 100644 index 0000000..880f0c7 Binary files /dev/null and b/CodeDoc.odt differ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..6a62c4c --- /dev/null +++ b/README.md @@ -0,0 +1,180 @@ +## Introduction +This repository holds the code framework used in the paper Reg R-CNN: Lesion Detection and Grading under Noisy Labels [1]. +The framework is a fork of MIC's [medicaldetectiontoolkit](https://github.com/MIC-DKFZ/medicaldetectiontoolkit) with added regression +capabilities. + +As below figure shows, the regression capability allows for the preservation of ordinal relations in the training signal as opposed to a standard categorical classification loss like the cross entropy loss (see publication for details). +


+Network Reg R-CNN is a version of Mask R-CNN [2] but with a regressor in place of the object-class head (see figure below). In this scenario, the first stage makes foreground (fg) vs background (bg) detections, then the regression head determines the class on an ordinal scale. Consequently, prediction confidence scores are taken from the first stage as opposed to the head in the original Mask R-CNN. +


+ +In the configs file of a data set in the framework, you may set attribute self.prediction_tasks = ["task"] to a value "task" from ["class", "regression_bin", "regression"]. "class" produces the same behavior as the original framework, i.e., standard object-detection behavior. "regression" on the other hand, swaps the class head of network Mask R-CNN [2] for a regression head. Consequently, objects are identified as fg/bg and then the class is decided by the regressor. For the sake of comparability, "regression_bin" produces a similar behavior but with a classification head. Both methods should be evaluated with the (implemented) Average Viewpoint Precision instead of only Average Precision. + +Below you will found a description of the general framework operations and handling. Basic framework functionality and description are for the most part identical to the original [medicaldetectiontoolkit](https://github.com/MIC-DKFZ/medicaldetectiontoolkit). + +
+[1] Ramien, Gregor et al., "Reg R-CNN: Lesion Detection and Grading under Noisy Labels". In: UNSURE Workshop at MICCAI, 2019.
+[2] He, Kaiming, et al. "Mask R-CNN" ICCV, 2017
+
+ +## Overview +This is a comprehensive framework for object detection featuring: +- 2D + 3D implementations of common object detectors: e.g., Mask R-CNN [2], Retina Net [3], Retina U-Net [4]. +- Modular and light-weight structure ensuring sharing of all processing steps (incl. backbone architecture) for comparability of models. +- training with bounding box and/or pixel-wise annotations. +- dynamic patching and tiling of 2D + 3D images (for training and inference). +- weighted consolidation of box predictions across patch-overlaps, ensembles, and dimensions [4] or standard non-maximum suppression. +- monitoring + evaluation simultaneously on object and patient level. +- 2D + 3D output visualizations. +- integration of COCO mean average precision metric [5]. +- integration of MIC-DKFZ batch generators for extensive data augmentation [6]. +- possible evaluation of instance segmentation and/or semantic segmentation by dice scores. +
+ +[3] Lin, Tsung-Yi, et al. "Focal Loss for Dense Object Detection" TPAMI, 2018.
+[4] Jaeger, Paul et al. "Retina U-Net: Embarrassingly Simple Exploitation +of Segmentation Supervision for Medical Object Detection" , 2018 + +[5] https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py
+[6] https://github.com/MIC-DKFZ/batchgenerators

+ +## How to cite this code +Please cite the Reg R-CNN publication [1] or the original publication [4] depending on what features you use. + +## Installation +Setup package in virtual environment +``` +git clone https://github.com/MIC-DKFZ/RegRCNN.git. +cd RegRCNN +virtualenv -p python3 regrcnn_env +source regrcnn_env/bin/activate +pip install -e . +``` +We use two cuda functions: Non-Maximum Suppression (taken from [pytorch-faster-rcnn](https://github.com/ruotianluo/pytorch-faster-rcnn) and added adaption for 3D) and RoiAlign (taken from [RoiAlign](https://github.com/longcw/RoIAlign.pytorch), fixed according to [this bug report](https://hackernoon.com/how-tensorflows-tf-image-resize-stole-60-days-of-my-life-aba5eb093f35), and added adaption for 3D). In this framework, they come pre-compile for TitanX. If you have a different GPU you need to re-compile these functions: + + +| GPU | arch | +| --- | --- | +| TitanX | sm_52 | +| GTX 960M | sm_50 | +| GTX 1070 | sm_61 | +| GTX 1080 (Ti) | sm_61 | + +``` +cd cuda_functions/nms_xD/src/cuda/ +nvcc -c -o nms_kernel.cu.o nms_kernel.cu -x cu -Xcompiler -fPIC -arch=[arch] +cd ../../ +python build.py +cd ../ + +cd cuda_functions/roi_align_xD/roi_align/src/cuda/ +nvcc -c -o crop_and_resize_kernel.cu.o crop_and_resize_kernel.cu -x cu -Xcompiler -fPIC -arch=[arch] +cd ../../ +python build.py +cd ../../ +``` + +## Prepare the Data +This framework is meant for you to be able to train models on your own data sets. + +In order to include a data set in the framework, create a new folder in RegRCNN/datasets, for instance "example_data". Your data set needs to have a config file in the style of the provided example data sets "lidc" and "toy". It also needs a data loader meeting the same requirements as the provided examples. Likely, you will also need a preprocessing script that transforms your data (once per data set creation, i.e., not a repetitive operation) into a suitable and easily processable format. +Important requirements: +- The framework expects numpy arrays as data and segmentation ground truth input. +- Segmentations need to be suited for object detection, i.e., Regions of Interest (RoIs) need to be marked by integers (RoI-ID) in the segmentation volume (0 is background). Corresponding properties of a RoI, e.g., the "class_targets" need to be provided in a separate array or list with (RoI-ID - 1) corresponding to the index of the property in the list (-1 due to zero-indexing). Example: A data volume contains two RoIs. The second RoI is marked in the segmentation by number 2. The "class_targets" info associated with the data volume holds the list [2, 3]. Hence, RoI-ID 2 is assigned class 3. +- This framework uses a modified version of MIC's batchgenerators' segmentation-to-bounding-box conversion tool. In this version, "class_targets", i.e., object classes start at 1, 0 is reserved for background. Thus, if you use "ConvertSegToBoundingBoxCoordinates" classes in your preprocessed data need to start at 1, not 0. + + +Two example data loaders are provided in RegRCNN/datasets. The way I load data is to have a preprocessing script, which after preprocessing saves the data of whatever data type into numpy arrays (this is just run once). During training / testing, the data loader then loads these numpy arrays dynamically. Please note the data input side is meant to be customized by you according to your own needs and the provided data loaders are merely examples: LIDC has a powerful data loader that handles 2D/3D inputs and is optimized for patch-based training and inference. Due to the large data volumes of LIDC, this loader is slow. The provided toy data set, however, is light weight and a good starting point to get familiar with the framework. It is fully creatable from scratch within a few minutes with RegRCNN/datasets/toy/generate_toys.py. + +## Execute +1. Set I/O paths, model and training specifics in the configs file: RegRCNN/datasets/_your_dataset_/configs.py +2. i) Train the model: + + ``` + python exec.py --mode train --dataset_name your_dataset --exp_dir path/to/experiment/directory + ``` + This copies snapshots of configs and model to the specified exp_dir, where all outputs will be saved. By default, the data is split into 60% training and 20% validation and 20% testing data to perform a 5-fold cross validation (can be changed to hold-out test set in configs) and all folds will be trained iteratively. In order to train a single fold, specify it using the folds arg: + ``` + python exec.py --folds 0 1 2 .... # specify any combination of folds [0-configs.n_cv_splits] + ``` + ii) Alternatively, train and test consecutively: + ``` + python exec.py --mode train_test --dataset_name your_dataset --exp_dir path/to/experiment/directory + ``` +3. Run inference: + ``` + python exec.py --mode test --exp_dir path/to/experiment/directory + ``` + This runs the prediction pipeline and saves all results to exp_dir. +4. Additional settings: + - Check the args parser in exec.py to see which arguments and modes are available. + - E.g., you may pass ```-d``` or ```--dev``` to enable a short development run of the whole train_test procedure (small batch size, only one epoch, two folds, one test patient, etc.). + + +## Models + +This framework features models explored in [4] (implemented in 2D + 3D): The proposed Retina U-Net, a simple but effective Architecture fusing state-of-the-art semantic segmentation with object detection,

+


+also implementations of prevalent object detectors, such as Mask R-CNN, Faster R-CNN+ (Faster R-CNN w\ RoIAlign), Retina Net, Detection U-Net (a U-Net like segmentation architecture with heuristics for object detection.)


+


+ +## Training annotations +This framework features training with pixelwise and/or bounding box annotations. To overcome the issue of box coordinates in +data augmentation, we feed the annotation masks through data augmentation (create a pseudo mask, if only bounding box annotations provided) and draw the boxes afterwards.

+


+ + +The framework further handles two types of pixel-wise annotations: + +1. A label map with individual ROIs identified by increasing label values, accompanied by a vector containing in each position the class target for the lesion with the corresponding label (for this mode set get_rois_from_seg_flag = False when calling ConvertSegToBoundingBoxCoordinates in your Data Loader). This is usual use case as explained in section "Prepare the data". +2. A binary label map. There is only one foreground class and single lesions are not identified. All lesions have the same class target (foreground). In this case the data loader runs a Connected Component Labelling algorithm to create processable lesion - class target pairs on the fly (for this mode set get_rois_from_seg_flag = True when calling ConvertSegToBoundingBoxCoordinates in your data loader). + +## Prediction pipeline +This framework provides an inference module, which automatically handles patching of inputs, and tiling, ensembling, and weighted consolidation of output predictions:


+

+ + +## Consolidation of predictions +### Weighted Box Clustering +Multiple predictions of the same image (from test time augmentations, tested epochs and overlapping patches), result in a high amount of boxes (or cubes), which need to be consolidated. In semantic segmentation, the final output would typically be obtained by averaging every pixel over all predictions. As described in [4], **weighted box clustering** (WBC) does this for box predictions:
+



+



+ +To enable WBC, set self.clustering = "wbc" in your configs file. + +### Non-Maximum Suppression +Test-time predictions can alternatively be aggregated with standard non-maximum suppression. In your configs file, simply set self.clustering = "nms" instead of "wbc". + +As a further alternative you may also choose no test-time aggregation by setting self.clustering = None. + +## Visualization / Monitoring +In opposition to the original framework, this fork uses tensorboard for monitoring training and validation progress. Since, for now, the framework cannot easily be updated to pytorch >= 1.x, we need third-party package [tensorboardX](https://github.com/lanpa/tensorboardX) to use tensorboard with pytorch. + +You can set an applicable choice of implemented metrics like "ap" for Average Precision or "auc" for patient-level ROC-AUC in the configs under self.metrics = [...]. Metrics are then evaluated by evaluator.py and recorded in monitor_metrics. logger.metrics2tboard sends monitor_metrics to your tensorboard logfiles at the end of each epoch. +You need to separately start a virtual tensorboard server, pass it your experiment directory (or directories, but it crashes if its more than ~5 experiments) and navigate to the server address. (You can also read up on tensoardboard usage in the original documentation). + +### Example: +1. Activate your virtualenv where tensorboard is installed. +2. Start tensorboard server. For instance, your experiment directory is + _yourexpdir_:
+ ```tensorboard --port 6007 --logdir yourexpdir``` +3. Navigate to ```localhost:6007``` in your browser. + +### Output monitoring +For qualitative monitoring, example plots are saved to _yourexpdir_/plots for training and validation and _yourexpdir_/test/example_plots for testing. Note, that test-time example plots may contain unconsolidated predictions over test-time augmentations, thereby possibly showing many overlapping and/or noisy predictions. You may adapt/use separate file RegRCNN/inference_analysis.py to create clean and nice plots of (consolidated) test-time predictions. + +## Balancing Mechanism of Example Data Loader +The data loaders of the provided example data sets employ a custom mechanism with the goal of assembling target-balanced batches or training sequences. I.e., the amount of examples shown per target class should be near balance. + +The mechanism creates a sampling-likelihood distribution, as shown below, over all available patients (PIDs). At batch generation, some patients are drawn according to this distribution, others are drawn completely randomly (according to a uniform distribution across all patients). The ratio of uniformly and target-dependently drawn patients is set in your configs file by configs.batch_random_ratio. configs.balance_target determines which targets are considered for the balancing distribution. + +While the balancing distribution assigns probability 0 to empty patients (contains no object of desired target kind), the random ratio allows for inclusion of those empty patients in the training exposure. Experience has shown, that showing at least one foreground example in each batch is most critical, other properties have less impact. + +



+ +## Unittests +unittests.py contains some verification and testing procedures, which, however, need you to adjust paths in the TestCase classes before execution. Tests can be used, for instance, to verify if your cross-validation folds have been created correctly, or if separate experiments have the same fold splits. + + +# License +This framework is published under the [APACHE 2.0 License](https://github.com/MIC-DKFZ/RegRCNN/blob/master/LICENSE) \ No newline at end of file diff --git a/assets/.directory b/assets/.directory new file mode 100644 index 0000000..4e2d005 --- /dev/null +++ b/assets/.directory @@ -0,0 +1,4 @@ +[Dolphin] +Timestamp=2018,11,4,16,51,18 +Version=3 +ViewMode=1 diff --git a/code_optim/code_optim.py b/code_optim/code_optim.py new file mode 100644 index 0000000..2702b3c --- /dev/null +++ b/code_optim/code_optim.py @@ -0,0 +1,328 @@ +""" +Created at 04/02/19 13:50 +@author: gregor +""" +import plotting as plg + +import sys +import os +import pickle +import json, socket, subprocess, time, threading + +import numpy as np +import pandas as pd +import torch +from collections import OrderedDict +from matplotlib.lines import Line2D + +import utils.exp_utils as utils +import utils.model_utils as mutils +from predictor import Predictor +from evaluator import Evaluator + + +""" +Need to start this script as sudo for background logging thread to work (needs to set niceness<0) +""" + + +def measure_train_batch_loading(logger, batch_gen, iters=1, warm_up=20, is_val=False, out_dir=None): + torch.cuda.empty_cache() + timer_key = "val_fw" if is_val else "train_fw" + for i in range(warm_up): + batch = next(batch_gen) + print("\rloaded warm-up batch {}/{}".format(i+1, warm_up), end="", flush=True) + sysmetrics_start_ix = len(logger.sysmetrics.index) + for i in range(iters): + logger.time(timer_key) + batch = next(batch_gen) + print("\r{} batch {} loading took {:.3f}s.".format("val" if is_val else "train", i+1, + logger.time(timer_key)), end="", flush=True) + print("Total avg fw {:.2f}s".format(logger.get_time(timer_key)/iters)) + if out_dir is not None: + assert len(logger.sysmetrics[sysmetrics_start_ix:-1]) > 0, "train loading: empty df" + logger.sysmetrics[sysmetrics_start_ix:-1].to_pickle(os.path.join( + out_dir,"{}_loading.pickle".format("val" if is_val else "train"))) + return logger.sysmetrics[sysmetrics_start_ix:-1] + + +def measure_RPN(logger, net, batch, iters=1, warm_up=20, out_dir=None): + torch.cuda.empty_cache() + data = torch.from_numpy(batch["data"]).float().cuda() + fpn_outs = net.fpn(data) + rpn_feature_maps = [fpn_outs[i] for i in net.cf.pyramid_levels] + + for i in range(warm_up): + layer_outputs = [net.rpn(p_feats) for p_feats in rpn_feature_maps] + print("\rfinished warm-up batch {}/{}".format(i+1, warm_up), end="", flush=True) + sysmetrics_start_ix = len(logger.sysmetrics.index) + for i in range(iters): + logger.time("RPN_fw") + layer_outputs = [net.rpn(p_feats) for p_feats in rpn_feature_maps] + print("\r{} batch took {:.3f}s.".format("RPN", logger.time("RPN_fw")), end="", flush=True) + print("Total avg fw {:.2f}s".format(logger.get_time("RPN_fw")/iters)) + + if out_dir is not None: + assert len(logger.sysmetrics[sysmetrics_start_ix:-1])>0, "six {}, sysm ix {}".format(sysmetrics_start_ix, logger.sysmetrics.index) + logger.sysmetrics[sysmetrics_start_ix:-1].to_pickle(os.path.join(out_dir,"RPN_msrmts.pickle")) + return logger.sysmetrics[sysmetrics_start_ix:-1] + +def measure_FPN(logger, net, batch, iters=1, warm_up=20, out_dir=None): + torch.cuda.empty_cache() + data = torch.from_numpy(batch["data"]).float().cuda() + for i in range(warm_up): + outputs = net.fpn(data) + print("\rfinished warm-up batch {}/{}".format(i+1, warm_up), end="", flush=True) + sysmetrics_start_ix = len(logger.sysmetrics.index) + for i in range(iters): + logger.time("FPN_fw") + outputs = net.fpn(data) + #print("in mean thread", logger.sysmetrics.index) + print("\r{} batch took {:.3f}s.".format("FPN", logger.time("FPN_fw")), end="", flush=True) + print("Total avg fw {:.2f}s".format(logger.get_time("FPN_fw")/iters)) + + if out_dir is not None: + assert len(logger.sysmetrics[sysmetrics_start_ix:-1])>0, "six {}, sysm ix {}".format(sysmetrics_start_ix, logger.sysmetrics.index) + logger.sysmetrics[sysmetrics_start_ix:-1].to_pickle(os.path.join(out_dir,"FPN_msrmts.pickle")) + return logger.sysmetrics[sysmetrics_start_ix:-1] + +def measure_forward(logger, net, batch, iters=1, warm_up=20, out_dir=None): + torch.cuda.empty_cache() + data = torch.from_numpy(batch["data"]).float().cuda() + for i in range(warm_up): + outputs = net.forward(data) + print("\rfinished warm-up batch {}/{}".format(i+1, warm_up), end="", flush=True) + sysmetrics_start_ix = len(logger.sysmetrics.index) + for i in range(iters): + logger.time("net_fw") + outputs = net.forward(data) + print("\r{} batch took {:.3f}s.".format("forward", logger.time("net_fw")), end="", flush=True) + print("Total avg fw {:.2f}s".format(logger.get_time("net_fw")/iters)) + if out_dir is not None: + assert len(logger.sysmetrics[sysmetrics_start_ix:-1]) > 0, "fw: empty df" + logger.sysmetrics[sysmetrics_start_ix:-1].to_pickle(os.path.join(out_dir,"fw_msrmts.pickle")) + return logger.sysmetrics[sysmetrics_start_ix:-1].copy() + +def measure_train_forward(logger, net, batch, iters=1, warm_up=20, is_val=False, out_dir=None): + torch.cuda.empty_cache() + timer_key = "val_fw" if is_val else "train_fw" + optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay) + for i in range(warm_up): + results_dict = net.train_forward(batch) + print("\rfinished warm-up batch {}/{}".format(i+1, warm_up), end="", flush=True) + sysmetrics_start_ix = len(logger.sysmetrics.index) + for i in range(iters): + logger.time(timer_key) + if not is_val: + optimizer.zero_grad() + results_dict = net.train_forward(batch, is_validation=is_val) + #results_dict["torch_loss"] *= torch.rand(1).cuda() + if not is_val: + results_dict["torch_loss"].backward() + optimizer.step() + print("\r{} batch took {:.3f}s.".format("val" if is_val else "train", logger.time(timer_key)), end="", flush=True) + print("Total avg fw {:.2f}s".format(logger.get_time(timer_key)/iters)) + if out_dir is not None: + assert len(logger.sysmetrics[sysmetrics_start_ix:-1]) > 0, "train_fw: empty df" + logger.sysmetrics[sysmetrics_start_ix:-1].to_pickle(os.path.join( + out_dir,"{}_msrmts.pickle".format("val_fw" if is_val else "train_fwbw"))) + return logger.sysmetrics[sysmetrics_start_ix:-1].copy() + +def measure_train_fw_incl_batch_gen(logger, net, batch_gen, iters=1, warm_up=20, is_val=False, out_dir=None): + torch.cuda.empty_cache() + timer_key = "val_fw" if is_val else "train_fw" + for i in range(warm_up): + batch = next(batch_gen) + results_dict = net.train_forward(batch) + print("\rfinished warm-up batch {}/{}".format(i+1, warm_up), end="", flush=True) + sysmetrics_start_ix = len(logger.sysmetrics.index) + for i in range(iters): + logger.time(timer_key) + batch = next(batch_gen) + results_dict = net.train_forward(batch, is_validation=is_val) + if not is_val: + results_dict["torch_loss"].backward() + print("\r{} batch took {:.3f}s.".format("val" if is_val else "train", logger.time(timer_key)), end="", flush=True) + print("Total avg fw {:.2f}s".format(logger.get_time(timer_key)/iters)) + if out_dir is not None: + assert len(logger.sysmetrics[sysmetrics_start_ix:-1]) > 0, "train_fw incl batch: empty df" + logger.sysmetrics[sysmetrics_start_ix:-1].to_pickle(os.path.join( + out_dir,"{}_incl_batch_msrmts.pickle".format("val_fw" if is_val else "train_fwbw"))) + return logger.sysmetrics[sysmetrics_start_ix:-1] + + + +def measure_train_backward(cf, logger, net, batch, iters=1, warm_up=20, out_dir=None): + torch.cuda.empty_cache() + optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay) + results_dict = net.train_forward(batch, is_validation=False) + loss = results_dict["torch_loss"] + for i in range(warm_up): + loss.backward(retain_graph=True) + print("\rfinished warm-up batch {}/{}".format(i + 1, warm_up), end="", flush=True) + sysmetrics_start_ix = len(logger.sysmetrics.index) + for i in range(iters): + logger.time("train_bw") + optimizer.zero_grad() + loss.backward(retain_graph=True) + optimizer.step() + print("\r{} bw batch {} took {:.3f}s.".format("train", i+1, logger.time("train_bw")), end="", flush=True) + print("Total avg bw {:.2f}s".format(logger.get_time("train_bw") / iters)) + if out_dir is not None: + assert len(logger.sysmetrics[sysmetrics_start_ix:-1]) > 0, "train_bw: empty df" + logger.sysmetrics[sysmetrics_start_ix:-1].to_pickle(os.path.join(out_dir,"train_bw.pickle")) + return logger.sysmetrics[sysmetrics_start_ix:-1] + + + +def measure_test_forward(logger, net, batch, iters=1, return_masks=False): + torch.cuda.empty_cache() + for i in range(iters): + logger.time("test_fw") + results_dict = net.test_forward(batch, return_masks=return_masks) + print("\rtest batch took {:.3f}s.".format(logger.time("test_fw")), end="", flush=True) + print("Total avg test fw {:.2f}s".format(logger.get_time('test_fw')/iters)) + + +def perform_measurements(args, iters=20): + + cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, is_training=True, use_stored_settings=False) + + cf.exp_dir = args.exp_dir + + # pid = 1624 + # cf.fold = find_pid_in_splits(pid) + cf.fold = 0 + cf.merge_2D_to_3D_preds = False + cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold)) + + logger = utils.get_logger(cf.exp_dir, sysmetrics_interval=0.5) + model = utils.import_module('model', cf.model_path) + net = model.net(cf, logger).cuda() + test_predictor = Predictor(cf, None, logger, mode='test') + #cf.p_batchbalance = 0 + #cf.do_aug = False + batch_gens = data_loader.get_train_generators(cf, logger) + train_gen, val_gen = batch_gens['train'], batch_gens['val_sampling'] + test_gen = data_loader.get_test_generator(cf, logger)['test'] + weight_paths = [os.path.join(cf.fold_dir, '{}_best_params.pth'.format(rank)) for rank in + test_predictor.epoch_ranking] + + try: + pids = test_gen.dataset_pids + except: + pids = test_gen.generator.dataset_pids + print("pids in test set: ", pids) + pid = pids[0] + assert pid in pids + pid = "285" + + model_name = cf.model + + results_dir = "/home/gregor/Documents/medicaldetectiontoolkit/code_optim/"+model_name + os.makedirs(results_dir, exist_ok=True) + print("Model: {}.".format(model_name)) + #gpu_logger = utils.Nvidia_GPU_Logger() + #gpu_logger.start(interval=0.1) + #measure_train_batch_loading(logger, train_gen, iters=iters, out_dir=results_dir) + #measure_train_batch_loading(logger, val_gen, iters=iters, is_val=True, out_dir=results_dir) + #measure_RPN(logger, net, next(train_gen), iters=iters, out_dir=results_dir) + #measure_FPN(logger, net, next(train_gen), iters=iters, out_dir=results_dir) + #measure_forward(logger, net, next(train_gen), iters=iters, out_dir=results_dir) + measure_train_forward(logger, net, next(train_gen), iters=iters, out_dir=results_dir) #[['global_step', 'gpu_utilization (%)']] + #measure_train_forward(logger, net, next(val_gen), iters=iters, is_val=True, out_dir=results_dir) + #measure_train_fw_incl_batch_gen(logger, net, train_gen, iters=iters, out_dir=results_dir) + #measure_train_fw_incl_batch_gen(logger, net, val_gen, iters=iters, is_val=True, out_dir=results_dir) + #measure_train_backward(cf, logger, net, next(train_gen), iters=iters, out_dir=results_dir) + #measure_test_forward(logger, net, next(test_gen), iters=iters, return_masks=cf.return_masks_in_test) + + return results_dir, iters + +def plot_folder(cf, ax, results_dir, iters, markers='o', offset=(+0.01, -4)): + point_renaming = {"FPN_msrmts": ["FPN.forward", (offset[0], -4)], "fw_msrmts": "net.forward", + "train_bw": "backward+optimizer", + "train_fw_msrmts": "net.train_forward", + "train_fw_incl_batch": "train_fw+batch", "RPN_msrmts": "RPN.forward", + "train_fwbw_msrmts": ["train_fw+bw", (offset[0], +2)], + "val_fw_msrmts": ["val_fw", (offset[0], -4)], + "train_fwbw_incl_batch_msrmts": ["train_fw+bw+batchload", (offset[0], +2)], + "train_fwbw_incl_batch_aug_msrmts": ["train_fw+bw+batchload+aug", (-0.2, +2)], + "val_fw_incl_batch_msrmts": ["val_fw+batchload", (offset[0], -4)], + "val_loading": ["val_load", (-0.06, -4)], + "train_loading_wo_bal_fg_aug": ["train_load_w/o_bal,fg,aug", (offset[0], 2)], + "train_loading_wo_balancing": ["train_load_w/o_balancing", (-0.05, 2)], + "train_loading_wo_aug": ["train_load_w/o_aug", (offset[0], 2)], + "train_loading_wo_bal_fg": ["train_load_w/o_bal,fg", (offset[0], -4)], + "train_loading": ["train_load", (+0.01, -1.3)] + } + dfs = OrderedDict() + for file in os.listdir(results_dir): + if os.path.splitext(file)[-1]==".pickle": + dfs[file.split(os.sep)[-1].split(".")[0]] = pd.read_pickle(os.path.join(results_dir,file)) + + + for i, (name, df) in enumerate(dfs.items()): + time = (df["rel_time"].iloc[-1] - df["rel_time"].iloc[0])/iters + gpu_u = df["gpu_utilization (%)"].values.astype(int).mean() + + color = cf.color_palette[i%len(cf.color_palette)] + ax.scatter(time, gpu_u, color=color, marker=markers) + if name in point_renaming.keys(): + name = point_renaming[name] + if isinstance(name, list): + offset = name[1] + name = name[0] + ax.text(time+offset[0], gpu_u+offset[1], name, color=color) + +def analyze_measurements(cf, results_dir, iters, title=""): + fig, ax = plg.plt.subplots(1, 1) + + settings = [(results_dir, iters, 'o'), (os.path.join(results_dir, "200iters_pre_optim"), 200, 'v', (-0.08, 2)), + (os.path.join(results_dir, "200iters_after_optim"), 200, 'o')] + for args in settings: + plot_folder(cf, ax, *args) + labels = ["after optim", "pre optim"] + handles = [Line2D([0], [0], marker=settings[i][2], label=labels[i], color="w", markerfacecolor=cf.black, markersize=10) + for i in range(len(settings[:2]))] + plg.plt.legend(handles=handles, loc="best") + ax.set_xlim(0,ax.get_xlim()[1]*1.05) + ax.set_ylim(0, 100) + ax.set_ylabel("Mean GPU Utilization (%)") + ax.set_xlabel("Runtime (s)") + plg.plt.title(title+"GPU utilization vs Method Runtime\nMean Over {} Iterations".format(iters)) + + major_ticks = np.arange(0, 101, 10) + minor_ticks = np.arange(0, 101, 5) + ax.set_yticks(major_ticks) + ax.set_yticks(minor_ticks, minor=True) + ax.grid(which='minor', alpha=0.2) + ax.grid(which='major', alpha=0.5) + + + plg.plt.savefig(os.path.join(results_dir, "measurements.png")) + + + + return + + +if __name__=="__main__": + class Args(): + def __init__(self): + self.dataset_name = "datasets/prostate" + self.exp_dir = "datasets/prostate/experiments/dev" + self.server_env = False + + + args = Args() + + sys.path.append(args.dataset_name) + import data_loader + from configs import Configs + cf = configs(args.server_env) + iters = 200 + results_dir, iters = perform_measurements(args, iters=iters) + results_dir = "/home/gregor/Documents/medicaldetectiontoolkit/code_optim/" + cf.model + analyze_measurements(cf, results_dir, iters=iters, title=cf.model+": ") + + diff --git a/cuda_functions/nms_2D/__init__.py b/cuda_functions/nms_2D/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cuda_functions/nms_2D/_ext/__init__.py b/cuda_functions/nms_2D/_ext/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cuda_functions/nms_2D/_ext/nms/__init__.py b/cuda_functions/nms_2D/_ext/nms/__init__.py new file mode 100644 index 0000000..d71786f --- /dev/null +++ b/cuda_functions/nms_2D/_ext/nms/__init__.py @@ -0,0 +1,15 @@ + +from torch.utils.ffi import _wrap_function +from ._nms import lib as _lib, ffi as _ffi + +__all__ = [] +def _import_symbols(locals): + for symbol in dir(_lib): + fn = getattr(_lib, symbol) + if callable(fn): + locals[symbol] = _wrap_function(fn, _ffi) + else: + locals[symbol] = fn + __all__.append(symbol) + +_import_symbols(locals()) diff --git a/cuda_functions/nms_2D/_ext/nms/_nms.so b/cuda_functions/nms_2D/_ext/nms/_nms.so new file mode 100755 index 0000000..1856faf Binary files /dev/null and b/cuda_functions/nms_2D/_ext/nms/_nms.so differ diff --git a/cuda_functions/nms_2D/build.py b/cuda_functions/nms_2D/build.py new file mode 100644 index 0000000..4d9a96b --- /dev/null +++ b/cuda_functions/nms_2D/build.py @@ -0,0 +1,34 @@ +import os +import torch +from torch.utils.ffi import create_extension + + +sources = ['src/nms.c'] +headers = ['src/nms.h'] +defines = [] +with_cuda = False + +if torch.cuda.is_available(): + print('Including CUDA code.') + sources += ['src/nms_cuda.c'] + headers += ['src/nms_cuda.h'] + defines += [('WITH_CUDA', None)] + with_cuda = True + +this_file = os.path.dirname(os.path.realpath(__file__)) +print(this_file) +extra_objects = ['src/cuda/nms_kernel.cu.o'] +extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] + +ffi = create_extension( + '_ext.nms', + headers=headers, + sources=sources, + define_macros=defines, + relative_to=__file__, + with_cuda=with_cuda, + extra_objects=extra_objects +) + +if __name__ == '__main__': + ffi.build() diff --git a/cuda_functions/nms_2D/pth_nms.py b/cuda_functions/nms_2D/pth_nms.py new file mode 100644 index 0000000..bfdc29a --- /dev/null +++ b/cuda_functions/nms_2D/pth_nms.py @@ -0,0 +1,39 @@ +import torch +from ._ext import nms + + +def nms_gpu(dets, thresh): + """ + dets has to be a tensor + """ + + scores = dets[:, 4] + order = scores.sort(0, descending=True)[1] + dets = dets[order].contiguous() + + keep = torch.LongTensor(dets.size(0)) + num_out = torch.LongTensor(1) + nms.gpu_nms(keep, num_out, dets, thresh) + return order[keep[:num_out[0]].cuda()].contiguous() + + + +def nms_cpu(dets, thresh): + + dets = dets.cpu() + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.sort(0, descending=True)[1] + # order = torch.from_numpy(np.ascontiguousarray(scores.numpy().argsort()[::-1])).long() + + keep = torch.LongTensor(dets.size(0)) + num_out = torch.LongTensor(1) + nms.cpu_nms(keep, num_out, dets, order, areas, thresh) + + return keep[:num_out[0]] + diff --git a/cuda_functions/nms_2D/src/cuda/nms_kernel.cu b/cuda_functions/nms_2D/src/cuda/nms_kernel.cu new file mode 100644 index 0000000..1174f22 --- /dev/null +++ b/cuda_functions/nms_2D/src/cuda/nms_kernel.cu @@ -0,0 +1,87 @@ +// ------------------------------------------------------------------ +// Faster R-CNN +// Copyright (c) 2015 Microsoft +// Licensed under The MIT License [see fast-rcnn/LICENSE for details] +// Written by Shaoqing Ren +// ------------------------------------------------------------------ +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include +#include +#include "nms_kernel.h" + +__device__ inline float devIoU(float const * const a, float const * const b) { + float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]); + float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]); + float width = fmaxf(right - left + 1, 0.f), height = fmaxf(bottom - top + 1, 0.f); + float interS = width * height; + float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); + float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); + return interS / (Sa + Sb - interS); +} + +__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, + const float *dev_boxes, unsigned long long *dev_mask) { + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + + const int row_size = + fminf(n_boxes - row_start * threadsPerBlock, threadsPerBlock); + const int col_size = + fminf(n_boxes - col_start * threadsPerBlock, threadsPerBlock); + + __shared__ float block_boxes[threadsPerBlock * 5]; + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 5 + 0] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; + block_boxes[threadIdx.x * 5 + 1] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; + block_boxes[threadIdx.x * 5 + 2] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; + block_boxes[threadIdx.x * 5 + 3] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; + block_boxes[threadIdx.x * 5 + 4] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; + const float *cur_box = dev_boxes + cur_box_idx * 5; + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { + t |= 1ULL << i; + } + } + const int col_blocks = DIVUP(n_boxes, threadsPerBlock); + dev_mask[cur_box_idx * col_blocks + col_start] = t; + } +} + + +void _nms(int boxes_num, float * boxes_dev, + unsigned long long * mask_dev, float nms_overlap_thresh) { + + dim3 blocks(DIVUP(boxes_num, threadsPerBlock), + DIVUP(boxes_num, threadsPerBlock)); + dim3 threads(threadsPerBlock); + nms_kernel<<>>(boxes_num, + nms_overlap_thresh, + boxes_dev, + mask_dev); +} + +#ifdef __cplusplus +} +#endif diff --git a/cuda_functions/nms_2D/src/cuda/nms_kernel.cu.o b/cuda_functions/nms_2D/src/cuda/nms_kernel.cu.o new file mode 100644 index 0000000..00135bf Binary files /dev/null and b/cuda_functions/nms_2D/src/cuda/nms_kernel.cu.o differ diff --git a/cuda_functions/nms_2D/src/cuda/nms_kernel.h b/cuda_functions/nms_2D/src/cuda/nms_kernel.h new file mode 100644 index 0000000..2f40582 --- /dev/null +++ b/cuda_functions/nms_2D/src/cuda/nms_kernel.h @@ -0,0 +1,19 @@ +#ifndef _NMS_KERNEL +#define _NMS_KERNEL + +#ifdef __cplusplus +extern "C" { +#endif + +#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) +int const threadsPerBlock = sizeof(unsigned long long) * 8; + +void _nms(int boxes_num, float * boxes_dev, + unsigned long long * mask_dev, float nms_overlap_thresh); + +#ifdef __cplusplus +} +#endif + +#endif + diff --git a/cuda_functions/nms_2D/src/nms.c b/cuda_functions/nms_2D/src/nms.c new file mode 100644 index 0000000..4795cc1 --- /dev/null +++ b/cuda_functions/nms_2D/src/nms.c @@ -0,0 +1,69 @@ +#include +#include + +int cpu_nms(THLongTensor * keep_out, THLongTensor * num_out, THFloatTensor * boxes, THLongTensor * order, THFloatTensor * areas, float nms_overlap_thresh) { + // boxes has to be sorted + THArgCheck(THLongTensor_isContiguous(keep_out), 0, "keep_out must be contiguous"); + THArgCheck(THLongTensor_isContiguous(boxes), 2, "boxes must be contiguous"); + THArgCheck(THLongTensor_isContiguous(order), 3, "order must be contiguous"); + THArgCheck(THLongTensor_isContiguous(areas), 4, "areas must be contiguous"); + // Number of ROIs + long boxes_num = THFloatTensor_size(boxes, 0); + long boxes_dim = THFloatTensor_size(boxes, 1); + + long * keep_out_flat = THLongTensor_data(keep_out); + float * boxes_flat = THFloatTensor_data(boxes); + long * order_flat = THLongTensor_data(order); + float * areas_flat = THFloatTensor_data(areas); + + THByteTensor* suppressed = THByteTensor_newWithSize1d(boxes_num); + THByteTensor_fill(suppressed, 0); + unsigned char * suppressed_flat = THByteTensor_data(suppressed); + + // nominal indices + int i, j; + // sorted indices + int _i, _j; + // temp variables for box i's (the box currently under consideration) + float ix1, iy1, ix2, iy2, iarea; + // variables for computing overlap with box j (lower scoring box) + float xx1, yy1, xx2, yy2; + float w, h; + float inter, ovr; + + long num_to_keep = 0; + for (_i=0; _i < boxes_num; ++_i) { + i = order_flat[_i]; + if (suppressed_flat[i] == 1) { + continue; + } + keep_out_flat[num_to_keep++] = i; + ix1 = boxes_flat[i * boxes_dim]; + iy1 = boxes_flat[i * boxes_dim + 1]; + ix2 = boxes_flat[i * boxes_dim + 2]; + iy2 = boxes_flat[i * boxes_dim + 3]; + iarea = areas_flat[i]; + for (_j = _i + 1; _j < boxes_num; ++_j) { + j = order_flat[_j]; + if (suppressed_flat[j] == 1) { + continue; + } + xx1 = fmaxf(ix1, boxes_flat[j * boxes_dim]); + yy1 = fmaxf(iy1, boxes_flat[j * boxes_dim + 1]); + xx2 = fminf(ix2, boxes_flat[j * boxes_dim + 2]); + yy2 = fminf(iy2, boxes_flat[j * boxes_dim + 3]); + w = fmaxf(0.0, xx2 - xx1 + 1); + h = fmaxf(0.0, yy2 - yy1 + 1); + inter = w * h; + ovr = inter / (iarea + areas_flat[j] - inter); + if (ovr >= nms_overlap_thresh) { + suppressed_flat[j] = 1; + } + } + } + + long *num_out_flat = THLongTensor_data(num_out); + *num_out_flat = num_to_keep; + THByteTensor_free(suppressed); + return 1; +} \ No newline at end of file diff --git a/cuda_functions/nms_2D/src/nms.h b/cuda_functions/nms_2D/src/nms.h new file mode 100644 index 0000000..25ca0a3 --- /dev/null +++ b/cuda_functions/nms_2D/src/nms.h @@ -0,0 +1 @@ +int cpu_nms(THLongTensor * keep_out, THLongTensor * num_out, THFloatTensor * boxes, THLongTensor * order, THFloatTensor * areas, float nms_overlap_thresh); \ No newline at end of file diff --git a/cuda_functions/nms_2D/src/nms_cuda.c b/cuda_functions/nms_2D/src/nms_cuda.c new file mode 100644 index 0000000..5a9a70f --- /dev/null +++ b/cuda_functions/nms_2D/src/nms_cuda.c @@ -0,0 +1,67 @@ +// ------------------------------------------------------------------ +// Faster R-CNN +// Copyright (c) 2015 Microsoft +// Licensed under The MIT License [see fast-rcnn/LICENSE for details] +// Written by Shaoqing Ren +// ------------------------------------------------------------------ +#include +#include +#include +#include + +#include "cuda/nms_kernel.h" + + +extern THCState *state; + +int gpu_nms(THLongTensor * keep, THLongTensor* num_out, THCudaTensor * boxes, float nms_overlap_thresh) { + // boxes has to be sorted + THArgCheck(THLongTensor_isContiguous(keep), 0, "boxes must be contiguous"); + THArgCheck(THCudaTensor_isContiguous(state, boxes), 2, "boxes must be contiguous"); + // Number of ROIs + int boxes_num = THCudaTensor_size(state, boxes, 0); + int boxes_dim = THCudaTensor_size(state, boxes, 1); + + float* boxes_flat = THCudaTensor_data(state, boxes); + + const int col_blocks = DIVUP(boxes_num, threadsPerBlock); + THCudaLongTensor * mask = THCudaLongTensor_newWithSize2d(state, boxes_num, col_blocks); + unsigned long long* mask_flat = THCudaLongTensor_data(state, mask); + + _nms(boxes_num, boxes_flat, mask_flat, nms_overlap_thresh); + + THLongTensor * mask_cpu = THLongTensor_newWithSize2d(boxes_num, col_blocks); + THLongTensor_copyCuda(state, mask_cpu, mask); + THCudaLongTensor_free(state, mask); + + unsigned long long * mask_cpu_flat = THLongTensor_data(mask_cpu); + + THLongTensor * remv_cpu = THLongTensor_newWithSize1d(col_blocks); + unsigned long long* remv_cpu_flat = THLongTensor_data(remv_cpu); + THLongTensor_fill(remv_cpu, 0); + + long * keep_flat = THLongTensor_data(keep); + long num_to_keep = 0; + + int i, j; + for (i = 0; i < boxes_num; i++) { + int nblock = i / threadsPerBlock; + int inblock = i % threadsPerBlock; + + if (!(remv_cpu_flat[nblock] & (1ULL << inblock))) { + keep_flat[num_to_keep++] = i; + unsigned long long *p = &mask_cpu_flat[0] + i * col_blocks; + for (j = nblock; j < col_blocks; j++) { + remv_cpu_flat[j] |= p[j]; + } + } + } + + long * num_out_flat = THLongTensor_data(num_out); + * num_out_flat = num_to_keep; + + THLongTensor_free(mask_cpu); + THLongTensor_free(remv_cpu); + + return 1; +} diff --git a/cuda_functions/nms_2D/src/nms_cuda.h b/cuda_functions/nms_2D/src/nms_cuda.h new file mode 100644 index 0000000..0826111 --- /dev/null +++ b/cuda_functions/nms_2D/src/nms_cuda.h @@ -0,0 +1 @@ +int gpu_nms(THLongTensor * keep_out, THLongTensor* num_out, THCudaTensor * boxes, float nms_overlap_thresh); \ No newline at end of file diff --git a/cuda_functions/nms_3D/__init__.py b/cuda_functions/nms_3D/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cuda_functions/nms_3D/_ext/__init__.py b/cuda_functions/nms_3D/_ext/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cuda_functions/nms_3D/_ext/nms/__init__.py b/cuda_functions/nms_3D/_ext/nms/__init__.py new file mode 100644 index 0000000..d71786f --- /dev/null +++ b/cuda_functions/nms_3D/_ext/nms/__init__.py @@ -0,0 +1,15 @@ + +from torch.utils.ffi import _wrap_function +from ._nms import lib as _lib, ffi as _ffi + +__all__ = [] +def _import_symbols(locals): + for symbol in dir(_lib): + fn = getattr(_lib, symbol) + if callable(fn): + locals[symbol] = _wrap_function(fn, _ffi) + else: + locals[symbol] = fn + __all__.append(symbol) + +_import_symbols(locals()) diff --git a/cuda_functions/nms_3D/_ext/nms/_nms.so b/cuda_functions/nms_3D/_ext/nms/_nms.so new file mode 100755 index 0000000..c8498a0 Binary files /dev/null and b/cuda_functions/nms_3D/_ext/nms/_nms.so differ diff --git a/cuda_functions/nms_3D/build.py b/cuda_functions/nms_3D/build.py new file mode 100644 index 0000000..4d9a96b --- /dev/null +++ b/cuda_functions/nms_3D/build.py @@ -0,0 +1,34 @@ +import os +import torch +from torch.utils.ffi import create_extension + + +sources = ['src/nms.c'] +headers = ['src/nms.h'] +defines = [] +with_cuda = False + +if torch.cuda.is_available(): + print('Including CUDA code.') + sources += ['src/nms_cuda.c'] + headers += ['src/nms_cuda.h'] + defines += [('WITH_CUDA', None)] + with_cuda = True + +this_file = os.path.dirname(os.path.realpath(__file__)) +print(this_file) +extra_objects = ['src/cuda/nms_kernel.cu.o'] +extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] + +ffi = create_extension( + '_ext.nms', + headers=headers, + sources=sources, + define_macros=defines, + relative_to=__file__, + with_cuda=with_cuda, + extra_objects=extra_objects +) + +if __name__ == '__main__': + ffi.build() diff --git a/cuda_functions/nms_3D/pth_nms.py b/cuda_functions/nms_3D/pth_nms.py new file mode 100644 index 0000000..3639b5b --- /dev/null +++ b/cuda_functions/nms_3D/pth_nms.py @@ -0,0 +1,38 @@ +import torch +from ._ext import nms + + +def nms_gpu(dets, thresh): + """ + dets has to be a tensor + """ + + scores = dets[:, -1] + order = scores.sort(0, descending=True)[1] + dets = dets[order].contiguous() + + keep = torch.LongTensor(dets.size(0)) + num_out = torch.LongTensor(1) + nms.gpu_nms(keep, num_out, dets, thresh) + return order[keep[:num_out[0]].cuda()].contiguous() + + +def nms_cpu(dets, thresh): + + dets = dets.cpu() + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + z1 = dets[:, 4] + z2 = dets[:, 5] + scores = dets[:, 6] + areas = (x2 - x1 +1) * (y2 - y1 +1) * (z2 - z1 +1) + order = scores.sort(0, descending=True)[1] + + keep = torch.LongTensor(dets.size(0)) + num_out = torch.LongTensor(1) + nms.cpu_nms(keep, num_out, dets, order, areas, thresh) + + return keep[:num_out[0]] + diff --git a/cuda_functions/nms_3D/src/cuda/nms_kernel.cu b/cuda_functions/nms_3D/src/cuda/nms_kernel.cu new file mode 100644 index 0000000..5692de8 --- /dev/null +++ b/cuda_functions/nms_3D/src/cuda/nms_kernel.cu @@ -0,0 +1,96 @@ +// ------------------------------------------------------------------ +// Faster R-CNN +// Copyright (c) 2015 Microsoft +// Licensed under The MIT License [see fast-rcnn/LICENSE for details] +// Written by Shaoqing Ren +// ------------------------------------------------------------------ +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include +#include +#include "nms_kernel.h" + +__device__ inline float devIoU(float const * const a, float const * const b) { + float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]); + float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]); + float front = fmaxf(a[4], b[4]), back = fminf(a[5], b[5]); + + float width = fmaxf(right - left + 1, 0.f), height = fmaxf(bottom - top + 1, 0.f), depth = fmaxf(back - front + 1, 0.f); + float interS = width * height * depth; + float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1) * (a[5] - a[4] + 1); + float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1) * (b[5] - b[4] + 1); + //printf("IoU 3D %f \n", interS / (Sa + Sb - interS)); + + return interS / (Sa + Sb - interS); +} + +__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, + const float *dev_boxes, unsigned long long *dev_mask) { + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + + const int row_size = + fminf(n_boxes - row_start * threadsPerBlock, threadsPerBlock); + const int col_size = + fminf(n_boxes - col_start * threadsPerBlock, threadsPerBlock); + + __shared__ float block_boxes[threadsPerBlock * 7]; + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 7 + 0] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 0]; + block_boxes[threadIdx.x * 7 + 1] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 1]; + block_boxes[threadIdx.x * 7 + 2] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 2]; + block_boxes[threadIdx.x * 7 + 3] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 3]; + block_boxes[threadIdx.x * 7 + 4] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 4]; + block_boxes[threadIdx.x * 7 + 5] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 5]; + block_boxes[threadIdx.x * 7 + 6] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 7 + 6]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; + const float *cur_box = dev_boxes + cur_box_idx * 7; + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (devIoU(cur_box, block_boxes + i * 7) > nms_overlap_thresh) { + t |= 1ULL << i; + } + } + const int col_blocks = DIVUP(n_boxes, threadsPerBlock); + dev_mask[cur_box_idx * col_blocks + col_start] = t; + } +} + + +void _nms(int boxes_num, float * boxes_dev, + unsigned long long * mask_dev, float nms_overlap_thresh) { + + + dim3 blocks(DIVUP(boxes_num, threadsPerBlock), + DIVUP(boxes_num, threadsPerBlock)); + dim3 threads(threadsPerBlock); + nms_kernel<<>>(boxes_num, + nms_overlap_thresh, + boxes_dev, + mask_dev); +} + +#ifdef __cplusplus +} +#endif diff --git a/cuda_functions/nms_3D/src/cuda/nms_kernel.cu.o b/cuda_functions/nms_3D/src/cuda/nms_kernel.cu.o new file mode 100644 index 0000000..ee3ed41 Binary files /dev/null and b/cuda_functions/nms_3D/src/cuda/nms_kernel.cu.o differ diff --git a/cuda_functions/nms_3D/src/cuda/nms_kernel.h b/cuda_functions/nms_3D/src/cuda/nms_kernel.h new file mode 100644 index 0000000..2f40582 --- /dev/null +++ b/cuda_functions/nms_3D/src/cuda/nms_kernel.h @@ -0,0 +1,19 @@ +#ifndef _NMS_KERNEL +#define _NMS_KERNEL + +#ifdef __cplusplus +extern "C" { +#endif + +#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) +int const threadsPerBlock = sizeof(unsigned long long) * 8; + +void _nms(int boxes_num, float * boxes_dev, + unsigned long long * mask_dev, float nms_overlap_thresh); + +#ifdef __cplusplus +} +#endif + +#endif + diff --git a/cuda_functions/nms_3D/src/nms.c b/cuda_functions/nms_3D/src/nms.c new file mode 100644 index 0000000..dd64336 --- /dev/null +++ b/cuda_functions/nms_3D/src/nms.c @@ -0,0 +1,74 @@ +#include +#include + + +int cpu_nms(THLongTensor * keep_out, THLongTensor * num_out, THFloatTensor * boxes, THLongTensor * order, THFloatTensor * areas, float nms_overlap_thresh) { + // boxes has to be sorted + THArgCheck(THLongTensor_isContiguous(keep_out), 0, "keep_out must be contiguous"); + THArgCheck(THLongTensor_isContiguous(boxes), 2, "boxes must be contiguous"); + THArgCheck(THLongTensor_isContiguous(order), 3, "order must be contiguous"); + THArgCheck(THLongTensor_isContiguous(areas), 4, "areas must be contiguous"); + // Number of ROIs + long boxes_num = THFloatTensor_size(boxes, 0); + long boxes_dim = THFloatTensor_size(boxes, 1); + + long * keep_out_flat = THLongTensor_data(keep_out); + float * boxes_flat = THFloatTensor_data(boxes); + long * order_flat = THLongTensor_data(order); + float * areas_flat = THFloatTensor_data(areas); + + THByteTensor* suppressed = THByteTensor_newWithSize1d(boxes_num); + THByteTensor_fill(suppressed, 0); + unsigned char * suppressed_flat = THByteTensor_data(suppressed); + // nominal indices + int i, j; + // sorted indices + int _i, _j; + // temp variables for box i's (the box currently under consideration) + float ix1, iy1, ix2, iy2, iz1, iz2, iarea; + // variables for computing overlap with box j (lower scoring box) + float xx1, yy1, xx2, yy2, zz1, zz2; + float w, h, d; + float inter, ovr; + + long num_to_keep = 0; + for (_i=0; _i < boxes_num; ++_i) { + i = order_flat[_i]; // from sorted index to nominal index in boxes list. + if (suppressed_flat[i] == 1) { //maybe flag for later. overlapping boxes are surpressed. + continue; + } + keep_out_flat[num_to_keep++] = i; //num to keep is read and then increased. the box index i is saved in keep_out. + ix1 = boxes_flat[i * boxes_dim]; + iy1 = boxes_flat[i * boxes_dim + 1]; + ix2 = boxes_flat[i * boxes_dim + 2]; + iy2 = boxes_flat[i * boxes_dim + 3]; + iz1 = boxes_flat[i * boxes_dim + 4]; + iz2 = boxes_flat[i * boxes_dim + 5]; + iarea = areas_flat[i]; + for (_j = _i + 1; _j < boxes_num; ++_j) { + j = order_flat[_j]; + if (suppressed_flat[j] == 1) { + continue; + } + xx1 = fmaxf(ix1, boxes_flat[j * boxes_dim]); + yy1 = fmaxf(iy1, boxes_flat[j * boxes_dim + 1]); + xx2 = fminf(ix2, boxes_flat[j * boxes_dim + 2]); + yy2 = fminf(iy2, boxes_flat[j * boxes_dim + 3]); + zz1 = fmaxf(iz1, boxes_flat[j * boxes_dim + 4]); + zz2 = fminf(iz2, boxes_flat[j * boxes_dim + 5]); + w = fmaxf(0.0, xx2 - xx1 + 1); + h = fmaxf(0.0, yy2 - yy1 + 1); + d = fmaxf(0.0, zz2 - zz1 + 1); + inter = w * h * d; + ovr = inter / (iarea + areas_flat[j] - inter); + if (ovr >= nms_overlap_thresh) { + suppressed_flat[j] = 1; // can be surpressed because score j < score i (from order: _j = _i + 1 ...) + } + } + } + + long *num_out_flat = THLongTensor_data(num_out); + *num_out_flat = num_to_keep; + THByteTensor_free(suppressed); + return 1; +} \ No newline at end of file diff --git a/cuda_functions/nms_3D/src/nms.h b/cuda_functions/nms_3D/src/nms.h new file mode 100644 index 0000000..d17d9c9 --- /dev/null +++ b/cuda_functions/nms_3D/src/nms.h @@ -0,0 +1 @@ +int cpu_nms(THLongTensor * keep_out, THLongTensor * num_out, THFloatTensor * boxes, THLongTensor * order, THFloatTensor * areas, float nms_overlap_thresh); diff --git a/cuda_functions/nms_3D/src/nms_cuda.c b/cuda_functions/nms_3D/src/nms_cuda.c new file mode 100644 index 0000000..5a9a70f --- /dev/null +++ b/cuda_functions/nms_3D/src/nms_cuda.c @@ -0,0 +1,67 @@ +// ------------------------------------------------------------------ +// Faster R-CNN +// Copyright (c) 2015 Microsoft +// Licensed under The MIT License [see fast-rcnn/LICENSE for details] +// Written by Shaoqing Ren +// ------------------------------------------------------------------ +#include +#include +#include +#include + +#include "cuda/nms_kernel.h" + + +extern THCState *state; + +int gpu_nms(THLongTensor * keep, THLongTensor* num_out, THCudaTensor * boxes, float nms_overlap_thresh) { + // boxes has to be sorted + THArgCheck(THLongTensor_isContiguous(keep), 0, "boxes must be contiguous"); + THArgCheck(THCudaTensor_isContiguous(state, boxes), 2, "boxes must be contiguous"); + // Number of ROIs + int boxes_num = THCudaTensor_size(state, boxes, 0); + int boxes_dim = THCudaTensor_size(state, boxes, 1); + + float* boxes_flat = THCudaTensor_data(state, boxes); + + const int col_blocks = DIVUP(boxes_num, threadsPerBlock); + THCudaLongTensor * mask = THCudaLongTensor_newWithSize2d(state, boxes_num, col_blocks); + unsigned long long* mask_flat = THCudaLongTensor_data(state, mask); + + _nms(boxes_num, boxes_flat, mask_flat, nms_overlap_thresh); + + THLongTensor * mask_cpu = THLongTensor_newWithSize2d(boxes_num, col_blocks); + THLongTensor_copyCuda(state, mask_cpu, mask); + THCudaLongTensor_free(state, mask); + + unsigned long long * mask_cpu_flat = THLongTensor_data(mask_cpu); + + THLongTensor * remv_cpu = THLongTensor_newWithSize1d(col_blocks); + unsigned long long* remv_cpu_flat = THLongTensor_data(remv_cpu); + THLongTensor_fill(remv_cpu, 0); + + long * keep_flat = THLongTensor_data(keep); + long num_to_keep = 0; + + int i, j; + for (i = 0; i < boxes_num; i++) { + int nblock = i / threadsPerBlock; + int inblock = i % threadsPerBlock; + + if (!(remv_cpu_flat[nblock] & (1ULL << inblock))) { + keep_flat[num_to_keep++] = i; + unsigned long long *p = &mask_cpu_flat[0] + i * col_blocks; + for (j = nblock; j < col_blocks; j++) { + remv_cpu_flat[j] |= p[j]; + } + } + } + + long * num_out_flat = THLongTensor_data(num_out); + * num_out_flat = num_to_keep; + + THLongTensor_free(mask_cpu); + THLongTensor_free(remv_cpu); + + return 1; +} diff --git a/cuda_functions/nms_3D/src/nms_cuda.h b/cuda_functions/nms_3D/src/nms_cuda.h new file mode 100644 index 0000000..08bf147 --- /dev/null +++ b/cuda_functions/nms_3D/src/nms_cuda.h @@ -0,0 +1 @@ +int gpu_nms(THLongTensor * keep_out, THLongTensor* num_out, THCudaTensor * boxes, float nms_overlap_thresh); diff --git a/cuda_functions/roi_align_2D/__init__.py b/cuda_functions/roi_align_2D/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cuda_functions/roi_align_2D/roi_align/__init__.py b/cuda_functions/roi_align_2D/roi_align/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cuda_functions/roi_align_2D/roi_align/_ext/__init__.py b/cuda_functions/roi_align_2D/roi_align/_ext/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cuda_functions/roi_align_2D/roi_align/_ext/crop_and_resize/__init__.py b/cuda_functions/roi_align_2D/roi_align/_ext/crop_and_resize/__init__.py new file mode 100644 index 0000000..4486c09 --- /dev/null +++ b/cuda_functions/roi_align_2D/roi_align/_ext/crop_and_resize/__init__.py @@ -0,0 +1,15 @@ + +from torch.utils.ffi import _wrap_function +from ._crop_and_resize import lib as _lib, ffi as _ffi + +__all__ = [] +def _import_symbols(locals): + for symbol in dir(_lib): + fn = getattr(_lib, symbol) + if callable(fn): + locals[symbol] = _wrap_function(fn, _ffi) + else: + locals[symbol] = fn + __all__.append(symbol) + +_import_symbols(locals()) diff --git a/cuda_functions/roi_align_2D/roi_align/_ext/crop_and_resize/_crop_and_resize.so b/cuda_functions/roi_align_2D/roi_align/_ext/crop_and_resize/_crop_and_resize.so new file mode 100755 index 0000000..e852f11 Binary files /dev/null and b/cuda_functions/roi_align_2D/roi_align/_ext/crop_and_resize/_crop_and_resize.so differ diff --git a/cuda_functions/roi_align_2D/roi_align/build.py b/cuda_functions/roi_align_2D/roi_align/build.py new file mode 100755 index 0000000..3798d82 --- /dev/null +++ b/cuda_functions/roi_align_2D/roi_align/build.py @@ -0,0 +1,40 @@ +import os +import torch +from torch.utils.ffi import create_extension + + +sources = ['src/crop_and_resize.c'] +headers = ['src/crop_and_resize.h'] +defines = [] +with_cuda = False + +extra_objects = [] +if torch.cuda.is_available(): + print('Including CUDA code.') + sources += ['src/crop_and_resize_gpu.c'] + headers += ['src/crop_and_resize_gpu.h'] + defines += [('WITH_CUDA', None)] + extra_objects += ['src/cuda/crop_and_resize_kernel.cu.o'] + with_cuda = True + +extra_compile_args = ['-fopenmp', '-std=c99'] + +this_file = os.path.dirname(os.path.realpath(__file__)) +print(this_file) +sources = [os.path.join(this_file, fname) for fname in sources] +headers = [os.path.join(this_file, fname) for fname in headers] +extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] + +ffi = create_extension( + '_ext.crop_and_resize', + headers=headers, + sources=sources, + define_macros=defines, + relative_to=__file__, + with_cuda=with_cuda, + extra_objects=extra_objects, + extra_compile_args=extra_compile_args +) + +if __name__ == '__main__': + ffi.build() diff --git a/cuda_functions/roi_align_2D/roi_align/crop_and_resize.py b/cuda_functions/roi_align_2D/roi_align/crop_and_resize.py new file mode 100755 index 0000000..4291ae4 --- /dev/null +++ b/cuda_functions/roi_align_2D/roi_align/crop_and_resize.py @@ -0,0 +1,66 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function + +from ._ext import crop_and_resize as _backend + + +class CropAndResizeFunction(Function): + + def __init__(self, crop_height, crop_width, extrapolation_value=0): + self.crop_height = crop_height + self.crop_width = crop_width + self.extrapolation_value = extrapolation_value + + def forward(self, image, boxes, box_ind): + crops = torch.zeros_like(image) + if image.is_cuda: + _backend.crop_and_resize_gpu_forward( + image, boxes, box_ind, + self.extrapolation_value, self.crop_height, self.crop_width, crops) + else: + _backend.crop_and_resize_forward( + image, boxes, box_ind, + self.extrapolation_value, self.crop_height, self.crop_width, crops) + + # save for backward + self.im_size = image.size() + self.save_for_backward(boxes, box_ind) + + return crops + + def backward(self, grad_outputs): + boxes, box_ind = self.saved_tensors + + grad_outputs = grad_outputs.contiguous() + grad_image = torch.zeros_like(grad_outputs).resize_(*self.im_size) + + if grad_outputs.is_cuda: + _backend.crop_and_resize_gpu_backward( + grad_outputs, boxes, box_ind, grad_image + ) + else: + _backend.crop_and_resize_backward( + grad_outputs, boxes, box_ind, grad_image + ) + + return grad_image, None, None + + +class CropAndResize(nn.Module): + """ + Crop and resize ported from tensorflow + See more details on https://www.tensorflow.org/api_docs/python/tf/image/crop_and_resize + """ + + def __init__(self, crop_height, crop_width, extrapolation_value=0): + super(CropAndResize, self).__init__() + + self.crop_height = crop_height + self.crop_width = crop_width + self.extrapolation_value = extrapolation_value + + def forward(self, image, boxes, box_ind): + return CropAndResizeFunction(self.crop_height, self.crop_width, self.extrapolation_value)(image, boxes, box_ind) diff --git a/cuda_functions/roi_align_2D/roi_align/roi_align.py b/cuda_functions/roi_align_2D/roi_align/roi_align.py new file mode 100644 index 0000000..6931539 --- /dev/null +++ b/cuda_functions/roi_align_2D/roi_align/roi_align.py @@ -0,0 +1,48 @@ +import torch +from torch import nn + +from .crop_and_resize import CropAndResizeFunction, CropAndResize + + +class RoIAlign(nn.Module): + + def __init__(self, crop_height, crop_width, extrapolation_value=0, transform_fpcoor=True): + super(RoIAlign, self).__init__() + + self.crop_height = crop_height + self.crop_width = crop_width + self.extrapolation_value = extrapolation_value + self.transform_fpcoor = transform_fpcoor + + def forward(self, featuremap, boxes, box_ind): + """ + RoIAlign based on crop_and_resize. + See more details on https://github.com/ppwwyyxx/tensorpack/blob/6d5ba6a970710eaaa14b89d24aace179eb8ee1af/examples/FasterRCNN/model.py#L301 + :param featuremap: NxCxHxW + :param boxes: Mx4 float box with (x1, y1, x2, y2) **without normalization** + :param box_ind: M + :return: MxCxoHxoW + """ + x1, y1, x2, y2 = torch.split(boxes, 1, dim=1) + image_height, image_width = featuremap.size()[2:4] + + if self.transform_fpcoor: + spacing_w = (x2 - x1) / float(self.crop_width) + spacing_h = (y2 - y1) / float(self.crop_height) + + nx0 = (x1 + spacing_w / 2 - 0.5) / float(image_width - 1) + ny0 = (y1 + spacing_h / 2 - 0.5) / float(image_height - 1) + nw = spacing_w * float(self.crop_width - 1) / float(image_width - 1) + nh = spacing_h * float(self.crop_height - 1) / float(image_height - 1) + + boxes = torch.cat((ny0, nx0, ny0 + nh, nx0 + nw), 1) + else: + x1 = x1 / float(image_width - 1) + x2 = x2 / float(image_width - 1) + y1 = y1 / float(image_height - 1) + y2 = y2 / float(image_height - 1) + boxes = torch.cat((y1, x1, y2, x2), 1) + + boxes = boxes.detach().contiguous() + box_ind = box_ind.detach() + return CropAndResizeFunction(self.crop_height, self.crop_width, self.extrapolation_value)(featuremap, boxes, box_ind) diff --git a/cuda_functions/roi_align_2D/roi_align/src/crop_and_resize.c b/cuda_functions/roi_align_2D/roi_align/src/crop_and_resize.c new file mode 100644 index 0000000..e1fce67 --- /dev/null +++ b/cuda_functions/roi_align_2D/roi_align/src/crop_and_resize.c @@ -0,0 +1,252 @@ +#include +#include +#include + + +void CropAndResizePerBox( + const float * image_data, + const int batch_size, + const int depth, + const int image_height, + const int image_width, + + const float * boxes_data, + const int * box_index_data, + const int start_box, + const int limit_box, + + float * corps_data, + const int crop_height, + const int crop_width, + const float extrapolation_value +) { + const int image_channel_elements = image_height * image_width; + const int image_elements = depth * image_channel_elements; + + const int channel_elements = crop_height * crop_width; + const int crop_elements = depth * channel_elements; + + int b; + #pragma omp parallel for + for (b = start_box; b < limit_box; ++b) { + const float * box = boxes_data + b * 4; + const float y1 = box[0]; + const float x1 = box[1]; + const float y2 = box[2]; + const float x2 = box[3]; + + const int b_in = box_index_data[b]; + if (b_in < 0 || b_in >= batch_size) { + printf("Error: batch_index %d out of range [0, %d)\n", b_in, batch_size); + exit(-1); + } + + const float height_scale = + (crop_height > 1) + ? (y2 - y1) * (image_height - 1) / (crop_height - 1) + : 0; + const float width_scale = + (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) + : 0; + + for (int y = 0; y < crop_height; ++y) + { + const float in_y = (crop_height > 1) + ? y1 * (image_height - 1) + y * height_scale + : 0.5 * (y1 + y2) * (image_height - 1); + + if (in_y < 0 || in_y > image_height - 1) + { + for (int x = 0; x < crop_width; ++x) + { + for (int d = 0; d < depth; ++d) + { + // crops(b, y, x, d) = extrapolation_value; + corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = extrapolation_value; + } + } + continue; + } + + const int top_y_index = floorf(in_y); + const int bottom_y_index = ceilf(in_y); + const float y_lerp = in_y - top_y_index; + + for (int x = 0; x < crop_width; ++x) + { + const float in_x = (crop_width > 1) + ? x1 * (image_width - 1) + x * width_scale + : 0.5 * (x1 + x2) * (image_width - 1); + if (in_x < 0 || in_x > image_width - 1) + { + for (int d = 0; d < depth; ++d) + { + corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = extrapolation_value; + } + continue; + } + + const int left_x_index = floorf(in_x); + const int right_x_index = ceilf(in_x); + const float x_lerp = in_x - left_x_index; + + for (int d = 0; d < depth; ++d) + { + const float *pimage = image_data + b_in * image_elements + d * image_channel_elements; + + const float top_left = pimage[top_y_index * image_width + left_x_index]; + const float top_right = pimage[top_y_index * image_width + right_x_index]; + const float bottom_left = pimage[bottom_y_index * image_width + left_x_index]; + const float bottom_right = pimage[bottom_y_index * image_width + right_x_index]; + + const float top = top_left + (top_right - top_left) * x_lerp; + const float bottom = + bottom_left + (bottom_right - bottom_left) * x_lerp; + + corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = top + (bottom - top) * y_lerp; + } + } // end for x + } // end for y + } // end for b + +} + + +void crop_and_resize_forward( + THFloatTensor * image, + THFloatTensor * boxes, // [y1, x1, y2, x2] + THIntTensor * box_index, // range in [0, batch_size) + const float extrapolation_value, + const int crop_height, + const int crop_width, + THFloatTensor * crops +) { + const int batch_size = image->size[0]; + const int depth = image->size[1]; + const int image_height = image->size[2]; + const int image_width = image->size[3]; + + const int num_boxes = boxes->size[0]; + + // init output space + THFloatTensor_resize4d(crops, num_boxes, depth, crop_height, crop_width); + THFloatTensor_zero(crops); + + // crop_and_resize for each box + CropAndResizePerBox( + THFloatTensor_data(image), + batch_size, + depth, + image_height, + image_width, + + THFloatTensor_data(boxes), + THIntTensor_data(box_index), + 0, + num_boxes, + + THFloatTensor_data(crops), + crop_height, + crop_width, + extrapolation_value + ); + +} + + +void crop_and_resize_backward( + THFloatTensor * grads, + THFloatTensor * boxes, // [y1, x1, y2, x2] + THIntTensor * box_index, // range in [0, batch_size) + THFloatTensor * grads_image // resize to [bsize, c, hc, wc] +) +{ + // shape + const int batch_size = grads_image->size[0]; + const int depth = grads_image->size[1]; + const int image_height = grads_image->size[2]; + const int image_width = grads_image->size[3]; + + const int num_boxes = grads->size[0]; + const int crop_height = grads->size[2]; + const int crop_width = grads->size[3]; + + // n_elements + const int image_channel_elements = image_height * image_width; + const int image_elements = depth * image_channel_elements; + + const int channel_elements = crop_height * crop_width; + const int crop_elements = depth * channel_elements; + + // init output space + THFloatTensor_zero(grads_image); + + // data pointer + const float * grads_data = THFloatTensor_data(grads); + const float * boxes_data = THFloatTensor_data(boxes); + const int * box_index_data = THIntTensor_data(box_index); + float * grads_image_data = THFloatTensor_data(grads_image); + + for (int b = 0; b < num_boxes; ++b) { + const float * box = boxes_data + b * 4; + const float y1 = box[0]; + const float x1 = box[1]; + const float y2 = box[2]; + const float x2 = box[3]; + + const int b_in = box_index_data[b]; + if (b_in < 0 || b_in >= batch_size) { + printf("Error: batch_index %d out of range [0, %d)\n", b_in, batch_size); + exit(-1); + } + + const float height_scale = + (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) + : 0; + const float width_scale = + (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) + : 0; + + for (int y = 0; y < crop_height; ++y) + { + const float in_y = (crop_height > 1) + ? y1 * (image_height - 1) + y * height_scale + : 0.5 * (y1 + y2) * (image_height - 1); + if (in_y < 0 || in_y > image_height - 1) + { + continue; + } + const int top_y_index = floorf(in_y); + const int bottom_y_index = ceilf(in_y); + const float y_lerp = in_y - top_y_index; + + for (int x = 0; x < crop_width; ++x) + { + const float in_x = (crop_width > 1) + ? x1 * (image_width - 1) + x * width_scale + : 0.5 * (x1 + x2) * (image_width - 1); + if (in_x < 0 || in_x > image_width - 1) + { + continue; + } + const int left_x_index = floorf(in_x); + const int right_x_index = ceilf(in_x); + const float x_lerp = in_x - left_x_index; + + for (int d = 0; d < depth; ++d) + { + float *pimage = grads_image_data + b_in * image_elements + d * image_channel_elements; + const float grad_val = grads_data[crop_elements * b + channel_elements * d + y * crop_width + x]; + + const float dtop = (1 - y_lerp) * grad_val; + pimage[top_y_index * image_width + left_x_index] += (1 - x_lerp) * dtop; + pimage[top_y_index * image_width + right_x_index] += x_lerp * dtop; + + const float dbottom = y_lerp * grad_val; + pimage[bottom_y_index * image_width + left_x_index] += (1 - x_lerp) * dbottom; + pimage[bottom_y_index * image_width + right_x_index] += x_lerp * dbottom; + } // end d + } // end x + } // end y + } // end b +} \ No newline at end of file diff --git a/cuda_functions/roi_align_2D/roi_align/src/crop_and_resize.h b/cuda_functions/roi_align_2D/roi_align/src/crop_and_resize.h new file mode 100644 index 0000000..d494865 --- /dev/null +++ b/cuda_functions/roi_align_2D/roi_align/src/crop_and_resize.h @@ -0,0 +1,16 @@ +void crop_and_resize_forward( + THFloatTensor * image, + THFloatTensor * boxes, // [y1, x1, y2, x2] + THIntTensor * box_index, // range in [0, batch_size) + const float extrapolation_value, + const int crop_height, + const int crop_width, + THFloatTensor * crops +); + +void crop_and_resize_backward( + THFloatTensor * grads, + THFloatTensor * boxes, // [y1, x1, y2, x2] + THIntTensor * box_index, // range in [0, batch_size) + THFloatTensor * grads_image // resize to [bsize, c, hc, wc] +); \ No newline at end of file diff --git a/cuda_functions/roi_align_2D/roi_align/src/crop_and_resize_gpu.c b/cuda_functions/roi_align_2D/roi_align/src/crop_and_resize_gpu.c new file mode 100644 index 0000000..dd347c6 --- /dev/null +++ b/cuda_functions/roi_align_2D/roi_align/src/crop_and_resize_gpu.c @@ -0,0 +1,68 @@ +#include +#include "cuda/crop_and_resize_kernel.h" + +extern THCState *state; + + +void crop_and_resize_gpu_forward( + THCudaTensor * image, + THCudaTensor * boxes, // [y1, x1, y2, x2] + THCudaIntTensor * box_index, // range in [0, batch_size) + const float extrapolation_value, + const int crop_height, + const int crop_width, + THCudaTensor * crops +) { + const int batch_size = THCudaTensor_size(state, image, 0); + const int depth = THCudaTensor_size(state, image, 1); + const int image_height = THCudaTensor_size(state, image, 2); + const int image_width = THCudaTensor_size(state, image, 3); + + const int num_boxes = THCudaTensor_size(state, boxes, 0); + + // init output space + THCudaTensor_resize4d(state, crops, num_boxes, depth, crop_height, crop_width); + THCudaTensor_zero(state, crops); + cudaStream_t stream = THCState_getCurrentStream(state); + CropAndResizeLaucher( + THCudaTensor_data(state, image), + THCudaTensor_data(state, boxes), + THCudaIntTensor_data(state, box_index), + num_boxes, batch_size, image_height, image_width, + crop_height, crop_width, depth, extrapolation_value, + THCudaTensor_data(state, crops), + stream + ); +} + + +void crop_and_resize_gpu_backward( + THCudaTensor * grads, + THCudaTensor * boxes, // [y1, x1, y2, x2] + THCudaIntTensor * box_index, // range in [0, batch_size) + THCudaTensor * grads_image // resize to [bsize, c, hc, wc] +) { + // shape + const int batch_size = THCudaTensor_size(state, grads_image, 0); + const int depth = THCudaTensor_size(state, grads_image, 1); + const int image_height = THCudaTensor_size(state, grads_image, 2); + const int image_width = THCudaTensor_size(state, grads_image, 3); + + const int num_boxes = THCudaTensor_size(state, grads, 0); + const int crop_height = THCudaTensor_size(state, grads, 2); + const int crop_width = THCudaTensor_size(state, grads, 3); + + // init output space + THCudaTensor_zero(state, grads_image); + + cudaStream_t stream = THCState_getCurrentStream(state); + CropAndResizeBackpropImageLaucher( + THCudaTensor_data(state, grads), + THCudaTensor_data(state, boxes), + THCudaIntTensor_data(state, box_index), + num_boxes, batch_size, image_height, image_width, + crop_height, crop_width, depth, + THCudaTensor_data(state, grads_image), + stream + ); +} \ No newline at end of file diff --git a/cuda_functions/roi_align_2D/roi_align/src/crop_and_resize_gpu.h b/cuda_functions/roi_align_2D/roi_align/src/crop_and_resize_gpu.h new file mode 100644 index 0000000..c2a64cf --- /dev/null +++ b/cuda_functions/roi_align_2D/roi_align/src/crop_and_resize_gpu.h @@ -0,0 +1,16 @@ +void crop_and_resize_gpu_forward( + THCudaTensor * image, + THCudaTensor * boxes, // [y1, x1, y2, x2] + THCudaIntTensor * box_index, // range in [0, batch_size) + const float extrapolation_value, + const int crop_height, + const int crop_width, + THCudaTensor * crops +); + +void crop_and_resize_gpu_backward( + THCudaTensor * grads, + THCudaTensor * boxes, // [y1, x1, y2, x2] + THCudaIntTensor * box_index, // range in [0, batch_size) + THCudaTensor * grads_image // resize to [bsize, c, hc, wc] +); \ No newline at end of file diff --git a/cuda_functions/roi_align_2D/roi_align/src/cuda/backup.cu b/cuda_functions/roi_align_2D/roi_align/src/cuda/backup.cu new file mode 100644 index 0000000..3a1ab8b --- /dev/null +++ b/cuda_functions/roi_align_2D/roi_align/src/cuda/backup.cu @@ -0,0 +1,243 @@ +#include +#include +#include "crop_and_resize_kernel.h" + +#define CUDA_1D_KERNEL_LOOP(i, n) \ +for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + + +__global__ +void CropAndResizeKernel( + const int nthreads, const float *image_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int crop_height, int crop_width, int depth, + float extrapolation_value, float *crops_ptr) +{ + CUDA_1D_KERNEL_LOOP(out_idx, nthreads) + { + // NHWC: out_idx = d + depth * (w + crop_width * (h + crop_height * b)) + // NCHW: out_idx = w + crop_width * (h + crop_height * (d + depth * b)) + int idx = out_idx; + const int x = idx % crop_width; + idx /= crop_width; + const int y = idx % crop_height; + idx /= crop_height; + const int d = idx % depth; + const int b = idx / depth; + + const float y1 = boxes_ptr[b * 4]; + const float x1 = boxes_ptr[b * 4 + 1]; + const float y2 = boxes_ptr[b * 4 + 2]; + const float x2 = boxes_ptr[b * 4 + 3]; + + // printf("INIT CUDA SCRIPT %f \n", idx); + + const int b_in = box_ind_ptr[b]; + if (b_in < 0 || b_in >= batch) + { + continue; + } + + const float height_scale = + (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) + : 0; + const float width_scale = + (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) : 0; + + const float in_y = (crop_height > 1) + ? y1 * (image_height - 1) + y * height_scale + : 0.5 * (y1 + y2) * (image_height - 1); + if (in_y < 0 || in_y > image_height - 1) + { + crops_ptr[out_idx] = extrapolation_value; + continue; + } + + const float in_x = (crop_width > 1) + ? x1 * (image_width - 1) + x * width_scale + : 0.5 * (x1 + x2) * (image_width - 1); + if (in_x < 0 || in_x > image_width - 1) + { + crops_ptr[out_idx] = extrapolation_value; + continue; + } + + const int top_y_index = floorf(in_y); + const int bottom_y_index = ceilf(in_y); + const float y_lerp = in_y - top_y_index; + + const int left_x_index = floorf(in_x); + const int right_x_index = ceilf(in_x); + const float x_lerp = in_x - left_x_index; + + const float *pimage = image_ptr + (b_in * depth + d) * image_height * image_width; + const float top_left = pimage[top_y_index * image_width + left_x_index]; + const float top_right = pimage[top_y_index * image_width + right_x_index]; + const float bottom_left = pimage[bottom_y_index * image_width + left_x_index]; + const float bottom_right = pimage[bottom_y_index * image_width + right_x_index]; + // if (top_left == 0){ + // const float top = top_right} + // elif (top_right == 0){ + // const float top = top_left} + // else{ + const float top = top_left + (top_right - top_left) * x_lerp; + //} + + //if (bottom_left == 0){ + // const float bottom = bottom_right} + // elif (bottom_right == 0){ + // const float bottom = bottom_left} + // else{ + const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp; + //} + + //if (top == 0){ + // crops_ptr[out_idx] = bottom } + // elif (bottom == 0){ + // crops_ptr[out_idx] = top + //} + // else{ + crops_ptr[out_idx] = top + (bottom - top) * y_lerp; + //} + } +} + +__global__ +void CropAndResizeBackpropImageKernel( + const int nthreads, const float *grads_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int crop_height, int crop_width, int depth, + float *grads_image_ptr) +{ + CUDA_1D_KERNEL_LOOP(out_idx, nthreads) + { + // NHWC: out_idx = d + depth * (w + crop_width * (h + crop_height * b)) + // NCHW: out_idx = w + crop_width * (h + crop_height * (d + depth * b)) + int idx = out_idx; + const int x = idx % crop_width; + idx /= crop_width; + const int y = idx % crop_height; + idx /= crop_height; + const int d = idx % depth; + const int b = idx / depth; + + const float y1 = boxes_ptr[b * 4]; + const float x1 = boxes_ptr[b * 4 + 1]; + const float y2 = boxes_ptr[b * 4 + 2]; + const float x2 = boxes_ptr[b * 4 + 3]; + + const int b_in = box_ind_ptr[b]; + if (b_in < 0 || b_in >= batch) + { + continue; + } + + const float height_scale = + (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) + : 0; + const float width_scale = + (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) : 0; + + const float in_y = (crop_height > 1) + ? y1 * (image_height - 1) + y * height_scale + : 0.5 * (y1 + y2) * (image_height - 1); + if (in_y < 0 || in_y > image_height - 1) + { + continue; + } + + const float in_x = (crop_width > 1) + ? x1 * (image_width - 1) + x * width_scale + : 0.5 * (x1 + x2) * (image_width - 1); + if (in_x < 0 || in_x > image_width - 1) + { + continue; + } + + const int top_y_index = floorf(in_y); + const int bottom_y_index = ceilf(in_y); + const float y_lerp = in_y - top_y_index; + + const int left_x_index = floorf(in_x); + const int right_x_index = ceilf(in_x); + const float x_lerp = in_x - left_x_index; + + float *pimage = grads_image_ptr + (b_in * depth + d) * image_height * image_width; + const float dtop = (1 - y_lerp) * grads_ptr[out_idx]; + atomicAdd( + pimage + top_y_index * image_width + left_x_index, + (1 - x_lerp) * dtop + ); + atomicAdd( + pimage + top_y_index * image_width + right_x_index, + x_lerp * dtop + ); + + const float dbottom = y_lerp * grads_ptr[out_idx]; + atomicAdd( + pimage + bottom_y_index * image_width + left_x_index, + (1 - x_lerp) * dbottom + ); + atomicAdd( + pimage + bottom_y_index * image_width + right_x_index, + x_lerp * dbottom + ); + } +} + + +void CropAndResizeLaucher( + const float *image_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int crop_height, int crop_width, int depth, + float extrapolation_value, float *crops_ptr, cudaStream_t stream) +{ + const int total_count = num_boxes * crop_height * crop_width * depth; + const int thread_per_block = 1024; + const int block_count = (total_count + thread_per_block - 1) / thread_per_block; + cudaError_t err; + + if (total_count > 0) + { + CropAndResizeKernel<<>>( + total_count, image_ptr, boxes_ptr, + box_ind_ptr, num_boxes, batch, image_height, image_width, + crop_height, crop_width, depth, extrapolation_value, crops_ptr); + + err = cudaGetLastError(); + if (cudaSuccess != err) + { + fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } + } +} + + +void CropAndResizeBackpropImageLaucher( + const float *grads_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int crop_height, int crop_width, int depth, + float *grads_image_ptr, cudaStream_t stream) +{ + const int total_count = num_boxes * crop_height * crop_width * depth; + const int thread_per_block = 1024; + const int block_count = (total_count + thread_per_block - 1) / thread_per_block; + cudaError_t err; + + if (total_count > 0) + { + CropAndResizeBackpropImageKernel<<>>( + total_count, grads_ptr, boxes_ptr, + box_ind_ptr, num_boxes, batch, image_height, image_width, + crop_height, crop_width, depth, grads_image_ptr); + + err = cudaGetLastError(); + if (cudaSuccess != err) + { + fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } + } +} \ No newline at end of file diff --git a/cuda_functions/roi_align_2D/roi_align/src/cuda/crop_and_resize_kernel.cu b/cuda_functions/roi_align_2D/roi_align/src/cuda/crop_and_resize_kernel.cu new file mode 100644 index 0000000..0702551 --- /dev/null +++ b/cuda_functions/roi_align_2D/roi_align/src/cuda/crop_and_resize_kernel.cu @@ -0,0 +1,250 @@ +#include +#include +#include "crop_and_resize_kernel.h" + +#define CUDA_1D_KERNEL_LOOP(i, n) \ +for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + + +__global__ +void CropAndResizeKernel( + const int nthreads, const float *image_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int crop_height, int crop_width, int depth, + float extrapolation_value, float *crops_ptr) +{ + CUDA_1D_KERNEL_LOOP(out_idx, nthreads) + { + // NHWC: out_idx = d + depth * (w + crop_width * (h + crop_height * b)) + // NCHW: out_idx = w + crop_width * (h + crop_height * (d + depth * b)) + int idx = out_idx; + //printf("start %i \n", idx); + const int x = idx % crop_width; + idx /= crop_width; + const int y = idx % crop_height; + idx /= crop_height; + const int d = idx % depth; + const int b = idx / depth; + + const float y1 = boxes_ptr[b * 4]; + const float x1 = boxes_ptr[b * 4 + 1]; + const float y2 = boxes_ptr[b * 4 + 2]; + const float x2 = boxes_ptr[b * 4 + 3]; + + const int b_in = box_ind_ptr[b]; + if (b_in < 0 || b_in >= batch) + { + continue; + } + + const float height_scale = + (crop_height > 1) ? (y2 - y1) * (image_height) / (crop_height) + : 0; + const float width_scale = + (crop_width > 1) ? (x2 - x1) * (image_width) / (crop_width) : 0; + + + float tmp_in_y = (crop_height > 1) + ? y1 * (image_height ) + y * height_scale + height_scale/2 - 0.5 + : 0.5 * (y1 + y2) * (image_height); + + if (tmp_in_y > image_height - 1) + { + tmp_in_y = image_height - 1; + } + if (tmp_in_y < 0) + { + tmp_in_y = 0; + } + const float in_y = tmp_in_y; + + float tmp_in_x = (crop_width > 1) + ? x1 * (image_width ) + x * width_scale + width_scale/2 - 0.5 + : 0.5 * (x1 + x2) * (image_width ); + + if (tmp_in_x > image_width - 1) + { + tmp_in_x = image_width - 1; + } + if (tmp_in_x < 0) + { + tmp_in_x= 0; + } + const float in_x = tmp_in_x; + + //printf("height_scale %f \n", height_scale); + //printf("width_scale %f \n", width_scale); + //printf("in_x %f \n", in_x); + //printf("in_y %f \n", in_y); + + const int top_y_index = floorf(in_y); + const int bottom_y_index = ceilf(in_y); + const float y_lerp = in_y - top_y_index; + + const int left_x_index = floorf(in_x); + const int right_x_index = ceilf(in_x); + const float x_lerp = in_x - left_x_index; + + const float *pimage = image_ptr + (b_in * depth + d) * image_height * image_width; + const float top_left = pimage[top_y_index * image_width + left_x_index]; + const float top_right = pimage[top_y_index * image_width + right_x_index]; + const float bottom_left = pimage[bottom_y_index * image_width + left_x_index]; + const float bottom_right = pimage[bottom_y_index * image_width + right_x_index]; + + const float top = top_left + (top_right - top_left) * x_lerp; + const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp; + crops_ptr[out_idx] = top + (bottom - top) * y_lerp; + } +} + +__global__ +void CropAndResizeBackpropImageKernel( + const int nthreads, const float *grads_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int crop_height, int crop_width, int depth, + float *grads_image_ptr) +{ + CUDA_1D_KERNEL_LOOP(out_idx, nthreads) + { + // NHWC: out_idx = d + depth * (w + crop_width * (h + crop_height * b)) + // NCHW: out_idx = w + crop_width * (h + crop_height * (d + depth * b)) + int idx = out_idx; + const int x = idx % crop_width; + idx /= crop_width; + const int y = idx % crop_height; + idx /= crop_height; + const int d = idx % depth; + const int b = idx / depth; + + const float y1 = boxes_ptr[b * 4]; + const float x1 = boxes_ptr[b * 4 + 1]; + const float y2 = boxes_ptr[b * 4 + 2]; + const float x2 = boxes_ptr[b * 4 + 3]; + + const int b_in = box_ind_ptr[b]; + if (b_in < 0 || b_in >= batch) + { + continue; + } + + const float height_scale = + (crop_height > 1) ? (y2 - y1) * (image_height ) / (crop_height ) + : 0; + const float width_scale = + (crop_width > 1) ? (x2 - x1) * (image_width ) / (crop_width ) : 0; + + float tmp_in_y = (crop_height > 1) + ? y1 * (image_height ) + y * height_scale + height_scale/2 - 0.5 + : 0.5 * (y1 + y2) * (image_height); + + if (tmp_in_y > image_height - 1) + { + tmp_in_y = image_height - 1; + } + if (tmp_in_y < 0) + { + tmp_in_y = 0; + } + const float in_y = tmp_in_y; + + float tmp_in_x = (crop_width > 1) + ? x1 * (image_width ) + x * width_scale + width_scale/2 - 0.5 + : 0.5 * (x1 + x2) * (image_width ); + + if (tmp_in_x > image_width - 1) + { + tmp_in_x = image_width - 1; + } + if (tmp_in_x < 0) + { + tmp_in_x= 0; + } + const float in_x = tmp_in_x; + + const int top_y_index = floorf(in_y); + const int bottom_y_index = ceilf(in_y); + const float y_lerp = in_y - top_y_index; + + const int left_x_index = floorf(in_x); + const int right_x_index = ceilf(in_x); + const float x_lerp = in_x - left_x_index; + + float *pimage = grads_image_ptr + (b_in * depth + d) * image_height * image_width; + const float dtop = (1 - y_lerp) * grads_ptr[out_idx]; + atomicAdd( + pimage + top_y_index * image_width + left_x_index, + (1 - x_lerp) * dtop + ); + atomicAdd( + pimage + top_y_index * image_width + right_x_index, + x_lerp * dtop + ); + + const float dbottom = y_lerp * grads_ptr[out_idx]; + atomicAdd( + pimage + bottom_y_index * image_width + left_x_index, + (1 - x_lerp) * dbottom + ); + atomicAdd( + pimage + bottom_y_index * image_width + right_x_index, + x_lerp * dbottom + ); + } +} + + +void CropAndResizeLaucher( + const float *image_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int crop_height, int crop_width, int depth, + float extrapolation_value, float *crops_ptr, cudaStream_t stream) +{ + const int total_count = num_boxes * crop_height * crop_width * depth; + const int thread_per_block = 1024; + const int block_count = (total_count + thread_per_block - 1) / thread_per_block; + cudaError_t err; + + if (total_count > 0) + { + CropAndResizeKernel<<>>( + total_count, image_ptr, boxes_ptr, + box_ind_ptr, num_boxes, batch, image_height, image_width, + crop_height, crop_width, depth, extrapolation_value, crops_ptr); + + err = cudaGetLastError(); + if (cudaSuccess != err) + { + fprintf(stderr, "cudaCheckError in Roi Align () failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } + } +} + + +void CropAndResizeBackpropImageLaucher( + const float *grads_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int crop_height, int crop_width, int depth, + float *grads_image_ptr, cudaStream_t stream) +{ + const int total_count = num_boxes * crop_height * crop_width * depth; + const int thread_per_block = 1024; + const int block_count = (total_count + thread_per_block - 1) / thread_per_block; + cudaError_t err; + + if (total_count > 0) + { + CropAndResizeBackpropImageKernel<<>>( + total_count, grads_ptr, boxes_ptr, + box_ind_ptr, num_boxes, batch, image_height, image_width, + crop_height, crop_width, depth, grads_image_ptr); + + err = cudaGetLastError(); + if (cudaSuccess != err) + { + fprintf(stderr, "cudaCheckError() failed in Roi Align : %s\n", cudaGetErrorString(err)); + exit(-1); + } + } +} \ No newline at end of file diff --git a/cuda_functions/roi_align_2D/roi_align/src/cuda/crop_and_resize_kernel.cu.o b/cuda_functions/roi_align_2D/roi_align/src/cuda/crop_and_resize_kernel.cu.o new file mode 100644 index 0000000..2f1a1b9 Binary files /dev/null and b/cuda_functions/roi_align_2D/roi_align/src/cuda/crop_and_resize_kernel.cu.o differ diff --git a/cuda_functions/roi_align_2D/roi_align/src/cuda/crop_and_resize_kernel.h b/cuda_functions/roi_align_2D/roi_align/src/cuda/crop_and_resize_kernel.h new file mode 100644 index 0000000..893aee1 --- /dev/null +++ b/cuda_functions/roi_align_2D/roi_align/src/cuda/crop_and_resize_kernel.h @@ -0,0 +1,24 @@ +#ifndef _CropAndResize_Kernel +#define _CropAndResize_Kernel + +#ifdef __cplusplus +extern "C" { +#endif + +void CropAndResizeLaucher( + const float *image_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int crop_height, int crop_width, int depth, + float extrapolation_value, float *crops_ptr, cudaStream_t stream); + +void CropAndResizeBackpropImageLaucher( + const float *grads_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int crop_height, int crop_width, int depth, + float *grads_image_ptr, cudaStream_t stream); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/cuda_functions/roi_align_2D/roi_align/src/cuda/fix.cu b/cuda_functions/roi_align_2D/roi_align/src/cuda/fix.cu new file mode 100644 index 0000000..6eea4a8 --- /dev/null +++ b/cuda_functions/roi_align_2D/roi_align/src/cuda/fix.cu @@ -0,0 +1,243 @@ +#include +#include +#include "crop_and_resize_kernel.h" + +#define CUDA_1D_KERNEL_LOOP(i, n) \ +for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + + +__global__ +void CropAndResizeKernel( + const int nthreads, const float *image_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int crop_height, int crop_width, int depth, + float extrapolation_value, float *crops_ptr) +{ + CUDA_1D_KERNEL_LOOP(out_idx, nthreads) + { + // NHWC: out_idx = d + depth * (w + crop_width * (h + crop_height * b)) + // NCHW: out_idx = w + crop_width * (h + crop_height * (d + depth * b)) + int idx = out_idx; + const int x = idx % crop_width; + idx /= crop_width; + const int y = idx % crop_height; + idx /= crop_height; + const int d = idx % depth; + const int b = idx / depth; + + const float y1 = boxes_ptr[b * 4]; + const float x1 = boxes_ptr[b * 4 + 1]; + const float y2 = boxes_ptr[b * 4 + 2]; + const float x2 = boxes_ptr[b * 4 + 3]; + + // printf("INIT CUDA SCRIPT %f \n", idx); + + const int b_in = box_ind_ptr[b]; + if (b_in < 0 || b_in >= batch) + { + continue; + } + + const float height_scale = + (crop_height > 1) ? (y2 - y1) * (image_height ) / (crop_height ) + : 0; + const float width_scale = + (crop_width > 1) ? (x2 - x1) * (image_width) / (crop_width ) : 0; + + const float in_y = (crop_height > 1) + ? y1 * (image_height ) + y * height_scale + height_scale/2 - 0.5 + : 0.5 * (y1 + y2) * (image_height ); + if (in_y < 0 || in_y > image_height ) + { + crops_ptr[out_idx] = extrapolation_value; + continue; + } + + const float in_x = (crop_width > 1) + ? x1 * (image_width ) + x * width_scale + width_scale/2 - 0.5 + : 0.5 * (x1 + x2) * (image_width ); + if (in_x < 0 || in_x > image_width ) + { + crops_ptr[out_idx] = extrapolation_value; + continue; + } + + const int top_y_index = floorf(in_y); + const int bottom_y_index = ceilf(in_y); + const float y_lerp = in_y - top_y_index; + + const int left_x_index = floorf(in_x); + const int right_x_index = ceilf(in_x); + const float x_lerp = in_x - left_x_index; + + const float *pimage = image_ptr + (b_in * depth + d) * image_height * image_width; + const float top_left = pimage[top_y_index * image_width + left_x_index]; + const float top_right = pimage[top_y_index * image_width + right_x_index]; + const float bottom_left = pimage[bottom_y_index * image_width + left_x_index]; + const float bottom_right = pimage[bottom_y_index * image_width + right_x_index]; + // if (top_left == 0){ + // const float top = top_right} + // elif (top_right == 0){ + // const float top = top_left} + // else{ + const float top = top_left + (top_right - top_left) * x_lerp; + //} + + //if (bottom_left == 0){ + // const float bottom = bottom_right} + // elif (bottom_right == 0){ + // const float bottom = bottom_left} + // else{ + const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp; + //} + + //if (top == 0){ + // crops_ptr[out_idx] = bottom } + // elif (bottom == 0){ + // crops_ptr[out_idx] = top + //} + // else{ + crops_ptr[out_idx] = top + (bottom - top) * y_lerp; + //} + } +} + +__global__ +void CropAndResizeBackpropImageKernel( + const int nthreads, const float *grads_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int crop_height, int crop_width, int depth, + float *grads_image_ptr) +{ + CUDA_1D_KERNEL_LOOP(out_idx, nthreads) + { + // NHWC: out_idx = d + depth * (w + crop_width * (h + crop_height * b)) + // NCHW: out_idx = w + crop_width * (h + crop_height * (d + depth * b)) + int idx = out_idx; + const int x = idx % crop_width; + idx /= crop_width; + const int y = idx % crop_height; + idx /= crop_height; + const int d = idx % depth; + const int b = idx / depth; + + const float y1 = boxes_ptr[b * 4]; + const float x1 = boxes_ptr[b * 4 + 1]; + const float y2 = boxes_ptr[b * 4 + 2]; + const float x2 = boxes_ptr[b * 4 + 3]; + + const int b_in = box_ind_ptr[b]; + if (b_in < 0 || b_in >= batch) + { + continue; + } + + const float height_scale = + (crop_height > 1) ? (y2 - y1) * (image_height ) / (crop_height ) + : 0; + const float width_scale = + (crop_width > 1) ? (x2 - x1) * (image_width ) / (crop_width ) : 0; + + const float in_y = (crop_height > 1) + ? y1 * (image_height ) + y * height_scale + height_scale/2 - 0.5 + : 0.5 * (y1 + y2) * (image_height ); + if (in_y < 0 || in_y > image_height ) + { + continue; + } + + const float in_x = (crop_width > 1) + ? x1 * (image_width ) + x * width_scale + width_scale/2 - 0.5 + : 0.5 * (x1 + x2) * (image_width ); + if (in_x < 0 || in_x > image_width ) + { + continue; + } + + const int top_y_index = floorf(in_y); + const int bottom_y_index = ceilf(in_y); + const float y_lerp = in_y - top_y_index; + + const int left_x_index = floorf(in_x); + const int right_x_index = ceilf(in_x); + const float x_lerp = in_x - left_x_index; + + float *pimage = grads_image_ptr + (b_in * depth + d) * image_height * image_width; + const float dtop = (1 - y_lerp) * grads_ptr[out_idx]; + atomicAdd( + pimage + top_y_index * image_width + left_x_index, + (1 - x_lerp) * dtop + ); + atomicAdd( + pimage + top_y_index * image_width + right_x_index, + x_lerp * dtop + ); + + const float dbottom = y_lerp * grads_ptr[out_idx]; + atomicAdd( + pimage + bottom_y_index * image_width + left_x_index, + (1 - x_lerp) * dbottom + ); + atomicAdd( + pimage + bottom_y_index * image_width + right_x_index, + x_lerp * dbottom + ); + } +} + + +void CropAndResizeLaucher( + const float *image_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int crop_height, int crop_width, int depth, + float extrapolation_value, float *crops_ptr, cudaStream_t stream) +{ + const int total_count = num_boxes * crop_height * crop_width * depth; + const int thread_per_block = 1024; + const int block_count = (total_count + thread_per_block - 1) / thread_per_block; + cudaError_t err; + + if (total_count > 0) + { + CropAndResizeKernel<<>>( + total_count, image_ptr, boxes_ptr, + box_ind_ptr, num_boxes, batch, image_height, image_width, + crop_height, crop_width, depth, extrapolation_value, crops_ptr); + + err = cudaGetLastError(); + if (cudaSuccess != err) + { + fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } + } +} + + +void CropAndResizeBackpropImageLaucher( + const float *grads_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int crop_height, int crop_width, int depth, + float *grads_image_ptr, cudaStream_t stream) +{ + const int total_count = num_boxes * crop_height * crop_width * depth; + const int thread_per_block = 1024; + const int block_count = (total_count + thread_per_block - 1) / thread_per_block; + cudaError_t err; + + if (total_count > 0) + { + CropAndResizeBackpropImageKernel<<>>( + total_count, grads_ptr, boxes_ptr, + box_ind_ptr, num_boxes, batch, image_height, image_width, + crop_height, crop_width, depth, grads_image_ptr); + + err = cudaGetLastError(); + if (cudaSuccess != err) + { + fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } + } +} \ No newline at end of file diff --git a/cuda_functions/roi_align_3D/__init__.py b/cuda_functions/roi_align_3D/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cuda_functions/roi_align_3D/roi_align/__init__.py b/cuda_functions/roi_align_3D/roi_align/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cuda_functions/roi_align_3D/roi_align/_ext/__init__.py b/cuda_functions/roi_align_3D/roi_align/_ext/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cuda_functions/roi_align_3D/roi_align/_ext/crop_and_resize/._crop_and_resize.so.swp b/cuda_functions/roi_align_3D/roi_align/_ext/crop_and_resize/._crop_and_resize.so.swp new file mode 100644 index 0000000..3db0ea4 Binary files /dev/null and b/cuda_functions/roi_align_3D/roi_align/_ext/crop_and_resize/._crop_and_resize.so.swp differ diff --git a/cuda_functions/roi_align_3D/roi_align/_ext/crop_and_resize/__init__.py b/cuda_functions/roi_align_3D/roi_align/_ext/crop_and_resize/__init__.py new file mode 100644 index 0000000..4486c09 --- /dev/null +++ b/cuda_functions/roi_align_3D/roi_align/_ext/crop_and_resize/__init__.py @@ -0,0 +1,15 @@ + +from torch.utils.ffi import _wrap_function +from ._crop_and_resize import lib as _lib, ffi as _ffi + +__all__ = [] +def _import_symbols(locals): + for symbol in dir(_lib): + fn = getattr(_lib, symbol) + if callable(fn): + locals[symbol] = _wrap_function(fn, _ffi) + else: + locals[symbol] = fn + __all__.append(symbol) + +_import_symbols(locals()) diff --git a/cuda_functions/roi_align_3D/roi_align/_ext/crop_and_resize/_crop_and_resize.so b/cuda_functions/roi_align_3D/roi_align/_ext/crop_and_resize/_crop_and_resize.so new file mode 100755 index 0000000..81dc147 Binary files /dev/null and b/cuda_functions/roi_align_3D/roi_align/_ext/crop_and_resize/_crop_and_resize.so differ diff --git a/cuda_functions/roi_align_3D/roi_align/build.py b/cuda_functions/roi_align_3D/roi_align/build.py new file mode 100755 index 0000000..3798d82 --- /dev/null +++ b/cuda_functions/roi_align_3D/roi_align/build.py @@ -0,0 +1,40 @@ +import os +import torch +from torch.utils.ffi import create_extension + + +sources = ['src/crop_and_resize.c'] +headers = ['src/crop_and_resize.h'] +defines = [] +with_cuda = False + +extra_objects = [] +if torch.cuda.is_available(): + print('Including CUDA code.') + sources += ['src/crop_and_resize_gpu.c'] + headers += ['src/crop_and_resize_gpu.h'] + defines += [('WITH_CUDA', None)] + extra_objects += ['src/cuda/crop_and_resize_kernel.cu.o'] + with_cuda = True + +extra_compile_args = ['-fopenmp', '-std=c99'] + +this_file = os.path.dirname(os.path.realpath(__file__)) +print(this_file) +sources = [os.path.join(this_file, fname) for fname in sources] +headers = [os.path.join(this_file, fname) for fname in headers] +extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] + +ffi = create_extension( + '_ext.crop_and_resize', + headers=headers, + sources=sources, + define_macros=defines, + relative_to=__file__, + with_cuda=with_cuda, + extra_objects=extra_objects, + extra_compile_args=extra_compile_args +) + +if __name__ == '__main__': + ffi.build() diff --git a/cuda_functions/roi_align_3D/roi_align/crop_and_resize.py b/cuda_functions/roi_align_3D/roi_align/crop_and_resize.py new file mode 100755 index 0000000..cff4e90 --- /dev/null +++ b/cuda_functions/roi_align_3D/roi_align/crop_and_resize.py @@ -0,0 +1,69 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function + +from ._ext import crop_and_resize as _backend + + +class CropAndResizeFunction(Function): + + def __init__(self, crop_height, crop_width, crop_zdepth, extrapolation_value=0): + self.crop_height = crop_height + self.crop_width = crop_width + self.crop_zdepth = crop_zdepth + self.extrapolation_value = extrapolation_value + + def forward(self, image, boxes, box_ind): + crops = torch.zeros_like(image) + + if image.is_cuda: + _backend.crop_and_resize_gpu_forward( + image, boxes, box_ind, + self.extrapolation_value, self.crop_height, self.crop_width, self.crop_zdepth, crops) + else: + _backend.crop_and_resize_forward( + image, boxes, box_ind, + self.extrapolation_value, self.crop_height, self.crop_width, self.crop_zdepth, crops) + + # save for backward + self.im_size = image.size() + self.save_for_backward(boxes, box_ind) + + return crops + + def backward(self, grad_outputs): + boxes, box_ind = self.saved_tensors + + grad_outputs = grad_outputs.contiguous() + grad_image = torch.zeros_like(grad_outputs).resize_(*self.im_size) + + if grad_outputs.is_cuda: + _backend.crop_and_resize_gpu_backward( + grad_outputs, boxes, box_ind, grad_image + ) + else: + _backend.crop_and_resize_backward( + grad_outputs, boxes, box_ind, grad_image + ) + + return grad_image, None, None + + +class CropAndResize(nn.Module): + """ + Crop and resize ported from tensorflow + See more details on https://www.tensorflow.org/api_docs/python/tf/image/crop_and_resize + """ + + def __init__(self, crop_height, crop_width, crop_zdepth, extrapolation_value=0): + super(CropAndResize, self).__init__() + + self.crop_height = crop_height + self.crop_width = crop_width + self.crop_zdepth = crop_zdepth + self.extrapolation_value = extrapolation_value + + def forward(self, image, boxes, box_ind): + return CropAndResizeFunction(self.crop_height, self.crop_width, self.crop_zdepth, self.extrapolation_value)(image, boxes, box_ind) diff --git a/cuda_functions/roi_align_3D/roi_align/roi_align.py b/cuda_functions/roi_align_3D/roi_align/roi_align.py new file mode 100644 index 0000000..6931539 --- /dev/null +++ b/cuda_functions/roi_align_3D/roi_align/roi_align.py @@ -0,0 +1,48 @@ +import torch +from torch import nn + +from .crop_and_resize import CropAndResizeFunction, CropAndResize + + +class RoIAlign(nn.Module): + + def __init__(self, crop_height, crop_width, extrapolation_value=0, transform_fpcoor=True): + super(RoIAlign, self).__init__() + + self.crop_height = crop_height + self.crop_width = crop_width + self.extrapolation_value = extrapolation_value + self.transform_fpcoor = transform_fpcoor + + def forward(self, featuremap, boxes, box_ind): + """ + RoIAlign based on crop_and_resize. + See more details on https://github.com/ppwwyyxx/tensorpack/blob/6d5ba6a970710eaaa14b89d24aace179eb8ee1af/examples/FasterRCNN/model.py#L301 + :param featuremap: NxCxHxW + :param boxes: Mx4 float box with (x1, y1, x2, y2) **without normalization** + :param box_ind: M + :return: MxCxoHxoW + """ + x1, y1, x2, y2 = torch.split(boxes, 1, dim=1) + image_height, image_width = featuremap.size()[2:4] + + if self.transform_fpcoor: + spacing_w = (x2 - x1) / float(self.crop_width) + spacing_h = (y2 - y1) / float(self.crop_height) + + nx0 = (x1 + spacing_w / 2 - 0.5) / float(image_width - 1) + ny0 = (y1 + spacing_h / 2 - 0.5) / float(image_height - 1) + nw = spacing_w * float(self.crop_width - 1) / float(image_width - 1) + nh = spacing_h * float(self.crop_height - 1) / float(image_height - 1) + + boxes = torch.cat((ny0, nx0, ny0 + nh, nx0 + nw), 1) + else: + x1 = x1 / float(image_width - 1) + x2 = x2 / float(image_width - 1) + y1 = y1 / float(image_height - 1) + y2 = y2 / float(image_height - 1) + boxes = torch.cat((y1, x1, y2, x2), 1) + + boxes = boxes.detach().contiguous() + box_ind = box_ind.detach() + return CropAndResizeFunction(self.crop_height, self.crop_width, self.extrapolation_value)(featuremap, boxes, box_ind) diff --git a/cuda_functions/roi_align_3D/roi_align/src/crop_and_resize.c b/cuda_functions/roi_align_3D/roi_align/src/crop_and_resize.c new file mode 100644 index 0000000..e1fce67 --- /dev/null +++ b/cuda_functions/roi_align_3D/roi_align/src/crop_and_resize.c @@ -0,0 +1,252 @@ +#include +#include +#include + + +void CropAndResizePerBox( + const float * image_data, + const int batch_size, + const int depth, + const int image_height, + const int image_width, + + const float * boxes_data, + const int * box_index_data, + const int start_box, + const int limit_box, + + float * corps_data, + const int crop_height, + const int crop_width, + const float extrapolation_value +) { + const int image_channel_elements = image_height * image_width; + const int image_elements = depth * image_channel_elements; + + const int channel_elements = crop_height * crop_width; + const int crop_elements = depth * channel_elements; + + int b; + #pragma omp parallel for + for (b = start_box; b < limit_box; ++b) { + const float * box = boxes_data + b * 4; + const float y1 = box[0]; + const float x1 = box[1]; + const float y2 = box[2]; + const float x2 = box[3]; + + const int b_in = box_index_data[b]; + if (b_in < 0 || b_in >= batch_size) { + printf("Error: batch_index %d out of range [0, %d)\n", b_in, batch_size); + exit(-1); + } + + const float height_scale = + (crop_height > 1) + ? (y2 - y1) * (image_height - 1) / (crop_height - 1) + : 0; + const float width_scale = + (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) + : 0; + + for (int y = 0; y < crop_height; ++y) + { + const float in_y = (crop_height > 1) + ? y1 * (image_height - 1) + y * height_scale + : 0.5 * (y1 + y2) * (image_height - 1); + + if (in_y < 0 || in_y > image_height - 1) + { + for (int x = 0; x < crop_width; ++x) + { + for (int d = 0; d < depth; ++d) + { + // crops(b, y, x, d) = extrapolation_value; + corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = extrapolation_value; + } + } + continue; + } + + const int top_y_index = floorf(in_y); + const int bottom_y_index = ceilf(in_y); + const float y_lerp = in_y - top_y_index; + + for (int x = 0; x < crop_width; ++x) + { + const float in_x = (crop_width > 1) + ? x1 * (image_width - 1) + x * width_scale + : 0.5 * (x1 + x2) * (image_width - 1); + if (in_x < 0 || in_x > image_width - 1) + { + for (int d = 0; d < depth; ++d) + { + corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = extrapolation_value; + } + continue; + } + + const int left_x_index = floorf(in_x); + const int right_x_index = ceilf(in_x); + const float x_lerp = in_x - left_x_index; + + for (int d = 0; d < depth; ++d) + { + const float *pimage = image_data + b_in * image_elements + d * image_channel_elements; + + const float top_left = pimage[top_y_index * image_width + left_x_index]; + const float top_right = pimage[top_y_index * image_width + right_x_index]; + const float bottom_left = pimage[bottom_y_index * image_width + left_x_index]; + const float bottom_right = pimage[bottom_y_index * image_width + right_x_index]; + + const float top = top_left + (top_right - top_left) * x_lerp; + const float bottom = + bottom_left + (bottom_right - bottom_left) * x_lerp; + + corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = top + (bottom - top) * y_lerp; + } + } // end for x + } // end for y + } // end for b + +} + + +void crop_and_resize_forward( + THFloatTensor * image, + THFloatTensor * boxes, // [y1, x1, y2, x2] + THIntTensor * box_index, // range in [0, batch_size) + const float extrapolation_value, + const int crop_height, + const int crop_width, + THFloatTensor * crops +) { + const int batch_size = image->size[0]; + const int depth = image->size[1]; + const int image_height = image->size[2]; + const int image_width = image->size[3]; + + const int num_boxes = boxes->size[0]; + + // init output space + THFloatTensor_resize4d(crops, num_boxes, depth, crop_height, crop_width); + THFloatTensor_zero(crops); + + // crop_and_resize for each box + CropAndResizePerBox( + THFloatTensor_data(image), + batch_size, + depth, + image_height, + image_width, + + THFloatTensor_data(boxes), + THIntTensor_data(box_index), + 0, + num_boxes, + + THFloatTensor_data(crops), + crop_height, + crop_width, + extrapolation_value + ); + +} + + +void crop_and_resize_backward( + THFloatTensor * grads, + THFloatTensor * boxes, // [y1, x1, y2, x2] + THIntTensor * box_index, // range in [0, batch_size) + THFloatTensor * grads_image // resize to [bsize, c, hc, wc] +) +{ + // shape + const int batch_size = grads_image->size[0]; + const int depth = grads_image->size[1]; + const int image_height = grads_image->size[2]; + const int image_width = grads_image->size[3]; + + const int num_boxes = grads->size[0]; + const int crop_height = grads->size[2]; + const int crop_width = grads->size[3]; + + // n_elements + const int image_channel_elements = image_height * image_width; + const int image_elements = depth * image_channel_elements; + + const int channel_elements = crop_height * crop_width; + const int crop_elements = depth * channel_elements; + + // init output space + THFloatTensor_zero(grads_image); + + // data pointer + const float * grads_data = THFloatTensor_data(grads); + const float * boxes_data = THFloatTensor_data(boxes); + const int * box_index_data = THIntTensor_data(box_index); + float * grads_image_data = THFloatTensor_data(grads_image); + + for (int b = 0; b < num_boxes; ++b) { + const float * box = boxes_data + b * 4; + const float y1 = box[0]; + const float x1 = box[1]; + const float y2 = box[2]; + const float x2 = box[3]; + + const int b_in = box_index_data[b]; + if (b_in < 0 || b_in >= batch_size) { + printf("Error: batch_index %d out of range [0, %d)\n", b_in, batch_size); + exit(-1); + } + + const float height_scale = + (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) + : 0; + const float width_scale = + (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) + : 0; + + for (int y = 0; y < crop_height; ++y) + { + const float in_y = (crop_height > 1) + ? y1 * (image_height - 1) + y * height_scale + : 0.5 * (y1 + y2) * (image_height - 1); + if (in_y < 0 || in_y > image_height - 1) + { + continue; + } + const int top_y_index = floorf(in_y); + const int bottom_y_index = ceilf(in_y); + const float y_lerp = in_y - top_y_index; + + for (int x = 0; x < crop_width; ++x) + { + const float in_x = (crop_width > 1) + ? x1 * (image_width - 1) + x * width_scale + : 0.5 * (x1 + x2) * (image_width - 1); + if (in_x < 0 || in_x > image_width - 1) + { + continue; + } + const int left_x_index = floorf(in_x); + const int right_x_index = ceilf(in_x); + const float x_lerp = in_x - left_x_index; + + for (int d = 0; d < depth; ++d) + { + float *pimage = grads_image_data + b_in * image_elements + d * image_channel_elements; + const float grad_val = grads_data[crop_elements * b + channel_elements * d + y * crop_width + x]; + + const float dtop = (1 - y_lerp) * grad_val; + pimage[top_y_index * image_width + left_x_index] += (1 - x_lerp) * dtop; + pimage[top_y_index * image_width + right_x_index] += x_lerp * dtop; + + const float dbottom = y_lerp * grad_val; + pimage[bottom_y_index * image_width + left_x_index] += (1 - x_lerp) * dbottom; + pimage[bottom_y_index * image_width + right_x_index] += x_lerp * dbottom; + } // end d + } // end x + } // end y + } // end b +} \ No newline at end of file diff --git a/cuda_functions/roi_align_3D/roi_align/src/crop_and_resize.h b/cuda_functions/roi_align_3D/roi_align/src/crop_and_resize.h new file mode 100644 index 0000000..d494865 --- /dev/null +++ b/cuda_functions/roi_align_3D/roi_align/src/crop_and_resize.h @@ -0,0 +1,16 @@ +void crop_and_resize_forward( + THFloatTensor * image, + THFloatTensor * boxes, // [y1, x1, y2, x2] + THIntTensor * box_index, // range in [0, batch_size) + const float extrapolation_value, + const int crop_height, + const int crop_width, + THFloatTensor * crops +); + +void crop_and_resize_backward( + THFloatTensor * grads, + THFloatTensor * boxes, // [y1, x1, y2, x2] + THIntTensor * box_index, // range in [0, batch_size) + THFloatTensor * grads_image // resize to [bsize, c, hc, wc] +); \ No newline at end of file diff --git a/cuda_functions/roi_align_3D/roi_align/src/crop_and_resize_gpu.c b/cuda_functions/roi_align_3D/roi_align/src/crop_and_resize_gpu.c new file mode 100644 index 0000000..8e07b3d --- /dev/null +++ b/cuda_functions/roi_align_3D/roi_align/src/crop_and_resize_gpu.c @@ -0,0 +1,73 @@ +#include +#include "cuda/crop_and_resize_kernel.h" + +extern THCState *state; + + +void crop_and_resize_gpu_forward( + THCudaTensor * image, + THCudaTensor * boxes, // [y1, x1, y2, x2] + THCudaIntTensor * box_index, // range in [0, batch_size) + const float extrapolation_value, + const int crop_height, + const int crop_width, + const int crop_zdepth, + THCudaTensor * crops +) { + const int batch_size = THCudaTensor_size(state, image, 0); + const int depth = THCudaTensor_size(state, image, 1); + const int image_height = THCudaTensor_size(state, image, 2); + const int image_width = THCudaTensor_size(state, image, 3); + const int image_zdepth = THCudaTensor_size(state, image, 4); + + const int num_boxes = THCudaTensor_size(state, boxes, 0); + + // init output space + THCudaTensor_resize5d(state, crops, num_boxes, depth, crop_height, crop_width, crop_zdepth); + THCudaTensor_zero(state, crops); + + cudaStream_t stream = THCState_getCurrentStream(state); + CropAndResizeLaucher( + THCudaTensor_data(state, image), + THCudaTensor_data(state, boxes), + THCudaIntTensor_data(state, box_index), + num_boxes, batch_size, image_height, image_width, image_zdepth, + crop_height, crop_width, crop_zdepth, depth, extrapolation_value, + THCudaTensor_data(state, crops), + stream + ); +} + + +void crop_and_resize_gpu_backward( + THCudaTensor * grads, + THCudaTensor * boxes, // [y1, x1, y2, x2] + THCudaIntTensor * box_index, // range in [0, batch_size) + THCudaTensor * grads_image // resize to [bsize, c, hc, wc] +) { + // shape + const int batch_size = THCudaTensor_size(state, grads_image, 0); + const int depth = THCudaTensor_size(state, grads_image, 1); + const int image_height = THCudaTensor_size(state, grads_image, 2); + const int image_width = THCudaTensor_size(state, grads_image, 3); + const int image_zdepth = THCudaTensor_size(state, grads_image, 4); + + const int num_boxes = THCudaTensor_size(state, grads, 0); + const int crop_height = THCudaTensor_size(state, grads, 2); + const int crop_width = THCudaTensor_size(state, grads, 3); + const int crop_zdepth = THCudaTensor_size(state, grads, 4); + + // init output space + THCudaTensor_zero(state, grads_image); + + cudaStream_t stream = THCState_getCurrentStream(state); + CropAndResizeBackpropImageLaucher( + THCudaTensor_data(state, grads), + THCudaTensor_data(state, boxes), + THCudaIntTensor_data(state, box_index), + num_boxes, batch_size, image_height, image_width, image_zdepth, + crop_height, crop_width, crop_zdepth, depth, + THCudaTensor_data(state, grads_image), + stream + ); +} \ No newline at end of file diff --git a/cuda_functions/roi_align_3D/roi_align/src/crop_and_resize_gpu.h b/cuda_functions/roi_align_3D/roi_align/src/crop_and_resize_gpu.h new file mode 100644 index 0000000..dd2eb5a --- /dev/null +++ b/cuda_functions/roi_align_3D/roi_align/src/crop_and_resize_gpu.h @@ -0,0 +1,17 @@ +void crop_and_resize_gpu_forward( + THCudaTensor * image, + THCudaTensor * boxes, // [y1, x1, y2, x2] + THCudaIntTensor * box_index, // range in [0, batch_size) + const float extrapolation_value, + const int crop_height, + const int crop_width, + const int crop_zdepth, + THCudaTensor * crops +); + +void crop_and_resize_gpu_backward( + THCudaTensor * grads, + THCudaTensor * boxes, // [y1, x1, y2, x2] + THCudaIntTensor * box_index, // range in [0, batch_size) + THCudaTensor * grads_image // resize to [bsize, c, hc, wc] +); \ No newline at end of file diff --git a/cuda_functions/roi_align_3D/roi_align/src/cuda/crop_and_resize_kernel.cu b/cuda_functions/roi_align_3D/roi_align/src/cuda/crop_and_resize_kernel.cu new file mode 100644 index 0000000..e381dab --- /dev/null +++ b/cuda_functions/roi_align_3D/roi_align/src/cuda/crop_and_resize_kernel.cu @@ -0,0 +1,361 @@ +#include +#include +#include "crop_and_resize_kernel.h" +#include + +#define CUDA_1D_KERNEL_LOOP(i, n) \ +for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + + +__global__ +void CropAndResizeKernel( + const int nthreads, const float *image_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int image_zdepth, int crop_height, int crop_width, int crop_zdepth, int depth, + float extrapolation_value, float *crops_ptr) +{ + CUDA_1D_KERNEL_LOOP(out_idx, nthreads) // nthreads = total_count! + { + // NHWC: out_idx = d + depth * (w + crop_width * (h + crop_height * b)) position in out grid!!! + // NCHW: out_idx = w + crop_width * (h + crop_height * (d + depth * b)) NCYX yes seems like xy is exchanged! + // NCHWZ: out_idx = z + crop_zdepth * (w + crop_width * (h + crop_height * (d + depth * b))) z == last. + + int idx = out_idx; + + const int z = idx % crop_zdepth; + idx /= crop_zdepth; + const int x = idx % crop_width; + idx /= crop_width; + const int y = idx % crop_height; + idx /= crop_height; + + const int d = idx % depth; + const int b = idx / depth; // batch + + const float y1 = boxes_ptr[b * 6]; // b = batch -> 0 // normalized coords!! + const float x1 = boxes_ptr[b * 6 + 1]; + const float y2 = boxes_ptr[b * 6 + 2]; + const float x2 = boxes_ptr[b * 6 + 3]; + const float z1 = boxes_ptr[b * 6 + 4]; + const float z2 = boxes_ptr[b * 6 + 5]; + + const int b_in = box_ind_ptr[b]; // == 0 in my case. + if (b_in < 0 || b_in >= batch) + { + continue; + } + + // e.g. (0.4-0.3)*100 = 10 / 7 = 1.3 ratio proposal_size / crops_size. one cell in crops has size 1.3 in_pixel. + + const float height_scale = + (crop_height > 1) ? (y2 - y1) * (image_height ) / (crop_height ) : 0; + const float width_scale = + (crop_width > 1) ? (x2 - x1) * (image_width ) / (crop_width ) : 0; + + const float zdepth_scale = + (crop_zdepth > 1) ? (z2 - z1) * (image_zdepth ) / (crop_zdepth ) : 0; + + + // e.g. 0.3*100 + 5 * 1.3 . Which floating coordinate is going into cell? + // e.g. y: 30 (lower bound prop) + 7.5 (current crop position * scale) + + + float tmp_in_y = (crop_height > 1) + ? y1 * (image_height ) + y * height_scale + height_scale/2 - 0.5 + : 0.5 * (y1 + y2) * (image_height); + + if (tmp_in_y > image_height - 1) + { + tmp_in_y = image_height - 1; + } + if (tmp_in_y < 0) + { + tmp_in_y = 0; + } + const float in_y = tmp_in_y; + + + float tmp_in_x = (crop_width > 1) + ? x1 * (image_width ) + x * width_scale + width_scale/2 - 0.5 + : 0.5 * (x1 + x2) * (image_width ); + + if (tmp_in_x > image_width - 1) + { + tmp_in_x = image_width - 1; + } + if (tmp_in_x < 0) + { + tmp_in_x= 0; + } + const float in_x = tmp_in_x; + + + float tmp_in_z = (crop_zdepth > 1) + ? z1 * (image_zdepth ) + z * zdepth_scale + zdepth_scale/2 - 0.5 + : 0.5 * (z1 + z2) * (image_zdepth); + + if (tmp_in_z > image_zdepth - 1) + { + tmp_in_z = image_zdepth - 1; + } + if (tmp_in_z < 0) + { + tmp_in_z= 0; + } + const float in_z = tmp_in_z; + + // this is just rounding of the floating coord of grid cell. The distances to nearest grid points are + // memorized (lerp) to be used for bilinear interpolation later. + const int top_y_index = floorf(in_y); + const int bottom_y_index = ceilf(in_y); + const float y_lerp = in_y - top_y_index; + + const int left_x_index = floorf(in_x); + const int right_x_index = ceilf(in_x); + const float x_lerp = in_x - left_x_index; // + + const int front_z_index = floorf(in_z); + const int back_z_index = ceilf(in_z); + const float z_lerp = in_z - front_z_index; + + + // address of image + going to the right feature map. + const float *pimage = image_ptr + (b_in * depth + d) * image_height * image_width * image_zdepth; + + // 1D address of corner points of in_coords to grid cell. + // NCHWZ: out_idx = z + crop_zdepth * (w + crop_width * (h + crop_height * (d + depth * b))) z == last. + const float top_left_front = pimage[front_z_index + image_zdepth * (left_x_index + image_width * top_y_index)]; + const float top_right_front = pimage[front_z_index + image_zdepth * (right_x_index + image_width * top_y_index)]; + const float bottom_left_front = pimage[front_z_index + image_zdepth * (left_x_index + image_width * bottom_y_index)]; + const float bottom_right_front = pimage[front_z_index + image_zdepth * (right_x_index + image_width * bottom_y_index)]; + const float top_left_back = pimage[back_z_index + image_zdepth * (left_x_index + image_width * top_y_index)]; + const float top_right_back = pimage[back_z_index + image_zdepth * (right_x_index + image_width * top_y_index)]; + const float bottom_left_back = pimage[back_z_index + image_zdepth * (left_x_index + image_width * bottom_y_index)]; + const float bottom_right_back = pimage[back_z_index + image_zdepth * (right_x_index + image_width * bottom_y_index)]; + + // Bilinear Interpolation!! These are pixel values now! lerp is the interpolation distance! + // No Maxpool, only one point is sampled! + const float top_front = top_left_front + (top_right_front - top_left_front) * x_lerp; + const float bottom_front = bottom_left_front + (bottom_right_front - bottom_left_front) * x_lerp; + const float top_back = top_left_back + (top_right_back - top_left_back) * x_lerp; + const float bottom_back = bottom_left_back + (bottom_right_back - bottom_left_back) * x_lerp; + + const float front = top_front + (bottom_front - top_front) * y_lerp; + const float back = top_back + (bottom_back - top_back) * y_lerp; + + crops_ptr[out_idx] = front + (back - front) * z_lerp; // assign interpolated value to Grid cell! + + + } +} + +__global__ +void CropAndResizeBackpropImageKernel( + const int nthreads, const float *grads_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int image_zdepth, int crop_height, int crop_width, int crop_zdepth, int depth, + float *grads_image_ptr) +{ + CUDA_1D_KERNEL_LOOP(out_idx, nthreads) + { + // NHWC: out_idx = d + depth * (w + crop_width * (h + crop_height * b)) + // NCHW: out_idx = w + crop_width * (h + crop_height * (d + depth * b)) + // NCHWZ: out_idx = z + crop_zdepth * (w + crop_width * (h + crop_height * (d + depth * b))) z == last. + int idx = out_idx; + + const int z = idx % crop_zdepth; + idx /= crop_zdepth; + const int x = idx % crop_width; + idx /= crop_width; + const int y = idx % crop_height; + idx /= crop_height; + const int d = idx % depth; + const int b = idx / depth; + + const float y1 = boxes_ptr[b * 6]; // b = batch -> 0 // normalized coords!! + const float x1 = boxes_ptr[b * 6 + 1]; + const float y2 = boxes_ptr[b * 6 + 2]; + const float x2 = boxes_ptr[b * 6 + 3]; + const float z1 = boxes_ptr[b * 6 + 4]; + const float z2 = boxes_ptr[b * 6 + 5]; + + + const int b_in = box_ind_ptr[b]; + if (b_in < 0 || b_in >= batch) + { + continue; + } + + const float height_scale = + (crop_height > 1) ? (y2 - y1) * (image_height ) / (crop_height ) + : 0; + const float width_scale = + (crop_width > 1) ? (x2 - x1) * (image_width ) / (crop_width ) : 0; + + const float zdepth_scale = + (crop_zdepth > 1) ? (z2 - z1) * (image_zdepth ) / (crop_zdepth ) : 0; + + + float tmp_in_y = (crop_height > 1) + ? y1 * (image_height ) + y * height_scale + height_scale/2 - 0.5 + : 0.5 * (y1 + y2) * (image_height); + if (tmp_in_y > image_height - 1) + { + tmp_in_y = image_height - 1; + } + if (tmp_in_y < 0) + { + tmp_in_y = 0; + } + const float in_y = tmp_in_y; + + + float tmp_in_x = (crop_width > 1) + ? x1 * (image_width ) + x * width_scale + width_scale/2 - 0.5 + : 0.5 * (x1 + x2) * (image_width ); + if (tmp_in_x > image_width - 1) + { + tmp_in_x = image_width - 1; + } + if (tmp_in_x < 0) + { + tmp_in_x= 0; + } + const float in_x = tmp_in_x; + + + float tmp_in_z = (crop_zdepth > 1) + ? z1 * (image_zdepth ) + z * zdepth_scale + zdepth_scale/2 - 0.5 + : 0.5 * (z1 + z2) * (image_zdepth); + if (tmp_in_z > image_zdepth - 1) + { + tmp_in_z = image_zdepth - 1; + } + if (tmp_in_z < 0) + { + tmp_in_z= 0; + } + const float in_z = tmp_in_z; + + const int top_y_index = floorf(in_y); + const int bottom_y_index = ceilf(in_y); + const float y_lerp = in_y - top_y_index; + + const int left_x_index = floorf(in_x); + const int right_x_index = ceilf(in_x); + const float x_lerp = in_x - left_x_index; + + const int front_z_index = floorf(in_z); + const int back_z_index = ceilf(in_z); + const float z_lerp = in_z - front_z_index; + + float *pimage = grads_image_ptr + (b_in * depth + d) * image_height * image_width * image_zdepth; + + // top left front + atomicAdd( + pimage + front_z_index + image_zdepth * (left_x_index + image_width * top_y_index), + (1 - x_lerp) * (1 - z_lerp) * (1 - y_lerp) * grads_ptr[out_idx] // THIS IS BACKWARD INTERPOL. + ); + + // top left back + atomicAdd( + pimage + back_z_index + image_zdepth * (left_x_index + image_width * top_y_index), + (1 - x_lerp) * (z_lerp) * (1 - y_lerp) * grads_ptr[out_idx] // THIS IS BACKWARD INTERPOL. + ); + + // top right front + atomicAdd( + pimage + front_z_index + image_zdepth * (right_x_index + image_width * top_y_index), + (x_lerp) * (1 - z_lerp) * (1 - y_lerp) * grads_ptr[out_idx] // THIS IS backward INTERPOL. + ); + + // top right back + atomicAdd( + pimage + back_z_index + image_zdepth * (right_x_index + image_width * top_y_index), + (x_lerp) * (z_lerp) * (1 - y_lerp) * grads_ptr[out_idx] // THIS IS backward INTERPOL. + ); + + // bottom left front + atomicAdd( + pimage + front_z_index + image_zdepth * (left_x_index + image_width * bottom_y_index), + (1 - x_lerp) * (1 - z_lerp) * (y_lerp) * grads_ptr[out_idx] // THIS IS backward INTERPOL. + ); + + // bottom left back + atomicAdd( + pimage + back_z_index + image_zdepth * (left_x_index + image_width * bottom_y_index), + (1 - x_lerp) * (z_lerp) * (y_lerp) * grads_ptr[out_idx] // THIS IS backward INTERPOL. + ); + + // bottom right front + atomicAdd( + pimage + front_z_index + image_zdepth * (right_x_index + image_width * bottom_y_index), + (x_lerp) * (1 - z_lerp) * (y_lerp) * grads_ptr[out_idx] // THIS IS backward INTERPOL. + ); + + // bottom right back + atomicAdd( + pimage + back_z_index + image_zdepth * (right_x_index + image_width * bottom_y_index), + (x_lerp) * (z_lerp) * (y_lerp) * grads_ptr[out_idx] // THIS IS backward INTERPOL. + ); + + } +} + + + +void CropAndResizeLaucher( + const float *image_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int image_zdepth, int crop_height, int crop_width, int crop_zdepth, int depth, + float extrapolation_value, float *crops_ptr, cudaStream_t stream) +{ + const int total_count = num_boxes * crop_height * crop_width * crop_zdepth * depth; + const int thread_per_block = 1024; + const int block_count = (total_count + thread_per_block - 1) / thread_per_block; + cudaError_t err; + + if (total_count > 0) + { + CropAndResizeKernel<<>>( + total_count, image_ptr, boxes_ptr, + box_ind_ptr, num_boxes, batch, image_height, image_width, image_zdepth, + crop_height, crop_width, crop_zdepth, depth, extrapolation_value, crops_ptr); + + err = cudaGetLastError(); + if (cudaSuccess != err) + { + fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); + exit(-1); + } + } +} + + +void CropAndResizeBackpropImageLaucher( + const float *grads_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int image_zdepth, int crop_height, int crop_width, int crop_zdepth, int depth, + float *grads_image_ptr, cudaStream_t stream) +{ + const int total_count = num_boxes * crop_height * crop_width * crop_zdepth * depth; + const int thread_per_block = 1024; + const int block_count = (total_count + thread_per_block - 1) / thread_per_block; + cudaError_t err; + + if (total_count > 0) + { + CropAndResizeBackpropImageKernel<<>>( + total_count, grads_ptr, boxes_ptr, + box_ind_ptr, num_boxes, batch, image_height, image_width, image_zdepth, + crop_height, crop_width, crop_zdepth, depth, grads_image_ptr); + + err = cudaGetLastError(); + if (cudaSuccess != err) + { + fprintf(stderr, "cudaCheckError() failed in Roi Align : %s\n", cudaGetErrorString(err)); + exit(-1); + } + } +} \ No newline at end of file diff --git a/cuda_functions/roi_align_3D/roi_align/src/cuda/crop_and_resize_kernel.cu.o b/cuda_functions/roi_align_3D/roi_align/src/cuda/crop_and_resize_kernel.cu.o new file mode 100644 index 0000000..d488598 Binary files /dev/null and b/cuda_functions/roi_align_3D/roi_align/src/cuda/crop_and_resize_kernel.cu.o differ diff --git a/cuda_functions/roi_align_3D/roi_align/src/cuda/crop_and_resize_kernel.h b/cuda_functions/roi_align_3D/roi_align/src/cuda/crop_and_resize_kernel.h new file mode 100644 index 0000000..9244582 --- /dev/null +++ b/cuda_functions/roi_align_3D/roi_align/src/cuda/crop_and_resize_kernel.h @@ -0,0 +1,24 @@ +#ifndef _CropAndResize_Kernel +#define _CropAndResize_Kernel + +#ifdef __cplusplus +extern "C" { +#endif + +void CropAndResizeLaucher( + const float *image_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int image_zdepth, int crop_height, int crop_width, int crop_zdepth, int depth, + float extrapolation_value, float *crops_ptr, cudaStream_t stream); + +void CropAndResizeBackpropImageLaucher( + const float *grads_ptr, const float *boxes_ptr, + const int *box_ind_ptr, int num_boxes, int batch, int image_height, + int image_width, int image_zdepth, int crop_height, int crop_width, int crop_zdepth, int depth, + float *grads_image_ptr, cudaStream_t stream); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/data_manager.py b/data_manager.py new file mode 100644 index 0000000..82f529b --- /dev/null +++ b/data_manager.py @@ -0,0 +1,197 @@ + +import os +import warnings +import time +import subprocess +from multiprocessing import Pool + +import argparse +import shutil + +import numpy as np + + +def get_identifiers(folder, ending=".npz"): + identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith(ending)] + return identifiers + +def convert_to_npz(kwargs): + """ + :param kwargs: npy-file:path or name:namestr and destination:path + """ + assert "identifier" in kwargs.keys(), "you need to define at least a npyfile-identifier" + identifier = kwargs["identifier"] + if "folder" in kwargs.keys(): + folder = kwargs["folder"] + else: + folder = "" + dest = kwargs["destination"] + if "name" in kwargs.keys(): + name = kwargs["name"] + else: + name = identifier + + npy_file = os.path.join(folder, identifier+".npy") + data = np.load(npy_file) + np.savez_compressed(os.path.join(dest, identifier + ".npz"), **{name:data}) + if "verbose" in kwargs.keys() and kwargs["verbose"]: + print("converted file {} to npz".format(npy_file)) + +def pack_dataset(folder, destination=None, recursive=False, processes=None, verbose=True): + """call convert_to_npz parallely with "processes" processes on all npys in folder. + does not actually pack more than one file into an archive... + """ + if processes is None: + processes = os.cpu_count() + + p = Pool(processes) + + if destination is None: + destination = folder + if recursive: + folders = [root for (root, dir, file) in os.walk(folder)] + else: + folders = [folder] + + for fldr in folders: + identifiers = get_identifiers(fldr, ".npy") + if recursive: + cur_dest = os.path.join(destination, os.path.relpath(fldr, folder)) + else: + cur_dest = destination + if not os.path.isdir(cur_dest): + os.mkdir(cur_dest) + + kwargs = [{"folder":fldr, "identifier":ident, "destination":cur_dest, "verbose":verbose} for ident in identifiers] + p.map(convert_to_npz, kwargs) + print("converted folder {}.".format(fldr)) + p.close() + p.join() + + +def convert_to_npy(kwargs): + identifier = kwargs["identifier"] + folder = kwargs["folder"] + delete = kwargs["delete"] + npz_file = os.path.join(folder,identifier+".npz") + + if os.path.isfile(npz_file[:-4] + ".npy"): + print("{}.npy already exists, not overwriting.".format(npz_file[:-4])) + else: + data = np.load(npz_file)[identifier] # should be the only entry anyway + np.save(npz_file[:-4] + ".npy", data) + print("converted {} from npz to npy".format(npz_file[:-4])) + if delete: + os.remove(npz_file) + +def unpack_dataset(folder, recursive=False, delete=True, processes=None): + if processes is None: + processes = os.cpu_count() + + p = Pool(processes) + + if recursive: + folders = [root for (root, dir, file) in os.walk(folder)] + else: + folders = [folder] + + for fldr in folders: + identifiers = get_identifiers(fldr) + kwargs = [{"folder":fldr, "identifier":ident, "delete":delete} for ident in identifiers] + p.map(convert_to_npy, kwargs) + print("unpacked folder ", fldr) + p.close() + p.join() + +def delete_npy(folder, recursive=False): #not used + identifiers = get_identifiers(folder) + npy_files = [os.path.join(folder, i+".npy") for i in identifiers] + #should not be necessary since get_iden already returns only existing files: + npy_files = [i for i in npy_files if os.path.isfile(i)] + for n in npy_files: + os.remove(n) + +def copy(args, file_subset=None, verbose=True): + r"""copy and evtly unpack dataset (convert npz->npy) or pack dataset (npy->npz). + :param file_subset: copy only files whose names are in file_subset + """ + + source_path = args.source + dest_path = args.destination + assert dest_path is not None, "you need to define a copy destination" + start_time = time.time() + print("Destination: ", dest_path) + + rsync_opts = "-v " if verbose else "" + if args.recursive: + rsync_opts += r"-a --include '**/'" + if args.cp_only_npz: + rsync_opts+= r" --include '*.npz'" #to copy only npz files + + try: + rsync_opts_all = rsync_opts + if file_subset is not None: #ranks higher than only-npz criterium + #rsync include/exclude doesnt work with absolute paths for the files!! :/:/ + for file in file_subset: + if os.path.isabs(file): + file = os.path.relpath(file, source_path) + rsync_opts_all +=r" --include '{}'".format(file) + if args.cp_only_npz or file_subset is not None: + rsync_opts_all += r" --exclude '*'" #has to be added after all --includes + subprocess.call('rsync {} {} {}'.format(rsync_opts_all, + source_path, dest_path), shell=True) + except OSError: #in case argument list too long due to file subset + warnings.warn("OSError when trying to copy file_subset at once. Copying single files instead.") + if file_subset is not None: + for file in file_subset: + rsync_opts_file = rsync_opts+" --include '{}' --exclude '*'".format(file) + subprocess.call('rsync {} {} {}'.format(rsync_opts_file, + source_path, dest_path), shell=True) + else: + if args.cp_only_npz: + rsync_opts += r" --exclude '*'" + subprocess.call('rsync {} {} {}'.format(rsync_opts, + source_path, dest_path), shell=True) + #one would only need the part in exception catcher, but try part might be faster if feasible + + if not args.keep_packed: + unpack_dataset(dest_path, recursive=args.recursive, delete=args.del_after_unpack, processes=args.threads) + #elif pack_copied_dataset: + # pack_dataset(dest_path, recursive=args.recursive) + mins, secs = divmod((time.time() - start_time), 60) + t = "{:d}m:{:02d}s".format(int(mins), int(secs)) + print("copying and unpacking data set took: {}".format(t)) + try: + copied_files = [file for (root, dir, file) in os.walk(dest_path)] + print("nr of files in destination: {}".format(len(copied_files))) + except FileNotFoundError: #fails if destination is on a server + pass + + +if __name__=="__main__": + """ usage: create folder containing converted npys (i.e., npzs) and all other data that needs to be copied, + copy the folder, evtly unpack to npy again. + """ + parser = argparse.ArgumentParser() + parser.add_argument('-m', '--mode', type=str, help="convert, copy, or delete. convert: npy->npz. delete: rmtree dest.") + parser.add_argument('-s', '--source', type=str, help="Source path to folder containing data.") + parser.add_argument('-d', '--destination', type=str, default=None, help="Destination path") + parser.add_argument('--keep_packed', action='store_true', help="after copying, do not convert to npy.") + #parser.add_argument('--pack_after_copy', action='store_true', help="after copying, convert npy to npz.") + parser.add_argument('-r', '--recursive', action='store_true') + parser.add_argument('--cp_only_npz', action='store_true', help="whether to copy only .npz-files") + parser.add_argument('--del_after_unpack', action='store_true', help="whether to delete npz after unpacking them") + parser.add_argument('--threads', type=int, default=None, help="how many cpu threads to use for conversions") + + args = parser.parse_args() + mode = args.mode + + if mode == "convert": + #convert from npy to npz + pack_dataset(args.source, destination=args.destination, recursive=args.recursive, processes=args.threads) + elif mode == 'copy': + copy(args) + elif mode == 'delete': + shutil.rmtree(args.destination) + else: + 'cluster_data_manager: chosen mode not implemented.' diff --git a/datasets/cityscapes/configs.py b/datasets/cityscapes/configs.py new file mode 100644 index 0000000..ed2cdab --- /dev/null +++ b/datasets/cityscapes/configs.py @@ -0,0 +1,434 @@ +__author__ = '' +#credit Paul F. Jaeger + +######################### +# Example Config # +######################### + +import os +import sys + +import numpy as np +from collections import namedtuple + +sys.path.append('../') +from default_configs import DefaultConfigs + +class Configs(DefaultConfigs): + + def __init__(self, server_env=None): + super(Configs, self).__init__(server_env) + + self.dim = 2 + + ######################### + # I/O # + ######################### + + self.data_sourcedir = "/mnt/HDD2TB/Documents/data/cityscapes/cs_20190715/" + if server_env: + #self.source_dir = '/home/ramien/medicaldetectiontoolkit/' + self.data_sourcedir = '/datasets/data_ramien/cityscapes/cs_20190715_npz/' + #self.data_sourcedir = "/mnt/HDD2TB/Documents/data/cityscapes/cs_6c_inst_only/" + + self.datapath = "leftImg8bit/" + self.targetspath = "gtFine/" + + self.cities = {'train':['dusseldorf', 'aachen', 'bochum', 'cologne', 'erfurt', + 'hamburg', 'hanover', 'jena', 'krefeld', 'monchengladbach', + 'strasbourg', 'stuttgart', 'tubingen', 'ulm', 'weimar', + 'zurich'], + 'val':['frankfurt', 'munster'], + 'test':['bremen', 'darmstadt', 'lindau'] } + self.set_splits = ["train", "val", "test"] # for training and val, mixed up + # test cities are not held out + + self.info_dict_name = 'city_info.pkl' + self.info_dict_path = os.path.join(self.data_sourcedir, self.info_dict_name) + self.config_path = os.path.realpath(__file__) + self.backbone_path = 'models/backbone.py' + + # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'detection_fpn']. + self.model = 'retina_unet' + self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net') + self.model_path = os.path.join(self.source_dir, self.model_path) + + self.select_prototype_subset = None + + ######################### + # Preprocessing # + ######################### + self.prepro = { + + 'data_dir': '/mnt/HDD2TB/Documents/data/cityscapes_raw/', #raw files (input), needs to end with "/" + 'targettype': "gtFine_instanceIds", + 'set_splits': ["train", "val", "test"], + + 'img_target_size': np.array([256, 512])*4, #y,x + + 'output_directory': self.data_sourcedir, + + 'center_of_mass_crop': True, #not implemented + #'pre_crop_size': , #z,y,x + 'normalization': {'percentiles':[1., 99.]},#not implemented + 'interpolation': 'nearest', #not implemented + + 'info_dict_path': self.info_dict_path, + + 'npz_dir' : self.data_sourcedir[:-1]+"_npz" #if not None: convert to npz, copy data here + } + + ######################### + # Architecture # + ######################### + # 'class', 'regression', 'regression_ken_gal' + # 'class': standard object classification per roi, pairwise combinable with each of below tasks. + # 'class' is only option implemented for CityScapes data set. + self.prediction_tasks = ['class',] + self.start_filts = 52 + self.end_filts = self.start_filts * 4 + self.res_architecture = 'resnet101' # 'resnet101' , 'resnet50' + self.weight_init = None # 'kaiming', 'xavier' or None for pytorch default + self.norm = 'instance_norm' # 'batch_norm' # one of 'None', 'instance_norm', 'batch_norm' + self.relu = 'relu' + + ######################### + # Data Loader # + ######################### + + self.seed = 17 + self.n_workers = 16 if server_env else os.cpu_count() + + self.batch_size = 8 + self.n_cv_splits = 10 #at least 2 (train, val) + + self.num_classes = None #set below #for instance classification (excl background) + self.num_seg_classes = None #set below #incl background + + self.create_bounding_box_targets = True + self.class_specific_seg = True + + self.channels = [0,1,2] + self.pre_crop_size = self.prepro['img_target_size'] # y,x + self.crop_margin = [10,10] #has to be smaller than respective patch_size//2 + self.patch_size_2D = [256, 512] #self.pre_crop_size #would be better to save as tuple since should not be altered + self.patch_size_3D = self.patch_size_2D + [1] + self.patch_size = self.patch_size_2D + + self.balance_target = "class_targets" + # ratio of fully random patients drawn during batch generation + # resulting batch random count is rounded down to closest integer + self.batch_random_ratio = 0.2 + + self.observables_patient = [] + self.observables_rois = [] + + ######################### + # Data Augmentation # + ######################### + #the angle rotations are implemented incorrectly in batchgenerators! in 2D, + #the x-axis angle controls the z-axis angle. + self.do_aug = True + self.da_kwargs = { + 'mirror': True, + 'mirror_axes': (1,), #image axes, (batch and channel are ignored, i.e., actual tensor dims are +2) + 'random_crop': True, + 'rand_crop_dist': (self.patch_size[0] / 2., self.patch_size[1] / 2.), + 'do_elastic_deform': True, + 'alpha': (0., 1000.), + 'sigma': (28., 30.), + 'do_rotation': True, + 'angle_x': (-np.pi / 8., np.pi / 8.), + 'angle_y': (0.,0.), + 'angle_z': (0.,0.), + 'do_scale': True, + 'scale': (0.6, 1.4), + 'border_mode_data': 'constant', + 'gamma_range': (0.6, 1.4) + } + + ################################# + # Schedule / Selection / Optim # + ################################# + #mrcnn paper: ~2.56m samples seen during coco-dataset training + self.num_epochs = 400 + self.num_train_batches = 600 + + self.do_validation = True + # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training) + # the former is morge accurate, while the latter is faster (depending on volume size) + self.val_mode = 'val_sampling' # one of 'val_sampling', 'val_patient' + # if 'all' iterates over entire val_set once. + self.num_val_batches = "all" # for val_sampling + + self.save_n_models = 3 + self.min_save_thresh = 1 # in epochs + self.model_selection_criteria = {"human_ap": 1., "vehicle_ap": 0.9} + self.warm_up = 0 + + self.learning_rate = [5*1e-4] * self.num_epochs + self.dynamic_lr_scheduling = True #with scheduler set in exec + self.lr_decay_factor = 0.5 + self.scheduling_patience = int(self.num_epochs//10) + self.weight_decay = 1e-6 + self.clip_norm = None # number or None + + ######################### + # Colors and Legends # + ######################### + self.plot_frequency = 5 + + #colors + self.color_palette = [self.red, self.blue, self.green, self.orange, self.aubergine, + self.yellow, self.gray, self.cyan, self.black] + + #legends + Label = namedtuple( 'Label' , [ + 'name' , # The identifier of this label, e.g. 'car', 'person', ... . + # We use them to uniquely name a class + 'ppId' , # An integer ID that is associated with this label. + # The IDs are used to represent the label in ground truth images + # An ID of -1 means that this label does not have an ID and thus + # is ignored when creating ground truth images (e.g. license plate). + # Do not modify these IDs, since exactly these IDs are expected by the + # evaluation server. + 'id' , # Feel free to modify these IDs as suitable for your method. + # Max value is 255! + 'category' , # The name of the category that this label belongs to + 'categoryId' , # The ID of this category. Used to create ground truth images + # on category level. + 'hasInstances', # Whether this label distinguishes between single instances or not + 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored + # during evaluations or not + 'color' , # The color of this label + ] ) + segLabel = namedtuple( "segLabel", ["name", "id", "color"]) + boxLabel = namedtuple( 'boxLabel', [ "name", "color"]) + + self.labels = [ + # name ppId id category catId hasInstances ignoreInEval color + Label( 'ignore' , 0 , 0 , 'void' , 0 , False , True , ( 0., 0., 0., 1.) ), + Label( 'ego vehicle' , 1 , 0 , 'void' , 0 , False , True , ( 0., 0., 0., 1.) ), + Label( 'rectification border' , 2 , 0 , 'void' , 0 , False , True , ( 0., 0., 0., 1.) ), + Label( 'out of roi' , 3 , 0 , 'void' , 0 , False , True , ( 0., 0., 0., 1.) ), + Label( 'static' , 4 , 0 , 'void' , 0 , False , True , ( 0., 0., 0., 1.) ), + Label( 'dynamic' , 5 , 0 , 'void' , 0 , False , True , (0.44, 0.29, 0., 1.) ), + Label( 'ground' , 6 , 0 , 'void' , 0 , False , True , ( 0.32, 0., 0.32, 1.) ), + Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (0.5, 0.25, 0.5, 1.) ), + Label( 'sidewalk' , 8 , 0 , 'flat' , 1 , False , False , (0.96, 0.14, 0.5, 1.) ), + Label( 'parking' , 9 , 0 , 'flat' , 1 , False , True , (0.98, 0.67, 0.63, 1.) ), + Label( 'rail track' , 10 , 0 , 'flat' , 1 , False , True , ( 0.9, 0.59, 0.55, 1.) ), + Label( 'building' , 11 , 0 , 'construction' , 2 , False , False , ( 0.27, 0.27, 0.27, 1.) ), + Label( 'wall' , 12 , 0 , 'construction' , 2 , False , False , (0.4,0.4,0.61, 1.) ), + Label( 'fence' , 13 , 0 , 'construction' , 2 , False , False , (0.75,0.6,0.6, 1.) ), + Label( 'guard rail' , 14 , 0 , 'construction' , 2 , False , True , (0.71,0.65,0.71, 1.) ), + Label( 'bridge' , 15 , 0 , 'construction' , 2 , False , True , (0.59,0.39,0.39, 1.) ), + Label( 'tunnel' , 16 , 0 , 'construction' , 2 , False , True , (0.59,0.47, 0.35, 1.) ), + Label( 'pole' , 17 , 0 , 'object' , 3 , False , False , (0.6,0.6,0.6, 1.) ), + Label( 'polegroup' , 18 , 0 , 'object' , 3 , False , True , (0.6,0.6,0.6, 1.) ), + Label( 'traffic light' , 19 , 0 , 'object' , 3 , False , False , (0.98,0.67, 0.12, 1.) ), + Label( 'traffic sign' , 20 , 0 , 'object' , 3 , False , False , (0.86,0.86, 0., 1.) ), + Label( 'vegetation' , 21 , 0 , 'nature' , 4 , False , False , (0.42,0.56, 0.14, 1.) ), + Label( 'terrain' , 22 , 0 , 'nature' , 4 , False , False , (0.6, 0.98,0.6, 1.) ), + Label( 'sky' , 23 , 0 , 'sky' , 5 , False , False , (0.27,0.51,0.71, 1.) ), + Label( 'person' , 24 , 1 , 'human' , 6 , True , False , (0.86, 0.08, 0.24, 1.) ), + Label( 'rider' , 25 , 1 , 'human' , 6 , True , False , (1., 0., 0., 1.) ), + Label( 'car' , 26 , 2 , 'vehicle' , 7 , True , False , ( 0., 0.,0.56, 1.) ), + Label( 'truck' , 27 , 2 , 'vehicle' , 7 , True , False , ( 0., 0., 0.27, 1.) ), + Label( 'bus' , 28 , 2 , 'vehicle' , 7 , True , False , ( 0., 0.24,0.39, 1.) ), + Label( 'caravan' , 29 , 2 , 'vehicle' , 7 , True , True , ( 0., 0., 0.35, 1.) ), + Label( 'trailer' , 30 , 2 , 'vehicle' , 7 , True , True , ( 0., 0.,0.43, 1.) ), + Label( 'train' , 31 , 2 , 'vehicle' , 7 , True , False , ( 0., 0.31,0.39, 1.) ), + Label( 'motorcycle' , 32 , 2 , 'vehicle' , 7 , True , False , ( 0., 0., 0.9, 1.) ), + Label( 'bicycle' , 33 , 2 , 'vehicle' , 7 , True , False , (0.47, 0.04, 0.13, 1.) ), + Label( 'license plate' , -1 , 0 , 'vehicle' , 7 , False , True , ( 0., 0., 0.56, 1.) ), + Label( 'background' , -1 , 0 , 'void' , 0 , False , True , ( 0., 0., 0.0, 0.) ), + Label( 'vehicle' , 33 , 2 , 'vehicle' , 7 , True , False , (*self.aubergine, 1.) ), + Label( 'human' , 25 , 1 , 'human' , 6 , True , False , (*self.blue, 1.) ) + ] + # evtl problem: class-ids (trainIds) don't start with 0 for the first class, 0 is bg. + #WONT WORK: class ids need to start at 0 (excluding bg!) and be consecutively numbered + + self.ppId2id = { label.ppId : label.id for label in self.labels} + self.class_id2label = { label.id : label for label in self.labels} + self.class_cmap = {label.id : label.color for label in self.labels} + self.class_dict = {label.id : label.name for label in self.labels if label.id!=0} + #c_dict: only for evaluation, remove bg class. + + self.box_type2label = {label.name : label for label in self.box_labels} + self.box_color_palette = {label.name:label.color for label in self.box_labels} + + if self.class_specific_seg: + self.seg_labels = [label for label in self.class_id2label.values()] + else: + self.seg_labels = [ + # name id color + segLabel( "bg" , 0, (1.,1.,1.,0.) ), + segLabel( "fg" , 1, (*self.orange, .8)) + ] + + self.seg_id2label = {label.id : label for label in self.seg_labels} + self.cmap = {label.id : label.color for label in self.seg_labels} + + self.plot_prediction_histograms = True + self.plot_stat_curves = False + self.has_colorchannels = True + self.plot_class_ids = True + + self.num_classes = len(self.class_dict) + self.num_seg_classes = len(self.seg_labels) + + ######################### + # Testing # + ######################### + + self.test_aug_axes = None #None or list: choices are 2,3,(2,3) + self.held_out_test_set = False + self.max_test_patients = 'all' # 'all' for all + self.report_score_level = ['rois',] # choose list from 'patient', 'rois' + self.patient_class_of_interest = 1 + + self.metrics = ['ap', 'dice'] + self.ap_match_ious = [0.1] # threshold(s) for considering a prediction as true positive + # aggregation method for test and val_patient predictions. + # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf, + # nms = standard non-maximum suppression, or None = no clustering + self.clustering = 'wbc' + # iou thresh (exclusive!) for regarding two preds as concerning the same ROI + self.clustering_iou = 0.1 # has to be larger than desired possible overlap iou of model predictions + + self.min_det_thresh = 0.06 + self.merge_2D_to_3D_preds = False + + self.n_test_plots = 1 #per fold and rankself.ap_match_ious = [0.1] #threshold(s) for considering a prediction as true positive + self.test_n_epochs = self.save_n_models + + + ######################### + # shared model settings # + ######################### + + # max number of roi candidates to identify per image and class (slice in 2D, volume in 3D) + self.n_roi_candidates = 100 + + ######################### + # Add model specifics # + ######################### + + {'mrcnn': self.add_mrcnn_configs, 'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs + }[self.model]() + + def add_mrcnn_configs(self): + + self.scheduling_criterion = max(self.model_selection_criteria, key=self.model_selection_criteria.get) + self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' + + # number of classes for network heads: n_foreground_classes + 1 (background) + self.head_classes = self.num_classes + 1 + + # seg_classes here refers to the first stage classifier (RPN) reallY? + + # feed +/- n neighbouring slices into channel dimension. set to None for no context. + self.n_3D_context = None + + + self.frcnn_mode = False + + self.detect_while_training = True + # disable the re-sampling of mask proposals to original size for speed-up. + # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching), + # mask outputs are optional. + self.return_masks_in_train = True + self.return_masks_in_val = True + self.return_masks_in_test = True + + # feature map strides per pyramid level are inferred from architecture. anchor scales are set accordingly. + self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]} + # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale + # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.) + self.rpn_anchor_scales = {'xy': [[4], [8], [16], [32]], 'z': [[1], [2], [4], [8]]} + # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3. + self.pyramid_levels = [0, 1, 2, 3] + # number of feature maps in rpn. typically lowered in 3D to save gpu-memory. + self.n_rpn_features = 512 if self.dim == 2 else 64 + + # anchor ratios and strides per position in feature maps. + self.rpn_anchor_ratios = [0.5, 1., 2.] + self.rpn_anchor_stride = 1 + # Threshold for first stage (RPN) non-maximum suppression (NMS): LOWER == HARDER SELECTION + self.rpn_nms_threshold = 0.7 + + # loss sampling settings. + self.rpn_train_anchors_per_image = 8 + self.train_rois_per_image = 10 # per batch_instance + self.roi_positive_ratio = 0.5 + self.anchor_matching_iou = 0.8 + + # k negative example candidates are drawn from a pool of size k*shem_poolsize (stochastic hard-example mining), + # where k<=#positive examples. + self.shem_poolsize = 3 + + self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3) + self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5) + self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10) + + self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) + self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) + self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]]) + self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1], + self.patch_size_3D[2], self.patch_size_3D[2]]) # y1,x1,y2,x2,z1,z2 + + if self.dim == 2: + self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4] + self.bbox_std_dev = self.bbox_std_dev[:4] + self.window = self.window[:4] + self.scale = self.scale[:4] + + self.plot_y_max = 1.5 + self.n_plot_rpn_props = 5 # per batch_instance (slice in 2D / patient in 3D) + + # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element. + self.pre_nms_limit = 3000 + + # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True, + # since proposals of the entire batch are forwarded through second stage as one "batch". + self.roi_batch_size = 2500 + self.post_nms_rois_training = 500 + self.post_nms_rois_inference = 500 + + # Final selection of detections (refine_detections) + self.model_max_instances_per_batch_element = 50 # per batch element and class. + self.detection_nms_threshold = 1e-5 # needs to be > 0, otherwise all predictions are one cluster. + self.model_min_confidence = 0.05 # iou for nms in box refining (directly after heads), should be >0 since ths>=x in mrcnn.py + + if self.dim == 2: + self.backbone_shapes = np.array( + [[int(np.ceil(self.patch_size[0] / stride)), + int(np.ceil(self.patch_size[1] / stride))] + for stride in self.backbone_strides['xy']]) + else: + self.backbone_shapes = np.array( + [[int(np.ceil(self.patch_size[0] / stride)), + int(np.ceil(self.patch_size[1] / stride)), + int(np.ceil(self.patch_size[2] / stride_z))] + for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z'] + )]) + + if self.model == 'retina_net' or self.model == 'retina_unet': + # implement extra anchor-scales according to https://arxiv.org/abs/1708.02002 + self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in + self.rpn_anchor_scales['xy']] + self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in + self.rpn_anchor_scales['z']] + self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3 + + self.n_rpn_features = 256 if self.dim == 2 else 64 + + # pre-selection of detections for NMS-speedup. per entire batch. + self.pre_nms_limit = 10000 if self.dim == 2 else 30000 + + # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002 + self.anchor_matching_iou = 0.5 + + if self.model == 'retina_unet': + self.operate_stride1 = True \ No newline at end of file diff --git a/datasets/cityscapes/data_loader.py b/datasets/cityscapes/data_loader.py new file mode 100644 index 0000000..01a1a45 --- /dev/null +++ b/datasets/cityscapes/data_loader.py @@ -0,0 +1,452 @@ +import sys +sys.path.append('../') #works on cluster indep from where sbatch job is started +import plotting as plg + +import warnings +import os +import time +import pickle + + +import numpy as np +import pandas as pd +from PIL import Image as pil + +import torch +import torch.utils.data + +# batch generator tools from https://github.com/MIC-DKFZ/batchgenerators +from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror +from batchgenerators.transforms.abstract_transforms import Compose +from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter +from batchgenerators.transforms.spatial_transforms import SpatialTransform +from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform +from batchgenerators.transforms.color_transforms import GammaTransform +#from batchgenerators.transforms.utility_transforms import ConvertSegToBoundingBoxCoordinates + + +sys.path.append(os.path.dirname(os.path.realpath(__file__))) + +import utils.exp_utils as utils +import utils.dataloader_utils as dutils +from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates + +from configs import Configs +cf= configs() + + +warnings.filterwarnings("ignore", message="This figure includes Axes.*") + + +def load_obj(file_path): + with open(file_path, 'rb') as handle: + return pickle.load(handle) + +def save_to_npy(arr_out, array): + np.save(arr_out+".npy", array) + print("Saved binary .npy-file to {}".format(arr_out)) + return arr_out+".npy" + +def shape_small_first(shape): + if len(shape)<=2: #no changing dimensions if channel-dim is missing + return shape + smallest_dim = np.argmin(shape) + if smallest_dim!=0: #assume that smallest dim is color channel + new_shape = np.array(shape) #to support mask indexing + new_shape = (new_shape[smallest_dim], + *new_shape[(np.arange(len(shape),dtype=int)!=smallest_dim)]) + return new_shape + else: + return shape + +class Dataset(dutils.Dataset): + def __init__(self, cf, logger=None, subset_ids=None, data_sourcedir=None): + super(Dataset, self).__init__(cf, data_sourcedir=data_sourcedir) + + info_dict = load_obj(cf.info_dict_path) + + if subset_ids is not None: + img_ids = subset_ids + if logger is None: + print('subset: selected {} instances from df'.format(len(pids))) + else: + logger.info('subset: selected {} instances from df'.format(len(pids))) + else: + img_ids = list(info_dict.keys()) + + #evtly copy data from data_rootdir to data_dir + if cf.server_env and not hasattr(cf, "data_dir"): + file_subset = [info_dict[img_id]['img'][:-3]+"*" for img_id in img_ids] + file_subset+= [info_dict[img_id]['seg'][:-3]+"*" for img_id in img_ids] + file_subset+= [cf.info_dict_path] + self.copy_data(cf, file_subset=file_subset) + cf.data_dir = self.data_dir + + img_paths = [os.path.join(self.data_dir, info_dict[img_id]['img']) for img_id in img_ids] + seg_paths = [os.path.join(self.data_dir, info_dict[img_id]['seg']) for img_id in img_ids] + + # load all subject files + self.data = {} + for i, img_id in enumerate(img_ids): + subj_data = {'img_id':img_id} + subj_data['img'] = img_paths[i] + subj_data['seg'] = seg_paths[i] + if 'class' in self.cf.prediction_tasks: + subj_data['class_targets'] = np.array(info_dict[img_id]['roi_classes']) + else: + subj_data['class_targets'] = np.ones_like(np.array(info_dict[img_id]['roi_classes'])) + + self.data[img_id] = subj_data + + cf.roi_items = cf.observables_rois[:] + cf.roi_items += ['class_targets'] + if 'regression' in cf.prediction_tasks: + cf.roi_items += ['regression_targets'] + + self.set_ids = list(self.data.keys()) + + self.df = None + +class BatchGenerator(dutils.BatchGenerator): + """ + create the training/validation batch generator. Randomly sample batch_size patients + from the data set, (draw a random slice if 2D), pad-crop them to equal sizes and merge to an array. + :param data: data dictionary as provided by 'load_dataset' + :param img_modalities: list of strings ['adc', 'b1500'] from config + :param batch_size: number of patients to sample for the batch + :param pre_crop_size: equal size for merging the patients to a single array (before the final random-crop in data aug.) + :return dictionary containing the batch data / seg / pids as lists; the augmenter will later concatenate them into an array. + """ + def __init__(self, cf, data, n_batches=None, sample_pids_w_replace=True): + super(BatchGenerator, self).__init__(cf, data, n_batches) + self.dataset_length = len(self._data) + self.cf = cf + + self.sample_pids_w_replace = sample_pids_w_replace + self.eligible_pids = list(self._data.keys()) + + self.chans = cf.channels if cf.channels is not None else np.index_exp[:] + assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing" + + self.p_fg = 0.5 + self.empty_samples_max_ratio = 0.33 + self.random_count = int(cf.batch_random_ratio * cf.batch_size) + + self.balance_target_distribution(plot=sample_pids_w_replace) + self.stats = {"roi_counts" : np.zeros((len(self.unique_ts),), dtype='uint32'), "empty_samples_count" : 0} + + def generate_train_batch(self): + #everything done in here is per batch + #print statements in here get confusing due to multithreading + if self.sample_pids_w_replace: + # fully random patients + batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False)) + # target-balanced patients + batch_patient_ids += list(np.random.choice( + self.dataset_pids, size=self.batch_size - self.random_count, replace=False, p=self.p_probs)) + else: + batch_patient_ids = np.random.choice(self.eligible_pids, size=self.batch_size, replace=False) + if self.sample_pids_w_replace == False: + self.eligible_pids = [pid for pid in self.eligible_pids if pid not in batch_patient_ids] + if len(self.eligible_pids) < self.batch_size: + self.eligible_pids = self.dataset_pids + + batch_data, batch_segs, batch_class_targets = [], [], [] + # record roi count of classes in batch + batch_roi_counts, empty_samples_count = np.zeros((self.cf.num_classes,), dtype='uint32'), 0 + + for sample in range(self.batch_size): + + patient = self._data[batch_patient_ids[sample]] + + data = np.load(patient["img"], mmap_mode="r") + seg = np.load(patient['seg'], mmap_mode="r") + + (c,y,x) = data.shape + spatial_shp = data[0].shape + assert spatial_shp==seg.shape, "spatial shape incongruence betw. data {} and seg {}".format(spatial_shp, seg.shape) + + if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]): + new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))] + data = dutils.pad_nd_image(data, (len(data), *new_shape)) + seg = dutils.pad_nd_image(seg, new_shape) + + #eventual cropping to pre_crop_size: with prob self.p_fg sample pixel from random ROI and shift center, + #if possible, to that pixel, so that img still contains ROI after pre-cropping + dim_cropflags = [spatial_shp[i] > self.cf.pre_crop_size[i] for i in range(len(spatial_shp))] + if np.any(dim_cropflags): + #sample crop center regardless of ROIs, not guaranteed to be empty + def get_cropped_centercoords(dim): + return np.random.randint(low=self.cf.pre_crop_size[dim]//2, + high=spatial_shp[dim] - self.cf.pre_crop_size[dim]//2) + + sample_seg_center = {} + for dim in np.where(dim_cropflags)[0]: + sample_seg_center[dim] = get_cropped_centercoords(dim) + min_ = int(sample_seg_center[dim] - self.cf.pre_crop_size[dim]//2) + max_ = int(sample_seg_center[dim] + self.cf.pre_crop_size[dim]//2) + data = np.take(data, indices=range(min_, max_), axis=dim+1) #+1 for channeldim + seg = np.take(seg, indices=range(min_, max_), axis=dim) + + batch_data.append(data) + batch_segs.append(seg[np.newaxis]) + + batch_class_targets.append(patient['class_targets']) + + for cl in range(self.cf.num_classes): + batch_roi_counts[cl] += np.count_nonzero(patient['class_targets'][np.unique(seg[seg>0]) - 1] == cl) + if not np.any(seg): + empty_samples_count += 1 + + batch = {'data': np.array(batch_data).astype('float32'), 'seg': np.array(batch_segs).astype('uint8'), + 'pid': batch_patient_ids, 'class_targets': np.array(batch_class_targets), + 'roi_counts': batch_roi_counts, 'empty_samples_count': empty_samples_count} + return batch + +class PatientBatchIterator(dutils.PatientBatchIterator): + """ + creates a val/test generator. Step through the dataset and return dictionaries per patient. + For Patching, shifts all patches into batch dimension. batch_tiling_forward will take care of exceeding batch dimensions. + + This iterator/these batches are not intended to go through MTaugmenter afterwards + """ + + def __init__(self, cf, data): + super(PatientBatchIterator, self).__init__(cf, data) + + self.patch_size = cf.patch_size + + self.patient_ix = 0 # running index over all patients in set + + def generate_train_batch(self, pid=None): + + if self.patient_ix == len(self.dataset_pids): + self.patient_ix = 0 + if pid is None: + pid = self.dataset_pids[self.patient_ix] # + self.thread_id + patient = self._data[pid] + batch_class_targets = np.array([patient['class_targets']]) + + data = np.load(patient["img"], mmap_mode="r")[np.newaxis] + seg = np.load(patient['seg'], mmap_mode="r")[np.newaxis, np.newaxis] + (b, c, y, x) = data.shape + spatial_shp = data.shape[2:] + assert spatial_shp == seg.shape[2:], "spatial shape incongruence betw. data {} and seg {}".format(spatial_shp, + seg.shape) + if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]): + new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))] + data = dutils.pad_nd_image(data, (len(data), *new_shape)) + seg = dutils.pad_nd_image(seg, new_shape) + + batch = {'data': data, 'seg': seg, 'class_targets': batch_class_targets} + converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, self.cf.class_specific_seg) + batch = converter(**batch) + batch.update({'patient_bb_target': batch['bb_target'], + 'patient_class_targets': batch['class_targets'], + 'original_img_shape': data.shape, + 'pid': np.array([pid] * len(data))}) + + # eventual tiling into patches + spatial_shp = batch["data"].shape[2:] + if np.any([spatial_shp[ix] > self.patch_size[ix] for ix in range(len(spatial_shp))]): + patient_batch = batch + print("patientiterator produced patched batch!") + patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size) + new_img_batch, new_seg_batch = [], [] + + for c in patch_crop_coords_list: + new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3]]) + seg_patch = seg[:, c[0]:c[1], c[2]: c[3]] + new_seg_batch.append(seg_patch) + + shps = [] + for arr in new_img_batch: + shps.append(arr.shape) + + data = np.array(new_img_batch) # (patches, c, x, y, z) + seg = np.array(new_seg_batch) + batch_class_targets = np.repeat(batch_class_targets, len(patch_crop_coords_list), axis=0) + + patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'), + 'class_targets': batch_class_targets, + 'pid': np.array([pid] * data.shape[0])} + patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) + patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] + patch_batch['patient_class_targets'] = patient_batch['patient_class_targets'] + patch_batch['patient_data'] = patient_batch['data'] + patch_batch['patient_seg'] = patient_batch['seg'] + patch_batch['original_img_shape'] = patient_batch['original_img_shape'] + + converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, self.cf.class_specific_seg) + patch_batch = converter(**patch_batch) + batch = patch_batch + + self.patient_ix += 1 + if self.patient_ix == len(self.dataset_pids): + self.patient_ix = 0 + + return batch + +def create_data_gen_pipeline(cf, patient_data, do_aug=True, sample_pids_w_replace=True): + """ + create mutli-threaded train/val/test batch generation and augmentation pipeline. + :param patient_data: dictionary containing one dictionary per patient in the train/test subset + :param test_pids: (optional) list of test patient ids, calls the test generator. + :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing) + :return: multithreaded_generator + """ + data_gen = BatchGenerator(cf, patient_data, sample_pids_w_replace=sample_pids_w_replace) + + my_transforms = [] + if do_aug: + if cf.da_kwargs["mirror"]: + mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes']) + my_transforms.append(mirror_transform) + spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], + patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'][:2], + do_elastic_deform=cf.da_kwargs['do_elastic_deform'], + alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], + do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], + angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], + do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], + random_crop=cf.da_kwargs['random_crop'], + border_mode_data=cf.da_kwargs['border_mode_data']) + my_transforms.append(spatial_transform) + gamma_transform = GammaTransform(gamma_range=cf.da_kwargs["gamma_range"], invert_image=False, + per_channel=False, retain_stats=False) + my_transforms.append(gamma_transform) + + else: + my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) + + if cf.create_bounding_box_targets: + my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg)) + #batch receives entry 'bb_target' w bbox coordinates as [y1,x1,y2,x2,z1,z2]. + #my_transforms.append(ConvertSegToOnehotTransform(classes=range(cf.num_seg_classes))) + all_transforms = Compose(my_transforms) + #MTAugmenter creates iterator from data iterator data_gen after applying the composed transform all_transforms + multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, + seeds=np.random.randint(0,cf.n_workers*2,size=cf.n_workers)) + return multithreaded_generator + + +def get_train_generators(cf, logger, data_statistics=True): + """ + wrapper function for creating the training batch generator pipeline. returns the train/val generators + need to select cv folds on patient level, but be able to include both breasts of each patient. + """ + dataset = Dataset(cf) + dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits) + dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle')) + set_splits = dataset.fg.splits + + test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1) + train_ids = np.concatenate(set_splits, axis=0) + + if cf.held_out_test_set: + train_ids = np.concatenate((train_ids, test_ids), axis=0) + test_ids = [] + + train_data = {k: v for (k, v) in dataset.data.items() if k in train_ids} + val_data = {k: v for (k, v) in dataset.data.items() if k in val_ids} + + logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids), + len(test_ids))) + if data_statistics: + dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids}, + plot_dir=os.path.join(cf.plot_dir, "data_stats_fold_"+str(cf.fold))) + + batch_gen = {} + batch_gen['train'] = create_data_gen_pipeline(cf, train_data, do_aug=True) + batch_gen[cf.val_mode] = create_data_gen_pipeline(cf, val_data, do_aug=False, sample_pids_w_replace=False) + batch_gen['n_val'] = cf.num_val_batches if cf.num_val_batches!="all" else len(val_data) + + return batch_gen + +def get_test_generator(cf, logger): + """ + if get_test_generators is called multiple times in server env, every time of + Dataset initiation rsync will check for copying the data; this should be okay + since rsync will not copy if files already exist in destination. + """ + + if cf.held_out_test_set: + sourcedir = cf.test_data_sourcedir + test_ids = None + else: + sourcedir = None + with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle: + set_splits = pickle.load(handle) + test_ids = set_splits[cf.fold] + + + test_set = Dataset(cf, test_ids, data_sourcedir=sourcedir) + logger.info("data set loaded with: {} test patients".format(len(test_set.set_ids))) + batch_gen = {} + batch_gen['test'] = PatientBatchIterator(cf, test_set.data) + batch_gen['n_test'] = len(test_set.set_ids) if cf.max_test_patients=="all" else min(cf.max_test_patients, len(test_set.set_ids)) + + return batch_gen + +def main(): + total_stime = time.time() + times = {} + + CUDA = torch.cuda.is_available() + print("CUDA available: ", CUDA) + + + #cf.server_env = True + #cf.data_dir = "experiments/dev_data" + + cf.exp_dir = "experiments/dev/" + cf.plot_dir = cf.exp_dir+"plots" + os.makedirs(cf.exp_dir, exist_ok=True) + cf.fold = 0 + logger = utils.get_logger(cf.exp_dir) + + gens = get_train_generators(cf, logger) + train_loader = gens['train'] + + #for i in range(train_loader.dataset_length): + # print("batch", i) + stime = time.time() + ex_batch = next(train_loader) + # plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_extrainbatch.png", has_colorchannels=True, isRGB=True) + times["train_batch"] = time.time()-stime + + + val_loader = gens['val_sampling'] + stime = time.time() + ex_batch = next(val_loader) + times["val_batch"] = time.time()-stime + stime = time.time() + plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch.png", has_colorchannels=True, isRGB=True, show_gt_boxes=False) + times["val_plot"] = time.time()-stime + + test_loader = get_test_generator(cf, logger)["test"] + stime = time.time() + ex_batch = next(test_loader) + times["test_batch"] = time.time()-stime + #plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_expatientbatch.png", has_colorchannels=True, isRGB=True) + + print(ex_batch["data"].shape) + + + print("Times recorded throughout:") + for (k,v) in times.items(): + print(k, "{:.2f}".format(v)) + + mins, secs = divmod((time.time() - total_stime), 60) + h, mins = divmod(mins, 60) + t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) + print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) + + + +if __name__=="__main__": + start_time = time.time() + + main() + + print("Program runtime in s: ", '{:.2f}'.format(time.time()-start_time)) \ No newline at end of file diff --git a/datasets/cityscapes/preprocessing.py b/datasets/cityscapes/preprocessing.py new file mode 100644 index 0000000..56c8c20 --- /dev/null +++ b/datasets/cityscapes/preprocessing.py @@ -0,0 +1,267 @@ +import sys +import os +from multiprocessing import Pool +import time +import pickle + +import numpy as np + +from PIL import Image as pil +from matplotlib import pyplot as plt + +sys.path.append("../") +import data_manager as dmanager + +from configs import Configs +cf = configs() + + +""" +""" + +def load_obj(file_path): + with open(file_path, 'rb') as handle: + return pickle.load(handle) + +def save_obj(obj, path): + """Pickle a python object.""" + with open(path, 'wb') as f: + pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) + +def merge_labelids(target, cf=cf): + """relabel preprocessing id to training id according to config.labels + :param target: np.array hxw holding the annotation (labelids at pixel positions) + :cf: The configurations file + """ + for i in range(target.shape[0]): #Iterate over height. + for j in range(target.shape[1]): #Iterate over width + target[i][j] = cf.ppId2id[int(target[i][j])] + + return target + +def generate_detection_labels(target, cf=cf): + """labels suitable to be used with batchgenerators.ConvertSegToBoundingBoxCoordinates. + Flaw: cannot handle more than 2 segmentation classes (fg/bg). + --> seg-info is lost, but not interested in seg rn anyway. + :param target: expected as instanceIds img + The pixel values encode both, class and the individual instance. + The integer part of a division by 1000 of each ID provides the class ID, + as described in labels.py. The remainder is the instance ID. If a certain + annotation describes multiple instances, then the pixels have the regular + ID of that class. + """ + + unique_IDs = np.unique(target) + roi_classes = [] + + objs_in_img = 0 + for i, instanceID in enumerate(unique_IDs): + if instanceID > max(list(cf.ppId2id.keys())): + instance_classID = instanceID // 1000 + else: + # this is the group case (only class id assigned, no instance id) + instance_classID = instanceID + if cf.ppId2id[instance_classID]!=0: + #discard this whole sample since it has group instead of + #single instance annotations for a non-bg class + return None, None + + if cf.ppId2id[instance_classID]!=0: + #only pick reasonable objects, exclude road, sky, etc. + roi_classes.append(cf.ppId2id[instance_classID]) + objs_in_img+=1 #since 0 is bg + target[target==instanceID] = objs_in_img + else: + target[target==instanceID] = 0 + + return target, roi_classes + +class Preprocessor(): + + def __init__(self, cf, cities): + + self._cf = cf.prepro + + self.rootpath = cf.prepro['data_dir'] + self.set_splits = self._cf["set_splits"] + self.cities = cities + self.datapath = cf.datapath + self.targetspath = cf.targetspath + self.targettype = cf.prepro["targettype"] + + self.img_t_size = cf.prepro["img_target_size"] + self.target_t_size = self.img_t_size + + self.rootpath_out = cf.prepro["output_directory"] + + self.info_dict = {} + """info_dict: will hold {img_identifier: img_dict} with + img_dict = {id: img_identifier, img:img_path, seg:seg_path, + roi_classes:roiclasses} + """ + + def load_from_path_to_path(self, set_split, max_num=None): + """composes data and corresponding labels paths (to .png-files). + + assumes data tree structure: datapath-|-->city1-->img1.png,img2.png,... + |-->city2-->img1.png, ... + """ + data = [] + labels = [] + num=0 + for city in self.cities[set_split]: + path = os.path.join(self.rootpath, self.datapath, set_split, city) + lpath = os.path.join(self.rootpath,self.targetspath,set_split, city) + + files_in_dir = os.listdir(path) + for file in files_in_dir: + split = os.path.splitext(file) + if split[1].lower() == ".png": + num+=1 + filetag = file[:-(len(self.datapath)+3)] + data.append(os.path.join(path,file)) + labels.append(os.path.join(lpath,filetag+self.targettype+".png")) + + if num==max_num: + break + if num==max_num: + break + + return data, labels + + def prep_img(self, args): + """suited for multithreading. + :param args: (img_path, targ_path) + """ + + img_path, trg_path = args[0], args[1] + + img_rel_path = img_path[len(self.rootpath):] + trg_rel_path = trg_path[len(self.rootpath):] + + _path, img_name = os.path.split(img_path) + img_identifier = "".join(img_name.split("_")[:3]) + img_info_dict = {} #entry of img_identifier in full info_dict + + img, target = pil.open(img_path), pil.open(trg_path) + img, target = img.resize(self.img_t_size[::-1]), target.resize(self.target_t_size[::-1]) + img, target = np.array(img), np.array(target) #shapes y,x(,c) + img = np.transpose(img, axes=(2,0,1)) #shapes (c,)y,x + + target, roi_classes = generate_detection_labels(target) + if target is None: + return (img_identifier, target) + img_info_dict["roi_classes"] = roi_classes + + path = os.path.join(self.rootpath_out,*img_rel_path.split(os.path.sep)[:-1]) + os.makedirs(path, exist_ok=True) + + img_path = os.path.join(self.rootpath_out, img_rel_path[:-3]+"npy") + + #img.save(img_path) + img_info_dict["img"] = img_rel_path[:-3]+"npy" + np.save(img_path, img) + + path = os.path.join(self.rootpath_out,*trg_rel_path.split(os.path.sep)[:-1]) + os.makedirs(path, exist_ok=True) + t_path = os.path.join(self.rootpath_out, trg_rel_path)[:-3]+"npy" + #target.save(t_path) + img_info_dict["seg"] = trg_rel_path[:-3]+"npy" + np.save(t_path, target) + + print("\rSaved npy images and targets of shapes {}, {} to files\n {},\n {}". \ + format(img.shape, target.shape, img_path, t_path), flush=True, end="") + + return (img_identifier, img_info_dict) + + def prep_imgs(self, max_num=None, processes=4): + self.info_dict = {} + self.discarded = [] + os.makedirs(self.rootpath_out, exist_ok=True) + for set_split in self.set_splits: + data, targets = self.load_from_path_to_path(set_split, max_num=max_num) + + print(next(zip(data, targets))) + p = Pool(processes) + + img_info_dicts = p.map(self.prep_img, zip(data, targets)) + + p.close() + p.join() + + self.info_dict.update({id_:dict_ for (id_,dict_) in img_info_dicts if dict_ is not None}) + self.discarded += [id_ for (id_, dict_) in img_info_dicts if dict_ is None] + #list of samples discarded due to group instead of single instance annotation + + def finish(self): + total_items = len(self.info_dict)+len(self.discarded) + + print("\n\nSamples discarded: {}/{}={:.1f}%, identifiers:".format(len(self.discarded), + total_items, len(self.discarded)/total_items*100)) + for id_ in self.discarded: + print(id_) + + save_obj(self.info_dict, self._cf["info_dict_path"]) + + + def convert_copy_npz(self): + if not self._cf["npz_dir"]: + return + print("converting & copying to npz dir", self._cf['npz_dir']) + os.makedirs(self._cf['npz_dir'], exist_ok=True) + save_obj(self.info_dict, os.path.join(self._cf['npz_dir'], + self._cf['info_dict_path'].split("/")[-1])) + + dmanager.pack_dataset(self._cf["output_directory"], self._cf["npz_dir"], recursive=True, verbose=False) + + + def verification(self, max_num=None): + print("\n\n\nVerification\n") + for i, k in enumerate(self.info_dict): + if max_num is not None and i==max_num: + break + + subject = self.info_dict[k] + + seg = np.load(os.path.join(self.rootpath_out, subject["seg"])) + + #print("seg values", np.unique(seg)) + print("nr of objects", len(subject["roi_classes"])) + print("nr of objects should equal highest seg value, fulfilled?", + np.max(seg)==len(subject["roi_classes"])) + #print("roi_classes", subject["roi_classes"]) + + img = np.transpose(np.load(os.path.join(self.rootpath_out, subject["img"])), axes=(1,2,0)) + print("img shp", img.shape) + plt.imshow(img) + + +def main(): + #cf.set_splits = ["train"] + #cities = {'train':['dusseldorf'], 'val':['frankfurt']} #cf.cities + cities= cf.cities + + pp = Preprocessor(cf, cities) + pp.prep_imgs(max_num=None, processes=8) + pp.finish() + + #pp.convert_copy_npz() + + pp.verification(1) + + + + + + + return + +if __name__=="__main__": + stime = time.time() + + main() + + mins, secs = divmod((time.time() - stime), 60) + h, mins = divmod(mins, 60) + t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) + print("Prepro program runtime: {}".format(t)) diff --git a/datasets/legacy/convert_folds_ids.py b/datasets/legacy/convert_folds_ids.py new file mode 100644 index 0000000..ba16b34 --- /dev/null +++ b/datasets/legacy/convert_folds_ids.py @@ -0,0 +1,148 @@ +""" +Created at 28.05.19 16:46 +@author: gregor +""" + +import os +import sys +import subprocess + +import pickle +import numpy as np +import pandas as pd +from collections import OrderedDict + +import utils.exp_utils as utils + +def get_cf(dataset_name, exp_dir=""): + + cf_path = os.path.join('datasets', dataset_name, exp_dir, "configs.py") + cf_file = utils.import_module('configs', cf_path) + + return cf_file.Configs() + +def vector(item): + """ensure item is vector-like (list or array or tuple) + :param item: anything + """ + if not isinstance(item, (list, tuple, np.ndarray)): + item = [item] + return item + +def load_dataset(cf, subset_ixs=None): + """ + loads the dataset. if deployed in cloud also copies and unpacks the data to the working directory. + :param subset_ixs: subset indices to be loaded from the dataset. used e.g. for testing to only load the test folds. + :return: data: dictionary with one entry per patient (in this case per patient-breast, since they are treated as + individual images for training) each entry is a dictionary containing respective meta-info as well as paths to the preprocessed + numpy arrays to be loaded during batch-generation + """ + + p_df = pd.read_pickle(os.path.join(cf.pp_data_path, cf.input_df_name)) + + exclude_pids = ["0305a", "0447a"] # due to non-bg segmentation but bg mal label in nodules 5728, 8840 + p_df = p_df[~p_df.pid.isin(exclude_pids)] + + if cf.select_prototype_subset is not None: + prototype_pids = p_df.pid.tolist()[:cf.select_prototype_subset] + p_df = p_df[p_df.pid.isin(prototype_pids)] + logger.warning('WARNING: using prototyping data subset!!!') + if subset_ixs is not None: + subset_pids = [np.unique(p_df.pid.tolist())[ix] for ix in subset_ixs] + p_df = p_df[p_df.pid.isin(subset_pids)] + + print('subset: selected {} instances from df'.format(len(p_df))) + + pids = p_df.pid.tolist() + cf.data_dir = cf.pp_data_path + + + imgs = [os.path.join(cf.data_dir, '{}_img.npy'.format(pid)) for pid in pids] + segs = [os.path.join(cf.data_dir,'{}_rois.npz'.format(pid)) for pid in pids] + orig_class_targets = p_df['class_target'].tolist() + + data = OrderedDict() + for ix, pid in enumerate(pids): + data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid} + data[pid]['fg_slices'] = np.array(p_df['fg_slices'].tolist()[ix]) + if 'class' in cf.prediction_tasks: + # malignancy scores are binarized: (benign: 1-2 --> cl 1, malignant: 3-5 --> cl 2) + raise NotImplementedError + # todo need to consider bg + data[pid]['class_targets'] = np.array([ [2 if ii >= 3 else 1 for ii in four_fold_targs] for four_fold_targs in orig_class_targets[ix]]) + else: + data[pid]['class_targets'] = np.array([ [1 if ii>0 else 0 for ii in four_fold_targs] for four_fold_targs in orig_class_targets[ix]], dtype='uint8') + if any(['regression' in task for task in cf.prediction_tasks]): + data[pid]["regression_targets"] = np.array([ [vector(v) for v in four_fold_targs] for four_fold_targs in orig_class_targets[ix] ], dtype='float16') + data[pid]["rg_bin_targets"] = np.array([ [cf.rg_val_to_bin_id(v) for v in four_fold_targs] for four_fold_targs in data[pid]["regression_targets"]], dtype='uint8') + + cf.roi_items = cf.observables_rois[:] + cf.roi_items += ['class_targets'] + if any(['regression' in task for task in cf.prediction_tasks]): + cf.roi_items += ['regression_targets'] + cf.roi_items += ['rg_bin_targets'] + + return data + + +def get_patient_identifiers(cf, fold_lists): + + + all_data = load_dataset(cf) + all_pids_list = np.unique([v['pid'] for (k, v) in all_data.items()]) + + + verifier = [] #list of folds + for fold in range(cf.n_cv_splits): + train_ix, val_ix, test_ix, fold_nr = fold_lists[fold] + assert fold==fold_nr + test_ids = [all_pids_list[ix] for ix in test_ix] + for ix, arr in enumerate(verifier): + inter = np.intersect1d(test_ids, arr) + #print("intersect of fold {} with fold {}: {}".format(fold, ix, inter)) + assert len(inter)==0 + verifier.append(test_ids) + + + return verifier + +def convert_folds_ids(exp_dir): + import inference_analysis + cf = get_cf('lidc', exp_dir=exp_dir) + cf.exp_dir = exp_dir + with open(os.path.join(exp_dir, 'fold_ids.pickle'), 'rb') as f: + fids = pickle.load(f) + + pid_fold_splits = get_patient_identifiers(cf, fids) + + with open(os.path.join(exp_dir, 'fold_real_ids.pickle'), 'wb') as handle: + pickle.dump(pid_fold_splits, handle) + + + #inference_analysis.find_pid_in_splits('0811a', exp_dir=exp_dir) + return + + +def copy_to_new_exp_dir(old_dir, new_dir): + + + cp_ids = r"rsync {} {}".format(os.path.join(old_dir, 'fold_real_ids.pickle'), new_dir) + rn_ids = "mv {} {}".format(os.path.join(new_dir, 'fold_real_ids.pickle'), os.path.join(new_dir, 'fold_ids.pickle')) + cp_params = r"""rsync -a --include='*/' --include='*best_params.pth' --exclude='*' --prune-empty-dirs + {} {}""".format(old_dir, new_dir) + cp_ranking = r"""rsync -a --include='*/' --include='epoch_ranking.npy' --exclude='*' --prune-empty-dirs + {} {}""".format(old_dir, new_dir) + cp_results = r"""rsync -a --include='*/' --include='pred_results.pkl' --exclude='*' --prune-empty-dirs + {} {}""".format(old_dir, new_dir) + + for cmd in [cp_ids, rn_ids, cp_params, cp_ranking, cp_results]: + subprocess.call(cmd, shell=True) + print("Setup {} for inference with ids, params from {}".format(new_dir, old_dir)) + + + +if __name__=="__main__": + exp_dir = '/home/gregor/networkdrives/E132-Cluster-Projects/lidc_sa/experiments/ms12345_mrcnn3d_rgbin_bs8' + new_exp_dir = '/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/ms12345_mrcnn3d_rgbin_copiedparams' + #convert_folds_ids(exp_dir) + copy_to_new_exp_dir(exp_dir, new_exp_dir) \ No newline at end of file diff --git a/datasets/lidc/LIDC_XML_Documentation_1_Jan_2009.doc b/datasets/lidc/LIDC_XML_Documentation_1_Jan_2009.doc new file mode 100644 index 0000000..9ae6550 Binary files /dev/null and b/datasets/lidc/LIDC_XML_Documentation_1_Jan_2009.doc differ diff --git a/datasets/lidc/analyze_dataset.py b/datasets/lidc/analyze_dataset.py new file mode 100644 index 0000000..cc79b0c --- /dev/null +++ b/datasets/lidc/analyze_dataset.py @@ -0,0 +1,14 @@ +""" +Created at 29/03/2019 19:20 +@author: gregor +""" + + + +if __name__ == "__main__": + + + + + + pass \ No newline at end of file diff --git a/datasets/lidc/configs.py b/datasets/lidc/configs.py new file mode 100644 index 0000000..126300b --- /dev/null +++ b/datasets/lidc/configs.py @@ -0,0 +1,445 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import sys +import os +from collections import namedtuple +sys.path.append(os.path.dirname(os.path.realpath(__file__))) +import numpy as np +sys.path.append(os.path.dirname(os.path.realpath(__file__))+"/../..") +from default_configs import DefaultConfigs + +# legends, nested classes are not handled well in multiprocessing! hence, Label class def in outer scope +Label = namedtuple("Label", ['id', 'name', 'color', 'm_scores']) # m_scores = malignancy scores +binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'm_scores', 'bin_vals']) + + +class Configs(DefaultConfigs): + + def __init__(self, server_env=None): + super(Configs, self).__init__(server_env) + + ######################### + # Preprocessing # + ######################### + + self.root_dir = '/home/gregor/networkdrives/E130-Personal/Goetz/Datenkollektive/Lungendaten/Nodules_LIDC_IDRI' + self.raw_data_dir = '{}/new_nrrd'.format(self.root_dir) + self.pp_dir = '/mnt/HDD2TB/Documents/data/lidc/pp_20190805' + # 'merged' for one gt per image, 'single_annotator' for four gts per image. + self.gts_to_produce = ["single_annotator", "merged"] + + self.target_spacing = (0.7, 0.7, 1.25) + + ######################### + # I/O # + ######################### + + # path to preprocessed data. + #self.pp_name = 'pp_20190318' + self.pp_name = 'pp_20190805' + + self.input_df_name = 'info_df.pickle' + self.data_sourcedir = '/mnt/HDD2TB/Documents/data/lidc/{}/'.format(self.pp_name) + + # settings for deployment on cluster. + if server_env: + # path to preprocessed data. + self.data_sourcedir = '/datasets/data_ramien/lidc/{}_npz/'.format(self.pp_name) + + # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_fpn']. + self.model = 'retina_net' + self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net') + self.model_path = os.path.join(self.source_dir, self.model_path) + + + ######################### + # Architecture # + ######################### + + # dimension the model operates in. one out of [2, 3]. + self.dim = 3 + + # 'class': standard object classification per roi, pairwise combinable with each of below tasks. + # if 'class' is omitted from tasks, object classes will be fg/bg (1/0) from RPN. + # 'regression': regress some vector per each roi + # 'regression_ken_gal': use kendall-gal uncertainty sigma + # 'regression_bin': classify each roi into a bin related to a regression scale + self.prediction_tasks = ['class'] + + self.start_filts = 48 if self.dim == 2 else 18 + self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2 + self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50' + self.norm = None # one of None, 'instance_norm', 'batch_norm' + + # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform') + self.weight_init = None + + self.regression_n_features = 1 + + ######################### + # Data Loader # + ######################### + + # distorted gt experiments: train on single-annotator gts in a random fashion to investigate network's + # handling of noisy gts. + # choose 'merged' for single, merged gt per image, or 'single_annotator' for four gts per image. + # validation is always performed on same gt kind as training, testing always on merged gt. + self.training_gts = "merged" + + # select modalities from preprocessed data + self.channels = [0] + self.n_channels = len(self.channels) + + # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation. + self.pre_crop_size_2D = [320, 320] + self.patch_size_2D = [320, 320] + self.pre_crop_size_3D = [160, 160, 96] + self.patch_size_3D = [160, 160, 96] + + self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D + self.pre_crop_size = self.pre_crop_size_2D if self.dim == 2 else self.pre_crop_size_3D + + # ratio of free sampled batch elements before class balancing is triggered + # (>0 to include "empty"/background patches.) + self.batch_random_ratio = 0.3 + self.balance_target = "class_targets" if 'class' in self.prediction_tasks else 'rg_bin_targets' + + # set 2D network to match 3D gt boxes. + self.merge_2D_to_3D_preds = self.dim==2 + + self.observables_rois = [] + + #self.rg_map = {1:1, 2:2, 3:3, 4:4, 5:5} + + ######################### + # Colors and Legends # + ######################### + self.plot_frequency = 5 + + binary_cl_labels = [Label(1, 'benign', (*self.dark_green, 1.), (1, 2)), + Label(2, 'malignant', (*self.red, 1.), (3, 4, 5))] + quintuple_cl_labels = [Label(1, 'MS1', (*self.dark_green, 1.), (1,)), + Label(2, 'MS2', (*self.dark_yellow, 1.), (2,)), + Label(3, 'MS3', (*self.orange, 1.), (3,)), + Label(4, 'MS4', (*self.bright_red, 1.), (4,)), + Label(5, 'MS5', (*self.red, 1.), (5,))] + # choose here if to do 2-way or 5-way regression-bin classification + task_spec_cl_labels = quintuple_cl_labels + + self.class_labels = [ + # #id #name #color #malignancy score + Label( 0, 'bg', (*self.gray, 0.), (0,))] + if "class" in self.prediction_tasks: + self.class_labels += task_spec_cl_labels + + else: + self.class_labels += [Label(1, 'lesion', (*self.orange, 1.), (1,2,3,4,5))] + + if any(['regression' in task for task in self.prediction_tasks]): + self.bin_labels = [binLabel(0, 'MS0', (*self.gray, 1.), (0,), (0,))] + self.bin_labels += [binLabel(cll.id, cll.name, cll.color, cll.m_scores, + tuple([ms for ms in cll.m_scores])) for cll in task_spec_cl_labels] + self.bin_id2label = {label.id: label for label in self.bin_labels} + self.ms2bin_label = {ms: label for label in self.bin_labels for ms in label.m_scores} + bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels] + self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)} + self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)] + + if self.class_specific_seg: + self.seg_labels = self.class_labels + else: + self.seg_labels = [ # id #name #color + Label(0, 'bg', (*self.gray, 0.)), + Label(1, 'fg', (*self.orange, 1.)) + ] + + self.class_id2label = {label.id: label for label in self.class_labels} + self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0} + # class_dict is used in evaluator / ap, auc, etc. statistics, and class 0 (bg) only needs to be + # evaluated in debugging + self.class_cmap = {label.id: label.color for label in self.class_labels} + + self.seg_id2label = {label.id: label for label in self.seg_labels} + self.cmap = {label.id: label.color for label in self.seg_labels} + + self.plot_prediction_histograms = True + self.plot_stat_curves = False + self.has_colorchannels = False + self.plot_class_ids = True + + self.num_classes = len(self.class_dict) # for instance classification (excl background) + self.num_seg_classes = len(self.seg_labels) # incl background + + + ######################### + # Data Augmentation # + ######################### + + self.da_kwargs={ + 'mirror': True, + 'mirror_axes': tuple(np.arange(0, self.dim, 1)), + 'do_elastic_deform': True, + 'alpha':(0., 1500.), + 'sigma':(30., 50.), + 'do_rotation':True, + 'angle_x': (0., 2 * np.pi), + 'angle_y': (0., 0), + 'angle_z': (0., 0), + 'do_scale': True, + 'scale':(0.8, 1.1), + 'random_crop':False, + 'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3), + 'border_mode_data': 'constant', + 'border_cval_data': 0, + 'order_data': 1} + + if self.dim == 3: + self.da_kwargs['do_elastic_deform'] = False + self.da_kwargs['angle_x'] = (0, 0.0) + self.da_kwargs['angle_y'] = (0, 0.0) #must be 0!! + self.da_kwargs['angle_z'] = (0., 2 * np.pi) + + ################################# + # Schedule / Selection / Optim # + ################################# + + self.num_epochs = 130 if self.dim == 2 else 150 + self.num_train_batches = 200 if self.dim == 2 else 200 + self.batch_size = 20 if self.dim == 2 else 8 + + # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training) + # the former is morge accurate, while the latter is faster (depending on volume size) + self.val_mode = 'val_sampling' # only 'val_sampling', 'val_patient' not implemented + if self.val_mode == 'val_patient': + raise NotImplementedError + if self.val_mode == 'val_sampling': + self.num_val_batches = 70 + + self.save_n_models = 4 + # set a minimum epoch number for saving in case of instabilities in the first phase of training. + self.min_save_thresh = 0 if self.dim == 2 else 0 + # criteria to average over for saving epochs, 'criterion':weight. + if "class" in self.prediction_tasks: + # 'criterion': weight + if len(self.class_labels)==3: + self.model_selection_criteria = {"benign_ap": 0.5, "malignant_ap": 0.5} + elif len(self.class_labels)==6: + self.model_selection_criteria = {str(label.name)+"_ap": 1./5 for label in self.class_labels if label.id!=0} + elif any("regression" in task for task in self.prediction_tasks): + self.model_selection_criteria = {"lesion_ap": 0.2, "lesion_avp": 0.8} + + self.weight_decay = 0 + self.clip_norm = 200 if 'regression_ken_gal' in self.prediction_tasks else None # number or None + + # int in [0, dataset_size]. select n patients from dataset for prototyping. If None, all data is used. + self.select_prototype_subset = None #self.batch_size + + ######################### + # Testing # + ######################### + + # set the top-n-epochs to be saved for temporal averaging in testing. + self.test_n_epochs = self.save_n_models + + self.test_aug_axes = (0,1,(0,1)) # None or list: choices are 0,1,(0,1) (0==spatial y, 1== spatial x). + self.held_out_test_set = False + self.max_test_patients = "all" # "all" or number + + self.report_score_level = ['rois', 'patient'] # choose list from 'patient', 'rois' + self.patient_class_of_interest = 2 if 'class' in self.prediction_tasks else 1 + + self.metrics = ['ap', 'auc'] + if any(['regression' in task for task in self.prediction_tasks]): + self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp', + 'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp'] + if 'aleatoric' in self.model: + self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted'] + self.evaluate_fold_means = True + + self.ap_match_ious = [0.1] # list of ious to be evaluated for ap-scoring. + self.min_det_thresh = 0.1 # minimum confidence value to select predictions for evaluation. + + # aggregation method for test and val_patient predictions. + # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf, + # nms = standard non-maximum suppression, or None = no clustering + self.clustering = 'wbc' + # iou thresh (exclusive!) for regarding two preds as concerning the same ROI + self.clustering_iou = 0.1 # has to be larger than desired possible overlap iou of model predictions + + self.plot_prediction_histograms = True + self.plot_stat_curves = False + self.n_test_plots = 1 + + ######################### + # Assertions # + ######################### + if not 'class' in self.prediction_tasks: + assert self.num_classes == 1 + + ######################### + # Add model specifics # + ######################### + + {'detection_fpn': self.add_det_fpn_configs, + 'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs, + 'retina_net': self.add_mrcnn_configs, + 'retina_unet': self.add_mrcnn_configs, + }[self.model]() + + def rg_val_to_bin_id(self, rg_val): + return float(np.digitize(np.mean(rg_val), self.bin_edges)) + + def add_det_fpn_configs(self): + + self.learning_rate = [1e-4] * self.num_epochs + self.dynamic_lr_scheduling = False + + # RoI score assigned to aggregation from pixel prediction (connected component). One of ['max', 'median']. + self.score_det = 'max' + + # max number of roi candidates to identify per batch element and class. + self.n_roi_candidates = 10 if self.dim == 2 else 30 + + # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') + self.seg_loss_mode = 'wce' + + # if <1, false positive predictions in foreground are penalized less. + self.fp_dice_weight = 1 if self.dim == 2 else 1 + if len(self.class_labels)==3: + self.wce_weights = [1., 1., 1.] if self.seg_loss_mode=="dice_wce" else [0.1, 1., 1.] + elif len(self.class_labels)==6: + self.wce_weights = [1., 1., 1., 1., 1., 1.] if self.seg_loss_mode == "dice_wce" else [0.1, 1., 1., 1., 1., 1.] + else: + raise Exception("mismatch loss weights & nr of classes") + self.detection_min_confidence = self.min_det_thresh + + self.head_classes = self.num_seg_classes + + def add_mrcnn_configs(self): + + # learning rate is a list with one entry per epoch. + self.learning_rate = [1e-4] * self.num_epochs + self.dynamic_lr_scheduling = False + # disable the re-sampling of mask proposals to original size for speed-up. + # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching), + # mask-outputs are optional. + self.return_masks_in_train = False + self.return_masks_in_val = True + self.return_masks_in_test = False + + # set number of proposal boxes to plot after each epoch. + self.n_plot_rpn_props = 5 if self.dim == 2 else 30 + + # number of classes for network heads: n_foreground_classes + 1 (background) + self.head_classes = self.num_classes + 1 + + self.frcnn_mode = False + + # feature map strides per pyramid level are inferred from architecture. + self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]} + + # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale + # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.) + self.rpn_anchor_scales = {'xy': [[8], [16], [32], [64]], 'z': [[2], [4], [8], [16]]} + + # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3. + self.pyramid_levels = [0, 1, 2, 3] + + # number of feature maps in rpn. typically lowered in 3D to save gpu-memory. + self.n_rpn_features = 512 if self.dim == 2 else 128 + + # anchor ratios and strides per position in feature maps. + self.rpn_anchor_ratios = [0.5, 1, 2] + self.rpn_anchor_stride = 1 + + # Threshold for first stage (RPN) non-maximum suppression (NMS): LOWER == HARDER SELECTION + self.rpn_nms_threshold = 0.7 if self.dim == 2 else 0.7 + + # loss sampling settings. + self.rpn_train_anchors_per_image = 6 #per batch element + self.train_rois_per_image = 6 #per batch element + self.roi_positive_ratio = 0.5 + self.anchor_matching_iou = 0.7 + + # factor of top-k candidates to draw from per negative sample (stochastic-hard-example-mining). + # poolsize to draw top-k candidates from will be shem_poolsize * n_negative_samples. + self.shem_poolsize = 10 + + self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3) + self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5) + self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10) + + self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) + self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) + self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]]) + self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1], + self.patch_size_3D[2], self.patch_size_3D[2]]) + if self.dim == 2: + self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4] + self.bbox_std_dev = self.bbox_std_dev[:4] + self.window = self.window[:4] + self.scale = self.scale[:4] + + # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element. + self.pre_nms_limit = 3000 if self.dim == 2 else 6000 + + # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True, + # since proposals of the entire batch are forwarded through second stage in as one "batch". + self.roi_chunk_size = 2500 if self.dim == 2 else 600 + self.post_nms_rois_training = 500 if self.dim == 2 else 75 + self.post_nms_rois_inference = 500 + + # Final selection of detections (refine_detections) + self.model_max_instances_per_batch_element = 10 if self.dim == 2 else 30 # per batch element and class. + self.detection_nms_threshold = 1e-5 # needs to be > 0, otherwise all predictions are one cluster. + self.model_min_confidence = 0.1 + + if self.dim == 2: + self.backbone_shapes = np.array( + [[int(np.ceil(self.patch_size[0] / stride)), + int(np.ceil(self.patch_size[1] / stride))] + for stride in self.backbone_strides['xy']]) + else: + self.backbone_shapes = np.array( + [[int(np.ceil(self.patch_size[0] / stride)), + int(np.ceil(self.patch_size[1] / stride)), + int(np.ceil(self.patch_size[2] / stride_z))] + for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z'] + )]) + + if self.model == 'retina_net' or self.model == 'retina_unet': + + self.focal_loss = True + + # implement extra anchor-scales according to retina-net publication. + self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in + self.rpn_anchor_scales['xy']] + self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in + self.rpn_anchor_scales['z']] + self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3 + + self.n_rpn_features = 256 if self.dim == 2 else 128 + + # pre-selection of detections for NMS-speedup. per entire batch. + self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size + + # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002 + self.anchor_matching_iou = 0.5 + + if self.model == 'retina_unet': + self.operate_stride1 = True + diff --git a/datasets/lidc/data_loader.py b/datasets/lidc/data_loader.py new file mode 100644 index 0000000..964e6fb --- /dev/null +++ b/datasets/lidc/data_loader.py @@ -0,0 +1,978 @@ +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +''' +Data Loader for the LIDC data set. This dataloader expects preprocessed data in .npy or .npz files per patient and +a pandas dataframe containing the meta info e.g. file paths, and some ground-truth info like labels, foreground slice ids. + +LIDC 4-fold annotations storage capacity problem: keep segmentation gts compressed (npz), unpack at each batch generation. + +''' + +import plotting as plg + +import os +import pickle +import time +import subprocess +from multiprocessing import Pool + +import numpy as np +import pandas as pd +from collections import OrderedDict + +# batch generator tools from https://github.com/MIC-DKFZ/batchgenerators +from batchgenerators.dataloading.data_loader import SlimDataLoaderBase +from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror +from batchgenerators.transforms.abstract_transforms import Compose +from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter +from batchgenerators.dataloading import SingleThreadedAugmenter +from batchgenerators.transforms.spatial_transforms import SpatialTransform +from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform +#from batchgenerators.transforms.utility_transforms import ConvertSegToBoundingBoxCoordinates + +import utils.dataloader_utils as dutils +from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates +import data_manager as dmanager + + + +def save_obj(obj, name): + """Pickle a python object.""" + with open(name + '.pkl', 'wb') as f: + pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) + +def vector(item): + """ensure item is vector-like (list or array or tuple) + :param item: anything + """ + if not isinstance(item, (list, tuple, np.ndarray)): + item = [item] + return item + + +class Dataset(dutils.Dataset): + r"""Load a dict holding memmapped arrays and clinical parameters for each patient, + evtly subset of those. + If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to + cf.data_dest. + :param cf: config object. + :param logger: logger. + :param subset_ids: subset of patient/sample identifiers to load from whole set. + :param data_sourcedir: directory in which to find data, defaults to cf.data_sourcedir if None. + :return: dict with imgs, segs, pids, class_labels, observables + """ + + def __init__(self, cf, logger=None, subset_ids=None, data_sourcedir=None, mode='train'): + super(Dataset,self).__init__(cf, data_sourcedir) + if mode == 'train' and not cf.training_gts == "merged": + self.gt_dir = "patient_gts_sa" + self.gt_kind = cf.training_gts + else: + self.gt_dir = "patient_gts_merged" + self.gt_kind = "merged" + if logger is not None: + logger.info("loading {} ground truths for {}".format(self.gt_kind, 'training and validation' if mode=='train' + else 'testing')) + + p_df = pd.read_pickle(os.path.join(self.data_sourcedir, self.gt_dir, cf.input_df_name)) + #exclude_pids = ["0305a", "0447a"] # due to non-bg segmentation but bg mal label in nodules 5728, 8840 + #p_df = p_df[~p_df.pid.isin(exclude_pids)] + + if subset_ids is not None: + p_df = p_df[p_df.pid.isin(subset_ids)] + if logger is not None: + logger.info('subset: selected {} instances from df'.format(len(p_df))) + if cf.select_prototype_subset is not None: + prototype_pids = p_df.pid.tolist()[:cf.select_prototype_subset] + p_df = p_df[p_df.pid.isin(prototype_pids)] + if logger is not None: + logger.warning('WARNING: using prototyping data subset of length {}!!!'.format(len(p_df))) + + pids = p_df.pid.tolist() + + # evtly copy data from data_sourcedir to data_dest + if cf.server_env and not hasattr(cf, 'data_dir') and hasattr(cf, "data_dest"): + # copy and unpack images + file_subset = ["{}_img.npz".format(pid) for pid in pids if not + os.path.isfile(os.path.join(cf.data_dest,'{}_img.npy'.format(pid)))] + file_subset += [os.path.join(self.data_sourcedir, self.gt_dir, cf.input_df_name)] + self.copy_data(cf, file_subset=file_subset, keep_packed=False, del_after_unpack=True) + # copy and do not unpack segmentations + file_subset = [os.path.join(self.gt_dir, "{}_rois.np*".format(pid)) for pid in pids] + keep_packed = not cf.training_gts == "merged" + self.copy_data(cf, file_subset=file_subset, keep_packed=keep_packed, del_after_unpack=(not keep_packed)) + else: + cf.data_dir = self.data_sourcedir + + ext = 'npy' if self.gt_kind == "merged" else 'npz' + imgs = [os.path.join(self.data_dir, '{}_img.npy'.format(pid)) for pid in pids] + segs = [os.path.join(self.data_dir, self.gt_dir, '{}_rois.{}'.format(pid, ext)) for pid in pids] + orig_class_targets = p_df['class_target'].tolist() + + data = OrderedDict() + + if self.gt_kind == 'merged': + for ix, pid in enumerate(pids): + data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid} + data[pid]['fg_slices'] = np.array(p_df['fg_slices'].tolist()[ix]) + if 'class' in cf.prediction_tasks: + if len(cf.class_labels)==3: + # malignancy scores are binarized: (benign: 1-2 --> cl 1, malignant: 3-5 --> cl 2) + data[pid]['class_targets'] = np.array([2 if ii >= 3 else 1 for ii in orig_class_targets[ix]], + dtype='uint8') + elif len(cf.class_labels)==6: + # classify each malignancy score + data[pid]['class_targets'] = np.array([1 if ii==0.5 else np.round(ii) for ii in orig_class_targets[ix]], dtype='uint8') + else: + raise Exception("mismatch class labels and data-loading implementations.") + else: + data[pid]['class_targets'] = np.ones_like(np.array(orig_class_targets[ix]), dtype='uint8') + if any(['regression' in task for task in cf.prediction_tasks]): + data[pid]["regression_targets"] = np.array([vector(v) for v in orig_class_targets[ix]], + dtype='float16') + data[pid]["rg_bin_targets"] = np.array( + [cf.rg_val_to_bin_id(v) for v in data[pid]["regression_targets"]], dtype='uint8') + else: + for ix, pid in enumerate(pids): + data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid} + data[pid]['fg_slices'] = np.array(p_df['fg_slices'].values[ix]) + if 'class' in cf.prediction_tasks: + # malignancy scores are binarized: (benign: 1-2 --> cl 1, malignant: 3-5 --> cl 2) + raise NotImplementedError + # todo need to consider bg + # data[pid]['class_targets'] = np.array( + # [[2 if ii >= 3 else 1 for ii in four_fold_targs] for four_fold_targs in orig_class_targets[ix]]) + else: + data[pid]['class_targets'] = np.array( + [[1 if ii > 0 else 0 for ii in four_fold_targs] for four_fold_targs in orig_class_targets[ix]], dtype='uint8') + if any(['regression' in task for task in cf.prediction_tasks]): + data[pid]["regression_targets"] = np.array( + [[vector(v) for v in four_fold_targs] for four_fold_targs in orig_class_targets[ix]], dtype='float16') + data[pid]["rg_bin_targets"] = np.array( + [[cf.rg_val_to_bin_id(v) for v in four_fold_targs] for four_fold_targs in data[pid]["regression_targets"]], dtype='uint8') + + cf.roi_items = cf.observables_rois[:] + cf.roi_items += ['class_targets'] + if any(['regression' in task for task in cf.prediction_tasks]): + cf.roi_items += ['regression_targets'] + cf.roi_items += ['rg_bin_targets'] + + self.data = data + self.set_ids = np.array(list(self.data.keys())) + self.df = None + +# merged GTs +class BatchGenerator_merged(dutils.BatchGenerator): + """ + creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D) + from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size. + Actual patch_size is obtained after data augmentation. + :param data: data dictionary as provided by 'load_dataset'. + :param batch_size: number of patients to sample for the batch + :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target + """ + def __init__(self, cf, data): + super(BatchGenerator_merged, self).__init__(cf, data) + + self.crop_margin = np.array(self.cf.patch_size)/8. #min distance of ROI center to edge of cropped_patch. + self.p_fg = 0.5 + self.empty_samples_max_ratio = 0.6 + + self.random_count = int(cf.batch_random_ratio * cf.batch_size) + self.class_targets = {k: v["class_targets"] for (k, v) in self._data.items()} + + + self.balance_target_distribution(plot=True) + self.stats = {"roi_counts": np.zeros((len(self.unique_ts),), dtype='uint32'), "empty_samples_count": 0} + + + def generate_train_batch(self): + + # samples patients towards equilibrium of foreground classes on a roi-level after sampling a random ratio + # fully random patients + batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False)) + # target-balanced patients + batch_patient_ids += list(np.random.choice(self.dataset_pids, size=self.batch_size-self.random_count, + replace=False, p=self.p_probs)) + + batch_data, batch_segs, batch_pids, batch_patient_labels = [], [], [], [] + batch_roi_items = {name: [] for name in self.cf.roi_items} + # record roi count of classes in batch + batch_roi_counts, empty_samples_count = np.zeros((len(self.unique_ts),), dtype='uint32'), 0 + # empty count for full bg samples (empty slices in 2D/patients in 3D) + + + for sample in range(self.batch_size): + patient = self._data[batch_patient_ids[sample]] + + data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis] + seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0)) + batch_pids.append(patient['pid']) + (c, y, x, z) = data.shape + + if self.cf.dim == 2: + + elig_slices, choose_fg = [], False + if len(patient['fg_slices']) > 0: + if empty_samples_count / self.batch_size >= self.empty_samples_max_ratio or np.random.rand(1)<=self.p_fg: + # fg is to be picked + for tix in np.argsort(batch_roi_counts): + # pick slices of patient that have roi of sought-for target + # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix + elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero( + patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0])-1] == + self.unique_ts[tix]) > 0] + if len(elig_slices) > 0: + choose_fg = True + break + else: + # pick bg + elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices']) + + if len(elig_slices)>0: + sl_pick_ix = np.random.choice(elig_slices, size=None) + else: + sl_pick_ix = np.random.choice(z, size=None) + + data = data[..., sl_pick_ix] + seg = seg[..., sl_pick_ix] + + # pad data if smaller than pre_crop_size. + if np.any([data.shape[dim + 1] < ps for dim, ps in enumerate(self.cf.pre_crop_size)]): + new_shape = [np.max([data.shape[dim + 1], ps]) for dim, ps in enumerate(self.cf.pre_crop_size)] + data = dutils.pad_nd_image(data, new_shape, mode='constant') + seg = dutils.pad_nd_image(seg, new_shape, mode='constant') + + # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg. + crop_dims = [dim for dim, ps in enumerate(self.cf.pre_crop_size) if data.shape[dim + 1] > ps] + if len(crop_dims) > 0: + if self.cf.dim == 3: + choose_fg = (empty_samples_count/self.batch_size>=self.empty_samples_max_ratio) or np.random.rand(1) <= self.p_fg + if choose_fg and np.any(seg): + available_roi_ids = np.unique(seg)[1:] + for tix in np.argsort(batch_roi_counts): + elig_roi_ids = available_roi_ids[patient[self.balance_target][available_roi_ids-1] == self.unique_ts[tix]] + if len(elig_roi_ids)>0: + seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None)) + break + roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)] + assert seg[tuple(roi_anchor_pixel)] > 0 + # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by + # distance to the desired ROI < patch_size /2. + # (here final patch size to account for center_crop after data augmentation). + sample_seg_center = {} + for ii in crop_dims: + low = np.max((self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] - (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) + high = np.min((data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2, + roi_anchor_pixel[ii] + (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) + # happens if lesion on the edge of the image. dont care about roi anymore, + # just make sure pre-crop is inside image. + if low >= high: + low = data.shape[ii + 1] // 2 - (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) + high = data.shape[ii + 1] // 2 + (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) + sample_seg_center[ii] = np.random.randint(low=low, high=high) + + else: + # not guaranteed to be empty. probability of emptiness depends on the data. + sample_seg_center = {ii: np.random.randint(low=self.cf.pre_crop_size[ii]//2, + high=data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2) for ii in crop_dims} + + for ii in crop_dims: + min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2) + max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2) + data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1) + seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii) + + batch_data.append(data) + batch_segs.append(seg[np.newaxis]) + for o in batch_roi_items: #after loop, holds every entry of every batchpatient per roi-item + batch_roi_items[o].append(patient[o]) + + if self.cf.dim == 3: + for tix in range(len(self.unique_ts)): + batch_roi_counts[tix] += np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix]) + elif self.cf.dim == 2: + for tix in range(len(self.unique_ts)): + batch_roi_counts[tix] += np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix]) + if not np.any(seg): + empty_samples_count += 1 + + + data = np.array(batch_data).astype(np.float16) + seg = np.array(batch_segs).astype(np.uint8) + batch = {'data': data, 'seg': seg, 'pid': batch_pids, + 'roi_counts':batch_roi_counts, 'empty_samples_count': empty_samples_count} + for key,val in batch_roi_items.items(): #extend batch dic by roi-wise items (obs, class ids, regression vectors...) + batch[key] = np.array(val) + + return batch + +class PatientBatchIterator_merged(dutils.PatientBatchIterator): + """ + creates a test generator that iterates over entire given dataset returning 1 patient per batch. + Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actualy evaluation (done in 3D), + if willing to accept speed-loss during training. + :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or + batch_size = n_2D_patches in 2D . + """ + + def __init__(self, cf, data): # threads in augmenter + super(PatientBatchIterator_merged, self).__init__(cf, data) + self.patient_ix = 0 + self.patch_size = cf.patch_size + [1] if cf.dim == 2 else cf.patch_size + + def generate_train_batch(self, pid=None): + + if pid is None: + pid = self.dataset_pids[self.patient_ix] + patient = self._data[pid] + + data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0)) + seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0)) + + # pad data if smaller than patch_size seen during training. + if np.any([data.shape[dim] < ps for dim, ps in enumerate(self.patch_size)]): + new_shape = [np.max([data.shape[dim], self.patch_size[dim]]) for dim, ps in enumerate(self.patch_size)] + data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape. + seg = dutils.pad_nd_image(seg, new_shape) + + # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor. + if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: + out_data = data[np.newaxis, np.newaxis] + out_seg = seg[np.newaxis, np.newaxis] + batch_3D = {'data': out_data, 'seg': out_seg} + for o in self.cf.roi_items: + batch_3D[o] = np.array([patient[o]]) + converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg) + batch_3D = converter(**batch_3D) + batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape}) + for o in self.cf.roi_items: + batch_3D["patient_" + o] = batch_3D[o] + + if self.cf.dim == 2: + out_data = np.transpose(data, axes=(2, 0, 1))[:, np.newaxis] # (z, c, x, y ) + out_seg = np.transpose(seg, axes=(2, 0, 1))[:, np.newaxis] + + batch_2D = {'data': out_data, 'seg': out_seg} + for o in self.cf.roi_items: + batch_2D[o] = np.repeat(np.array([patient[o]]), out_data.shape[0], axis=0) + + converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg) + batch_2D = converter(**batch_2D) + + if self.cf.merge_2D_to_3D_preds: + batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'], + 'original_img_shape': out_data.shape}) + for o in self.cf.roi_items: + batch_2D["patient_" + o] = batch_3D[o] + else: + batch_2D.update({'patient_bb_target': batch_2D['bb_target'], + 'original_img_shape': out_data.shape}) + for o in self.cf.roi_items: + batch_2D["patient_" + o] = batch_2D[o] + + out_batch = batch_3D if self.cf.dim == 3 else batch_2D + out_batch.update({'pid': np.array([patient['pid']] * len(out_data))}) + + # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension. + # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1. + if np.any([data.shape[dim] > self.patch_size[dim] for dim in range(3)]): + patient_batch = out_batch + patch_crop_coords_list = dutils.get_patch_crop_coords(data, self.patch_size) + new_img_batch, new_seg_batch = [], [] + + for cix, c in enumerate(patch_crop_coords_list): + + seg_patch = seg[c[0]:c[1], c[2]: c[3], c[4]:c[5]] + new_seg_batch.append(seg_patch) + + tmp_c_5 = c[5] + + new_img_batch.append(data[c[0]:c[1], c[2]:c[3], c[4]:tmp_c_5]) + + data = np.array(new_img_batch)[:, np.newaxis] # (n_patches, c, x, y, z) + seg = np.array(new_seg_batch)[:, np.newaxis] # (n_patches, 1, x, y, z) + if self.cf.dim == 2: + # all patches have z dimension 1 (slices). discard dimension + data = data[..., 0] + seg = seg[..., 0] + + patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'), + 'pid': np.array([patient['pid']] * data.shape[0])} + for o in self.cf.roi_items: + patch_batch[o] = np.repeat(np.array([patient[o]]), len(patch_crop_coords_list), axis=0) + # patient-wise (orig) batch info for putting the patches back together after prediction + for o in self.cf.roi_items: + patch_batch["patient_" + o] = patient_batch['patient_' + o] + if self.cf.dim == 2: + # this could also be named "unpatched_2d_roi_items" + patch_batch["patient_" + o + "_2d"] = patient_batch[o] + # adding patient-wise data and seg adds about 2 GB of additional RAM consumption to a batch 20x288x288 + # and enables calculating test-dice/viewing patient-wise results in test + # remove, but also remove dice from metrics, when like to save memory + patch_batch['patient_data'] = patient_batch['data'] + patch_batch['patient_seg'] = patient_batch['seg'] + patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) + patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] + if self.cf.dim == 2: + patch_batch['patient_bb_target_2d'] = patient_batch['bb_target'] + patch_batch['original_img_shape'] = patient_batch['original_img_shape'] + + converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, + self.cf.class_specific_seg) + patch_batch = converter(**patch_batch) + out_batch = patch_batch + + self.patient_ix += 1 + if self.patient_ix == len(self.dataset_pids): + self.patient_ix = 0 + + return out_batch + +# single-annotator GTs +class BatchGenerator_sa(dutils.BatchGenerator): + """ + creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D) + from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size. + Actual patch_size is obtained after data augmentation. + :param data: data dictionary as provided by 'load_dataset'. + :param batch_size: number of patients to sample for the batch + :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target + """ + + # noinspection PyMethodOverriding + def balance_target_distribution(self, rater, plot=False): + """ + :param rater: for which rater slot to generate the distribution + :param self.targets: dic holding {patient_specifier : patient-wise-unique ROI targets} + :param plot: whether to plot the generated patient distributions + :return: probability distribution over all pids. draw without replace from this. + """ + # get unique foreground targets per patient, assign -1 to an "empty" patient (has no foreground) + patient_ts = [[roi[rater] for roi in patient_rois_lst] for patient_rois_lst in self.targets.values()] + # assign [-1] to empty patients + patient_ts = [np.unique(lst) if len([t for t in lst if np.any(t>0)])>0 else [-1] for lst in patient_ts] + #bg_mask = np.array([np.all(lst == [-1]) for lst in patient_ts]) + # sort out bg labels (are 0) + unique_ts, t_counts = np.unique([t for lst in patient_ts for t in lst if t>0], return_counts=True) + t_probs = t_counts.sum() / t_counts + t_probs /= t_probs.sum() + t_probs = {t : t_probs[ix] for ix, t in enumerate(unique_ts)} + t_probs[-1] = 0. + t_probs[0] = 0. + # fail if balance target is not a number (i.e., a vector) + p_probs = np.array([ max([t_probs[t] for t in lst]) for lst in patient_ts ]) + #normalize + p_probs /= p_probs.sum() + + if plot: + plg.plot_batchgen_distribution(self.cf, self.dataset_pids, p_probs, self.balance_target, + out_file=os.path.join(self.cf.plot_dir, + "train_gen_distr_"+str(self.cf.fold)+"_rater"+str(rater)+".png")) + return p_probs, unique_ts + + + + def __init__(self, cf, data): + super(BatchGenerator_sa, self).__init__(cf, data) + + self.crop_margin = np.array(self.cf.patch_size) / 8. # min distance of ROI center to edge of cropped_patch. + self.p_fg = 0.5 + self.empty_samples_max_ratio = 0.6 + + self.random_count = int(cf.batch_random_ratio * cf.batch_size) + + self.rater_bsize = 4 + unique_ts_total = set() + self.rater_p_probs = [] + for r in range(self.rater_bsize): + p_probs, unique_ts = self.balance_target_distribution(r, plot=True) + self.rater_p_probs.append(p_probs) + unique_ts_total.update(unique_ts) + self.unique_ts = sorted(list(unique_ts_total)) + self.stats = {"roi_counts": np.zeros((len(self.unique_ts),), dtype='uint32'), "empty_samples_count": 0} + + + def generate_train_batch(self): + + rater = np.random.randint(self.rater_bsize) + + # samples patients towards equilibrium of foreground classes on a roi-level (after randomly sampling the ratio batch_random_ratio). + # random patients + batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False)) + # target-balanced patients + batch_patient_ids += list(np.random.choice(self.dataset_pids, size=self.batch_size-self.random_count, replace=False, + p=self.rater_p_probs[rater])) + + batch_data, batch_segs, batch_pids, batch_patient_labels = [], [], [], [] + batch_roi_items = {name: [] for name in self.cf.roi_items} + # record roi count of classes in batch + batch_roi_counts, empty_samples_count = np.zeros((len(self.unique_ts),), dtype='uint32'), 0 + # empty count for full bg samples (empty slices in 2D/patients in 3D) + + + for sample in range(self.batch_size): + + patient = self._data[batch_patient_ids[sample]] + + patient_balance_ts = np.array([roi[rater] for roi in patient[self.balance_target]]) + data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis] + seg = np.load(patient['seg'], mmap_mode='r') + seg = np.transpose(seg[list(seg.keys())[0]][rater], axes=(1, 2, 0)) + batch_pids.append(patient['pid']) + (c, y, x, z) = data.shape + + if self.cf.dim == 2: + + elig_slices, choose_fg = [], False + if len(patient['fg_slices']) > 0: + if empty_samples_count / self.batch_size >= self.empty_samples_max_ratio or np.random.rand( + 1) <= self.p_fg: + # fg is to be picked + for tix in np.argsort(batch_roi_counts): + # pick slices of patient that have roi of sought-for target + # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix + elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero( + patient_balance_ts[np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] == + self.unique_ts[tix]) > 0] + if len(elig_slices) > 0: + choose_fg = True + break + else: + # pick bg + elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices'][rater]) + + if len(elig_slices) > 0: + sl_pick_ix = np.random.choice(elig_slices, size=None) + else: + sl_pick_ix = np.random.choice(z, size=None) + + data = data[..., sl_pick_ix] + seg = seg[..., sl_pick_ix] + + # pad data if smaller than pre_crop_size. + if np.any([data.shape[dim + 1] < ps for dim, ps in enumerate(self.cf.pre_crop_size)]): + new_shape = [np.max([data.shape[dim + 1], ps]) for dim, ps in enumerate(self.cf.pre_crop_size)] + data = dutils.pad_nd_image(data, new_shape, mode='constant') + seg = dutils.pad_nd_image(seg, new_shape, mode='constant') + + # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg. + crop_dims = [dim for dim, ps in enumerate(self.cf.pre_crop_size) if data.shape[dim + 1] > ps] + if len(crop_dims) > 0: + if self.cf.dim == 3: + choose_fg = (empty_samples_count/self.batch_size>=self.empty_samples_max_ratio) or np.random.rand(1) <= self.p_fg + if choose_fg and np.any(seg): + available_roi_ids = np.unique(seg[seg>0]) + assert np.all(patient_balance_ts[available_roi_ids-1]>0), "trying to choose roi with rating 0" + for tix in np.argsort(batch_roi_counts): + elig_roi_ids = available_roi_ids[ patient_balance_ts[available_roi_ids-1] == self.unique_ts[tix] ] + if len(elig_roi_ids)>0: + seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None)) + roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)] + break + + assert seg[tuple(roi_anchor_pixel)] > 0, "roi_anchor_pixel not inside roi: {}, pb_ts {}, elig ids {}".format(tuple(roi_anchor_pixel), patient_balance_ts, elig_roi_ids) + # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by + # distance to the desired ROI < patch_size /2. + # (here final patch size to account for center_crop after data augmentation). + sample_seg_center = {} + for ii in crop_dims: + low = np.max((self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] - (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) + high = np.min((data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2, + roi_anchor_pixel[ii] + (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) + # happens if lesion on the edge of the image. dont care about roi anymore, + # just make sure pre-crop is inside image. + if low >= high: + low = data.shape[ii + 1] // 2 - (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) + high = data.shape[ii + 1] // 2 + (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) + sample_seg_center[ii] = np.random.randint(low=low, high=high) + else: + # not guaranteed to be empty. probability of emptiness depends on the data. + sample_seg_center = {ii: np.random.randint(low=self.cf.pre_crop_size[ii]//2, + high=data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2) for ii in crop_dims} + + for ii in crop_dims: + min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2) + max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2) + data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1) + seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii) + + batch_data.append(data) + batch_segs.append(seg[np.newaxis]) + for o in batch_roi_items: #after loop, holds every entry of every batchpatient per roi-item + batch_roi_items[o].append([roi[rater] for roi in patient[o]]) + + if self.cf.dim == 3: + for tix in range(len(self.unique_ts)): + batch_roi_counts[tix] += np.count_nonzero(patient_balance_ts == self.unique_ts[tix]) + elif self.cf.dim == 2: + for tix in range(len(self.unique_ts)): + batch_roi_counts[tix] += np.count_nonzero(patient_balance_ts[np.unique(seg[seg>0]) - 1] == self.unique_ts[tix]) + if not np.any(seg): + empty_samples_count += 1 + + + data = np.array(batch_data).astype('float16') + seg = np.array(batch_segs).astype('uint8') + batch = {'data': data, 'seg': seg, 'pid': batch_pids, 'rater_id': rater, + 'roi_counts':batch_roi_counts, 'empty_samples_count': empty_samples_count} + for key,val in batch_roi_items.items(): #extend batch dic by roi-wise items (obs, class ids, regression vectors...) + batch[key] = np.array(val) + + return batch + +class PatientBatchIterator_sa(dutils.PatientBatchIterator): + """ + creates a test generator that iterates over entire given dataset returning 1 patient per batch. + Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actual evaluation (done in 3D), + if willing to accept speed loss during training. + :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or + batch_size = n_2D_patches in 2D . + + This is the data & gt loader for the 4-fold single-annotator GTs: each data input has separate annotations of 4 annotators. + the way the pipeline is currently setup, the single-annotator GTs are only used if training with validation mode + val_patient; during testing the Iterator with the merged GTs is used. + # todo mode val_patient not implemented yet (since very slow). would need to sample from all available rater GTs. + """ + def __init__(self, cf, data): #threads in augmenter + super(PatientBatchIterator_sa, self).__init__(cf, data) + self.cf = cf + self.patient_ix = 0 + self.dataset_pids = list(self._data.keys()) + self.patch_size = cf.patch_size+[1] if cf.dim==2 else cf.patch_size + + self.rater_bsize = 4 + + + def generate_train_batch(self, pid=None): + + if pid is None: + pid = self.dataset_pids[self.patient_ix] + patient = self._data[pid] + + data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0)) + # all gts are 4-fold and npz! + seg = np.load(patient['seg'], mmap_mode='r') + seg = np.transpose(seg[list(seg.keys())[0]], axes=(0, 2, 3, 1)) + + # pad data if smaller than patch_size seen during training. + if np.any([data.shape[dim] < ps for dim, ps in enumerate(self.patch_size)]): + new_shape = [np.max([data.shape[dim], self.patch_size[dim]]) for dim, ps in enumerate(self.patch_size)] + data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape. + seg = dutils.pad_nd_image(seg, new_shape) + + # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor. + if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: + out_data = data[np.newaxis, np.newaxis] + out_seg = seg[:, np.newaxis] + batch_3D = {'data': out_data, 'seg': out_seg} + + for item in self.cf.roi_items: + batch_3D[item] = [] + for r in range(self.rater_bsize): + for item in self.cf.roi_items: + batch_3D[item].append(np.array([roi[r] for roi in patient[item]])) + + converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg) + batch_3D = converter(**batch_3D) + batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape}) + for o in self.cf.roi_items: + batch_3D["patient_" + o] = batch_3D[o] + + if self.cf.dim == 2: + out_data = np.transpose(data, axes=(2, 0, 1))[:, np.newaxis] # (z, c, y, x ) + out_seg = np.transpose(seg, axes=(0, 3, 1, 2))[:, :, np.newaxis] # (n_raters, z, 1, y,x) + + batch_2D = {'data': out_data} + + for item in ["seg", "bb_target"]+self.cf.roi_items: + batch_2D[item] = [] + + converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg) + for r in range(self.rater_bsize): + tmp_batch = {"seg": out_seg[r]} + for item in self.cf.roi_items: + tmp_batch[item] = np.repeat(np.array([[roi[r] for roi in patient[item]]]), out_data.shape[0], axis=0) + tmp_batch = converter(**tmp_batch) + for item in ["seg", "bb_target"]+self.cf.roi_items: + batch_2D[item].append(tmp_batch[item]) + # for item in ["seg", "bb_target"]+self.cf.roi_items: + # batch_2D[item] = np.array(batch_2D[item]) + + if self.cf.merge_2D_to_3D_preds: + batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'], + 'original_img_shape': out_data.shape}) + for o in self.cf.roi_items: + batch_2D["patient_" + o] = batch_3D[o] + else: + batch_2D.update({'patient_bb_target': batch_2D['bb_target'], + 'original_img_shape': out_data.shape}) + for o in self.cf.roi_items: + batch_2D["patient_" + o] = batch_2D[o] + + out_batch = batch_3D if self.cf.dim == 3 else batch_2D + out_batch.update({'pid': np.array([patient['pid']] * out_data.shape[0])}) + + # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension. + # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1. + if np.any([data.shape[dim] > self.patch_size[dim] for dim in range(3)]): + patient_batch = out_batch + patch_crop_coords_list = dutils.get_patch_crop_coords(data, self.patch_size) + new_img_batch = [] + new_seg_batch = [] + + for cix, c in enumerate(patch_crop_coords_list): + seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]] + new_seg_batch.append(seg_patch) + tmp_c_5 = c[5] + + new_img_batch.append(data[c[0]:c[1], c[2]:c[3], c[4]:tmp_c_5]) + + data = np.array(new_img_batch)[:, np.newaxis] # (n_patches, c, x, y, z) + seg = np.transpose(np.array(new_seg_batch), axes=(1,0,2,3,4))[:,:,np.newaxis] # (n_raters, n_patches, x, y, z) + + if self.cf.dim == 2: + # all patches have z dimension 1 (slices). discard dimension + data = data[..., 0] + seg = seg[..., 0] + + patch_batch = {'data': data.astype('float32'), + 'pid': np.array([patient['pid']] * data.shape[0])} + # for o in self.cf.roi_items: + # patch_batch[o] = np.repeat(np.array([patient[o]]), len(patch_crop_coords_list), axis=0) + + converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, + self.cf.class_specific_seg) + + for item in ["seg", "bb_target"]+self.cf.roi_items: + patch_batch[item] = [] + # coord_list = [np.min(seg_ixs[:, 1]) - 1, np.min(seg_ixs[:, 2]) - 1, np.max(seg_ixs[:, 1]) + 1, + # IndexError: index 2 is out of bounds for axis 1 with size 2 + for r in range(self.rater_bsize): + tmp_batch = {"seg": seg[r]} + for item in self.cf.roi_items: + tmp_batch[item] = np.repeat(np.array([[roi[r] for roi in patient[item]]]), len(patch_crop_coords_list), axis=0) + tmp_batch = converter(**tmp_batch) + for item in ["seg", "bb_target"]+self.cf.roi_items: + patch_batch[item].append(tmp_batch[item]) + + # patient-wise (orig) batch info for putting the patches back together after prediction + for o in self.cf.roi_items: + patch_batch["patient_" + o] = patient_batch['patient_'+o] + if self.cf.dim==2: + # this could also be named "unpatched_2d_roi_items" + patch_batch["patient_"+o+"_2d"] = patient_batch[o] + # adding patient-wise data and seg adds about 2 GB of additional RAM consumption to a batch 20x288x288 + # and enables calculating test-dice/viewing patient-wise results in test + # remove, but also remove dice from metrics, if you like to save memory + patch_batch['patient_data'] = patient_batch['data'] + patch_batch['patient_seg'] = patient_batch['seg'] + patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) + patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] + if self.cf.dim==2: + patch_batch['patient_bb_target_2d'] = patient_batch['bb_target'] + patch_batch['original_img_shape'] = patient_batch['original_img_shape'] + + out_batch = patch_batch + + self.patient_ix += 1 + if self.patient_ix == len(self.dataset_pids): + self.patient_ix = 0 + + return out_batch + + +def create_data_gen_pipeline(cf, patient_data, is_training=True): + """ create multi-threaded train/val/test batch generation and augmentation pipeline. + :param cf: configs object. + :param patient_data: dictionary containing one dictionary per patient in the train/test subset. + :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing) + :return: multithreaded_generator + """ + + data_gen = BatchGenerator_merged(cf, patient_data) if cf.training_gts=='merged' else BatchGenerator_sa(cf, patient_data) + + # add transformations to pipeline. + my_transforms = [] + if is_training: + if cf.da_kwargs["mirror"]: + mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes']) + my_transforms.append(mirror_transform) + spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], + patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], + do_elastic_deform=cf.da_kwargs['do_elastic_deform'], + alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], + do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], + angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], + do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], + random_crop=cf.da_kwargs['random_crop']) + + my_transforms.append(spatial_transform) + else: + my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) + + if cf.create_bounding_box_targets: + my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg)) + all_transforms = Compose(my_transforms) + + multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers)) + return multithreaded_generator + +def get_train_generators(cf, logger, data_statistics=True): + """ + wrapper function for creating the training batch generator pipeline. returns the train/val generators. + selects patients according to cv folds (generated by first run/fold of experiment): + splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set) + If cf.held_out_test_set is True, adds the test split to the training data. + """ + dataset = Dataset(cf, logger) + + dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits) + dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle')) + set_splits = dataset.fg.splits + + test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1) + train_ids = np.concatenate(set_splits, axis=0) + + if cf.held_out_test_set: + train_ids = np.concatenate((train_ids, test_ids), axis=0) + test_ids = [] + + train_data = {k: v for (k, v) in dataset.data.items() if k in train_ids} + val_data = {k: v for (k, v) in dataset.data.items() if k in val_ids} + + logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids), + len(test_ids))) + if data_statistics: + dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids}, + plot_dir=os.path.join(cf.plot_dir,"dataset")) + + batch_gen = {} + batch_gen['train'] = create_data_gen_pipeline(cf, train_data, is_training=True) + batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, is_training=False) + if cf.val_mode == 'val_patient': + assert cf.training_gts == 'merged', 'val_patient not yet implemented for sa gts' + batch_gen['val_patient'] = PatientBatchIterator_merged(cf, val_data) if cf.training_gts=='merged' \ + else PatientBatchIterator_sa(cf, val_data) + batch_gen['n_val'] = len(val_data) if cf.max_val_patients is None else cf.max_val_patients + else: + batch_gen['n_val'] = cf.num_val_batches + + return batch_gen + +def get_test_generator(cf, logger): + """ + wrapper function for creating the test batch generator pipeline. + selects patients according to cv folds (generated by first run/fold of experiment) + If cf.held_out_test_set is True, gets the data from an external folder instead. + """ + if cf.held_out_test_set: + sourcedir = cf.test_data_sourcedir + test_ids = None + else: + sourcedir = None + with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle: + set_splits = pickle.load(handle) + test_ids = set_splits[cf.fold] + + test_data = Dataset(cf, logger, subset_ids=test_ids, data_sourcedir=sourcedir, mode="test").data + logger.info("data set loaded with: {} test patients".format(len(test_ids))) + batch_gen = {} + batch_gen['test'] = PatientBatchIterator_merged(cf, test_data) + batch_gen['n_test'] = len(test_ids) if cf.max_test_patients == "all" else min(cf.max_test_patients, len(test_ids)) + return batch_gen + + +if __name__ == "__main__": + import sys + sys.path.append('../') + import plotting as plg + import utils.exp_utils as utils + from configs import Configs + + cf = configs() + cf.batch_size = 3 + #dataset_path = os.path.dirname(os.path.realpath(__file__)) + #exp_path = os.path.join(dataset_path, "experiments/dev") + #cf = utils.prep_exp(dataset_path, exp_path, server_env=False, use_stored_settings=False, is_training=True) + cf.created_fold_id_pickle = False + total_stime = time.time() + times = {} + + # cf.server_env = True + # cf.data_dir = "experiments/dev_data" + + # dataset = Dataset(cf) + # patient = dataset['Master_00018'] + cf.exp_dir = "experiments/dev/" + cf.plot_dir = cf.exp_dir + "plots" + os.makedirs(cf.exp_dir, exist_ok=True) + cf.fold = 0 + logger = utils.get_logger(cf.exp_dir) + gens = get_train_generators(cf, logger) + train_loader = gens['train'] + + + + for i in range(1): + stime = time.time() + #ex_batch = next(train_loader) + print("train batch", i) + times["train_batch"] = time.time() - stime + #plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch.png", show_gt_labels=True) + # + # # with open(os.path.join(cf.exp_dir, "fold_"+str(cf.fold), "BatchGenerator_stats.txt"), mode="w") as file: + # # train_loader.generator.print_stats(logger, file) + # + val_loader = gens['val_sampling'] + stime = time.time() + ex_batch = next(val_loader) + times["val_batch"] = time.time() - stime + stime = time.time() + #plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch.png", show_gt_labels=True, plot_mods=False, + # show_info=False) + times["val_plot"] = time.time() - stime + # + test_loader = get_test_generator(cf, logger)["test"] + stime = time.time() + ex_batch = test_loader.generate_train_batch() + times["test_batch"] = time.time() - stime + stime = time.time() + plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png")#, sample_picks=[0,1,2,3]) + times["test_patchbatch_plot"] = time.time() - stime + + # ex_batch['data'] = ex_batch['patient_data'] + # ex_batch['seg'] = ex_batch['patient_seg'] + # ex_batch['bb_target'] = ex_batch['patient_bb_target'] + # for item in cf.roi_items: + # ex_batch[] + # stime = time.time() + # #ex_batch = next(test_loader) + # ex_batch = next(test_loader) + # plg.view_batch(cf, ex_batch, show_gt_labels=False, show_gt_boxes=True, patient_items=True,# vol_slice_picks=[146,148, 218,220], + # out_file="experiments/dev/dev_expatientbatch.png") # , sample_picks=[0,1,2,3]) + # times["test_patient_batch_plot"] = time.time() - stime + + + + print("Times recorded throughout:") + for (k, v) in times.items(): + print(k, "{:.2f}".format(v)) + + mins, secs = divmod((time.time() - total_stime), 60) + h, mins = divmod(mins, 60) + t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) + print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) diff --git a/datasets/lidc/preprocessing.py b/datasets/lidc/preprocessing.py new file mode 100644 index 0000000..2f5efd4 --- /dev/null +++ b/datasets/lidc/preprocessing.py @@ -0,0 +1,478 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +''' +This preprocessing script loads nrrd files obtained by the data conversion tool: https://github.com/MIC-DKFZ/LIDC-IDRI-processing/tree/v1.0.1 +After applying preprocessing, images are saved as numpy arrays and the meta information for the corresponding patient is stored +as a line in the dataframe saved as info_df.pickle. +''' + +import os +import sys +import shutil +import subprocess +import pickle +import time + +import SimpleITK as sitk +import numpy as np +from multiprocessing import Pool +import pandas as pd +import numpy.testing as npt +from skimage.transform import resize + +sys.path.append(os.path.dirname(os.path.realpath(__file__))) +sys.path.append('../..') +import data_manager as dmanager + +class AttributeDict(dict): + __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ + +def load_df(path): + df = pd.read_pickle(path) + print(df) + + return + +def resample_array(src_imgs, src_spacing, target_spacing): + """ Resample a numpy array. + :param src_imgs: source image. + :param src_spacing: source image's spacing. + :param target_spacing: spacing to resample source image to. + :return: + """ + src_spacing = np.round(src_spacing, 3) + target_shape = [int(src_imgs.shape[ix] * src_spacing[::-1][ix] / target_spacing[::-1][ix]) for ix in range(len(src_imgs.shape))] + for i in range(len(target_shape)): + try: + assert target_shape[i] > 0 + except: + raise AssertionError("AssertionError:", src_imgs.shape, src_spacing, target_spacing) + + img = src_imgs.astype('float64') + resampled_img = resize(img, target_shape, order=1, clip=True, mode='edge').astype('float32') + + return resampled_img + +class Preprocessor(object): + """Preprocessor for LIDC raw data. Set in config: which ground truths to produce, choices are + - "merged" for a single ground truth per input image, created by merging the given four rater annotations + into one. + - "single-annotator" for a four-fold ground truth per input image, created by leaving the each rater annotation + separately. + :param cf: config. + :param exclude_inconsistents: bool or tuple, list, np.array, exclude patients that show technical inconsistencies + in the raw files, likely due to file-naming mistakes. if bool and True: search for patients that have too many + ratings per lesion or other inconstencies, exclude findings. if param is list of pids: exclude given pids. + :param overwrite: look for patients that already exist in the pp dir. if overwrite is False, do not redo existing + patients, otherwise ignore any existing files. + :param max_count: maximum number of patients to preprocess. + :param pids_subset: subset of pids to preprocess. + """ + + def __init__(self, cf, exclude_inconsistents=True, overwrite=False, max_count=None, pids_subset=None): + + self.cf = cf + + assert len(self.cf.gts_to_produce)>0, "need to specify which gts to produce, choices: 'merged', 'single_annotator'" + + self.paths = [os.path.join(cf.raw_data_dir, ii) for ii in os.listdir(cf.raw_data_dir)] + if exclude_inconsistents: + if isinstance(exclude_inconsistents, bool): + exclude_paths = self.exclude_too_many_ratings() + exclude_paths += self.verify_seg_label_pairings() + else: + assert isinstance(exclude_inconsistents, (tuple,list,np.ndarray)) + exclude_paths = exclude_inconsistents + self.paths = [path for path in self.paths if path not in exclude_paths] + + + if 'single_annotator' in self.cf.gts_to_produce or 'sa' in self.cf.gts_to_produce: + self.pp_dir_sa = os.path.join(cf.pp_dir, "patient_gts_sa") + if 'merged' in self.cf.gts_to_produce: + self.pp_dir_merged = os.path.join(cf.pp_dir, "patient_gts_merged") + orig_count = len(self.paths) + # check if some patients already have ppd versions in destination dir + if os.path.exists(cf.pp_dir) and not overwrite: + fs_in_dir = os.listdir(cf.pp_dir) + already_done = [file.split("_")[0] for file in fs_in_dir if file.split("_")[1] == "img.npy"] + if 'single_annotator' in self.cf.gts_to_produce or 'sa' in self.cf.gts_to_produce: + ext = '.npy' if hasattr(self.cf, "save_sa_segs_as") and ( + self.cf.save_sa_segs_as == "npy" or self.cf.save_sa_segs_as == ".npy") else '.npz' + fs_in_dir = os.listdir(self.pp_dir_sa) + already_done = [ pid for pid in already_done if pid+"_rois"+ext in fs_in_dir and pid+"_meta_info.pickle" in fs_in_dir] + if 'merged' in self.cf.gts_to_produce: + fs_in_dir = os.listdir(self.pp_dir_merged) + already_done = [pid for pid in already_done if + pid + "_rois.npy" in fs_in_dir and pid+"_meta_info.pickle" in fs_in_dir] + + self.paths = [p for p in self.paths if not p.split(os.sep)[-1] in already_done] + if len(self.paths)!=orig_count: + print("Due to existing ppd files: Selected a subset of {} patients from originally {}".format(len(self.paths), orig_count)) + + if pids_subset: + self.paths = [p for p in self.paths if p.split(os.sep)[-1] in pids_subset] + if max_count is not None: + self.paths = self.paths[:max_count] + + if not os.path.exists(cf.pp_dir): + os.mkdir(cf.pp_dir) + if ('single_annotator' in self.cf.gts_to_produce or 'sa' in self.cf.gts_to_produce) and \ + not os.path.exists(self.pp_dir_sa): + os.mkdir(self.pp_dir_sa) + if 'merged' in self.cf.gts_to_produce and not os.path.exists(self.pp_dir_merged): + os.mkdir(self.pp_dir_merged) + + + def exclude_too_many_ratings(self): + """exclude a patient's full path (the patient folder) from further processing if patient has nodules with + ratings of more than four raters (which is inconsistent with what the raw data is supposed to comprise, + also rater ids appear multiple times on the same nodule in these cases motivating the assumption that + the same rater issued more than one rating / mixed up files or annotations for a nodule). + :return: paths to be excluded. + """ + exclude_paths = [] + for path in self.paths: + roi_ids = set([ii.split('.')[0].split('_')[-1] for ii in os.listdir(path) if '.nii.gz' in ii]) + found = False + for roi_id in roi_ids: + n_raters = len([ii for ii in os.listdir(path) if '{}.nii'.format(roi_id) in ii]) + # assert n_raters<=4, "roi {} in path {} has {} raters".format(roi_id, path, n_raters) + if n_raters > 4: + print("roi {} in path {} has {} raters".format(roi_id, path, n_raters)) + found = True + if found: + exclude_paths.append(path) + print("Patients excluded bc of too many raters:\n") + for p in exclude_paths: + print(p) + print() + + return exclude_paths + + def analyze_lesion(self, pid, nodule_id): + """print unique seg and counts of nodule nodule_id of patient pid. + """ + nodule_id = nodule_id.lstrip("0") + nodule_id_paths = [ii for ii in os.listdir(os.path.join(self.cf.raw_data_dir, pid)) if '.nii' in ii] + nodule_id_paths = [ii for ii in nodule_id_paths if ii.split('_')[2].lstrip("0")==nodule_id] + assert len(nodule_id_paths)==1 + nodule_path = nodule_id_paths[0] + + roi = sitk.ReadImage(os.path.join(self.cf.raw_data_dir, pid, nodule_path)) + roi_arr = sitk.GetArrayFromImage(roi).astype(np.uint8) + + print("pid {}, nodule {}, unique seg & counts: {}".format(pid, nodule_id, np.unique(roi_arr, return_counts=True))) + return + + def verify_seg_label_pairing(self, path): + """verifies that a nodule's segmentation has malignancy label > 0 if segmentation has foreground (>0 anywhere), + and vice-versa that it has only background (==0 everywhere) if no malignancy label (==label 0) assigned. + :param path: path to the patient folder. + :return: df containing eventual inconsistency findings. + """ + + pid = path.split('/')[-1] + + df = pd.read_csv(os.path.join(self.cf.root_dir, 'characteristics.csv'), sep=';') + df = df[df.PatientID == pid] + + findings_df = pd.DataFrame(columns=["problem", "pid", "roi_id", "nodule_id", "rater_ix", "seg_unique", "label"]) + + print('verifying {}'.format(pid)) + + roi_ids = set([ii.split('.')[0].split('_')[-1] for ii in os.listdir(path) if '.nii.gz' in ii]) + + for roi_id in roi_ids: + roi_id_paths = [ii for ii in os.listdir(path) if '{}.nii'.format(roi_id) in ii] + nodule_ids = [rp.split('_')[2].lstrip("0") for rp in roi_id_paths] + rater_ids = [rp.split('_')[1] for rp in roi_id_paths] + rater_labels = [df[df.NoduleID == int(ii)].Malignancy.values[0] for ii in nodule_ids] + + # check double existence of nodule ids + uniq, counts = np.unique(nodule_ids, return_counts=True) + if np.any([count>1 for count in counts]): + finding = ("same nodule id exists more than once", pid, roi_id, nodule_ids, "N/A", "N/A", "N/A") + print("not unique nodule id", finding) + findings_df.loc[findings_df.shape[0]] = finding + + # check double gradings of single rater for single roi + uniq, counts = np.unique(rater_ids, return_counts=True) + if np.any([count>1 for count in counts]): + finding = ("same roi_id exists more than once for a single rater", pid, roi_id, nodule_ids, rater_ids, "N/A", rater_labels) + print("more than one grading per roi per single rater", finding) + findings_df.loc[findings_df.shape[0]] = finding + + + rater_segs = [] + for rp in roi_id_paths: + roi = sitk.ReadImage(os.path.join(self.cf.raw_data_dir, pid, rp)) + roi_arr = sitk.GetArrayFromImage(roi).astype(np.uint8) + + rater_segs.append(roi_arr) + rater_segs = np.array(rater_segs) + for r in range(rater_segs.shape[0]): + if np.sum(rater_segs[r])>0: + if rater_labels[r]<=0: + finding = ("non-empty seg w/ bg label", pid, roi_id, nodule_ids[r], rater_ids[r], np.unique(rater_segs[r]), rater_labels[r]) + print("{}: pid {}, nodule {}, rater {}, seg unique {}, label {}".format( + *finding)) + findings_df.loc[findings_df.shape[0]] = finding + else: + if rater_labels[r]>0: + finding = ("empty seg w/ fg label", pid, roi_id, nodule_ids[r], rater_ids[r], np.unique(rater_segs[r]), rater_labels[r]) + print("{}: pid {}, nodule {}, rater {}, seg unique {}, label {}".format( + *finding)) + findings_df.loc[findings_df.shape[0]] = finding + + return findings_df + + def verify_seg_label_pairings(self, processes=os.cpu_count()): + """wrapper to multi-process verification of seg-label pairings. + """ + + pool = Pool(processes=processes) + findings_dfs = pool.map(self.verify_seg_label_pairing, self.paths, chunksize=1) + pool.close() + pool.join() + + findings_df = pd.concat(findings_dfs, axis=0) + findings_df.to_pickle(os.path.join(self.cf.pp_dir, "verification_seg_label_pairings.pickle")) + findings_df.to_csv(os.path.join(self.cf.pp_dir, "verification_seg_label_pairings.csv")) + + return findings_df.pid.tolist() + + def produce_sa_gt(self, path, pid, df, img_spacing, img_arr_shape): + """ Keep annotations separate, i.e., every processed image has four final GTs. + Images are always saved as npy. For meeting hard-disk-memory constraints, segmentations can optionally be + saved as .npz instead of .npy. Dataloader is only implemented for reading .npz segs. + """ + + final_rois = np.zeros((4, *img_arr_shape), dtype='uint8') + patient_mal_labels = [] + roi_ids = list(set([ii.split('.')[0].split('_')[-1] for ii in os.listdir(path) if '.nii.gz' in ii])) + roi_ids.sort() # just a precaution to have same order of lesions throughout separate runs + + rix = 1 + for roi_id in roi_ids: + roi_id_paths = [ii for ii in os.listdir(path) if '{}.nii'.format(roi_id) in ii] + assert len(roi_id_paths)>0 and len(roi_id_paths)<=4, "pid {}: should find 0< n_rois <4, but found {}".format(pid, len(roi_id_paths)) + + """ not strictly necessary precaution: in theory, segmentations of different raters could overlap also for + *different* rois, i.e., a later roi of a rater could (partially) cover up / destroy the roi of another + rater. practically this is unlikely as overlapping lesions of different raters should be regarded as the + same lesion, but safety first. hence, the order of raters is maintained across rois, i.e., rater 0 + (marked as rater 0 in roi's file name) always has slot 0 in rater_labels and rater_segs, thereby rois + are certain to not overlap. + """ + rater_labels, rater_segs = np.zeros((4,), dtype='uint8'), np.zeros((4,*img_arr_shape), dtype="float32") + for ix, rp in enumerate(roi_id_paths): # one roi path per rater + nodule_id = rp.split('_')[2].lstrip("0") + assert not (nodule_id=="5728" or nodule_id=="8840"), "nodule ids {}, {} should be excluded due to seg-mal-label inconsistency.".format(5728, 8840) + rater = int(rp.split('_')[1]) + rater_label = df[df.NoduleID == int(nodule_id)].Malignancy.values[0] + rater_labels[rater] = rater_label + + roi = sitk.ReadImage(os.path.join(self.cf.raw_data_dir, pid, rp)) + for dim in range(len(img_arr_shape)): + npt.assert_almost_equal(roi.GetSpacing()[dim], img_spacing[dim]) + roi_arr = sitk.GetArrayFromImage(roi) + roi_arr = resample_array(roi_arr, roi.GetSpacing(), self.cf.target_spacing) + assert roi_arr.shape == img_arr_shape, [roi_arr.shape, img_arr_shape, pid, roi.GetSpacing()] + assert not np.any(rater_segs[rater]), "overwriting existing rater's seg with roi {}".format(rp) + rater_segs[rater] = roi_arr + rater_segs = np.array(rater_segs) + + # rename/remap the malignancy to be positive. + roi_mal_labels = [ii if ii > -1 else 0 for ii in rater_labels] + assert rater_segs.shape == final_rois.shape, "rater segs shape {}, final rois shp {}".format(rater_segs.shape, final_rois.shape) + + # assert non-zero rating has non-zero seg + for rater in range(4): + if roi_mal_labels[rater]>0: + assert np.any(rater_segs[rater]>0), "rater {} mal label {} but uniq seg {}".format(rater, roi_mal_labels[rater], np.unique(rater_segs[rater])) + + # add the roi to patient. i.e., write current lesion into final labels and seg of whole patient. + assert np.any(rater_segs), "empty segmentations for all raters should not exist in single-annotator mode, pid {}, rois: {}".format(pid, roi_id_paths) + patient_mal_labels.append(roi_mal_labels) + final_rois[rater_segs > 0] = rix + rix += 1 + + + fg_slices = [[ii for ii in np.unique(np.argwhere(final_rois[r] != 0)[:, 0])] for r in range(4)] + patient_mal_labels = np.array(patient_mal_labels) + roi_ids = np.unique(final_rois[final_rois>0]) + assert len(roi_ids) == len(patient_mal_labels), "mismatch {} rois in seg, {} rois in mal labels".format(len(roi_ids), len(patient_mal_labels)) + + if hasattr(self.cf, "save_sa_segs_as") and (self.cf.save_sa_segs_as=="npy" or self.cf.save_sa_segs_as==".npy"): + np.save(os.path.join(self.pp_dir_sa, '{}_rois.npy'.format(pid)), final_rois) + else: + np.savez_compressed(os.path.join(self.cf.pp_dir, 'patient_gts_sa', '{}_rois.npz'.format(pid)), seg=final_rois) + with open(os.path.join(self.pp_dir_sa, '{}_meta_info.pickle'.format(pid)), 'wb') as handle: + meta_info_dict = {'pid': pid, 'class_target': patient_mal_labels, 'spacing': img_spacing, + 'fg_slices': fg_slices} + pickle.dump(meta_info_dict, handle) + + def produce_merged_gt(self, path, pid, df, img_spacing, img_arr_shape): + """ process patient with merged annotations, i.e., only one final GT per image. save img and seg to npy, rest to + metadata. + annotations merging: + - segmentations: only regard a pixel as foreground if at least two raters found it be foreground. + - malignancy labels: average over all four rater votes. every rater who did not assign a finding or + assigned -1 to the RoI contributes to the average with a vote of 0. + + :param path: path to patient folder. + """ + + final_rois = np.zeros(img_arr_shape, dtype=np.uint8) + patient_mal_labels = [] + roi_ids = set([ii.split('.')[0].split('_')[-1] for ii in os.listdir(path) if '.nii.gz' in ii]) + + rix = 1 + for roi_id in roi_ids: + roi_id_paths = [ii for ii in os.listdir(path) if '{}.nii'.format(roi_id) in ii] + nodule_ids = [ii.split('_')[2].lstrip("0") for ii in roi_id_paths] + rater_labels = [df[df.NoduleID == int(ii)].Malignancy.values[0] for ii in nodule_ids] + rater_labels.extend([0] * (4 - len(rater_labels))) + mal_label = np.mean([ii if ii > -1 else 0 for ii in rater_labels]) + rater_segs = [] + for rp in roi_id_paths: + roi = sitk.ReadImage(os.path.join(self.cf.raw_data_dir, pid, rp)) + for dim in range(len(img_arr_shape)): + npt.assert_almost_equal(roi.GetSpacing()[dim], img_spacing[dim]) + roi_arr = sitk.GetArrayFromImage(roi).astype(np.uint8) + roi_arr = resample_array(roi_arr, roi.GetSpacing(), self.cf.target_spacing) + assert roi_arr.shape == img_arr_shape, [roi_arr.shape, img_arr_shape, pid, roi.GetSpacing()] + rater_segs.append(roi_arr) + rater_segs.extend([np.zeros_like(rater_segs[-1])] * (4 - len(roi_id_paths))) + rater_segs = np.mean(np.array(rater_segs), axis=0) + # annotations merging: if less than two raters found fg, set segmentation to bg. + rater_segs[rater_segs < 0.5] = 0 + if np.sum(rater_segs) > 0: + patient_mal_labels.append(mal_label) + final_rois[rater_segs > 0] = rix + rix += 1 + else: + # indicate rois suppressed by majority voting of raters + print('suppressed roi!', roi_id_paths) + with open(os.path.join(self.pp_dir_merged, 'suppressed_rois.txt'), 'a') as handle: + handle.write(" ".join(roi_id_paths)) + + fg_slices = [ii for ii in np.unique(np.argwhere(final_rois != 0)[:, 0])] + patient_mal_labels = np.array(patient_mal_labels) + assert len(patient_mal_labels) + 1 == len(np.unique(final_rois)), [len(patient_mal_labels), np.unique(final_rois), pid] + assert final_rois.dtype == 'uint8' + np.save(os.path.join(self.pp_dir_merged, '{}_rois.npy'.format(pid)), final_rois) + + with open(os.path.join(self.pp_dir_merged, '{}_meta_info.pickle'.format(pid)), 'wb') as handle: + meta_info_dict = {'pid': pid, 'class_target': patient_mal_labels, 'spacing': img_spacing, + 'fg_slices': fg_slices} + pickle.dump(meta_info_dict, handle) + + def pp_patient(self, path): + + pid = path.split('/')[-1] + img = sitk.ReadImage(os.path.join(path, '{}_ct_scan.nrrd'.format(pid))) + img_arr = sitk.GetArrayFromImage(img) + print('processing {} with GT(s) {}, spacing {} and img shape {}.'.format( + pid, " and ".join(self.cf.gts_to_produce), img.GetSpacing(), img_arr.shape)) + img_arr = resample_array(img_arr, img.GetSpacing(), self.cf.target_spacing) + img_arr = np.clip(img_arr, -1200, 600) + #img_arr = (1200 + img_arr) / (600 + 1200) * 255 # a+x / (b-a) * (c-d) (c, d = new) + img_arr = img_arr.astype(np.float32) + img_arr = (img_arr - np.mean(img_arr)) / np.std(img_arr).astype('float16') + + df = pd.read_csv(os.path.join(self.cf.root_dir, 'characteristics.csv'), sep=';') + df = df[df.PatientID == pid] + + np.save(os.path.join(self.cf.pp_dir, '{}_img.npy'.format(pid)), img_arr) + if 'single_annotator' in self.cf.gts_to_produce or 'sa' in self.cf.gts_to_produce: + self.produce_sa_gt(path, pid, df, img.GetSpacing(), img_arr.shape) + if 'merged' in self.cf.gts_to_produce: + self.produce_merged_gt(path, pid, df, img.GetSpacing(), img_arr.shape) + + + def iterate_patients(self, processes=os.cpu_count()): + pool = Pool(processes=processes) + pool.map(self.pp_patient, self.paths, chunksize=1) + pool.close() + pool.join() + print("finished processing raw patient data") + + + def aggregate_meta_info(self): + self.dfs = {} + for gt_kind in self.cf.gts_to_produce: + kind_dir = self.pp_dir_merged if gt_kind == "merged" else self.pp_dir_sa + files = [os.path.join(kind_dir, f) for f in os.listdir(kind_dir) if 'meta_info.pickle' in f] + self.dfs[gt_kind] = pd.DataFrame(columns=['pid', 'class_target', 'spacing', 'fg_slices']) + for f in files: + with open(f, 'rb') as handle: + self.dfs[gt_kind].loc[len(self.dfs[gt_kind])] = pickle.load(handle) + + self.dfs[gt_kind].to_pickle(os.path.join(kind_dir, 'info_df.pickle')) + print("aggregated meta info to df with length", len(self.dfs[gt_kind])) + + def convert_copy_npz(self): + npz_dir = os.path.join(self.cf.pp_dir+'_npz') + print("converting to npz dir", npz_dir) + os.makedirs(npz_dir, exist_ok=True) + + dmanager.pack_dataset(self.cf.pp_dir, destination=npz_dir, recursive=True, verbose=False) + if hasattr(self, 'pp_dir_merged'): + subprocess.call('rsync -avh --exclude="*.npy" {} {}'.format(self.pp_dir_merged, npz_dir), shell=True) + if hasattr(self, 'pp_dir_sa'): + subprocess.call('rsync -avh --exclude="*.npy" {} {}'.format(self.pp_dir_sa, npz_dir), shell=True) + + +if __name__ == "__main__": + total_stime = time.time() + + import configs + cf = configs.configs() + + # analysis finding: the following patients have unclear annotations. some raters gave more than one judgement + # on the same roi. + patients_to_exclude = ["0137a", "0404a", "0204a", "0252a", "0366a", "0863a", "0815a", "0060a", "0249a", "0436a", "0865a"] + # further finding: the following patients contain nodules with segmentation-label inconsistencies + # running Preprocessor.verify_seg_label_pairings() produces a data frame with detailed findings. + patients_to_exclude += ["0305a", "0447a"] + exclude_paths = [os.path.join(cf.raw_data_dir, pid) for pid in patients_to_exclude] + # These pids are automatically found and excluded, when setting exclude_inconsistents=True at Preprocessor + # initialization instead of passing the pre-compiled list. + + + pp = Preprocessor(cf, overwrite=True, exclude_inconsistents=exclude_paths, max_count=None, pids_subset=None)#["0998a"]) + #pp.analyze_lesion("0305a", "5728") + #pp.analyze_lesion("0305a", "5741") + #pp.analyze_lesion("0447a", "8840") + + #pp.verify_seg_label_pairings() + #load_df(os.path.join(cf.pp_dir, "verification_seg_label_pairings.pickle")) + pp.iterate_patients(processes=8) + # for i in ["/mnt/E130-Personal/Goetz/Datenkollektive/Lungendaten/Nodules_LIDC_IDRI/new_nrrd/0305a", + # "/mnt/E130-Personal/Goetz/Datenkollektive/Lungendaten/Nodules_LIDC_IDRI/new_nrrd/0447a"]: #pp.paths[:1]: + # pp.pp_patient(i) + pp.aggregate_meta_info() + pp.convert_copy_npz() + + + + mins, secs = divmod((time.time() - total_stime), 60) + h, mins = divmod(mins, 60) + t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) + print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) diff --git a/datasets/prostate/check_GSBx_Re.py b/datasets/prostate/check_GSBx_Re.py new file mode 100755 index 0000000..8f64ca1 --- /dev/null +++ b/datasets/prostate/check_GSBx_Re.py @@ -0,0 +1,120 @@ +""" +Created at 20/11/18 16:18 +@author: gregor +""" +import os +import numpy as np +import pandas as pd + + +class CombinedPrinter(object): + """combined print function. + prints to logger and/or file if given, to normal print if non given. + + """ + def __init__(self, logger=None, file=None): + + if logger is None and file is None: + self.out = [print] + elif logger is None: + self.out = [print, file.write] + elif file is None: + self.out = [print, logger.info] + else: + self.out = [print, logger.info, file.write] + + def __call__(self, string): + for fct in self.out: + fct(string) + +def spec_to_id(spec): + """Get subject id from string""" + return int(spec[-5:]) + + +def pat_roi_GS_histo_check(root_dir): + """ Check, in histo files, whether patient-wide Gleason Score equals maximum GS found in single lesions of patient. + """ + + histo_les_path = os.path.join(root_dir, "MasterHistoAll.csv") + histo_pat_path = os.path.join(root_dir, "MasterPatientbasedAll_clean.csv") + + with open(histo_les_path,mode="r") as les_file: + les_df = pd.read_csv(les_file, delimiter=",") + with open(histo_pat_path, mode="r") as pat_file: + pat_df = pd.read_csv(pat_file, delimiter=",") + + merged_df = les_df.groupby('Master_ID').agg({'Gleason': 'max', 'segmentationsNameADC': 'last'}) + + for pid in merged_df.index: + merged_df.set_value(pid, "GSBx", pat_df[pat_df.Master_ID_Short==pid].GSBx.unique().astype('uint32')) + + #print(merged_df) + print("All patient-wise GS are maximum of lesion-wise GS?", np.all(merged_df.Gleason == merged_df.GSBx), end="\n\n") + assert np.all(merged_df.Gleason == merged_df.GSBx) + + +def lesion_redone_check(root_dir, out_path=None): + """check how many les annotations without post_fix _Re exist and if exists what their GS is + """ + + histo_les_path = os.path.join(root_dir, "Dokumente/MasterHistoAll.csv") + with open(histo_les_path,mode="r") as les_file: + les_df = pd.read_csv(les_file, delimiter=",") + if out_path is not None: + out_file = open(out_path, "w") + else: + out_file = None + print_f = CombinedPrinter(file=out_file) + + data_dir = os.path.join(root_dir, "Daten") + + matches = {} + for patient in [dir for dir in os.listdir(data_dir) if dir.startswith("Master_") \ + and os.path.isdir(os.path.join(data_dir, dir))]: + matches[patient] = {} + pat_dir = os.path.join(data_dir,patient) + lesions = [os.path.splitext(file)[0] for file in os.listdir(pat_dir) if os.path.isfile(os.path.join(pat_dir,file)) and file.startswith("seg") and "LES" in file] + lesions_wo = [os.path.splitext(file)[0] for file in lesions if not "_Re" in file] + lesions_with = [file for file in lesions if "_Re" in file and not "registered" in file] + + matches[patient] = {les_wo : [] for les_wo in lesions_wo} + + for les_wo in matches[patient].keys(): + matches[patient][les_wo] += [les_with for les_with in lesions_with if les_with.startswith(les_wo)] + + missing_les_count = 0 + for patient, lesions in sorted(list(matches.items())): + pat_df = les_df[les_df.Master_ID==spec_to_id(patient)] + for les, les_matches in sorted(list(lesions.items())): + if len(les_matches)==0: + if "t2" in les.lower(): + les_GS = pat_df[pat_df.segmentationsNameT2==les]["Gleason"] + elif "adc" in les.lower(): + les_GS = pat_df[pat_df.segmentationsNameADC==les]["Gleason"] + if len(les_GS)==0: + les_GS = r"[no histo finding!]" + print_f("Patient {}, lesion {} with GS {} has no matches!\n".format(patient, les, les_GS)) + missing_les_count +=1 + else: + del matches[patient][les] + #elif len(les_matches) > 1: + # print("Patient {}, Lesion {} has {} matches: {}".format(patient, les, len(les_matches), les_matches)) + if len(matches[patient])==0: + del matches[patient] + + print_f("Total missing lesion matches: {} within {} patients".format(missing_les_count, len(matches))) + + out_file.close() + + +if __name__=="__main__": + + #root_dir = "/mnt/HDD2TB/Documents/data/prostate/data_di_ana_081118_ps384_gs71/histos/" + root_dir = "/mnt/E132-Projekte/Move_to_E132-Rohdaten/Prisma_Master/Dokumente" + pat_roi_GS_histo_check(root_dir) + + root_dir = "/mnt/E132-Projekte/Move_to_E132-Rohdaten/Prisma_Master" + out_path = os.path.join(root_dir,"lesion_redone_check.txt") + lesion_redone_check(root_dir, out_path=out_path) + diff --git a/datasets/prostate/configs.py b/datasets/prostate/configs.py new file mode 100644 index 0000000..2de02f3 --- /dev/null +++ b/datasets/prostate/configs.py @@ -0,0 +1,588 @@ +__author__ = '' +#credit Paul F. Jaeger + +######################### +# Example Config # +######################### + +import os +import sys +import pickle + +import numpy as np +import torch + +from collections import namedtuple + +from default_configs import DefaultConfigs + +def load_obj(file_path): + with open(file_path, 'rb') as handle: + return pickle.load(handle) + +# legends, nested classes are not handled well in multiprocessing! hence, Label class def in outer scope +Label = namedtuple("Label", ['id', 'name', 'color', 'gleasons']) +binLabel = namedtuple("Label", ['id', 'name', 'color', 'gleasons', 'bin_vals']) + + +class Configs(DefaultConfigs): #todo change to Configs + + def __init__(self, server_env=None): + ######################### + # General # + ######################### + super(Configs, self).__init__(server_env) + + ######################### + # I/O # + ######################### + + self.data_sourcedir = "/mnt/HDD2TB/Documents/data/prostate/data_di_250519_ps384_gs6071/" + #self.data_sourcedir = "/mnt/HDD2TB/Documents/data/prostate/data_t2_250519_ps384_gs6071/" + #self.data_sourcedir = "/mnt/HDD2TB/Documents/data/prostate/data_analysis/" + + if server_env: + self.data_sourcedir = "/datasets/data_ramien/prostate/data_di_250519_ps384_gs6071_npz/" + #self.data_sourcedir = '/datasets/data_ramien/prostate/data_t2_250519_ps384_gs6071_npz/' + #self.data_sourcedir = "/mnt/HDD2TB/Documents/data/prostate/data_di_ana_151118_ps384_gs60/" + + self.histo_dir = os.path.join(self.data_sourcedir,"histos/") + self.info_dict_name = 'master_info.pkl' + self.info_dict_path = os.path.join(self.data_sourcedir, self.info_dict_name) + + self.config_path = os.path.realpath(__file__) + + # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_fpn']. + self.model = 'detection_fpn' + self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net') + self.model_path = os.path.join(self.source_dir,self.model_path) + + self.select_prototype_subset = None + + ######################### + # Preprocessing # + ######################### + self.missing_pz_subjects = [#189, 196, 198, 205, 211, 214, 215, 217, 218, 219, 220, + #223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, + #234, 235, 236, 237, 238, 239, 240, 241, 242, 244, 258, + #261, 262, 264, 267, 268, 269, 270, 271, 273, 275, 276, + #277, 278, 283 + ] + self.no_bval_radval_subjects = [57] #this guy has master id 222 + + self.prepro = { + 'data_dir': '/home/gregor/networkdrives/E132-Projekte/Move_to_E132-Rohdaten/Prisma_Master/Daten/', + 'dir_spec': 'Master', + #'images': {'t2': 'T2TRA', 'adc': 'ADC1500', 'b50': 'BVAL50', 'b500': 'BVAL500', + # 'b1000': 'BVAL1000', 'b1500': 'BVAL1500'}, + #'images': {'adc': 'ADC1500', 'b50': 'BVAL50', 'b500': 'BVAL500', 'b1000': 'BVAL1000', 'b1500': 'BVAL1500'}, + 'images': {'t2': 'T2TRA'}, + 'anatomical_masks': ['seg_T2_PRO'], # try: 'seg_T2_PRO','seg_T2_PZ', 'seg_ADC_PRO', 'seg_ADC_PZ', + 'merge_mode' : 'union', #if registered data w/ two gts: take 'union' or 'adc' or 't2' of gt + 'rename_tags': {'seg_ADC_PRO':"pro", 'seg_T2_PRO':"pro", 'seg_ADC_PZ':"pz", 'seg_T2_PZ':"pz"}, + 'lesion_postfix': '_Re', #lesion files are tagged seg_MOD_LESx + 'img_postfix': "_resampled2", #"_resampled2_registered", + 'overall_postfix': ".nrrd", #including filetype ending! + + 'histo_dir': '/home/gregor/networkdrives/E132-Projekte/Move_to_E132-Rohdaten/Prisma_Master/Dokumente/', + 'histo_dir_out': self.histo_dir, + 'histo_lesion_based': 'MasterHistoAll.csv', + 'histo_patient_based': 'MasterPatientbasedAll_clean.csv', + 'histo_id_column_name': 'Master_ID', + 'histo_pb_id_column_name': 'Master_ID_Short', #for patient histo + + 'excluded_prisma_subjects': [], + 'excluded_radval_subjects': self.no_bval_radval_subjects, + 'excluded_master_subjects': self.missing_pz_subjects, + + 'seg_labels': {'tz': 0, 'pz': 0, 'lesions':'roi'}, + #set as hard label or 'roi' to have seg labels represent obj instance count + #if not given 'lesions' are numbered highest seg label +lesion-nr-in-histofile + 'class_labels': {'lesions':'gleason'}, #0 is not bg, but first fg class! + #i.e., prepro labels are shifted by -1 towards later training labels in gt, legends, dicts, etc. + #evtly set lesions to 'gleason' and check gleason remap in prepro + #'gleason_thresh': 71, + 'gleason_mapping': {0: -1, 60:0, 71:1, 72:1, 80:1, 90:1, 91:1, 92:1}, + 'gleason_map': self.gleason_map, #see below + 'color_palette': [self.green, self.red], + + 'output_directory': self.data_sourcedir, + + 'modalities2concat' : "all", #['t2', 'adc','b50','b500','b1000','b1500'], #will be concatenated on colorchannel + 'center_of_mass_crop': True, + 'mod_scaling' : (1,1,1), #z,y,x + 'pre_crop_size': [20, 384, 384], #z,y,x, z-cropping and non-square not implemented atm!! + 'swap_yx_to_xy': False, #change final spatial shape from z,y,x to z,x,y + 'normalization': {'percentiles':[1., 99.]}, + 'interpolation': 'nearest', + + 'observables_patient': ['Original_ID', 'GSBx', 'PIRADS2', 'PSA'], + 'observables_rois': ['lesion_gleasons'], + + 'info_dict_path': self.info_dict_path, + + 'npz_dir' : self.data_sourcedir[:-1]+"_npz" #if not None: convert to npz, copy data here + } + if self.prepro["modalities2concat"] == "all": + self.prepro["modalities2concat"] = list(self.prepro["images"].keys()) + + ######################### + # Architecture # + ######################### + + # dimension the model operates in. one out of [2, 3]. + self.dim = 2 + + # 'class': standard object classification per roi, pairwise combinable with each of below tasks. + # if 'class' is omitted from tasks, object classes will be fg/bg (1/0) from RPN. + # 'regression': regress some vector per each roi + # 'regression_ken_gal': use kendall-gal uncertainty sigma + # 'regression_bin': classify each roi into a bin related to a regression scale + self.prediction_tasks = ['class',] + + self.start_filts = 48 if self.dim == 2 else 18 + self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2 + self.res_architecture = 'resnet50' # 'resnet101' or 'resnet50' + self.weight_init = None #'kaiming_normal' #, 'xavier' or None-->pytorch standard, + self.norm = None #'instance_norm' # one of 'None', 'instance_norm', 'batch_norm' + self.relu = 'relu' # 'relu' or 'leaky_relu' + + self.regression_n_features = 1 #length of regressor target vector (always 1D) + + ######################### + # Data Loader # + ######################### + + self.seed = 17 + self.n_workers = 16 if server_env else os.cpu_count() + + self.batch_size = 10 if self.dim == 2 else 6 + + self.channels = [1, 2, 3, 4] # modalities2load, see prepo + self.n_channels = len(self.channels) # for compatibility, but actually redundant + # which channel (mod) to show as bg in plotting, will be extra added to batch if not in self.channels + self.plot_bg_chan = 0 + self.pre_crop_size = list(np.array(self.prepro['pre_crop_size'])[[1, 2, 0]]) # now y,x,z + self.crop_margin = [20, 20, 1] # has to be smaller than respective patch_size//2 + self.patch_size_2D = self.pre_crop_size[:2] #[288, 288] + self.patch_size_3D = self.pre_crop_size[:2] + [8] # only numbers divisible by 2 multiple times + # (at least 5 times for x,y, at least 3 for z)! + # otherwise likely to produce error in crop fct or net + self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D + + self.balance_target = "class_targets" if 'class' in self.prediction_tasks else 'rg_bin_targets' + # ratio of fully random patients drawn during batch generation + # resulting batch random count is rounded down to closest integer + self.batch_random_ratio = 0.2 if self.dim==2 else 0.4 + + self.observables_patient = ['Original_ID', 'GSBx', 'PIRADS2'] + self.observables_rois = ['lesion_gleasons'] + + self.regression_target = "lesion_gleasons" # name of the info_dict entry holding regression targets + # linear mapping + self.rg_map = {0: 0, 60: 1, 71: 2, 72: 3, 80: 4, 90: 5, 91: 6, 92: 7, None: 0} + # non-linear mapping + #self.rg_map = {0: 0, 60: 1, 71: 6, 72: 7.5, 80: 9, 90: 10, 91: 10, 92: 10, None: 0} + + ######################### + # Colors and Legends # + ######################### + self.plot_frequency = 5 + + # colors + self.gravity_col_palette = [self.green, self.yellow, self.orange, self.bright_red, self.red, self.dark_red] + + self.gs_labels = [ + Label(0, 'bg', self.gray, (0,)), + Label(60, 'GS60', self.dark_green, (60,)), + Label(71, 'GS71', self.dark_yellow, (71,)), + Label(72, 'GS72', self.orange, (72,)), + Label(80, 'GS80', self.brighter_red,(80,)), + Label(90, 'GS90', self.bright_red, (90,)), + Label(91, 'GS91', self.red, (91,)), + Label(92, 'GS92', self.dark_red, (92,)) + ] + self.gs2label = {label.id: label for label in self.gs_labels} + + + binary_cl_labels = [Label(1, 'benign', (*self.green, 1.), (60,)), + Label(2, 'malignant', (*self.red, 1.), (71,72,80,90,91,92)), + #Label(3, 'pz', (*self.blue, 1.), (None,)), + #Label(4, 'tz', (*self.aubergine, 1.), (None,)) + ] + + self.class_labels = [ + #id #name #color #gleason score + Label( 0, 'bg', (*self.gray, 0.), (0,))] + if "class" in self.prediction_tasks: + self.class_labels += binary_cl_labels + # self.class_labels += [Label(cl, cl_dic["name"], cl_dic["color"], tuple(cl_dic["gleasons"])) + # for cl, cl_dic in + # load_obj(os.path.join(self.data_sourcedir, "pp_class_labels.pkl")).items()] + else: + self.class_labels += [Label( 1, 'lesion', (*self.red, 1.), (60,71,72,80,90,91,92))] + + if any(['regression' in task for task in self.prediction_tasks]): + self.bin_labels = [binLabel(0, 'bg', (*self.gray, 0.), (0,), (0,))] + self.bin_labels += [binLabel(cl, cl_dic["name"], cl_dic["color"], tuple(cl_dic["gleasons"]), + tuple([self.rg_map[gs] for gs in cl_dic["gleasons"]])) for cl, cl_dic in + sorted(load_obj(os.path.join(self.data_sourcedir, "pp_class_labels.pkl")).items())] + self.bin_id2label = {label.id: label for label in self.bin_labels} + self.gs2bin_label = {gs: label for label in self.bin_labels for gs in label.gleasons} + bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels] + self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)} + self.bin_edges = [(bins[i][1] + bins[i+1][0]) / 2 for i in range(len(bins)-1)] + self.bin_dict = {label.id: label.name for label in self.bin_labels if label.id != 0} + + + if self.class_specific_seg: + self.seg_labels = self.class_labels + else: + self.seg_labels = [ # id #name #color + Label(0, 'bg', (*self.white, 0.)), + Label(1, 'fg', (*self.orange, 1.)) + ] + + self.class_id2label = {label.id: label for label in self.class_labels} + self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0} + # class_dict is used in evaluator / ap, auc, etc. statistics, and class 0 (bg) only needs to be + # evaluated in debugging + self.class_cmap = {label.id: label.color for label in self.class_labels} + + self.seg_id2label = {label.id: label for label in self.seg_labels} + self.cmap = {label.id: label.color for label in self.seg_labels} + + self.plot_prediction_histograms = True + self.plot_stat_curves = False + self.plot_class_ids = True + + self.num_classes = len(self.class_dict) # for instance classification (excl background) + self.num_seg_classes = len(self.seg_labels) # incl background + + ######################### + # Data Augmentation # + ######################### + #the angle rotations are implemented incorrectly in batchgenerators! in 2D, + #the x-axis angle controls the z-axis angle. + if self.dim == 2: + angle_x = (-np.pi / 3., np.pi / 3.) + angle_z = (0.,0.) + rcd = (self.patch_size[0] / 2., self.patch_size[1] / 2.) + else: + angle_x = (0.,0.) + angle_z = (-np.pi / 2., np.pi / 2.) + rcd = (self.patch_size[0] / 2., self.patch_size[1] / 2., + self.patch_size[2] / 2.) + + self.do_aug = True + # DA settings for DWI + self.da_kwargs = { + 'mirror': True, + 'mirror_axes': tuple(np.arange(0, self.dim, 1)), + 'random_crop': True, + 'rand_crop_dist': rcd, + 'do_elastic_deform': self.dim==2, + 'alpha': (0., 1500.), + 'sigma': (25., 50.), + 'do_rotation': True, + 'angle_x': angle_x, + 'angle_y': (0., 0.), + 'angle_z': angle_z, + 'do_scale': True, + 'scale': (0.7, 1.3), + 'border_mode_data': 'constant', + 'gamma_transform': True, + 'gamma_range': (0.5, 2.) + } + # for T2 + # self.da_kwargs = { + # 'mirror': True, + # 'mirror_axes': tuple(np.arange(0, self.dim, 1)), + # 'random_crop': False, + # 'rand_crop_dist': rcd, + # 'do_elastic_deform': False, + # 'alpha': (0., 1500.), + # 'sigma': (25., 50.), + # 'do_rotation': True, + # 'angle_x': angle_x, + # 'angle_y': (0., 0.), + # 'angle_z': angle_z, + # 'do_scale': False, + # 'scale': (0.7, 1.3), + # 'border_mode_data': 'constant', + # 'gamma_transform': False, + # 'gamma_range': (0.5, 2.) + # } + + + ################################# + # Schedule / Selection / Optim # + ################################# + + # good guess: train for n_samples = 1.1m = epochs*n_train_bs*b_size + self.num_epochs = 270 + self.num_train_batches = 120 if self.dim == 2 else 140 + + self.val_mode = 'val_patient' # one of 'val_sampling', 'val_patient' + # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training) + # the former is more accurate, while the latter is faster (depending on volume size) + self.num_val_batches = 200 if self.dim==2 else 40 # for val_sampling, number or "all" + self.max_val_patients = "all" #for val_patient, "all" takes whole split + + self.save_n_models = 6 + self.min_save_thresh = 3 if self.dim == 2 else 4 #=wait time in epochs + if "class" in self.prediction_tasks: + # 'criterion': weight + self.model_selection_criteria = {"benign_ap": 0.2, "malignant_ap": 0.8} + elif any("regression" in task for task in self.prediction_tasks): + self.model_selection_criteria = {"lesion_ap": 0.2, "lesion_avp": 0.8} + #self.model_selection_criteria = {"GS71-92_ap": 0.9, "GS60_ap": 0.1} # 'criterion':weight + #self.model_selection_criteria = {"lesion_ap": 0.2, "lesion_avp": 0.8} + #self.model_selection_criteria = {label.name+"_ap": 1. for label in self.class_labels if label.id!=0} + + self.scan_det_thresh = False + self.warm_up = 0 + + self.optimizer = "ADAM" + self.weight_decay = 1e-5 + self.clip_norm = None #number or None + + self.learning_rate = [1e-4] * self.num_epochs + self.dynamic_lr_scheduling = True + self.lr_decay_factor = 0.5 + self.scheduling_patience = int(self.num_epochs / 6) + + ######################### + # Testing # + ######################### + + self.test_aug_axes = (0,1,(0,1)) # None or list: choices are 0,1,(0,1) (0==spatial y, 1== spatial x). + self.held_out_test_set = False + self.max_test_patients = "all" # "all" or number + self.report_score_level = ['rois', 'patient'] # 'patient' or 'rois' (incl) + self.patient_class_of_interest = 2 if 'class' in self.prediction_tasks else 1 + + + self.eval_bins_separately = "additionally" if not 'class' in self.prediction_tasks else False + self.patient_bin_of_interest = 2 + self.metrics = ['ap', 'auc', 'dice'] + if any(['regression' in task for task in self.prediction_tasks]): + self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp', + 'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp'] + if 'aleatoric' in self.model: + self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted'] + self.evaluate_fold_means = True + + self.min_det_thresh = 0.02 + + self.ap_match_ious = [0.1] # threshold(s) for considering a prediction as true positive + # aggregation method for test and val_patient predictions. + # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf, + # nms = standard non-maximum suppression, or None = no clustering + self.clustering = 'wbc' + # iou thresh (exclusive!) for regarding two preds as concerning the same ROI + self.clustering_iou = 0.1 # has to be larger than desired possible overlap iou of model predictions + # 2D-3D merging is applied independently from clustering setting. + self.merge_2D_to_3D_preds = True if self.dim == 2 else False + self.merge_3D_iou = 0.1 + self.n_test_plots = 1 # per fold and rank + self.test_n_epochs = self.save_n_models # should be called n_test_ens, since is number of models to ensemble over during testing + # is multiplied by n_test_augs if test_aug + + ######################### + # shared model settings # + ######################### + + # max number of roi candidates to identify per image and class (slice in 2D, volume in 3D) + self.n_roi_candidates = 10 if self.dim == 2 else 15 + + ######################### + # assertions # + ######################### + if not 'class' in self.prediction_tasks: + assert self.num_classes == 1 + for mod in self.prepro['modalities2concat']: + assert mod in self.prepro['images'].keys(), "need to adapt mods2concat to chosen images" + + ######################### + # Add model specifics # + ######################### + + {'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs, + 'mrcnn_gan': self.add_mrcnn_configs, + 'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs, + 'detection_unet': self.add_det_unet_configs, 'detection_fpn': self.add_det_fpn_configs + }[self.model]() + + def gleason_map(self, GS): + """gleason to class id + :param GS: gleason score as in histo file + """ + if "gleason_thresh" in self.prepro.keys(): + assert "gleason_mapping" not in self.prepro.keys(), "cant define both, thresh and map, for GS to classes" + # -1 == bg, 0 == benign, 1 == malignant + # before shifting, i.e., 0!=bg, but 0==first class + remapping = 0 if GS >= self.prepro["gleason_thresh"] else -1 + return remapping + elif "gleason_mapping" in self.prepro.keys(): + return self.prepro["gleason_mapping"][GS] + else: + raise Exception("Need to define some remapping, at least GS 0 -> background (class -1)") + + def rg_val_to_bin_id(self, rg_val): + return float(np.digitize(rg_val, self.bin_edges)) + + def add_det_fpn_configs(self): + self.scheduling_criterion = 'torch_loss' + self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' + + # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') + self.seg_loss_mode = 'wce' + self.wce_weights = [1]*self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1, 1] + # if <1, false positive predictions in foreground are penalized less. + self.fp_dice_weight = 1 if self.dim == 2 else 1 + + + self.detection_min_confidence = 0.05 + #how to determine score of roi: 'max' or 'median' + self.score_det = 'max' + + self.cuda_benchmark = self.dim==3 + + def add_det_unet_configs(self): + self.scheduling_criterion = "torch_loss" + self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' + + # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') + self.seg_loss_mode = 'wce' + self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1, 1] + # if <1, false positive predictions in foreground are penalized less. + self.fp_dice_weight = 1 if self.dim == 2 else 1 + + self.detection_min_confidence = 0.05 + #how to determine score of roi: 'max' or 'median' + self.score_det = 'max' + + self.init_filts = 32 + self.kernel_size = 3 #ks for horizontal, normal convs + self.kernel_size_m = 2 #ks for max pool + self.pad = "same" # "same" or integer, padding of horizontal convs + + self.cuda_benchmark = True + + def add_mrcnn_configs(self): + + self.scheduling_criterion = max(self.model_selection_criteria, key=self.model_selection_criteria.get) + self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' + + # number of classes for network heads: n_foreground_classes + 1 (background) + self.head_classes = self.num_classes + 1 + # + # feed +/- n neighbouring slices into channel dimension. set to None for no context. + self.n_3D_context = None + if self.n_3D_context is not None and self.dim == 2: + self.n_channels *= (self.n_3D_context * 2 + 1) + + self.frcnn_mode = False + # disable the re-sampling of mask proposals to original size for speed-up. + # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching), + # mask outputs are optional. + self.return_masks_in_train = True + self.return_masks_in_val = True + self.return_masks_in_test = True + + # feature map strides per pyramid level are inferred from architecture. anchor scales are set accordingly. + self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]} + # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale + # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.) + self.rpn_anchor_scales = {'xy': [[4], [8], [16], [32]], 'z': [[1], [2], [4], [8]]} + # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3. + self.pyramid_levels = [0, 1, 2, 3] + # number of feature maps in rpn. typically lowered in 3D to save gpu-memory. + self.n_rpn_features = 512 if self.dim == 2 else 128 + + # anchor ratios and strides per position in feature maps. + self.rpn_anchor_ratios = [0.5,1.,2.] + self.rpn_anchor_stride = 1 + # Threshold for first stage (RPN) non-maximum suppression (NMS): LOWER == HARDER SELECTION + self.rpn_nms_threshold = 0.7 if self.dim == 2 else 0.7 + + # loss sampling settings. + self.rpn_train_anchors_per_image = 6 + self.train_rois_per_image = 6 #per batch_instance + self.roi_positive_ratio = 0.5 + self.anchor_matching_iou = 0.7 + + # k negative example candidates are drawn from a pool of size k*shem_poolsize (stochastic hard-example mining), + # where k<=#positive examples. + self.shem_poolsize = 3 + + self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3) + self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5) + self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10) + + self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) + self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) + self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]]) + self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1], + self.patch_size_3D[2], self.patch_size_3D[2]]) #y1,x1,y2,x2,z1,z2 + + if self.dim == 2: + self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4] + self.bbox_std_dev = self.bbox_std_dev[:4] + self.window = self.window[:4] + self.scale = self.scale[:4] + + self.plot_y_max = 1.5 + self.n_plot_rpn_props = 5 if self.dim == 2 else 30 #per batch_instance (slice in 2D / patient in 3D) + + # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element. + self.pre_nms_limit = 3000 if self.dim == 2 else 6000 + + # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True, + # since proposals of the entire batch are forwarded through second stage in as one "batch". + self.roi_chunk_size = 2000 if self.dim == 2 else 400 + self.post_nms_rois_training = 250 * (self.head_classes-1) if self.dim == 2 else 500 + self.post_nms_rois_inference = 250 * (self.head_classes-1) + + # Final selection of detections (refine_detections) + self.model_max_instances_per_batch_element = self.n_roi_candidates # per batch element and class. + # iou for nms in box refining (directly after heads), should be >0 since ths>=x in mrcnn.py, otherwise all predictions are one cluster. + self.detection_nms_threshold = 1e-5 + # detection score threshold in refine_detections() + self.model_min_confidence = 0.05 #self.min_det_thresh/2 + + if self.dim == 2: + self.backbone_shapes = np.array( + [[int(np.ceil(self.patch_size[0] / stride)), + int(np.ceil(self.patch_size[1] / stride))] + for stride in self.backbone_strides['xy']]) + else: + self.backbone_shapes = np.array( + [[int(np.ceil(self.patch_size[0] / stride)), + int(np.ceil(self.patch_size[1] / stride)), + int(np.ceil(self.patch_size[2] / stride_z))] + for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z'] + )]) + + self.operate_stride1 = False + + if self.model == 'retina_net' or self.model == 'retina_unet': + self.cuda_benchmark = self.dim == 3 + #implement extra anchor-scales according to https://arxiv.org/abs/1708.02002 + self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in + self.rpn_anchor_scales['xy']] + self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in + self.rpn_anchor_scales['z']] + self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3 + + self.n_rpn_features = 256 if self.dim == 2 else 64 + + # pre-selection of detections for NMS-speedup. per entire batch. + self.pre_nms_limit = (1000 if self.dim == 2 else 6250) * self.batch_size + + # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002 + self.anchor_matching_iou = 0.5 + + if self.model == 'retina_unet': + self.operate_stride1 = True \ No newline at end of file diff --git a/datasets/prostate/data_loader.py b/datasets/prostate/data_loader.py new file mode 100644 index 0000000..69c53e6 --- /dev/null +++ b/datasets/prostate/data_loader.py @@ -0,0 +1,716 @@ +__author__ = '' +#credit derives from Paul Jaeger, Simon Kohl + +import os +import time +import warnings + +from collections import OrderedDict +import pickle + +import numpy as np +import pandas as pd + +# batch generator tools from https://github.com/MIC-DKFZ/batchgenerators +from batchgenerators.augmentations.utils import resize_image_by_padding, center_crop_2D_image, center_crop_3D_image +from batchgenerators.dataloading.data_loader import SlimDataLoaderBase +from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror +from batchgenerators.transforms.abstract_transforms import Compose +from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter +from batchgenerators.dataloading import SingleThreadedAugmenter +from batchgenerators.transforms.spatial_transforms import SpatialTransform +from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform +#from batchgenerators.transforms.utility_transforms import ConvertSegToBoundingBoxCoordinates +from batchgenerators.transforms import AbstractTransform +from batchgenerators.transforms.color_transforms import GammaTransform + +#sys.path.append(os.path.dirname(os.path.realpath(__file__))) + +#import utils.exp_utils as utils +import utils.dataloader_utils as dutils +from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates +import data_manager as dmanager + + +def load_obj(file_path): + with open(file_path, 'rb') as handle: + return pickle.load(handle) + +def id_to_spec(id, base_spec): + """Construct subject specifier from base string and an integer subject number.""" + num_zeros = 5 - len(str(id)) + assert num_zeros>=0, "id_to_spec: patient id too long to fit into 5 figures" + return base_spec + '_' + ('').join(['0'] * num_zeros) + str(id) + +def convert_3d_to_2d_generator(data_dict, shape="bcxyz"): + """Fold/Shape z-dimension into color-channel. + :param shape: bcxyz or bczyx + :return: shape b(c*z)xy or b(c*z)yx + """ + if shape=="bcxyz": + data_dict['data'] = np.transpose(data_dict['data'], axes=(0,1,4,3,2)) + data_dict['seg'] = np.transpose(data_dict['seg'], axes=(0,1,4,3,2)) + elif shape=="bczyx": + pass + else: + raise Exception("unknown datashape {} in 3d_to_2d transform converter".format(shape)) + + shp = data_dict['data'].shape + data_dict['orig_shape_data'] = shp + seg_shp = data_dict['seg'].shape + data_dict['orig_shape_seg'] = seg_shp + + data_dict['data'] = data_dict['data'].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4])) + data_dict['seg'] = data_dict['seg'].reshape((seg_shp[0], seg_shp[1] * seg_shp[2], seg_shp[3], seg_shp[4])) + + return data_dict + +def convert_2d_to_3d_generator(data_dict, shape="bcxyz"): + """Unfold z-dimension from color-channel. + data needs to be in shape bcxy or bcyx, x,y dims won't be swapped relative to each other. + :param shape: target shape, bcxyz or bczyx + """ + shp = data_dict['orig_shape_data'] + cur_shape = data_dict['data'].shape + seg_shp = data_dict['orig_shape_seg'] + cur_shape_seg = data_dict['seg'].shape + + data_dict['data'] = data_dict['data'].reshape((shp[0], shp[1], shp[2], cur_shape[-2], cur_shape[-1])) + data_dict['seg'] = data_dict['seg'].reshape((seg_shp[0], seg_shp[1], seg_shp[2], cur_shape_seg[-2], cur_shape_seg[-1])) + + if shape=="bcxyz": + data_dict['data'] = np.transpose(data_dict['data'], axes=(0,1,4,3,2)) + data_dict['seg'] = np.transpose(data_dict['seg'], axes=(0,1,4,3,2)) + return data_dict + +class Convert3DTo2DTransform(AbstractTransform): + def __init__(self): + pass + + def __call__(self, **data_dict): + return convert_3d_to_2d_generator(data_dict) + +class Convert2DTo3DTransform(AbstractTransform): + def __init__(self): + pass + + def __call__(self, **data_dict): + return convert_2d_to_3d_generator(data_dict) + +def vector(item): + """ensure item is vector-like (list or array or tuple) + :param item: anything + """ + if not isinstance(item, (list, tuple, np.ndarray)): + item = [item] + return item + +class Dataset(dutils.Dataset): + r"""Load a dict holding memmapped arrays and clinical parameters for each patient, + evtly subset of those. + If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to + cf.data_dest. + :param cf: config file + :param data_dir: directory in which to find data, defaults to cf.data_dir if None. + :return: dict with imgs, segs, pids, class_labels, observables + """ + + def __init__(self, cf, logger=None, subset_ids=None, data_sourcedir=None): + super(Dataset,self).__init__(cf, data_sourcedir=data_sourcedir) + + info_dict = load_obj(cf.info_dict_path) + + if subset_ids is not None: + pids = subset_ids + if logger is None: + print('subset: selected {} instances from df'.format(len(pids))) + else: + logger.info('subset: selected {} instances from df'.format(len(pids))) + else: + pids = list(info_dict.keys()) + + #evtly copy data from data_rootdir to data_dir + if cf.server_env and not hasattr(cf, "data_dir"): + file_subset = [info_dict[pid]['img'][:-3]+"*" for pid in pids] + file_subset+= [info_dict[pid]['seg'][:-3]+"*" for pid in pids] + file_subset += [cf.info_dict_path] + self.copy_data(cf, file_subset=file_subset) + cf.data_dir = self.data_dir + + img_paths = [os.path.join(self.data_dir, info_dict[pid]['img']) for pid in pids] + seg_paths = [os.path.join(self.data_dir, info_dict[pid]['seg']) for pid in pids] + + # load all subject files + self.data = OrderedDict() + for i, pid in enumerate(pids): + subj_spec = id_to_spec(pid, cf.prepro['dir_spec']) + subj_data = {'pid':pid, "spec":subj_spec} + subj_data['img'] = img_paths[i] + subj_data['seg'] = seg_paths[i] + #read, add per-roi labels + for obs in cf.observables_patient+cf.observables_rois: + subj_data[obs] = np.array(info_dict[pid][obs]) + if 'class' in self.cf.prediction_tasks: + subj_data['class_targets'] = np.array(info_dict[pid]['roi_classes'], dtype='uint8') + 1 + else: + subj_data['class_targets'] = np.ones_like(np.array(info_dict[pid]['roi_classes']), dtype='uint8') + if any(['regression' in task for task in self.cf.prediction_tasks]): + if hasattr(cf, "rg_map"): + subj_data["regression_targets"] = np.array([vector(cf.rg_map[v]) for v in info_dict[pid][cf.regression_target]], dtype='float16') + else: + subj_data["regression_targets"] = np.array([vector(v) for v in info_dict[pid][cf.regression_target]], dtype='float16') + subj_data["rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in subj_data["regression_targets"]], dtype='uint8') + subj_data['fg_slices'] = info_dict[pid]['fg_slices'] + + self.data[pid] = subj_data + + cf.roi_items = cf.observables_rois[:] + cf.roi_items += ['class_targets'] + if any(['regression' in task for task in self.cf.prediction_tasks]): + cf.roi_items += ['regression_targets'] + cf.roi_items += ['rg_bin_targets'] + #cf.patient_items = cf.observables_patient[:] + #patient-wise items not used currently + self.set_ids = np.array(list(self.data.keys())) + + self.df = None + +class BatchGenerator(dutils.BatchGenerator): + """ + create the training/validation batch generator. Randomly sample batch_size patients + from the data set, (draw a random slice if 2D), pad-crop them to equal sizes and merge to an array. + :param data: data dictionary as provided by 'load_dataset' + :param img_modalities: list of strings ['adc', 'b1500'] from config + :param batch_size: number of patients to sample for the batch + :param pre_crop_size: equal size for merging the patients to a single array (before the final random-crop in data aug.) + :param sample_pids_w_replace: whether to randomly draw pids from dataset for batch generation. if False, step through whole dataset + before repition. + :return dictionary containing the batch data / seg / pids as lists; the augmenter will later concatenate them into an array. + """ + def __init__(self, cf, data, n_batches=None, sample_pids_w_replace=True): + super(BatchGenerator, self).__init__(cf, data, n_batches) + self.dataset_length = len(self._data) + self.cf = cf + + self.sample_pids_w_replace = sample_pids_w_replace + self.eligible_pids = list(self._data.keys()) + + self.chans = cf.channels if cf.channels is not None else np.index_exp[:] + assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing" + + self.p_fg = 0.5 + self.empty_samples_max_ratio = 0.6 + self.random_count = int(cf.batch_random_ratio * cf.batch_size) + + self.balance_target_distribution(plot=sample_pids_w_replace) + self.stats = {"roi_counts" : np.zeros((len(self.unique_ts),), dtype='uint32'), "empty_samples_count" : 0} + + def generate_train_batch(self): + #everything done in here is per batch + #print statements in here get confusing due to multithreading + if self.sample_pids_w_replace: + # fully random patients + batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False)) + # target-balanced patients + batch_patient_ids += list(np.random.choice( + self.dataset_pids, size=self.batch_size - self.random_count, replace=False, p=self.p_probs)) + else: + batch_patient_ids = np.random.choice(self.eligible_pids, size=self.batch_size, + replace=False) + if self.sample_pids_w_replace == False: + self.eligible_pids = [pid for pid in self.eligible_pids if pid not in batch_patient_ids] + if len(self.eligible_pids) < self.batch_size: + self.eligible_pids = self.dataset_pids + + batch_data, batch_segs, batch_patient_specs = [], [], [] + batch_roi_items = {name: [] for name in self.cf.roi_items} + #record roi count of classes in batch + batch_roi_counts, empty_samples_count = np.zeros((len(self.unique_ts),), dtype='uint32'), 0 + #empty count for full bg samples (empty slices in 2D/patients in 3D) + + for sample in range(self.batch_size): + + patient = self._data[batch_patient_ids[sample]] + + #swap dimensions from (c,)z,y,x to (c,)y,x,z or h,w,d to ease 2D/3D-case handling + data = np.transpose(np.load(patient['img'], mmap_mode='r'), axes=(0, 2, 3, 1))[self.chans] + seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0)) + (c,y,x,z) = data.shape + + #original data is 3D MRIs, so need to pick (e.g. randomly) single slice to make it 2D, + #consider batch roi-class balance + if self.cf.dim == 2: + elig_slices, choose_fg = [], False + if self.sample_pids_w_replace and len(patient['fg_slices']) > 0: + if empty_samples_count / self.batch_size >= self.empty_samples_max_ratio or np.random.rand( + 1) <= self.p_fg: + # fg is to be picked + for tix in np.argsort(batch_roi_counts): + # pick slices of patient that have roi of sought-for target + # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix + elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero( + patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] == + self.unique_ts[tix]) > 0] + if len(elig_slices) > 0: + choose_fg = True + break + else: + # pick bg + elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices']) + if len(elig_slices) == 0: + elig_slices = z + sl_pick_ix = np.random.choice(elig_slices, size=None) + data = data[..., sl_pick_ix] + seg = seg[..., sl_pick_ix] + + spatial_shp = data[0].shape + assert spatial_shp==seg.shape, "spatial shape incongruence betw. data and seg" + + if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]): + new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))] + data = dutils.pad_nd_image(data, (len(data), *new_shape)) + seg = dutils.pad_nd_image(seg, new_shape) + + #eventual cropping to pre_crop_size: with prob self.p_fg sample pixel from random ROI and shift center, + #if possible, to that pixel, so that img still contains ROI after pre-cropping + dim_cropflags = [spatial_shp[i] > self.cf.pre_crop_size[i] for i in range(len(spatial_shp))] + if np.any(dim_cropflags): + print("dim crop applied") + # sample pixel from random ROI and shift center, if possible, to that pixel + if self.cf.dim==3: + choose_fg = (empty_samples_count/self.batch_size>=self.empty_samples_max_ratio) or np.random.rand(1) <= self.p_fg + if self.sample_pids_w_replace and choose_fg and np.any(seg): + available_roi_ids = np.unique(seg)[1:] + for tix in np.argsort(batch_roi_counts): + elig_roi_ids = available_roi_ids[ + patient[self.balance_target][available_roi_ids - 1] == self.unique_ts[tix]] + if len(elig_roi_ids) > 0: + seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None)) + break + roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)] + assert seg[tuple(roi_anchor_pixel)] > 0 + + # sample the patch center coords. constrained by edges of image - pre_crop_size /2 and + # distance to the selected ROI < patch_size /2 + def get_cropped_centercoords(dim): + low = np.max((self.cf.pre_crop_size[dim]//2, + roi_anchor_pixel[dim] - (self.cf.patch_size[dim]//2 - self.cf.crop_margin[dim]))) + high = np.min((spatial_shp[dim] - self.cf.pre_crop_size[dim]//2, + roi_anchor_pixel[dim] + (self.cf.patch_size[dim]//2 - self.cf.crop_margin[dim]))) + if low >= high: #happens if lesion on the edge of the image. + #print('correcting low/high:', low, high, spatial_shp, roi_anchor_pixel, dim) + low = self.cf.pre_crop_size[dim] // 2 + high = spatial_shp[dim] - self.cf.pre_crop_size[dim]//2 + + assert low0]) - 1] == self.unique_ts[tix]) + if not np.any(seg): + empty_samples_count += 1 + + #self.stats['roi_counts'] += batch_roi_counts #DOESNT WORK WITH MULTITHREADING! do outside + #self.stats['empty_samples_count'] += empty_samples_count + + batch = {'data': np.array(batch_data), 'seg': np.array(batch_segs).astype('uint8'), + 'pid': batch_patient_ids, 'spec': batch_patient_specs, + 'roi_counts':batch_roi_counts, 'empty_samples_count': empty_samples_count} + for key,val in batch_roi_items.items(): #extend batch dic by roi-wise items (obs, class ids, regression vectors...) + batch[key] = np.array(val) + + return batch + +class PatientBatchIterator(dutils.PatientBatchIterator): + """ + creates a val/test generator. Step through the dataset and return dictionaries per patient. + 2D is a special case of 3D patching with patch_size[2] == 1 (slices) + Creates whole Patient batch and targets, and - if necessary - patchwise batch and targets. + Appends patient targets anyway for evaluation. + For Patching, shifts all patches into batch dimension. batch_tiling_forward will take care of exceeding batch dimensions. + + This iterator/these batches are not intended to go through MTaugmenter afterwards + """ + + def __init__(self, cf, data): + super(PatientBatchIterator, self).__init__(cf, data) + + self.patient_ix = 0 #running index over all patients in set + + self.patch_size = cf.patch_size+[1] if cf.dim==2 else cf.patch_size + self.chans = cf.channels if cf.channels is not None else np.index_exp[:] + assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing" + + def generate_train_batch(self, pid=None): + + if self.patient_ix == len(self.dataset_pids): + self.patient_ix = 0 + if pid is None: + pid = self.dataset_pids[self.patient_ix] # + self.thread_id + patient = self._data[pid] + + #swap dimensions from (c,)z,y,x to c,y,x,z or h,w,d to ease 2D/3D-case handling + data = np.transpose(np.load(patient['img'], mmap_mode='r'), axes=(0, 2, 3, 1)) + seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis] + data_shp_raw = data.shape + plot_bg = data[self.cf.plot_bg_chan] if self.cf.plot_bg_chan not in self.chans else None + data = data[self.chans] + discarded_chans = len( + [c for c in np.setdiff1d(np.arange(data_shp_raw[0]), self.chans) if c < self.cf.plot_bg_chan]) + spatial_shp = data[0].shape # spatial dims need to be in order x,y,z + assert spatial_shp==seg[0].shape, "spatial shape incongruence betw. data and seg" + + if np.any([spatial_shp[i] < ps for i, ps in enumerate(self.patch_size)]): + new_shape = [np.max([spatial_shp[i], self.patch_size[i]]) for i in range(len(self.patch_size))] + data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape. + seg = dutils.pad_nd_image(seg, new_shape) + if plot_bg is not None: + plot_bg = dutils.pad_nd_image(plot_bg, new_shape) + + if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: + #adds the batch dim here bc won't go through MTaugmenter + out_data = data[np.newaxis] + out_seg = seg[np.newaxis] + if plot_bg is not None: + out_plot_bg = plot_bg[np.newaxis] + #data and seg shape: (1,c,x,y,z), where c=1 for seg + batch_3D = {'data': out_data, 'seg': out_seg} + for o in self.cf.roi_items: + batch_3D[o] = np.array([patient[o]]) + converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg) + batch_3D = converter(**batch_3D) + batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape}) + for o in self.cf.roi_items: + batch_3D["patient_" + o] = batch_3D[o] + + if self.cf.dim == 2: + out_data = np.transpose(data, axes=(3,0,1,2)) #(c,y,x,z) to (b=z,c,x,y), use z=b as batchdim + out_seg = np.transpose(seg, axes=(3,0,1,2)).astype('uint8') #(c,y,x,z) to (b=z,c,x,y) + + batch_2D = {'data': out_data, 'seg': out_seg} + for o in self.cf.roi_items: + batch_2D[o] = np.repeat(np.array([patient[o]]), len(out_data), axis=0) + + converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg) + batch_2D = converter(**batch_2D) + + if plot_bg is not None: + out_plot_bg = np.transpose(plot_bg, axes=(2,0,1)).astype('float32') + + if self.cf.merge_2D_to_3D_preds: + batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'], + 'original_img_shape': out_data.shape}) + for o in self.cf.roi_items: + batch_2D["patient_" + o] = batch_3D['patient_'+o] + else: + batch_2D.update({'patient_bb_target': batch_2D['bb_target'], + 'original_img_shape': out_data.shape}) + for o in self.cf.roi_items: + batch_2D["patient_" + o] = batch_2D[o] + + out_batch = batch_3D if self.cf.dim == 3 else batch_2D + out_batch.update({'pid': np.array([patient['pid']] * len(out_data)), + 'spec':np.array([patient['spec']] * len(out_data))}) + + if self.cf.plot_bg_chan in self.chans and discarded_chans>0: + assert plot_bg is None + plot_bg = int(self.cf.plot_bg_chan - discarded_chans) + out_plot_bg = plot_bg + if plot_bg is not None: + out_batch['plot_bg'] = out_plot_bg + + #eventual tiling into patches + spatial_shp = out_batch["data"].shape[2:] + if np.any([spatial_shp[ix] > self.patch_size[ix] for ix in range(len(spatial_shp))]): + patient_batch = out_batch + #print("patientiterator produced patched batch!") + patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size) + new_img_batch, new_seg_batch = [], [] + + for c in patch_crop_coords_list: + new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:c[5]]) + seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]] + new_seg_batch.append(seg_patch) + shps = [] + for arr in new_img_batch: + shps.append(arr.shape) + + data = np.array(new_img_batch) # (patches, c, x, y, z) + seg = np.array(new_seg_batch) + if self.cf.dim == 2: + # all patches have z dimension 1 (slices). discard dimension + data = data[..., 0] + seg = seg[..., 0] + patch_batch = {'data': data, 'seg': seg.astype('uint8'), + 'pid': np.array([patient['pid']] * data.shape[0]), + 'spec':np.array([patient['spec']] * data.shape[0])} + for o in self.cf.roi_items: + patch_batch[o] = np.repeat(np.array([patient[o]]), len(patch_crop_coords_list), axis=0) + # patient-wise (orig) batch info for putting the patches back together after prediction + for o in self.cf.roi_items: + patch_batch["patient_"+o] = patient_batch['patient_'+o] + patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) + patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] + #patch_batch['patient_roi_labels'] = patient_batch['patient_roi_labels'] + patch_batch['patient_data'] = patient_batch['data'] + patch_batch['patient_seg'] = patient_batch['seg'] + patch_batch['original_img_shape'] = patient_batch['original_img_shape'] + if plot_bg is not None: + patch_batch['patient_plot_bg'] = patient_batch['plot_bg'] + + converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, self.cf.class_specific_seg) + + patch_batch = converter(**patch_batch) + out_batch = patch_batch + + self.patient_ix += 1 + # todo raise stopiteration when in test mode + if self.patient_ix == len(self.dataset_pids): + self.patient_ix = 0 + + return out_batch + + +def create_data_gen_pipeline(cf, patient_data, do_aug=True, sample_pids_w_replace=True): + """ + create mutli-threaded train/val/test batch generation and augmentation pipeline. + :param patient_data: dictionary containing one dictionary per patient in the train/test subset + :param test_pids: (optional) list of test patient ids, calls the test generator. + :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing) + :return: multithreaded_generator + """ + data_gen = BatchGenerator(cf, patient_data, sample_pids_w_replace=sample_pids_w_replace) + + my_transforms = [] + if do_aug: + if cf.da_kwargs["mirror"]: + mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes']) + my_transforms.append(mirror_transform) + if cf.da_kwargs["gamma_transform"]: + gamma_transform = GammaTransform(gamma_range=cf.da_kwargs["gamma_range"], invert_image=False, + per_channel=False, retain_stats=True) + my_transforms.append(gamma_transform) + if cf.dim == 3: + # augmentations with desired effect on z-dimension + spatial_transform = SpatialTransform(patch_size=cf.patch_size, + patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], + do_elastic_deform=False, + do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], + angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], + do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], + random_crop=cf.da_kwargs['random_crop'], + border_mode_data=cf.da_kwargs['border_mode_data']) + my_transforms.append(spatial_transform) + # augmentations that are only meant to affect x-y + my_transforms.append(Convert3DTo2DTransform()) + spatial_transform = SpatialTransform(patch_size=cf.patch_size[:2], + patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'][:2], + do_elastic_deform=cf.da_kwargs['do_elastic_deform'], + alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], + do_rotation=False, + do_scale=False, + random_crop=False, + border_mode_data=cf.da_kwargs['border_mode_data']) + my_transforms.append(spatial_transform) + my_transforms.append(Convert2DTo3DTransform()) + + else: + spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], + patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'][:2], + do_elastic_deform=cf.da_kwargs['do_elastic_deform'], + alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], + do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], + angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], + do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], + random_crop=cf.da_kwargs['random_crop'], + border_mode_data=cf.da_kwargs['border_mode_data']) + my_transforms.append(spatial_transform) + else: + my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) + + if cf.create_bounding_box_targets: + my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg)) + #batch receives entry 'bb_target' w bbox coordinates as [y1,x1,y2,x2,z1,z2]. + #my_transforms.append(ConvertSegToOnehotTransform(classes=range(cf.num_seg_classes))) + all_transforms = Compose(my_transforms) + #MTAugmenter creates iterator from data iterator data_gen after applying the composed transform all_transforms + multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, + seeds=list(np.random.randint(0,cf.n_workers*2,size=cf.n_workers))) + return multithreaded_generator + +def get_train_generators(cf, logger, data_statistics=True): + """ + wrapper function for creating the training batch generator pipeline. returns the train/val generators + need to select cv folds on patient level, but be able to include both breasts of each patient. + """ + dataset = Dataset(cf, logger) + + dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits) + dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle')) + set_splits = dataset.fg.splits + + test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold-1) + train_ids = np.concatenate(set_splits, axis=0) + + if cf.held_out_test_set: + train_ids = np.concatenate((train_ids, test_ids), axis=0) + test_ids = [] + + train_data = {k: v for (k, v) in dataset.data.items() if k in train_ids} + val_data = {k: v for (k, v) in dataset.data.items() if k in val_ids} + + logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids), len(test_ids))) + if data_statistics: + dataset.calc_statistics(subsets={"train":train_ids, "val":val_ids, "test":test_ids}, + plot_dir=os.path.join(cf.plot_dir,"dataset")) + + batch_gen = {} + batch_gen['train'] = create_data_gen_pipeline(cf, train_data, do_aug=cf.do_aug) + batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, do_aug=False, sample_pids_w_replace=False) + + if cf.val_mode == 'val_patient': + batch_gen['val_patient'] = PatientBatchIterator(cf, val_data) + batch_gen['n_val'] = len(val_ids) if cf.max_val_patients=="all" else cf.max_val_patients + elif cf.val_mode == 'val_sampling': + batch_gen['n_val'] = cf.num_val_batches if cf.num_val_batches!="all" else len(val_ids) + + return batch_gen + +def get_test_generator(cf, logger): + """ + if get_test_generators is called multiple times in server env, every time of + Dataset initiation rsync will check for copying the data; this should be okay + since rsync will not copy if files already exist in destination. + """ + + if cf.held_out_test_set: + sourcedir = cf.test_data_sourcedir + test_ids = None + else: + sourcedir = None + with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle: + set_splits = pickle.load(handle) + test_ids = set_splits[cf.fold] + + test_set = Dataset(cf, logger, test_ids, data_sourcedir=sourcedir) + logger.info("data set loaded with: {} test patients".format(len(test_set.set_ids))) + batch_gen = {} + batch_gen['test'] = PatientBatchIterator(cf, test_set.data) + batch_gen['n_test'] = len(test_set.set_ids) if cf.max_test_patients=="all" else min(cf.max_test_patients, len(test_set.set_ids)) + + return batch_gen + + +if __name__=="__main__": + import sys + sys.path.append('../') # works on cluster indep from where sbatch job is started + import plotting as plg + import utils.exp_utils as utils + from configs import Configs + cf = configs() + + total_stime = time.time() + times = {} + + #cf.server_env = True + #cf.data_dir = "experiments/dev_data" + + #dataset = Dataset(cf) + #patient = dataset['Master_00018'] + cf.exp_dir = "experiments/dev/" + cf.plot_dir = cf.exp_dir+"plots" + os.makedirs(cf.exp_dir, exist_ok=True) + cf.fold = 0 + logger = utils.get_logger(cf.exp_dir) + gens = get_train_generators(cf, logger) + train_loader = gens['train'] + + #for i in range(train_loader.dataset_length): + # print("batch", i) + stime = time.time() + ex_batch = next(train_loader) + #ex_batch = next(train_loader) + times["train_batch"] = time.time()-stime + plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch.png", show_gt_labels=True) + + #with open(os.path.join(cf.exp_dir, "fold_"+str(cf.fold), "BatchGenerator_stats.txt"), mode="w") as file: + # train_loader.generator.print_stats(logger, file) + + + val_loader = gens['val_sampling'] + stime = time.time() + ex_batch = next(val_loader) + times["val_batch"] = time.time()-stime + stime = time.time() + plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch.png", show_gt_labels=True, plot_mods=False, show_info=False) + times["val_plot"] = time.time()-stime + + test_loader = get_test_generator(cf, logger)["test"] + stime = time.time() + ex_batch = test_loader.generate_train_batch() + print(ex_batch["data"].shape) + times["test_batch"] = time.time()-stime + stime = time.time() + plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/ex_patchbatch.png", show_gt_boxes=False, show_info=False, dpi=400, sample_picks=[2,5], plot_mods=False) + times["test_patchbatch_plot"] = time.time()-stime + + #stime = time.time() + #ex_batch['data'] = ex_batch['patient_data'] + #ex_batch['seg'] = ex_batch['patient_seg'] + #if 'patient_plot_bg' in ex_batch.keys(): + # ex_batch['plot_bg'] = ex_batch['patient_plot_bg'] + #plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png") + #times["test_patientbatch_plot"] = time.time() - stime + + + #print("patch batch keys", ex_batch.keys()) + #print("patch batch les gle", ex_batch["lesion_gleasons"].shape) + #print("patch batch gsbx", ex_batch["GSBx"].shape) + #print("patch batch class_targ", ex_batch["class_targets"].shape) + #print("patient b roi labels", ex_batch["patient_roi_labels"].shape) + #print("patient les gleas", ex_batch["patient_lesion_gleasons"].shape) + #print("patch&patient batch pid", ex_batch["pid"], len(ex_batch["pid"])) + #print("unique patient_seg", np.unique(ex_batch["patient_seg"])) + #print("pb patient roi labels", len(ex_batch["patient_roi_labels"]), ex_batch["patient_roi_labels"]) + #print("pid", ex_batch["pid"]) + + #patient_batch = {k[len("patient_"):]:v for (k,v) in ex_batch.items() if k.lower().startswith("patient")} + #patient_batch["pid"] = ex_batch["pid"] + #stime = time.time() + #plg.view_batch(cf, patient_batch, out_file="experiments/dev_expatientbatch") + #times["test_plot"] = time.time()-stime + + + print("Times recorded throughout:") + for (k,v) in times.items(): + print(k, "{:.2f}".format(v)) + + mins, secs = divmod((time.time() - total_stime), 60) + h, mins = divmod(mins, 60) + t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) + print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) \ No newline at end of file diff --git a/datasets/prostate/data_preprocessing.py b/datasets/prostate/data_preprocessing.py new file mode 100644 index 0000000..ca97532 --- /dev/null +++ b/datasets/prostate/data_preprocessing.py @@ -0,0 +1,809 @@ +__author__ = "Simon Kohl, Gregor Ramien" + + +# subject-wise extractor that does not depend on Prisma/Radval and that checks for geometry miss-alignments +# (corrects them if applicable), images and masks should be stored separately, each in its own memmap +# at run-time, the data-loaders will assemble dicts using the histo csvs +import os +import sys +from multiprocessing import Pool +import warnings +import time +import shutil + +import pandas as pd +import numpy as np +import pickle + +import SimpleITK as sitk +from scipy.ndimage.measurements import center_of_mass + +sys.path.append("../") +import plotting as plg +import data_manager as dmanager + +def save_obj(obj, name): + """Pickle a python object.""" + with open(name + '.pkl', 'wb') as f: + pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) + +def load_array(path): + """Load an image as a numpy array.""" + img = sitk.ReadImage(path) + return sitk.GetArrayFromImage(img) + +def id_to_spec(id, base_spec): + """Construct subject specifier from base string and an integer subject number.""" + num_zeros = 5 - len(str(id)) + assert num_zeros>=0, "id_to_spec: patient id too long to fit into 5 figures" + return base_spec + '_' + ('').join(['0'] * num_zeros) + str(id) + +def spec_to_id(spec): + """Get subject id from string""" + return int(spec[-5:]) + +def has_equal_geometry(img1, img2, precision=0.001): + """Check whether geometries of 2 images match within a given precision.""" + equal = True + + # assert equal image extentions + delta = [abs((img1.GetSize()[i] - img2.GetSize()[i])) < precision for i in range(3)] + if not np.all(delta): + equal = False + + # assert equal origins + delta = [abs((img1.GetOrigin()[i] - img2.GetOrigin()[i])) < precision for i in range(3)] + if not np.all(delta): + equal = False + + # assert equal spacings + delta = [abs((img1.GetSpacing()[i] - img2.GetSpacing()[i])) < precision for i in range(3)] + if not np.all(delta): + equal = False + + return equal + +def resample_to_reference(ref_img, img, interpolation): + """ + Resample an sitk image to a reference image, the size, spacing, + origin and direction of the reference image will be used + :param ref_img: + :param img: + :param interpolation: + :return: interpolated SITK image + """ + if interpolation == 'nearest': + interpolator = sitk.sitkNearestNeighbor #these are just integers + elif interpolation == 'linear': + interpolator = sitk.sitkLinear + elif interpolation == 'bspline': + # basis spline of order 3 + interpolator = sitk.sitkBSpline + else: + raise NotImplementedError('Interpolation of type {} not implemented!'.format(interpolation)) + + img = sitk.Cast(img, sitk.sitkFloat64) + + rif = sitk.ResampleImageFilter() + # set the output size, origin, spacing and direction to that of the provided image + rif.SetReferenceImage(ref_img) + rif.SetInterpolator(interpolator) + + return rif.Execute(img) + +def rescale(img, scaling, interpolation=sitk.sitkBSpline, out_fpath=None): + """ + :param scaling: tuple (z_scale, y_scale, x_scale) of scaling factors + :param out_fpath: filepath (incl filename), if set will write .nrrd (uncompressed) + to that location + + sitk/nrrd images spacing: imgs are treated as physical objects. When resampling, + a given image is re-evaluated (resampled) at given gridpoints, the physical + properties of the image don't change. Hence, if the resampling-grid has a smaller + spacing than the original image(grid), the image is sampled more often than before. + Since every sampling produces one pixel, the resampled image will have more pixels + (when sampled at undefined points of the image grid, the sample values will be + interpolated). I.e., for an upsampling of an image, we need to set a smaller + spacing for the resampling grid and a larger (pixel)size for the resampled image. + """ + (z,y,x) = scaling + + old_size = np.array(img.GetSize()) + old_spacing = np.array(img.GetSpacing()) + + + new_size = (int(old_size[0]*x), int(old_size[1]*y), int(old_size[2]*z)) + new_spacing = old_spacing * (old_size/ new_size) + + rif = sitk.ResampleImageFilter() + + rif.SetReferenceImage(img) + rif.SetInterpolator(interpolation) + rif.SetOutputSpacing(new_spacing) + rif.SetSize(new_size) + + new_img = rif.Execute(img) + + if not out_fpath is None: + writer = sitk.ImageFileWriter() + writer.SetFileName(out_fpath) + writer.SetUseCompression(True) + writer.Execute(new_img) + + return new_img + +def get_valid_z_range(arr): + """ + check which z-slices of an image array aren't constant + :param arr: + :return: min and max valid slice found; under the assumption that invalid + slices occur never inbetween valid slices + """ + + valid_z_slices = [] + for z in range(arr.shape[0]): + if np.var(arr[z]) != 0: + valid_z_slices.append(z) + return valid_z_slices[0], valid_z_slices[-1] + +def convert_to_arrays(data): + """convert to numpy arrays. + sitk.Images have shape (x,y,z), but GetArrayFromImage returns shape (z,y,x) + """ + for mod in data['img'].keys(): + data['img'][mod] = sitk.GetArrayFromImage(data['img'][mod]).astype(np.float32) + + for mask in data['anatomical_masks'].keys(): + data['anatomical_masks'][mask] = sitk.GetArrayFromImage(data['anatomical_masks'][mask]).astype(np.uint8) + + for mask in data['lesions'].keys(): + data['lesions'][mask] = sitk.GetArrayFromImage(data['lesions'][mask]).astype(np.uint8) + return data + +def merge_crossmod_masks(data, rename_tags, mode="union"): + """if data has multiple ground truths (e.g. after registration), merge + masks by mode. class labels (leason gleason) are assumed to be naturally registered (no ambiguity) + :param rename_tags: usually from prepro_cf['rename_tags'] + :param mode: 'union' or name of mod ('adc', 't2') to consider only one gt + """ + + if 'adc' in data['img'].keys() and 't2' in data['img'].keys(): + if mode=='union': + #print("Merging gts of T2, ADC mods. Assuming data is registered!") + tags = list(data["anatomical_masks"].keys()) + for tag in tags: + tags.remove(tag) + merge_with = [mtag for mtag in tags\ + if mtag.lower().split("_")[2]==tag.lower().split("_")[2]] + assert len(merge_with)==1, "attempted to merge {} ground truths".format(len(merge_with)) + merge_with = merge_with[0] + tags.remove(merge_with) + #masks are binary + #will throw error if masks dont have same shape + data["anatomical_masks"][tag] = np.logical_or(data["anatomical_masks"][tag].astype(np.uint8), + data["anatomical_masks"].pop(merge_with).astype(np.uint8)).astype(np.uint8) + + tags = list(data["lesions"].keys()) + for tag in tags: + tags.remove(tag) + merge_with = [mtag for mtag in tags\ + if mtag.lower().split("_")[2]==tag.lower().split("_")[2]] + assert len(merge_with)==1, "attempted to merge {} ground truths".format(len(merge_with)) + merge_with = merge_with[0] + tags.remove(merge_with) + data["lesions"][tag] = np.logical_or(data["lesions"][tag], + data["lesions"].pop(merge_with)).astype(np.uint8) + + elif mode=='adc' or mode=='t2': + data["anatomical_masks"] = {tag:v for tag,v in data["anatomical_masks"].items() if + tag.lower().split("_")[1]==mode} + data["lesions"] = {tag: v for tag, v in data["lesions"].items() if tag.lower().split("_")[1] == mode} + + else: + raise Exception("cross-mod gt merge mode {} not implemented".format(mode)) + + for tag in list(data["anatomical_masks"]): + data["anatomical_masks"][rename_tags[tag]] = data["anatomical_masks"].pop(tag) + #del data["anatomical_masks"][tag] + for tag in list(data["lesions"]): + new_tag = "seg_REG_"+"".join(tag.split("_")[2:]) + data["lesions"][new_tag] = data["lesions"].pop(tag) + data["lesion_gleasons"][new_tag] = data["lesion_gleasons"].pop(tag) + + return data + +def crop_3D(data, pre_crop_size, center_of_mass_crop=True): + pre_crop_size = np.array(pre_crop_size) + # restrain z-ranges to where ADC has valid entries + if 'adc' in data['img'].keys(): + ref_mod = 'adc' + comp_mod = 't2' + else: + ref_mod = 't2' + comp_mod = 'adc' + min_z, max_z = get_valid_z_range(data['img'][ref_mod]) + if comp_mod in data['img'].keys(): + assert (min_z, max_z) == get_valid_z_range(data['img'][comp_mod]), "adc, t2 different valid z range" + + if center_of_mass_crop: + # cut the arrays to the desired x_y_crop_size around the center-of-mass of the PRO segmentation + pro_com = center_of_mass(data['anatomical_masks']['pro']) + center = [int(np.round(i, 0)) for i in pro_com] + else: + center = [data['img'][ref_mod].shape[i] // 2 for i in range(3)] + + + l = pre_crop_size // 2 + #z_low, z_up = max(min_z, center[0] - l[0]), min(max_z + 1, center[0] + l[0]) + z_low, z_up = center[0] - l[0], center[0] + l[0] + while z_lowmax_z+1: + if z_lowmax_z+1: + warnings.warn("could not crop patient {}'s z-dim to demanded size.".format(data['Original_ID'])) + if z_up>max_z+1: + z_low -= 1 + z_up -= 1 + if z_low=0),\ + "Precropsize too large for image dimensions by {} pixels in patient {}".format(d, data['Original_ID']) + + for mod in data['img'].keys(): + data['img'][mod] = data['img'][mod][z_low:z_up, center[1]-l[1]: center[1] + l[1], center[2]-l[2]: center[2]+l[2]] + vals_lst = list(data['img'].values()) + assert np.all([mod.shape==vals_lst[0].shape for mod in vals_lst]),\ + "produced modalities for same subject with different shapes" + + for mask in data['anatomical_masks'].keys(): + data['anatomical_masks'][mask] = data['anatomical_masks'][mask] \ + [z_low:z_up, center[1]-l[1]: center[1]+l[1], center[2]-l[2]: center[2]+l[2]] + + for mask in data['lesions'].keys(): + data['lesions'][mask] = data['lesions'][mask] \ + [z_low:z_up, center[1]-l[1]: center[1]+l[1], center[2]-l[2]: center[2]+l[2]] + return data + +def add_transitional_zone_mask(data): + if 'pro' in data['anatomical_masks'] and 'pz' in data['anatomical_masks']: + intersection = data['anatomical_masks']['pro'] & data['anatomical_masks']['pz'] + data['anatomical_masks']['tz'] = data['anatomical_masks']['pro'] - intersection + return data + +def generate_labels(data, seg_labels, class_labels, gleason_map, observables_rois): + """merge individual binary labels to an integer label mask and create class labels from Gleason score. + if seg_labels has seg_label 'roi': seg label will be roi count. + """ + anatomical_masks2label = [l for l in data['anatomical_masks'].keys() if l in seg_labels.keys()] + + data['seg'] = np.zeros(shape=data['anatomical_masks']['pro'].shape, dtype=np.uint8) + data['roi_classes'] = [] + #data['roi_observables']: dict, each entry is one list of length final roi_count in this patient + data['roi_observables'] = {obs:[] for obs in observables_rois} + roi_count = 0 + + for mask in anatomical_masks2label: + ixs = np.where(data['anatomical_masks'][mask]) + roi_class = class_labels[mask] + if len(ixs)>0 and roi_class!=-1: + roi_count+=1 + label = seg_labels[mask] + if label=='roi': + label = roi_count + data['seg'][ixs] = label + data['roi_classes'].append(roi_class) + for obs in observables_rois: + obs_val = data[obs][mask] if mask in data[obs].keys() else None + data['roi_observables'][obs].append(obs_val) + #print("appended mask lab", class_labels[mask]) + + if "lesions" in seg_labels.keys(): + for lesion_key, lesion_mask in data['lesions'].items(): + ixs = np.where(lesion_mask) + roi_class = class_labels['lesions'] + if roi_class == "gleason": + roi_class = gleason_map(data['lesion_gleasons'][lesion_key]) + # roi_class = data['lesion_gleasons'][lesion_key] + if len(ixs)>0 and roi_class!=-1: + roi_count+=1 + label = seg_labels['lesions'] + if label=='roi': + label = roi_count + data['seg'][ixs] = label + #segs have form: slices x h x w, i.e., one channel per z-slice, each lesion has its own label + data['roi_classes'].append(roi_class) + for obs in observables_rois: + obs_val = data[obs][lesion_key] if lesion_key in data[obs].keys() else None + data['roi_observables'][obs].append(obs_val) + + # data['lesion_gleasons'][label] = data['lesion_gleasons'].pop(lesion_key) + for obs in data['roi_observables'].keys(): + del data[obs] + return data + +def normalize_image(data, normalization_dict): + """normalize the full image.""" + percentiles = normalization_dict['percentiles'] + for mod in data['img'].keys(): + p = np.percentile(data['img'][mod], percentiles[0]) + q = np.percentile(data['img'][mod], percentiles[1]) + masked_img = data['img'][mod][(data['img'][mod] > p) & (data['img'][mod] < q)] + data['img'][mod] = (data['img'][mod] - np.median(masked_img)) / np.std(masked_img) + return data + +def concat_mods(data, mods2concat): + """concat modalities on new first channel + """ + concat_on_channel = [] #holds tmp data to be concatenated on the same channel + for mod in mods2concat: + mod_img = data['img'][mod][np.newaxis] + concat_on_channel.append(mod_img) + data['img'] = np.concatenate(concat_on_channel, axis=0) + + return data + +def swap_yx(data, apply_flag): + """swap x and y axes in img and seg + """ + if apply_flag: + data["img"] = np.swapaxes(data["img"], -1,-2) + data["seg"] = np.swapaxes(data["seg"], -1,-2) + + return data + +def get_fg_z_indices(seg): + """return z-indices of array at which the x-y-arrays have labels!=0, 0 is background + """ + fg_slices = np.argwhere(seg.astype(int))[:,0] + fg_slices = np.unique(fg_slices) + return fg_slices + + +class Preprocessor(): + + def __init__(self, config): + + self._config_path = config.config_path + self.full_cf = config + self._cf = config.prepro + + def get_excluded_master_ids(self): + """Get the Master IDs that are excluded from their corresponding Prisma/Radval/Master IDs.""" + + excluded_prisma = self._cf['excluded_prisma_subjects'] + excluded_radval = self._cf['excluded_radval_subjects'] + excluded_master = self._cf['excluded_master_subjects'] + histo = self._histo_patient_based + + excluded_master_ids = [] + + if len(excluded_prisma) > 0: + for prisma_id in excluded_prisma: + master_spec = histo['Master_ID'][histo['Original_ID'] == id_to_spec(prisma_id, 'Prisma')].values[0] + excluded_master_ids.append(spec_to_id(master_spec)) + + if len(excluded_radval) > 0: + for radval_id in excluded_radval: + master_spec = histo['Master_ID'][histo['Original_ID'] == id_to_spec(radval_id, 'Radiology')].values[0] + excluded_master_ids.append(spec_to_id(master_spec)) + + excluded_master_ids += excluded_master + + return excluded_master_ids + + + def prepare_filenames(self): + """check whether histology-backed subjects and lesions are available in the data and + yield dict of subject file-paths.""" + + # assemble list of histology-backed subject ids and check that corresponding images are available + self._histo_lesion_based = pd.read_csv(os.path.join(self._cf['histo_dir'], self._cf['histo_lesion_based'])) + self._histo_patient_based = pd.read_csv(os.path.join(self._cf['histo_dir'], self._cf['histo_patient_based'])) + + excluded_master_ids = self.get_excluded_master_ids() + self._subj_ids = np.unique(self._histo_lesion_based[self._cf['histo_id_column_name']].values) + self._subj_ids = [s for s in self._subj_ids.tolist() if + s not in excluded_master_ids] + + # get subject directory paths from + img_paths = os.listdir(self._cf['data_dir']) + self._img_paths = [p for p in img_paths if 'Master' in p and len(p) == len('Master') + 6] + + # check that all images of subjects with histology are available + available_subj_ids = np.array([spec_to_id(s) for s in self._img_paths]) + self._missing_image_ids = np.setdiff1d(self._subj_ids, available_subj_ids) + + assert len(self._missing_image_ids)== 0,\ + 'Images of subjs {} are not available.'.format(self._missing_image_ids) + + # make dict holding relevant paths to data of each subject + self._paths_by_subject = {} + for s in self._subj_ids: + self._paths_by_subject[s] = self.load_subject_paths(s) + + + def load_subject_paths(self, subject_id): + """Make dict holding relevant paths to data of a given subject.""" + dir_spec = self._cf['dir_spec'] + s_dict = {} + + # iterate images + images_paths = {} + for kind, filename in self._cf['images'].items(): + filename += self._cf['img_postfix']+self._cf['overall_postfix'] + images_paths[kind] = os.path.join(self._cf['data_dir'], id_to_spec(subject_id, dir_spec), filename) + s_dict['images'] = images_paths + + # iterate anatomical structures + anatomical_masks_paths = {} + for tag in self._cf['anatomical_masks']: + filename = tag + self._cf['overall_postfix'] + anatomical_masks_paths[tag] = os.path.join(self._cf['data_dir'], id_to_spec(subject_id, dir_spec), filename) + s_dict['anatomical_masks'] = anatomical_masks_paths + + # iterate lesions + lesion_names = [] + if 'adc' in self._cf['images']: + lesion_names.extend(self._histo_lesion_based[self._histo_lesion_based[self._cf['histo_id_column_name']]\ + == subject_id]['segmentationsNameADC'].dropna()) + if 't2' in self._cf['images']: + lesion_names.extend(self._histo_lesion_based[self._histo_lesion_based[self._cf['histo_id_column_name']]\ + == subject_id]['segmentationsNameT2'].dropna()) + lesion_paths = {} + for l in lesion_names: + lesion_path = os.path.join(self._cf['data_dir'], id_to_spec(subject_id, dir_spec), + l+self._cf['lesion_postfix']+self._cf['overall_postfix']) + assert os.path.isfile(lesion_path), 'Lesion mask not found under {}!'.format(lesion_path) + + lesion_paths[l] = lesion_path + + s_dict['lesions'] = lesion_paths + return s_dict + + + def load_subject_data(self, subject_id): + """load img data, masks, histo data for a single subject.""" + subj_paths = self._paths_by_subject[subject_id] + data = {} + + # iterate images + data['img'] = {} + for mod in subj_paths['images']: + data['img'][mod] = sitk.ReadImage(subj_paths['images'][mod]) + + # iterate anatomical masks + data['anatomical_masks'] = {} + for tag in subj_paths['anatomical_masks']: + data['anatomical_masks'][tag] = sitk.ReadImage(subj_paths['anatomical_masks'][tag]) + + # iterate lesions, include gleason score + data['lesions'] = {} + data['lesion_gleasons'] = {} + idcol = self._cf['histo_id_column_name'] + subj_histo = self._histo_lesion_based[self._histo_lesion_based[idcol]==subject_id] + for l in subj_paths['lesions']: + #print("subjpaths lesions l ", l) + data['lesions'][l] = sitk.ReadImage(subj_paths['lesions'][l]) + + try: + gleason = subj_histo[subj_histo["segmentationsNameADC"]==l]["Gleason"].tolist()[0] + except IndexError: + gleason = subj_histo[subj_histo["segmentationsNameT2"]==l]["Gleason"].tolist()[0] + + data['lesion_gleasons'][l] = gleason + + # add other subj-specific histo and id data + idcol = self._cf['histo_pb_id_column_name'] + subj_histo = self._histo_patient_based[self._histo_patient_based[idcol]==subject_id] + for d in self._cf['observables_patient']: + data[d] = subj_histo[d].values + + return data + + def analyze_subject_data(self, data): + """record post-alignment geometries.""" + + ref_mods = data['img'].keys() + geos = {} + for ref_mod in ref_mods: + geos[ref_mod] = {'size': data['img'][ref_mod].GetSize(), 'origin': data['img'][ref_mod].GetOrigin(), + 'spacing': data['img'][ref_mod].GetSpacing()} + + return geos + + def process_subject_data(self, data): + """evtly rescale images, check for geometry miss-alignments and perform crop.""" + + if not self._cf['mod_scaling'] == (1,1,1): + for img_name in data['img']: + res_img = rescale(data["img"][img_name], self._cf['mod_scaling']) + data['img'][img_name] = res_img + + #----check geometry alignment between masks and image--- + for tag in self._cf['anatomical_masks']: + if tag.lower().startswith("seg_adc"): + ref_mod = 'adc' + elif tag.lower().startswith("seg_t2"): + ref_mod = 't2' + if not has_equal_geometry(data['img'][ref_mod], data['anatomical_masks'][tag]): + #print("bef", np.unique(sitk.GetArrayFromImage(data['anatomical_masks'][tag]))) + #print('Geometry mismatch: {}, {} is resampled to its image geometry!'.format(data["Original_ID"], tag)) + data['anatomical_masks'][tag] =\ + resample_to_reference(data['img'][ref_mod], data['anatomical_masks'][tag], + interpolation=self._cf['interpolation']) + #print("aft", np.unique(sitk.GetArrayFromImage(data['anatomical_masks'][tag]))) + + for tag in data['lesions'].keys(): + if tag.lower().startswith("seg_adc"): + ref_mod = 'adc' + elif tag.lower().startswith("seg_t2"): + ref_mod = 't2' + if not has_equal_geometry(data['img'][ref_mod], data['lesions'][tag]): + #print('Geometry mismatch: {}, {} is resampled to its image geometry!'.format(data["Original_ID"], tag)) + #print("pre-sampling data type: {}".format(data['lesions'][tag])) + data['lesions'][tag] = resample_to_reference(data['img'][ref_mod], data['lesions'][tag], + interpolation=self._cf['interpolation']) + + + data = convert_to_arrays(data) + data = merge_crossmod_masks(data, self._cf['rename_tags'], mode=self._cf['merge_mode']) + data = crop_3D(data, self._cf['pre_crop_size'], self._cf['center_of_mass_crop']) + data = add_transitional_zone_mask(data) + data = generate_labels(data, self._cf['seg_labels'], self._cf['class_labels'], self._cf['gleason_map'], + self._cf['observables_rois']) + data = normalize_image(data, self._cf['normalization']) + data = concat_mods(data, self._cf['modalities2concat']) + data = swap_yx(data, self._cf["swap_yx_to_xy"]) + + data['fg_slices'] = get_fg_z_indices(data['seg']) + + return data + + def write_subject_arrays(self, data, subject_spec): + """Write arrays to disk and save file names in dict.""" + + out_dir = self._cf['output_directory'] + os.makedirs(out_dir, exist_ok=True) #might throw error if restrictive permissions + + out_dict = {} + + # image(s) + name = subject_spec + '_imgs.npy' + np.save(os.path.join(out_dir, name), data['img']) + out_dict['img'] = name + + # merged labels + name = subject_spec + '_merged_seg.npy' + np.save(os.path.join(out_dir, name), data['seg']) + out_dict['seg'] = name + + # anatomical masks separately + #for mask in list(data['anatomical_masks'].keys()) + (['tz'] if 'tz' in data.keys() else []): + # name = subject_spec + '_{}.npy'.format(mask) + # np.save(os.path.join(out_dir, name), data['anatomical_masks'][mask]) + # out_dict[mask] = name + + # lesion masks and lesion classes separately + #out_dict['lesion_gleasons'] = {} + #for mask in data['lesions'].keys(): + # name = subject_spec + '_{}.npy'.format(mask) + # np.save(os.path.join(out_dir, name), data['lesions'][mask]) + # out_dict[mask] = name + # out_dict['lesion_gleasons'][int(mask[-1])] = data['lesion_gleasons'][int(mask[-1])] + + # roi classes + out_dict['roi_classes'] = data['roi_classes'] + + + # fg_slices info + out_dict['fg_slices'] = data['fg_slices'] + + # other observables + for obs in self._cf['observables_patient']: + out_dict[obs] = data[obs] + for obs in data['roi_observables'].keys(): + out_dict[obs] = data['roi_observables'][obs] + #print("subj outdict ", out_dict.keys()) + return out_dict + + def subject_iteration(self, subj_id): #single iteration, wrapped for pooling + data = self.load_subject_data(subj_id) + data = self.process_subject_data(data) + subj_out_dict = self.write_subject_arrays(data, id_to_spec(subj_id, self._cf['dir_spec'])) + + print('Processed subject {}.'.format(id_to_spec(subj_id, self._cf['dir_spec']))) + + return (subj_id, subj_out_dict) + + def iterate_subjects(self, ids_subset=None, processes=6): + """process all subjects.""" + + if ids_subset is None: + ids_subset = self._subj_ids + else: + ids_subset = np.array(ids_subset) + id_check = np.array([id in self._subj_ids for id in ids_subset]) + assert np.all(id_check), "pids {} not in eligible pids".format(ids_subset[np.invert(id_check)]) + + p = Pool(processes) + subj_out_dicts = p.map(self.subject_iteration, ids_subset) + """note on Pool.map: only takes one arg, pickles the function for execution --> + cannot write to variables defined outside local scope --> cannot write to + self.variables, therefore need to return single subj_out_dicts and join after; + however p.map can access object methods via self.method(). + Is a bit complicated, but speedup is huge. + """ + p.close() + p.join() + assert len(subj_out_dicts)==len(ids_subset), "produced less subject dicts than demanded" + self._info_dict = {id:dic for (id, dic) in subj_out_dicts} + + return + + def subject_analysis(self, subj_id): # single iteration, wrapped for pooling + data = self.load_subject_data(subj_id) + analysis = self.analyze_subject_data(data) + + print('Analyzed subject {}.'.format(id_to_spec(subj_id, self._cf['dir_spec']))) + + return (subj_id, analysis) + + def analyze_subjects(self, ids_subset=None, processes=os.cpu_count()): + """process all subjects.""" + + if ids_subset is None: + ids_subset = self._subj_ids + else: + ids_subset = np.array(ids_subset) + id_check = np.array([id in self._subj_ids for id in ids_subset]) + assert np.all(id_check), "pids {} not in eligible pids".format(ids_subset[np.invert(id_check)]) + + p = Pool(processes) + subj_analyses = p.map(self.subject_analysis, ids_subset) + """note on Pool.map: only takes one arg, pickles the function for execution --> + cannot write to variables defined outside local scope --> cannot write to + self.variables, therefore need to return single subj_out_dicts and join after; + however p.map can access object methods via self.method(). + Is a bit complicated, but speedup is huge. + """ + p.close() + p.join() + + df = pd.DataFrame(columns=['id', 'mod', 'size', 'origin', 'spacing']) + for subj_id, analysis in subj_analyses: + for mod, geo in analysis.items(): + df.loc[len(df)] = [subj_id, mod, np.array(geo['size']), np.array(geo['origin']), np.array(geo['spacing'])] + + os.makedirs(self._cf['output_directory'], exist_ok=True) + df.to_csv(os.path.join(self._cf['output_directory'], "analysis_df")) + + print("\nOver all mods") + print("Size mean {}\u00B1{}".format(df['size'].mean(), np.std(df['size'].values))) + print("Origin mean {}\u00B1{}".format(df['origin'].mean(), np.std(df['origin'].values))) + print("Spacing mean {}\u00B1{}".format(df['spacing'].mean(), np.std(df['spacing'].values))) + print("-----------------------------------------\n") + + for mod in df['mod'].unique(): + print("\nModality: {}".format(mod)) + mod_df = df[df['mod']==mod] + print("Size mean {}\u00B1{}".format(mod_df['size'].mean(), np.std(mod_df['size'].values))) + print("Origin mean {}\u00B1{}".format(mod_df['origin'].mean(), np.std(mod_df['origin'].values))) + print("Spacing mean {}\u00B1{}".format(mod_df['spacing'].mean(), np.std(mod_df['spacing'].values))) + print("-----------------------------------------\n") + return + + + def dump_class_labels(self, out_dir): + """save used GS mapping and class labels to file. + will likely not work if non-lesion classes (anatomy) are contained + """ + #if "gleason_thresh" in self._cf.keys(): + possible_gs = {gs for p_dict in self._info_dict.values() for gs in p_dict['lesion_gleasons']} + gs_mapping_inv = [(self._cf["gleason_map"](gs)+1, gs) for gs in possible_gs] + #elif "gleason_mapping" in self._cf.keys(): + #gs_mapping_inv = [(val + 1, key) for (key, val) in self._cf["gleason_mapping"].items() if val != -1] + classes = {pair[0] for pair in gs_mapping_inv} + groups = [[pair[1] for pair in gs_mapping_inv if pair[0]==cl] for cl in classes] + gr_names = [ "GS{}-{}".format(min(gr), max(gr)) if len(gr)>1 else "GS"+str(*gr) for gr in groups ] + if "color_palette" in self._cf.keys(): + class_labels = {cl: {"gleasons": groups[ix], "name": gr_names[ix], "color": self._cf["color_palette"][ix]} + for ix, cl in enumerate(classes) } + else: + class_labels = {cl: {"gleasons": groups[ix], "name": gr_names[ix], "color": self.full_cf.color_palette[ix]} + for ix, cl in enumerate(classes)} + + save_obj(class_labels, os.path.join(out_dir,"pp_class_labels")) + + + + def save_and_finish(self): + """copy config and used code to out_dir.""" + + out_dir = self._cf['output_directory'] + + # save script + current_script = os.path.realpath(__file__) + shutil.copyfile(current_script, os.path.join(out_dir, 'applied_preprocessing.py')) + + # save config + if self._config_path[-1] == 'c': + self._config_path = self._config_path[:-1] + shutil.copyfile(self._config_path, os.path.join(out_dir, 'applied_config.py')) + + #copy histo data to local dir + lbased = self._cf['histo_lesion_based'] + pbased = self._cf['histo_patient_based'] + os.makedirs(self._cf['histo_dir_out'], exist_ok=True) + shutil.copyfile(self._cf['histo_dir']+lbased, self._cf['histo_dir_out']+lbased) + shutil.copyfile(self._cf['histo_dir']+pbased, self._cf['histo_dir_out']+pbased) + + # save info dict + #print("info dict ", self._info_dict) + save_obj(self._info_dict, self._cf['info_dict_path'][:-4]) + self.dump_class_labels(out_dir) + + return + + def convert_copy_npz(self): + if not self._cf["npz_dir"]: + return + print("npz dir", self._cf['npz_dir']) + os.makedirs(self._cf['npz_dir'], exist_ok=True) + save_obj(self._info_dict, os.path.join(self._cf['npz_dir'], + self._cf['info_dict_path'].split("/")[-1][:-4])) + lbased = self._cf['histo_lesion_based'] + pbased = self._cf['histo_patient_based'] + histo_out = os.path.join(self._cf['npz_dir'], "histos/") + print("histo dir", histo_out) + os.makedirs(histo_out, exist_ok=True) + shutil.copyfile(self._cf['histo_dir']+lbased, histo_out+lbased) + shutil.copyfile(self._cf['histo_dir']+pbased, histo_out+pbased) + shutil.copyfile(os.path.join(self._cf['output_directory'], 'applied_config.py'), + os.path.join(self._cf['npz_dir'], 'applied_config.py')) + shutil.copyfile(os.path.join(self._cf['output_directory'], 'applied_preprocessing.py'), + os.path.join(self._cf['npz_dir'], 'applied_preprocessing.py')) + shutil.copyfile(os.path.join(self._cf['output_directory'], 'pp_class_labels.pkl'), + os.path.join(self._cf['npz_dir'], 'pp_class_labels.pkl')) + + dmanager.pack_dataset(self._cf["output_directory"], self._cf["npz_dir"], recursive=True) + + + + + +if __name__ == "__main__": + + stime = time.time() + + from configs import Configs + cf = configs() + + + pp = Preprocessor(config=cf) + pp.prepare_filenames() + #pp.analyze_subjects(ids_subset=None)#[1,2,3]) + pp.iterate_subjects(ids_subset=None, processes=os.cpu_count()) + pp.save_and_finish() + pp.convert_copy_npz() + + + #patient_id = 17 + #data = pp.load_subject_data(patient_id) + #data = pp.process_subject_data(data) + + #img = data['img'] + #print("img shape ", img.shape) + #print("seg shape ", data['seg'].shape) + #label_remap = {0:0} + #label_remap.update({roi_id : 1 for roi_id in range(1,5)}) + #plg.view_slices(cf, img[0], data['seg'], instance_labels=True, + # out_dir="experiments/dev/ex_slices.png") + + mins, secs = divmod((time.time() - stime), 60) + h, mins = divmod(mins, 60) + t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) + print("Prepro program runtime: {}".format(t)) diff --git a/datasets/toy/configs.py b/datasets/toy/configs.py new file mode 100644 index 0000000..9c39db3 --- /dev/null +++ b/datasets/toy/configs.py @@ -0,0 +1,495 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import sys +import os +sys.path.append(os.path.dirname(os.path.realpath(__file__))) +import numpy as np +from default_configs import DefaultConfigs +from collections import namedtuple + +boxLabel = namedtuple('boxLabel', ["name", "color"]) +Label = namedtuple("Label", ['id', 'name', 'shape', 'radius', 'color', 'regression', 'ambiguities', 'gt_distortion']) +binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'bin_vals']) + +class Configs(DefaultConfigs): + + def __init__(self, server_env=None): + super(Configs, self).__init__(server_env) + + ######################### + # Prepro # + ######################### + + self.pp_rootdir = os.path.join('/mnt/HDD2TB/Documents/data/toy', "cyl1ps_dev") + self.pp_npz_dir = self.pp_rootdir+"_npz" + + self.pre_crop_size = [320,320,8] #y,x,z; determines pp data shape (2D easily implementable, but only 3D for now) + self.min_2d_radius = 6 #in pixels + self.n_train_samples, self.n_test_samples = 80, 80 + + # not actually real one-hot encoding (ohe) but contains more info: roi-overlap only within classes. + self.pp_create_ohe_seg = False + self.pp_empty_samples_ratio = 0.1 + + self.pp_place_radii_mid_bin = True + self.pp_only_distort_2d = True + # outer-most intensity of blurred radii, relative to inner-object intensity. <1 for decreasing, > 1 for increasing. + # e.g.: setting 0.1 means blurred edge has min intensity 10% as large as inner-object intensity. + self.pp_blur_min_intensity = 0.2 + + self.max_instances_per_sample = 1 #how many max instances over all classes per sample (img if 2d, vol if 3d) + self.max_instances_per_class = self.max_instances_per_sample # how many max instances per image per class + self.noise_scale = 0. # std-dev of gaussian noise + + self.ambigs_sampling = "gaussian" #"gaussian" or "uniform" + """ radius_calib: gt distort for calibrating uncertainty. Range of gt distortion is inferable from + image by distinguishing it from the rest of the object. + blurring width around edge will be shifted so that symmetric rel to orig radius. + blurring scale: if self.ambigs_sampling is uniform, distribution's non-zero range (b-a) will be sqrt(12)*scale + since uniform dist has variance (b-a)²/12. b,a will be placed symmetrically around unperturbed radius. + if sampling is gaussian, then scale parameter sets one std dev, i.e., blurring width will be orig_radius * std_dev * 2. + """ + self.ambiguities = { + #set which classes to apply which ambs to below in class labels + #choose out of: 'outer_radius', 'inner_radius', 'radii_relations'. + #kind #probability #scale (gaussian std, relative to unperturbed value) + #"outer_radius": (1., 0.5), + #"outer_radius_xy": (1., 0.5), + #"inner_radius": (0.5, 0.1), + #"radii_relations": (0.5, 0.1), + "radius_calib": (1., 1./6) + } + + # shape choices: 'cylinder', 'block' + self.pp_classes = [Label(1, 'cylinder', 'cylinder', ((6,6,1),(40,40,8)), (*self.blue, 1.), "radius_2d", (), ('radius_calib',)), + #Label(2, 'block', 'block', ((6,6,1),(40,40,8)), (*self.aubergine,1.), "radii_2d", (), ('radius_calib',)) + ] + + + ######################### + # I/O # + ######################### + + #self.data_sourcedir = '/mnt/HDD2TB/Documents/data/toy/cyl1ps_dev' + self.data_sourcedir = '/mnt/HDD2TB/Documents/data/toy/cyl1ps_exact' + #self.data_sourcedir = '/mnt/HDD2TB/Documents/data/toy/cyl1ps_ambig_beyond_bin' + + if server_env: + #self.data_sourcedir = '/datasets/data_ramien/toy/cyl1ps_exact_npz' + self.data_sourcedir = '/datasets/data_ramien/toy/cyl1ps_ambig_beyond_bin_npz' + + self.test_data_sourcedir = os.path.join(self.data_sourcedir, 'test') + self.data_sourcedir = os.path.join(self.data_sourcedir, "train") + + self.info_df_name = 'info_df.pickle' + + # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn', 'detection_fpn']. + self.model = 'retina_net' + self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net') + self.model_path = os.path.join(self.source_dir, self.model_path) + + + ######################### + # Architecture # + ######################### + + # one out of [2, 3]. dimension the model operates in. + self.dim = 2 + + # 'class', 'regression', 'regression_bin', 'regression_ken_gal' + # currently only tested mode is a single-task at a time (i.e., only one task in below list) + # but, in principle, tasks could be combined (e.g., object classes and regression per class) + self.prediction_tasks = ['class',] + + self.start_filts = 48 if self.dim == 2 else 18 + self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2 + self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50' + self.norm = 'instance_norm' # one of None, 'instance_norm', 'batch_norm' + self.relu = 'relu' + # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform') + self.weight_init = None + + self.regression_n_features = 1 # length of regressor target vector + + + ######################### + # Data Loader # + ######################### + + self.num_epochs = 32 + self.num_train_batches = 120 if self.dim == 2 else 80 + self.batch_size = 16 if self.dim == 2 else 8 + + self.n_cv_splits = 4 + # select modalities from preprocessed data + self.channels = [0] + self.n_channels = len(self.channels) + + # which channel (mod) to show as bg in plotting, will be extra added to batch if not in self.channels + self.plot_bg_chan = 0 + self.crop_margin = [20, 20, 1] # has to be smaller than respective patch_size//2 + self.patch_size_2D = self.pre_crop_size[:2] + self.patch_size_3D = self.pre_crop_size[:2]+[8] + + # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation. + self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D + + # ratio of free sampled batch elements before class balancing is triggered + # (>0 to include "empty"/background patches.) + self.batch_random_ratio = 0.2 + self.balance_target = "class_targets" if 'class' in self.prediction_tasks else "rg_bin_targets" + + self.observables_patient = [] + self.observables_rois = [] + + self.seed = 3 #for generating folds + + ############################# + # Colors, Classes, Legends # + ############################# + self.plot_frequency = 1 + + binary_bin_labels = [binLabel(1, 'r<=25', (*self.green, 1.), (1,25)), + binLabel(2, 'r>25', (*self.red, 1.), (25,))] + quintuple_bin_labels = [binLabel(1, 'r2-10', (*self.green, 1.), (2,10)), + binLabel(2, 'r10-20', (*self.yellow, 1.), (10,20)), + binLabel(3, 'r20-30', (*self.orange, 1.), (20,30)), + binLabel(4, 'r30-40', (*self.bright_red, 1.), (30,40)), + binLabel(5, 'r>40', (*self.red, 1.), (40,))] + + # choose here if to do 2-way or 5-way regression-bin classification + task_spec_bin_labels = quintuple_bin_labels + + self.class_labels = [ + # regression: regression-task label, either value or "(x,y,z)_radius" or "radii". + # ambiguities: name of above defined ambig to apply to image data (not gt); need to be iterables! + # gt_distortion: name of ambig to apply to gt only; needs to be iterable! + # #id #name #shape #radius #color #regression #ambiguities #gt_distortion + Label( 0, 'bg', None, (0, 0, 0), (*self.white, 0.), (0, 0, 0), (), ())] + if "class" in self.prediction_tasks: + self.class_labels += self.pp_classes + else: + self.class_labels += [Label(1, 'object', 'object', ('various',), (*self.orange, 1.), ('radius_2d',), ("various",), ('various',))] + + + if any(['regression' in task for task in self.prediction_tasks]): + self.bin_labels = [binLabel(0, 'bg', (*self.white, 1.), (0,))] + self.bin_labels += task_spec_bin_labels + self.bin_id2label = {label.id: label for label in self.bin_labels} + bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels] + self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)} + self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)] + self.bin_dict = {label.id: label.name for label in self.bin_labels if label.id != 0} + + if self.class_specific_seg: + self.seg_labels = self.class_labels + + self.box_type2label = {label.name: label for label in self.box_labels} + self.class_id2label = {label.id: label for label in self.class_labels} + self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0} + + self.seg_id2label = {label.id: label for label in self.seg_labels} + self.cmap = {label.id: label.color for label in self.seg_labels} + + self.plot_prediction_histograms = True + self.plot_stat_curves = False + self.has_colorchannels = False + self.plot_class_ids = True + + self.num_classes = len(self.class_dict) + self.num_seg_classes = len(self.seg_labels) + + ######################### + # Data Augmentation # + ######################### + self.do_aug = True + self.da_kwargs = { + 'mirror': True, + 'mirror_axes': tuple(np.arange(0, self.dim, 1)), + 'do_elastic_deform': False, + 'alpha': (500., 1500.), + 'sigma': (40., 45.), + 'do_rotation': False, + 'angle_x': (0., 2 * np.pi), + 'angle_y': (0., 0), + 'angle_z': (0., 0), + 'do_scale': False, + 'scale': (0.8, 1.1), + 'random_crop': False, + 'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3), + 'border_mode_data': 'constant', + 'border_cval_data': 0, + 'order_data': 1 + } + + if self.dim == 3: + self.da_kwargs['do_elastic_deform'] = False + self.da_kwargs['angle_x'] = (0, 0.0) + self.da_kwargs['angle_y'] = (0, 0.0) # must be 0!! + self.da_kwargs['angle_z'] = (0., 2 * np.pi) + + ######################### + # Schedule / Selection # + ######################### + + # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training) + # the former is morge accurate, while the latter is faster (depending on volume size) + self.val_mode = 'val_sampling' # one of 'val_sampling' , 'val_patient' + if self.val_mode == 'val_patient': + self.max_val_patients = 220 # if 'all' iterates over entire val_set once. + if self.val_mode == 'val_sampling': + self.num_val_batches = 25 if self.dim==2 else 15 + + self.save_n_models = 2 + self.min_save_thresh = 1 if self.dim == 2 else 1 # =wait time in epochs + if "class" in self.prediction_tasks: + self.model_selection_criteria = {name + "_ap": 1. for name in self.class_dict.values()} + elif any("regression" in task for task in self.prediction_tasks): + self.model_selection_criteria = {name + "_ap": 0.2 for name in self.class_dict.values()} + self.model_selection_criteria.update({name + "_avp": 0.8 for name in self.class_dict.values()}) + + self.lr_decay_factor = 0.5 + self.scheduling_patience = int(self.num_epochs / 5) + self.weight_decay = 1e-5 + self.clip_norm = None # number or None + + ######################### + # Testing / Plotting # + ######################### + + self.test_aug_axes = (0,1,(0,1)) # None or list: choices are 0,1,(0,1) + self.held_out_test_set = True + self.max_test_patients = "all" # number or "all" for all + + self.test_against_exact_gt = not 'exact' in self.data_sourcedir + self.val_against_exact_gt = False # True is an unrealistic --> irrelevant scenario. + self.report_score_level = ['rois'] # 'patient' or 'rois' (incl) + self.patient_class_of_interest = 1 + self.patient_bin_of_interest = 2 + + self.eval_bins_separately = False#"additionally" if not 'class' in self.prediction_tasks else False + self.metrics = ['ap', 'auc', 'dice'] + if any(['regression' in task for task in self.prediction_tasks]): + self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp', + 'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp'] + if 'aleatoric' in self.model: + self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted'] + self.evaluate_fold_means = True + + self.ap_match_ious = [0.5] # threshold(s) for considering a prediction as true positive + self.min_det_thresh = 0.3 + + self.model_max_iou_resolution = 0.2 + + # aggregation method for test and val_patient predictions. + # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf, + # nms = standard non-maximum suppression, or None = no clustering + self.clustering = 'wbc' + # iou thresh (exclusive!) for regarding two preds as concerning the same ROI + self.clustering_iou = self.model_max_iou_resolution # has to be larger than desired possible overlap iou of model predictions + + self.merge_2D_to_3D_preds = False + self.merge_3D_iou = self.model_max_iou_resolution + self.n_test_plots = 1 # per fold and rank + + self.test_n_epochs = self.save_n_models # should be called n_test_ens, since is number of models to ensemble over during testing + # is multiplied by (1 + nr of test augs) + + #self.losses_to_monitor += ['class_loss', 'rg_loss'] + + ######################### + # Assertions # + ######################### + if not 'class' in self.prediction_tasks: + assert self.num_classes == 1 + + ######################### + # Add model specifics # + ######################### + + {'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs, + 'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs, + 'detection_unet': self.add_det_unet_configs, 'detection_fpn': self.add_det_fpn_configs + }[self.model]() + + def rg_val_to_bin_id(self, rg_val): + #only meant for isotropic radii!! + # only 2D radii (x and y dims) or 1D (x or y) are expected + return np.round(np.digitize(rg_val, self.bin_edges).mean()) + + + def add_det_fpn_configs(self): + + self.learning_rate = [5 * 1e-4] * self.num_epochs + self.dynamic_lr_scheduling = True + self.scheduling_criterion = 'torch_loss' + self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' + + self.n_roi_candidates = 4 if self.dim == 2 else 6 + # max number of roi candidates to identify per image (slice in 2D, volume in 3D) + + # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') + self.seg_loss_mode = 'wce' + self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1, 1] + + self.fp_dice_weight = 1 if self.dim == 2 else 1 + # if <1, false positive predictions in foreground are penalized less. + + self.detection_min_confidence = 0.05 + # how to determine score of roi: 'max' or 'median' + self.score_det = 'max' + + def add_det_unet_configs(self): + + self.learning_rate = [5 * 1e-4] * self.num_epochs + self.dynamic_lr_scheduling = True + self.scheduling_criterion = "torch_loss" + self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' + + # max number of roi candidates to identify per image (slice in 2D, volume in 3D) + self.n_roi_candidates = 4 if self.dim == 2 else 6 + + # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') + self.seg_loss_mode = 'wce' + self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1, 1] + # if <1, false positive predictions in foreground are penalized less. + self.fp_dice_weight = 1 if self.dim == 2 else 1 + + self.detection_min_confidence = 0.05 + # how to determine score of roi: 'max' or 'median' + self.score_det = 'max' + + self.init_filts = 32 + self.kernel_size = 3 # ks for horizontal, normal convs + self.kernel_size_m = 2 # ks for max pool + self.pad = "same" # "same" or integer, padding of horizontal convs + + def add_mrcnn_configs(self): + + self.learning_rate = [1e-4] * self.num_epochs + self.dynamic_lr_scheduling = True # with scheduler set in exec + self.scheduling_criterion = max(self.model_selection_criteria, key=self.model_selection_criteria.get) + self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' + + # number of classes for network heads: n_foreground_classes + 1 (background) + self.head_classes = self.num_classes + 1 if 'class' in self.prediction_tasks else 2 + + # feed +/- n neighbouring slices into channel dimension. set to None for no context. + self.n_3D_context = None + if self.n_3D_context is not None and self.dim == 2: + self.n_channels *= (self.n_3D_context * 2 + 1) + + self.detect_while_training = True + # disable the re-sampling of mask proposals to original size for speed-up. + # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching), + # mask outputs are optional. + self.return_masks_in_train = True + self.return_masks_in_val = True + self.return_masks_in_test = True + + # feature map strides per pyramid level are inferred from architecture. anchor scales are set accordingly. + self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]} + # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale + # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.) + self.rpn_anchor_scales = {'xy': [[4], [8], [16], [32]], 'z': [[1], [2], [4], [8]]} + # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3. + self.pyramid_levels = [0, 1, 2, 3] + # number of feature maps in rpn. typically lowered in 3D to save gpu-memory. + self.n_rpn_features = 512 if self.dim == 2 else 64 + + # anchor ratios and strides per position in feature maps. + self.rpn_anchor_ratios = [0.5, 1., 2.] + self.rpn_anchor_stride = 1 + # Threshold for first stage (RPN) non-maximum suppression (NMS): LOWER == HARDER SELECTION + self.rpn_nms_threshold = max(0.8, self.model_max_iou_resolution) + + # loss sampling settings. + self.rpn_train_anchors_per_image = 4 + self.train_rois_per_image = 6 # per batch_instance + self.roi_positive_ratio = 0.5 + self.anchor_matching_iou = 0.8 + + # k negative example candidates are drawn from a pool of size k*shem_poolsize (stochastic hard-example mining), + # where k<=#positive examples. + self.shem_poolsize = 2 + + self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3) + self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5) + self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10) + + self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) + self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) + self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]]) + self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1], + self.patch_size_3D[2], self.patch_size_3D[2]]) # y1,x1,y2,x2,z1,z2 + + if self.dim == 2: + self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4] + self.bbox_std_dev = self.bbox_std_dev[:4] + self.window = self.window[:4] + self.scale = self.scale[:4] + + self.plot_y_max = 1.5 + self.n_plot_rpn_props = 5 if self.dim == 2 else 30 # per batch_instance (slice in 2D / patient in 3D) + + # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element. + self.pre_nms_limit = 2000 if self.dim == 2 else 4000 + + # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True, + # since proposals of the entire batch are forwarded through second stage as one "batch". + self.roi_chunk_size = 1300 if self.dim == 2 else 500 + self.post_nms_rois_training = 200 * (self.head_classes-1) if self.dim == 2 else 400 + self.post_nms_rois_inference = 200 * (self.head_classes-1) + + # Final selection of detections (refine_detections) + self.model_max_instances_per_batch_element = 9 if self.dim == 2 else 18 # per batch element and class. + self.detection_nms_threshold = self.model_max_iou_resolution # needs to be > 0, otherwise all predictions are one cluster. + self.model_min_confidence = 0.2 # iou for nms in box refining (directly after heads), should be >0 since ths>=x in mrcnn.py + + if self.dim == 2: + self.backbone_shapes = np.array( + [[int(np.ceil(self.patch_size[0] / stride)), + int(np.ceil(self.patch_size[1] / stride))] + for stride in self.backbone_strides['xy']]) + else: + self.backbone_shapes = np.array( + [[int(np.ceil(self.patch_size[0] / stride)), + int(np.ceil(self.patch_size[1] / stride)), + int(np.ceil(self.patch_size[2] / stride_z))] + for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z'] + )]) + + if self.model == 'retina_net' or self.model == 'retina_unet': + # whether to use focal loss or SHEM for loss-sample selection + self.focal_loss = True + # implement extra anchor-scales according to https://arxiv.org/abs/1708.02002 + self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in + self.rpn_anchor_scales['xy']] + self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in + self.rpn_anchor_scales['z']] + self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3 + + #self.n_rpn_features = 256 if self.dim == 2 else 64 + + # pre-selection of detections for NMS-speedup. per entire batch. + self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size + + # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002 + self.anchor_matching_iou = 0.7 + + if self.model == 'retina_unet': + self.operate_stride1 = True diff --git a/datasets/toy/data_loader.py b/datasets/toy/data_loader.py new file mode 100644 index 0000000..6a59948 --- /dev/null +++ b/datasets/toy/data_loader.py @@ -0,0 +1,600 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import sys +sys.path.append('../') #works on cluster indep from where sbatch job is started +import plotting as plg + +import numpy as np +import os +from collections import OrderedDict +import pandas as pd +import pickle +import time + +# batch generator tools from https://github.com/MIC-DKFZ/batchgenerators +from batchgenerators.dataloading.data_loader import SlimDataLoaderBase +from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror +from batchgenerators.transforms.abstract_transforms import Compose +from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter +from batchgenerators.dataloading import SingleThreadedAugmenter +from batchgenerators.transforms.spatial_transforms import SpatialTransform +from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform +#from batchgenerators.transforms.utility_transforms import ConvertSegToBoundingBoxCoordinates + +sys.path.append(os.path.dirname(os.path.realpath(__file__))) +import utils.exp_utils as utils +import utils.dataloader_utils as dutils +from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates + + +def load_obj(file_path): + with open(file_path, 'rb') as handle: + return pickle.load(handle) + +class Dataset(dutils.Dataset): + r""" Load a dict holding memmapped arrays and clinical parameters for each patient, + evtly subset of those. + If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to + cf.data_dir. + :param cf: config file + :param folds: number of folds out of @params n_cv folds to include + :param n_cv: number of total folds + :return: dict with imgs, segs, pids, class_labels, observables + """ + + def __init__(self, cf, logger, subset_ids=None, data_sourcedir=None, mode='train'): + super(Dataset,self).__init__(cf, data_sourcedir=data_sourcedir) + + load_exact_gts = (mode=='test' or cf.val_mode=="val_patient") and self.cf.test_against_exact_gt + + p_df = pd.read_pickle(os.path.join(self.data_dir, cf.info_df_name)) + + if subset_ids is not None: + p_df = p_df[p_df.pid.isin(subset_ids)] + logger.info('subset: selected {} instances from df'.format(len(p_df))) + + pids = p_df.pid.tolist() + #evtly copy data from data_sourcedir to data_dest + if cf.server_env and not hasattr(cf, "data_dir"): + file_subset = [os.path.join(self.data_dir, '{}.*'.format(pid)) for pid in pids] + file_subset += [os.path.join(self.data_dir, '{}_seg.*'.format(pid)) for pid in pids] + file_subset += [cf.info_df_name] + if load_exact_gts: + file_subset += [os.path.join(self.data_dir, '{}_exact_seg.*'.format(pid)) for pid in pids] + self.copy_data(cf, file_subset=file_subset) + + img_paths = [os.path.join(self.data_dir, '{}.npy'.format(pid)) for pid in pids] + seg_paths = [os.path.join(self.data_dir, '{}_seg.npy'.format(pid)) for pid in pids] + if load_exact_gts: + exact_seg_paths = [os.path.join(self.data_dir, '{}_exact_seg.npy'.format(pid)) for pid in pids] + + class_targets = p_df['class_ids'].tolist() + rg_targets = p_df['regression_vectors'].tolist() + if load_exact_gts: + exact_rg_targets = p_df['undistorted_rg_vectors'].tolist() + fg_slices = p_df['fg_slices'].tolist() + + self.data = OrderedDict() + for ix, pid in enumerate(pids): + self.data[pid] = {'data': img_paths[ix], 'seg': seg_paths[ix], 'pid': pid, + 'fg_slices': np.array(fg_slices[ix])} + if load_exact_gts: + self.data[pid]['exact_seg'] = exact_seg_paths[ix] + if 'class' in self.cf.prediction_tasks: + self.data[pid]['class_targets'] = np.array(class_targets[ix], dtype='uint8') + else: + self.data[pid]['class_targets'] = np.ones_like(np.array(class_targets[ix]), dtype='uint8') + if load_exact_gts: + self.data[pid]['exact_class_targets'] = self.data[pid]['class_targets'] + if any(['regression' in task for task in self.cf.prediction_tasks]): + self.data[pid]['regression_targets'] = np.array(rg_targets[ix], dtype='float16') + self.data[pid]["rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in rg_targets[ix]], dtype='uint8') + if load_exact_gts: + self.data[pid]['exact_regression_targets'] = np.array(exact_rg_targets[ix], dtype='float16') + self.data[pid]["exact_rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in exact_rg_targets[ix]], + dtype='uint8') + + + cf.roi_items = cf.observables_rois[:] + cf.roi_items += ['class_targets'] + if any(['regression' in task for task in self.cf.prediction_tasks]): + cf.roi_items += ['regression_targets'] + cf.roi_items += ['rg_bin_targets'] + + self.set_ids = np.array(list(self.data.keys())) + self.df = None + +class BatchGenerator(dutils.BatchGenerator): + """ + creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D) + from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size. + Actual patch_size is obtained after data augmentation. + :param data: data dictionary as provided by 'load_dataset'. + :param batch_size: number of patients to sample for the batch + :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target + """ + def __init__(self, cf, data, sample_pids_w_replace=True): + super(BatchGenerator, self).__init__(cf, data) + + self.chans = cf.channels if cf.channels is not None else np.index_exp[:] + assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing" + + + self.sample_pids_w_replace = sample_pids_w_replace + self.eligible_pids = list(self._data.keys()) + + self.crop_margin = np.array(self.cf.patch_size) / 8. # min distance of ROI center to edge of cropped_patch. + self.p_fg = 0.5 + self.empty_samples_max_ratio = 0.6 + self.random_count = int(cf.batch_random_ratio * cf.batch_size) + + self.balance_target_distribution(plot=sample_pids_w_replace) + self.stats = {"roi_counts": np.zeros((len(self.unique_ts),), dtype='uint32'), "empty_samples_count": 0} + + + def generate_train_batch(self): + # everything done in here is per batch + # print statements in here get confusing due to multithreading + if self.sample_pids_w_replace: + # fully random patients + batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False)) + # target-balanced patients + batch_patient_ids += list(np.random.choice( + self.dataset_pids, size=self.batch_size - self.random_count, replace=False, p=self.p_probs)) + else: + batch_patient_ids = np.random.choice(self.eligible_pids, size=self.batch_size, + replace=False) + if self.sample_pids_w_replace == False: + self.eligible_pids = [pid for pid in self.eligible_pids if pid not in batch_patient_ids] + if len(self.eligible_pids) < self.batch_size: + self.eligible_pids = self.dataset_pids + + batch_data, batch_segs, batch_patient_targets = [], [], [] + batch_roi_items = {name: [] for name in self.cf.roi_items} + # record roi count of classes in batch + # empty count for full bg samples (empty slices in 2D/patients in 3D) in slot num_classes (last) + batch_roi_counts, empty_samples_count = np.zeros((len(self.unique_ts),), dtype='uint32'), 0 + + for b in range(self.batch_size): + patient = self._data[batch_patient_ids[b]] + + data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis] + seg = np.load(patient['seg'], mmap_mode='r').astype('uint8') + + (c, y, x, z) = data.shape + if self.cf.dim == 2: + elig_slices, choose_fg = [], False + if len(patient['fg_slices']) > 0: + if empty_samples_count / self.batch_size >= self.empty_samples_max_ratio or np.random.rand( + 1) <= self.p_fg: + # fg is to be picked + for tix in np.argsort(batch_roi_counts): + # pick slices of patient that have roi of sought-for target + # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix + elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero( + patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] == + self.unique_ts[tix]) > 0] + if len(elig_slices) > 0: + choose_fg = True + break + else: + # pick bg + elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices']) + if len(elig_slices) > 0: + sl_pick_ix = np.random.choice(elig_slices, size=None) + else: + sl_pick_ix = np.random.choice(z, size=None) + data = data[..., sl_pick_ix] + seg = seg[..., sl_pick_ix] + + spatial_shp = data[0].shape + assert spatial_shp == seg.shape, "spatial shape incongruence betw. data and seg" + if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]): + new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))] + data = dutils.pad_nd_image(data, (len(data), *new_shape)) + seg = dutils.pad_nd_image(seg, new_shape) + + # eventual cropping to pre_crop_size: sample pixel from random ROI and shift center, + # if possible, to that pixel, so that img still contains ROI after pre-cropping + dim_cropflags = [spatial_shp[i] > self.cf.pre_crop_size[i] for i in range(len(spatial_shp))] + if np.any(dim_cropflags): + # sample pixel from random ROI and shift center, if possible, to that pixel + if self.cf.dim==3: + choose_fg = (empty_samples_count/self.batch_size>=self.empty_samples_max_ratio) or np.random.rand(1) <= self.p_fg + if choose_fg and np.any(seg): + available_roi_ids = np.unique(seg)[1:] + for tix in np.argsort(batch_roi_counts): + elig_roi_ids = available_roi_ids[patient[self.balance_target][available_roi_ids-1] == self.unique_ts[tix]] + if len(elig_roi_ids)>0: + seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None)) + break + roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)] + assert seg[tuple(roi_anchor_pixel)] > 0 + + # sample the patch center coords. constrained by edges of image - pre_crop_size /2 and + # distance to the selected ROI < patch_size /2 + def get_cropped_centercoords(dim): + low = np.max((self.cf.pre_crop_size[dim] // 2, + roi_anchor_pixel[dim] - ( + self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim]))) + high = np.min((spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2, + roi_anchor_pixel[dim] + ( + self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim]))) + if low >= high: # happens if lesion on the edge of the image. + low = self.cf.pre_crop_size[dim] // 2 + high = spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2 + + assert low < high, 'low greater equal high, data dimension {} too small, shp {}, patient {}, low {}, high {}'.format( + dim, + spatial_shp, patient['pid'], low, high) + return np.random.randint(low=low, high=high) + else: + # sample crop center regardless of ROIs, not guaranteed to be empty + def get_cropped_centercoords(dim): + return np.random.randint(low=self.cf.pre_crop_size[dim] // 2, + high=spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2) + + sample_seg_center = {} + for dim in np.where(dim_cropflags)[0]: + sample_seg_center[dim] = get_cropped_centercoords(dim) + min_ = int(sample_seg_center[dim] - self.cf.pre_crop_size[dim] // 2) + max_ = int(sample_seg_center[dim] + self.cf.pre_crop_size[dim] // 2) + data = np.take(data, indices=range(min_, max_), axis=dim + 1) # +1 for channeldim + seg = np.take(seg, indices=range(min_, max_), axis=dim) + + batch_data.append(data) + batch_segs.append(seg[np.newaxis]) + + for o in batch_roi_items: #after loop, holds every entry of every batchpatient per observable + batch_roi_items[o].append(patient[o]) + + if self.cf.dim == 3: + for tix in range(len(self.unique_ts)): + batch_roi_counts[tix] += np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix]) + elif self.cf.dim == 2: + for tix in range(len(self.unique_ts)): + batch_roi_counts[tix] += np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix]) + if not np.any(seg): + empty_samples_count += 1 + + batch = {'data': np.array(batch_data), 'seg': np.array(batch_segs).astype('uint8'), + 'pid': batch_patient_ids, + 'roi_counts': batch_roi_counts, 'empty_samples_count': empty_samples_count} + for key,val in batch_roi_items.items(): #extend batch dic by entries of observables dic + batch[key] = np.array(val) + + return batch + +class PatientBatchIterator(dutils.PatientBatchIterator): + """ + creates a test generator that iterates over entire given dataset returning 1 patient per batch. + Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actually evaluation (done in 3D), + if willing to accept speed-loss during training. + Specific properties of toy data set: toy data may be created with added ground-truth noise. thus, there are + exact ground truths (GTs) and noisy ground truths available. the normal or noisy GTs are used in training by + the BatchGenerator. The PatientIterator, however, may use the exact GTs if set in configs. + + :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or + batch_size = n_2D_patches in 2D . + """ + + def __init__(self, cf, data, mode='test'): + super(PatientBatchIterator, self).__init__(cf, data) + + self.patch_size = cf.patch_size_2D + [1] if cf.dim == 2 else cf.patch_size_3D + self.chans = cf.channels if cf.channels is not None else np.index_exp[:] + assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing" + + if (mode=="validation" and hasattr(self.cf, 'val_against_exact_gt') and self.cf.val_against_exact_gt) or \ + (mode == 'test' and self.cf.test_against_exact_gt): + self.gt_prefix = 'exact_' + print("PatientIterator: Loading exact Ground Truths.") + else: + self.gt_prefix = '' + + self.patient_ix = 0 # running index over all patients in set + + def generate_train_batch(self, pid=None): + + if pid is None: + pid = self.dataset_pids[self.patient_ix] + patient = self._data[pid] + + # already swapped dimensions in pp from (c,)z,y,x to c,y,x,z or h,w,d to ease 2D/3D-case handling + data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis] + seg = np.load(patient[self.gt_prefix+'seg']).astype('uint8')[np.newaxis] + + data_shp_raw = data.shape + plot_bg = data[self.cf.plot_bg_chan] if self.cf.plot_bg_chan not in self.chans else None + data = data[self.chans] + discarded_chans = len( + [c for c in np.setdiff1d(np.arange(data_shp_raw[0]), self.chans) if c < self.cf.plot_bg_chan]) + spatial_shp = data[0].shape # spatial dims need to be in order x,y,z + assert spatial_shp == seg[0].shape, "spatial shape incongruence betw. data and seg" + + if np.any([spatial_shp[i] < ps for i, ps in enumerate(self.patch_size)]): + new_shape = [np.max([spatial_shp[i], self.patch_size[i]]) for i in range(len(self.patch_size))] + data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape. + seg = dutils.pad_nd_image(seg, new_shape) + if plot_bg is not None: + plot_bg = dutils.pad_nd_image(plot_bg, new_shape) + + if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: + # adds the batch dim here bc won't go through MTaugmenter + out_data = data[np.newaxis] + out_seg = seg[np.newaxis] + if plot_bg is not None: + out_plot_bg = plot_bg[np.newaxis] + # data and seg shape: (1,c,x,y,z), where c=1 for seg + + batch_3D = {'data': out_data, 'seg': out_seg} + for o in self.cf.roi_items: + batch_3D[o] = np.array([patient[self.gt_prefix+o]]) + converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg) + batch_3D = converter(**batch_3D) + batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape}) + for o in self.cf.roi_items: + batch_3D["patient_" + o] = batch_3D[o] + + if self.cf.dim == 2: + out_data = np.transpose(data, axes=(3, 0, 1, 2)).astype('float32') # (c,y,x,z) to (b=z,c,x,y), use z=b as batchdim + out_seg = np.transpose(seg, axes=(3, 0, 1, 2)).astype('uint8') # (c,y,x,z) to (b=z,c,x,y) + + batch_2D = {'data': out_data, 'seg': out_seg} + for o in self.cf.roi_items: + batch_2D[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(out_data), axis=0) + converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg) + batch_2D = converter(**batch_2D) + + if plot_bg is not None: + out_plot_bg = np.transpose(plot_bg, axes=(2, 0, 1)).astype('float32') + + if self.cf.merge_2D_to_3D_preds: + batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'], + 'original_img_shape': out_data.shape}) + for o in self.cf.roi_items: + batch_2D["patient_" + o] = batch_3D[o] + else: + batch_2D.update({'patient_bb_target': batch_2D['bb_target'], + 'original_img_shape': out_data.shape}) + for o in self.cf.roi_items: + batch_2D["patient_" + o] = batch_2D[o] + + out_batch = batch_3D if self.cf.dim == 3 else batch_2D + out_batch.update({'pid': np.array([patient['pid']] * len(out_data))}) + + if self.cf.plot_bg_chan in self.chans and discarded_chans > 0: # len(self.chans[:self.cf.plot_bg_chan]) self.patch_size[ix] for ix in range(len(spatial_shp))]): + patient_batch = out_batch + print("patientiterator produced patched batch!") + patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size) + new_img_batch, new_seg_batch = [], [] + + for c in patch_crop_coords_list: + new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:c[5]]) + seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]] + new_seg_batch.append(seg_patch) + shps = [] + for arr in new_img_batch: + shps.append(arr.shape) + + data = np.array(new_img_batch) # (patches, c, x, y, z) + seg = np.array(new_seg_batch) + if self.cf.dim == 2: + # all patches have z dimension 1 (slices). discard dimension + data = data[..., 0] + seg = seg[..., 0] + patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'), + 'pid': np.array([patient['pid']] * data.shape[0])} + for o in self.cf.roi_items: + patch_batch[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(patch_crop_coords_list), axis=0) + #patient-wise (orig) batch info for putting the patches back together after prediction + for o in self.cf.roi_items: + patch_batch["patient_"+o] = patient_batch["patient_"+o] + if self.cf.dim == 2: + # this could also be named "unpatched_2d_roi_items" + patch_batch["patient_" + o + "_2d"] = patient_batch[o] + patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) + patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] + if self.cf.dim == 2: + patch_batch['patient_bb_target_2d'] = patient_batch['bb_target'] + patch_batch['patient_data'] = patient_batch['data'] + patch_batch['patient_seg'] = patient_batch['seg'] + patch_batch['original_img_shape'] = patient_batch['original_img_shape'] + if plot_bg is not None: + patch_batch['patient_plot_bg'] = patient_batch['plot_bg'] + + converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, get_rois_from_seg=False, + class_specific_seg=self.cf.class_specific_seg) + + patch_batch = converter(**patch_batch) + out_batch = patch_batch + + self.patient_ix += 1 + if self.patient_ix == len(self.dataset_pids): + self.patient_ix = 0 + + return out_batch + + +def create_data_gen_pipeline(cf, patient_data, do_aug=True, sample_pids_w_replace=True): + """ + create mutli-threaded train/val/test batch generation and augmentation pipeline. + :param patient_data: dictionary containing one dictionary per patient in the train/test subset. + :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing) + :return: multithreaded_generator + """ + + # create instance of batch generator as first element in pipeline. + data_gen = BatchGenerator(cf, patient_data, sample_pids_w_replace=sample_pids_w_replace) + + my_transforms = [] + if do_aug: + if cf.da_kwargs["mirror"]: + mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes']) + my_transforms.append(mirror_transform) + + spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], + patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], + do_elastic_deform=cf.da_kwargs['do_elastic_deform'], + alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], + do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], + angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], + do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], + random_crop=cf.da_kwargs['random_crop']) + + my_transforms.append(spatial_transform) + else: + my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) + + my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg)) + all_transforms = Compose(my_transforms) + # multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms) + multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers)) + return multithreaded_generator + +def get_train_generators(cf, logger, data_statistics=False): + """ + wrapper function for creating the training batch generator pipeline. returns the train/val generators. + selects patients according to cv folds (generated by first run/fold of experiment): + splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set) + If cf.hold_out_test_set is True, adds the test split to the training data. + """ + dataset = Dataset(cf, logger) + dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits) + dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle')) + set_splits = dataset.fg.splits + + test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1) + train_ids = np.concatenate(set_splits, axis=0) + + if cf.held_out_test_set: + train_ids = np.concatenate((train_ids, test_ids), axis=0) + test_ids = [] + + train_data = {k: v for (k, v) in dataset.data.items() if str(k) in train_ids} + val_data = {k: v for (k, v) in dataset.data.items() if str(k) in val_ids} + + logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids), + len(test_ids))) + if data_statistics: + dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids}, plot_dir= + os.path.join(cf.plot_dir,"dataset")) + + batch_gen = {} + batch_gen['train'] = create_data_gen_pipeline(cf, train_data, do_aug=cf.do_aug, sample_pids_w_replace=True) + batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, do_aug=False, sample_pids_w_replace=False) + + if cf.val_mode == 'val_patient': + batch_gen['val_patient'] = PatientBatchIterator(cf, val_data, mode='validation') + batch_gen['n_val'] = len(val_ids) if cf.max_val_patients is None else cf.max_val_patients + elif cf.val_mode == 'val_sampling': + batch_gen['n_val'] = cf.num_val_batches if cf.num_val_batches != "all" else len(val_data) + + return batch_gen + +def get_test_generator(cf, logger): + """ + if get_test_generators is possibly called multiple times in server env, every time of + Dataset initiation rsync will check for copying the data; this should be okay + since rsync will not copy if files already exist in destination. + """ + + if cf.held_out_test_set: + sourcedir = cf.test_data_sourcedir + test_ids = None + else: + sourcedir = None + with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle: + set_splits = pickle.load(handle) + test_ids = set_splits[cf.fold] + + test_set = Dataset(cf, logger, subset_ids=test_ids, data_sourcedir=sourcedir, mode='test') + logger.info("data set loaded with: {} test patients".format(len(test_set.set_ids))) + batch_gen = {} + batch_gen['test'] = PatientBatchIterator(cf, test_set.data) + batch_gen['n_test'] = len(test_set.set_ids) if cf.max_test_patients=="all" else \ + min(cf.max_test_patients, len(test_set.set_ids)) + + return batch_gen + + +if __name__=="__main__": + + import utils.exp_utils as utils + from configs import Configs + + cf = configs() + + total_stime = time.time() + times = {} + + # cf.server_env = True + # cf.data_dir = "experiments/dev_data" + + cf.exp_dir = "experiments/dev/" + cf.plot_dir = cf.exp_dir + "plots" + os.makedirs(cf.exp_dir, exist_ok=True) + cf.fold = 0 + logger = utils.get_logger(cf.exp_dir) + gens = get_train_generators(cf, logger) + train_loader = gens['train'] + for i in range(1): + stime = time.time() + print("producing training batch nr ", i) + ex_batch = next(train_loader) + times["train_batch"] = time.time() - stime + #experiments/dev/dev_exbatch_{}.png".format(i) + plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=False) + + + val_loader = gens['val_sampling'] + stime = time.time() + for i in range(0): + ex_batch = next(val_loader) + times["val_batch"] = time.time() - stime + stime = time.time() + #"experiments/dev/dev_exvalbatch_{}.png" + plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=True) + times["val_plot"] = time.time() - stime + # + test_loader = get_test_generator(cf, logger)["test"] + stime = time.time() + ex_batch = test_loader.generate_train_batch(pid=None) + times["test_batch"] = time.time() - stime + stime = time.time() + plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png", vmin=0) + times["test_patchbatch_plot"] = time.time() - stime + + + + print("Times recorded throughout:") + for (k, v) in times.items(): + print(k, "{:.2f}".format(v)) + + mins, secs = divmod((time.time() - total_stime), 60) + h, mins = divmod(mins, 60) + t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) + print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) \ No newline at end of file diff --git a/datasets/toy/generate_toys.py b/datasets/toy/generate_toys.py new file mode 100644 index 0000000..7d430d5 --- /dev/null +++ b/datasets/toy/generate_toys.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" Generate a data set of toy examples. Examples can be cylinders, spheres, blocks, diamonds. + Distortions may be applied, e.g., noise to the radius ground truths. + Settings are configured in configs file. +""" + +import plotting as plg +import os +import shutil +import warnings +import time +from multiprocessing import Pool + +import numpy as np +import pandas as pd + +import data_manager as dmanager + + +for msg in ["RuntimeWarning: divide by zero encountered in true_divide.*",]: + warnings.filterwarnings("ignore", msg) + + +class ToyGenerator(object): + """ Generator of toy data set. + A train and a test split with certain nr of samples are created and saved to disk. Samples can contain varying + number of objects. Objects have shapes cylinder or block (diamond, ellipsoid, torus not fully implemented). + + self.mp_args holds image split and id, objects are then randomly drawn into each image. Multi-processing is + enabled for parallel creation of images, final .npy-files can then be converted to .npz. + """ + def __init__(self, cf): + """ + :param cf: configs file holding object specifications and output directories. + """ + + self.cf = cf + + self.n_train, self.n_test = cf.n_train_samples, cf.n_test_samples + self.sample_size = cf.pre_crop_size + self.dim = len(self.sample_size) + self.class_radii = np.array([label.radius for label in self.cf.pp_classes if label.id!=0]) + self.class_id2label = {label.id: label for label in self.cf.pp_classes} + + self.mp_args = [] + # count sample ids consecutively over train, test splits within on dataset (one shape kind) + self.last_s_id = 0 + for split in ["train", "test"]: + self.set_splits_info(split) + + def set_splits_info(self, split): + """ Set info for data set splits, i.e., directory and nr of samples. + :param split: name of split, in {"train", "test"}. + """ + out_dir = os.path.join(self.cf.pp_rootdir, split) + os.makedirs(out_dir, exist_ok=True) + + n_samples = self.n_train if "train" in split else self.n_test + + self.mp_args+= [[out_dir, self.last_s_id+running_id] for running_id in range(n_samples)] + self.last_s_id+= n_samples + + def generate_sample_radii(self, class_ids, shapes): + + # the radii set in labels are ranges to sample from in the form [(min_x,min_y,min_z), (max_x,max_y,max_z)] + all_radii = [] + for ix, cl_radii in enumerate([self.class_radii[cl_id - 1].transpose() for cl_id in class_ids]): + if "cylinder" in shapes[ix] or "block" in shapes[ix]: + # maintain 2D aspect ratio + sample_radii = [np.random.uniform(*cl_radii[0])] * 2 + assert len(sample_radii) == 2, "upper sr {}, cl_radii {}".format(sample_radii, cl_radii) + if self.cf.pp_place_radii_mid_bin: + bef_conv_r = np.copy(sample_radii) + bin_id = self.cf.rg_val_to_bin_id(bef_conv_r) + assert np.isscalar(bin_id) + sample_radii = self.cf.bin_id2rg_val[bin_id]*2 + assert len(sample_radii) == 2, "mid before sr {}, sr {}, rgv2bid {}, cl_radii {}, bid2rgval {}".format(bef_conv_r, sample_radii, bin_id, cl_radii, + self.cf.bin_id2rg_val[bin_id]) + else: + raise NotImplementedError("requested object shape {}".format(shapes[ix])) + if self.dim == 3: + assert len(sample_radii) == 2, "lower sr {}, cl_radii {}".format(sample_radii, cl_radii) + #sample_radii += [np.random.uniform(*cl_radii[2])] + sample_radii = np.concatenate((sample_radii, np.random.uniform(*cl_radii[2], size=1))) + all_radii.append(sample_radii) + + return all_radii + + def apply_gt_distort(self, class_id, radii, radii_divs, outer_min_radii=None, outer_max_radii=None): + """ Apply a distortion to the ground truth (gt). This is motivated by investigating the effects of noisy labels. + GTs that can be distorted are the object radii and ensuing GT quantities like segmentation and regression + targets. + :param class_id: class id of object. + :param radii: radii of object. This is in the abstract sense, s.t. for a block-shaped object radii give the side + lengths. + :param radii_divs: radii divisors, i.e., fractions to take from radii to get inner radii of hole-shaped objects, + like a torus. + :param outer_min_radii: min radii assignable when distorting gt. + :param outer_max_radii: max radii assignable when distorting gt. + :return: + """ + applied_gt_distort = False + for ambig in self.class_id2label[class_id].gt_distortion: + if self.cf.ambiguities[ambig][0] > np.random.rand(): + if ambig == "outer_radius": + radii = radii * abs(np.random.normal(1., self.cf.ambiguities["outer_radius"][1])) + applied_gt_distort = True + if ambig == "radii_relations": + radii = radii * abs(np.random.normal(1.,self.cf.ambiguities["radii_relations"][1],size=len(radii))) + applied_gt_distort = True + if ambig == "inner_radius": + radii_divs = radii_divs * abs(np.random.normal(1., self.cf.ambiguities["inner_radius"][1])) + applied_gt_distort = True + if ambig == "radius_calib": + if self.cf.ambigs_sampling=="uniform": + radii = abs(np.random.uniform(outer_min_radii, outer_max_radii)) + elif self.cf.ambigs_sampling=="gaussian": + distort = abs(np.random.normal(1, scale=self.cf.ambiguities["radius_calib"][1], size=None)) + assert len(radii) == self.dim, "radii {}".format(radii) + radii *= [distort, distort, 1.] if self.cf.pp_only_distort_2d else distort + applied_gt_distort = True + return radii, radii_divs, applied_gt_distort + + def draw_object(self, img, seg, undistorted_seg, ics, regress_targets, undistorted_rg_targets, applied_gt_distort, + roi_ix, class_id, shape, radii, center): + """ Draw a single object into the given image and add it to the corresponding ground truths. + :param img: image (volume) to hold the object. + :param seg: pixel-wise labelling of the image, possibly distorted if gt distortions are applied. + :param undistorted_seg: certainly undistorted, i.e., exact segmentation of object. + :param ics: indices which mark the positions within the image. + :param regress_targets: regression targets (e.g., 2D radii of object), evtly distorted. + :param undistorted_rg_targets: undistorted regression targets. + :param applied_gt_distort: boolean, whether or not gt distortion was applied. + :param roi_ix: running index of object in whole image. + :param class_id: class id of object. + :param shape: shape of object (e.g., whether to draw a cylinder, or block, or ...). + :param radii: radii of object (in an abstract sense, i.e., radii are side lengths in case of block shape). + :param center: center of object in image coordinates. + :return: img, seg, undistorted_seg, regress_targets, undistorted_rg_targets, applied_gt_distort, which are now + extended are amended to reflect the new object. + """ + + radii_blur = hasattr(self.cf, "ambiguities") and hasattr(self.class_id2label[class_id], + "gt_distortion") and 'radius_calib' in \ + self.class_id2label[class_id].gt_distortion + + if radii_blur: + blur_width = self.cf.ambiguities['radius_calib'][1] + if self.cf.ambigs_sampling == "uniform": + blur_width *= np.sqrt(12) + if self.cf.pp_only_distort_2d: + outer_max_radii = np.concatenate((radii[:2] + blur_width * radii[:2], [radii[2]])) + outer_min_radii = np.concatenate((radii[:2] - blur_width * radii[:2], [radii[2]])) + #print("belt width ", outer_max_radii - outer_min_radii) + else: + outer_max_radii = radii + blur_width * radii + outer_min_radii = radii - blur_width * radii + else: + outer_max_radii, outer_min_radii = radii, radii + + if "ellipsoid" in shape or "torus" in shape: + # sphere equation: (x-h)**2 + (y-k)**2 - (z-l)**2 = r**2 + # ellipsoid equation: ((x-h)/a)**2+((y-k)/b)**2+((z-l)/c)**2 <= 1; a, b, c the "radii"/ half-length of principal axes + obj = ((ics - center) / radii) ** 2 + elif "diamond" in shape: + # diamond equation: (|x-h|)/a+(|y-k|)/b+(|z-l|)/c <= 1 + obj = abs(ics - center) / radii + elif "cylinder" in shape: + # cylinder equation:((x-h)/a)**2 + ((y-k)/b)**2 <= 1 while |z-l| <= c + obj = ((ics - center).astype("float64") / radii) ** 2 + # set z values s.t. z slices outside range are sorted out + obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= radii[2], 0., 1.1) + if radii_blur: + inner_obj = ((ics - center).astype("float64") / outer_min_radii) ** 2 + inner_obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= outer_min_radii[2], 0., 1.1) + outer_obj = ((ics - center).astype("float64") / outer_max_radii) ** 2 + outer_obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= outer_max_radii[2], 0., 1.1) + # radial dists: sqrt( (x-h)**2 + (y-k)**2 + (z-l)**2 ) + obj_radial_dists = np.sqrt(np.sum((ics - center).astype("float64")**2, axis=1)) + elif "block" in shape: + # block equation: (|x-h|)/a+(|y-k|)/b <= 1 while |z-l| <= c + obj = abs(ics - center) / radii + obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= radii[2], 0., 1.1) + if radii_blur: + inner_obj = abs(ics - center) / outer_min_radii + inner_obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= outer_min_radii[2], 0., 1.1) + outer_obj = abs(ics - center) / outer_max_radii + outer_obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= outer_max_radii[2], 0., 1.1) + obj_radial_dists = np.sum(abs(ics - center), axis=1).astype("float64") + else: + raise Exception("Invalid object shape '{}'".format(shape)) + + # create the "original" GT, i.e., the actually true object and draw it into undistorted seg. + obj = (np.sum(obj, axis=1) <= 1) + obj = obj.reshape(seg[0].shape) + slices_to_discard = np.where(np.count_nonzero(np.count_nonzero(obj, axis=0), axis=0) <= self.cf.min_2d_radius)[0] + obj[..., slices_to_discard] = 0 + undistorted_radii = np.copy(radii) + undistorted_seg[class_id][obj] = roi_ix + 1 + obj = obj.astype('float64') + + if radii_blur: + inner_obj = np.sum(inner_obj, axis=1) <= 1 + outer_obj = (np.sum(outer_obj, axis=1) <= 1) & ~inner_obj + obj_radial_dists[outer_obj] = obj_radial_dists[outer_obj] / max(obj_radial_dists[outer_obj]) + intensity_slope = self.cf.pp_blur_min_intensity - 1. + # intensity(r) = (i(r_max)-i(0))/r_max * r + i(0), where i(0)==1. + obj_radial_dists[outer_obj] = obj_radial_dists[outer_obj] * intensity_slope + 1. + inner_obj = inner_obj.astype('float64') + #outer_obj, obj_radial_dists = outer_obj.reshape(seg[0].shape), obj_radial_dists.reshape(seg[0].shape) + inner_obj += np.where(outer_obj, obj_radial_dists, 0.) + obj = inner_obj.reshape(seg[0].shape) + if not np.any(obj): + print("An object was completely discarded due to min 2d radius requirement, discarded slices: {}.".format( + slices_to_discard)) + # draw the evtly blurred obj into image. + img += obj * (class_id + 1.) + + if hasattr(self.cf, "ambiguities") and hasattr(self.class_id2label[class_id], "gt_distortion"): + radii_divs = [None] # dummy since not implemented yet + radii, radii_divs, applied_gt_distort = self.apply_gt_distort(class_id, radii, radii_divs, + outer_min_radii, outer_max_radii) + if applied_gt_distort: + if "ellipsoid" in shape or "torus" in shape: + obj = ((ics - center) / radii) ** 2 + elif 'diamond' in shape: + obj = abs(ics - center) / radii + elif "cylinder" in shape: + obj = ((ics - center) / radii) ** 2 + obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= radii[2], 0., 1.1) + elif "block" in shape: + obj = abs(ics - center) / radii + obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= radii[2], 0., 1.1) + obj = (np.sum(obj, axis=1) <= 1).reshape(seg[0].shape) + obj[..., slices_to_discard] = False + + if self.class_id2label[class_id].regression == "radii": + regress_targets.append(radii) + undistorted_rg_targets.append(undistorted_radii) + elif self.class_id2label[class_id].regression == "radii_2d": + regress_targets.append(radii[:2]) + undistorted_rg_targets.append(undistorted_radii[:2]) + elif self.class_id2label[class_id].regression == "radius_2d": + regress_targets.append(radii[:1]) + undistorted_rg_targets.append(undistorted_radii[:1]) + else: + regress_targets.append(self.class_id2label[class_id].regression) + undistorted_rg_targets.append(self.class_id2label[class_id].regression) + + seg[class_id][obj.astype('bool')] = roi_ix + 1 + + return img, seg, undistorted_seg, regress_targets, undistorted_rg_targets, applied_gt_distort + + def create_sample(self, args): + """ Create a single sample and save to file. One sample is one image (volume) containing none, one, or multiple + objects. + :param args: out_dir: directory where to save sample, s_id: id of the sample. + :return: specs that identify this single created image + """ + out_dir, s_id = args + + print('processing {} {}'.format(out_dir, s_id)) + img = np.random.normal(loc=0.0, scale=self.cf.noise_scale, size=self.sample_size) + img[img<0.] = 0. + # one-hot-encoded seg + seg = np.zeros((self.cf.num_classes+1, *self.sample_size)).astype('uint8') + undistorted_seg = np.copy(seg) + applied_gt_distort = False + + if hasattr(self.cf, "pp_empty_samples_ratio") and self.cf.pp_empty_samples_ratio >= np.random.rand(): + # generate fully empty sample + class_ids, regress_targets, undistorted_rg_targets = [], [], [] + else: + class_choices = np.repeat(np.arange(1, self.cf.num_classes+1), self.cf.max_instances_per_class) + n_insts = np.random.randint(1, self.cf.max_instances_per_sample + 1) + class_ids = np.random.choice(class_choices, size=n_insts, replace=False) + shapes = np.array([self.class_id2label[cl_id].shape for cl_id in class_ids]) + all_radii = self.generate_sample_radii(class_ids, shapes) + + # reorder s.t. larger objects are drawn first (in order to not fully cover smaller objects) + order = np.argsort(-1*np.prod(all_radii,axis=1)) + class_ids = class_ids[order]; all_radii = np.array(all_radii)[order]; shapes = shapes[order] + + regress_targets, undistorted_rg_targets = [], [] + # indices ics equal positions within img/volume + ics = np.argwhere(np.ones(seg[0].shape)) + for roi_ix, class_id in enumerate(class_ids): + radii = all_radii[roi_ix] + # enforce distance between object center and image edge relative to radii. + margin_r_divisor = (2, 2, 4) + center = [np.random.randint(radii[dim] / margin_r_divisor[dim], img.shape[dim] - + radii[dim] / margin_r_divisor[dim]) for dim in range(len(img.shape))] + + img, seg, undistorted_seg, regress_targets, undistorted_rg_targets, applied_gt_distort = \ + self.draw_object(img, seg, undistorted_seg, ics, regress_targets, undistorted_rg_targets, applied_gt_distort, + roi_ix, class_id, shapes[roi_ix], radii, center) + + fg_slices = np.where(np.sum(np.sum(np.sum(seg,axis=0), axis=0), axis=0))[0] + if self.cf.pp_create_ohe_seg: + img = img[np.newaxis] + else: + # choosing rois to keep by smaller radius==higher prio needs to be ensured during roi generation, + # smaller objects need to be drawn later (==higher roi id) + seg = seg.max(axis=0) + seg_ids = np.unique(seg) + if len(seg_ids) != len(class_ids) + 1: + # in this case an object was completely covered by a succeeding object + print("skipping corrupt sample") + print("seg ids {}, class_ids {}".format(seg_ids, class_ids)) + return None + if not applied_gt_distort: + assert np.all(np.flatnonzero(img>0) == np.flatnonzero(seg>0)) + assert np.all(np.array(regress_targets).flatten()==np.array(undistorted_rg_targets).flatten()) + + out_path = os.path.join(out_dir, '{}.npy'.format(s_id)) + np.save(out_path, img.astype('float16')); np.save(os.path.join(out_dir, '{}_seg.npy'.format(s_id)), seg) + if hasattr(self.cf, 'ambiguities') and \ + np.any([hasattr(label, "gt_distortion") and len(label.gt_distortion)>0 for label in self.class_id2label.values()]): + undist_out_path = os.path.join(out_dir, '{}_exact_seg.npy'.format(s_id)) + if not self.cf.pp_create_ohe_seg: + undistorted_seg = undistorted_seg.max(axis=0) + np.save(undist_out_path, undistorted_seg) + + return [out_dir, out_path, class_ids, regress_targets, fg_slices, undistorted_rg_targets, str(s_id)] + + def create_sets(self, processes=os.cpu_count()): + """ Create whole training and test set, save to files under given directory cf.out_dir. + :param processes: nr of parallel processes. + """ + print('starting creation of {} images'.format(len(self.mp_args))) + shutil.copyfile("configs.py", os.path.join(self.cf.pp_rootdir, 'applied_configs.py')) + pool = Pool(processes=processes) + imgs_info = pool.map(self.create_sample, self.mp_args, chunksize=1) + pool.close() + pool.join() + imgs_info = [img for img in imgs_info if img is not None] + print("created a total of {} samples.".format(len(imgs_info))) + self.df = pd.DataFrame.from_records(imgs_info, columns=['out_dir', 'path', 'class_ids', 'regression_vectors', + 'fg_slices', 'undistorted_rg_vectors', 'pid']) + + for out_dir, group_df in self.df.groupby("out_dir"): + group_df.to_pickle(os.path.join(out_dir, 'info_df.pickle')) + + + def convert_copy_npz(self): + """ Convert a copy of generated .npy-files to npz and save in .npz-directory given in configs. + """ + if hasattr(self.cf, "pp_npz_dir") and self.cf.pp_npz_dir: + for out_dir, group_df in self.df.groupby("out_dir"): + rel_dir = os.path.relpath(out_dir, self.cf.pp_rootdir).split(os.sep) + npz_out_dir = os.path.join(self.cf.pp_npz_dir, str(os.sep).join(rel_dir)) + print("npz out dir: ", npz_out_dir) + os.makedirs(npz_out_dir, exist_ok=True) + group_df.to_pickle(os.path.join(npz_out_dir, 'info_df.pickle')) + dmanager.pack_dataset(out_dir, npz_out_dir, recursive=True, verbose=False) + else: + print("Did not convert .npy-files to .npz because npz directory not set in configs.") + + +if __name__ == '__main__': + import configs as cf + cf = cf.configs() + total_stime = time.time() + + toy_gen = ToyGenerator(cf) + toy_gen.create_sets() + toy_gen.convert_copy_npz() + + + mins, secs = divmod((time.time() - total_stime), 60) + h, mins = divmod(mins, 60) + t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) + print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) diff --git a/default_configs.py b/default_configs.py new file mode 100644 index 0000000..c2d16e2 --- /dev/null +++ b/default_configs.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Default Configurations script. Avoids changing configs of all experiments if general settings are to be changed.""" + +import os +from collections import namedtuple + +boxLabel = namedtuple('boxLabel', ["name", "color"]) + +class DefaultConfigs: + + def __init__(self, server_env=None, dim=2): + self.server_env = server_env + self.cuda_benchmark = True + ######################### + # I/O # + ######################### + + self.dim = dim + # int [0 < dataset_size]. select n patients from dataset for prototyping. + self.select_prototype_subset = None + + # some default paths. + self.source_dir = os.path.dirname(os.path.realpath(__file__)) # current dir. + self.backbone_path = os.path.join(self.source_dir, 'models/backbone.py') + self.input_df_name = 'info_df.pickle' + + + if server_env: + self.select_prototype_subset = None + + ######################### + # Colors/legends # + ######################### + + # in part from solarized theme. + self.black = (0.1, 0.05, 0.) + self.gray = (0.514, 0.580, 0.588) + self.beige = (1., 1., 0.85) + self.white = (0.992, 0.965, 0.890) + + self.green = (0.659, 0.792, 0.251) # [168, 202, 64] + self.dark_green = (0.522, 0.600, 0.000) # [133.11, 153. , 0. ] + self.cyan = (0.165, 0.631, 0.596) # [ 42.075, 160.905, 151.98 ] + self.bright_blue = (0.85, 0.95, 1.) + self.blue = (0.149, 0.545, 0.824) # [ 37.995, 138.975, 210.12 ] + self.dkfz_blue = (0, 75. / 255, 142. / 255) + self.dark_blue = (0.027, 0.212, 0.259) # [ 6.885, 54.06 , 66.045] + self.purple = (0.424, 0.443, 0.769) # [108.12 , 112.965, 196.095] + self.aubergine = (0.62, 0.21, 0.44) # [ 157, 53 , 111] + self.magenta = (0.827, 0.212, 0.510) # [210.885, 54.06 , 130.05 ] + self.coral = (1., 0.251, 0.4) # [255,64,102] + self.bright_red = (1., 0.15, 0.1) # [255, 38.25, 25.5] + self.brighter_red = (0.863, 0.196, 0.184) # [220.065, 49.98 , 46.92 ] + self.red = (0.87, 0.05, 0.01) # [ 223, 13, 2] + self.dark_red = (0.6, 0.04, 0.005) + self.orange = (0.91, 0.33, 0.125) # [ 232.05 , 84.15 , 31.875] + self.dark_orange = (0.796, 0.294, 0.086) #[202.98, 74.97, 21.93] + self.yellow = (0.95, 0.9, 0.02) # [ 242.25, 229.5 , 5.1 ] + self.dark_yellow = (0.710, 0.537, 0.000) # [181.05 , 136.935, 0. ] + + + self.color_palette = [self.blue, self.dark_blue, self.aubergine, self.green, self.yellow, self.orange, self.red, + self.cyan, self.black] + + self.box_labels = [ + # name color + boxLabel("det", self.blue), + boxLabel("prop", self.gray), + boxLabel("pos_anchor", self.cyan), + boxLabel("neg_anchor", self.cyan), + boxLabel("neg_class", self.green), + boxLabel("pos_class", self.aubergine), + boxLabel("gt", self.red) + ] # neg and pos in a medical sense, i.e., pos=positive diagnostic finding + + self.box_type2label = {label.name: label for label in self.box_labels} + self.box_color_palette = {label.name: label.color for label in self.box_labels} + + # whether the input data is mono-channel or RGB/rgb + self.has_colorchannels = False + + ######################### + # Data Loader # + ######################### + + #random seed for fold_generator and batch_generator. + self.seed = 0 + + #number of threads for multithreaded tasks like batch generation, wcs, merge2dto3d + self.n_workers = 16 if server_env else os.cpu_count() + + self.create_bounding_box_targets = True + self.class_specific_seg = True # False if self.model=="mrcnn" else True + ######################### + # Architecture # + ######################### + + self.prediction_tasks = ["class"] # 'class', 'regression_class', 'regression_kendall', 'regression_feindt' + + self.weight_decay = 0.0 + + # nonlinearity to be applied after convs with nonlinearity. one of 'relu' or 'leaky_relu' + self.relu = 'relu' + + # if True initializes weights as specified in model script. else use default Pytorch init. + self.weight_init = None + + # if True adds high-res decoder levels to feature pyramid: P1 + P0. (e.g. set to true in retina_unet configs) + self.operate_stride1 = False + + ######################### + # Optimization # + ######################### + + self.optimizer = "ADAM" # "ADAM" or "SGD" or implemented additionals + + ######################### + # Schedule # + ######################### + + # number of folds in cross validation. + self.n_cv_splits = 5 + + ######################### + # Testing / Plotting # + ######################### + + # perform mirroring at test time. (only XY. Z not done to not blow up predictions times). + self.test_aug = True + + # if True, test data lies in a separate folder and is not part of the cross validation. + self.held_out_test_set = False + # if hold-out test set: eval each fold's parameters separately on the test set + self.eval_test_fold_wise = True + + # if held_out_test_set provided, ensemble predictions over models of all trained cv-folds. + self.ensemble_folds = False + + # what metrics to evaluate + self.metrics = ['ap'] + # whether to evaluate fold means when evaluating over more than one fold + self.evaluate_fold_means = False + + # how often (in nr of epochs) to plot example batches during train/val + self.plot_frequency = 1 + + # color specifications for all box_types in prediction_plot. + self.box_color_palette = {'det': 'b', 'gt': 'r', 'neg_class': 'purple', + 'prop': 'w', 'pos_class': 'g', 'pos_anchor': 'c', 'neg_anchor': 'c'} + + # scan over confidence score in evaluation to optimize it on the validation set. + self.scan_det_thresh = False + + # plots roc-curves / prc-curves in evaluation. + self.plot_stat_curves = False + + # if True: evaluate average precision per patient id and average over per-pid results, + # instead of computing one ap over whole data set. + self.per_patient_ap = False + + # threshold for clustering 2D box predictions to 3D Cubes. Overlap is computed in XY. + self.merge_3D_iou = 0.1 + + ######################### + # MRCNN # + ######################### + + # if True, mask loss is not applied. used for data sets, where no pixel-wise annotations are provided. + self.frcnn_mode = False + + + + + self.return_masks_in_train = False + # if True, unmolds masks in Mask R-CNN to full-res for plotting/monitoring. + self.return_masks_in_val = False + self.return_masks_in_test = False # needed if doing instance segmentation. evaluation not yet implemented. + + # add P6 to Feature Pyramid Network. + self.sixth_pooling = False + + + ######################### + # RetinaNet # + ######################### + self.focal_loss = False + self.focal_loss_gamma = 2. diff --git a/evaluator.py b/evaluator.py new file mode 100644 index 0000000..cf93f5b --- /dev/null +++ b/evaluator.py @@ -0,0 +1,971 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +from multiprocessing import Pool +import pickle +import time + +import numpy as np +import pandas as pd +from sklearn.metrics import roc_auc_score, average_precision_score +from sklearn.metrics import roc_curve, precision_recall_curve +from sklearn.metrics import mean_squared_error, mean_absolute_error, accuracy_score +import torch + +import utils.model_utils as mutils +import plotting as plg + +import warnings + + +def get_roi_ap_from_df(inputs): + ''' + :param df: data frame. + :param det_thresh: min_threshold for filtering out low confidence predictions. + :param per_patient_ap: boolean flag. evaluate average precision per patient id and average over per-pid results, + instead of computing one ap over whole data set. + :return: average_precision (float) + ''' + + df, det_thresh, per_patient_ap = inputs + + if per_patient_ap: + pids_list = df.pid.unique() + aps = [] + for match_iou in df.match_iou.unique(): + iou_df = df[df.match_iou == match_iou] + for pid in pids_list: + pid_df = iou_df[iou_df.pid == pid] + all_p = len(pid_df[pid_df.class_label == 1]) + pid_df = pid_df[(pid_df.det_type == 'det_fp') | (pid_df.det_type == 'det_tp')].sort_values('pred_score', ascending=False) + pid_df = pid_df[pid_df.pred_score > det_thresh] + if (len(pid_df) ==0 and all_p == 0): + pass + elif (len(pid_df) > 0 and all_p == 0): + aps.append(0) + else: + aps.append(compute_roi_ap(pid_df, all_p)) + return np.mean(aps) + + else: + aps = [] + for match_iou in df.match_iou.unique(): + iou_df = df[df.match_iou == match_iou] + # it's important to not apply the threshold before counting all_p in order to not lose the fn! + all_p = len(iou_df[(iou_df.det_type == 'det_tp') | (iou_df.det_type == 'det_fn')]) + # sorting out all entries that are not fp or tp or have confidence(=pred_score) <= detection_threshold + iou_df = iou_df[(iou_df.det_type == 'det_fp') | (iou_df.det_type == 'det_tp')].sort_values('pred_score', ascending=False) + iou_df = iou_df[iou_df.pred_score > det_thresh] + if all_p>0: + aps.append(compute_roi_ap(iou_df, all_p)) + return np.mean(aps) + +def compute_roi_ap(df, all_p): + """ + adapted from: https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py + :param df: dataframe containing class labels of predictions sorted in descending manner by their prediction score. + :param all_p: number of all ground truth objects. (for denominator of recall.) + :return: + """ + tp = df.class_label.values + fp = (tp == 0) * 1 + #recall thresholds, where precision will be measured + R = np.linspace(0., 1., np.round((1. - 0.) / .01).astype(int) + 1, endpoint=True) + tp_sum = np.cumsum(tp) + fp_sum = np.cumsum(fp) + n_dets = len(tp) + rc = tp_sum / all_p + pr = tp_sum / (fp_sum + tp_sum) + + # initialize precision array over recall steps (q=queries). + q = [0. for _ in range(len(R))] + # numpy is slow without cython optimization for accessing elements + # python array gets significant speed improvement + pr = pr.tolist() + + for i in range(n_dets - 1, 0, -1): + if pr[i] > pr[i - 1]: + pr[i - 1] = pr[i] + #--> pr[i]<=pr[i-1] for all i since we want to consider the maximum + #precision value for a queried interval + + # discretize empiric recall steps with given bins. + assert np.all(rc[:-1]<=rc[1:]), "recall not sorted ascendingly" + inds = np.searchsorted(rc, R, side='left') + try: + for rc_ix, pr_ix in enumerate(inds): + q[rc_ix] = pr[pr_ix] + except IndexError: #now q is filled with pr values up to first non-available index + pass + + return np.mean(q) + +def roi_avp(inputs): + ''' + :param df: data frame. + :param det_thresh: min_threshold for filtering out low confidence predictions. + :param per_patient_ap: boolean flag. evaluate average precision per patient id and average over per-pid results, + instead of computing one ap over whole data set. + :return: average_precision (float) + ''' + + df, det_thresh, per_patient_ap = inputs + + if per_patient_ap: + pids_list = df.pid.unique() + aps = [] + for match_iou in df.match_iou.unique(): + iou_df = df[df.match_iou == match_iou] + for pid in pids_list: + pid_df = iou_df[iou_df.pid == pid] + all_p = len(pid_df[pid_df.class_label == 1]) + mask = ((pid_df.rg_bins == pid_df.rg_bin_target) & (pid_df.det_type == 'det_tp')) | (pid_df.det_type == 'det_fp') + pid_df = pid_df[mask].sort_values('pred_score', ascending=False) + pid_df = pid_df[pid_df.pred_score > det_thresh] + if (len(pid_df) ==0 and all_p == 0): + pass + elif (len(pid_df) > 0 and all_p == 0): + aps.append(0) + else: + aps.append(compute_roi_ap(pid_df, all_p)) + return np.mean(aps) + + else: + aps = [] + for match_iou in df.match_iou.unique(): + iou_df = df[df.match_iou == match_iou] + #it's important to not apply the threshold before counting all_positives! + all_p = len(iou_df[(iou_df.det_type == 'det_tp') | (iou_df.det_type == 'det_fn')]) + # filtering out tps which don't match rg_bin target at this point is same as reclassifying them as fn. + # also sorting out all entries that are not fp or have confidence(=pred_score) <= detection_threshold + mask = ((iou_df.rg_bins == iou_df.rg_bin_target) & (iou_df.det_type == 'det_tp')) | (iou_df.det_type == 'det_fp') + iou_df = iou_df[mask].sort_values('pred_score', ascending=False) + iou_df = iou_df[iou_df.pred_score > det_thresh] + if all_p>0: + aps.append(compute_roi_ap(iou_df, all_p)) + + return np.mean(aps) + +def compute_prc(df): + """compute precision-recall curve with maximum precision per recall interval. + :param df: + :param all_p: # of all positive samples in data. + :return: array: [precisions, recall query values] + """ + assert (df.class_label==1).any(), "cannot compute prc when no positives in data." + all_p = len(df[(df.det_type == 'det_tp') | (df.det_type == 'det_fn')]) + df = df[(df.det_type=="det_tp") | (df.det_type=="det_fp")] + df = df.sort_values("pred_score", ascending=False) + # recall thresholds, where precision will be measured + scores = df.pred_score.values + labels = df.class_label.values + n_dets = len(scores) + + pr = np.zeros((n_dets,)) + rc = pr.copy() + for rank in range(n_dets): + tp = np.count_nonzero(labels[:rank+1]==1) + fp = np.count_nonzero(labels[:rank+1]==0) + + pr[rank] = tp/(tp+fp) + rc[rank] = tp/all_p + + #after obj detection convention/ coco-dataset template: take maximum pr within intervals: + # --> pr[i]<=pr[i-1] for all i since we want to consider the maximum + # precision value for a queried interval + for i in range(n_dets - 1, 0, -1): + if pr[i] > pr[i - 1]: + pr[i - 1] = pr[i] + + R = np.linspace(0., 1., np.round((1. - 0.) / .01).astype(int) + 1, endpoint=True)#precision queried at R points + inds = np.searchsorted(rc, R, side='left') + queries = np.zeros((len(R),)) + try: + for q_ix, rank in enumerate(inds): + queries[q_ix] = pr[rank] + except IndexError: + pass + return np.array((queries, R)) + +def RMSE(y_true, y_pred, weights=None): + if len(y_true)>0: + return np.sqrt(mean_squared_error(y_true, y_pred, sample_weight=weights)) + else: + return np.nan + +def MAE_w_std(y_true, y_pred, weights=None): + if len(y_true)>0: + y_true, y_pred = np.array(y_true), np.array(y_pred) + deltas = np.abs(y_true-y_pred) + mae = np.average(deltas, weights=weights, axis=0).item() + skmae = mean_absolute_error(y_true, y_pred, sample_weight=weights) + assert np.allclose(mae, skmae, atol=1e-6), "mae {}, sklearn mae {}".format(mae, skmae) + std = np.std(weights*deltas) + return mae, std + + else: + return np.nan, np.nan + +def MAE(y_true, y_pred, weights=None): + if len(y_true)>0: + return mean_absolute_error(y_true, y_pred, sample_weight=weights) + else: + return np.nan + +def accuracy(y_true, y_pred, weights=None): + if len(y_true)>0: + return accuracy_score(y_true, y_pred, sample_weight=weights) + else: + return np.nan + + +# noinspection PyCallingNonCallable +class Evaluator(): + """ Evaluates given results dicts. Can return results as updated monitor_metrics. Can save test data frames to + file. + """ + + def __init__(self, cf, logger, mode='test'): + """ + :param mode: either 'train', 'val_sampling', 'val_patient' or 'test'. handles prediction lists of different forms. + """ + self.cf = cf + self.logger = logger + self.mode = mode + + self.regress_flag = any(['regression' in task for task in self.cf.prediction_tasks]) + + self.plot_dir = self.cf.plot_dir if not self.mode == "test" else self.cf.test_dir + if self.cf.plot_prediction_histograms: + self.hist_dir = os.path.join(self.plot_dir, 'histograms') + os.makedirs(self.hist_dir, exist_ok=True) + if self.cf.plot_stat_curves: + self.curves_dir = os.path.join(self.plot_dir, 'stat_curves') + os.makedirs(self.curves_dir, exist_ok=True) + + + def eval_losses(self, batch_res_dicts): + if hasattr(self.cf, "losses_to_monitor"): + loss_names = self.cf.losses_to_monitor + else: + loss_names = {name for b_res_dict in batch_res_dicts for name in b_res_dict if 'loss' in name} + self.epoch_losses = {l_name: torch.tensor([b_res_dict[l_name] for b_res_dict in batch_res_dicts if l_name + in b_res_dict.keys()]).mean().item() for l_name in loss_names} + + def eval_segmentations(self, batch_res_dicts, pid_list): + + batch_dices = [b_res_dict['batch_dices'] for b_res_dict in batch_res_dicts if + 'batch_dices' in b_res_dict.keys()] # shape (n_batches, n_seg_classes) + if len(batch_dices) > 0: + batch_dices = np.array(batch_dices) # dims n_batches x 1 in sampling / n_test_epochs x n_classes + assert batch_dices.shape[1] == self.cf.num_seg_classes, "bdices shp {}, n seg cl {}, pid lst len {}".format( + batch_dices.shape, self.cf.num_seg_classes, len(pid_list)) + self.seg_df = pd.DataFrame() + for seg_id in range(batch_dices.shape[1]): + self.seg_df[self.cf.seg_id2label[seg_id].name + "_dice"] = batch_dices[:, + seg_id] # one row== one batch, one column== one class + # self.seg_df[self.cf.seg_id2label[seg_id].name+"_dice"] = np.concatenate(batch_dices[:,:,seg_id]) + self.seg_df['fold'] = self.cf.fold + if self.mode == "val_patient" or self.mode == "test": + # need to make it more conform between sampling and patient-mode + self.seg_df["pid"] = [pid for pix, pid in enumerate(pid_list)] # for b_inst in batch_inst_boxes[pix]] + else: + self.seg_df["pid"] = np.nan + + def eval_boxes(self, batch_res_dicts, pid_list, obj_cl_dict, + obj_cl_identifiers={"gt":'class_targets', "pred":'box_pred_class_id'}): + """ + + :param batch_res_dicts: + :param pid_list: [pid_0, pid_1, ...] + :return: + """ + if self.mode == 'train' or self.mode == 'val_sampling': + # one pid per batch element + # batch_size > 1, with varying patients across batch: + # [[[results_0, ...], [pid_0, ...]], [[results_n, ...], [pid_n, ...]], ...] + # -> [results_0, results_1, ..] + batch_inst_boxes = [b_res_dict['boxes'] for b_res_dict in batch_res_dicts] # len: nr of batches in epoch + batch_inst_boxes = [[b_inst_boxes] for whole_batch_boxes in batch_inst_boxes for b_inst_boxes in + whole_batch_boxes] # len: batch instances of whole epoch + assert np.all(len(b_boxes_list) == self.cf.batch_size for b_boxes_list in batch_inst_boxes) + elif self.mode == "val_patient" or self.mode == "test": + # patient processing, one element per batch = one patient. + # [[results_0, pid_0], [results_1, pid_1], ...] -> [results_0, results_1, ..] + # in patientbatchiterator there is only one pid per batch + batch_inst_boxes = [b_res_dict['boxes'] for b_res_dict in batch_res_dicts] + # in patient mode not actually per batch instance, but per whole batch! + if hasattr(self.cf, "eval_test_separately") and self.cf.eval_test_separately: + """ you could write your own routines to add GTs to raw predictions for evaluation. + implemented standard is: cf.eval_test_separately = False or not set --> GTs are saved at same time + and in same file as raw prediction results. + """ + raise NotImplementedError + assert len(batch_inst_boxes) == len(pid_list) + + df_list_preds = [] + df_list_labels = [] + df_list_class_preds = [] + df_list_pids = [] + df_list_type = [] + df_list_match_iou = [] + df_list_n_missing = [] + df_list_regressions = [] + df_list_rg_targets = [] + df_list_rg_bins = [] + df_list_rg_bin_targets = [] + df_list_rg_uncs = [] + + for match_iou in self.cf.ap_match_ious: + self.logger.info('evaluating with ap_match_iou: {}'.format(match_iou)) + for cl in list(obj_cl_dict.keys()): + for pix, pid in enumerate(pid_list): + len_df_list_before_patient = len(df_list_pids) + # input of each batch element is a list of boxes, where each box is a dictionary. + for b_inst_ix, b_boxes_list in enumerate(batch_inst_boxes[pix]): + + b_tar_boxes = [] + b_cand_boxes, b_cand_scores, b_cand_n_missing = [], [], [] + if self.regress_flag: + b_tar_regs, b_tar_rg_bins = [], [] + b_cand_regs, b_cand_rg_bins, b_cand_rg_uncs = [], [], [] + for box in b_boxes_list: + # each box is either gt or detection or proposal/anchor + # we need all gts in the same order & all dets in same order + if box['box_type'] == 'gt' and box[obj_cl_identifiers["gt"]] == cl: + b_tar_boxes.append(box["box_coords"]) + if self.regress_flag: + b_tar_regs.append(np.array(box['regression_targets'], dtype='float32')) + b_tar_rg_bins.append(box['rg_bin_targets']) + + if box['box_type'] == 'det' and box[obj_cl_identifiers["pred"]] == cl: + b_cand_boxes.append(box["box_coords"]) + b_cand_scores.append(box["box_score"]) + b_cand_n_missing.append(box["cluster_n_missing"] if 'cluster_n_missing' in box.keys() else np.nan) + if self.regress_flag: + b_cand_regs.append(box["regression"]) + b_cand_rg_bins.append(box["rg_bin"]) + b_cand_rg_uncs.append(box["rg_uncertainty"] if 'rg_uncertainty' in box.keys() else np.nan) + b_tar_boxes = np.array(b_tar_boxes) + b_cand_boxes, b_cand_scores, b_cand_n_missing = np.array(b_cand_boxes), np.array(b_cand_scores), np.array(b_cand_n_missing) + if self.regress_flag: + b_tar_regs, b_tar_rg_bins = np.array(b_tar_regs), np.array(b_tar_rg_bins) + b_cand_regs, b_cand_rg_bins, b_cand_rg_uncs = np.array(b_cand_regs), np.array(b_cand_rg_bins), np.array(b_cand_rg_uncs) + + # check if predictions and ground truth boxes exist and match them according to match_iou. + if not 0 in b_cand_boxes.shape and not 0 in b_tar_boxes.shape: + assert np.all(np.round(b_cand_scores,6) <= 1.), "there is a box score>1: {}".format(b_cand_scores[~(b_cand_scores<=1.)]) + #coords_check = np.array([len(coords)==self.cf.dim*2 for coords in b_cand_boxes]) + #assert np.all(coords_check), "cand box with wrong bcoords dim: {}, mode: {}".format(b_cand_boxes[~coords_check], self.mode) + expected_dim = len(b_cand_boxes[0]) + assert np.all([len(coords) == expected_dim for coords in b_tar_boxes]), \ + "gt/cand box coords mismatch, expected dim: {}.".format(expected_dim) + + # overlaps: shape len(cand_boxes) x len(tar_boxes) + overlaps = mutils.compute_overlaps(b_cand_boxes, b_tar_boxes) + + # match_cand_ixs: shape (nr_of_matches,) + # theses indices are the indices of b_cand_boxes + match_cand_ixs = np.argwhere(np.max(overlaps, axis=1) > match_iou)[:, 0] + + non_match_cand_ixs = np.argwhere(np.max(overlaps, 1) <= match_iou)[:, 0] + # the corresponding gt assigned to the pred boxes by highest iou overlap, + # i.e., match_gt_ixs holds index into b_tar_boxes for each entry in match_cand_ixs, + # i.e., gt_ixs and cand_ixs are paired via their position in their list + # (cand_ixs[j] corresponds to gt_ixs[j]) + match_gt_ixs = np.argmax(overlaps[match_cand_ixs, :], axis=1) if \ + not 0 in match_cand_ixs.shape else np.array([]) + assert len(match_gt_ixs)==len(match_cand_ixs) + + #match_gt_ixs: shape (nr_of_matches,) or 0 + non_match_gt_ixs = np.array( + [ii for ii in np.arange(b_tar_boxes.shape[0]) if ii not in match_gt_ixs]) + unique, counts = np.unique(match_gt_ixs, return_counts=True) + + # check for double assignments, i.e. two predictions having been assigned to the same gt. + # according to the COCO-metrics, only one prediction counts as true positive, the rest counts as + # false positive. This case is supposed to be avoided by the model itself by, + # e.g. using a low enough NMS threshold. + if np.any(counts > 1): + double_match_gt_ixs = unique[np.argwhere(counts > 1)[:, 0]] + keep_max = [] + double_match_list = [] + for dg in double_match_gt_ixs: + double_match_cand_ixs = match_cand_ixs[np.argwhere(match_gt_ixs == dg)] + keep_max.append(double_match_cand_ixs[np.argmax(b_cand_scores[double_match_cand_ixs])]) + double_match_list += [ii for ii in double_match_cand_ixs] + + fp_ixs = np.array([ii for ii in match_cand_ixs if + (ii in double_match_list and ii not in keep_max)]) + # count as fp: boxes that match gt above match_iou threshold but have not highest class confidence score + match_gt_ixs = np.array([gt_ix for ii, gt_ix in enumerate(match_gt_ixs) if match_cand_ixs[ii] not in fp_ixs]) + match_cand_ixs = np.array([cand_ix for cand_ix in match_cand_ixs if cand_ix not in fp_ixs]) + assert len(match_gt_ixs) == len(match_cand_ixs) + + df_list_preds += [ii for ii in b_cand_scores[fp_ixs]] + df_list_labels += [0] * fp_ixs.shape[0] # means label==gt==0==bg for all these fp_ixs + df_list_class_preds += [cl] * fp_ixs.shape[0] + df_list_n_missing += [n for n in b_cand_n_missing[fp_ixs]] + if self.regress_flag: + df_list_regressions += [r for r in b_cand_regs[fp_ixs]] + df_list_rg_bins += [r for r in b_cand_rg_bins[fp_ixs]] + df_list_rg_uncs += [r for r in b_cand_rg_uncs[fp_ixs]] + df_list_rg_targets += [[0.]*self.cf.regression_n_features] * fp_ixs.shape[0] + df_list_rg_bin_targets += [0.] * fp_ixs.shape[0] + df_list_pids += [pid] * fp_ixs.shape[0] + df_list_type += ['det_fp'] * fp_ixs.shape[0] + + # matched/tp: + if not 0 in match_cand_ixs.shape: + df_list_preds += list(b_cand_scores[match_cand_ixs]) + df_list_labels += [1] * match_cand_ixs.shape[0] + df_list_class_preds += [cl] * match_cand_ixs.shape[0] + df_list_n_missing += list(b_cand_n_missing[match_cand_ixs]) + if self.regress_flag: + df_list_regressions += list(b_cand_regs[match_cand_ixs]) + df_list_rg_bins += list(b_cand_rg_bins[match_cand_ixs]) + df_list_rg_uncs += list(b_cand_rg_uncs[match_cand_ixs]) + assert len(match_cand_ixs)==len(match_gt_ixs) + df_list_rg_targets += list(b_tar_regs[match_gt_ixs]) + df_list_rg_bin_targets += list(b_tar_rg_bins[match_gt_ixs]) + df_list_pids += [pid] * match_cand_ixs.shape[0] + df_list_type += ['det_tp'] * match_cand_ixs.shape[0] + # rest fp: + if not 0 in non_match_cand_ixs.shape: + df_list_preds += list(b_cand_scores[non_match_cand_ixs]) + df_list_labels += [0] * non_match_cand_ixs.shape[0] + df_list_class_preds += [cl] * non_match_cand_ixs.shape[0] + df_list_n_missing += list(b_cand_n_missing[non_match_cand_ixs]) + if self.regress_flag: + df_list_regressions += list(b_cand_regs[non_match_cand_ixs]) + df_list_rg_bins += list(b_cand_rg_bins[non_match_cand_ixs]) + df_list_rg_uncs += list(b_cand_rg_uncs[non_match_cand_ixs]) + df_list_rg_targets += [[0.]*self.cf.regression_n_features] * non_match_cand_ixs.shape[0] + df_list_rg_bin_targets += [0.] * non_match_cand_ixs.shape[0] + df_list_pids += [pid] * non_match_cand_ixs.shape[0] + df_list_type += ['det_fp'] * non_match_cand_ixs.shape[0] + # fn: + if not 0 in non_match_gt_ixs.shape: + df_list_preds += [0] * non_match_gt_ixs.shape[0] + df_list_labels += [1] * non_match_gt_ixs.shape[0] + df_list_class_preds += [cl] * non_match_gt_ixs.shape[0] + df_list_n_missing += [np.nan] * non_match_gt_ixs.shape[0] + if self.regress_flag: + df_list_regressions += [[0.]*self.cf.regression_n_features] * non_match_gt_ixs.shape[0] + df_list_rg_bins += [0.] * non_match_gt_ixs.shape[0] + df_list_rg_uncs += [np.nan] * non_match_gt_ixs.shape[0] + df_list_rg_targets += list(b_tar_regs[non_match_gt_ixs]) + df_list_rg_bin_targets += list(b_tar_rg_bins[non_match_gt_ixs]) + df_list_pids += [pid] * non_match_gt_ixs.shape[0] + df_list_type += ['det_fn'] * non_match_gt_ixs.shape[0] + # only fp: + if not 0 in b_cand_boxes.shape and 0 in b_tar_boxes.shape: + # means there is no gt in all samples! any preds have to be fp. + df_list_preds += list(b_cand_scores) + df_list_labels += [0] * b_cand_boxes.shape[0] + df_list_class_preds += [cl] * b_cand_boxes.shape[0] + df_list_n_missing += list(b_cand_n_missing) + if self.regress_flag: + df_list_regressions += list(b_cand_regs) + df_list_rg_bins += list(b_cand_rg_bins) + df_list_rg_uncs += list(b_cand_rg_uncs) + df_list_rg_targets += [[0.]*self.cf.regression_n_features] * b_cand_boxes.shape[0] + df_list_rg_bin_targets += [0.] * b_cand_boxes.shape[0] + df_list_pids += [pid] * b_cand_boxes.shape[0] + df_list_type += ['det_fp'] * b_cand_boxes.shape[0] + # only fn: + if 0 in b_cand_boxes.shape and not 0 in b_tar_boxes.shape: + df_list_preds += [0] * b_tar_boxes.shape[0] + df_list_labels += [1] * b_tar_boxes.shape[0] + df_list_class_preds += [cl] * b_tar_boxes.shape[0] + df_list_n_missing += [np.nan] * b_tar_boxes.shape[0] + if self.regress_flag: + df_list_regressions += [[0.]*self.cf.regression_n_features] * b_tar_boxes.shape[0] + df_list_rg_bins += [0.] * b_tar_boxes.shape[0] + df_list_rg_uncs += [np.nan] * b_tar_boxes.shape[0] + df_list_rg_targets += list(b_tar_regs) + df_list_rg_bin_targets += list(b_tar_rg_bins) + df_list_pids += [pid] * b_tar_boxes.shape[0] + df_list_type += ['det_fn'] * b_tar_boxes.shape[0] + + # empty patient with 0 detections needs empty patient score, in order to not disappear from stats. + # filtered out for roi-level evaluation later. During training (and val_sampling), + # tn are assigned per sample independently of associated patients. + # i.e., patient_tn is also meant as sample_tn if a list of samples is evaluated instead of whole patient + if len(df_list_pids) == len_df_list_before_patient: + df_list_preds += [0] + df_list_labels += [0] + df_list_class_preds += [cl] + df_list_n_missing += [np.nan] + if self.regress_flag: + df_list_regressions += [[0.]*self.cf.regression_n_features] + df_list_rg_bins += [0.] + df_list_rg_uncs += [np.nan] + df_list_rg_targets += [[0.]*self.cf.regression_n_features] + df_list_rg_bin_targets += [0.] + df_list_pids += [pid] + df_list_type += ['patient_tn'] # true negative: no ground truth boxes, no detections. + + df_list_match_iou += [match_iou] * (len(df_list_preds) - len(df_list_match_iou)) + + self.test_df = pd.DataFrame() + self.test_df['pred_score'] = df_list_preds + self.test_df['class_label'] = df_list_labels + # class labels are gt, 0,1, only indicate neg/pos (or bg/fg) remapped from all classes + self.test_df['pred_class'] = df_list_class_preds # can be diff than 0,1 + self.test_df['pid'] = df_list_pids + self.test_df['det_type'] = df_list_type + self.test_df['fold'] = self.cf.fold + self.test_df['match_iou'] = df_list_match_iou + self.test_df['cluster_n_missing'] = df_list_n_missing + if self.regress_flag: + self.test_df['regressions'] = df_list_regressions + self.test_df['rg_targets'] = df_list_rg_targets + self.test_df['rg_uncertainties'] = df_list_rg_uncs + self.test_df['rg_bins'] = df_list_rg_bins + # super weird error: pandas does not properly add an attribute if column is named "rg_bin_targets" ... ?!? + self.test_df['rg_bin_target'] = df_list_rg_bin_targets + assert hasattr(self.test_df, "rg_bin_target") + + #fn_df = self.test_df[self.test_df["det_type"] == "det_fn"] + + pass + + def evaluate_predictions(self, results_list, monitor_metrics=None): + """ + Performs the matching of predicted boxes and ground truth boxes. Loops over list of matching IoUs and foreground classes. + Resulting info of each prediction is stored as one line in an internal dataframe, with the keys: + det_type: 'tp' (true positive), 'fp' (false positive), 'fn' (false negative), 'tn' (true negative) + pred_class: foreground class which the object predicts. + pid: corresponding patient-id. + pred_score: confidence score [0, 1] + fold: corresponding fold of CV. + match_iou: utilized IoU for matching. + :param results_list: list of model predictions. Either from train/val_sampling (patch processing) for monitoring with form: + [[[results_0, ...], [pid_0, ...]], [[results_n, ...], [pid_n, ...]], ...] + Or from val_patient/testing (patient processing), with form: [[results_0, pid_0], [results_1, pid_1], ...]) + :param monitor_metrics (optional): dict of dicts with all metrics of previous epochs. + :return monitor_metrics: if provided (during training), return monitor_metrics now including results of current epoch. + """ + # gets results_list = [[batch_instances_box_lists], [batch_instances_pids]]*n_batches + # we want to evaluate one batch_instance (= 2D or 3D image) at a time. + + + self.logger.info('evaluating in mode {}'.format(self.mode)) + + batch_res_dicts = [batch[0] for batch in results_list] # len: nr of batches in epoch + if self.mode == 'train' or self.mode=='val_sampling': + # one pid per batch element + # [[[results_0, ...], [pid_0, ...]], [[results_n, ...], [pid_n, ...]], ...] + # -> [pid_0, pid_1, ...] + # additional list wrapping to make conform with below per-patient batches, where one pid is linked to more than one batch instance + pid_list = [batch_instance_pid for batch in results_list for batch_instance_pid in batch[1]] + elif self.mode == "val_patient" or self.mode=="test": + # [[results_0, pid_0], [results_1, pid_1], ...] -> [pid_0, pid_1, ...] + # in patientbatchiterator there is only one pid per batch + pid_list = [np.unique(batch[1]) for batch in results_list] + assert np.all([len(pid)==1 for pid in pid_list]), "pid list in patient-eval mode, should only contain a single scalar per patient: {}".format(pid_list) + pid_list = [pid[0] for pid in pid_list] + else: + raise Exception("undefined run mode encountered") + + self.eval_losses(batch_res_dicts) + self.eval_segmentations(batch_res_dicts, pid_list) + self.eval_boxes(batch_res_dicts, pid_list, self.cf.class_dict) + + if monitor_metrics is not None: + # return all_stats, updated monitor_metrics + return self.return_metrics(self.test_df, self.cf.class_dict, monitor_metrics) + + def return_metrics(self, df, obj_cl_dict, monitor_metrics=None, boxes_only=False): + """ + Calculates metric scores for internal data frame. Called directly from evaluate_predictions during training for + monitoring, or from score_test_df during inference (for single folds or aggregated test set). + Loops over foreground classes and score_levels ('roi' and/or 'patient'), gets scores and stores them. + Optionally creates plots of prediction histograms and ROC/PR curves. + :param df: Data frame that holds evaluated predictions. + :param obj_cl_dict: Dict linking object-class ids to object-class names. E.g., {1: "bikes", 2 : "cars"}. Set in + configs as cf.class_dict. + :param monitor_metrics: dict of dicts with all metrics of previous epochs. This function adds metrics for + current epoch and returns the same object. + :param boxes_only: whether to produce metrics only for the boxes, not the segmentations. + :return: all_stats: list. Contains dicts with resulting scores for each combination of foreground class and + score_level. + :return: monitor_metrics + """ + + # -------------- monitoring independent of class, score level ------------ + if monitor_metrics is not None: + for l_name in self.epoch_losses: + monitor_metrics[l_name] = [self.epoch_losses[l_name]] + + # -------------- metrics calc dependent on class, score level ------------ + + all_stats = [] # all_stats: one entry per score_level per class + for cl in list(obj_cl_dict.keys()): # bg eval is neglected + cl_name = obj_cl_dict[cl] + cl_df = df[df.pred_class == cl] + if hasattr(self, "seg_df") and not boxes_only: + dice_col = self.cf.seg_id2label[cl].name+"_dice" + seg_cl_df = self.seg_df.loc[:,['pid', dice_col, 'fold']] + + for score_level in self.cf.report_score_level: + + stats_dict = {} + stats_dict['name'] = 'fold_{} {} {}'.format(self.cf.fold, score_level, cl_name) + + # -------------- RoI-based ----------------- + if score_level == 'rois': + + stats_dict['auc'] = np.nan + stats_dict['roc'] = np.nan + + if monitor_metrics is not None: + tn = len(cl_df[cl_df.det_type == "patient_tn"]) + tp = len(cl_df[(cl_df.det_type == "det_tp")&(cl_df.pred_score>self.cf.min_det_thresh)]) + fp = len(cl_df[(cl_df.det_type == "det_fp")&(cl_df.pred_score>self.cf.min_det_thresh)]) + fn = len(cl_df[cl_df.det_type == "det_fn"]) + sens = np.divide(tp, (fn + tp)) + monitor_metrics.update({"Bin_Stats/" + cl_name + "_fp": [fp], "Bin_Stats/" + cl_name + "_tp": [tp], + "Bin_Stats/" + cl_name + "_fn": [fn], "Bin_Stats/" + cl_name + "_tn": [tn], + "Bin_Stats/" + cl_name + "_sensitivity": [sens]}) + # list wrapping only needed bc other metrics are recorded over all epochs; + + spec_df = cl_df[cl_df.det_type != 'patient_tn'] + if self.regress_flag: + # filter false negatives out for regression-only eval since regressor didn't predict + truncd_df = spec_df[(((spec_df.det_type == "det_fp") | ( + spec_df.det_type == "det_tp")) & spec_df.pred_score > self.cf.min_det_thresh)] + truncd_df_tp = truncd_df[truncd_df.det_type == "det_tp"] + weights, weights_tp = truncd_df.pred_score.tolist(), truncd_df_tp.pred_score.tolist() + + y_true, y_pred = truncd_df.rg_targets.tolist(), truncd_df.regressions.tolist() + stats_dict["rg_RMSE"] = RMSE(y_true, y_pred) + stats_dict["rg_MAE"] = MAE(y_true, y_pred) + stats_dict["rg_RMSE_weighted"] = RMSE(y_true, y_pred, weights) + stats_dict["rg_MAE_weighted"] = MAE(y_true, y_pred, weights) + y_true, y_pred = truncd_df_tp.rg_targets.tolist(), truncd_df_tp.regressions.tolist() + stats_dict["rg_MAE_weighted_tp"] = MAE(y_true, y_pred, weights_tp) + stats_dict["rg_MAE_w_std_weighted_tp"] = MAE_w_std(y_true, y_pred, weights_tp) + + y_true, y_pred = truncd_df.rg_bin_target.tolist(), truncd_df.rg_bins.tolist() + stats_dict["rg_bin_accuracy"] = accuracy(y_true, y_pred) + stats_dict["rg_bin_accuracy_weighted"] = accuracy(y_true, y_pred, weights) + + y_true, y_pred = truncd_df_tp.rg_bin_target.tolist(), truncd_df_tp.rg_bins.tolist() + stats_dict["rg_bin_accuracy_weighted_tp"] = accuracy(y_true, y_pred, weights_tp) + if np.any(~truncd_df.rg_uncertainties.isna()): + # det_fn are expected to be NaN so they drop out in means + stats_dict.update({"rg_uncertainty": truncd_df.rg_uncertainties.mean(), + "rg_uncertainty_tp": truncd_df_tp.rg_uncertainties.mean(), + "rg_uncertainty_tp_weighted": (truncd_df_tp.rg_uncertainties * truncd_df_tp.pred_score).sum() + / truncd_df_tp.pred_score.sum() + }) + + if (spec_df.class_label==1).any(): + stats_dict['ap'] = get_roi_ap_from_df((spec_df, self.cf.min_det_thresh, self.cf.per_patient_ap)) + stats_dict['prc'] = precision_recall_curve(spec_df.class_label.tolist(), spec_df.pred_score.tolist()) + if self.regress_flag: + stats_dict['avp'] = roi_avp((spec_df, self.cf.min_det_thresh, self.cf.per_patient_ap)) + else: + stats_dict['ap'] = np.nan + stats_dict['prc'] = np.nan + stats_dict['avp'] = np.nan + # np.nan is formattable by __format__ as a float, None-type is not + + if hasattr(self, "seg_df") and not boxes_only: + stats_dict["dice"] = seg_cl_df.loc[:,dice_col].mean() # mean per all rois in this epoch + stats_dict["dice_std"] = seg_cl_df.loc[:,dice_col].std() + + # for the aggregated test set case, additionally get the scores of averaging over fold results. + if self.cf.evaluate_fold_means and len(df.fold.unique()) > 1: + aps = [] + for fold in df.fold.unique(): + fold_df = spec_df[spec_df.fold == fold] + if (fold_df.class_label==1).any(): + aps.append(get_roi_ap_from_df((fold_df, self.cf.min_det_thresh, self.cf.per_patient_ap))) + + stats_dict['ap_folds_mean'] = np.mean(aps) if len(aps)>0 else np.nan + stats_dict['ap_folds_std'] = np.std(aps) if len(aps)>0 else np.nan + stats_dict['auc_folds_mean'] = np.nan + stats_dict['auc_folds_std'] = np.nan + if self.regress_flag: + avps, accuracies, MAEs = [], [], [] + for fold in df.fold.unique(): + fold_df = spec_df[spec_df.fold == fold] + if (fold_df.class_label == 1).any(): + avps.append(roi_avp((fold_df, self.cf.min_det_thresh, self.cf.per_patient_ap))) + truncd_df_tp = fold_df[((fold_df.det_type == "det_tp") & fold_df.pred_score > self.cf.min_det_thresh)] + weights_tp = truncd_df_tp.pred_score.tolist() + y_true, y_pred = truncd_df_tp.rg_bin_target.tolist(), truncd_df_tp.rg_bins.tolist() + accuracies.append(accuracy(y_true, y_pred, weights_tp)) + y_true, y_pred = truncd_df_tp.rg_targets.tolist(), truncd_df_tp.regressions.tolist() + MAEs.append(MAE_w_std(y_true, y_pred, weights_tp)) + + stats_dict['avp_folds_mean'] = np.mean(avps) if len(avps) > 0 else np.nan + stats_dict['avp_folds_std'] = np.std(avps) if len(avps) > 0 else np.nan + stats_dict['rg_bin_accuracy_weighted_tp_folds_mean'] = np.mean(accuracies) if len(accuracies) > 0 else np.nan + stats_dict['rg_bin_accuracy_weighted_tp_folds_std'] = np.std(accuracies) if len(accuracies) > 0 else np.nan + stats_dict['rg_MAE_w_std_weighted_tp_folds_mean'] = np.mean(MAEs, axis=0) if len(MAEs) > 0 else np.nan + stats_dict['rg_MAE_w_std_weighted_tp_folds_std'] = np.std(MAEs, axis=0) if len(MAEs) > 0 else np.nan + + if hasattr(self, "seg_df") and not boxes_only and self.cf.evaluate_fold_means and len(seg_cl_df.fold.unique()) > 1: + fold_means = seg_cl_df.groupby(['fold'], as_index=True).agg({dice_col:"mean"}) + stats_dict["dice_folds_mean"] = fold_means.mean().item() + stats_dict["dice_folds_std"] = fold_means.std().item() + + # -------------- patient-based ----------------- + # on patient level, aggregate predictions per patient (pid): The patient predicted score is the highest + # confidence prediction for this class. The patient class label is 1 if roi of this class exists in patient, else 0. + if score_level == 'patient': + #this is the critical part in patient scoring: only the max gt and max pred score are taken per patient! + #--> does mix up values from separate detections + spec_df = cl_df.groupby(['pid'], as_index=False) + agg_args = {'class_label': 'max', 'pred_score': 'max', 'fold': 'first'} + if self.regress_flag: + # pandas throws error if aggregated value is np.array, not if is list. + agg_args.update({'regressions': lambda series: list(series.iloc[np.argmax(series.apply(np.linalg.norm).values)]), + 'rg_targets': lambda series: list(series.iloc[np.argmax(series.apply(np.linalg.norm).values)]), + 'rg_bins': 'max', 'rg_bin_target': 'max', + 'rg_uncertainties': 'max' + }) + if hasattr(cl_df, "cluster_n_missing"): + agg_args.update({'cluster_n_missing': 'mean'}) + spec_df = spec_df.agg(agg_args) + + if len(spec_df.class_label.unique()) > 1: + stats_dict['auc'] = roc_auc_score(spec_df.class_label.tolist(), spec_df.pred_score.tolist()) + stats_dict['roc'] = roc_curve(spec_df.class_label.tolist(), spec_df.pred_score.tolist()) + else: + stats_dict['auc'] = np.nan + stats_dict['roc'] = np.nan + + if (spec_df.class_label == 1).any(): + patient_cl_labels = spec_df.class_label.tolist() + stats_dict['ap'] = average_precision_score(patient_cl_labels, spec_df.pred_score.tolist()) + stats_dict['prc'] = precision_recall_curve(patient_cl_labels, spec_df.pred_score.tolist()) + if self.regress_flag: + avp_scores = spec_df[spec_df.rg_bins == spec_df.rg_bin_target].pred_score.tolist() + avp_scores += [0.] * (len(patient_cl_labels) - len(avp_scores)) + stats_dict['avp'] = average_precision_score(patient_cl_labels, avp_scores) + else: + stats_dict['ap'] = np.nan + stats_dict['prc'] = np.nan + stats_dict['avp'] = np.nan + if self.regress_flag: + y_true, y_pred = spec_df.rg_targets.tolist(), spec_df.regressions.tolist() + stats_dict["rg_RMSE"] = RMSE(y_true, y_pred) + stats_dict["rg_MAE"] = MAE(y_true, y_pred) + stats_dict["rg_bin_accuracy"] = accuracy(spec_df.rg_bin_target.tolist(), spec_df.rg_bins.tolist()) + stats_dict["rg_uncertainty"] = spec_df.rg_uncertainties.mean() + if hasattr(self, "seg_df") and not boxes_only: + seg_cl_df = seg_cl_df.groupby(['pid'], as_index=False).agg( + {dice_col: "mean", "fold": "first"}) # mean of all rois per patient in this epoch + stats_dict["dice"] = seg_cl_df.loc[:,dice_col].mean() #mean of all patients + stats_dict["dice_std"] = seg_cl_df.loc[:, dice_col].std() + + + # for the aggregated test set case, additionally get the scores for averaging over fold results. + if self.cf.evaluate_fold_means and len(df.fold.unique()) > 1 and self.mode in ["test", "analysis"]: + aucs = [] + aps = [] + for fold in df.fold.unique(): + fold_df = spec_df[spec_df.fold == fold] + if (fold_df.class_label==1).any(): + aps.append( + average_precision_score(fold_df.class_label.tolist(), fold_df.pred_score.tolist())) + if len(fold_df.class_label.unique())>1: + aucs.append(roc_auc_score(fold_df.class_label.tolist(), fold_df.pred_score.tolist())) + stats_dict['auc_folds_mean'] = np.mean(aucs) + stats_dict['auc_folds_std'] = np.std(aucs) + stats_dict['ap_folds_mean'] = np.mean(aps) + stats_dict['ap_folds_std'] = np.std(aps) + if hasattr(self, "seg_df") and not boxes_only and self.cf.evaluate_fold_means and len(seg_cl_df.fold.unique()) > 1: + fold_means = seg_cl_df.groupby(['fold'], as_index=True).agg({dice_col:"mean"}) + stats_dict["dice_folds_mean"] = fold_means.mean().item() + stats_dict["dice_folds_std"] = fold_means.std().item() + + all_stats.append(stats_dict) + + # -------------- monitoring, visualisation ----------------- + # fill new results into monitor_metrics dict. for simplicity, only one class (of interest) is monitored on patient level. + patient_interests = [self.cf.class_dict[self.cf.patient_class_of_interest],] + if hasattr(self.cf, "bin_dict"): + patient_interests += [self.cf.bin_dict[self.cf.patient_bin_of_interest]] + if monitor_metrics is not None and (score_level != 'patient' or cl_name in patient_interests): + name = 'patient_'+cl_name if score_level == 'patient' else cl_name + for metric in self.cf.metrics: + if metric in stats_dict.keys(): + monitor_metrics[name + '_'+metric].append(stats_dict[metric]) + else: + print("WARNING: skipped monitor metric {}_{} since not avail".format(name, metric)) + + # histograms + if self.cf.plot_prediction_histograms: + out_filename = os.path.join(self.hist_dir, 'pred_hist_{}_{}_{}_{}'.format( + self.cf.fold, self.mode, score_level, cl_name)) + plg.plot_prediction_hist(self.cf, spec_df, out_filename) + + # analysis of the hyper-parameter cf.min_det_thresh, for optimization on validation set. + if self.cf.scan_det_thresh and "val" in self.mode: + conf_threshs = list(np.arange(0.8, 1, 0.02)) + pool = Pool(processes=self.cf.n_workers) + mp_inputs = [[spec_df, ii, self.cf.per_patient_ap] for ii in conf_threshs] + aps = pool.map(get_roi_ap_from_df, mp_inputs, chunksize=1) + pool.close() + pool.join() + self.logger.info('results from scanning over det_threshs: {}'.format([[i, j] for i, j in zip(conf_threshs, aps)])) + + if self.cf.plot_stat_curves: + out_filename = os.path.join(self.curves_dir, '{}_{}_stat_curves'.format(self.cf.fold, self.mode)) + plg.plot_stat_curves(self.cf, all_stats, out_filename) + if self.cf.plot_prediction_histograms and hasattr(df, "cluster_n_missing") and df.cluster_n_missing.notna().any(): + out_filename = os.path.join(self.hist_dir, 'n_missing_hist_{}_{}.png'.format(self.cf.fold, self.mode)) + plg.plot_wbc_n_missing(self.cf, df, outfile=out_filename) + + return all_stats, monitor_metrics + + + def score_test_df(self, max_fold=None, internal_df=True): + """ + Writes out resulting scores to text files: First checks for class-internal-df (typically current) fold, + gets resulting scores, writes them to a text file and pickles data frame. Also checks if data-frame pickles of + all folds of cross-validation exist in exp_dir. If true, loads all dataframes, aggregates test sets over folds, + and calculates and writes out overall metrics. + """ + # this should maybe be extended to auc, ap stds. + metrics_to_score = self.cf.metrics # + [ m+ext for m in self.cf.metrics if "dice" in m for ext in ["_std"]] + + if internal_df: + + self.test_df.to_pickle(os.path.join(self.cf.test_dir, '{}_test_df.pkl'.format(self.cf.fold))) + if hasattr(self, "seg_df"): + self.seg_df.to_pickle(os.path.join(self.cf.test_dir, '{}_test_seg_df.pkl'.format(self.cf.fold))) + stats, _ = self.return_metrics(self.test_df, self.cf.class_dict) + + with open(os.path.join(self.cf.test_dir, 'results.txt'), 'a') as handle: + handle.write('\n****************************\n') + handle.write('\nresults for fold {}, {} \n'.format(self.cf.fold, time.strftime("%d/%m/%y %H:%M:%S"))) + handle.write('\n****************************\n') + handle.write('\nfold df shape {}\n \n'.format(self.test_df.shape)) + for s in stats: + for metric in metrics_to_score: + if metric in s.keys(): #needed as long as no dice on patient level poss + if "accuracy" in metric: + handle.write('{} {:0.4f} '.format(metric, s[metric])) + else: + handle.write('{} {:0.3f} '.format(metric, s[metric])) + else: + print("WARNING: skipped metric {} since not avail".format(metric)) + handle.write('{} \n'.format(s['name'])) + + fold_df_paths = sorted([ii for ii in os.listdir(self.cf.test_dir) if 'test_df.pkl' in ii]) + fold_seg_df_paths = sorted([ii for ii in os.listdir(self.cf.test_dir) if 'test_seg_df.pkl' in ii]) + for paths in [fold_df_paths, fold_seg_df_paths]: + assert len(paths)<= self.cf.n_cv_splits, "found {} > nr of cv splits results dfs in {}".format(len(paths), self.cf.test_dir) + if max_fold is None: + max_fold = self.cf.n_cv_splits-1 + if self.cf.fold == max_fold: + print("max fold/overall stats triggered") + if self.cf.evaluate_fold_means: + metrics_to_score += [m + ext for m in self.cf.metrics for ext in ("_folds_mean", "_folds_std")] + + with open(os.path.join(self.cf.test_dir, 'results.txt'), 'a') as handle: + + self.cf.fold = 'overall' + dfs_list = [pd.read_pickle(os.path.join(self.cf.test_dir, ii)) for ii in fold_df_paths] + seg_dfs_list = [pd.read_pickle(os.path.join(self.cf.test_dir, ii)) for ii in fold_seg_df_paths] + + self.test_df = pd.concat(dfs_list, sort=True) + if len(seg_dfs_list)>0: + self.seg_df = pd.concat(seg_dfs_list, sort=True) + stats, _ = self.return_metrics(self.test_df, self.cf.class_dict) + + handle.write('\n****************************\n') + handle.write('\nOVERALL RESULTS \n') + handle.write('\n****************************\n') + handle.write('\ndf shape \n \n'.format(self.test_df.shape)) + for s in stats: + for metric in metrics_to_score: + if metric in s.keys(): + handle.write('{} {:0.3f} '.format(metric, s[metric])) + handle.write('{} \n'.format(s['name'])) + + results_table_path = os.path.join(self.cf.test_dir,"../../", 'results_table.csv') + with open(results_table_path, 'a') as handle: + #---column headers--- + handle.write('\n{},'.format("Experiment Name")) + handle.write('{},'.format("Time Stamp")) + handle.write('{},'.format("Samples Seen")) + handle.write('{},'.format("Spatial Dim")) + handle.write('{},'.format("Patch Size")) + handle.write('{},'.format("CV Folds")) + handle.write('{},'.format("{}-clustering IoU".format(self.cf.clustering))) + handle.write('{},'.format("Merge-2D-to-3D IoU")) + if hasattr(self.cf, "test_against_exact_gt"): + handle.write('{},'.format('Exact GT')) + for s in stats: + assert "overall" in s['name'].split(" ")[0] + if self.cf.class_dict[self.cf.patient_class_of_interest] in s['name']: + for metric in metrics_to_score: + if metric in s.keys() and not np.isnan(s[metric]): + if metric=='ap': + handle.write('{}_{} : {}_{},'.format(*s['name'].split(" ")[1:], metric, int(np.mean(self.cf.ap_match_ious)*100))) + elif not "folds_std" in metric: + handle.write('{}_{} : {},'.format(*s['name'].split(" ")[1:], metric)) + else: + print("WARNING: skipped metric {} since not avail".format(metric)) + handle.write('\n') + + #--- columns content--- + handle.write('{},'.format(self.cf.exp_dir.split(os.sep)[-1])) + handle.write('{},'.format(time.strftime("%d%b%y %H:%M:%S"))) + handle.write('{},'.format(self.cf.num_epochs*self.cf.num_train_batches*self.cf.batch_size)) + handle.write('{}D,'.format(self.cf.dim)) + handle.write('{},'.format("x".join([str(self.cf.patch_size[i]) for i in range(self.cf.dim)]))) + handle.write('{},'.format(str(self.test_df.fold.unique().tolist()).replace(",", ""))) + handle.write('{},'.format(self.cf.clustering_iou if self.cf.clustering else str("N/A"))) + handle.write('{},'.format(self.cf.merge_3D_iou if self.cf.merge_2D_to_3D_preds else str("N/A"))) + if hasattr(self.cf, "test_against_exact_gt"): + handle.write('{},'.format(self.cf.test_against_exact_gt)) + for s in stats: + if self.cf.class_dict[self.cf.patient_class_of_interest] in s['name']: + for metric in metrics_to_score: + if metric in s.keys() and not np.isnan(s[metric]): # needed as long as no dice on patient level possible + if "folds_mean" in metric: + handle.write('{:0.3f}\u00B1{:0.3f}, '.format(s[metric], s["_".join((*metric.split("_")[:-1], "std"))])) + elif not "folds_std" in metric: + handle.write('{:0.3f}, '.format(s[metric])) + + handle.write('\n') + + with open(os.path.join(self.cf.test_dir, 'results_extr_scores.txt'), 'w') as handle: + handle.write('\n****************************\n') + handle.write('\nextremal scores for fold {} \n'.format(self.cf.fold)) + handle.write('\n****************************\n') + # want: pid & fold (&other) of highest scoring tp & fp in test_df + for cl in self.cf.class_dict.keys(): + print("\nClass {}".format(self.cf.class_dict[cl]), file=handle) + cl_df = self.test_df[self.test_df.pred_class == cl] #.dropna(axis=1) + for det_type in ['det_tp', 'det_fp']: + filtered_df = cl_df[cl_df.det_type==det_type] + print("\nHighest scoring {} of class {}".format(det_type, self.cf.class_dict[cl]), file=handle) + if len(filtered_df)>0: + print(filtered_df.loc[filtered_df.pred_score.idxmax()], file=handle) + else: + print("No detections of type {} for class {} in this df".format(det_type, self.cf.class_dict[cl]), file=handle) + handle.write('\n****************************\n') diff --git a/exec.py b/exec.py new file mode 100644 index 0000000..413db13 --- /dev/null +++ b/exec.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" execution script. this where all routines come together and the only script you need to call. + refer to parse args below to see options for execution. +""" + +import plotting as plg + +import os +import warnings +import argparse +import time + +import torch + +import utils.exp_utils as utils +from evaluator import Evaluator +from predictor import Predictor + + +for msg in ["Attempting to set identical bottom==top results", + "This figure includes Axes that are not compatible with tight_layout", + "Data has no positive values, and therefore cannot be log-scaled.", + ".*invalid value encountered in true_divide.*"]: + warnings.filterwarnings("ignore", msg) + + +def train(cf, logger): + """ + performs the training routine for a given fold. saves plots and selected parameters to the experiment dir + specified in the configs. + + """ + logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format( + cf.dim, cf.fold, cf.exp_dir, cf.model)) + logger.time("train_val") + + # -------------- inits and settings ----------------- + net = model.net(cf, logger).cuda() + if cf.optimizer == "ADAM": + optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay) + elif cf.optimizer == "SGD": + optimizer = torch.optim.SGD(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay, momentum=0.3) + if cf.dynamic_lr_scheduling: + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=cf.scheduling_mode, factor=cf.lr_decay_factor, + patience=cf.scheduling_patience) + model_selector = utils.ModelSelector(cf, logger) + + starting_epoch = 1 + if cf.resume_from_checkpoint: + starting_epoch = utils.load_checkpoint(cf.resume_from_checkpoint, net, optimizer) + logger.info('resumed from checkpoint {} at epoch {}'.format(cf.resume_from_checkpoint, starting_epoch)) + + # prepare monitoring + monitor_metrics = utils.prepare_monitoring(cf) + + logger.info('loading dataset and initializing batch generators...') + batch_gen = data_loader.get_train_generators(cf, logger) + + # -------------- training ----------------- + for epoch in range(starting_epoch, cf.num_epochs + 1): + + logger.info('starting training epoch {}/{}'.format(epoch, cf.num_epochs)) + logger.time("train_epoch") + + net.train() + + train_results_list = [] + train_evaluator = Evaluator(cf, logger, mode='train') + + for i in range(cf.num_train_batches): + logger.time("train_batch_loadfw") + batch = next(batch_gen['train']) + batch_gen['train'].generator.stats['roi_counts'] += batch['roi_counts'] + batch_gen['train'].generator.stats['empty_samples_count'] += batch['empty_samples_count'] + + logger.time("train_batch_loadfw") + logger.time("train_batch_netfw") + results_dict = net.train_forward(batch) + logger.time("train_batch_netfw") + logger.time("train_batch_bw") + optimizer.zero_grad() + results_dict['torch_loss'].backward() + if cf.clip_norm: + torch.nn.utils.clip_grad_norm_(net.parameters(), cf.clip_norm, norm_type=2) #gradient clipping + optimizer.step() + train_results_list.append(({k:v for k,v in results_dict.items() if k != "seg_preds"}, batch["pid"])) #slim res dict + if not cf.server_env: + print("\rFinished training batch " + + "{}/{} in {:.1f}s ({:.2f}/{:.2f} forw load/net, {:.2f} backw).".format(i+1, cf.num_train_batches, + logger.get_time("train_batch_loadfw")+ + logger.get_time("train_batch_netfw") + +logger.time("train_batch_bw"), + logger.get_time("train_batch_loadfw",reset=True), + logger.get_time("train_batch_netfw", reset=True), + logger.get_time("train_batch_bw", reset=True)), end="", flush=True) + print() + + #--------------- train eval ---------------- + if (epoch-1)%cf.plot_frequency==0: + # view an example batch + plg.view_batch(cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True, + out_file=os.path.join(cf.plot_dir, 'batch_example_train_{}.png'.format(cf.fold))) + + + logger.time("evals") + _, monitor_metrics['train'] = train_evaluator.evaluate_predictions(train_results_list, monitor_metrics['train']) + #np_loss, torch_loss = train_loss_running_mean / cf.num_train_batches, monitor_metrics['train']["loss"][-1] + #assert np_loss/torch_loss-1<0.005, "{} vs {}".format(np_loss, torch_loss) + logger.time("evals") + logger.time("train_epoch", toggle=False) + del train_results_list + #----------- validation ------------ + logger.info('starting validation in mode {}.'.format(cf.val_mode)) + logger.time("val_epoch") + with torch.no_grad(): + net.eval() + val_results_list = [] + val_evaluator = Evaluator(cf, logger, mode=cf.val_mode) + val_predictor = Predictor(cf, net, logger, mode='val') + + for i in range(batch_gen['n_val']): + logger.time("val_batch") + batch = next(batch_gen[cf.val_mode]) + if cf.val_mode == 'val_patient': + results_dict = val_predictor.predict_patient(batch) + elif cf.val_mode == 'val_sampling': + results_dict = net.train_forward(batch, is_validation=True) + val_results_list.append([results_dict, batch["pid"]]) + if not cf.server_env: + print("\rFinished validation {} {}/{} in {:.1f}s.".format('patient' if cf.val_mode=='val_patient' else 'batch', + i + 1, batch_gen['n_val'], + logger.time("val_batch")), end="", flush=True) + print() + + #------------ val eval ------------- + logger.time("val_plot") + if (epoch - 1) % cf.plot_frequency == 0: + plg.view_batch(cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True, + out_file=os.path.join(cf.plot_dir, 'batch_example_val_{}.png'.format(cf.fold))) + logger.time("val_plot") + + logger.time("evals") + _, monitor_metrics['val'] = val_evaluator.evaluate_predictions(val_results_list, monitor_metrics['val']) + + model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch) + del val_results_list + #----------- monitoring ------------- + monitor_metrics.update({"lr": + {str(g) : group['lr'] for (g, group) in enumerate(optimizer.param_groups)}}) + logger.metrics2tboard(monitor_metrics, global_step=epoch) + logger.time("evals") + + logger.info('finished epoch {}/{}, took {:.2f}s. train total: {:.2f}s, average: {:.2f}s. val total: {:.2f}s, average: {:.2f}s.'.format( + epoch, cf.num_epochs, logger.get_time("train_epoch")+logger.time("val_epoch"), logger.get_time("train_epoch"), + logger.get_time("train_epoch", reset=True)/cf.num_train_batches, logger.get_time("val_epoch"), + logger.get_time("val_epoch", reset=True)/batch_gen["n_val"])) + logger.info("time for evals: {:.2f}s, val plot {:.2f}s".format(logger.get_time("evals", reset=True), logger.get_time("val_plot", reset=True))) + + #-------------- scheduling ----------------- + if not cf.dynamic_lr_scheduling: + for param_group in optimizer.param_groups: + param_group['lr'] = cf.learning_rate[epoch-1] + else: + scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1]) + + logger.time("train_val") + logger.info("Training and validating over {} epochs took {}".format(cf.num_epochs, logger.get_time("train_val", format="hms", reset=True))) + batch_gen['train'].generator.print_stats(logger, plot=True) + +def test(cf, logger, max_fold=None): + """performs testing for a given fold (or held out set). saves stats in evaluator. + """ + logger.time("test_fold") + logger.info('starting testing model of fold {} in exp {}'.format(cf.fold, cf.exp_dir)) + net = model.net(cf, logger).cuda() + batch_gen = data_loader.get_test_generator(cf, logger) + + test_predictor = Predictor(cf, net, logger, mode='test') + test_results_list = test_predictor.predict_test_set(batch_gen, return_results = not hasattr( + cf, "eval_test_separately") or not cf.eval_test_separately) + + if test_results_list is not None: + test_evaluator = Evaluator(cf, logger, mode='test') + test_evaluator.evaluate_predictions(test_results_list) + test_evaluator.score_test_df(max_fold=max_fold) + + mins, secs = divmod(logger.get_time("test_fold"), 60) + h, mins = divmod(mins, 60) + t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) + + logger.info('Testing of fold {} took {}.'.format(cf.fold, t)) + + +if __name__ == '__main__': + stime = time.time() + + parser = argparse.ArgumentParser() + parser.add_argument('-m', '--mode', type=str, default='train_test', help='one out of: create_exp, analysis, train, train_test, or test') + parser.add_argument('-f', '--folds', nargs='+', type=int, default=None, help='None runs over all folds in CV. otherwise specify list of folds.') + parser.add_argument('--exp_dir', type=str, default='/home/gregor/Documents/medicaldetectiontoolkit/datasets/prostate/experiments/dev', + help='path to experiment dir. will be created if non existent.') + parser.add_argument('--server_env', default=False, action='store_true', help='change IO settings to deploy models on a cluster.') + parser.add_argument('--data_dest', type=str, default=None, help="path to final data folder if different from config") + parser.add_argument('--use_stored_settings', default=False, action='store_true', + help='load configs from existing exp_dir instead of source dir. always done for testing, ' + 'but can be set to true to do the same for training. useful in job scheduler environment, ' + 'where source code might change before the job actually runs.') + parser.add_argument('--resume_from_checkpoint', type=str, default=None, + help='path to checkpoint. if resuming from checkpoint, the desired fold still needs to be parsed via --folds.') + parser.add_argument('--dataset_name', type=str, default='prostate', help="path to the dataset-specific code in source_dir/datasets") + parser.add_argument('-d', '--dev', default=False, action='store_true', help="development mode: shorten everything") + + args = parser.parse_args() + args.dataset_name = os.path.join("datasets", args.dataset_name) if not "datasets" in args.dataset_name else args.dataset_name + folds = args.folds + resume_from_checkpoint = None if args.resume_from_checkpoint in ['None', 'none'] else args.resume_from_checkpoint + + if args.mode == 'create_exp': + cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=False) + logger = utils.get_logger(cf.exp_dir, cf.server_env) + logger.info('created experiment directory at {}'.format(args.exp_dir)) + + elif args.mode == 'train' or args.mode == 'train_test': + cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, args.use_stored_settings) + if args.dev: + folds = [0,1] + cf.batch_size, cf.num_epochs, cf.min_save_thresh, cf.save_n_models = 3 if cf.dim==2 else 1, 1, 0, 1 + cf.num_train_batches, cf.num_val_batches, cf.max_val_patients = 7, 1, 1 + cf.test_n_epochs = cf.save_n_models + cf.max_test_patients = 1 + torch.backends.cudnn.benchmark = cf.dim==3 + else: + torch.backends.cudnn.benchmark = cf.cuda_benchmark + if args.data_dest is not None: + cf.data_dest = args.data_dest + + logger = utils.get_logger(cf.exp_dir, cf.server_env) + data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py')) + model = utils.import_module('model', cf.model_path) + logger.info("loaded model from {}".format(cf.model_path)) + if folds is None: + folds = range(cf.n_cv_splits) + + for fold in folds: + """k-fold cross-validation: the dataset is split into k equally-sized folds, one used for validation, + one for testing, the rest for training. This loop iterates k-times over the dataset, cyclically moving the + splits. k==folds, fold in [0,folds) says which split is used for testing. + """ + cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold)) + cf.fold, logger.fold = fold, fold + cf.resume_from_checkpoint = resume_from_checkpoint + if not os.path.exists(cf.fold_dir): + os.mkdir(cf.fold_dir) + train(cf, logger) + cf.resume_from_checkpoint = None + if args.mode == 'train_test': + test(cf, logger) + + elif args.mode == 'test': + cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False) + if args.data_dest is not None: + cf.data_dest = args.data_dest + logger = utils.get_logger(cf.exp_dir, cf.server_env) + data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py')) + model = utils.import_module('model', cf.model_path) + logger.info("loaded model from {}".format(cf.model_path)) + + fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if + os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")]) + if folds is None: + folds = range(cf.n_cv_splits) + if args.dev: + folds = folds[:2] + cf.batch_size, cf.num_test_patients, cf.test_n_epochs = 1 if cf.dim==2 else 1, 2, 2 + else: + torch.backends.cudnn.benchmark = cf.cuda_benchmark + for fold in folds: + cf.fold = fold + cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold)) + if cf.fold_dir in fold_dirs: + test(cf, logger, max_fold=max([int(f[-1]) for f in fold_dirs])) + else: + logger.info("Skipping fold {} since no model parameters found.".format(fold)) + # load raw predictions saved by predictor during testing, run aggregation algorithms and evaluation. + elif args.mode == 'analysis': + """ analyse already saved predictions. + """ + cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False) + logger = utils.get_logger(cf.exp_dir, cf.server_env) + + if cf.held_out_test_set and not cf.eval_test_fold_wise: + predictor = Predictor(cf, net=None, logger=logger, mode='analysis') + results_list = predictor.load_saved_predictions() + logger.info('starting evaluation...') + cf.fold = 0 + evaluator = Evaluator(cf, logger, mode='test') + evaluator.evaluate_predictions(results_list) + evaluator.score_test_df(max_fold=0) + else: + fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if + os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")]) + if args.dev: + fold_dirs = fold_dirs[:1] + if folds is None: + folds = range(cf.n_cv_splits) + for fold in folds: + cf.fold = fold + cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold)) + + if cf.fold_dir in fold_dirs: + predictor = Predictor(cf, net=None, logger=logger, mode='analysis') + results_list = predictor.load_saved_predictions() + # results_list[x][1] is pid, results_list[x][0] is list of len samples-per-patient, each entry hlds + # list of boxes per that sample, i.e., len(results_list[x][y][0]) would be nr of boxes in sample y of patient x + logger.info('starting evaluation...') + evaluator = Evaluator(cf, logger, mode='test') + evaluator.evaluate_predictions(results_list) + max_fold = max([int(f[-1]) for f in fold_dirs]) + evaluator.score_test_df(max_fold=max_fold) + else: + logger.info("Skipping fold {} since no model parameters found.".format(fold)) + else: + raise ValueError('mode "{}" specified in args is not implemented.'.format(args.mode)) + + mins, secs = divmod((time.time() - stime), 60) + h, mins = divmod(mins, 60) + t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) + logger.info("{} total runtime: {}".format(os.path.split(__file__)[1], t)) + del logger + torch.cuda.empty_cache() + + + diff --git a/graphics_generation.py b/graphics_generation.py new file mode 100644 index 0000000..6c59a0c --- /dev/null +++ b/graphics_generation.py @@ -0,0 +1,1932 @@ +""" +Created at 07/03/19 11:42 +@author: gregor +""" +import plotting as plg +import matplotlib.lines as mlines + +import os +import sys +import multiprocessing +from copy import deepcopy +import logging +import time + +import numpy as np +import pandas as pd +from scipy.stats import norm +from sklearn.metrics import confusion_matrix + +import utils.exp_utils as utils +import utils.model_utils as mutils +import utils.dataloader_utils as dutils +from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates + +import predictor as predictor_file +import evaluator as evaluator_file + + + +class NoDaemonProcess(multiprocessing.Process): + # make 'daemon' attribute always return False + def _get_daemon(self): + return False + def _set_daemon(self, value): + pass + daemon = property(_get_daemon, _set_daemon) + +# We sub-class multiprocessing.pool.Pool instead of multiprocessing.Pool +# because the latter is only a wrapper function, not a proper class. +class NoDaemonProcessPool(multiprocessing.pool.Pool): + Process = NoDaemonProcess + +class AttributeDict(dict): + __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ + +def get_cf(dataset_name, exp_dir=""): + + cf_path = os.path.join('datasets', dataset_name, exp_dir, "configs.py") + cf_file = utils.import_module('configs', cf_path) + + return cf_file.Configs() + + +def prostate_results_static(plot_dir=None): + cf = get_cf('prostate', '') + if plot_dir is None: + plot_dir = os.path.join('datasets', 'prostate', 'misc') + + text_fs = 18 + fig = plg.plt.figure(figsize=(6, 3)) #w,h + grid = plg.plt.GridSpec(1, 1, wspace=0.0, hspace=0.0, figure=fig) #r,c + + groups = ["b values", "ADC + b values", "T2"] + splits = ["Det. U-Net", "Mask R-CNN", "Faster R-CNN+"] + values = {"detu": [(0.296, 0.031), (0.312, 0.045), (0.090, 0.040)], + "mask": [(0.393, 0.051), (0.382, 0.047), (0.136, 0.016)], + "fast": [(0.424, 0.083), (0.390, 0.086), (0.036, 0.013)]} + bar_values = [[v[0] for v in split] for split in values.values()] + errors = [[v[1] for v in split] for split in values.values()] + ax = fig.add_subplot(grid[0,0]) + colors = [cf.aubergine, cf.blue, cf.dark_blue] + plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, errors=errors, colors=colors, ax=ax, legend=True, + title="Prostate Main Results (3D)", ylabel=r"Performance as $\mathrm{AP}_{10}$", xlabel="Input Modalities") + plg.plt.tight_layout() + plg.plt.savefig(os.path.join(plot_dir, 'prostate_main_results.png'), dpi=600) + +def prostate_GT_examples(exp_dir='', plot_dir=None, pid=8., z_ix=None): + + import datasets.prostate.data_loader as dl + cf = get_cf('prostate', exp_dir) + cf.exp_dir = exp_dir + cf.fold = 0 + cf.data_sourcedir = "/mnt/HDD2TB/Documents/data/prostate/data_di_250519_ps384_gs6071/" + dataset = dl.Dataset(cf) + dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits) + dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle')) + set_splits = dataset.fg.splits + + test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1) + train_ids = np.concatenate(set_splits, axis=0) + + if cf.held_out_test_set: + train_ids = np.concatenate((train_ids, test_ids), axis=0) + test_ids = [] + print("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids), + len(test_ids))) + + + if plot_dir is None: + plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'prostate', 'misc') + + text_fs = 18 + fig = plg.plt.figure(figsize=(10, 7.7)) #w,h + grid = plg.plt.GridSpec(3, 4, wspace=0.0, hspace=0.0, figure=fig) #r,c + text_x, text_y = 0.1, 0.8 + + # ------- DWI ------- + if z_ix is None: + z_ix_dwi = np.random.choice(dataset[pid]["fg_slices"]) + img = np.load(dataset[pid]["img"])[:,z_ix_dwi] # mods, z,y,x + seg = np.load(dataset[pid]["seg"])[z_ix_dwi] # z,y,x + ax = fig.add_subplot(grid[0,0]) + ax.imshow(img[0], cmap='gray') + ax.text(text_x, text_y, "ADC", size=text_fs, color=cf.white, transform=ax.transAxes, + bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7)) + ax.axis('off') + ax = fig.add_subplot(grid[0,1]) + ax.imshow(img[0], cmap='gray') + cmap = cf.class_cmap + for r_ix in np.unique(seg[seg>0]): + seg[seg==r_ix] = dataset[pid]["class_targets"][r_ix-1] + ax.imshow(plg.to_rgba(seg, cmap), alpha=1) + ax.text(text_x, text_y, "DWI GT", size=text_fs, color=cf.white, transform=ax.transAxes, + bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7)) + ax.axis('off') + for b_ix, b in enumerate([50,500,1000,1500]): + ax = fig.add_subplot(grid[1, b_ix]) + ax.imshow(img[b_ix+1], cmap='gray') + ax.text(text_x, text_y, r"{}{}".format("$b=$" if b_ix == 0 else "", b), size=text_fs, color=cf.white, + transform=ax.transAxes, + bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7)) + ax.axis('off') + + # ----- T2 ----- + cf.data_sourcedir = "/mnt/HDD2TB/Documents/data/prostate/data_t2_250519_ps384_gs6071/" + dataset = dl.Dataset(cf) + if z_ix is None: + if z_ix_dwi in dataset[pid]["fg_slices"]: + z_ix_t2 = z_ix_dwi + else: + z_ix_t2 = np.random.choice(dataset[pid]["fg_slices"]) + img = np.load(dataset[pid]["img"])[:,z_ix_t2] # mods, z,y,x + seg = np.load(dataset[pid]["seg"])[z_ix_t2] # z,y,x + ax = fig.add_subplot(grid[2,0]) + ax.imshow(img[0], cmap='gray') + ax.text(text_x, text_y, "T2w", size=text_fs, color=cf.white, transform=ax.transAxes, + bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7)) + ax.axis('off') + ax = fig.add_subplot(grid[2,1]) + ax.imshow(img[0], cmap='gray') + cmap = cf.class_cmap + for r_ix in np.unique(seg[seg>0]): + seg[seg==r_ix] = dataset[pid]["class_targets"][r_ix-1] + ax.imshow(plg.to_rgba(seg, cmap), alpha=1) + ax.text(text_x, text_y, "T2 GT", size=text_fs, color=cf.white, transform=ax.transAxes, + bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7)) + ax.axis('off') + + #grid.tight_layout(fig) + plg.plt.tight_layout() + plg.plt.savefig(os.path.join(plot_dir, 'prostate_gt_examples.png'), dpi=600) + + +def prostate_dataset_stats(exp_dir='', plot_dir=None, show_splits=True,): + + import datasets.prostate.data_loader as dl + cf = get_cf('prostate', exp_dir) + cf.exp_dir = exp_dir + cf.fold = 0 + dataset = dl.Dataset(cf) + dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits) + dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle')) + set_splits = dataset.fg.splits + + test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1) + train_ids = np.concatenate(set_splits, axis=0) + + if cf.held_out_test_set: + train_ids = np.concatenate((train_ids, test_ids), axis=0) + test_ids = [] + + print("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids), + len(test_ids))) + + df, labels = dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids}, plot_dir=None) + + if plot_dir is None: + plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'prostate', 'misc') + + if show_splits: + fig = plg.plt.figure(figsize=(6, 6)) # w, h + grid = plg.plt.GridSpec(2, 2, wspace=0.05, hspace=0.15, figure=fig) # rows, cols + else: + fig = plg.plt.figure(figsize=(6, 3.)) + grid = plg.plt.GridSpec(1, 1, wspace=0.0, hspace=0.15, figure=fig) + + ax = fig.add_subplot(grid[0,0]) + ax = plg.plot_data_stats(cf, df, labels, ax=ax) + ax.set_xlabel("") + ax.set_xticklabels(df.columns, rotation='horizontal', fontsize=11) + ax.set_title("") + if show_splits: + ax.text(0.05,0.95, 'a)', horizontalalignment='center', verticalalignment='center', transform = ax.transAxes, weight='bold') + ax.text(0, 25, "GS$=6$", horizontalalignment='center', verticalalignment='center', bbox=dict(facecolor=(*cf.white, 0.8), edgecolor=cf.dark_green, pad=3)) + ax.text(1, 25, "GS$\geq 7a$", horizontalalignment='center', verticalalignment='center', bbox=dict(facecolor=(*cf.white, 0.8), edgecolor=cf.red, pad=3)) + ax.margins(y=0.1) + + if show_splits: + ax = fig.add_subplot(grid[:, 1]) + ax = plg.plot_fold_stats(cf, df, labels, ax=ax) + ax.set_xlabel("") + ax.set_title("") + ax.text(0.05, 0.98, 'c)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, weight='bold') + ax.yaxis.tick_right() + ax.yaxis.set_label_position("right") + ax.margins(y=0.1) + + ax = fig.add_subplot(grid[1, 0]) + cf.balance_target = "lesion_gleasons" + dataset.df = None + df, labels = dataset.calc_statistics(plot_dir=None, overall_stats=True) + ax = plg.plot_data_stats(cf, df, labels, ax=ax) + ax.set_xlabel("") + ax.set_title("") + ax.text(0.05, 0.95, 'b)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, weight='bold') + ax.margins(y=0.1) + # rename GS according to names in thesis + renamer = {'GS60':'GS 6', 'GS71':'GS 7a', 'GS72':'GS 7b', 'GS80':'GS 8', 'GS90': 'GS 9', 'GS91':'GS 9a', 'GS92':'GS 9b'} + x_ticklabels = [str(l.get_text()) for l in ax.xaxis.get_ticklabels()] + ax.xaxis.set_ticklabels([renamer[l] for l in x_ticklabels]) + + plg.plt.tight_layout() + plg.plt.savefig(os.path.join(plot_dir, 'data_stats_prostate.png'), dpi=600) + + return + +def lidc_merged_sa_joint_plot(exp_dir='', plot_dir=None): + import datasets.lidc.data_loader as dl + cf = get_cf('lidc', exp_dir) + cf.balance_target = "regression_targets" + + if plot_dir is None: + plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'lidc', 'misc') + + cf.training_gts = 'merged' + dataset = dl.Dataset(cf, mode='train') + df, labels = dataset.calc_statistics(plot_dir=None, overall_stats=True) + + fig = plg.plt.figure(figsize=(4, 5.6)) #w, h + # fig.subplots_adjust(hspace=0, wspace=0) + grid = plg.plt.GridSpec(3, 1, wspace=0.0, hspace=0.7, figure=fig) #rows, cols + fs = 9 + + ax = fig.add_subplot(grid[0, 0]) + + labels = [AttributeDict({ 'name': rg_val, 'color': cf.bin_id2label[cf.rg_val_to_bin_id(rg_val)].color}) for rg_val + in df.columns] + ax = plg.plot_data_stats(cf, df, labels, ax=ax, fs=fs) + ax.set_xlabel("averaged multi-rater malignancy scores (ms)", fontsize=fs) + ax.set_title("") + ax.text(0.05, 0.91, 'a)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, + weight='bold', fontsize=fs) + ax.margins(y=0.2) + + #----- single annotator ------- + cf.training_gts = 'sa' + dataset = dl.Dataset(cf, mode='train') + df, labels = dataset.calc_statistics(plot_dir=None, overall_stats=True) + + ax = fig.add_subplot(grid[1, 0]) + labels = [AttributeDict({ 'name': '{:.0f}'.format(rg_val), 'color': cf.bin_id2label[cf.rg_val_to_bin_id(rg_val)].color}) for rg_val + in df.columns] + mapper = {rg_val:'{:.0f}'.format(rg_val) for rg_val in df.columns} + df = df.rename(mapper, axis=1) + ax = plg.plot_data_stats(cf, df, labels, ax=ax, fs=fs) + ax.set_xlabel("unaggregrated single-rater malignancy scores (ms)", fontsize=fs) + ax.set_title("") + ax.text(0.05, 0.91, 'b)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, + weight='bold', fontsize=fs) + ax.margins(y=0.45) + + #------ binned dissent ----- + #cf.balance_target = "regression_targets" + all_patients = [(pid,patient['rg_bin_targets']) for pid, patient in dataset.data.items()] + non_empty_patients = [(pid, lesions) for (pid, lesions) in all_patients if len(lesions) > 0] + + mean_std_per_lesion = np.array([(np.mean(roi), np.std(roi)) for (pid, lesions) in non_empty_patients for roi in lesions]) + distribution_max_per_lesion = [np.unique(roi, return_counts=True) for (pid, lesions) in non_empty_patients for roi in lesions] + distribution_max_per_lesion = np.array([uniq[cts.argmax()] for (uniq, cts) in distribution_max_per_lesion]) + + binned_stats = [[] for bin_id in cf.bin_id2rg_val.keys()] + for l_ix, mean_std in enumerate(mean_std_per_lesion): + bin_id = cf.rg_val_to_bin_id(mean_std[0]) + bin_id_max = cf.rg_val_to_bin_id(distribution_max_per_lesion[l_ix]) + binned_stats[int(bin_id)].append((*mean_std, distribution_max_per_lesion[l_ix], bin_id-bin_id_max)) + + ax = fig.add_subplot(grid[2, 0]) + plg.plot_binned_rater_dissent(cf, binned_stats, ax=ax, fs=fs) + ax.set_title("") + ax.text(0.05, 0.91, 'c)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, + weight='bold', fontsize=fs) + ax.margins(y=0.2) + + + plg.plt.savefig(os.path.join(plot_dir, 'data_stats_lidc_solarized.png'), bbox_inches='tight', dpi=600) + + return + +def lidc_dataset_stats(exp_dir='', plot_dir=None): + + import datasets.lidc.data_loader as dl + cf = get_cf('lidc', exp_dir) + cf.data_rootdir = cf.pp_data_path + cf.balance_target = "regression_targets" + + dataset = dl.Dataset(cf, data_dir=cf.data_rootdir) + if plot_dir is None: + plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'lidc', 'misc') + + df, labels = dataset.calc_statistics(plot_dir=plot_dir, overall_stats=True) + + return df, labels + +def lidc_sa_dataset_stats(exp_dir='', plot_dir=None): + + import datasets.lidc_sa.data_loader as dl + cf = get_cf('lidc_sa', exp_dir) + #cf.data_rootdir = cf.pp_data_path + cf.balance_target = "regression_targets" + + dataset = dl.Dataset(cf) + if plot_dir is None: + plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'lidc_sa', 'misc') + + dataset.calc_statistics(plot_dir=plot_dir, overall_stats=True) + + all_patients = [(pid,patient['rg_bin_targets']) for pid, patient in dataset.data.items()] + empty_patients = [pid for (pid, lesions) in all_patients if len(lesions) == 0] + non_empty_patients = [(pid, lesions) for (pid, lesions) in all_patients if len(lesions) > 0] + full_consent_patients = [(pid, lesions) for (pid, lesions) in non_empty_patients if np.all([np.unique(roi).size == 1 for roi in lesions])] + all_lesions = [roi for (pid, lesions) in non_empty_patients for roi in lesions] + two_vote_min = [roi for (pid, lesions) in non_empty_patients for roi in lesions if np.count_nonzero(roi) > 1] + three_vote_min = [roi for (pid, lesions) in non_empty_patients for roi in lesions if np.count_nonzero(roi) > 2] + mean_std_per_lesion = np.array([(np.mean(roi), np.std(roi)) for (pid, lesions) in non_empty_patients for roi in lesions]) + avg_mean_std_pl = np.mean(mean_std_per_lesion, axis=0) + # call std dev per lesion disconsent from now on + disconsent_std = np.std(mean_std_per_lesion[:, 1]) + + distribution_max_per_lesion = [np.unique(roi, return_counts=True) for (pid, lesions) in non_empty_patients for roi in lesions] + distribution_max_per_lesion = np.array([uniq[cts.argmax()] for (uniq, cts) in distribution_max_per_lesion]) + + mean_max_delta = abs(mean_std_per_lesion[:, 0] - distribution_max_per_lesion) + + binned_stats = [[] for bin_id in cf.bin_id2rg_val.keys()] + for l_ix, mean_std in enumerate(mean_std_per_lesion): + bin_id = cf.rg_val_to_bin_id(mean_std[0]) + bin_id_max = cf.rg_val_to_bin_id(distribution_max_per_lesion[l_ix]) + binned_stats[int(bin_id)].append((*mean_std, distribution_max_per_lesion[l_ix], bin_id-bin_id_max)) + + plg.plot_binned_rater_dissent(cf, binned_stats, out_file=os.path.join(plot_dir, "binned_dissent.png")) + + + mean_max_bin_divergence = [[] for bin_id in cf.bin_id2rg_val.keys()] + for bin_id, bin_stats in enumerate(binned_stats): + mean_max_bin_divergence[bin_id].append([roi for roi in bin_stats if roi[3] != 0]) + mean_max_bin_divergence[bin_id].insert(0,len(mean_max_bin_divergence[bin_id][0])) + + + return + +def lidc_annotator_confusion(exp_dir='', plot_dir=None, normalize=None, dataset=None, plot=True): + """ + :param exp_dir: + :param plot_dir: + :param normalize: str or None. str in ['truth', 'pred'] + :param dataset: + :param plot: + :return: + """ + if dataset is None: + import datasets.lidc.data_loader as dl + cf = get_cf('lidc', exp_dir) + # cf.data_rootdir = cf.pp_data_path + cf.training_gts = "sa" + cf.balance_target = "regression_targets" + dataset = dl.Dataset(cf) + else: + cf = dataset.cf + + if plot_dir is None: + plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'lidc', 'misc') + + dataset.calc_statistics(plot_dir=plot_dir, overall_stats=True) + + all_patients = [(pid,patient['rg_bin_targets']) for pid, patient in dataset.data.items()] + non_empty_patients = [(pid, lesions) for (pid, lesions) in all_patients if len(lesions) > 0] + + y_true, y_pred = [], [] + for (pid, lesions) in non_empty_patients: + for roi in lesions: + true_bin = cf.rg_val_to_bin_id(np.mean(roi)) + y_true.extend([true_bin] * len(roi)) + y_pred.extend(roi) + cm = confusion_matrix(y_true, y_pred) + if normalize in ["truth", "row"]: + cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] + elif normalize in ["pred", "prediction", "column", "col"]: + cm = cm.astype('float') / cm.sum(axis=0)[:, np.newaxis] + + if plot: + plg.plot_confusion_matrix(cf, cm, out_file=os.path.join(plot_dir, "annotator_confusion.pdf")) + + return cm + +def plot_lidc_dissent_and_example(confusion_matrix=True, bin_stds=False, plot_dir=None, numbering=True, example_title="Example"): + import datasets.lidc.data_loader as dl + dataset_name = 'lidc' + exp_dir1 = '/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/ms12345_mrcnn3d_rg_bs8' + exp_dir2 = '/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/ms12345_mrcnn3d_rgbin_bs8' + #exp_dir1 = '/home/gregor/networkdrives/E132-Cluster-Projects/lidc_sa/experiments/ms12345_mrcnn3d_rg_bs8' + #exp_dir2 = '/home/gregor/networkdrives/E132-Cluster-Projects/lidc_sa/experiments/ms12345_mrcnn3d_rgbin_bs8' + cf = get_cf(dataset_name, exp_dir1) + #file_names = [f_name for f_name in os.listdir(os.path.join(exp_dir, 'inference_analysis')) if f_name.endswith('.pkl')] + # file_names = [os.path.join(exp_dir, "inference_analysis", f_name) for f_name in file_names] + file_names = ["bytes_merged_boxes_fold_0_pid_0811a.pkl",] + z_ics = [194,] + plot_files = [ + {'files': [os.path.join(exp_dir, "inference_analysis", f_name) for exp_dir in [exp_dir1, exp_dir2]], + 'z_ix': z_ix} for (f_name, z_ix) in zip(file_names, z_ics) + ] + + cf.training_gts = 'sa' + info_df_path = '/mnt/HDD2TB/Documents/data/lidc/pp_20190805/patient_gts_{}/info_df.pickle'.format(cf.training_gts) + info_df = pd.read_pickle(info_df_path) + + cf.roi_items = ['regression_targets', 'rg_bin_targets_sa'] #['class_targets'] + cf.observables_rois + + text_fs = 14 + title_fs = text_fs + text_x, text_y = 0.06, 0.92 + fig = plg.plt.figure(figsize=(8.6, 3)) #w, h + #fig.subplots_adjust(hspace=0, wspace=0) + grid = plg.plt.GridSpec(1, 4, wspace=0.0, hspace=0.0, figure=fig) #rows, cols + cf.plot_class_ids = True + + f_ix = 0 + z_ix = plot_files[f_ix]['z_ix'] + for model_ix in range(2)[::-1]: + print("f_ix, m_ix", f_ix, model_ix) + plot_file = utils.load_obj(plot_files[f_ix]['files'][model_ix]) + batch = plot_file["batch"] + pid = batch["pid"][0] + batch['patient_rg_bin_targets_sa'] = info_df[info_df.pid == pid]['class_target'].tolist() + # apply same filter as with merged GTs: need at least two non-zero votes to consider a RoI. + batch['patient_rg_bin_targets_sa'] = [[four_votes.astype("uint8") for four_votes in batch_el if + np.count_nonzero(four_votes>0)>=2] for batch_el in + batch['patient_rg_bin_targets_sa']] + results_dict = plot_file["res_dict"] + + # pred + ax = fig.add_subplot(grid[0, model_ix+2]) + plg.view_batch_thesis(cf, batch, res_dict=results_dict, legend=False, sample_picks=None, fontsize=text_fs*1.3, + vol_slice_picks=[z_ix, ], show_gt_labels=True, box_score_thres=0.2, plot_mods=False, + seg_cmap="rg", show_cl_ids=False, + out_file=None, dpi=600, patient_items=True, return_fig=False, axes={'pred': ax}) + + #ax.set_title("{}".format("Reg R-CNN" if model_ix==0 else "Mask R-CNN"), size=title_fs) + ax.set_title("") + ax.set_xlabel("{}".format("Reg R-CNN" if model_ix == 0 else "Mask R-CNN"), size=title_fs) + if numbering: + ax.text(text_x, text_y, chr(model_ix+99)+")", horizontalalignment='center', verticalalignment='center', + transform=ax.transAxes, weight='bold', color=cf.white, fontsize=title_fs) + #ax.axis("off") + ax.axis("on") + plg.suppress_axes_lines(ax) + + # GT + if model_ix==0: + ax.set_title(example_title, fontsize=title_fs) + ax = fig.add_subplot(grid[0, 1]) + # ax.imshow(batch['patient_data'][0, 0, :, :, z_ix], cmap='gray') + # ax.imshow(plg.to_rgba(batch['patient_seg'][0,0,:,:,z_ix], cf.cmap), alpha=0.8) + plg.view_batch_thesis(cf, batch, res_dict=results_dict, legend=True, sample_picks=None, fontsize=text_fs*1.3, + vol_slice_picks=[z_ix, ], show_gt_labels=True, box_score_thres=0.13, plot_mods=False, seg_cmap="rg", + out_file=None, dpi=600, patient_items=True, return_fig=False, axes={'gt':ax}) + if numbering: + ax.text(text_x, text_y, "b)", horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, + weight='bold', color=cf.white, fontsize=title_fs) + #ax.set_title("Ground Truth", size=title_fs) + ax.set_title("") + ax.set_xlabel("Ground Truth", size=title_fs) + plg.suppress_axes_lines(ax) + #ax.axis('off') + #----- annotator dissent plot(s) ------ + + cf.training_gts = 'sa' + cf.balance_targets = 'rg_bin_targets' + dataset = dl.Dataset(cf, mode='train') + + if bin_stds: + #------ binned dissent ----- + #cf = get_cf('lidc', "") + + #cf.balance_target = "regression_targets" + all_patients = [(pid,patient['rg_bin_targets']) for pid, patient in dataset.data.items()] + non_empty_patients = [(pid, lesions) for (pid, lesions) in all_patients if len(lesions) > 0] + + mean_std_per_lesion = np.array([(np.mean(roi), np.std(roi)) for (pid, lesions) in non_empty_patients for roi in lesions]) + distribution_max_per_lesion = [np.unique(roi, return_counts=True) for (pid, lesions) in non_empty_patients for roi in lesions] + distribution_max_per_lesion = np.array([uniq[cts.argmax()] for (uniq, cts) in distribution_max_per_lesion]) + + binned_stats = [[] for bin_id in cf.bin_id2rg_val.keys()] + for l_ix, mean_std in enumerate(mean_std_per_lesion): + bin_id = cf.rg_val_to_bin_id(mean_std[0]) + bin_id_max = cf.rg_val_to_bin_id(distribution_max_per_lesion[l_ix]) + binned_stats[int(bin_id)].append((*mean_std, distribution_max_per_lesion[l_ix], bin_id-bin_id_max)) + + ax = fig.add_subplot(grid[0, 0]) + plg.plot_binned_rater_dissent(cf, binned_stats, ax=ax, fs=text_fs) + if numbering: + ax.text(text_x, text_y, 'a)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, + weight='bold', fontsize=title_fs) + ax.margins(y=0.2) + ax.set_xlabel("Malignancy-Score Bins", fontsize=title_fs) + #ax.yaxis.set_label_position("right") + #ax.yaxis.tick_right() + ax.set_yticklabels([]) + #ax.xaxis.set_label_position("top") + #ax.xaxis.tick_top() + ax.set_title("Average Rater Dissent", fontsize=title_fs) + + if confusion_matrix: + #------ confusion matrix ------- + cm = lidc_annotator_confusion(dataset=dataset, plot=False, normalize="truth") + ax = fig.add_subplot(grid[0, 0]) + cmap = plg.make_colormap([(1,1,1), cf.dkfz_blue]) + plg.plot_confusion_matrix(cf, cm, ax=ax, fs=text_fs, color_bar=False, cmap=cmap )#plg.plt.cm.Purples) + ax.set_xticks(np.arange(cm.shape[1])) + if numbering: + ax.text(-0.16, text_y, 'a)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, + weight='bold', fontsize=title_fs) + ax.margins(y=0.2) + ax.set_title("Annotator Dissent", fontsize=title_fs) + + #fig.suptitle(" Example", fontsize=title_fs) + #fig.text(0.63, 1.03, "Example", va="center", ha="center", size=title_fs, transform=fig.transFigure) + + #fig_patches = fig_leg.get_patches() + #patches= [plg.mpatches.Patch(color=label.color, label="{:.10s}".format(label.name)) for label in cf.bin_id2label.values() if label.id!=0] + #fig.legends.append(fig_leg) + #plg.plt.figlegend(handles=patches, loc="lower center", bbox_to_anchor=(0.5, 0.0), borderaxespad=0., + # ncol=len(patches), bbox_transform=fig.transFigure, title="Binned Malignancy Score", fontsize= text_fs) + plg.plt.tight_layout() + if plot_dir is None: + plot_dir = "datasets/lidc/misc" + out_file = os.path.join(plot_dir, "regrcnn_lidc_diss_example.png") + if out_file is not None: + plg.plt.savefig(out_file, dpi=600, bbox_inches='tight') + +def lidc_annotator_dissent_images(exp_dir='', plot_dir=None): + if plot_dir is None: + plot_dir = "datasets/lidc/misc" + + import datasets.lidc.data_loader as dl + cf = get_cf('lidc', exp_dir) + cf.training_gts = "sa" + + dataset = dl.Dataset(cf, mode='train') + + pids = {'0069a': 132, '0493a':125, '1008a': 164}#, '0355b': 138, '0484a': 86} # pid : (z_ix to show) + # add_pids = dataset.set_ids[65:80] + # for pid in add_pids: + # try: + # + # pids[pid] = int(np.median(dataset.data[pid]['fg_slices'][0])) + # + # except (IndexError, ValueError): + # print("pid {} has no foreground".format(pid)) + + if not os.path.exists(plot_dir): + os.mkdir(plot_dir) + out_file = os.path.join(plot_dir, "lidc_example_rater_dissent.png") + + #cf.training_gts = 'sa' + cf.roi_items = ['regression_targets', 'rg_bin_targets_sa'] #['class_targets'] + cf.observables_rois + + title_fs = 14 + text_fs = 14 + fig = plg.plt.figure(figsize=(10, 5.9)) #w, h + #fig.subplots_adjust(hspace=0, wspace=0) + grid = plg.plt.GridSpec(len(pids.keys()), 5, wspace=0.0, hspace=0.0, figure=fig) #rows, cols + cf.plot_class_ids = True + cmap = {id : (label.color if id!=0 else (0.,0.,0.)) for id, label in cf.bin_id2label.items()} + legend_handles = set() + window_size = (250,250) + + for p_ix, (pid, z_ix) in enumerate(pids.items()): + try: + print("plotting pid, z_ix", pid, z_ix) + patient = dataset[pid] + img = np.load(patient['data'], mmap_mode='r')[z_ix] # z,y,x --> y,x + seg = np.load(patient['seg'], mmap_mode='r')['seg'][:,z_ix] # rater,z,y,x --> rater,y,x + rg_bin_targets = patient['rg_bin_targets'] + + contours = np.nonzero(seg[0]) + center_y, center_x = np.median(contours[0]), np.median(contours[1]) + #min_y, min_x = np.min(contours[0]), np.min(contours[1]) + #max_y, max_x = np.max(contours[0]), np.max(contours[1]) + #buffer_y, buffer_x = int(seg.shape[1]*0.5), int(seg.shape[2]*0.5) + #y_range = np.arange(max(min_y-buffer_y, 0), min(min_y+buffer_y, seg.shape[1])) + #x_range = np.arange(max(min_x-buffer_x, 0), min(min_x+buffer_x, seg.shape[2])) + y_range = np.arange(max(int(center_y-window_size[0]/2), 0), min(int(center_y+window_size[0]/2), seg.shape[1])) + + min_x = int(center_x-window_size[1]/2) + max_x = int(center_x+window_size[1]/2) + if min_x<0: + max_x += abs(min_x) + elif max_x>seg.shape[2]: + min_x -= max_x-seg.shape[2] + x_range = np.arange(max(min_x, 0), min(max_x, seg.shape[2])) + img = img[y_range][:,x_range] + seg = seg[:, y_range][:,:,x_range] + # data + ax = fig.add_subplot(grid[p_ix, 0]) + ax.imshow(img, cmap='gray') + + plg.suppress_axes_lines(ax) + # key = "spec" if "spec" in batch.keys() else "pid" + ylabel = str(pid) + "/" + str(z_ix) + ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number + if p_ix == 0: + ax.set_title("Image", fontsize=title_fs) + + # raters + for r_ix in range(seg.shape[0]): + rater_bin_targets = rg_bin_targets[:,r_ix] + for roi_ix, rating in enumerate(rater_bin_targets): + seg[r_ix][seg[r_ix]==roi_ix+1] = rating + ax = fig.add_subplot(grid[p_ix, r_ix+1]) + ax.imshow(seg[r_ix], cmap='gray') + ax.imshow(plg.to_rgba(seg[r_ix], cmap), alpha=0.8) + ax.axis('off') + if p_ix == 0: + ax.set_title("Rating {}".format(r_ix+1), fontsize=title_fs) + legend_handles.update([cf.bin_id2label[id] for id in np.unique(seg[r_ix]) if id!=0]) + except: + print("failed pid", pid) + pass + + legend_handles = [plg.mpatches.Patch(color=label.color, label="{:.10s}".format(label.name)) for label in legend_handles] + legend_handles = sorted(legend_handles, key=lambda h: h._label) + fig.suptitle("LIDC Single-Rater Annotations", fontsize=title_fs) + #patches= [plg.mpatches.Patch(color=label.color, label="{:.10s}".format(label.name)) for label in cf.bin_id2label.values() if label.id!=0] + + legend = fig.legend(handles=legend_handles, loc="lower center", bbox_to_anchor=(0.5, 0.0), borderaxespad=0, fontsize=text_fs, + bbox_transform=fig.transFigure, ncol=len(legend_handles), title="Malignancy Score") + plg.plt.setp(legend.get_title(), fontsize=title_fs) + #grid.tight_layout(fig) + #plg.plt.tight_layout(rect=[0, 0.00, 1, 1.5]) + if out_file is not None: + plg.plt.savefig(out_file, dpi=600, bbox_inches='tight') + + + + return + +def lidc_results_static(xlabels=None, plot_dir=None, in_percent=True): + cf = get_cf('lidc', '') + if plot_dir is None: + plot_dir = os.path.join('datasets', 'lidc', 'misc') + + text_fs = 18 + fig = plg.plt.figure(figsize=(3, 2.5)) #w,h + grid = plg.plt.GridSpec(2, 1, wspace=0.0, hspace=0.0, figure=fig) #r,c + + #--- LIDC 3D ----- + + + splits = ["Reg R-CNN", "Mask R-CNN"]#, "Reg R-CNN 2D", "Mask R-CNN 2D"] + values = {"reg3d": [(0.259, 0.035), (0.628, 0.038), (0.477, 0.035)], + "mask3d": [(0.235, 0.027), (0.622, 0.029), (0.411, 0.026)],} + groups = [r"$\mathrm{AVP}_{10}$", "$\mathrm{AP}_{10}$", "Bin Acc."] + if in_percent: + bar_values = [[v[0]*100 for v in split] for split in values.values()] + errors = [[v[1]*100 for v in split] for split in values.values()] + else: + bar_values = [[v[0] for v in split] for split in values.values()] + errors = [[v[1] for v in split] for split in values.values()] + + ax = fig.add_subplot(grid[0,0]) + colors = [cf.blue, cf.dkfz_blue] + plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, errors=errors, colors=colors, ax=ax, legend=False, label_format="{:.1f}", + title="LIDC Results", ylabel=r"3D Perf. (%)", xlabel="Metric", yticklabels=[], ylim=(0,80 if in_percent else 0.8)) + #------ LIDC 2D ------- + + splits = ["Reg R-CNN", "Mask R-CNN"] + values = {"reg2d": [(0.148, 0.046), (0.414, 0.052), (0.468, 0.057)], + "mask2d": [(0.127, 0.034), (0.406, 0.040), (0.447, 0.018)]} + groups = [r"$\mathrm{AVP}_{10}$", "$\mathrm{AP}_{10}$", "Bin Acc."] + if in_percent: + bar_values = [[v[0]*100 for v in split] for split in values.values()] + errors = [[v[1]*100 for v in split] for split in values.values()] + else: + bar_values = [[v[0] for v in split] for split in values.values()] + errors = [[v[1] for v in split] for split in values.values()] + ax = fig.add_subplot(grid[1,0]) + colors = [cf.blue, cf.dkfz_blue] + plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, errors=errors, colors=colors, ax=ax, legend=False, label_format="{:.1f}", + title="", ylabel=r"2D Perf.", xlabel="Metric", xticklabels=xlabels, yticklabels=[], ylim=(None,60 if in_percent else 0.6)) + plg.plt.tight_layout() + plg.plt.savefig(os.path.join(plot_dir, 'lidc_static_results.png'), dpi=700) + +def toy_results_static(xlabels=None, plot_dir=None, in_percent=True): + cf = get_cf('toy', '') + if plot_dir is None: + plot_dir = os.path.join('datasets', 'toy', 'misc') + + text_fs = 18 + fig = plg.plt.figure(figsize=(3, 2.5)) #w,h + grid = plg.plt.GridSpec(2, 1, wspace=0.0, hspace=0.0, figure=fig) #r,c + + #--- Toy 3D ----- + groups = [r"$\mathrm{AVP}_{10}$", "$\mathrm{AP}_{10}$", "Bin Acc."] + splits = ["Reg R-CNN", "Mask R-CNN"]#, "Reg R-CNN 2D", "Mask R-CNN 2D"] + values = {"reg3d": [(0.881, 0.014), (0.998, 0.004), (0.887, 0.014)], + "mask3d": [(0.822, 0.070), (1.0, 0.0), (0.826, 0.069)],} + if in_percent: + bar_values = [[v[0]*100 for v in split] for split in values.values()] + errors = [[v[1]*100 for v in split] for split in values.values()] + else: + bar_values = [[v[0] for v in split] for split in values.values()] + errors = [[v[1] for v in split] for split in values.values()] + ax = fig.add_subplot(grid[0,0]) + colors = [cf.blue, cf.dkfz_blue] + plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, errors=errors, colors=colors, ax=ax, legend=True, label_format="{:.1f}", + title="Toy Results", ylabel=r"3D Perf. (%)", xlabel="Metric", yticklabels=[], ylim=(0,130 if in_percent else .3)) + #------ Toy 2D ------- + groups = [r"$\mathrm{AVP}_{10}$", "$\mathrm{AP}_{10}$", "Bin Acc."] + splits = ["Reg R-CNN", "Mask R-CNN"] + values = {"reg2d": [(0.859, 0.021), (1., 0.0), (0.860, 0.021)], + "mask2d": [(0.748, 0.022), (1., 0.0), (0.748, 0.021)]} + if in_percent: + bar_values = [[v[0]*100 for v in split] for split in values.values()] + errors = [[v[1]*100 for v in split] for split in values.values()] + else: + bar_values = [[v[0] for v in split] for split in values.values()] + errors = [[v[1] for v in split] for split in values.values()] + ax = fig.add_subplot(grid[1,0]) + colors = [cf.blue, cf.dkfz_blue] + plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, errors=errors, colors=colors, ax=ax, legend=False, label_format="{:.1f}", + title="", ylabel=r"2D Perf.", xlabel="Metric", xticklabels=xlabels, yticklabels=[], ylim=(None,130 if in_percent else 1.3)) + plg.plt.tight_layout() + plg.plt.savefig(os.path.join(plot_dir, 'toy_static_results.png'), dpi=700) + +def analyze_test_df(dataset_name, exp_dir='', cf=None, logger=None, plot_dir=None): + evaluator_file = utils.import_module('evaluator', "evaluator.py") + if cf is None: + cf = get_cf(dataset_name, exp_dir) + cf.exp_dir = exp_dir + cf.test_dir = os.path.join(exp_dir, 'test') + if logger is None: + logger = utils.get_logger(cf.exp_dir, False) + evaluator = evaluator_file.Evaluator(cf, logger, mode='test') + + fold_df_paths = sorted([ii for ii in os.listdir(cf.test_dir) if 'test_df.pkl' in ii]) + fold_seg_df_paths = sorted([ii for ii in os.listdir(cf.test_dir) if 'test_seg_df.pkl' in ii]) + metrics_to_score = ['ap', 'auc']#, 'patient_ap', 'patient_auc', 'patient_dice'] #'rg_bin_accuracy_weighted_tp', 'rg_MAE_w_std_weighted_tp'] #cf.metrics + if cf.evaluate_fold_means: + means_to_score = [m for m in metrics_to_score] #+ ['rg_MAE_w_std_weighted_tp'] + #metrics_to_score += ['rg_MAE_std'] + metrics_to_score = [] + + + cf.fold = 'overall' + dfs_list = [pd.read_pickle(os.path.join(cf.test_dir, ii)) for ii in fold_df_paths] + evaluator.test_df = pd.concat(dfs_list, sort=True) + + seg_dfs_list = [pd.read_pickle(os.path.join(cf.test_dir, ii)) for ii in fold_seg_df_paths] + if len(seg_dfs_list) > 0: + evaluator.seg_df = pd.concat(seg_dfs_list, sort=True) + + # stats, _ = evaluator.return_metrics(evaluator.test_df, cf.class_dict) + # results_table_path = os.path.join(cf.exp_dir, "../", "semi_man_summary.csv") + # # ---column headers--- + # col_headers = ["Experiment Name", "CV Folds", "Spatial Dim", "Clustering Kind", "Clustering IoU", "Merge-2D-to-3D IoU"] + # if hasattr(cf, "test_against_exact_gt"): + # col_headers.append('Exact GT') + # for s in stats: + # assert "overall" in s['name'].split(" ")[0] + # if cf.class_dict[cf.patient_class_of_interest] in s['name']: + # for metric in metrics_to_score: + # #if metric in s.keys() and not np.isnan(s[metric]): + # col_headers.append('{}_{} : {}'.format(*s['name'].split(" ")[1:], metric)) + # for mean in means_to_score: + # if mean == "rg_MAE_w_std_weighted_tp": + # col_headers.append('(MAE_fold_mean\u00B1std_fold_mean)\u00B1fold_mean_std\u00B1fold_std_std)'.format(*s['name'].split(" ")[1:], mean)) + # elif mean in s.keys() and not np.isnan(s[mean]): + # col_headers.append('{}_{} : {}'.format(*s['name'].split(" ")[1:], mean)) + # else: + # print("skipping {}".format(mean)) + # with open(results_table_path, 'a') as handle: + # with open(results_table_path, 'r') as doublehandle: + # last_header = doublehandle.readlines() + # if len(last_header)==0 or len(col_headers)!=len(last_header[1].split(",")[:-1]) or \ + # not all([col_headers[ix]==lhix for ix, lhix in enumerate(last_header[1].split(",")[:-1])]): + # handle.write('\n') + # for head in col_headers: + # handle.write(head+',') + # handle.write('\n') + # + # # --- columns content--- + # handle.write('{},'.format(cf.exp_dir.split(os.sep)[-1])) + # handle.write('{},'.format(str(evaluator.test_df.fold.unique().tolist()).replace(",", ""))) + # handle.write('{}D,'.format(cf.dim)) + # handle.write('{},'.format(cf.clustering)) + # handle.write('{},'.format(cf.clustering_iou if cf.clustering else str("N/A"))) + # handle.write('{},'.format(cf.merge_3D_iou if cf.merge_2D_to_3D_preds else str("N/A"))) + # if hasattr(cf, "test_against_exact_gt"): + # handle.write('{},'.format(cf.test_against_exact_gt)) + # for s in stats: + # if cf.class_dict[cf.patient_class_of_interest] in s['name']: + # for metric in metrics_to_score: + # #if metric in s.keys() and not np.isnan(s[metric]): # needed as long as no dice on patient level poss + # handle.write('{:0.3f}, '.format(s[metric])) + # for mean in means_to_score: + # #if metric in s.keys() and not np.isnan(s[metric]): + # if mean=="rg_MAE_w_std_weighted_tp": + # handle.write('({:0.3f}\u00B1{:0.3f})\u00B1({:0.3f}\u00B1{:0.3f}),'.format(*s[mean + "_folds_mean"], *s[mean + "_folds_std"])) + # elif mean in s.keys() and not np.isnan(s[mean]): + # handle.write('{:0.3f}\u00B1{:0.3f},'.format(s[mean+"_folds_mean"], s[mean+"_folds_std"])) + # else: + # print("skipping {}".format(mean)) + # + # handle.write('\n') + + return evaluator.test_df + +def cluster_results_to_df(dataset_name, exp_dir='', overall_df=None, cf=None, logger=None, plot_dir=None): + evaluator_file = utils.import_module('evaluator', "evaluator.py") + if cf is None: + cf = get_cf(dataset_name, exp_dir) + cf.exp_dir = exp_dir + cf.test_dir = os.path.join(exp_dir, 'test') + if logger is None: + logger = utils.get_logger(cf.exp_dir, False) + evaluator = evaluator_file.Evaluator(cf, logger, mode='test') + cf.fold = 'overall' + metrics_to_score = ['ap', 'auc']#, 'patient_ap', 'patient_auc', 'patient_dice'] #'rg_bin_accuracy_weighted_tp', 'rg_MAE_w_std_weighted_tp'] #cf.metrics + if cf.evaluate_fold_means: + means_to_score = [m for m in metrics_to_score] #+ ['rg_MAE_w_std_weighted_tp'] + #metrics_to_score += ['rg_MAE_std'] + metrics_to_score = [] + + # use passed overall_df or, if not given, read dfs from file + if overall_df is None: + fold_df_paths = sorted([ii for ii in os.listdir(cf.test_dir) if 'test_df.pkl' in ii]) + fold_seg_df_paths = sorted([ii for ii in os.listdir(cf.test_dir) if 'test_seg_df.pkl' in ii]) + for paths in [fold_df_paths, fold_seg_df_paths]: + assert len(paths) <= cf.n_cv_splits, "found {} > nr of cv splits results dfs in {}".format(len(paths), cf.test_dir) + dfs_list = [pd.read_pickle(os.path.join(cf.test_dir, ii)) for ii in fold_df_paths] + evaluator.test_df = pd.concat(dfs_list, sort=True) + + # seg_dfs_list = [pd.read_pickle(os.path.join(cf.test_dir, ii)) for ii in fold_seg_df_paths] + # if len(seg_dfs_list) > 0: + # evaluator.seg_df = pd.concat(seg_dfs_list, sort=True) + + else: + evaluator.test_df = overall_df + # todo seg_df if desired + + stats, _ = evaluator.return_metrics(evaluator.test_df, cf.class_dict) + # ---column headers--- + col_headers = ["Experiment Name", "Model", "CV Folds", "Spatial Dim", "Clustering Kind", "Clustering IoU", "Merge-2D-to-3D IoU"] + for s in stats: + assert "overall" in s['name'].split(" ")[0] + if cf.class_dict[cf.patient_class_of_interest] in s['name']: + for metric in metrics_to_score: + #if metric in s.keys() and not np.isnan(s[metric]): + col_headers.append('{}_{} : {}'.format(*s['name'].split(" ")[1:], metric)) + for mean in means_to_score: + if mean in s.keys() and not np.isnan(s[mean]): + col_headers.append('{}_{} : {}'.format(*s['name'].split(" ")[1:], mean+"_folds_mean")) + else: + print("skipping {}".format(mean)) + results_df = pd.DataFrame(columns=col_headers) + # --- columns content--- + row = [] + row.append('{}'.format(cf.exp_dir.split(os.sep)[-1])) + model = 'frcnn' if (cf.model=="mrcnn" and cf.frcnn_mode) else cf.model + row.append('{}'.format(model)) + row.append('{}'.format(str(evaluator.test_df.fold.unique().tolist()).replace(",", ""))) + row.append('{}D'.format(cf.dim)) + row.append('{}'.format(cf.clustering)) + row.append('{}'.format(cf.clustering_iou if cf.clustering else "N/A")) + row.append('{}'.format(cf.merge_3D_iou if cf.merge_2D_to_3D_preds else "N/A")) + for s in stats: + if cf.class_dict[cf.patient_class_of_interest] in s['name']: + for metric in metrics_to_score: + #if metric in s.keys() and not np.isnan(s[metric]): # needed as long as no dice on patient level poss + row.append('{:0.3f} '.format(s[metric])) + for mean in means_to_score: + #if metric in s.keys() and not np.isnan(s[metric]): + if mean+"_folds_mean" in s.keys() and not np.isnan(s[mean+"_folds_mean"]): + row.append('{:0.3f}\u00B1{:0.3f}'.format(s[mean+"_folds_mean"], s[mean+"_folds_std"])) + else: + print("skipping {}".format(mean+"_folds_mean")) + #print("row, clustering, iou, exp", row, cf.clustering, cf.clustering_iou, cf.exp_dir) + results_df.loc[0] = row + + return results_df + +def multiple_clustering_results(dataset_name, exp_dir, plot_dir=None, plot_hist=False): + print("Gathering exp {}".format(exp_dir)) + cf = get_cf(dataset_name, exp_dir) + cf.n_workers = 1 + logger = logging.getLogger("dummy") + logger.setLevel(logging.DEBUG) + #logger.addHandler(logging.StreamHandler()) + cf.exp_dir = exp_dir + cf.test_dir = os.path.join(exp_dir, 'test') + cf.plot_prediction_histograms = False + if plot_dir is None: + #plot_dir = os.path.join(cf.test_dir, 'histograms') + plot_dir = os.path.join("datasets", dataset_name, "misc") + os.makedirs(plot_dir, exist_ok=True) + + # fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if + # os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")]) + folds = range(cf.n_cv_splits) + clusterings = {None: ['lol'], 'wbc': [0.0, 0.1, 0.2, 0.3, 0.4], 'nms': [0.0, 0.1, 0.2, 0.3, 0.4]} + #clusterings = {'wbc': [0.1,], 'nms': [0.1,]} + #clusterings = {None: ['lol']} + if plot_hist: + clusterings = {None: ['lol'], 'nms': [0.1, ], 'wbc': [0.1, ]} + class_of_interest = cf.patient_class_of_interest + + try: + if plot_hist: + title_fs, text_fs = 16, 13 + fig = plg.plt.figure(figsize=(11, 8)) #width, height + grid = plg.plt.GridSpec(len(clusterings.keys()), max([len(v) for v in clusterings.values()])+1, wspace=0.0, + hspace=0.0, figure=fig) #rows, cols + plg.plt.suptitle("Faster R-CNN+", fontsize=title_fs, va='bottom', y=0.925) + + results_df = pd.DataFrame() + for cl_ix, (clustering, ious) in enumerate(clusterings.items()): + cf.clustering = clustering + for iou_ix, iou in enumerate(ious): + cf.clustering_iou = iou + print(r"Producing Results for Clustering {} @ IoU {}".format(cf.clustering, cf.clustering_iou)) + overall_test_df = pd.DataFrame() + for fold in folds[:]: + cf.fold = fold + cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold)) + + predictor = predictor_file.Predictor(cf, net=None, logger=logger, mode='analysis') + results_list = predictor.load_saved_predictions() + logger.info('starting evaluation...') + evaluator = evaluator_file.Evaluator(cf, logger, mode='test') + evaluator.evaluate_predictions(results_list) + #evaluator.score_test_df(max_fold=100) + overall_test_df = overall_test_df.append(evaluator.test_df) + + results_df = results_df.append(cluster_results_to_df(dataset_name, overall_df=overall_test_df,cf=cf, + logger=logger)) + + if plot_hist: + if clustering=='wbc' and iou_ix==len(ious)-1: + # plot n_missing histogram for last wbc clustering only + out_filename = os.path.join(plot_dir, 'analysis_n_missing_overall_hist_{}_{}.png'.format(clustering, iou)) + ax = fig.add_subplot(grid[cl_ix, iou_ix+1]) + plg.plot_wbc_n_missing(cf, overall_test_df, outfile=out_filename, fs=text_fs, ax=ax) + ax.set_title("WBC Missing Predictions per Cluster.", fontsize=title_fs) + #ax.set_ylabel(r"Average Missing Preds per Cluster (%)") + ax.yaxis.tick_right() + ax.yaxis.set_label_position("right") + ax.text(0.07, 0.87, "{}) WBC".format(chr(len(clusterings.keys())*len(ious)+97)), transform=ax.transAxes, color=cf.white, fontsize=title_fs, + bbox=dict(boxstyle='square', facecolor='black', edgecolor='none', alpha=0.9)) + overall_test_df = overall_test_df[overall_test_df.pred_class == class_of_interest] + overall_test_df = overall_test_df[overall_test_df.det_type!='patient_tn'] + out_filename = "analysis_fold_overall_hist_{}_{}.png".format(clustering, iou) + out_filename = os.path.join(plot_dir, out_filename) + ax = fig.add_subplot(grid[cl_ix, iou_ix]) + plg.plot_prediction_hist(cf, overall_test_df, out_filename, fs=text_fs, ax=ax) + ax.text(0.11, 0.87, "{}) {}".format(chr((cl_ix+1)*len(ious)+96), clustering.upper() if clustering else "Raw Preds"), transform=ax.transAxes, color=cf.white, + bbox=dict(boxstyle='square', facecolor='black', edgecolor='none', alpha=0.9), fontsize=title_fs) + if cl_ix==0 and iou_ix==0: + ax.set_title("Prediction Histograms Malignant Class", fontsize=title_fs) + ax.legend(loc="best", fontsize=text_fs) + else: + ax.set_title("") + #analyze_test_df(dataset_name, cf=cf, logger=logger) + if plot_hist: + #plg.plt.subplots_adjust(top=0.) + plg.plt.savefig(os.path.join(plot_dir, "combined_hist_plot.pdf"), dpi=600, bbox_inches='tight') + + except FileNotFoundError as e: + print("Ignoring exp dir {} due to\n{}".format(exp_dir, e)) + logger.handlers = [] + del cf; del logger + return results_df + +def gather_clustering_results(dataset_name, exp_parent_dir, exps_filter=None, processes=os.cpu_count()//2): + exp_dirs = [os.path.join(exp_parent_dir, i) for i in os.listdir(exp_parent_dir + "/") if + os.path.isdir(os.path.join(exp_parent_dir, i))]#[:1] + if exps_filter is not None: + exp_dirs = [ed for ed in exp_dirs if not exps_filter in ed] + # for debugging + #exp_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_frcnn3d_cl_bs6" + #exp_dirs = [exp_dir,] + #exp_dirs = ["/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_detfpn2d_cl_bs10",] + + results_df = pd.DataFrame() + + p = NoDaemonProcessPool(processes=processes) + mp_inputs = [(dataset_name, exp_dir) for exp_dir in exp_dirs][:] + results_dfs = p.starmap(multiple_clustering_results, mp_inputs) + p.close() + p.join() + for df in results_dfs: + results_df = results_df.append(df) + + results_df.to_csv(os.path.join(exp_parent_dir, "df_cluster_summary.csv"), index=False) + + return results_df + +def plot_cluster_results_grid(cf, res_df, ylim=None, out_file=None): + """ + :param cf: + :param res_df: results over a single dimension setting (2D or 3D), over all clustering methods and ious. + :param out_file: + :return: + """ + is_2d = np.all(res_df["Spatial Dim"]=="2D") + # pandas has problems with recognising "N/A" string --> replace by None + #res_df['Merge-2D-to-3D IoU'].iloc[res_df['Merge-2D-to-3D IoU'] == "N/A"] = None + n_rows = 3#4 if is_2d else 3 + grid = plg.plt.GridSpec(n_rows, 5, wspace=0.4, hspace=0.3) + + fig = plg.plt.figure(figsize=(11,6)) + + splits = res_df["Model"].unique().tolist() # need to be model names + for split in splits: + assoc_exps = res_df[res_df["Model"]==split]["Experiment Name"].unique() + if len(assoc_exps)>1: + print("Model {} has multiple experiments:\n{}".format(split, assoc_exps)) + #res_df = res_df.where(~(res_df["Model"] == split), res_df["Experiment Name"], axis=0) + raise Exception("Multiple Experiments") + + sort_map = {'detection_fpn': 0, 'mrcnn':1, 'frcnn':2, 'retina_net':3, 'retina_unet':4} + splits.sort(key=sort_map.__getitem__) + #colors = [cf.color_palette[ix+3 % len(cf.color_palette)] for ix in range(len(splits))] + color_map = {'detection_fpn': cf.magenta, 'mrcnn':cf.blue, 'frcnn': cf.dark_blue, 'retina_net': cf.aubergine, 'retina_unet': cf.purple} + + colors = [color_map[split] for split in splits] + alphas = [0.9,] * len(splits) + legend_handles = [] + model_renamer = {'detection_fpn': "Detection U-Net", 'mrcnn': "Mask R-CNN", 'frcnn': "Faster R-CNN+", 'retina_net': "RetinaNet", 'retina_unet': "Retina U-Net"} + + for rix, c_kind in zip([0, 1],['wbc', 'nms']): + kind_df = res_df[res_df['Clustering Kind'] == c_kind] + groups = kind_df['Clustering IoU'].unique() + #for cix, iou in enumerate(groups): + assert np.all([split in splits for split in kind_df["Model"].unique()]) #need to be model names + ax = fig.add_subplot(grid[rix,:]) + bar_values = [kind_df[kind_df["Model"]==split]["rois_malignant : ap_folds_mean"] for split in splits] + bar_stds = [[float(val.split('\u00B1')[1]) for val in split_vals] for split_vals in bar_values] + bar_values = [ [float(val.split('\u00B1')[0]) for val in split_vals] for split_vals in bar_values ] + + + xlabel='' if rix == 0 else "Clustering IoU" + ylabel = str(c_kind.upper()) + " / AP" + lh = plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, colors=colors, alphas=alphas, errors=bar_stds, + ax=ax, ylabel=ylabel, xlabel=xlabel) + legend_handles.append(lh) + if rix == 0: + ax.axes.get_xaxis().set_ticks([]) + #ax.spines['top'].set_visible(False) + #ax.spines['right'].set_visible(False) + ax.spines['bottom'].set_visible(False) + #ax.spines['left'].set_visible(False) + else: + ax.spines['top'].set_visible(False) + #ticklab = ax.xaxis.get_ticklabels() + #trans = ticklab.get_transform() + ax.xaxis.set_label_coords(0.05, -0.05) + ax.set_ylim(0.,ylim) + + if is_2d: + # only 2d-3d merging @ 0.1 + ax = fig.add_subplot(grid[2, 1]) + kind_df = res_df[(res_df['Clustering Kind'] == 'None') & ~(res_df['Merge-2D-to-3D IoU'].isna())] + groups = kind_df['Clustering IoU'].unique() + bar_values = [kind_df[kind_df["Model"] == split]["rois_malignant : ap_folds_mean"] for split in splits] + bar_stds = [[float(val.split('\u00B1')[1]) for val in split_vals] for split_vals in bar_values] + bar_values = np.array([[float(val.split('\u00B1')[0]) for val in split_vals] for split_vals in bar_values]) + lh = plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, colors=colors, alphas=alphas, errors=bar_stds, + ax=ax, ylabel="2D-3D Merging\nOnly / AP") + legend_handles.append(lh) + ax.axes.get_xaxis().set_ticks([]) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.spines['bottom'].set_visible(False) + ax.spines['left'].set_visible(False) + ax.set_ylim(0., ylim) + + next_row = 2 + next_col = 2 + else: + next_row = 2 + next_col = 2 + + # No clustering at all + ax = fig.add_subplot(grid[next_row, next_col]) + kind_df = res_df[(res_df['Clustering Kind'] == 'None') & (res_df['Merge-2D-to-3D IoU'].isna())] + groups = kind_df['Clustering IoU'].unique() + bar_values = [kind_df[kind_df["Model"] == split]["rois_malignant : ap_folds_mean"] for split in splits] + bar_stds = [[float(val.split('\u00B1')[1]) for val in split_vals] for split_vals in bar_values] + bar_values = np.array([[float(val.split('\u00B1')[0]) for val in split_vals] for split_vals in bar_values]) + lh = plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, colors=colors, alphas=alphas, errors=bar_stds, + ax=ax, ylabel="No Clustering / AP") + legend_handles.append(lh) + #plg.suppress_axes_lines(ax) + #ax = fig.add_subplot(grid[next_row, 0]) + #ax.set_ylabel("No Clustering") + #plg.suppress_axes_lines(ax) + ax.axes.get_xaxis().set_ticks([]) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.spines['bottom'].set_visible(False) + ax.spines['left'].set_visible(False) + ax.set_ylim(0., ylim) + + + ax = fig.add_subplot(grid[next_row, 3]) + # awful hot fix: only legend_handles[0] used in order to have same order as in plots. + legend_handles = [plg.mpatches.Patch(color=handle[0], alpha=handle[1], label=model_renamer[handle[2]]) for handle in legend_handles[0]] + ax.legend(handles=legend_handles) + ax.axis('off') + + fig.suptitle('Prostate {} Results over Clustering Settings'.format(res_df["Spatial Dim"].unique().item()), fontsize=14) + + if out_file is not None: + plg.plt.savefig(out_file) + + return + +def get_plot_clustering_results(dataset_name, exp_parent_dir, res_from_file=True, exps_filter=None): + if not res_from_file: + results_df = gather_clustering_results(dataset_name, exp_parent_dir, exps_filter=exps_filter) + else: + results_df = pd.read_csv(os.path.join(exp_parent_dir, "df_cluster_summary.csv")) + if os.path.isfile(os.path.join(exp_parent_dir, "df_cluster_summary_no_clustering_2D.csv")): + results_df = results_df.append(pd.read_csv(os.path.join(exp_parent_dir, "df_cluster_summary_no_clustering_2D.csv"))) + + cf = get_cf(dataset_name) + if np.count_nonzero(results_df["Spatial Dim"] == "3D") >0: + # 3D + plot_cluster_results_grid(cf, results_df[results_df["Spatial Dim"] == "3D"], ylim=0.52, out_file=os.path.join(exp_parent_dir, "cluster_results_3D.pdf")) + if np.count_nonzero(results_df["Spatial Dim"] == "2D") > 0: + # 2D + plot_cluster_results_grid(cf, results_df[results_df["Spatial Dim"]=="2D"], ylim=0.4, out_file=os.path.join(exp_parent_dir, "cluster_results_2D.pdf")) + + +def plot_single_results(cf, exp_dir, plot_files, res_df=None): + out_file = os.path.join(exp_dir, "inference_analysis", "single_results.pdf") + + plot_files = utils.load_obj(plot_files) + batch = plot_files["batch"] + results_dict = plot_files["res_dict"] + cf.roi_items = ['class_targets'] + + class_renamer = {1: "GS 6", 2: "GS $\geq 7$"} + gs_renamer = {60: "6", 71: "7a"} + + if "adcb" in exp_dir: + modality = "adcb" + elif "t2" in exp_dir: + modality = "t2" + else: + modality = "b" + text_fs = 16 + + if modality=="t2": + n_rows, n_cols = 2, 3 + gt_col = 1 + fig_w, fig_h = 14, 4 + input_x, input_y = 0.05, 0.9 + z_ix = 11 + thresh = 0.22 + input_title = "Input" + elif modality=="b": + n_rows, n_cols = 2, 6 + gt_col = 2 # = gt_span + fig_w, fig_h = 14, 4 + input_x, input_y = 0.08, 0.8 + z_ix = 8 + thresh = 0.16 + input_title = " Input" + elif modality=="adcb": + n_rows, n_cols = 2, 7 + gt_col = 3 + fig_w, fig_h = 14, 4 + input_x, input_y = 0.08, 0.8 + z_ix = 8 + thresh = 0.16 + input_title = "Input" + fig_w, fig_h = 12, 3.87 + fig = plg.plt.figure(figsize=(fig_w, fig_h)) + grid = plg.plt.GridSpec(n_rows, n_cols, wspace=0.0, hspace=0.0, figure=fig) + cf.plot_class_ids = True + + if modality=="t2": + ax = fig.add_subplot(grid[:, 0]) + ax.imshow(batch['patient_data'][0, 0, :, :, z_ix], cmap='gray') + ax.set_title("Input", size=text_fs) + ax.text(0.05, 0.9, "T2", size=text_fs, color=cf.white, transform=ax.transAxes, + bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7)) + ax.axis("off") + elif modality=="b": + for m_ix, b in enumerate([50, 500, 1000, 1500]): + ax = fig.add_subplot(grid[int(np.round(m_ix/4+0.0001)), m_ix%2]) + print(int(np.round(m_ix/4+0.0001)), m_ix%2) + ax.imshow(batch['patient_data'][0, m_ix, :, :, z_ix], cmap='gray') + ax.text(input_x, input_y, r"{}{}".format("$b=$" if m_ix==0 else "", b), size=text_fs, color=cf.white, transform=ax.transAxes, + bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7)) + ax.axis("off") + if b==50: + ax.set_title(input_title, size=text_fs) + elif modality=="adcb": + for m_ix, b in enumerate(["ADC", 50, 500, 1000, 1500]): + p_ix = m_ix + 1 if m_ix>2 else m_ix + ax = fig.add_subplot(grid[int(np.round(p_ix/6+0.0001)), p_ix%3]) + print(int(np.round(p_ix/4+0.0001)), p_ix%2) + ax.imshow(batch['patient_data'][0, m_ix, :, :, z_ix], cmap='gray') + ax.text(input_x, input_y, r"{}{}".format("$b=$" if m_ix==1 else "", b), size=text_fs, color=cf.white, transform=ax.transAxes, + bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7)) + ax.axis("off") + if b==50: + ax.set_title(input_title, size=text_fs) + + ax_gt = fig.add_subplot(grid[:, gt_col:gt_col+2]) # GT + ax_pred = fig.add_subplot(grid[:, gt_col+2:gt_col+4]) # Prediction + #ax.imshow(batch['patient_data'][0, 0, :, :, z_ix], cmap='gray') + #ax.imshow(batch['patient_data'][0, 0, :, :, z_ix], cmap='gray') + #ax.imshow(plg.to_rgba(batch['patient_seg'][0,0,:,:,z_ix], cf.cmap), alpha=0.8) + plg.view_batch_thesis(cf, batch, res_dict=results_dict, legend=True, sample_picks=None, patient_items=True, + vol_slice_picks=[z_ix,], show_gt_labels=True, box_score_thres=thresh, plot_mods=True, + out_file=None, dpi=600, return_fig=False, axes={'gt':ax_gt, 'pred':ax_pred}, fontsize=text_fs) + + + ax_gt.set_title("Ground Truth", size=text_fs) + ax_pred.set_title("Prediction", size=text_fs) + texts = list(ax_gt.texts) + ax_gt.texts = [] + for text in texts: + cl_id = int(text.get_text()) + x, y = text.get_position() + text_str = "GS="+str(gs_renamer[cf.class_id2label[cl_id].gleasons[0]]) + ax_gt.text(x-4*text_fs//2, y, text_str, color=text.get_color(), + fontsize=text_fs, bbox=dict(facecolor=text.get_bbox_patch().get_facecolor(), alpha=0.7, edgecolor='none', clip_on=True, pad=0)) + texts = list(ax_pred.texts) + ax_pred.texts = [] + for text in texts: + x, y = text.get_position() + x -= 4 * text_fs // 2 + try: + cl_id = int(text.get_text()) + text_str = class_renamer[cl_id] + except ValueError: + text_str = text.get_text() + if text.get_bbox_patch().get_facecolor()[:3]==cf.dark_green: + x -= 4* text_fs + ax_pred.text(x, y, text_str, color=text.get_color(), + fontsize=text_fs, bbox=dict(facecolor=text.get_bbox_patch().get_facecolor(), alpha=0.7, edgecolor='none', clip_on=True, pad=0)) + + ax_gt.axis("off") + ax_pred.axis("off") + + plg.plt.tight_layout() + + if out_file is not None: + plg.plt.savefig(out_file, dpi=600, bbox_inches='tight') + + + + return + +def find_suitable_examples(exp_dir1, exp_dir2): + test_df1 = analyze_test_df('lidc',exp_dir1) + test_df2 = analyze_test_df('lidc', exp_dir2) + test_df1 = test_df1[test_df1.pred_score>0.3] + test_df2 = test_df2[test_df2.pred_score > 0.3] + + tp_df1 = test_df1[test_df1.det_type == 'det_tp'] + + tp_pids = tp_df1.pid.unique() + tp_fp_pids = test_df2[(test_df2.pid.isin(tp_pids)) & + ((test_df2.regressions-test_df2.rg_targets).abs()>1)].pid.unique() + cand_df = tp_df1[tp_df1.pid.isin(tp_fp_pids)] + sorter = (cand_df.regressions - cand_df.rg_targets).abs().argsort() + cand_df = cand_df.iloc[sorter] + print("Good guesses for examples: ", cand_df.pid.unique()[:20]) + return + +def plot_single_results_lidc(): + dataset_name = 'lidc' + exp_dir1 = '/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/ms12345_mrcnn3d_rg_copiedparams' + exp_dir2 = '/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/ms12345_mrcnn3d_rgbin_copiedparams' + cf = get_cf(dataset_name, exp_dir1) + #file_names = [f_name for f_name in os.listdir(os.path.join(exp_dir, 'inference_analysis')) if f_name.endswith('.pkl')] + # file_names = [os.path.join(exp_dir, "inference_analysis", f_name) for f_name in file_names] + file_names = ['bytes_merged_boxes_fold_0_pid_0296a.pkl', 'bytes_merged_boxes_fold_2_pid_0416a.pkl', + 'bytes_merged_boxes_fold_1_pid_0635a.pkl', "bytes_merged_boxes_fold_0_pid_0811a.pkl", + "bytes_merged_boxes_fold_0_pid_0969a.pkl", + # 'bytes_merged_boxes_fold_0_pid_0484a.pkl', 'bytes_merged_boxes_fold_0_pid_0492a.pkl', + # 'bytes_merged_boxes_fold_0_pid_0505a.pkl','bytes_merged_boxes_fold_2_pid_0164a.pkl', + # 'bytes_merged_boxes_fold_3_pid_0594a.pkl', + + + ] + z_ics = [167, 159, + 107, 194, + 177, + # 84, 145, + # 212, 219, + # 67 + ] + plot_files = [ + {'files': [os.path.join(exp_dir, "inference_analysis", f_name) for exp_dir in [exp_dir1, exp_dir2]], + 'z_ix': z_ix} for (f_name, z_ix) in zip(file_names, z_ics) + ] + + info_df_path = '/mnt/HDD2TB/Documents/data/lidc/pp_20190318/patient_gts_{}/info_df.pickle'.format(cf.training_gts) + info_df = pd.read_pickle(info_df_path) + + #cf.training_gts = 'sa' + cf.roi_items = ['regression_targets', 'rg_bin_targets_sa'] #['class_targets'] + cf.observables_rois + + text_fs = 8 + fig = plg.plt.figure(figsize=(6, 9.9)) #w, h + #fig = plg.plt.figure(figsize=(6, 6.5)) + #fig.subplots_adjust(hspace=0, wspace=0) + grid = plg.plt.GridSpec(len(plot_files), 3, wspace=0.0, hspace=0.0, figure=fig) #rows, cols + cf.plot_class_ids = True + + + for f_ix, pack in enumerate(plot_files): + z_ix = plot_files[f_ix]['z_ix'] + for model_ix in range(2)[::-1]: + print("f_ix, m_ix", f_ix, model_ix) + plot_file = utils.load_obj(plot_files[f_ix]['files'][model_ix]) + batch = plot_file["batch"] + pid = batch["pid"][0] + batch['patient_rg_bin_targets_sa'] = info_df[info_df.pid == pid]['class_target'].tolist() + # apply same filter as with merged GTs: need at least two non-zero votes to consider a RoI. + batch['patient_rg_bin_targets_sa'] = [[four_votes for four_votes in batch_el if + np.count_nonzero(four_votes>0)>=2] for batch_el in + batch['patient_rg_bin_targets_sa']] + results_dict = plot_file["res_dict"] + + # pred + ax = fig.add_subplot(grid[f_ix, model_ix+1]) + plg.view_batch_thesis(cf, batch, res_dict=results_dict, legend=True, sample_picks=None, + vol_slice_picks=[z_ix, ], show_gt_labels=True, box_score_thres=0.2, + plot_mods=False, + out_file=None, dpi=600, patient_items=True, return_fig=False, + axes={'pred': ax}) + if f_ix==0: + ax.set_title("{}".format("Reg R-CNN" if model_ix==0 else "Mask R-CNN"), size=text_fs*1.3) + else: + ax.set_title("") + + ax.axis("off") + #grid.tight_layout(fig) + + # GT + if model_ix==0: + ax = fig.add_subplot(grid[f_ix, 0]) + # ax.imshow(batch['patient_data'][0, 0, :, :, z_ix], cmap='gray') + # ax.imshow(plg.to_rgba(batch['patient_seg'][0,0,:,:,z_ix], cf.cmap), alpha=0.8) + boxes_fig = plg.view_batch_thesis(cf, batch, res_dict=results_dict, legend=True, sample_picks=None, + vol_slice_picks=[z_ix, ], show_gt_labels=True, box_score_thres=0.1, + plot_mods=False, seg_cmap="rg", + out_file=None, dpi=600, patient_items=True, return_fig=False, + axes={'gt':ax}) + ax.set_ylabel(r"$\mathbf{"+chr(f_ix+97)+")}$ " + ax.get_ylabel()) + ax.set_ylabel("") + if f_ix==0: + ax.set_title("Ground Truth", size=text_fs*1.3) + else: + ax.set_title("") + + + #fig_patches = fig_leg.get_patches() + patches= [plg.mpatches.Patch(color=label.color, label="{:.10s}".format(label.name)) for label in cf.bin_id2label.values() if not label.id in [0,]] + #fig.legends.append(fig_leg) + plg.plt.figlegend(handles=patches, loc="lower center", bbox_to_anchor=(0.5, 0.0), borderaxespad=0., + ncol=len(patches), bbox_transform=fig.transFigure, title="Binned Malignancy Score", + fontsize= text_fs) + plg.plt.tight_layout() + out_file = os.path.join(exp_dir1, "inference_analysis", "lidc_example_results_solarized.pdf") + if out_file is not None: + plg.plt.savefig(out_file, dpi=600, bbox_inches='tight') + + +def box_clustering(exp_dir='', plot_dir=None): + import datasets.prostate.data_loader as dl + cf = get_cf('prostate', exp_dir) + if plot_dir is None: + plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'prostate', 'misc') + + fig = plg.plt.figure(figsize=(10, 4)) + #fig.subplots_adjust(hspace=0, wspace=0) + grid = plg.plt.GridSpec(2, 3, wspace=0.0, hspace=0., figure=fig) + fs = 14 + xyA = (.9, 0.5) + xyB = (0.05, .5) + + patch_size = np.array([200, 320]) + clustering_iou = 0.1 + img_y, img_x = patch_size + + boxes = [ + {'box_coords': [img_y * 0.2, img_x * 0.04, img_y * 0.55, img_x * 0.31], 'box_score': 0.45, 'box_cl': 1, + 'regression': 2., 'rg_bin': cf.rg_val_to_bin_id(1.), + 'box_patch_center_factor': 1., 'ens_ix': 1, 'box_n_overlaps': 1.}, + {'box_coords': [img_y*0.05, img_x*0.05, img_y*0.5, img_x*0.3], 'box_score': 0.85, 'box_cl': 2, + 'regression': 1., 'rg_bin': cf.rg_val_to_bin_id(1.), + 'box_patch_center_factor': 1., 'ens_ix':1, 'box_n_overlaps':1.}, + {'box_coords': [img_y * 0.1, img_x * 0.2, img_y * 0.4, img_x * 0.7], 'box_score': 0.95, 'box_cl': 2, + 'regression': 1., 'rg_bin': cf.rg_val_to_bin_id(1.), + 'box_patch_center_factor': 1., 'ens_ix':1, 'box_n_overlaps':1.}, + {'box_coords': [img_y * 0.80, img_x * 0.35, img_y * 0.95, img_x * 0.85], 'box_score': 0.6, 'box_cl': 2, + 'regression': 1., 'rg_bin': cf.rg_val_to_bin_id(1.), + 'box_patch_center_factor': 1., 'ens_ix': 1, 'box_n_overlaps': 1.}, + {'box_coords': [img_y * 0.85, img_x * 0.4, img_y * 0.93, img_x * 0.9], 'box_score': 0.85, 'box_cl': 2, + 'regression': 1., 'rg_bin': cf.rg_val_to_bin_id(1.), + 'box_patch_center_factor': 1., 'ens_ix':1, 'box_n_overlaps':1.}, + ] + for box in boxes: + c = box['box_coords'] + box_centers = np.array([(c[ii + 2] - c[ii]) / 2 for ii in range(len(c) // 2)]) + box['box_patch_center_factor'] = np.mean( + [norm.pdf(bc, loc=pc, scale=pc * 0.8) * np.sqrt(2 * np.pi) * pc * 0.8 for bc, pc in + zip(box_centers, patch_size / 2)]) + print("pc fact", box['box_patch_center_factor']) + + box_coords = np.array([box['box_coords'] for box in boxes]) + box_scores = np.array([box['box_score'] for box in boxes]) + box_cl_ids = np.array([box['box_cl'] for box in boxes]) + ax0 = fig.add_subplot(grid[:,:2]) + plg.plot_boxes(cf, box_coords, patch_size, box_scores, box_cl_ids, out_file=os.path.join(plot_dir, "demo_boxes_unclustered.png"), ax=ax0) + ax0.text(*xyA, 'a) Raw ', horizontalalignment='right', verticalalignment='center', transform=ax0.transAxes, + weight='bold', fontsize=fs) + + nms_boxes = [] + for cl in range(1,3): + cl_boxes = [box for box in boxes if box['box_cl'] == cl ] + box_coords = np.array([box['box_coords'] for box in cl_boxes]) + box_scores = np.array([box['box_score'] for box in cl_boxes]) + if 0 not in box_scores.shape: + keep_ix = mutils.nms_numpy(box_coords, box_scores, thresh=clustering_iou) + else: + keep_ix = [] + nms_boxes += [cl_boxes[ix] for ix in keep_ix] + box_coords = np.array([box['box_coords'] for box in nms_boxes]) + box_scores = np.array([box['box_score'] for box in nms_boxes]) + box_cl_ids = np.array([box['box_cl'] for box in nms_boxes]) + ax1 = fig.add_subplot(grid[1, 2]) + nms_color = cf.black + plg.plot_boxes(cf, box_coords, patch_size, box_scores, box_cl_ids, out_file=os.path.join(plot_dir, "demo_boxes_nms_iou_{}.png".format(clustering_iou)), ax=ax1) + ax1.text(*xyB, ' c) NMS', horizontalalignment='left', verticalalignment='center', transform=ax1.transAxes, + weight='bold', color=nms_color, fontsize=fs) + + #------ WBC ------------------- + regress_flag = False + + wbc_boxes = [] + for cl in range(1,3): + cl_boxes = [box for box in boxes if box['box_cl'] == cl] + box_coords = np.array([box['box_coords'] for box in cl_boxes]) + box_scores = np.array([box['box_score'] for box in cl_boxes]) + box_center_factor = np.array([b['box_patch_center_factor'] for b in cl_boxes]) + box_n_overlaps = np.array([b['box_n_overlaps'] for b in cl_boxes]) + box_ens_ix = np.array([b['ens_ix'] for b in cl_boxes]) + box_regressions = np.array([b['regression'] for b in cl_boxes]) if regress_flag else None + box_rg_bins = np.array([b['rg_bin'] if 'rg_bin' in b.keys() else float('NaN') for b in cl_boxes]) + box_rg_uncs = np.array([b['rg_uncertainty'] if 'rg_uncertainty' in b.keys() else float('NaN') for b in cl_boxes]) + if 0 not in box_scores.shape: + keep_scores, keep_coords, keep_n_missing, keep_regressions, keep_rg_bins, keep_rg_uncs = \ + predictor_file.weighted_box_clustering(box_coords, box_scores, box_center_factor, box_n_overlaps, box_rg_bins, box_rg_uncs, + box_regressions, box_ens_ix, clustering_iou, n_ens=1) + + for boxix in range(len(keep_scores)): + clustered_box = {'box_type': 'det', 'box_coords': keep_coords[boxix], + 'box_score': keep_scores[boxix], 'cluster_n_missing': keep_n_missing[boxix], + 'box_pred_class_id': cl} + if regress_flag: + clustered_box.update({'regression': keep_regressions[boxix], + 'rg_uncertainty': keep_rg_uncs[boxix], + 'rg_bin': keep_rg_bins[boxix]}) + wbc_boxes.append(clustered_box) + + box_coords = np.array([box['box_coords'] for box in wbc_boxes]) + box_scores = np.array([box['box_score'] for box in wbc_boxes]) + box_cl_ids = np.array([box['box_pred_class_id'] for box in wbc_boxes]) + ax2 = fig.add_subplot(grid[0, 2]) + wbc_color = cf.black + plg.plot_boxes(cf, box_coords, patch_size, box_scores, box_cl_ids, out_file=os.path.join(plot_dir, "demo_boxes_wbc_iou_{}.png".format(clustering_iou)), ax=ax2) + ax2.text(*xyB, ' b) WBC', horizontalalignment='left', verticalalignment='center', transform=ax2.transAxes, + weight='bold', color=wbc_color, fontsize=fs) + # ax2.spines['bottom'].set_color(wbc_color) + # ax2.spines['top'].set_color(wbc_color) + # ax2.spines['right'].set_color(wbc_color) + # ax2.spines['left'].set_color(wbc_color) + + from matplotlib.patches import ConnectionPatch + con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA="axes fraction", coordsB="axes fraction", + axesA=ax0, axesB=ax2, color=wbc_color, lw=1.5, arrowstyle='-|>') + ax0.add_artist(con) + + con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA="axes fraction", coordsB="axes fraction", + axesA=ax0, axesB=ax1, color=nms_color, lw=1.5, arrowstyle='-|>') + ax0.add_artist(con) + # ax0.text(0.5, 0.5, "Test", size=30, va="center", ha="center", rotation=30, + # bbox=dict(boxstyle="angled,pad=0.5", alpha=0.2)) + plg.plt.tight_layout() + plg.plt.savefig(os.path.join(plot_dir, "box_clustering.pdf"), bbox_inches='tight') + +def sketch_AP_AUC(plot_dir=None, draw_auc=True): + from sklearn.metrics import roc_curve, roc_auc_score + from understanding_metrics import get_det_types + import matplotlib.transforms as mtrans + cf = get_cf('prostate', '') + if plot_dir is None: + plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('.') + + if draw_auc: + fig = plg.plt.figure(figsize=(7, 6)) #width, height + # fig.subplots_adjust(hspace=0, wspace=0) + grid = plg.plt.GridSpec(2, 2, wspace=0.23, hspace=.45, figure=fig) #rows, cols + else: + fig = plg.plt.figure(figsize=(12, 3)) #width, height + # fig.subplots_adjust(hspace=0, wspace=0) + grid = plg.plt.GridSpec(1, 3, wspace=0.23, hspace=.45, figure=fig) #rows, cols + fs = 13 + text_fs = 11 + optim_color = cf.dark_green + non_opt_color = cf.aubergine + + df = pd.DataFrame(columns=['pred_score', 'class_label', 'pred_class', 'det_type', 'match_iou']) + df2 = df.copy() + df["pred_score"] = [0,0.3,0.25,0.2, 0.8, 0.9, 0.9, 0.9, 0.9] + df["class_label"] = [0,0,0,0, 1, 1, 1, 1, 1] + df["det_type"] = get_det_types(df) + df["match_iou"] = [0.1] * len(df) + + df2["pred_score"] = [0, 0.77, 0.5, 1., 0.5, 0.35, 0.3, 0., 0.7, 0.85, 0.9] + df2["class_label"] = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + df2["det_type"] = get_det_types(df2) + df2["match_iou"] = [0.1] * len(df2) + + #------ PRC ------- + # optimal + if draw_auc: + ax = fig.add_subplot(grid[1, 0]) + else: + ax = fig.add_subplot(grid[0, 2]) + pr, rc = evaluator_file.compute_prc(df) + ax.plot(rc, pr, color=optim_color, label="Optimal Detection") + ax.fill_between(rc, pr, alpha=0.33, color=optim_color) + + # suboptimal + pr, rc = evaluator_file.compute_prc(df2) + ax.plot(rc, pr, color=non_opt_color, label="Suboptimal") + ax.fill_between(rc, pr, alpha=0.33, color=non_opt_color) + #plt.title() + #plt.legend(loc=3 if c == 'prc' else 4) + ax.set_ylabel('precision', fontsize=text_fs) + ax.set_ylim((0., 1.1)) + ax.set_xlabel('recall', fontsize=text_fs) + ax.set_title('Precision-Recall Curves', fontsize=fs) + #ax.legend(ncol=2, loc='center')#, bbox_to_anchor=(0.5, 1.05)) + + + #---- ROC curve + if draw_auc: + ax = fig.add_subplot(grid[1, 1]) + roc = roc_curve(df.class_label.tolist(), df.pred_score.tolist()) + ax.plot(roc[0], roc[1], color=optim_color) + ax.fill_between(roc[0], roc[1], alpha=0.33, color=optim_color) + ax.set_xlabel('false-positive rate', fontsize=text_fs) + ax.set_ylim((0., 1.1)) + ax.set_ylabel('recall', fontsize=text_fs) + + roc = roc_curve(df2.class_label.tolist(), df2.pred_score.tolist()) + ax.plot(roc[0], roc[1], color=non_opt_color) + ax.fill_between(roc[0], roc[1], alpha=0.33, color=non_opt_color) + + roc = ([0, 1], [0, 1]) + ax.plot(roc[0], roc[1], color=cf.gray, linestyle='dashed', label="random predictor") + + ax.set_title('ROC Curves', fontsize=fs) + ax.legend(ncol=2, loc='lower right', fontsize=text_fs) + + #--- hist optimal + text_left = 0.05 + ax = fig.add_subplot(grid[0, 0]) + tn_count = df.det_type.tolist().count('det_tn') + AUC = roc_auc_score(df.class_label, df.pred_score) + df = df[(df.det_type=="det_tp") | (df.det_type=="det_fp") | (df.det_type=="det_fn")] + labels = df.class_label.values + preds = df.pred_score.values + type_list = df.det_type.tolist() + + ax.hist(preds[labels == 0], alpha=0.3, color=cf.red, range=(0, 1), bins=50, label="FP") + ax.hist(preds[labels == 1], alpha=0.3, color=cf.blue, range=(0, 1), bins=50, label="FN at score 0 and TP") + #ax.axvline(x=cf.min_det_thresh, alpha=0.4, color=cf.orange, linewidth=1.5, label="min det thresh") + fp_count = type_list.count('det_fp') + fn_count = type_list.count('det_fn') + tp_count = type_list.count('det_tp') + pos_count = fn_count + tp_count + if draw_auc: + text = "AP: {:.2f} ROC-AUC: {:.2f}\n".format(evaluator_file.get_roi_ap_from_df((df, 0.0, False)), AUC) + else: + text = "AP: {:.2f}\n".format(evaluator_file.get_roi_ap_from_df((df, 0.0, False))) + text += 'TP: {} FP: {} FN: {} TN: {}\npositives: {}'.format(tp_count, fp_count, fn_count, tn_count, pos_count) + + ax.text(text_left,4, text, fontsize=text_fs) + ax.set_yscale('log') + ax.set_ylim(bottom=10**-2, top=10**2) + ax.set_xlabel("prediction score", fontsize=text_fs) + ax.set_ylabel("occurences", fontsize=text_fs) + #autoAxis = ax.axis() + # rec = plg.mpatches.Rectangle((autoAxis[0] - 0.7, autoAxis[2] - 0.2), (autoAxis[1] - autoAxis[0]) + 1, + # (autoAxis[3] - autoAxis[2]) + 0.4, fill=False, lw=2) + # rec = plg.mpatches.Rectangle((autoAxis[0] , autoAxis[2] ), (autoAxis[1] - autoAxis[0]) , + # (autoAxis[3] - autoAxis[2]) , fill=False, lw=2, color=optim_color) + # rec = ax.add_patch(rec) + # rec.set_clip_on(False) + plg.plt.setp(ax.spines.values(), color=optim_color, linewidth=2) + ax.set_facecolor((*optim_color,0.1)) + ax.set_title("Detection Histograms", fontsize=fs) + + ax = fig.add_subplot(grid[0, 1]) + tn_count = df2.det_type.tolist().count('det_tn') + AUC = roc_auc_score(df2.class_label, df2.pred_score) + df2 = df2[(df2.det_type=="det_tp") | (df2.det_type=="det_fp") | (df2.det_type=="det_fn")] + labels = df2.class_label.values + preds = df2.pred_score.values + type_list = df2.det_type.tolist() + + ax.hist(preds[labels == 0], alpha=0.3, color=cf.red, range=(0, 1), bins=50, label="FP") + ax.hist(preds[labels == 1], alpha=0.3, color=cf.blue, range=(0, 1), bins=50, label="FN at score 0 and TP") + # ax.axvline(x=cf.min_det_thresh, alpha=0.4, color=cf.orange, linewidth=1.5, label="min det thresh") + fp_count = type_list.count('det_fp') + fn_count = type_list.count('det_fn') + tp_count = type_list.count('det_tp') + pos_count = fn_count + tp_count + if draw_auc: + text = "AP: {:.2f} ROC-AUC: {:.2f}\n".format(evaluator_file.get_roi_ap_from_df((df2, 0.0, False)), AUC) + else: + text = "AP: {:.2f}\n".format(evaluator_file.get_roi_ap_from_df((df2, 0.0, False))) + text += 'TP: {} FP: {} FN: {} TN: {}\npositives: {}'.format(tp_count, fp_count, fn_count, tn_count, pos_count) + + ax.text(text_left, 4*10**0, text, fontsize=text_fs) + ax.set_yscale('log') + ax.margins(y=10e2) + ax.set_ylim(bottom=10**-2, top=10**2) + ax.set_xlabel("prediction score", fontsize=text_fs) + ax.set_yticks([]) + plg.plt.setp(ax.spines.values(), color=non_opt_color, linewidth=2) + ax.set_facecolor((*non_opt_color, 0.05)) + ax.legend(ncol=2, loc='upper center', bbox_to_anchor=(0.5, 1.18), fontsize=text_fs) + + if draw_auc: + # Draw a horizontal line + line = plg.plt.Line2D([0.1, .9], [0.48, 0.48], transform=fig.transFigure, color="black") + fig.add_artist(line) + + outfile = os.path.join(plot_dir, "metrics.png") + print("Saving plot to {}".format(outfile)) + plg.plt.savefig(outfile, bbox_inches='tight', dpi=600) + + return + +def draw_toy_cylinders(plot_dir=None): + source_path = "datasets/toy" + if plot_dir is None: + plot_dir = os.path.join(source_path, "misc") + #plot_dir = '/home/gregor/Dropbox/Thesis/Main/tmp' + os.makedirs(plot_dir, exist_ok=True) + + cf = get_cf('toy', '') + cf.pre_crop_size = [2200, 2200,1] #y,x,z; + #cf.dim = 2 + cf.ambiguities = {"radius_calib": (1., 1. / 6) } + cf.pp_blur_min_intensity = 0.2 + + generate_toys = utils.import_module("generate_toys", os.path.join(source_path, 'generate_toys.py')) + ToyGen = generate_toys.ToyGenerator(cf) + + fig = plg.plt.figure(figsize=(10, 8.2)) #width, height + grid = plg.plt.GridSpec(4, 5, wspace=0.0, hspace=.0, figure=fig) #rows, cols + fs, text_fs = 16, 14 + text_x, text_y = 0.5, 0.85 + true_gt_col, dist_gt_col = cf.dark_green, cf.blue + true_cmap = {1:true_gt_col} + + img = np.random.normal(loc=0.0, scale=cf.noise_scale, size=ToyGen.sample_size) + img[img < 0.] = 0. + # one-hot-encoded seg + seg = np.zeros((cf.num_classes + 1, *ToyGen.sample_size)).astype('uint8') + undistorted_seg = np.copy(seg) + applied_gt_distort = False + + class_id, shape = 1, 'cylinder' + #all_radii = ToyGen.generate_sample_radii(class_ids, shapes) + enlarge_f = 20 + all_radii = np.array([np.mean(label.bin_vals) if label.id!=5 else label.bin_vals[0]+5 for label in cf.bin_labels if label.id!=0]) + bins = [(min(label.bin_vals), max(label.bin_vals)) for label in cf.bin_labels] + bin_edges = [(bins[i][1] + bins[i + 1][0])*enlarge_f / 2 for i in range(len(bins) - 1)] + all_radii = [np.array([r*enlarge_f, r*enlarge_f, 1]) for r in all_radii] # extend to required 3D format + regress_targets, undistorted_rg_targets = [], [] + ics = np.argwhere(np.ones(seg[0].shape)) # indices ics equal positions within img/volume + center = np.array([dim//2 for dim in img.shape]) + + # for illustrating GT distribution, keep scale same size + #x = np.linspace(mu - 300, mu + 300, 100) + x = np.linspace(0, 50*enlarge_f, 500) + ax_gauss = fig.add_subplot(grid[3, :]) + mus, sigmas = [], [] + + for roi_ix, radii in enumerate(all_radii): + print('processing {} {}'.format(roi_ix, radii)) + cur_img, cur_seg, cur_undistorted_seg, cur_regress_targets, cur_undistorted_rg_targets, cur_applied_gt_distort = \ + ToyGen.draw_object(img.copy(), seg.copy(), undistorted_seg, ics, regress_targets, undistorted_rg_targets, applied_gt_distort, + roi_ix, class_id, shape, np.copy(radii), center) + + ax = fig.add_subplot(grid[0,roi_ix]) + ax.imshow(cur_img[...,0], cmap='gray', vmin=0) + ax.set_title("r{}".format(roi_ix+1), fontsize=fs) + if roi_ix==0: + ax.set_ylabel(r"$\mathbf{a)}$ Input", fontsize=fs) + plg.suppress_axes_lines(ax) + else: + ax.axis('off') + + ax = fig.add_subplot(grid[1, roi_ix]) + ax.imshow(cur_img[..., 0], cmap='gray') + ax.imshow(plg.to_rgba(np.argmax(cur_undistorted_seg[...,0], axis=0), true_cmap), alpha=0.8) + ax.text(text_x, text_y, r"$r_{a}=$"+"{:.1f}".format(cur_undistorted_rg_targets[roi_ix][0]/enlarge_f), transform=ax.transAxes, + color=cf.white, bbox=dict(facecolor=true_gt_col, alpha=0.7, edgecolor=cf.white, clip_on=False,pad=2.5), + fontsize=text_fs, ha='center', va='center') + if roi_ix==0: + ax.set_ylabel(r"$\mathbf{b)}$ Exact GT", fontsize=fs) + plg.suppress_axes_lines(ax) + else: + ax.axis('off') + ax = fig.add_subplot(grid[2, roi_ix]) + ax.imshow(cur_img[..., 0], cmap='gray') + ax.imshow(plg.to_rgba(np.argmax(cur_seg[..., 0], axis=0), cf.cmap), alpha=0.7) + ax.text(text_x, text_y, r"$r_{a}=$"+"{:.1f}".format(cur_regress_targets[roi_ix][0]/enlarge_f), transform=ax.transAxes, + color=cf.white, bbox=dict(facecolor=cf.blue, alpha=0.7, edgecolor=cf.white, clip_on=False,pad=2.5), + fontsize=text_fs, ha='center', va='center') + if roi_ix == 0: + ax.set_ylabel(r"$\mathbf{c)}$ Noisy GT", fontsize=fs) + plg.suppress_axes_lines(ax) + else: + ax.axis('off') + + # GT distributions + assert radii[0]==radii[1] + mu, sigma = radii[0], radii[0] * cf.ambiguities["radius_calib"][1] + ax_gauss.axvline(mu, color=true_gt_col) + ax_gauss.text(mu, -0.003, "$r=${:.0f}".format(mu/enlarge_f), color=true_gt_col, fontsize=text_fs, ha='center', va='center', + bbox = dict(facecolor='none', alpha=0.7, edgecolor=true_gt_col, clip_on=False, pad=2.5)) + mus.append(mu); sigmas.append(sigma) + lower_bound = max(bin_edges[roi_ix], min(x))# if roi_ix>0 else 2*mu-bin_edges[roi_ix+1] + upper_bound = bin_edges[roi_ix+1] if len(bin_edges)>roi_ix+1 else max(x)#2*mu-bin_edges[roi_ix] + if roi_ix, head_length = 0.05, head_width = .005", lw=1)) + #ax_gauss.arrow(1, 0.5, 0., 0.1) + handles = [plg.mpatches.Patch(facecolor=dist_gt_col, label='Inexact Seg.', alpha=0.7, edgecolor='none'), + mlines.Line2D([], [], color=dist_gt_col, marker=r'$\curlywedge$', linestyle='none', markersize=11, label='GT Sampling Distr.'), + mlines.Line2D([], [], color=true_gt_col, marker='|', markersize=12, label='Exact GT Radius.', linestyle='none'), + plg.mpatches.Patch(facecolor=true_gt_col, label='a)-c) Exact Seg., d) Bin', alpha=0.7, edgecolor='none')] + fig.legend(handles=handles, loc="lower center", ncol=len(handles), fontsize=text_fs) + outfile = os.path.join(plot_dir, "toy_cylinders.png") + print("Saving plot to {}".format(outfile)) + plg.plt.savefig(outfile, bbox_inches='tight', dpi=600) + + + return + +def seg_det_cityscapes_example(plot_dir=None): + cf = get_cf('cityscapes', '') + source_path = "datasets/cityscapes" + if plot_dir is None: + plot_dir = os.path.join(source_path, "misc") + os.makedirs(plot_dir, exist_ok=True) + + + dl = utils.import_module("dl", os.path.join(source_path, 'data_loader.py')) + #from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates + data_set = dl.Dataset(cf) + Converter = dl.ConvertSegToBoundingBoxCoordinates(2, cf.roi_items) + + fig = plg.plt.figure(figsize=(9, 3)) #width, height + grid = plg.plt.GridSpec(1, 2, wspace=0.05, hspace=.0, figure=fig) #rows, cols + fs, text_fs = 12, 10 + + nice_imgs = ["bremen000099000019", "hamburg000000033506", "frankfurt000001058914",] + img_id = nice_imgs[2] + #img_id = np.random.choice(data_set.set_ids) + + + print("Selected img", img_id) + img = np.load(data_set[img_id]["img"]).transpose(1,2,0) + seg = np.load(data_set[img_id]["seg"]) + cl_targs = data_set[img_id]["class_targets"] + roi_ids = np.unique(seg[seg > 0]) + # ---- detection example ----- + cl_id2name = {1: "h", 2: "v"} + color_palette = [cf.purple, cf.aubergine, cf.magenta, cf.dark_blue, cf.blue, cf.bright_blue, cf.cyan, cf.dark_green, + cf.green, cf.dark_yellow, cf.yellow, cf.orange, cf.red, cf.dark_red, cf.bright_red] + n_colors = len(color_palette) + cmap = {roi_id : color_palette[(roi_id-1)%n_colors] for roi_id in roi_ids} + cmap[0] = (1,1,1,0.) + + ax = fig.add_subplot(grid[0, 1]) + ax.imshow(img) + ax.imshow(plg.to_rgba(seg, cmap), alpha=0.7) + + data_dict = Converter(**{'seg':seg[np.newaxis, np.newaxis], 'class_targets': [cl_targs]}) # needs batch dim and channel + for roi_ix, bb_target in enumerate(data_dict['bb_target'][0]): + [y1, x1, y2, x2] = bb_target + width, height = x2 - x1, y2 - y1 + cl_id = cl_targs[roi_ix] + label = cf.class_id2label[cl_id] + text_x, text_y = x2, y1 + id_text = cl_id2name[cl_id] + text_str = '{}'.format(id_text) + text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) + #ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=text_fs, ha="center", va="center") + edgecolor = label.color + bbox = plg.mpatches.Rectangle((x1, y1), width, height, linewidth=1.05, edgecolor=edgecolor, facecolor='none') + ax.add_patch(bbox) + ax.axis('off') + + # ---- seg example ----- + for roi_id in roi_ids: + seg[seg==roi_id] = cl_targs[roi_id-1] + + ax = fig.add_subplot(grid[0,0]) + ax.imshow(img) + ax.imshow(plg.to_rgba(seg, cf.cmap), alpha=0.7) + ax.axis('off') + + plg.plt.tight_layout() + outfile = os.path.join(plot_dir, "cityscapes_example.png") + print("Saving plot to {}".format(outfile)) + plg.plt.savefig(outfile, bbox_inches='tight', dpi=600) + + + + + +if __name__=="__main__": + stime = time.time() + #seg_det_cityscapes_example() + #box_clustering() + #sketch_AP_AUC(draw_auc=False) + #draw_toy_cylinders() + #prostate_GT_examples(plot_dir="/home/gregor/Dropbox/Thesis/Main/MFPPresentation/graphics") + #prostate_results_static() + #prostate_dataset_stats(plot_dir="/home/gregor/Dropbox/Thesis/Main/MFPPresentation/graphics", show_splits=False) + #lidc_dataset_stats() + #lidc_sa_dataset_stats() + #lidc_annotator_confusion() + #lidc_merged_sa_joint_plot() + #lidc_annotator_dissent_images() + exp_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_frcnn3d_cl_bs6" + #multiple_clustering_results('prostate', exp_dir, plot_hist=True) + exp_parent_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments" + exp_parent_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments_debug_retinas" + #get_plot_clustering_results('prostate', exp_parent_dir, res_from_file=False) + + exp_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_frcnn3d_cl_bs6" + #cf = get_cf('prostate', exp_dir) + #plot_file = os.path.join(exp_dir, "inference_analysis/bytes_merged_boxes_fold_1_pid_177.pkl") + #plot_single_results(cf, exp_dir, plot_file) + + exp_dir1 = "/home/gregor/networkdrives/E132-Cluster-Projects/lidc_sa/experiments/ms12345_mrcnn3d_rg_bs8" + exp_dir2 = "/home/gregor/networkdrives/E132-Cluster-Projects/lidc_sa/experiments/ms12345_mrcnn3d_rgbin_bs8" + #find_suitable_examples(exp_dir1, exp_dir2) + #plot_single_results_lidc() + plot_dir = "/home/gregor/Dropbox/Thesis/MICCAI2019/Graphics" + #lidc_results_static(plot_dir=plot_dir) + #toy_results_static(plot_dir=plot_dir) + plot_lidc_dissent_and_example(plot_dir=plot_dir, confusion_matrix=True, numbering=False, example_title="LIDC example result") + + mins, secs = divmod((time.time() - stime), 60) + h, mins = divmod(mins, 60) + t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) + print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) \ No newline at end of file diff --git a/inference_analysis.py b/inference_analysis.py new file mode 100644 index 0000000..cce9bc9 --- /dev/null +++ b/inference_analysis.py @@ -0,0 +1,173 @@ +"""for presentations etc""" + +import plotting as plg + +import sys +import os +import pickle + +import numpy as np +import pandas as pd +import torch + +import utils.exp_utils as utils +import utils.model_utils as mutils +from predictor import Predictor +from evaluator import Evaluator + + +def find_pid_in_splits(pid, exp_dir=None): + if exp_dir is None: + exp_dir = cf.exp_dir + check_file = os.path.join(exp_dir, 'fold_ids.pickle') + with open(check_file, 'rb') as handle: + splits = pickle.load(handle) + + finds = [] + for i, split in enumerate(splits): + if pid in split: + finds.append(i) + print("Pid {} found in split {}".format(pid, i)) + if not len(finds)==1: + raise Exception("pid {} found in more than one split: {}".format(pid, finds)) + return finds[0] + +def plot_train_forward(slices=None): + with torch.no_grad(): + batch = next(val_gen) + results_dict = net.train_forward(batch, is_validation=True) #seg preds are int preds already + + out_file = os.path.join(anal_dir, "straight_val_inference_fold_{}".format(str(cf.fold))) + plg.view_batch(cf, batch, res_dict=results_dict, show_info=False, legend=True, + out_file=out_file, slices=slices) + +def plot_forward(pid, slices=None): + with torch.no_grad(): + batch = batch_gen['test'].generate_train_batch(pid=pid) + results_dict = net.test_forward(batch) #seg preds are only seg_logits! need to take argmax. + + if 'seg_preds' in results_dict.keys(): + results_dict['seg_preds'] = np.argmax(results_dict['seg_preds'], axis=1)[:,np.newaxis] + + out_file = os.path.join(anal_dir, "straight_inference_fold_{}_pid_{}".format(str(cf.fold), pid)) + plg.view_batch(cf, batch, res_dict=results_dict, show_info=False, legend=True, show_gt_labels=True, + out_file=out_file, sample_picks=slices) + + +def plot_merged_boxes(results_list, pid, plot_mods=False, show_seg_ids="all", show_info=True, show_gt_boxes=True, + s_picks=None, vol_slice_picks=None, score_thres=None): + """ + + :param results_list: holds (results_dict, pid) + :param pid: + :return: + """ + results_dict = [res_dict for (res_dict, pid_) in results_list if pid_==pid][0] + #seg preds are discarded in predictor pipeline. + #del results_dict['seg_preds'] + + batch = batch_gen['test'].generate_train_batch(pid=pid) + out_file = os.path.join(anal_dir, "merged_boxes_fold_{}_pid_{}_thres_{}.png".format(str(cf.fold), pid, str(score_thres).replace(".","_"))) + + utils.save_obj({'res_dict':results_dict, 'batch':batch}, os.path.join(anal_dir, "bytes_merged_boxes_fold_{}_pid_{}".format(str(cf.fold), pid))) + + plg.view_batch(cf, batch, res_dict=results_dict, show_info=show_info, legend=False, sample_picks=s_picks, + show_seg_pred=True, show_seg_ids=show_seg_ids, show_gt_boxes=show_gt_boxes, + box_score_thres=score_thres, vol_slice_picks=vol_slice_picks, show_gt_labels=True, + plot_mods=plot_mods, out_file=out_file, has_colorchannels=cf.has_colorchannels, dpi=600) + + return + + + + + +if __name__=="__main__": + class Args(): + def __init__(self): + #self.dataset_name = "datasets/prostate" + self.dataset_name = "datasets/lidc" + #self.exp_dir = "datasets/toy/experiments/mrcnnal2d_clkengal" # detunet2d_di_bs16_ps512" + #self.exp_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_retinau3d_cl_bs6" + #self.exp_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_frcnn3d_cl_bs6" + #self.exp_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments_t2/gs6071_mrcnn3d_cl_bs6_lessaug" + #self.exp_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_detfpn3d_cl_bs6" + #self.exp_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/lidc_sa/experiments/ms12345_mrcnn3d_rgbin_bs8" + self.exp_dir = '/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/ms12345_mrcnn3d_rg_bs8' + #self.exp_dir = '/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/ms12345_mrcnn3d_rgbin_bs8' + + self.server_env = False + args = Args() + + + data_loader = utils.import_module('dl', os.path.join(args.dataset_name, "data_loader.py")) + + config_file = utils.import_module('cf', os.path.join(args.exp_dir, "configs.py")) + cf = config_file.Configs() + cf.exp_dir = args.exp_dir + cf.test_dir = cf.exp_dir + + pid = '0811a' + cf.fold = find_pid_in_splits(pid) + #cf.fold = 0 + cf.merge_2D_to_3D_preds = False + if cf.merge_2D_to_3D_preds: + cf.dim==3 + cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold)) + anal_dir = os.path.join(cf.exp_dir, "inference_analysis") + + logger = utils.get_logger(cf.exp_dir) + model = utils.import_module('model', os.path.join(cf.exp_dir, "model.py")) + torch.backends.cudnn.benchmark = cf.dim == 3 + net = model.net(cf, logger).cuda() + test_predictor = Predictor(cf, None, logger, mode='test') + test_evaluator = Evaluator(cf, logger, mode='test') + #val_gen = data_loader.get_train_generators(cf, logger, data_statistics=False)['val_sampling'] + batch_gen = data_loader.get_test_generator(cf, logger) + weight_paths = [os.path.join(cf.fold_dir, '{}_best_params.pth'.format(rank)) for rank in + test_predictor.epoch_ranking] + try: + pids = batch_gen["test"].dataset_pids + except: + pids = batch_gen["test"].generator.dataset_pids + print("pids in test set: ", pids) + #pid = pids[0] + #assert pid in pids + + # load already trained model weights + rank = 0 + weight_path = weight_paths[rank] + with torch.no_grad(): + pass + net.load_state_dict(torch.load(weight_path)) + net.eval() + # generate a batch from test set and show results + if not os.path.isdir(anal_dir): + os.mkdir(anal_dir) + + #plot_train_forward() + #plot_forward(pids[0]) + #net.actual_dims() + #batch_gen = data_loader.get_test_generator(cf, logger) + merged_boxes_file = os.path.join(cf.fold_dir, "merged_box_results") + try: + results_list = utils.load_obj(merged_boxes_file+".pkl") + print("loaded merged boxes from file.") + except FileNotFoundError: + results_list = test_predictor.load_saved_predictions() + utils.save_obj(results_list, merged_boxes_file) + + + cf.plot_class_ids = False + for pid in [pid,]:#['0317a',]:#pids[2:8]: + assert pid in [res[1] for res in results_list] + plot_merged_boxes(results_list, pid=pid, show_info=True, show_gt_boxes=True, show_seg_ids="all", score_thres=0.13, + s_picks=None, vol_slice_picks=None, plot_mods=False) + + + + + + + + diff --git a/models/backbone.py b/models/backbone.py new file mode 100644 index 0000000..fed3d34 --- /dev/null +++ b/models/backbone.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch.nn as nn +import torch.nn.functional as F + + +class ConvGenerator(): + """conv-layer generator to avoid 2D vs. 3D distinction in code. + """ + + def __init__(self, dim): + self.dim = dim + + def __call__(self, c_in, c_out, ks, pad=0, stride=1, norm=None, relu='relu'): + """provides generic conv-layer modules for set dimension. + :param c_in: number of in_channels. + :param c_out: number of out_channels. + :param ks: kernel size. + :param pad: pad size. + :param stride: kernel stride. + :param norm: string specifying type of feature map normalization. If None, no normalization is applied. + :param relu: string specifying type of nonlinearity. If None, no nonlinearity is applied. + :return: 2D or 3D conv-layer module. + """ + + if self.dim == 2: + module = nn.Conv2d(c_in, c_out, kernel_size=ks, padding=pad, stride=stride) + if norm is not None: + if norm == 'instance_norm': + norm_layer = nn.InstanceNorm2d(c_out) + elif norm == 'batch_norm': + norm_layer = nn.BatchNorm2d(c_out) + else: + raise ValueError('norm type as specified in configs is not implemented...') + module = nn.Sequential(module, norm_layer) + + elif self.dim==3: + module = nn.Conv3d(c_in, c_out, kernel_size=ks, padding=pad, stride=stride) + if norm is not None: + if norm == 'instance_norm': + norm_layer = nn.InstanceNorm3d(c_out) + elif norm == 'batch_norm': + norm_layer = nn.BatchNorm3d(c_out) + else: + raise ValueError('norm type as specified in configs is not implemented... {}'.format(norm)) + module = nn.Sequential(module, norm_layer) + else: + raise Exception("Invalid dimension {} in conv-layer generation.".format(self.dim)) + + if relu is not None: + if relu == 'relu': + relu_layer = nn.ReLU(inplace=True) + elif relu == 'leaky_relu': + relu_layer = nn.LeakyReLU(inplace=True) + else: + raise ValueError('relu type as specified in configs is not implemented...') + module = nn.Sequential(module, relu_layer) + + return module + + +class Interpolate(nn.Module): + def __init__(self, scale_factor, mode): + super(Interpolate, self).__init__() + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + + def forward(self, x): + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False) + return x + +class ResBlock(nn.Module): + + def __init__(self, start_filts, planes, end_filts, conv, stride=1, identity_skip=True, norm=None, relu='relu'): + """Builds a residual net block. + :param start_filts: #input channels to the block. + :param planes: #channels in block's hidden layers. set start_filts>planes 0])) + out_box_coords.append(box_coords) + out_max_scores.append(max_scores) + return seg_logits, out_box_coords, out_max_scores + + def train_forward(self, batch, **kwargs): + """ + train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data + for processing, computes losses, and stores outputs in a dictionary. + :param batch: dictionary containing 'data', 'seg', etc. + :param kwargs: + :return: results_dict: dictionary with keys: + 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: + [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] + 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes] + 'torch_loss': 1D torch tensor for backprop. + 'class_loss': classification loss for monitoring. here: dummy array, since no classification conducted. + """ + + img = torch.from_numpy(batch['data']).cuda().float() + seg = torch.from_numpy(batch['seg']).cuda().long() + seg_ohe = torch.from_numpy(mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).cuda() + results_dict = {} + seg_logits, box_coords, max_scores = self.forward(img) + + # no extra class loss applied in this model. pass dummy tensor for monitoring. + results_dict['class_loss'] = np.nan + + results_dict['boxes'] = [[] for _ in range(img.shape[0])] + for cix in range(len(self.cf.class_dict.keys())): + for bix in range(img.shape[0]): + for rix in range(len(max_scores[cix][bix])): + if max_scores[cix][bix][rix] > self.cf.detection_min_confidence: + results_dict['boxes'][bix].append({'box_coords': np.copy(box_coords[cix][bix][rix]), + 'box_score': max_scores[cix][bix][rix], + 'box_pred_class_id': cix + 1, # add 0 for background. + 'box_type': 'det'}) + + for bix in range(img.shape[0]): + for tix in range(len(batch['bb_target'][bix])): + gt_box = {'box_coords': batch['bb_target'][bix][tix], 'box_type': 'gt'} + for name in self.cf.roi_items: + gt_box.update({name: batch[name][bix][tix]}) + + results_dict['boxes'][bix].append(gt_box) + + # compute segmentation loss as either weighted cross entropy, dice loss, or the sum of both. + loss = torch.tensor([0.], dtype=torch.float, requires_grad=False).cuda() + seg_pred = F.softmax(seg_logits, dim=1) + if self.cf.seg_loss_mode == 'dice' or self.cf.seg_loss_mode == 'dice_wce': + loss += 1 - mutils.batch_dice(seg_pred, seg_ohe.float(), false_positive_weight=float(self.cf.fp_dice_weight)) + + if self.cf.seg_loss_mode == 'wce' or self.cf.seg_loss_mode == 'dice_wce': + loss += F.cross_entropy(seg_logits, seg[:, 0], weight=torch.FloatTensor(self.cf.wce_weights).cuda()) + + results_dict['torch_loss'] = loss + seg_pred = seg_pred.argmax(dim=1).unsqueeze(dim=1).cpu().data.numpy() + results_dict['seg_preds'] = seg_pred + if 'dice' in self.cf.metrics: + results_dict['batch_dices'] = mutils.dice_per_batch_and_class(seg_pred, batch["seg"], + self.cf.num_seg_classes, convert_to_ohe=True) + #self.logger.info("loss: {0:.2f}".format(loss.item())) + return results_dict + + + def test_forward(self, batch, **kwargs): + """ + test method. wrapper around forward pass of network without usage of any ground truth information. + prepares input data for processing and stores outputs in a dictionary. + :param batch: dictionary containing 'data' + :param kwargs: + :return: results_dict: dictionary with keys: + 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: + [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] + 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes] + """ + img = torch.FloatTensor(batch['data']).cuda() + seg_logits, box_coords, max_scores = self.forward(img) + + results_dict = {} + results_dict['boxes'] = [[] for _ in range(img.shape[0])] + for cix in range(len(box_coords)): + for bix in range(img.shape[0]): + for rix in range(len(max_scores[cix][bix])): + if max_scores[cix][bix][rix] > self.cf.detection_min_confidence: + results_dict['boxes'][bix].append({'box_coords': np.copy(box_coords[cix][bix][rix]), + 'box_score': max_scores[cix][bix][rix], + 'box_pred_class_id': cix + 1, + 'box_type': 'det'}) + results_dict['seg_preds'] = F.softmax(seg_logits, dim=1).cpu().data.numpy() + + return results_dict + diff --git a/models/detection_unet.py b/models/detection_unet.py new file mode 100644 index 0000000..142f560 --- /dev/null +++ b/models/detection_unet.py @@ -0,0 +1,545 @@ +import warnings +import os +import shutil +import time + +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + + +import utils.exp_utils as utils +import utils.model_utils as mutils + +''' +Use nn.DataParallel to use more than one GPU +''' + +def center_crop_2D_image_batched(img, crop_size): + # from batch generator tools from https://github.com/MIC-DKFZ/batchgenerators + # dim 0 is batch, dim 1 is channel, dim 2 and 3 are x y + center = np.array(img.shape[2:]) / 2. + if not hasattr(crop_size, "__iter__"): + center_crop = [int(crop_size)] * (len(img.shape) - 2) + else: + center_crop = np.array(crop_size) + assert len(center_crop) == (len( + img.shape) - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)" + return img[:, :, int(center[0] - center_crop[0] / 2.):int(center[0] + center_crop[0] / 2.), + int(center[1] - center_crop[1] / 2.):int(center[1] + center_crop[1] / 2.)] + +def center_crop_3D_image_batched(img, crop_size): + # dim 0 is batch, dim 1 is channel, dim 2, 3 and 4 are x y z + center = np.array(img.shape[2:]) / 2. + if not hasattr(crop_size, "__iter__"): + center_crop = np.array([int(crop_size)] * (len(img.shape) - 2)) + else: + center_crop = np.array(crop_size) + assert len(center_crop) == (len( + img.shape) - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)" + return img[:, :, int(center[0] - center_crop[0] / 2.):int(center[0] + center_crop[0] / 2.), + int(center[1] - center_crop[1] / 2.):int(center[1] + center_crop[1] / 2.), + int(center[2] - center_crop[2] / 2.):int(center[2] + center_crop[2] / 2.)] + + +def centercrop_vol(tensor, size): + """:param tensor: tensor whose last two dimensions should be centercropped to size + :param size: 2- or 3-int tuple of target (height, width(,depth)) + """ + dim = len(size) + if dim==2: + center_crop_2D_image_batched(tensor, size) + elif dim==3: + center_crop_2D_image_batched(tensor, size) + else: + raise Exception("invalid size argument {} encountered in centercrop".format(size)) + + """this below worked so fine, when optional z-dim was first spatial dim instead of last + h_, w_ = size[0], size[1] #target size + (h,w) = tensor.size()[-2:] #orig size + dh, dw = h-h_, w-w_ #deltas + if dim == 3: + d_ = size[2] + d = tensor.size()[-3] + dd = d-d_ + + if h_=h: + print("no h crop") + warn.warn("no height crop applied since target dims larger equal orig dims") + if w_=w: + warn.warn("no width crop applied since target dims larger equal orig dims") + if dim == 3: + if d_ < d: + tensor = tensor[..., dd // 2:-int(math.ceil(dd / 2.)),:,:] + elif d_ >= d: + warn.warn("no depth crop applied since target dims larger equal orig dims") + """ + + return tensor + +def dimcalc_conv2D(dims,F=3,s=1,pad="same"): + r""" + :param dims: orig width, height as (2,)-np.array + :param F: quadratic kernel size + :param s: stride + :param pad: pad + """ + if pad=="same": + pad = (F-1)//2 + h, w = dims[0], dims[1] + return np.floor([(h + 2*pad-F)/s+1, (w+ 2*pad-F)/s+1]) + +def dimcalc_transconv2D(dims,F=2,s=2): + r""" + :param dims: orig width, height as (2,)-np.array + :param F: quadratic kernel size + :param s: stride + """ + + h, w = dims[0], dims[1] + return np.array([(h-1)*s+F, (w-1)*s+F]) + +def dimcalc_Unet_std(init_dims, F=3, F_pool=2, F_up=2, s=1, s_pool=2, s_up=2, pad=0): + r"""Calculate theoretic dimensions of feature maps throughout layers of this U-net. + """ + dims = np.array(init_dims) + print("init dims: ", dims) + + def down(dims): + for i in range(2): + dims = dimcalc_conv2D(dims, F=F, s=s, pad=pad) + dims = dimcalc_conv2D(dims, F=F_pool, s=s_pool) + return dims.astype(int) + def up(dims): + for i in range(2): + dims = dimcalc_conv2D(dims, F=F, s=s, pad=pad) + dims = dimcalc_transconv2D(dims, F=F_up,s=s_up) + return dims.astype(int) + + stage = 1 + for i in range(4): + dims = down(dims) + print("stage ", stage, ": ", dims) + stage+=1 + for i in range(4): + dims = up(dims) + print("stage ", stage, ": ", dims) + stage+=1 + for i in range(2): + dims = dimcalc_conv2D(dims,F=F,s=s, pad=pad).astype(int) + print("final output size: ", dims) + return dims + +def dimcalc_Unet(init_dims, F=3, F_pool=2, F_up=2, s=1, s_pool=2, s_up=2, pad=0): + r"""Calculate theoretic dimensions of feature maps throughout layers of this U-net. + """ + dims = np.array(init_dims) + print("init dims: ", dims) + + def down(dims): + for i in range(3): + dims = dimcalc_conv2D(dims, F=F, s=s, pad=pad) + dims = dimcalc_conv2D(dims, F=F_pool, s=s_pool) + return dims.astype(int) + def up(dims): + dims = dimcalc_transconv2D(dims, F=F_up,s=s_up) + for i in range(3): + dims = dimcalc_conv2D(dims, F=F, s=s, pad=pad) + return dims.astype(int) + + stage = 1 + for i in range(6): + dims = down(dims) + print("stage ", stage, ": ", dims) + stage+=1 + for i in range(3): + dims = dimcalc_conv2D(dims, F=F, s=s, pad=pad) + for i in range(6): + dims = up(dims) + print("stage ", stage, ": ", dims) + stage+=1 + dims = dims.astype(int) + print("final output size: ", dims) + return dims + + + +class horiz_conv(nn.Module): + def __init__(self, in_chans, out_chans, kernel_size, c_gen, norm, pad=0, relu="relu", bottleneck=True): + super(horiz_conv, self).__init__() + #TODO maybe make res-block? + if bottleneck: + bottleneck = int(np.round((in_chans+out_chans)*3/8)) + #print("bottleneck:", bottleneck) + else: + bottleneck = out_chans + self.conv = nn.Sequential( + c_gen(in_chans, bottleneck, kernel_size, pad=pad, norm=norm, relu=relu), #TODO maybe use norm only on last conv? + c_gen(bottleneck, out_chans, kernel_size, pad=pad, norm=norm, relu=relu), #TODO maybe make bottleneck? + #c_gen(out_chans, out_chans, kernel_size, pad=pad, norm=norm, relu=relu), + ) + def forward(self, x): + x = self.conv(x) + return x + +class up(nn.Module): + def __init__(self, in_chans, out_chans, kernel_size, interpol, c_gen, norm, pad=0, relu="relu", stride_ip=2): + super(up, self).__init__() + self.dim = c_gen.dim + self.upsample = interpol(stride_ip, "bilinear") if self.dim==2 else interpol(stride_ip, "trilinear") #TODO check if fits with spatial dims order in data + self.reduce_chans = c_gen(in_chans, out_chans, ks=1, norm=norm, relu=None) + self.horiz = horiz_conv(out_chans*2, out_chans, kernel_size, c_gen, norm=norm, pad=pad, relu=relu) + + def forward(self, x, skip_inp): + #TODO maybe add highway weights in skips? + x = self.upsample(x) + x = self.reduce_chans(x) + #print("shape x, skip", x.shape, skip_inp.shape) + targ_size = x.size()[-self.dim:] #ft map x,y,z (spatial) + skip_inp = centercrop_vol(skip_inp, targ_size) + assert targ_size == skip_inp.size()[-self.dim:], "corresp. skip and forward dimensions don't match" + x = torch.cat((x,skip_inp),dim=1) + x = self.horiz(x) + return x + + +class net(nn.Module): + r"""U-Net with few more steps than standard. + + Dimensions: + feature maps have dims ...xhxwxd, d=feature map depth, h, w = orig + img height, width. h,w each are downsized by unpadded forward-convs and pooling, + upsized by upsampling or upconvolution. + If :math:`F\times F` is the single kernel_size and stride is :math:`s\geq 1`, + :math:`k` is the number of kernels in the conv, i.e. the resulting feature map depth, + (all may differ between operations), then + + :Forward Conv: input :math:`h \times w \times d` is converted to + .. math:: \left[ (h-F)//s+1 \right] \times \left[ (w-F)//s+1 \right] \times k + + :Pooling: input :math:`h \times w \times d` is converted to + .. math:: \left[ (h-F)//s+1 \right] \times \left[ (w-F)//s+1 \right] \times d, + pooling filters have no depths => orig depths preserved. + + :Up-Conv.: input :math:`h \times w \times d` is converted to + .. math:: \left[ (h-1)s + F \right] \times \left[ (w-1)s + F \right] \times k + """ + + + def down(self, in_chans, out_chans, kernel_size, kernel_size_m, pad=0, relu="relu",maintain_z=False): + """generate encoder block + :param in_chans: + :param out_chans: + :param kernel_size: + :param pad: + :return: + """ + if maintain_z and self.dim==3: + stride_pool = (2,2,1) + if not hasattr(kernel_size_m, "__iter__"): + kernel_size_m = [kernel_size_m]*self.dim + kernel_size_m = (*kernel_size_m[:-1], 1) + else: + stride_pool = 2 + module = nn.Sequential( + nn.MaxPool2d(kernel_size_m, stride=stride_pool) if self.dim == 2 else nn.MaxPool3d( + kernel_size_m, stride=stride_pool), + #--> needs stride 2 in z in upsampling as well! + horiz_conv(in_chans, out_chans, kernel_size, self.c_gen, self.norm, pad, relu=relu) + ) + return module + + def up(self, in_chans, out_chans, kernel_size, pad=0, relu="relu", maintain_z=False): + """generate decoder block + :param in_chans: + :param out_chans: + :param kernel_size: + :param pad: + :param relu: + :return: + """ + if maintain_z and self.dim==3: + stride_ip = (2,2,1) + else: + stride_ip = 2 + + module = up(in_chans, out_chans, kernel_size, self.Interpolator, self.c_gen, norm=self.norm, pad=pad, + relu=relu, stride_ip=stride_ip) + + return module + + + def __init__(self, cf, logger): + super(net, self).__init__() + + self.cf = cf + self.dim = cf.dim + self.norm = cf.norm + self.logger = logger + backbone = utils.import_module('bbone', cf.backbone_path) + self.c_gen = backbone.ConvGenerator(cf.dim) + self.Interpolator = backbone.Interpolate + + #down = DownBlockGen(cf.dim) + #up = UpBlockGen(cf.dim, backbone.Interpolate) + down = self.down + up = self.up + + pad = cf.pad + if pad=="same": + pad = (cf.kernel_size-1)//2 + + + self.dims = "not yet recorded" + self.is_cuda = False + + self.init = horiz_conv(len(cf.channels), cf.init_filts, cf.kernel_size, self.c_gen, self.norm, pad=pad, + relu=cf.relu) + + self.down1 = down(cf.init_filts, cf.init_filts*2, cf.kernel_size, cf.kernel_size_m, pad=pad, relu=cf.relu) + self.down2 = down(cf.init_filts*2, cf.init_filts*4, cf.kernel_size, cf.kernel_size_m, pad=pad, relu=cf.relu) + self.down3 = down(cf.init_filts*4, cf.init_filts*6, cf.kernel_size, cf.kernel_size_m, pad=pad, relu=cf.relu) + self.down4 = down(cf.init_filts*6, cf.init_filts*8, cf.kernel_size, cf.kernel_size_m, pad=pad, relu=cf.relu, + maintain_z=True) + self.down5 = down(cf.init_filts*8, cf.init_filts*12, cf.kernel_size, cf.kernel_size_m, pad=pad, relu=cf.relu, + maintain_z=True) + #self.down6 = down(cf.init_filts*10, cf.init_filts*14, cf.kernel_size, cf.kernel_size_m, pad=pad, relu=cf.relu) + + #self.up1 = up(cf.init_filts*14, cf.init_filts*10, cf.kernel_size, pad=pad, relu=cf.relu) + self.up2 = up(cf.init_filts*12, cf.init_filts*8, cf.kernel_size, pad=pad, relu=cf.relu, maintain_z=True) + self.up3 = up(cf.init_filts*8, cf.init_filts*6, cf.kernel_size, pad=pad, relu=cf.relu, maintain_z=True) + self.up4 = up(cf.init_filts*6, cf.init_filts*4, cf.kernel_size, pad=pad, relu=cf.relu) + self.up5 = up(cf.init_filts*4, cf.init_filts*2, cf.kernel_size, pad=pad, relu=cf.relu) + self.up6 = up(cf.init_filts*2, cf.init_filts, cf.kernel_size, pad=pad, relu=cf.relu) + + self.seg = self.c_gen(cf.init_filts, cf.num_seg_classes, 1, norm=None, relu=None) #TODO maybe apply norm too? + + + # initialize parameters + if self.cf.weight_init == "custom": + logger.info("Tried to use custom weight init which is not defined. Using pytorch default.") + elif self.cf.weight_init: + mutils.initialize_weights(self) + else: + logger.info("using default pytorch weight init") + + + def forward(self, x): + r'''Forward application of network-function. + + :param x: input to the network, expected as torch.tensor of dims + .. math:: batch\_size \times channels \times height \times width + requires_grad should be True for training + ''' + #self.dims = np.array([x.size()[-self.dim-1:]]) + + x1 = self.init(x) + #self.dims = np.vstack((self.dims, x1.size()[-self.dim-1:])) + + #---downwards--- + x2 = self.down1(x1) + #self.dims = np.vstack((self.dims, x2.size()[-self.dim-1:])) + x3 = self.down2(x2) + #self.dims = np.vstack((self.dims, x3.size()[-self.dim-1:])) + x4 = self.down3(x3) + #self.dims = np.vstack((self.dims, x4.size()[-self.dim-1:])) + x5 = self.down4(x4) + #self.dims = np.vstack((self.dims, x5.size()[-self.dim-1:])) + #x6 = self.down5(x5) + #self.dims = np.vstack((self.dims, x6.size()[-self.dim-1:])) + + #---bottom--- + x = self.down5(x5) + #self.dims = np.vstack((self.dims, x.size()[-self.dim-1:])) + + #---upwards--- + #x = self.up1(x, x6) + #self.dims = np.vstack((self.dims, x.size()[-self.dim-1:])) + x = self.up2(x, x5) + #self.dims = np.vstack((self.dims, x.size()[-self.dim-1:])) + x = self.up3(x, x4) + #self.dims = np.vstack((self.dims, x.size()[-self.dim-1:])) + x = self.up4(x, x3) + #self.dims = np.vstack((self.dims, x.size()[-self.dim-1:])) + x = self.up5(x, x2) + #self.dims = np.vstack((self.dims, x.size()[-self.dim-1:])) + + x = self.up6(x, x1) + #self.dims = np.vstack((self.dims, x.size()[-self.dim-1:])) + + # ---final--- + x = self.seg(x) + #self.dims = np.vstack((self.dims, x.size()[-self.dim-1:])) + + seg_logits = x + out_box_coords, out_scores = [], [] + seg_probs = F.softmax(seg_logits.detach(), dim=1).cpu().data.numpy() + #seg_probs = F.softmax(seg_logits, dim=1) + + assert seg_logits.shape[1]==self.cf.num_seg_classes + for cl in range(1, seg_logits.shape[1]): + hard_mask = np.copy(seg_probs).argmax(1) + #hard_mask = seg_probs.clone().argmax(1) + hard_mask[hard_mask != cl] = 0 + hard_mask[hard_mask == cl] = 1 + # perform connected component analysis on argmaxed predictions, + # draw boxes around components and return coordinates. + box_coords, rois = mutils.get_coords(hard_mask, self.cf.n_roi_candidates, self.cf.dim) + + # for each object, choose the highest softmax score (in the respective class) + # of all pixels in the component as object score. + scores = [[] for b_inst in range(x.shape[0])] # np.zeros((out_features.shape[0], self.cf.n_roi_candidates)) + for b_inst, brois in enumerate(rois): + for nix, nroi in enumerate(brois): + score_det = np.max if self.cf.score_det == "max" else np.median # score determination + scores[b_inst].append(score_det(seg_probs[b_inst, cl][nroi > 0])) + out_box_coords.append(box_coords) + out_scores.append(scores) + + return seg_logits, out_box_coords, out_scores + + # noinspection PyCallingNonCallable + def train_forward(self, batch, **kwargs): + """ + train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data + for processing, computes losses, and stores outputs in a dictionary. + :param batch: dictionary containing 'data', 'seg', etc. + :param kwargs: + :return: results_dict: dictionary with keys: + 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: + [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] + 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes] + 'torch_loss': 1D torch tensor for backprop. + 'class_loss': classification loss for monitoring. here: dummy array, since no classification conducted. + """ + + img = torch.from_numpy(batch["data"]).cuda() + seg = torch.from_numpy(batch["seg"]).long().cuda() + seg_ohe = torch.from_numpy(mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).cuda() + + results_dict = {} + seg_logits, box_coords, scores = self.forward(img) + + # no extra class loss applied in this model. pass dummy tensor for monitoring. + results_dict['class_loss'] = np.nan + + results_dict['boxes'] = [[] for _ in range(img.shape[0])] + for cix in range(len(self.cf.class_dict.keys())): + for bix in range(img.shape[0]): + for rix in range(len(scores[cix][bix])): + if scores[cix][bix][rix] > self.cf.detection_min_confidence: + results_dict['boxes'][bix].append({'box_coords': np.copy(box_coords[cix][bix][rix]), + 'box_score': scores[cix][bix][rix], + 'box_pred_class_id': cix + 1, # add 0 for background. + 'box_type': 'det', + }) + + for bix in range(img.shape[0]): #bix = batch-element index + for tix in range(len(batch['bb_target'][bix])): #target index + gt_box = {'box_coords': batch['bb_target'][bix][tix], 'box_type': 'gt'} + for name in self.cf.roi_items: + gt_box.update({name: batch[name][bix][tix]}) + results_dict['boxes'][bix].append(gt_box) + + # compute segmentation loss as either weighted cross entropy, dice loss, or the sum of both. + seg_pred = F.softmax(seg_logits, 1) + loss = torch.tensor([0.], dtype=torch.float, requires_grad=False).cuda() + if self.cf.seg_loss_mode == 'dice' or self.cf.seg_loss_mode == 'dice_wce': + loss += 1 - mutils.batch_dice(seg_pred, seg_ohe.float(), + false_positive_weight=float(self.cf.fp_dice_weight)) + + if self.cf.seg_loss_mode == 'wce' or self.cf.seg_loss_mode == 'dice_wce': + loss += F.cross_entropy(seg_logits, seg[:, 0], weight=torch.FloatTensor(self.cf.wce_weights).cuda(), + reduction='elementwise_mean') + + results_dict['torch_loss'] = loss + seg_pred = seg_pred.argmax(dim=1).unsqueeze(dim=1).cpu().data.numpy() + results_dict['seg_preds'] = seg_pred + if 'dice' in self.cf.metrics: + results_dict['batch_dices'] = mutils.dice_per_batch_and_class(seg_pred, batch["seg"], + self.cf.num_seg_classes, convert_to_ohe=True) + #print("batch dice scores ", results_dict['batch_dices'] ) + # self.logger.info("loss: {0:.2f}".format(loss.item())) + return results_dict + + def test_forward(self, batch, **kwargs): + """ + test method. wrapper around forward pass of network without usage of any ground truth information. + prepares input data for processing and stores outputs in a dictionary. + :param batch: dictionary containing 'data' + :param kwargs: + :return: results_dict: dictionary with keys: + 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: + [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] + 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes] + """ + img = torch.FloatTensor(batch['data']).cuda() + seg_logits, box_coords, scores = self.forward(img) + + results_dict = {} + results_dict['boxes'] = [[] for b_inst in range(img.shape[0])] + for cix in range(len(box_coords)): #class index + for bix in range(img.shape[0]): #batch instance + for rix in range(len(scores[cix][bix])): #range(self.cf.n_roi_candidates): + if scores[cix][bix][rix] > self.cf.detection_min_confidence: + results_dict['boxes'][bix].append({'box_coords': np.copy(box_coords[cix][bix][rix]), + 'box_score': scores[cix][bix][rix], + 'box_pred_class_id': cix + 1, + 'box_type': 'det'}) + # carry probs instead of preds to use for multi-model voting in predictor + results_dict['seg_preds'] = F.softmax(seg_logits, dim=1).cpu().data.numpy() + + + return results_dict + + + def actual_dims(self, print_=True): + r"""Return dimensions of actually calculated layers at beginning of each block. + """ + if print_: + print("dimensions as recorded in forward pass: ") + for stage in range(len(self.dims)): + print("Stage ", stage, ": ", self.dims[stage]) + return self.dims + + def cuda(self, device=None): + r"""Moves all model parameters and buffers to the GPU. + + This also makes associated parameters and buffers different objects. So + it should be called before constructing optimizer if the module will + live on GPU while being optimized. + + Arguments: + device (int, optional): if specified, all parameters will be + copied to that device + + Returns: + Module: self + """ + try: + self.loss_f = self.loss_f.cuda() + except: + pass + self.is_cuda = True + return self._apply(lambda t: t.cuda(device)) + + def cpu(self): + r"""Moves all model parameters and buffers to the CPU. + + Returns: + Module: self + """ + self.is_cuda = False + return self._apply(lambda t: t.cpu()) + + + + + \ No newline at end of file diff --git a/models/mrcnn.py b/models/mrcnn.py new file mode 100644 index 0000000..ca1566d --- /dev/null +++ b/models/mrcnn.py @@ -0,0 +1,758 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn +published under MIT license. +""" +import os +from multiprocessing import Pool +import time + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils + +import utils.model_utils as mutils +import utils.exp_utils as utils + + + +class RPN(nn.Module): + """ + Region Proposal Network. + """ + + def __init__(self, cf, conv): + + super(RPN, self).__init__() + self.dim = conv.dim + + self.conv_shared = conv(cf.end_filts, cf.n_rpn_features, ks=3, stride=cf.rpn_anchor_stride, pad=1, relu=cf.relu) + self.conv_class = conv(cf.n_rpn_features, 2 * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None) + self.conv_bbox = conv(cf.n_rpn_features, 2 * self.dim * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None) + + + def forward(self, x): + """ + :param x: input feature maps (b, in_channels, y, x, (z)) + :return: rpn_class_logits (b, 2, n_anchors) + :return: rpn_probs_logits (b, 2, n_anchors) + :return: rpn_bbox (b, 2 * dim, n_anchors) + """ + + # Shared convolutional base of the RPN. + x = self.conv_shared(x) + + # Anchor Score. (batch, anchors per location * 2, y, x, (z)). + rpn_class_logits = self.conv_class(x) + # Reshape to (batch, 2, anchors) + axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) + rpn_class_logits = rpn_class_logits.permute(*axes) + rpn_class_logits = rpn_class_logits.contiguous() + rpn_class_logits = rpn_class_logits.view(x.size()[0], -1, 2) + + # Softmax on last dimension (fg vs. bg). + rpn_probs = F.softmax(rpn_class_logits, dim=2) + + # Bounding box refinement. (batch, anchors_per_location * (y, x, (z), log(h), log(w), (log(d)), y, x, (z)) + rpn_bbox = self.conv_bbox(x) + + # Reshape to (batch, 2*dim, anchors) + rpn_bbox = rpn_bbox.permute(*axes) + rpn_bbox = rpn_bbox.contiguous() + rpn_bbox = rpn_bbox.view(x.size()[0], -1, self.dim * 2) + + return [rpn_class_logits, rpn_probs, rpn_bbox] + + + +class Classifier(nn.Module): + """ + Head network for classification and bounding box refinement. Performs RoiAlign, processes resulting features through a + shared convolutional base and finally branches off the classifier- and regression head. + """ + def __init__(self, cf, conv): + super(Classifier, self).__init__() + + self.cf = cf + self.dim = conv.dim + self.in_channels = cf.end_filts + self.pool_size = cf.pool_size + self.pyramid_levels = cf.pyramid_levels + # instance_norm does not work with spatial dims (1, 1, (1)) + norm = cf.norm if cf.norm != 'instance_norm' else None + + self.conv1 = conv(cf.end_filts, cf.end_filts * 4, ks=self.pool_size, stride=1, norm=norm, relu=cf.relu) + self.conv2 = conv(cf.end_filts * 4, cf.end_filts * 4, ks=1, stride=1, norm=norm, relu=cf.relu) + self.linear_bbox = nn.Linear(cf.end_filts * 4, cf.head_classes * 2 * self.dim) + + + if 'regression' in self.cf.prediction_tasks: + self.linear_regressor = nn.Linear(cf.end_filts * 4, cf.head_classes * cf.regression_n_features) + self.rg_n_feats = cf.regression_n_features + #classify into bins of regression values + elif 'regression_bin' in self.cf.prediction_tasks: + self.linear_regressor = nn.Linear(cf.end_filts * 4, cf.head_classes * len(cf.bin_labels)) + self.rg_n_feats = len(cf.bin_labels) + else: + self.linear_regressor = lambda x: torch.zeros((x.shape[0], cf.head_classes * 1), dtype=torch.float32).fill_(float('NaN')).cuda() + self.rg_n_feats = 1 #cf.regression_n_features + if 'class' in self.cf.prediction_tasks: + self.linear_class = nn.Linear(cf.end_filts * 4, cf.head_classes) + else: + assert cf.head_classes == 2, "#head classes {} needs to be 2 (bg/fg) when not predicting classes".format(cf.head_classes) + self.linear_class = lambda x: torch.zeros((x.shape[0], cf.head_classes), dtype=torch.float64).cuda() + + + def forward(self, x, rois): + """ + :param x: input feature maps (b, in_channels, y, x, (z)) + :param rois: normalized box coordinates as proposed by the RPN to be forwarded through + the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements + have been merged to one vector, while the origin info has been stored for re-allocation. + :return: mrcnn_class_logits (n_proposals, n_head_classes) + :return: mrcnn_bbox (n_proposals, n_head_classes, 2 * dim) predicted corrections to be applied to proposals for refinement. + """ + x = mutils.pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim) + x = self.conv1(x) + x = self.conv2(x) + x = x.view(-1, self.in_channels * 4) + + mrcnn_bbox = self.linear_bbox(x) + mrcnn_bbox = mrcnn_bbox.view(mrcnn_bbox.size()[0], -1, self.dim * 2) + mrcnn_class_logits = self.linear_class(x) + mrcnn_regress = self.linear_regressor(x) + mrcnn_regress = mrcnn_regress.view(mrcnn_regress.size()[0], -1, self.rg_n_feats) + + return [mrcnn_bbox, mrcnn_class_logits, mrcnn_regress] + + +class Mask(nn.Module): + """ + Head network for proposal-based mask segmentation. Performs RoiAlign, some convolutions and applies sigmoid on the + output logits to allow for overlapping classes. + """ + def __init__(self, cf, conv): + super(Mask, self).__init__() + self.pool_size = cf.mask_pool_size + self.pyramid_levels = cf.pyramid_levels + self.dim = conv.dim + self.conv1 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) + self.conv2 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) + self.conv3 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) + self.conv4 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) + if conv.dim == 2: + self.deconv = nn.ConvTranspose2d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2) + else: + self.deconv = nn.ConvTranspose3d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2) + + self.relu = nn.ReLU(inplace=True) if cf.relu == 'relu' else nn.LeakyReLU(inplace=True) + self.conv5 = conv(cf.end_filts, cf.head_classes, ks=1, stride=1, relu=None) + self.sigmoid = nn.Sigmoid() + + def forward(self, x, rois): + """ + :param x: input feature maps (b, in_channels, y, x, (z)) + :param rois: normalized box coordinates as proposed by the RPN to be forwarded through + the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements + have been merged to one vector, while the origin info has been stored for re-allocation. + :return: x: masks (n_sampled_proposals (n_detections in inference), n_classes, y, x, (z)) + """ + x = mutils.pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim) + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + x = self.relu(self.deconv(x)) + x = self.conv5(x) + x = self.sigmoid(x) + return x + + +############################################################ +# Loss Functions +############################################################ + +def compute_rpn_class_loss(rpn_class_logits, rpn_match, shem_poolsize): + """ + :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. + :param rpn_class_logits: (n_anchors, 2). logits from RPN classifier. + :param SHEM_poolsize: int. factor of top-k candidates to draw from per negative sample (stochastic-hard-example-mining). + :return: loss: torch tensor + :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training. + """ + + # Filter out netural anchors + pos_indices = torch.nonzero(rpn_match == 1) + neg_indices = torch.nonzero(rpn_match == -1) + + # loss for positive samples + if not 0 in pos_indices.size(): + pos_indices = pos_indices.squeeze(1) + roi_logits_pos = rpn_class_logits[pos_indices] + pos_loss = F.cross_entropy(roi_logits_pos, torch.LongTensor([1] * pos_indices.shape[0]).cuda()) + else: + pos_loss = torch.FloatTensor([0]).cuda() + + # loss for negative samples: draw hard negative examples (SHEM) + # that match the number of positive samples, but at least 1. + if not 0 in neg_indices.size(): + neg_indices = neg_indices.squeeze(1) + roi_logits_neg = rpn_class_logits[neg_indices] + negative_count = np.max((1, pos_indices.cpu().data.numpy().size)) + roi_probs_neg = F.softmax(roi_logits_neg, dim=1) + neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize) + neg_loss = F.cross_entropy(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda()) + np_neg_ix = neg_ix.cpu().data.numpy() + #print("pos, neg count", pos_indices.cpu().data.numpy().size, negative_count) + else: + neg_loss = torch.FloatTensor([0]).cuda() + np_neg_ix = np.array([]).astype('int32') + + loss = (pos_loss + neg_loss) / 2 + return loss, np_neg_ix + + +def compute_rpn_bbox_loss(rpn_pred_deltas, rpn_target_deltas, rpn_match): + """ + :param rpn_target_deltas: (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))). + Uses 0 padding to fill in unsed bbox deltas. + :param rpn_pred_deltas: predicted deltas from RPN. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))) + :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. + :return: loss: torch 1D tensor. + """ + if not 0 in torch.nonzero(rpn_match == 1).size(): + + indices = torch.nonzero(rpn_match == 1).squeeze(1) + # Pick bbox deltas that contribute to the loss + rpn_pred_deltas = rpn_pred_deltas[indices] + # Trim target bounding box deltas to the same length as rpn_bbox. + target_deltas = rpn_target_deltas[:rpn_pred_deltas.size()[0], :] + # Smooth L1 loss + loss = F.smooth_l1_loss(rpn_pred_deltas, target_deltas) + else: + loss = torch.FloatTensor([0]).cuda() + + return loss + +def compute_mrcnn_bbox_loss(mrcnn_pred_deltas, mrcnn_target_deltas, target_class_ids): + """ + :param mrcnn_target_deltas: (n_sampled_rois, (dy, dx, (dz), log(dh), log(dw), (log(dh))) + :param mrcnn_pred_deltas: (n_sampled_rois, n_classes, (dy, dx, (dz), log(dh), log(dw), (log(dh))) + :param target_class_ids: (n_sampled_rois) + :return: loss: torch 1D tensor. + """ + if not 0 in torch.nonzero(target_class_ids > 0).size(): + positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0] + positive_roi_class_ids = target_class_ids[positive_roi_ix].long() + target_bbox = mrcnn_target_deltas[positive_roi_ix, :].detach() + pred_bbox = mrcnn_pred_deltas[positive_roi_ix, positive_roi_class_ids, :] + loss = F.smooth_l1_loss(pred_bbox, target_bbox) + else: + loss = torch.FloatTensor([0]).cuda() + + return loss + +def compute_mrcnn_mask_loss(pred_masks, target_masks, target_class_ids): + """ + :param target_masks: (n_sampled_rois, y, x, (z)) A float32 tensor of values 0 or 1. Uses zero padding to fill array. + :param pred_masks: (n_sampled_rois, n_classes, y, x, (z)) float32 tensor with values between [0, 1]. + :param target_class_ids: (n_sampled_rois) + :return: loss: torch 1D tensor. + """ + if not 0 in torch.nonzero(target_class_ids > 0).size(): + # Only positive ROIs contribute to the loss. And only + # the class-specific mask of each ROI. + positive_ix = torch.nonzero(target_class_ids > 0)[:, 0] + positive_class_ids = target_class_ids[positive_ix].long() + y_true = target_masks[positive_ix, :, :].detach() + y_pred = pred_masks[positive_ix, positive_class_ids, :, :] + loss = F.binary_cross_entropy(y_pred, y_true) + else: + loss = torch.FloatTensor([0]).cuda() + + return loss + +def compute_mrcnn_class_loss(tasks, pred_class_logits, target_class_ids): + """ + :param pred_class_logits: (n_sampled_rois, n_classes) + :param target_class_ids: (n_sampled_rois) batch dimension was merged into roi dimension. + :return: loss: torch 1D tensor. + """ + if 'class' in tasks and not 0 in target_class_ids.size(): + loss = F.cross_entropy(pred_class_logits, target_class_ids.long()) + else: + loss = torch.FloatTensor([0.]).cuda() + + return loss + +def compute_mrcnn_regression_loss(tasks, pred, target, target_class_ids): + """regression loss is a distance metric between target vector and predicted regression vector. + :param pred: (n_sampled_rois, n_classes, [n_rg_feats if real regression or 1 if rg_bin task) + :param target: (n_sampled_rois, [n_rg_feats or n_rg_bins]) + :return: differentiable loss, torch 1D tensor on cuda + """ + + if not 0 in target.shape and not 0 in torch.nonzero(target_class_ids > 0).shape: + positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0] + positive_roi_class_ids = target_class_ids[positive_roi_ix].long() + target = target[positive_roi_ix].detach() + pred = pred[positive_roi_ix, positive_roi_class_ids] + if "regression_bin" in tasks: + loss = F.cross_entropy(pred, target.long()) + else: + loss = F.smooth_l1_loss(pred, target) + #loss = F.mse_loss(pred, target) + else: + loss = torch.FloatTensor([0.]).cuda() + + return loss + +############################################################ +# Detection Layer +############################################################ + +def compute_roi_scores(tasks, batch_rpn_proposals, mrcnn_cl_logits): + """ Depending on the predicition tasks: if no class prediction beyong fg/bg (--> means no additional class + head was applied) use RPN objectness scores as roi scores, otherwise class head scores. + :param cf: + :param batch_rpn_proposals: + :param mrcnn_cl_logits: + :return: + """ + if not 'class' in tasks: + scores = batch_rpn_proposals[:, :, -1].view(-1, 1) + scores = torch.cat((1 - scores, scores), dim=1) + else: + scores = F.softmax(mrcnn_cl_logits, dim=1) + + return scores + +############################################################ +# MaskRCNN Class +############################################################ + +class net(nn.Module): + + + def __init__(self, cf, logger): + + super(net, self).__init__() + self.cf = cf + self.logger = logger + self.build() + + loss_order = ['rpn_class', 'rpn_bbox', 'mrcnn_bbox', 'mrcnn_mask', 'mrcnn_class', 'mrcnn_rg'] + if hasattr(cf, "mrcnn_loss_weights"): + #bring into right order + self.loss_weights = np.array([cf.mrcnn_loss_weights[k] for k in loss_order]) + else: + self.loss_weights = np.array([1.]*len(loss_order)) + + if self.cf.weight_init=="custom": + logger.info("Tried to use custom weight init which is not defined. Using pytorch default.") + elif self.cf.weight_init: + mutils.initialize_weights(self) + else: + logger.info("using default pytorch weight init") + + def build(self): + """Build Mask R-CNN architecture.""" + + # Image size must be dividable by 2 multiple times. + h, w = self.cf.patch_size[:2] + if h / 2**5 != int(h / 2**5) or w / 2**5 != int(w / 2**5): + raise Exception("Image size must be divisible by 2 at least 5 times " + "to avoid fractions when downscaling and upscaling." + "For example, use 256, 288, 320, 384, 448, 512, ... etc.,i.e.," + "any number x*32 will do!") + + # instantiate abstract multi-dimensional conv generator and load backbone module. + backbone = utils.import_module('bbone', self.cf.backbone_path) + self.logger.info("loaded backbone from {}".format(self.cf.backbone_path)) + conv = backbone.ConvGenerator(self.cf.dim) + + # build Anchors, FPN, RPN, Classifier / Bbox-Regressor -head, Mask-head + self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf) + self.anchors = torch.from_numpy(self.np_anchors).float().cuda() + self.fpn = backbone.FPN(self.cf, conv, relu_enc=self.cf.relu, operate_stride1=False).cuda() + self.rpn = RPN(self.cf, conv) + self.classifier = Classifier(self.cf, conv) + self.mask = Mask(self.cf, conv) + + def forward(self, img, is_training=True): + """ + :param img: input images (b, c, y, x, (z)). + :return: rpn_pred_logits: (b, n_anchors, 2) + :return: rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d)))) + :return: batch_proposal_boxes: (b, n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix)) only for monitoring/plotting. + :return: detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score) + :return: detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head. + """ + # extract features. + fpn_outs = self.fpn(img) + rpn_feature_maps = [fpn_outs[i] for i in self.cf.pyramid_levels] + self.mrcnn_feature_maps = rpn_feature_maps + + # loop through pyramid layers and apply RPN. + layer_outputs = [ self.rpn(p_feats) for p_feats in rpn_feature_maps ] + + # concatenate layer outputs. + # convert from list of lists of level outputs to list of lists of outputs across levels. + # e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]] + outputs = list(zip(*layer_outputs)) + outputs = [torch.cat(list(o), dim=1) for o in outputs] + rpn_pred_logits, rpn_pred_probs, rpn_pred_deltas = outputs + # + # # generate proposals: apply predicted deltas to anchors and filter by foreground scores from RPN classifier. + proposal_count = self.cf.post_nms_rois_training if is_training else self.cf.post_nms_rois_inference + batch_normed_props, batch_unnormed_props = mutils.refine_proposals(rpn_pred_probs, rpn_pred_deltas, + proposal_count, self.anchors, self.cf) + + # merge batch dimension of proposals while storing allocation info in coordinate dimension. + batch_ixs = torch.arange( + batch_normed_props.shape[0]).cuda().unsqueeze(1).repeat(1,batch_normed_props.shape[1]).view(-1).float() + rpn_rois = batch_normed_props[:, :, :-1].view(-1, batch_normed_props[:, :, :-1].shape[2]) + self.rpn_rois_batch_info = torch.cat((rpn_rois, batch_ixs.unsqueeze(1)), dim=1) + + # this is the first of two forward passes in the second stage, where no activations are stored for backprop. + # here, all proposals are forwarded (with virtual_batch_size = batch_size * post_nms_rois.) + # for inference/monitoring as well as sampling of rois for the loss functions. + # processed in chunks of roi_chunk_size to re-adjust to gpu-memory. + chunked_rpn_rois = self.rpn_rois_batch_info.split(self.cf.roi_chunk_size) + bboxes_list, class_logits_list, regressions_list = [], [], [] + with torch.no_grad(): + for chunk in chunked_rpn_rois: + chunk_bboxes, chunk_class_logits, chunk_regressions = self.classifier(self.mrcnn_feature_maps, chunk) + bboxes_list.append(chunk_bboxes) + class_logits_list.append(chunk_class_logits) + regressions_list.append(chunk_regressions) + mrcnn_bbox = torch.cat(bboxes_list, 0) + mrcnn_class_logits = torch.cat(class_logits_list, 0) + mrcnn_regressions = torch.cat(regressions_list, 0) + self.mrcnn_roi_scores = compute_roi_scores(self.cf.prediction_tasks, batch_normed_props, mrcnn_class_logits) + + # refine classified proposals, filter and return final detections. + # returns (cf.max_inst_per_batch_element, n_coords+1+...) + detections = mutils.refine_detections(self.cf, batch_ixs, rpn_rois, mrcnn_bbox, self.mrcnn_roi_scores, + mrcnn_regressions) + + # forward remaining detections through mask-head to generate corresponding masks. + scale = [img.shape[2]] * 4 + [img.shape[-1]] * 2 + scale = torch.from_numpy(np.array(scale[:self.cf.dim * 2] + [1])[None]).float().cuda() + + # first self.cf.dim * 2 entries on axis 1 are always the box coords, +1 is batch_ix + detection_boxes = detections[:, :self.cf.dim * 2 + 1] / scale + with torch.no_grad(): + detection_masks = self.mask(self.mrcnn_feature_maps, detection_boxes) + + return [rpn_pred_logits, rpn_pred_deltas, batch_unnormed_props, detections, detection_masks] + + + def loss_samples_forward(self, batch_gt_boxes, batch_gt_masks, batch_gt_class_ids, batch_gt_regressions=None): + """ + this is the second forward pass through the second stage (features from stage one are re-used). + samples few rois in loss_example_mining and forwards only those for loss computation. + :param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels. + :param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates. + :param batch_gt_masks: list over batch elements. Each element is binary mask of shape (n_gt_rois, y, x, (z), c) + :return: sample_logits: (n_sampled_rois, n_classes) predicted class scores. + :return: sample_deltas: (n_sampled_rois, n_classes, 2 * dim) predicted corrections to be applied to proposals for refinement. + :return: sample_mask: (n_sampled_rois, n_classes, y, x, (z)) predicted masks per class and proposal. + :return: sample_target_class_ids: (n_sampled_rois) target class labels of sampled proposals. + :return: sample_target_deltas: (n_sampled_rois, 2 * dim) target deltas of sampled proposals for box refinement. + :return: sample_target_masks: (n_sampled_rois, y, x, (z)) target masks of sampled proposals. + :return: sample_proposals: (n_sampled_rois, 2 * dim) RPN output for sampled proposals. only for monitoring/plotting. + """ + # sample rois for loss and get corresponding targets for all Mask R-CNN head network losses. + sample_ics, sample_target_deltas, sample_target_mask, sample_target_class_ids, sample_target_regressions = \ + mutils.loss_example_mining(self.cf, self.rpn_rois_batch_info, batch_gt_boxes, batch_gt_masks, + self.mrcnn_roi_scores, batch_gt_class_ids, batch_gt_regressions) + + # re-use feature maps and RPN output from first forward pass. + sample_proposals = self.rpn_rois_batch_info[sample_ics] + if not 0 in sample_proposals.size(): + sample_deltas, sample_logits, sample_regressions = self.classifier(self.mrcnn_feature_maps, sample_proposals) + sample_mask = self.mask(self.mrcnn_feature_maps, sample_proposals) + else: + sample_logits = torch.FloatTensor().cuda() + sample_deltas = torch.FloatTensor().cuda() + sample_regressions = torch.FloatTensor().cuda() + sample_mask = torch.FloatTensor().cuda() + + return [sample_deltas, sample_mask, sample_logits, sample_regressions, sample_proposals, + sample_target_deltas, sample_target_mask, sample_target_class_ids, sample_target_regressions] + + def get_results(self, img_shape, detections, detection_masks, box_results_list=None, return_masks=True): + """ + Restores batch dimension of merged detections, unmolds detections, creates and fills results dict. + :param img_shape: + :param detections: shape (n_final_detections, len(info)), where + info=( y1, x1, y2, x2, (z1,z2), batch_ix, pred_class_id, pred_score ) + :param detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head. + :param box_results_list: None or list of output boxes for monitoring/plotting. + each element is a list of boxes per batch element. + :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off). + :return: results_dict: dictionary with keys: + 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: + [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] + 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, 1] only fg. vs. bg for now. + class-specific return of masks will come with implementation of instance segmentation evaluation. + """ + + detections = detections.cpu().data.numpy() + if self.cf.dim == 2: + detection_masks = detection_masks.permute(0, 2, 3, 1).cpu().data.numpy() + else: + detection_masks = detection_masks.permute(0, 2, 3, 4, 1).cpu().data.numpy() + # det masks shape now (n_dets, y,x(,z), n_classes) + # restore batch dimension of merged detections using the batch_ix info. + batch_ixs = detections[:, self.cf.dim*2] + detections = [detections[batch_ixs == ix] for ix in range(img_shape[0])] + mrcnn_mask = [detection_masks[batch_ixs == ix] for ix in range(img_shape[0])] + #mrcnn_mask: shape (b_size, variable, variable, n_classes), variable bc depends on single instance mask size + + if box_results_list == None: # for test_forward, where no previous list exists. + box_results_list = [[] for _ in range(img_shape[0])] + # seg_logits == seg_probs in mrcnn since mask head finishes with sigmoid (--> image space = [0,1]) + seg_probs = [] + # loop over batch and unmold detections. + for ix in range(img_shape[0]): + + # final masks are one-hot encoded (b, n_classes, y, x, (z)) + final_masks = np.zeros((self.cf.num_classes + 1, *img_shape[2:])) + #+1 for bg, 0.5 bc mask head classifies only bg/fg with logits between 0,1--> bg is <0.5 + if self.cf.num_classes + 1 != self.cf.num_seg_classes: + self.logger.warning("n of roi-classifier head classes {} doesnt match cf.num_seg_classes {}".format( + self.cf.num_classes + 1, self.cf.num_seg_classes)) + + if not 0 in detections[ix].shape: + boxes = detections[ix][:, :self.cf.dim*2].astype(np.int32) + class_ids = detections[ix][:, self.cf.dim*2 + 1].astype(np.int32) + scores = detections[ix][:, self.cf.dim*2 + 2] + masks = mrcnn_mask[ix][np.arange(boxes.shape[0]), ..., class_ids] + regressions = detections[ix][:,self.cf.dim*2+3:] + + # Filter out detections with zero area. Often only happens in early + # stages of training when the network weights are still a bit random. + if self.cf.dim == 2: + exclude_ix = np.where((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0] + else: + exclude_ix = np.where( + (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4]) <= 0)[0] + + if exclude_ix.shape[0] > 0: + boxes = np.delete(boxes, exclude_ix, axis=0) + masks = np.delete(masks, exclude_ix, axis=0) + class_ids = np.delete(class_ids, exclude_ix, axis=0) + scores = np.delete(scores, exclude_ix, axis=0) + regressions = np.delete(regressions, exclude_ix, axis=0) + + # Resize masks to original image size and set boundary threshold. + if return_masks: + for i in range(masks.shape[0]): #masks per this batch instance/element/image + # Convert neural network mask to full size mask + if self.cf.dim == 2: + full_mask = mutils.unmold_mask_2D(masks[i], boxes[i], img_shape[2:]) + else: + full_mask = mutils.unmold_mask_3D(masks[i], boxes[i], img_shape[2:]) + # take the maximum seg_logits per class of instances in that class, i.e., a pixel in a class + # has the max seg_logit value over all instances of that class in one sample + final_masks[class_ids[i]] = np.max((final_masks[class_ids[i]], full_mask), axis=0) + final_masks[0] = np.full(final_masks[0].shape, 0.49999999) #effectively min_det_thres at 0.5 per pixel + + # add final predictions to results. + if not 0 in boxes.shape: + for ix2, coords in enumerate(boxes): + box = {'box_coords': coords, 'box_type': 'det', 'box_score': scores[ix2], + 'box_pred_class_id': class_ids[ix2]} + #if (hasattr(self.cf, "convert_cl_to_rg") and self.cf.convert_cl_to_rg): + if "regression_bin" in self.cf.prediction_tasks: + # in this case, regression preds are actually the rg_bin_ids --> map to rg value the bin represents + box['rg_bin'] = regressions[ix2].argmax() + box['regression'] = self.cf.bin_id2rg_val[box['rg_bin']] + else: + box['regression'] = regressions[ix2] + if hasattr(self.cf, "rg_val_to_bin_id") and \ + any(['regression' in task for task in self.cf.prediction_tasks]): + box.update({'rg_bin': self.cf.rg_val_to_bin_id(regressions[ix2])}) + + box_results_list[ix].append(box) + + # if no detections were made--> keep full bg mask (zeros). + seg_probs.append(final_masks) + + # create and fill results dictionary. + results_dict = {} + results_dict['boxes'] = box_results_list + results_dict['seg_preds'] = np.array(seg_probs) + + return results_dict + + def train_forward(self, batch, is_validation=False): + """ + train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data + for processing, computes losses, and stores outputs in a dictionary. + :param batch: dictionary containing 'data', 'seg', etc. + :return: results_dict: dictionary with keys: + 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: + [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] + 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]. + 'torch_loss': 1D torch tensor for backprop. + 'class_loss': classification loss for monitoring. + """ + img = batch['data'] + gt_boxes = batch['bb_target'] + axes = (0, 2, 3, 1) if self.cf.dim == 2 else (0, 2, 3, 4, 1) + gt_masks = [np.transpose(batch['roi_masks'][ii], axes=axes) for ii in range(len(batch['roi_masks']))] + gt_class_ids = batch['class_targets'] + if 'regression' in self.cf.prediction_tasks: + gt_regressions = batch["regression_targets"] + elif 'regression_bin' in self.cf.prediction_tasks: + gt_regressions = batch["rg_bin_targets"] + else: + gt_regressions = None + + + img = torch.from_numpy(img).cuda().float() + batch_rpn_class_loss = torch.FloatTensor([0]).cuda() + batch_rpn_bbox_loss = torch.FloatTensor([0]).cuda() + + # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. + box_results_list = [[] for _ in range(img.shape[0])] + + #forward passes. 1. general forward pass, where no activations are saved in second stage (for performance + # monitoring and loss sampling). 2. second stage forward pass of sampled rois with stored activations for backprop. + rpn_class_logits, rpn_pred_deltas, proposal_boxes, detections, detection_masks = self.forward(img) + + mrcnn_pred_deltas, mrcnn_pred_mask, mrcnn_class_logits, mrcnn_regressions, sample_proposals, \ + mrcnn_target_deltas, target_mask, target_class_ids, target_regressions = \ + self.loss_samples_forward(gt_boxes, gt_masks, gt_class_ids, gt_regressions) + + stime = time.time() + #loop over batch + for b in range(img.shape[0]): + if len(gt_boxes[b]) > 0: + # add gt boxes to output list + for tix in range(len(gt_boxes[b])): + gt_box = {'box_type': 'gt', 'box_coords': batch['bb_target'][b][tix]} + for name in self.cf.roi_items: + gt_box.update({name: batch[name][b][tix]}) + box_results_list[b].append(gt_box) + + # match gt boxes with anchors to generate targets for RPN losses. + rpn_match, rpn_target_deltas = mutils.gt_anchor_matching(self.cf, self.np_anchors, gt_boxes[b]) + + # add positive anchors used for loss to output list for monitoring. + pos_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == 1)][:, 0], img.shape[2:]) + for p in pos_anchors: + box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'}) + + else: + rpn_match = np.array([-1]*self.np_anchors.shape[0]) + rpn_target_deltas = np.array([0]) + + rpn_match = torch.from_numpy(rpn_match).cuda() + rpn_target_deltas = torch.from_numpy(rpn_target_deltas).float().cuda() + + # compute RPN losses. + rpn_class_loss, neg_anchor_ix = compute_rpn_class_loss(rpn_class_logits[b], rpn_match, self.cf.shem_poolsize) + rpn_bbox_loss = compute_rpn_bbox_loss(rpn_pred_deltas[b], rpn_target_deltas, rpn_match) + batch_rpn_class_loss += rpn_class_loss /img.shape[0] + batch_rpn_bbox_loss += rpn_bbox_loss /img.shape[0] + + # add negative anchors used for loss to output list for monitoring. + neg_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == -1)][0, neg_anchor_ix], img.shape[2:]) + for n in neg_anchors: + box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'}) + + # add highest scoring proposals to output list for monitoring. + rpn_proposals = proposal_boxes[b][proposal_boxes[b, :, -1].argsort()][::-1] + for r in rpn_proposals[:self.cf.n_plot_rpn_props, :-1]: + box_results_list[b].append({'box_coords': r, 'box_type': 'prop'}) + #print("gt anc matching, rpn losses loop time {:.4f}s".format(time.time()-stime)) + + # add positive and negative roi samples used for mrcnn losses to output list for monitoring. + if not 0 in sample_proposals.shape: + rois = mutils.clip_to_window(self.cf.window, sample_proposals).cpu().data.numpy() + for ix, r in enumerate(rois): + box_results_list[int(r[-1])].append({'box_coords': r[:-1] * self.cf.scale, + 'box_type': 'pos_class' if target_class_ids[ix] > 0 else 'neg_class'}) + + # compute mrcnn losses. + mrcnn_class_loss = compute_mrcnn_class_loss(self.cf.prediction_tasks, mrcnn_class_logits, target_class_ids) + mrcnn_bbox_loss = compute_mrcnn_bbox_loss(mrcnn_pred_deltas, mrcnn_target_deltas, target_class_ids) + mrcnn_regressions_loss = compute_mrcnn_regression_loss(self.cf.prediction_tasks, mrcnn_regressions, target_regressions, target_class_ids) + # mrcnn can be run without pixelwise annotations available (Faster R-CNN mode). + # In this case, the mask_loss is taken out of training. + if not self.cf.frcnn_mode: + mrcnn_mask_loss = compute_mrcnn_mask_loss(mrcnn_pred_mask, target_mask, target_class_ids) + else: + mrcnn_mask_loss = torch.FloatTensor([0]).cuda() + + loss = batch_rpn_class_loss + batch_rpn_bbox_loss +\ + mrcnn_bbox_loss + mrcnn_mask_loss + mrcnn_class_loss + mrcnn_regressions_loss + + # loss= [batch_rpn_class_loss, batch_rpn_bbox_loss, mrcnn_bbox_loss, mrcnn_mask_loss, mrcnn_class_loss, + # mrcnn_regressions_loss] + # loss = torch.tensor([part_loss * self.loss_weights[i] for i, part_loss in enumerate(loss)], requires_grad=True).sum(0, keepdim=True) + + # monitor RPN performance: detection count = the number of correctly matched proposals per fg-class. + #dcount = [list(target_class_ids.cpu().data.numpy()).count(c) for c in np.arange(self.cf.head_classes)[1:]] + #self.logger.info("regression loss {:.3f}".format(mrcnn_regressions_loss.item())) + #self.logger.info("loss: {0:.2f}, rpn_class: {1:.2f}, rpn_bbox: {2:.2f}, mrcnn_class: {3:.2f}, mrcnn_bbox: {4:.2f}, " + # "mrcnn_mask: {5:.2f}, dcount {6}".format(loss.item(), batch_rpn_class_loss.item(), + # batch_rpn_bbox_loss.item(), mrcnn_class_loss.item(), mrcnn_bbox_loss.item(), mrcnn_mask_loss.item(), dcount)) + + # run unmolding of predictions for monitoring and merge all results to one dictionary. + return_masks = self.cf.return_masks_in_val if is_validation else self.cf.return_masks_in_train + results_dict = self.get_results(img.shape, detections, detection_masks, box_results_list, + return_masks=return_masks) + results_dict['seg_preds'] = results_dict['seg_preds'].argmax(axis=1).astype('uint8')[:,np.newaxis] + if 'dice' in self.cf.metrics: + results_dict['batch_dices'] = mutils.dice_per_batch_and_class( + results_dict['seg_preds'], batch["seg"], self.cf.num_seg_classes, convert_to_ohe=True) + + results_dict['torch_loss'] = loss + results_dict['class_loss'] = mrcnn_class_loss.item() + results_dict['bbox_loss'] = mrcnn_bbox_loss.item() + results_dict['rg_loss'] = mrcnn_regressions_loss.item() + results_dict['rpn_class_loss'] = rpn_class_loss.item() + results_dict['rpn_bbox_loss'] = rpn_bbox_loss.item() + + return results_dict + + + def test_forward(self, batch, return_masks=True): + """ + test method. wrapper around forward pass of network without usage of any ground truth information. + prepares input data for processing and stores outputs in a dictionary. + :param batch: dictionary containing 'data' + :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off). + :return: results_dict: dictionary with keys: + 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: + [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] + 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes] + """ + img = batch['data'] + img = torch.from_numpy(img).float().cuda() + _, _, _, detections, detection_masks = self.forward(img) + results_dict = self.get_results(img.shape, detections, detection_masks, return_masks=return_masks) + + return results_dict \ No newline at end of file diff --git a/models/mrcnn_aleatoric.py b/models/mrcnn_aleatoric.py new file mode 100644 index 0000000..30d54d5 --- /dev/null +++ b/models/mrcnn_aleatoric.py @@ -0,0 +1,735 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn +published under MIT license. +""" +import time + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils + +import utils.model_utils as mutils +import utils.exp_utils as utils +#from cuda_functions.nms_2D.pth_nms import nms_gpu as nms_2D +#from cuda_functions.nms_3D.pth_nms import nms_gpu as nms_3D +#from cuda_functions.roi_align_2D.roi_align.crop_and_resize import CropAndResizeFunction as ra2D +#from cuda_functions.roi_align_3D.roi_align.crop_and_resize import CropAndResizeFunction as ra3D + + +class RPN(nn.Module): + """ + Region Proposal Network. + """ + + def __init__(self, cf, conv): + + super(RPN, self).__init__() + self.dim = conv.dim + + self.conv_shared = conv(cf.end_filts, cf.n_rpn_features, ks=3, stride=cf.rpn_anchor_stride, pad=1, relu=cf.relu) + self.conv_class = conv(cf.n_rpn_features, 2 * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None) + self.conv_bbox = conv(cf.n_rpn_features, 2 * self.dim * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None) + + + def forward(self, x): + """ + :param x: input feature maps (b, in_channels, y, x, (z)) + :return: rpn_class_logits (b, 2, n_anchors) + :return: rpn_probs_logits (b, 2, n_anchors) + :return: rpn_bbox (b, 2 * dim, n_anchors) + """ + + # Shared convolutional base of the RPN. + x = self.conv_shared(x) + + # Anchor Score. (batch, anchors per location * 2, y, x, (z)). + rpn_class_logits = self.conv_class(x) + # Reshape to (batch, 2, anchors) + axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) + rpn_class_logits = rpn_class_logits.permute(*axes) + rpn_class_logits = rpn_class_logits.contiguous() + rpn_class_logits = rpn_class_logits.view(x.size()[0], -1, 2) + + # Softmax on last dimension (fg vs. bg). + rpn_probs = F.softmax(rpn_class_logits, dim=2) + + # Bounding box refinement. (batch, anchors_per_location * (y, x, (z), log(h), log(w), (log(d)), y, x, (z)) + rpn_bbox = self.conv_bbox(x) + + # Reshape to (batch, 2*dim, anchors) + rpn_bbox = rpn_bbox.permute(*axes) + rpn_bbox = rpn_bbox.contiguous() + rpn_bbox = rpn_bbox.view(x.size()[0], -1, self.dim * 2) + + return [rpn_class_logits, rpn_probs, rpn_bbox] + +class Classifier(nn.Module): + """ + Head network for classification and bounding box refinement. Performs RoiAlign, processes resulting features through a + shared convolutional base and finally branches off the classifier- and regression head. + """ + def __init__(self, cf, conv): + super(Classifier, self).__init__() + + self.cf = cf + self.dim = conv.dim + self.in_channels = cf.end_filts + self.pool_size = cf.pool_size + self.pyramid_levels = cf.pyramid_levels + # instance_norm does not work with spatial dims (1, 1, (1)) + norm = cf.norm if cf.norm != 'instance_norm' else None + + self.conv1 = conv(cf.end_filts, cf.end_filts * 4, ks=self.pool_size, stride=1, norm=norm, relu=cf.relu) + self.conv2 = conv(cf.end_filts * 4, cf.end_filts * 4, ks=1, stride=1, norm=norm, relu=cf.relu) + self.linear_bbox = nn.Linear(cf.end_filts * 4, cf.head_classes * 2 * self.dim) + + + if 'regression_ken_gal' in self.cf.prediction_tasks: + self.linear_regressor = nn.Linear(cf.end_filts * 4, cf.head_classes*cf.regression_n_features) + self.uncert_regressor = nn.Linear(cf.end_filts * 4, cf.head_classes) + else: + raise NotImplementedError + if 'class' in self.cf.prediction_tasks: + #raise NotImplementedError + self.linear_class = nn.Linear(cf.end_filts * 4, cf.head_classes) + else: + assert cf.head_classes==2, "#head classes {} needs to be 2 (bg/fg) when not predicting classes" + self.linear_class = lambda x: torch.zeros((x.shape[0], cf.head_classes), dtype=torch.float64).cuda() + #assert hasattr(cf, "regression_n_features"), "cannot choose class inference from regression if regression not applied" + + def forward(self, x, rois): + """ + :param x: input feature maps (b, in_channels, y, x, (z)) + :param rois: normalized box coordinates as proposed by the RPN to be forwarded through + the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements + have been merged to one vector, while the origin info has been stored for re-allocation. + :return: mrcnn_class_logits (n_proposals, n_head_classes) + :return: mrcnn_bbox (n_proposals, n_head_classes, 2 * dim) predicted corrections to be applied to proposals for refinement. + :return: mrcnn_regress (n_proposals, n_head_classes, regression_n_features+1) +1 is aleatoric uncertainty + """ + x = mutils.pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim) + x = self.conv1(x) + x = self.conv2(x) + x = x.view(-1, self.in_channels * 4) + + mrcnn_bbox = self.linear_bbox(x) + mrcnn_bbox = mrcnn_bbox.view(mrcnn_bbox.size()[0], -1, self.dim * 2) + mrcnn_class_logits = self.linear_class(x) + mrcnn_regress, uncert_rg = self.linear_regressor(x), self.uncert_regressor(x) + mrcnn_regress = torch.cat((mrcnn_regress.view(mrcnn_regress.shape[0], -1, self.cf.regression_n_features), + uncert_rg.unsqueeze(-1)), dim=2) + + return [mrcnn_bbox, mrcnn_class_logits, mrcnn_regress] + +class Mask(nn.Module): + """ + Head network for proposal-based mask segmentation. Performs RoiAlign, some convolutions and applies sigmoid on the + output logits to allow for overlapping classes. + """ + def __init__(self, cf, conv): + super(Mask, self).__init__() + self.pool_size = cf.mask_pool_size + self.pyramid_levels = cf.pyramid_levels + self.dim = conv.dim + self.conv1 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) + self.conv2 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) + self.conv3 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) + self.conv4 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) + if conv.dim == 2: + self.deconv = nn.ConvTranspose2d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2) + else: + self.deconv = nn.ConvTranspose3d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2) + + self.relu = nn.ReLU(inplace=True) if cf.relu == 'relu' else nn.LeakyReLU(inplace=True) + self.conv5 = conv(cf.end_filts, cf.head_classes, ks=1, stride=1, relu=None) + self.sigmoid = nn.Sigmoid() + + def forward(self, x, rois): + """ + :param x: input feature maps (b, in_channels, y, x, (z)) + :param rois: normalized box coordinates as proposed by the RPN to be forwarded through + the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements + have been merged to one vector, while the origin info has been stored for re-allocation. + :return: x: masks (n_sampled_proposals (n_detections in inference), n_classes, y, x, (z)) + """ + x = mutils.pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim) + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + x = self.relu(self.deconv(x)) + x = self.conv5(x) + x = self.sigmoid(x) + return x + + +############################################################ +# Loss Functions +############################################################ + +def compute_rpn_class_loss(rpn_class_logits, rpn_match, shem_poolsize): + """ + :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. + :param rpn_class_logits: (n_anchors, 2). logits from RPN classifier. + :param SHEM_poolsize: int. factor of top-k candidates to draw from per negative sample (stochastic-hard-example-mining). + :return: loss: torch tensor + :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training. + """ + + # Filter out netural anchors + pos_indices = torch.nonzero(rpn_match == 1) + neg_indices = torch.nonzero(rpn_match == -1) + + # loss for positive samples + if not 0 in pos_indices.size(): + pos_indices = pos_indices.squeeze(1) + roi_logits_pos = rpn_class_logits[pos_indices] + pos_loss = F.cross_entropy(roi_logits_pos, torch.LongTensor([1] * pos_indices.shape[0]).cuda()) + else: + pos_loss = torch.FloatTensor([0]).cuda() + + # loss for negative samples: draw hard negative examples (SHEM) + # that match the number of positive samples, but at least 1. + if not 0 in neg_indices.size(): + neg_indices = neg_indices.squeeze(1) + roi_logits_neg = rpn_class_logits[neg_indices] + negative_count = np.max((1, pos_indices.cpu().data.numpy().size)) + roi_probs_neg = F.softmax(roi_logits_neg, dim=1) + neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize) + neg_loss = F.cross_entropy(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda()) + np_neg_ix = neg_ix.cpu().data.numpy() + #print("pos, neg count", pos_indices.cpu().data.numpy().size, negative_count) + else: + neg_loss = torch.FloatTensor([0]).cuda() + np_neg_ix = np.array([]).astype('int32') + + loss = (pos_loss + neg_loss) / 2 + return loss, np_neg_ix + + +def compute_rpn_bbox_loss(rpn_pred_deltas, rpn_target_deltas, rpn_match): + """ + :param rpn_target_deltas: (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))). + Uses 0 padding to fill in unsed bbox deltas. + :param rpn_pred_deltas: predicted deltas from RPN. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))) + :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. + :return: loss: torch 1D tensor. + """ + if not 0 in torch.nonzero(rpn_match == 1).size(): + + indices = torch.nonzero(rpn_match == 1).squeeze(1) + # Pick bbox deltas that contribute to the loss + rpn_pred_deltas = rpn_pred_deltas[indices] + # Trim target bounding box deltas to the same length as rpn_bbox. + target_deltas = rpn_target_deltas[:rpn_pred_deltas.size()[0], :] + # Smooth L1 loss + loss = F.smooth_l1_loss(rpn_pred_deltas, target_deltas) + else: + loss = torch.FloatTensor([0]).cuda() + + return loss + +def compute_mrcnn_bbox_loss(mrcnn_pred_deltas, mrcnn_target_deltas, target_class_ids): + """ + :param mrcnn_pred_deltas: (n_sampled_rois, n_classes, (dy, dx, (dz), log(dh), log(dw), (log(dh))) + :param mrcnn_target_deltas: (n_sampled_rois, (dy, dx, (dz), log(dh), log(dw), (log(dh))) + :param target_class_ids: (n_sampled_rois) + :return: loss: torch 1D tensor. + """ + if not 0 in torch.nonzero(target_class_ids > 0).size(): + positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0] + positive_roi_class_ids = target_class_ids[positive_roi_ix].long() + target_bbox = mrcnn_target_deltas[positive_roi_ix, :].detach() + pred_bbox = mrcnn_pred_deltas[positive_roi_ix, positive_roi_class_ids, :] + loss = F.smooth_l1_loss(pred_bbox, target_bbox) + else: + loss = torch.FloatTensor([0]).cuda() + + return loss + +def compute_mrcnn_mask_loss(pred_masks, target_masks, target_class_ids): + """ + :param pred_masks: (n_sampled_rois, n_classes, y, x, (z)) float32 tensor with values between [0, 1]. + :param target_masks: (n_sampled_rois, y, x, (z)) A float32 tensor of values 0 or 1. Uses zero padding to fill array. + :param target_class_ids: (n_sampled_rois) + :return: loss: torch 1D tensor. + """ + if not 0 in torch.nonzero(target_class_ids > 0).size(): + # Only positive ROIs contribute to the loss. And only + # the class specific mask of each ROI. + positive_ix = torch.nonzero(target_class_ids > 0)[:, 0] + positive_class_ids = target_class_ids[positive_ix].long() + y_true = target_masks[positive_ix, :, :].detach() + y_pred = pred_masks[positive_ix, positive_class_ids, :, :] + loss = F.binary_cross_entropy(y_pred, y_true) + else: + loss = torch.FloatTensor([0]).cuda() + + return loss + +def compute_mrcnn_class_loss(tasks, pred_class_logits, target_class_ids): + """ + :param pred_class_logits: (n_sampled_rois, n_classes) + :param target_class_ids: (n_sampled_rois) batch dimension was merged into roi dimension. + :return: loss: torch 1D tensor. + """ + if 'class' in tasks and not 0 in target_class_ids.size(): + loss = F.cross_entropy(pred_class_logits, target_class_ids.long()) + else: + loss = torch.FloatTensor([0.]).cuda() + + return loss + +def compute_mrcnn_regression_loss(pred, target, target_class_ids): + """regression loss is a distance metric between target vector and predicted regression vector. + :param pred: (n_sample_rois, n_classes, n_regr_feats+1) regression pred where last entry of each regression + pred is the uncertainty parameter + :param target: (n_sample_rois, n_regr_feats) + :param target_class_ids: (n_sample_rois) + :return: differentiable loss, torch 1D tensor on cuda + """ + + if not 0 in torch.nonzero(target_class_ids > 0).size(): + positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0] + positive_roi_class_ids = target_class_ids[positive_roi_ix].long() + target = target[positive_roi_ix, :].float().detach() + pred = pred[positive_roi_ix, positive_roi_class_ids, :] + + # loss is 1/(2N)*[Sum_i^N exp(-s_i) distance(pred_vec, targ_vec) + s_i] + loss = F.smooth_l1_loss(pred[...,:-1], target, reduction='none').sum(dim=1) * torch.exp(-pred[...,-1]) + loss += pred[...,-1] #regularizer for sigma + loss = 0.5*loss.mean() + else: + loss = torch.FloatTensor([0.]).cuda() + + return loss + +############################################################ +# Detection Layer +############################################################ + +def compute_roi_scores(cf, batch_rpn_proposals, mrcnn_cl_logits): + """Compute scores from uncertainty measures (lower=better) to use for sorting/clustering algos (higher=better). + :param cf: + :param uncert_class: + :param uncert_regression: + :return: + """ + if 'class' in cf.prediction_tasks: + scores = F.softmax(mrcnn_cl_logits, dim=1) + else: + scores = batch_rpn_proposals[:,:,-1].view(-1, 1) + scores = torch.cat((1-scores, scores), dim=1) + + return scores + +############################################################ +# MaskRCNN Class +############################################################ + +class net(nn.Module): + + + def __init__(self, cf, logger): + + super(net, self).__init__() + self.cf = cf + self.logger = logger + self.regress_flag = any(['regression' in task for task in self.cf.prediction_tasks]) + self.build() + + + if self.cf.weight_init=="custom": + logger.info("Tried to use custom weight init which is not defined. Using pytorch default.") + elif self.cf.weight_init: + mutils.initialize_weights(self) + else: + logger.info("using default pytorch weight init") + + def build(self): + """Build Mask R-CNN architecture.""" + + # Image size must be dividable by 2 multiple times. + h, w = self.cf.patch_size[:2] + if h / 2**5 != int(h / 2**5) or w / 2**5 != int(w / 2**5): + raise Exception("Image size must be dividable by 2 at least 5 times " + "to avoid fractions when downscaling and upscaling." + "For example, use 256, 320, 384, 448, 512, ... etc.,i.e.," + "any number x*32 will do!") + + # instantiate abstract multi-dimensional conv generator and load backbone module. + backbone = utils.import_module('bbone', self.cf.backbone_path) + conv = backbone.ConvGenerator(self.cf.dim) + + # build Anchors, FPN, RPN, Classifier / Bbox-Regressor -head, Mask-head + self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf) + self.anchors = torch.from_numpy(self.np_anchors).float().cuda() + self.fpn = backbone.FPN(self.cf, conv, relu_enc=self.cf.relu, operate_stride1=False).cuda() + self.rpn = RPN(self.cf, conv) + self.classifier = Classifier(self.cf, conv) + self.mask = Mask(self.cf, conv) + + def forward(self, img, is_training=True): + """ + :param img: input images (b, c, y, x, (z)). + :return: rpn_pred_logits: (b, n_anchors, 2) + :return: rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d)))) + :return: batch_unnormed_props: (b, n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix)) only for monitoring/plotting. + :return: detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score) + :return: detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head. + """ + # extract features. + fpn_outs = self.fpn(img) + rpn_feature_maps = [fpn_outs[i] for i in self.cf.pyramid_levels] + self.mrcnn_feature_maps = rpn_feature_maps + + # loop through pyramid layers and apply RPN. + layer_outputs = [] # list of lists + for p in rpn_feature_maps: + layer_outputs.append(self.rpn(p)) + + # concatenate layer outputs. + # convert from list of lists of level outputs to list of lists of outputs across levels. + # e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]] + outputs = list(zip(*layer_outputs)) + outputs = [torch.cat(list(o), dim=1) for o in outputs] + rpn_pred_logits, rpn_pred_probs, rpn_pred_deltas = outputs + + # generate proposals: apply predicted deltas to anchors and filter by foreground scores from RPN classifier. + proposal_count = self.cf.post_nms_rois_training if is_training else self.cf.post_nms_rois_inference + batch_normed_props, batch_unnormed_props = mutils.refine_proposals(rpn_pred_probs, rpn_pred_deltas, proposal_count, self.anchors, self.cf) + + # merge batch dimension of proposals while storing allocation info in coordinate dimension. + batch_ixs = torch.from_numpy(np.repeat(np.arange(batch_normed_props.shape[0]), batch_normed_props.shape[1])).float().cuda() + rpn_rois = batch_normed_props[:,:,:-1].view(-1, batch_normed_props[:,:,:-1].shape[2]) + self.rpn_rois_batch_info = torch.cat((rpn_rois, batch_ixs.unsqueeze(1)), dim=1) + + # this is the first of two forward passes in the second stage, where no activations are stored for backprop. + # here, all proposals are forwarded (with virtual_batch_size = batch_size * post_nms_rois.) + # for inference/monitoring as well as sampling of rois for the loss functions. + # processed in chunks of roi_chunk_size to re-adjust to gpu-memory. + chunked_rpn_rois = self.rpn_rois_batch_info.split(self.cf.roi_chunk_size) + bboxes_list, class_logits_list, regressions_list = [], [], [] + with torch.no_grad(): + for chunk in chunked_rpn_rois: + chunk_bboxes, chunk_class_logits, chunk_regressions = self.classifier(self.mrcnn_feature_maps, chunk) + bboxes_list.append(chunk_bboxes) + class_logits_list.append(chunk_class_logits) + regressions_list.append(chunk_regressions) + mrcnn_bbox = torch.cat(bboxes_list, 0) + mrcnn_class_logits = torch.cat(class_logits_list, 0) + mrcnn_regressions = torch.cat(regressions_list, 0) + #self.mrcnn_class_logits = F.softmax(mrcnn_class_logits, dim=1) + #why were mrcnn_bbox, class_logs, regress called batch_ ? they have no batch dim, in contrast to batch_normed_props + self.mrcnn_roi_scores = compute_roi_scores(self.cf, batch_normed_props, mrcnn_class_logits) + # refine classified proposals, filter and return final detections. + # returns (cf.max_inst_per_batch_element, n_coords+1+...) + detections = mutils.refine_detections(self.cf, batch_ixs, rpn_rois, mrcnn_bbox, self.mrcnn_roi_scores, + mrcnn_regressions) + + # forward remaining detections through mask-head to generate corresponding masks. + scale = [img.shape[2]] * 4 + [img.shape[-1]] * 2 + scale = torch.from_numpy(np.array(scale[:self.cf.dim * 2] + [1])[None]).float().cuda() + + # first self.cf.dim * 2 entries on axis 1 are always the box coords, +1 is batch_ics + detection_boxes = detections[:, :self.cf.dim * 2 + 1] / scale + with torch.no_grad(): + detection_masks = self.mask(self.mrcnn_feature_maps, detection_boxes) + + return [rpn_pred_logits, rpn_pred_deltas, batch_unnormed_props, detections, detection_masks] + + def loss_samples_forward(self, batch_gt_boxes, batch_gt_masks, batch_gt_class_ids, batch_gt_regressions): + """ + this is the second forward pass through the second stage (features from stage one are re-used). + samples few rois in loss_example_mining and forwards only those for loss computation. + :param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels. can be None. + :param batch_gt_regressions: can be None. + :param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates. + :param batch_gt_masks: list over batch elements. Each element is binary mask of shape (n_gt_rois, y, x, (z), c) + :return: sample_logits: (n_sampled_rois, n_classes) predicted class scores. + :return: sample_deltas: (n_sampled_rois, n_classes, 2 * dim) predicted corrections to be applied to proposals for refinement. + :return: sample_mask: (n_sampled_rois, n_classes, y, x, (z)) predicted masks per class and proposal. + :return: sample_target_class_ids: (n_sampled_rois) target class labels of sampled proposals. + :return: sample_target_deltas: (n_sampled_rois, 2 * dim) target deltas of sampled proposals for box refinement. + :return: sample_target_masks: (n_sampled_rois, y, x, (z)) target masks of sampled proposals. + :return: sample_proposals: (n_sampled_rois, 2 * dim) RPN output for sampled proposals. only for monitoring/plotting. + """ + # sample rois for loss and get corresponding targets for all Mask R-CNN head network losses. + sample_ics, sample_target_deltas, sample_target_mask, sample_target_class_ids, sample_target_regressions = \ + mutils.loss_example_mining(self.cf, self.rpn_rois_batch_info, batch_gt_boxes, batch_gt_masks, + self.mrcnn_roi_scores, batch_gt_class_ids, batch_gt_regressions) + + # re-use feature maps and RPN output from first forward pass. + sample_proposals = self.rpn_rois_batch_info[sample_ics] + if not 0 in sample_proposals.size(): + sample_deltas, sample_logits, sample_regressions = self.classifier(self.mrcnn_feature_maps, sample_proposals) + sample_mask = self.mask(self.mrcnn_feature_maps, sample_proposals) + else: + sample_logits = torch.FloatTensor().cuda() + sample_deltas = torch.FloatTensor().cuda() + sample_mask = torch.FloatTensor().cuda() + + return [sample_deltas, sample_mask, sample_logits, sample_regressions, sample_proposals, + sample_target_deltas, sample_target_mask, sample_target_class_ids, sample_target_regressions] + + def get_results(self, img_shape, detections, detection_masks, box_results_list=None, return_masks=True): + """ + Restores batch dimension of merged detections, unmolds detections, creates and fills results dict. + :param img_shape: + :param detections: shape (n_final_detections, len(info)), where + info=( y1, x1, y2, x2, (z1,z2), batch_ix, pred_class_id, pred_score ) + :param detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head. + :param box_results_list: None or list of output boxes for monitoring/plotting. + each element is a list of boxes per batch element. + :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off). + :return: results_dict: dictionary with keys: + 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: + [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] + 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, 1] only fg. vs. bg for now. + class-specific return of masks will come with implementation of instance segmentation evaluation. + """ + + detections = detections.cpu().data.numpy() + if self.cf.dim == 2: + detection_masks = detection_masks.permute(0, 2, 3, 1).cpu().data.numpy() + else: + detection_masks = detection_masks.permute(0, 2, 3, 4, 1).cpu().data.numpy() + # det masks shape now (n_dets, y,x(,z), n_classes) + # restore batch dimension of merged detections using the batch_ix info. + batch_ixs = detections[:, self.cf.dim*2] + detections = [detections[batch_ixs == ix] for ix in range(img_shape[0])] + mrcnn_mask = [detection_masks[batch_ixs == ix] for ix in range(img_shape[0])] + #mrcnn_mask: shape (b_size, variable, variable, n_classes), variable bc depends on single instance mask size + + if box_results_list == None: # for test_forward, where no previous list exists. + box_results_list = [[] for _ in range(img_shape[0])] + + seg_logits = [] + # loop over batch and unmold detections. + for ix in range(img_shape[0]): + + # final masks are one-hot encoded (b, n_classes, y, x, (z)) + final_masks = np.zeros((self.cf.num_classes + 1, *img_shape[2:])) + #+1 for bg, 0.5 bc mask head classifies only bg/fg with logits between 0,1--> bg is <0.5 + if self.cf.num_classes + 1 != self.cf.num_seg_classes: + self.logger.warning("n of box classifier head classes {} doesnt match cf.num_seg_classes {}".format( + self.cf.num_classes + 1, self.cf.num_seg_classes)) + + if not 0 in detections[ix].shape: + boxes = detections[ix][:, :self.cf.dim*2].astype(np.int32) + class_ids = detections[ix][:, self.cf.dim*2 + 1].astype(np.int32) + scores = detections[ix][:, self.cf.dim*2 + 2] + masks = mrcnn_mask[ix][np.arange(boxes.shape[0]), ..., class_ids] + regressions = detections[ix][:,self.cf.dim*2+3:] + + # Filter out detections with zero area. Often only happens in early + # stages of training when the network weights are still a bit random. + if self.cf.dim == 2: + exclude_ix = np.where((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0] + else: + exclude_ix = np.where( + (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4]) <= 0)[0] + + if exclude_ix.shape[0] > 0: + boxes = np.delete(boxes, exclude_ix, axis=0) + masks = np.delete(masks, exclude_ix, axis=0) + class_ids = np.delete(class_ids, exclude_ix, axis=0) + scores = np.delete(scores, exclude_ix, axis=0) + regressions = np.delete(regressions, exclude_ix, axis=0) + + # Resize masks to original image size and set boundary threshold. + if return_masks: + for i in range(masks.shape[0]): #masks per this batch instance/element/image + # Convert neural network mask to full size mask + if self.cf.dim == 2: + full_mask = mutils.unmold_mask_2D(masks[i], boxes[i], img_shape[2:]) + else: + full_mask = mutils.unmold_mask_3D(masks[i], boxes[i], img_shape[2:]) + # take the maximum seg_logits per class of instances in that class, i.e., a pixel in a class + # has the max seg_logit value over all instances of that class in one sample + final_masks[class_ids[i]] = np.max((final_masks[class_ids[i]], full_mask), axis=0) + final_masks[0] = np.full(final_masks[0].shape, 0.49999999) #effectively min_det_thres at 0.5 per pixel + + # add final predictions to results. + if not 0 in boxes.shape: + for ix2, coords in enumerate(boxes): + box = {'box_coords': coords, 'box_type': 'det', 'box_score': scores[ix2], + 'box_pred_class_id': class_ids[ix2]} + if 'regression_ken_gal' or 'regression_feindt' in self.cf.prediction_tasks: + rg_uncert = np.sqrt(np.exp(regressions[ix2][-1])) + box.update({'regression': regressions[ix2][:-1], 'rg_uncertainty': rg_uncert }) + if hasattr(self.cf, "rg_val_to_bin_id"): + box['rg_bin'] = self.cf.rg_val_to_bin_id(regressions[ix2][:-1]) + box_results_list[ix].append(box) + + # if no detections were made--> keep full bg mask (zeros). + seg_logits.append(final_masks) + + # create and fill results dictionary. + results_dict = {} + results_dict['boxes'] = box_results_list + results_dict['seg_preds'] = np.array(seg_logits) + + return results_dict + + + def train_forward(self, batch, is_validation=False): + """ + train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data + for processing, computes losses, and stores outputs in a dictionary. + :param batch: dictionary containing 'data', 'seg', etc. + :return: results_dict: dictionary with keys: + 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: + [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] + 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]. + 'torch_loss': 1D torch tensor for backprop. + 'class_loss': classification loss for monitoring. + """ + img = batch['data'] + gt_boxes = batch['bb_target'] + axes = (0, 2, 3, 1) if self.cf.dim == 2 else (0, 2, 3, 4, 1) + gt_masks = [np.transpose(batch['roi_masks'][ii], axes=axes) for ii in range(len(batch['roi_masks']))] + gt_regressions = batch["regression_targets"] if self.regress_flag else None + gt_class_ids = batch['class_targets'] + + + img = torch.from_numpy(img).float().cuda() + batch_rpn_class_loss = torch.FloatTensor([0]).cuda() + batch_rpn_bbox_loss = torch.FloatTensor([0]).cuda() + + # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. + box_results_list = [[] for _ in range(img.shape[0])] + + #forward passes. 1. general forward pass, where no activations are saved in second stage (for performance + # monitoring and loss sampling). 2. second stage forward pass of sampled rois with stored activations for backprop. + rpn_class_logits, rpn_pred_deltas, proposal_boxes, detections, detection_masks = self.forward(img) + + mrcnn_pred_deltas, mrcnn_pred_mask, mrcnn_class_logits, mrcnn_regressions, sample_proposals, \ + mrcnn_target_deltas, target_mask, target_class_ids, target_regressions = \ + self.loss_samples_forward(gt_boxes, gt_masks, gt_class_ids, gt_regressions) + + #loop over batch + for b in range(img.shape[0]): + if len(gt_boxes[b]) > 0: + # add gt boxes to output list + for tix in range(len(gt_boxes[b])): + gt_box = {'box_type': 'gt', 'box_coords': batch['bb_target'][b][tix]} + for name in self.cf.roi_items: + gt_box.update({name: batch[name][b][tix]}) + box_results_list[b].append(gt_box) + + # match gt boxes with anchors to generate targets for RPN losses. + rpn_match, rpn_target_deltas = mutils.gt_anchor_matching(self.cf, self.np_anchors, gt_boxes[b]) + + # add positive anchors used for loss to output list for monitoring. + pos_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == 1)][:, 0], img.shape[2:]) + for p in pos_anchors: + box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'}) + + else: + rpn_match = np.array([-1]*self.np_anchors.shape[0]) + rpn_target_deltas = np.array([0]) + + rpn_match = torch.from_numpy(rpn_match).cuda() + rpn_target_deltas = torch.from_numpy(rpn_target_deltas).float().cuda() + + # compute RPN losses. + rpn_class_loss, neg_anchor_ix = compute_rpn_class_loss(rpn_class_logits[b], rpn_match, self.cf.shem_poolsize) + rpn_bbox_loss = compute_rpn_bbox_loss(rpn_pred_deltas[b], rpn_target_deltas, rpn_match) + batch_rpn_class_loss += rpn_class_loss /img.shape[0] + batch_rpn_bbox_loss += rpn_bbox_loss /img.shape[0] + + # add negative anchors used for loss to output list for monitoring. + neg_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == -1)][0, neg_anchor_ix], img.shape[2:]) + for n in neg_anchors: + box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'}) + + # add highest scoring proposals to output list for monitoring. + rpn_proposals = proposal_boxes[b][proposal_boxes[b, :, -1].argsort()][::-1] + for r in rpn_proposals[:self.cf.n_plot_rpn_props, :-1]: + box_results_list[b].append({'box_coords': r, 'box_type': 'prop'}) + + # add positive and negative roi samples used for mrcnn losses to output list for monitoring. + if not 0 in sample_proposals.shape: + rois = mutils.clip_to_window(self.cf.window, sample_proposals).cpu().data.numpy() + for ix, r in enumerate(rois): + box_results_list[int(r[-1])].append({'box_coords': r[:-1] * self.cf.scale, + 'box_type': 'pos_class' if target_class_ids[ix] > 0 else 'neg_class'}) + + # compute mrcnn losses. + mrcnn_class_loss = compute_mrcnn_class_loss(self.cf.prediction_tasks, mrcnn_class_logits, target_class_ids) + mrcnn_bbox_loss = compute_mrcnn_bbox_loss(mrcnn_pred_deltas, mrcnn_target_deltas, target_class_ids) + mrcnn_regression_loss = compute_mrcnn_regression_loss(mrcnn_regressions, target_regressions, target_class_ids) + # mrcnn can be run without pixelwise annotations available (Faster R-CNN mode). + # In this case, the mask_loss is taken out of training. + if not self.cf.frcnn_mode: + mrcnn_mask_loss = compute_mrcnn_mask_loss(mrcnn_pred_mask, target_mask, target_class_ids) + else: + mrcnn_mask_loss = torch.FloatTensor([0]).cuda() + + loss = batch_rpn_class_loss + batch_rpn_bbox_loss +\ + mrcnn_bbox_loss + mrcnn_mask_loss + mrcnn_class_loss + mrcnn_regression_loss + + # monitor RPN performance: detection count = the number of correctly matched proposals per fg-class. + #dcount = [list(target_class_ids.cpu().data.numpy()).count(c) for c in np.arange(self.cf.head_classes)[1:]] + #self.logger.info("regression loss {:.3f}".format(mrcnn_regression_loss.item())) + #self.logger.info("loss: {0:.2f}, rpn_class: {1:.2f}, rpn_bbox: {2:.2f}, mrcnn_class: {3:.2f}, mrcnn_bbox: {4:.2f}, " + # "mrcnn_mask: {5:.2f}, dcount {6}".format(loss.item(), batch_rpn_class_loss.item(), + # batch_rpn_bbox_loss.item(), mrcnn_class_loss.item(), mrcnn_bbox_loss.item(), mrcnn_mask_loss.item(), dcount)) + + # run unmolding of predictions for monitoring and merge all results to one dictionary. + + return_masks = self.cf.return_masks_in_val if is_validation else self.cf.return_masks_in_train + results_dict = self.get_results( + img.shape, detections, detection_masks, box_results_list, return_masks=return_masks) + results_dict['seg_preds'] = results_dict['seg_preds'].argmax(axis=1).astype('uint8')[:,np.newaxis] + if 'dice' in self.cf.metrics: + results_dict['batch_dices'] = mutils.dice_per_batch_and_class( + results_dict['seg_preds'], batch["seg"], self.cf.num_seg_classes, convert_to_ohe=True) + + + results_dict['torch_loss'] = loss + results_dict['class_loss'] = mrcnn_class_loss.item() + results_dict['rg_loss'] = mrcnn_regression_loss.item() + results_dict['bbox_loss'] = mrcnn_bbox_loss.item() + results_dict['rpn_bbox_loss'] = rpn_bbox_loss.item() + results_dict['rpn_class_loss'] = rpn_class_loss.item() + + return results_dict + + + def test_forward(self, batch, return_masks=True): + """ + test method. wrapper around forward pass of network without usage of any ground truth information. + prepares input data for processing and stores outputs in a dictionary. + :param batch: dictionary containing 'data' + :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off). + :return: results_dict: dictionary with keys: + 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: + [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] + 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes] + """ + img = batch['data'] + img = torch.from_numpy(img).float().cuda() + _, _, _, detections, detection_masks = self.forward(img) + results_dict = self.get_results(img.shape, detections, detection_masks, return_masks=return_masks) + + return results_dict \ No newline at end of file diff --git a/models/mrcnn_gan.py b/models/mrcnn_gan.py new file mode 100644 index 0000000..af5632c --- /dev/null +++ b/models/mrcnn_gan.py @@ -0,0 +1,844 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn +published under MIT license. +""" +import time + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils + +import utils.model_utils as mutils +import utils.exp_utils as utils + + +class Generator_RPN(nn.Module): + """ + Region Proposal Network. + """ + + def __init__(self, cf, conv): + + super(Generator_RPN, self).__init__() + self.dim = conv.dim + + #assert cf.batch_size%2==0 + self.conv_shared = conv(cf.end_filts+1, cf.n_rpn_features, ks=3, stride=cf.rpn_anchor_stride, pad=1, relu=cf.relu) + self.conv_class = conv(cf.n_rpn_features, 2 * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None) + self.conv_bbox = conv(cf.n_rpn_features, 2 * self.dim * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None) + + + def forward(self, x): + """ + :param x: input feature maps (b, in_channels, y, x, (z)) + :return: rpn_class_logits (b, n_anchors, 2) + :return: rpn_probs_logits (b, n_anchors, 2) + :return: rpn_bbox (b, n_anchors, 2*dim) + """ + # latent vector from vanilla base distribution + z = torch.randn(x.shape[0], 1, *x.shape[2:], requires_grad=True).cuda() + x = torch.cat((x,z), dim=1) + # Shared convolutional base of the RPN. + x = self.conv_shared(x) + + # Anchor Score. (batch, anchors per location * 2, y, x, (z)). + rpn_class_logits = self.conv_class(x) + # Reshape to (batch, anchors, 2) + axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) + rpn_class_logits = rpn_class_logits.permute(*axes) + rpn_class_logits = rpn_class_logits.contiguous() + rpn_class_logits = rpn_class_logits.view(x.size()[0], -1, 2) + + # Softmax on last dimension (fg vs. bg). + rpn_probs = F.softmax(rpn_class_logits, dim=2) + + # Bounding box refinement. (batch, anchors_per_location * (y, x, (z), log(h), log(w), (log(d)), y, x, (z)) + rpn_bbox = self.conv_bbox(x) + + # Reshape to (batch, anchors, 2*dim) + rpn_bbox = rpn_bbox.permute(*axes) + rpn_bbox = rpn_bbox.contiguous() + rpn_bbox = rpn_bbox.view(x.size()[0], -1, self.dim * 2) + + return [rpn_class_logits, rpn_probs, rpn_bbox] + +class RPN_Discriminator(nn.Module): + """ + Region Proposal Network. + """ + + def __init__(self, cf, conv): + + super(RPN_Discriminator, self).__init__() + self.dim = conv.dim + + #assert cf.batch_size%2==0 + self.resizer = nn.Sequential( + conv(cf.end_filts, cf.end_filts//2, ks=3, stride=cf.rpn_anchor_stride, pad=0, relu=cf.relu), + nn.MaxPool2d(kernel_size=3, stride=2, padding=0) if \ + conv.dim == 2 else nn.MaxPool3d(kernel_size=3,stride=(2, 2, 1),padding=0), + conv(cf.end_filts//2, cf.end_filts // 2, ks=1, stride=1, pad=0, relu=cf.relu), + nn.MaxPool2d(kernel_size=3, stride=2, padding=0) if \ + conv.dim == 2 else nn.MaxPool3d(kernel_size=3, stride=(2, 2, 1), padding=0), + + ) + self.in_channels = cf.end_filts * 4 + self.conv2 = conv(cf.end_filts, cf.n_rpn_features, ks=1, stride=1, pad=1, relu=cf.relu) + self.conv3 = conv(cf.n_rpn_features, 2 * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None) + + def forward(self, f_maps, probs, deltas): + """ + :param feature_maps: list of tensors of sizes (bsize, cf.end_filts, varying map dimensions) + :param probs: tensor of size (bsize, n_proposals on all fpn layers, 2) + :param deltas: tensor of size (bsize, n_proposals on all fpn layers, cf.dim*2) + :return: + """ + f_maps = [self.resizer(m) for m in f_maps] + x = torch.cat([t.view(t.shape[0], t.shape[1], -1) for t in f_maps], dim=-1) + x = x.view(-1, self.in_channels) + x = torch.cat((x,z), dim=1) + # Shared convolutional base of the RPN. + x = self.conv_shared(x) + + # Anchor Score. (batch, anchors per location * 2, y, x, (z)). + rpn_class_logits = self.conv_class(x) + # Reshape to (batch, 2, anchors) + axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) + rpn_class_logits = rpn_class_logits.permute(*axes) + rpn_class_logits = rpn_class_logits.contiguous() + rpn_class_logits = rpn_class_logits.view(x.size()[0], -1, 2) + + # Softmax on last dimension (fg vs. bg). + rpn_probs = F.softmax(rpn_class_logits, dim=2) + + # Bounding box refinement. (batch, anchors_per_location * (y, x, (z), log(h), log(w), (log(d)), y, x, (z)) + rpn_bbox = self.conv_bbox(x) + + # Reshape to (batch, 2*dim, anchors) + rpn_bbox = rpn_bbox.permute(*axes) + rpn_bbox = rpn_bbox.contiguous() + rpn_bbox = rpn_bbox.view(x.size()[0], -1, self.dim * 2) + + return [rpn_class_logits, rpn_probs, rpn_bbox] + + + + + +class Classifier(nn.Module): + """ + Head network for classification and bounding box refinement. Performs RoiAlign, processes resulting features through a + shared convolutional base and finally branches off the classifier- and regression head. + """ + def __init__(self, cf, conv): + super(Classifier, self).__init__() + + self.cf = cf + self.dim = conv.dim + self.in_channels = cf.end_filts + self.pool_size = cf.pool_size + self.pyramid_levels = cf.pyramid_levels + # instance_norm does not work with spatial dims (1, 1, (1)) + norm = cf.norm if cf.norm != 'instance_norm' else None + + self.conv1 = conv(cf.end_filts, cf.end_filts * 4, ks=self.pool_size, stride=1, norm=norm, relu=cf.relu) + self.conv2 = conv(cf.end_filts * 4, cf.end_filts * 4, ks=1, stride=1, norm=norm, relu=cf.relu) + self.linear_bbox = nn.Linear(cf.end_filts * 4, cf.head_classes * 2 * self.dim) + + + if 'regression' in self.cf.prediction_tasks: + self.linear_regressor = nn.Linear(cf.end_filts * 4, cf.head_classes * cf.regression_n_features) + self.rg_n_feats = cf.regression_n_features + #classify into bins of regression values + elif 'regression_bin' in self.cf.prediction_tasks: + self.linear_regressor = nn.Linear(cf.end_filts * 4, cf.head_classes * len(cf.bin_labels)) + self.rg_n_feats = len(cf.bin_labels) + else: + self.linear_regressor = lambda x: torch.zeros((x.shape[0], cf.head_classes * cf.regression_n_features), dtype=torch.float32).fill_(float('NaN')).cuda() + self.rg_n_feats = cf.regression_n_features + if 'class' in self.cf.prediction_tasks: + self.linear_class = nn.Linear(cf.end_filts * 4, cf.head_classes) + else: + assert cf.head_classes == 2, "#head classes {} needs to be 2 (bg/fg) when not predicting classes".format(cf.head_classes) + self.linear_class = lambda x: torch.zeros((x.shape[0], cf.head_classes), dtype=torch.float64).cuda() + #print("\n\nWARNING: using extra class head\n\n") + #self.linear_class = nn.Linear(cf.end_filts * 4, cf.head_classes) + + def forward(self, x, rois): + """ + :param x: input feature maps (b, in_channels, y, x, (z)) + :param rois: normalized box coordinates as proposed by the RPN to be forwarded through + the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements + have been merged to one vector, while the origin info has been stored for re-allocation. + :return: mrcnn_class_logits (n_proposals, n_head_classes) + :return: mrcnn_bbox (n_proposals, n_head_classes, 2 * dim) predicted corrections to be applied to proposals for refinement. + """ + x = mutils.pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim) + x = self.conv1(x) + x = self.conv2(x) + x = x.view(-1, self.in_channels * 4) + + mrcnn_bbox = self.linear_bbox(x) + mrcnn_bbox = mrcnn_bbox.view(mrcnn_bbox.size()[0], -1, self.dim * 2) + mrcnn_class_logits = self.linear_class(x) + mrcnn_regress = self.linear_regressor(x) + mrcnn_regress = mrcnn_regress.view(mrcnn_regress.size()[0], -1, self.rg_n_feats) + + return [mrcnn_bbox, mrcnn_class_logits, mrcnn_regress] + + +class Mask(nn.Module): + """ + Head network for proposal-based mask segmentation. Performs RoiAlign, some convolutions and applies sigmoid on the + output logits to allow for overlapping classes. + """ + def __init__(self, cf, conv): + super(Mask, self).__init__() + self.pool_size = cf.mask_pool_size + self.pyramid_levels = cf.pyramid_levels + self.dim = conv.dim + self.conv1 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) + self.conv2 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) + self.conv3 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) + self.conv4 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) + if conv.dim == 2: + self.deconv = nn.ConvTranspose2d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2) + else: + self.deconv = nn.ConvTranspose3d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2) + + self.relu = nn.ReLU(inplace=True) if cf.relu == 'relu' else nn.LeakyReLU(inplace=True) + self.conv5 = conv(cf.end_filts, cf.head_classes, ks=1, stride=1, relu=None) + self.sigmoid = nn.Sigmoid() + + def forward(self, x, rois): + """ + :param x: input feature maps (b, in_channels, y, x, (z)) + :param rois: normalized box coordinates as proposed by the RPN to be forwarded through + the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements + have been merged to one vector, while the origin info has been stored for re-allocation. + :return: x: masks (n_sampled_proposals (n_detections in inference), n_classes, y, x, (z)) + """ + x = mutils.pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim) + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + x = self.relu(self.deconv(x)) + x = self.conv5(x) + x = self.sigmoid(x) + return x + + +############################################################ +# Loss Functions +############################################################ + +def compute_rpn_class_loss(rpn_class_logits, rpn_match, shem_poolsize): + """ + :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. + :param rpn_class_logits: (n_anchors, 2). logits from RPN classifier. + :param SHEM_poolsize: int. factor of top-k candidates to draw from per negative sample (stochastic-hard-example-mining). + :return: loss: torch tensor + :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training. + """ + + # Filter out netural anchors + pos_indices = torch.nonzero(rpn_match == 1) + neg_indices = torch.nonzero(rpn_match == -1) + + # loss for positive samples + if not 0 in pos_indices.size(): + pos_indices = pos_indices.squeeze(1) + roi_logits_pos = rpn_class_logits[pos_indices] + pos_loss = F.cross_entropy(roi_logits_pos, torch.LongTensor([1] * pos_indices.shape[0]).cuda()) + else: + pos_loss = torch.FloatTensor([0]).cuda() + + # loss for negative samples: draw hard negative examples (SHEM) + # that match the number of positive samples, but at least 1. + if not 0 in neg_indices.size(): + neg_indices = neg_indices.squeeze(1) + roi_logits_neg = rpn_class_logits[neg_indices] + negative_count = np.max((1, pos_indices.cpu().data.numpy().size)) + roi_probs_neg = F.softmax(roi_logits_neg, dim=1) + neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize) + neg_loss = F.cross_entropy(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda()) + np_neg_ix = neg_ix.cpu().data.numpy() + #print("pos, neg count", pos_indices.cpu().data.numpy().size, negative_count) + else: + neg_loss = torch.FloatTensor([0]).cuda() + np_neg_ix = np.array([]).astype('int32') + + loss = (pos_loss + neg_loss) / 2 + return loss, np_neg_ix + + +def compute_rpn_bbox_loss(rpn_pred_deltas, rpn_target_deltas, rpn_match): + """ + :param rpn_target_deltas: (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))). + Uses 0 padding to fill in unsed bbox deltas. + :param rpn_pred_deltas: predicted deltas from RPN. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))) + :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. + :return: loss: torch 1D tensor. + """ + if not 0 in torch.nonzero(rpn_match == 1).size(): + + indices = torch.nonzero(rpn_match == 1).squeeze(1) + # Pick bbox deltas that contribute to the loss + rpn_pred_deltas = rpn_pred_deltas[indices] + # Trim target bounding box deltas to the same length as rpn_bbox. + target_deltas = rpn_target_deltas[:rpn_pred_deltas.size()[0], :] + # Smooth L1 loss + loss = F.smooth_l1_loss(rpn_pred_deltas, target_deltas) + else: + loss = torch.FloatTensor([0]).cuda() + + return loss + +def compute_disc_loss(d_target, d_pred, target, shem_poolsize): + + + + + return + + +def compute_mrcnn_bbox_loss(mrcnn_pred_deltas, mrcnn_target_deltas, target_class_ids): + """ + :param mrcnn_target_deltas: (n_sampled_rois, (dy, dx, (dz), log(dh), log(dw), (log(dh))) + :param mrcnn_pred_deltas: (n_sampled_rois, n_classes, (dy, dx, (dz), log(dh), log(dw), (log(dh))) + :param target_class_ids: (n_sampled_rois) + :return: loss: torch 1D tensor. + """ + if not 0 in torch.nonzero(target_class_ids > 0).size(): + positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0] + positive_roi_class_ids = target_class_ids[positive_roi_ix].long() + target_bbox = mrcnn_target_deltas[positive_roi_ix, :].detach() + pred_bbox = mrcnn_pred_deltas[positive_roi_ix, positive_roi_class_ids, :] + loss = F.smooth_l1_loss(pred_bbox, target_bbox) + else: + loss = torch.FloatTensor([0]).cuda() + + return loss + +def compute_mrcnn_mask_loss(pred_masks, target_masks, target_class_ids): + """ + :param target_masks: (n_sampled_rois, y, x, (z)) A float32 tensor of values 0 or 1. Uses zero padding to fill array. + :param pred_masks: (n_sampled_rois, n_classes, y, x, (z)) float32 tensor with values between [0, 1]. + :param target_class_ids: (n_sampled_rois) + :return: loss: torch 1D tensor. + """ + if not 0 in torch.nonzero(target_class_ids > 0).size(): + # Only positive ROIs contribute to the loss. And only + # the class-specific mask of each ROI. + positive_ix = torch.nonzero(target_class_ids > 0)[:, 0] + positive_class_ids = target_class_ids[positive_ix].long() + y_true = target_masks[positive_ix, :, :].detach() + y_pred = pred_masks[positive_ix, positive_class_ids, :, :] + loss = F.binary_cross_entropy(y_pred, y_true) + else: + loss = torch.FloatTensor([0]).cuda() + + return loss + +def compute_mrcnn_class_loss(tasks, pred_class_logits, target_class_ids): + """ + :param pred_class_logits: (n_sampled_rois, n_classes) + :param target_class_ids: (n_sampled_rois) batch dimension was merged into roi dimension. + :return: loss: torch 1D tensor. + """ + if 'class' in tasks and not 0 in target_class_ids.size(): + #if 0 in target_class_ids.size(): + # print("WARNING: using additional cl head") + loss = F.cross_entropy(pred_class_logits, target_class_ids.long()) + else: + loss = torch.FloatTensor([0.]).cuda() + + return loss + +def compute_mrcnn_regression_loss(tasks, pred, target, target_class_ids): + """regression loss is a distance metric between target vector and predicted regression vector. + :param pred: (n_sampled_rois, n_classes, [n_rg_feats if real regression or 1 if rg_bin task) + :param target: (n_sampled_rois, [n_rg_feats or n_rg_bins]) + :return: differentiable loss, torch 1D tensor on cuda + """ + + if not 0 in target.shape and not 0 in torch.nonzero(target_class_ids > 0).shape: + if "regression_bin" in tasks: + positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0] + positive_roi_class_ids = target_class_ids[positive_roi_ix].long() + target = target[positive_roi_ix].detach() + pred = pred[positive_roi_ix, positive_roi_class_ids] #are the class logits + loss = F.cross_entropy(pred, target.long()) + else: + positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0] + positive_roi_class_ids = target_class_ids[positive_roi_ix].long() + target = target[positive_roi_ix, :].detach() + pred = pred[positive_roi_ix, positive_roi_class_ids, :] + loss = F.smooth_l1_loss(pred, target) + else: + loss = torch.FloatTensor([0.]).cuda() + + return loss + +############################################################ +# Detection Layer +############################################################ + +def compute_roi_scores(cf, batch_rpn_proposals, mrcnn_cl_logits): + """Compute scores from uncertainty measures (lower=better) to use for sorting/clustering algos (higher=better). + :param cf: + :param uncert_class: + :param uncert_regression: + :return: + """ + if not 'class' in cf.prediction_tasks: + scores = batch_rpn_proposals[:, :, -1].view(-1, 1) + scores = torch.cat((1 - scores, scores), dim=1) + else: + #print("WARNING: using extra class head") + scores = F.softmax(mrcnn_cl_logits, dim=1) + + return scores + +############################################################ +# MaskRCNN Class +############################################################ + +class net(nn.Module): + + + def __init__(self, cf, logger): + + super(net, self).__init__() + self.cf = cf + self.logger = logger + self.build() + + + if self.cf.weight_init=="custom": + logger.info("Tried to use custom weight init which is not defined. Using pytorch default.") + elif self.cf.weight_init: + mutils.initialize_weights(self) + else: + logger.info("using default pytorch weight init") + + def build(self): + """Build Mask R-CNN architecture.""" + + # Image size must be dividable by 2 multiple times. + h, w = self.cf.patch_size[:2] + if h / 2**5 != int(h / 2**5) or w / 2**5 != int(w / 2**5): + raise Exception("Image size must be divisible by 2 at least 5 times " + "to avoid fractions when downscaling and upscaling." + "For example, use 256, 320, 384, 448, 512, ... etc.,i.e.," + "any number x*32 will do!") + + # instantiate abstract multi-dimensional conv generator and load backbone module. + backbone = utils.import_module('bbone', self.cf.backbone_path) + conv = backbone.ConvGenerator(self.cf.dim) + + # build Anchors, FPN, RPN, Classifier / Bbox-Regressor -head, Mask-head + self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf) + self.anchors = torch.from_numpy(self.np_anchors).float().cuda() + self.fpn = backbone.FPN(self.cf, conv, relu_enc=self.cf.relu, operate_stride1=False).cuda() + self.rpn = Generator_RPN(self.cf, conv) + self.discriminator = RPN_Discriminator(self.cf, conv) + self.classifier = Classifier(self.cf, conv) + self.mask = Mask(self.cf, conv) + + def forward(self, img, is_training=True): + """ + :param img: input images (b, c, y, x, (z)). + :return: rpn_pred_logits: (b, n_anchors, 2) + :return: rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d)))) + :return: batch_proposal_boxes: (b, n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix)) only for monitoring/plotting. + :return: detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score) + :return: detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head. + """ + # extract features. + fpn_outs = self.fpn(img) + rpn_feature_maps = [fpn_outs[i] for i in self.cf.pyramid_levels] + self.mrcnn_feature_maps = rpn_feature_maps + + # loop through pyramid layers and apply RPN. + layer_outputs = [ self.rpn(p_feats) for p_feats in rpn_feature_maps ] + + # concatenate layer outputs. + # convert from list of lists of level outputs to list of lists of outputs across levels. + # e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]] + outputs = list(zip(*layer_outputs)) + rpn_pred_logits, rpn_pred_probs, rpn_pred_deltas = [torch.cat(list(o), dim=1) for o in outputs] + # + # # generate proposals: apply predicted deltas to anchors and filter by foreground scores from RPN classifier. + proposal_count = self.cf.post_nms_rois_training if is_training else self.cf.post_nms_rois_inference + batch_normed_props, batch_unnormed_props = mutils.refine_proposals(rpn_pred_probs, rpn_pred_deltas,proposal_count, + self.anchors, self.cf) + # merge batch dimension of proposals while storing allocation info in coordinate dimension. + batch_ixs = torch.arange( + batch_normed_props.shape[0]).cuda().unsqueeze(1).repeat(1, batch_normed_props.shape[1]).view(-1).float() + rpn_rois = batch_normed_props[:, :, :-1].view(-1, batch_normed_props[:, :, :-1].shape[2]) + self.rpn_rois_batch_info = torch.cat((rpn_rois, batch_ixs.unsqueeze(1)), dim=1) + + # this is the first of two forward passes in the second stage, where no activations are stored for backprop. + # here, all proposals are forwarded (with virtual_batch_size = batch_size * post_nms_rois.) + # for inference/monitoring as well as sampling of rois for the loss functions. + # processed in chunks of roi_batch_size to re-adjust to gpu-memory. + chunked_rpn_rois = self.rpn_rois_batch_info.split(self.cf.roi_batch_size) + bboxes_list, class_logits_list, regressions_list = [], [], [] + with torch.no_grad(): + for chunk in chunked_rpn_rois: + chunk_bboxes, chunk_class_logits, chunk_regressions = self.classifier(self.mrcnn_feature_maps, chunk) + bboxes_list.append(chunk_bboxes) + class_logits_list.append(chunk_class_logits) + regressions_list.append(chunk_regressions) + mrcnn_bbox = torch.cat(bboxes_list, 0) + mrcnn_class_logits = torch.cat(class_logits_list, 0) + mrcnn_regressions = torch.cat(regressions_list, 0) + self.mrcnn_roi_scores = compute_roi_scores(self.cf, batch_normed_props, mrcnn_class_logits) + + # refine classified proposals, filter and return final detections. + # returns (cf.max_inst_per_batch_element, n_coords+1+...) + detections = mutils.refine_detections(self.cf, batch_ixs, rpn_rois, mrcnn_bbox, self.mrcnn_roi_scores, + mrcnn_regressions) + + # forward remaining detections through mask-head to generate corresponding masks. + scale = [img.shape[2]] * 4 + [img.shape[-1]] * 2 + scale = torch.from_numpy(np.array(scale[:self.cf.dim * 2] + [1])[None]).float().cuda() + + # first self.cf.dim * 2 entries on axis 1 are always the box coords, +1 is batch_ix + detection_boxes = detections[:, :self.cf.dim * 2 + 1] / scale + with torch.no_grad(): + detection_masks = self.mask(self.mrcnn_feature_maps, detection_boxes) + + return rpn_pred_logits, rpn_pred_probs, rpn_pred_deltas, batch_unnormed_props, detections, detection_masks + + def loss_samples_forward(self, batch_gt_boxes, batch_gt_masks, batch_gt_class_ids, batch_gt_regressions=None): + """ + this is the second forward pass through the second stage (features from stage one are re-used). + samples few rois in loss_example_mining and forwards only those for loss computation. + :param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels. + :param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates. + :param batch_gt_masks: list over batch elements. Each element is binary mask of shape (n_gt_rois, y, x, (z), c) + :return: sample_logits: (n_sampled_rois, n_classes) predicted class scores. + :return: sample_deltas: (n_sampled_rois, n_classes, 2 * dim) predicted corrections to be applied to proposals for refinement. + :return: sample_mask: (n_sampled_rois, n_classes, y, x, (z)) predicted masks per class and proposal. + :return: sample_target_class_ids: (n_sampled_rois) target class labels of sampled proposals. + :return: sample_target_deltas: (n_sampled_rois, 2 * dim) target deltas of sampled proposals for box refinement. + :return: sample_target_masks: (n_sampled_rois, y, x, (z)) target masks of sampled proposals. + :return: sample_proposals: (n_sampled_rois, 2 * dim) RPN output for sampled proposals. only for monitoring/plotting. + """ + # sample rois for loss and get corresponding targets for all Mask R-CNN head network losses. + sample_ics, sample_target_deltas, sample_target_mask, sample_target_class_ids, sample_target_regressions = \ + mutils.loss_example_mining(self.cf, self.rpn_rois_batch_info, batch_gt_boxes, batch_gt_masks, + self.mrcnn_roi_scores, batch_gt_class_ids, batch_gt_regressions) + + # re-use feature maps and RPN output from first forward pass. + sample_proposals = self.rpn_rois_batch_info[sample_ics] + if not 0 in sample_proposals.size(): + sample_deltas, sample_logits, sample_regressions = self.classifier(self.mrcnn_feature_maps, sample_proposals) + sample_mask = self.mask(self.mrcnn_feature_maps, sample_proposals) + else: + sample_logits = torch.FloatTensor().cuda() + sample_deltas = torch.FloatTensor().cuda() + sample_regressions = torch.FloatTensor().cuda() + sample_mask = torch.FloatTensor().cuda() + + return [sample_deltas, sample_mask, sample_logits, sample_regressions, sample_proposals, + sample_target_deltas, sample_target_mask, sample_target_class_ids, sample_target_regressions] + + def get_results(self, img_shape, detections, detection_masks, box_results_list=None, return_masks=True): + """ + Restores batch dimension of merged detections, unmolds detections, creates and fills results dict. + :param img_shape: + :param detections: shape (n_final_detections, len(info)), where + info=( y1, x1, y2, x2, (z1,z2), batch_ix, pred_class_id, pred_score ) + :param detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head. + :param box_results_list: None or list of output boxes for monitoring/plotting. + each element is a list of boxes per batch element. + :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off). + :return: results_dict: dictionary with keys: + 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: + [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] + 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, 1] only fg. vs. bg for now. + class-specific return of masks will come with implementation of instance segmentation evaluation. + """ + + detections = detections.cpu().data.numpy() + if self.cf.dim == 2: + detection_masks = detection_masks.permute(0, 2, 3, 1).cpu().data.numpy() + else: + detection_masks = detection_masks.permute(0, 2, 3, 4, 1).cpu().data.numpy() + # det masks shape now (n_dets, y,x(,z), n_classes) + # restore batch dimension of merged detections using the batch_ix info. + batch_ixs = detections[:, self.cf.dim*2] + detections = [detections[batch_ixs == ix] for ix in range(img_shape[0])] + mrcnn_mask = [detection_masks[batch_ixs == ix] for ix in range(img_shape[0])] + #mrcnn_mask: shape (b_size, variable, variable, n_classes), variable bc depends on single instance mask size + + if box_results_list == None: # for test_forward, where no previous list exists. + box_results_list = [[] for _ in range(img_shape[0])] + # seg_logits == seg_probs in mrcnn since mask head finishes with sigmoid (--> image space = [0,1]) + seg_probs = [] + # loop over batch and unmold detections. + for ix in range(img_shape[0]): + + # final masks are one-hot encoded (b, n_classes, y, x, (z)) + final_masks = np.zeros((self.cf.num_classes + 1, *img_shape[2:])) + #+1 for bg, 0.5 bc mask head classifies only bg/fg with logits between 0,1--> bg is <0.5 + if self.cf.num_classes + 1 != self.cf.num_seg_classes: + self.logger.warning("n of box classifier head classes {} doesnt match cf.num_seg_classes {}".format( + self.cf.num_classes + 1, self.cf.num_seg_classes)) + + if not 0 in detections[ix].shape: + boxes = detections[ix][:, :self.cf.dim*2].astype(np.int32) + class_ids = detections[ix][:, self.cf.dim*2 + 1].astype(np.int32) + scores = detections[ix][:, self.cf.dim*2 + 2] + masks = mrcnn_mask[ix][np.arange(boxes.shape[0]), ..., class_ids] + regressions = detections[ix][:,self.cf.dim*2+3:] + + # Filter out detections with zero area. Often only happens in early + # stages of training when the network weights are still a bit random. + if self.cf.dim == 2: + exclude_ix = np.where((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0] + else: + exclude_ix = np.where( + (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4]) <= 0)[0] + + if exclude_ix.shape[0] > 0: + boxes = np.delete(boxes, exclude_ix, axis=0) + masks = np.delete(masks, exclude_ix, axis=0) + class_ids = np.delete(class_ids, exclude_ix, axis=0) + scores = np.delete(scores, exclude_ix, axis=0) + regressions = np.delete(regressions, exclude_ix, axis=0) + + # Resize masks to original image size and set boundary threshold. + if return_masks: + for i in range(masks.shape[0]): #masks per this batch instance/element/image + # Convert neural network mask to full size mask + if self.cf.dim == 2: + full_mask = mutils.unmold_mask_2D(masks[i], boxes[i], img_shape[2:]) + else: + full_mask = mutils.unmold_mask_3D(masks[i], boxes[i], img_shape[2:]) + # take the maximum seg_logits per class of instances in that class, i.e., a pixel in a class + # has the max seg_logit value over all instances of that class in one sample + final_masks[class_ids[i]] = np.max((final_masks[class_ids[i]], full_mask), axis=0) + final_masks[0] = np.full(final_masks[0].shape, 0.49999999) #effectively min_det_thres at 0.5 per pixel + + # add final predictions to results. + if not 0 in boxes.shape: + for ix2, coords in enumerate(boxes): + box = {'box_coords': coords, 'box_type': 'det'} + box.update({'box_score': scores[ix2], 'box_pred_class_id': class_ids[ix2]}) + #if (hasattr(self.cf, "convert_cl_to_rg") and self.cf.convert_cl_to_rg): + if "regression_bin" in self.cf.prediction_tasks: + # in this case, regression preds are actually the rg_bin_ids --> map to rg value the bin stands for + box['rg_bin'] = regressions[ix2].argmax() + box['regression'] = self.cf.bin_id2rg_val[box['rg_bin']] + else: + if hasattr(self.cf, "rg_val_to_bin_id"): + box.update({'rg_bin': self.cf.rg_val_to_bin_id(regressions[ix2])}) + box['regression'] = regressions[ix2] + + box_results_list[ix].append(box) + + # if no detections were made--> keep full bg mask (zeros). + seg_probs.append(final_masks) + + # create and fill results dictionary. + results_dict = {} + results_dict['boxes'] = box_results_list + results_dict['seg_preds'] = np.array(seg_probs) + + return results_dict + + + def train_forward(self, batch, is_validation=False): + """ + train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data + for processing, computes losses, and stores outputs in a dictionary. + :param batch: dictionary containing 'data', 'seg', etc. + :return: results_dict: dictionary with keys: + 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: + [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] + 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]. + 'torch_loss': 1D torch tensor for backprop. + 'class_loss': classification loss for monitoring. + """ + img = batch['data'] + gt_boxes = batch['bb_target'] + axes = (0, 2, 3, 1) if self.cf.dim == 2 else (0, 2, 3, 4, 1) + gt_masks = [np.transpose(batch['roi_masks'][ii], axes=axes) for ii in range(len(batch['roi_masks']))] + gt_class_ids = batch['class_targets'] + if 'regression' in self.cf.prediction_tasks: + gt_regressions = batch["regression_targets"] + elif 'regression_bin' in self.cf.prediction_tasks: + gt_regressions = batch["rg_bin_targets"] + else: + gt_regressions = None + + + img = torch.from_numpy(img).float().cuda() + batch_rpn_class_loss = torch.FloatTensor([0]).cuda() + batch_rpn_bbox_loss = torch.FloatTensor([0]).cuda() + # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. + box_results_list = [[] for _ in range(img.shape[0])] + + #forward passes. 1. general forward pass, where no activations are saved in second stage (for performance + # monitoring and loss sampling). 2. second stage forward pass of sampled rois with stored activations for backprop. + rpn_class_logits, rpn_probs, rpn_pred_deltas, proposal_boxes, detections, detection_masks = self.forward(img) + + mrcnn_pred_deltas, mrcnn_pred_mask, mrcnn_class_logits, mrcnn_regressions, sample_proposals, \ + mrcnn_target_deltas, target_mask, target_class_ids, target_regressions = \ + self.loss_samples_forward(gt_boxes, gt_masks, gt_class_ids, gt_regressions) + + rpn_batch_match_targets = torch.zeros(img.shape[0], self.np_anchors.shape[0]).cuda() + rpn_batch_delta_targets = torch.zeros(img.shape[0], self.np_anchors.shape[0], self.cf.dim*2).cuda() + #loop over batch + for b in range(img.shape[0]): + rpn_target_deltas = np.zeros((self.np_anchors.shape[0], self.cf.dim * 2)) + if len(gt_boxes[b]) > 0: + # add gt boxes to output list + for tix in range(len(gt_boxes[b])): + gt_box = {'box_type': 'gt', 'box_coords': batch['bb_target'][b][tix]} + for name in self.cf.roi_items: + gt_box.update({name: batch[name][b][tix]}) + box_results_list[b].append(gt_box) + + # match gt boxes with anchors to generate targets for RPN losses. + rpn_match, rpn_t_deltas = mutils.gt_anchor_matching(self.cf, self.np_anchors, gt_boxes[b]) + indices = np.nonzero(rpn_match == 1)[0] + rpn_target_deltas[indices] = rpn_t_deltas[:indices.shape[0]] + + # add positive anchors used for loss to output list for monitoring. + # pos_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == 1)][:, 0], img.shape[2:]) + # for p in pos_anchors: + # box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'}) + else: + rpn_match = np.array([-1]*self.np_anchors.shape[0]) + + rpn_batch_match_targets[b] = torch.from_numpy(rpn_match).cuda() + rpn_batch_delta_targets[b] = torch.from_numpy(rpn_target_deltas).float().cuda() + # compute RPN losses. + #rpn_class_loss, neg_anchor_ix = compute_rpn_class_loss(rpn_class_logits[b], rpn_match, self.cf.shem_poolsize) + #rpn_bbox_loss = compute_rpn_bbox_loss(rpn_pred_deltas[b], rpn_target_deltas, rpn_match) + + # batch_rpn_class_loss += rpn_class_loss /img.shape[0] + # batch_rpn_bbox_loss += rpn_bbox_loss /img.shape[0] + + # add negative anchors used for loss to output list for monitoring. + # neg_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == -1)][0, neg_anchor_ix], img.shape[2:]) + # for n in neg_anchors: + # box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'}) + + # add highest scoring proposals to output list for monitoring. + rpn_proposals = proposal_boxes[b][proposal_boxes[b, :, -1].argsort()][::-1] + for r in rpn_proposals[:self.cf.n_plot_rpn_props, :-1]: + box_results_list[b].append({'box_coords': r, 'box_type': 'prop'}) + + #filter_anchors(rpn_batch_match_targets, rpn_class_logits, rpn_batch_delta_targets, rpn_pred_deltas, + # self.cf.shem_poolsize) + # todo maybe send fixed number of rois to disc (fill up targets with bg-rois)? + non_neutral_mask = (rpn_batch_match_targets == 1) | (rpn_batch_match_targets == -1) + rpn_batch_match_targets = rpn_batch_match_targets[non_neutral_mask] + rpn_batch_delta_targets = rpn_batch_delta_targets[non_neutral_mask] + rpn_probs = rpn_probs[non_neutral_mask] + rpn_pred_deltas = rpn_pred_deltas[non_neutral_mask] + + # add positive and negative roi samples used for mrcnn losses to output list for monitoring. + # if not 0 in sample_proposals.shape: + # rois = mutils.clip_to_window(self.cf.window, sample_proposals).cpu().data.numpy() + # for ix, r in enumerate(rois): + # box_results_list[int(r[-1])].append({'box_coords': r[:-1] * self.cf.scale, + # 'box_type': 'pos_class' if target_class_ids[ix] > 0 else 'neg_class'}) + + # get discriminator judgement on predicted proposals + # d_z = self.discriminator(self.mrcnn_feature_maps, rpn_probs, rpn_pred_deltas) + d_judgement_gen = self.discriminator(self.mrcnn_feature_maps, rpn_batch_match_targets, rpn_batch_delta_targets) + + # compute Discriminator loss + compute_disc_loss(d_pred_target, d_pred_pred, d_target, self.cf.shem_poolsize) + + + # compute mrcnn losses. + mrcnn_class_loss = compute_mrcnn_class_loss(self.cf.prediction_tasks, mrcnn_class_logits, target_class_ids) + mrcnn_bbox_loss = compute_mrcnn_bbox_loss(mrcnn_pred_deltas, mrcnn_target_deltas, target_class_ids) + mrcnn_regressions_loss = compute_mrcnn_regression_loss(self.cf.prediction_tasks, mrcnn_regressions, target_regressions, target_class_ids) + # mrcnn can be run without pixelwise annotations available (Faster R-CNN mode). + # In this case, the mask_loss is taken out of training. + if not self.cf.frcnn_mode: + mrcnn_mask_loss = compute_mrcnn_mask_loss(mrcnn_pred_mask, target_mask, target_class_ids) + else: + mrcnn_mask_loss = torch.FloatTensor([0]).cuda() + + loss = batch_rpn_class_loss + batch_rpn_bbox_loss +\ + mrcnn_bbox_loss + mrcnn_mask_loss + mrcnn_class_loss + mrcnn_regressions_loss + + # monitor RPN performance: detection count = the number of correctly matched proposals per fg-class. + #dcount = [list(target_class_ids.cpu().data.numpy()).count(c) for c in np.arange(self.cf.head_classes)[1:]] + #self.logger.info("regression loss {:.3f}".format(mrcnn_regressions_loss.item())) + #self.logger.info("loss: {0:.2f}, rpn_class: {1:.2f}, rpn_bbox: {2:.2f}, mrcnn_class: {3:.2f}, mrcnn_bbox: {4:.2f}, " + # "mrcnn_mask: {5:.2f}, dcount {6}".format(loss.item(), batch_rpn_class_loss.item(), + # batch_rpn_bbox_loss.item(), mrcnn_class_loss.item(), mrcnn_bbox_loss.item(), mrcnn_mask_loss.item(), dcount)) + + # run unmolding of predictions for monitoring and merge all results to one dictionary. + if is_validation or self.cf.detect_while_training: + return_masks = self.cf.return_masks_in_val if is_validation else self.cf.return_masks_in_train + results_dict = self.get_results( + img.shape, detections, detection_masks, box_results_list, return_masks=return_masks) #TODO make multithreaded? + results_dict['seg_preds'] = results_dict['seg_preds'].argmax(axis=1).astype('uint8')[:,np.newaxis] + if 'dice' in self.cf.metrics: + results_dict['batch_dices'] = mutils.dice_per_batch_and_class( + results_dict['seg_preds'], batch["seg"], self.cf.num_seg_classes, convert_to_ohe=True) + else: + results_dict = {'boxes': box_results_list} + + results_dict['torch_loss'] = loss + results_dict['class_loss'] = mrcnn_class_loss.item() + results_dict['bbox_loss'] = mrcnn_bbox_loss.item() + results_dict['rg_loss'] = mrcnn_regressions_loss.item() + results_dict['rpn_class_loss'] = rpn_class_loss.item() + results_dict['rpn_bbox_loss'] = rpn_bbox_loss.item() + # #todo remove assert when sufficiently checked + # boxescoords = [b['box_coords'] for boxlist in box_results_list for b in boxlist] + # coords_check = np.array([len(coords) == self.cf.dim*2 for coords in boxescoords]) + # assert np.all(coords_check), "cand box with wrong bcoords dim: {}".format(boxescoords[~coords_check]) + + return results_dict + + + def test_forward(self, batch, return_masks=True): + """ + test method. wrapper around forward pass of network without usage of any ground truth information. + prepares input data for processing and stores outputs in a dictionary. + :param batch: dictionary containing 'data' + :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off). + :return: results_dict: dictionary with keys: + 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: + [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] + 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes] + """ + img = batch['data'] + img = torch.from_numpy(img).float().cuda() + _, _, _, detections, detection_masks = self.forward(img) + results_dict = self.get_results(img.shape, detections, detection_masks, return_masks=return_masks) + + return results_dict \ No newline at end of file diff --git a/models/retina_net.py b/models/retina_net.py new file mode 100644 index 0000000..ac4e17e --- /dev/null +++ b/models/retina_net.py @@ -0,0 +1,782 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Retina Net. According to https://arxiv.org/abs/1708.02002""" + +import utils.model_utils as mutils +import utils.exp_utils as utils +import sys +sys.path.append('../') +from cuda_functions.nms_2D.pth_nms import nms_gpu as nms_2D +from cuda_functions.nms_3D.pth_nms import nms_gpu as nms_3D + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils + + +class Classifier(nn.Module): + + + def __init__(self, cf, conv): + """ + Builds the classifier sub-network. + """ + super(Classifier, self).__init__() + self.dim = conv.dim + self.n_classes = cf.head_classes + n_input_channels = cf.end_filts + n_features = cf.n_rpn_features + n_output_channels = cf.n_anchors_per_pos * cf.head_classes + anchor_stride = cf.rpn_anchor_stride + + self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) + self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) + self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) + self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) + self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, pad=1, relu=None) + + + def forward(self, x): + """ + :param x: input feature map (b, in_c, y, x, (z)) + :return: class_logits (b, n_anchors, n_classes) + """ + x = self.conv_1(x) + x = self.conv_2(x) + x = self.conv_3(x) + x = self.conv_4(x) + + class_logits = self.conv_final(x) + axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) + class_logits = class_logits.permute(*axes) + class_logits = class_logits.contiguous() + class_logits = class_logits.view(x.shape[0], -1, self.n_classes) + + return [class_logits] + +class BBRegressor(nn.Module): + + + def __init__(self, cf, conv): + """ + Builds the bb-regression sub-network. + """ + super(BBRegressor, self).__init__() + self.dim = conv.dim + n_input_channels = cf.end_filts + n_features = cf.n_rpn_features + n_output_channels = cf.n_anchors_per_pos * self.dim * 2 + anchor_stride = cf.rpn_anchor_stride + + self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) + self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) + self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) + self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) + self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, pad=1, relu=None) + + def forward(self, x): + """ + :param x: input feature map (b, in_c, y, x, (z)) + :return: bb_logits (b, n_anchors, dim * 2) + """ + x = self.conv_1(x) + x = self.conv_2(x) + x = self.conv_3(x) + x = self.conv_4(x) + bb_logits = self.conv_final(x) + + axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) + bb_logits = bb_logits.permute(*axes) + bb_logits = bb_logits.contiguous() + bb_logits = bb_logits.view(x.shape[0], -1, self.dim * 2) + + return [bb_logits] + + +class RoIRegressor(nn.Module): + + + def __init__(self, cf, conv, rg_feats): + """ + Builds the RoI-item-regression sub-network. Regression items can be, e.g., malignancy scores of tumors. + """ + super(RoIRegressor, self).__init__() + self.dim = conv.dim + n_input_channels = cf.end_filts + n_features = cf.n_rpn_features + self.rg_feats = rg_feats + n_output_channels = cf.n_anchors_per_pos * self.rg_feats + anchor_stride = cf.rpn_anchor_stride + self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) + self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) + self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) + self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) + self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, + pad=1, relu=None) + + def forward(self, x): + """ + :param x: input feature map (b, in_c, y, x, (z)) + :return: bb_logits (b, n_anchors, dim * 2) + """ + x = self.conv_1(x) + x = self.conv_2(x) + x = self.conv_3(x) + x = self.conv_4(x) + x = self.conv_final(x) + + axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) + x = x.permute(*axes) + x = x.contiguous() + x = x.view(x.shape[0], -1, self.rg_feats) + + return [x] + + + +############################################################ +# Loss Functions +############################################################ +# +def compute_class_loss(anchor_matches, class_pred_logits, shem_poolsize=20): + """ + :param anchor_matches: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. + :param class_pred_logits: (n_anchors, n_classes). logits from classifier sub-network. + :param shem_poolsize: int. factor of top-k candidates to draw from per negative sample (online-hard-example-mining). + :return: loss: torch tensor + :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training. + """ + # Positive and Negative anchors contribute to the loss, + # but neutral anchors (match value = 0) don't. + pos_indices = torch.nonzero(anchor_matches > 0) + neg_indices = torch.nonzero(anchor_matches == -1) + + # get positive samples and calucalte loss. + if not 0 in pos_indices.size(): + pos_indices = pos_indices.squeeze(1) + roi_logits_pos = class_pred_logits[pos_indices] + targets_pos = anchor_matches[pos_indices].detach() + pos_loss = F.cross_entropy(roi_logits_pos, targets_pos.long()) + else: + pos_loss = torch.FloatTensor([0]).cuda() + + # get negative samples, such that the amount matches the number of positive samples, but at least 1. + # get high scoring negatives by applying online-hard-example-mining. + if not 0 in neg_indices.size(): + neg_indices = neg_indices.squeeze(1) + roi_logits_neg = class_pred_logits[neg_indices] + negative_count = np.max((1, pos_indices.cpu().data.numpy().size)) + roi_probs_neg = F.softmax(roi_logits_neg, dim=1) + neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize) + neg_loss = F.cross_entropy(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda()) + # return the indices of negative samples, who contributed to the loss for monitoring plots. + np_neg_ix = neg_ix.cpu().data.numpy() + else: + neg_loss = torch.FloatTensor([0]).cuda() + np_neg_ix = np.array([]).astype('int32') + + loss = (pos_loss + neg_loss) / 2 + return loss, np_neg_ix + + +def compute_bbox_loss(target_deltas, pred_deltas, anchor_matches): + """ + :param target_deltas: (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))). + Uses 0 padding to fill in unused bbox deltas. + :param pred_deltas: predicted deltas from bbox regression head. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))) + :param anchor_matches: tensor (n_anchors). value in [-1, 0, class_ids] for negative, neutral, and positive matched anchors. + i.e., positively matched anchors are marked by class_id >0 + :return: loss: torch 1D tensor. + """ + if not 0 in torch.nonzero(anchor_matches>0).shape: + indices = torch.nonzero(anchor_matches>0).squeeze(1) + + # Pick bbox deltas that contribute to the loss + pred_deltas = pred_deltas[indices] + # Trim target bounding box deltas to the same length as pred_deltas. + target_deltas = target_deltas[:pred_deltas.shape[0], :].detach() + # Smooth L1 loss + loss = F.smooth_l1_loss(pred_deltas, target_deltas) + else: + loss = torch.FloatTensor([0]).cuda() + + return loss + +def compute_rg_loss(tasks, target, pred, anchor_matches): + """ + :param target_deltas: (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))). + Uses 0 padding to fill in unsed bbox deltas. + :param pred_deltas: predicted deltas from bbox regression head. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))) + :param anchor_matches: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. + :return: loss: torch 1D tensor. + """ + if not 0 in target.shape and not 0 in torch.nonzero(anchor_matches>0).shape: + indices = torch.nonzero(anchor_matches>0).squeeze(1) + # Pick rgs that contribute to the loss + pred = pred[indices] + # Trim target + target = target[:pred.shape[0]].detach() + if 'regression_bin' in tasks: + loss = F.cross_entropy(pred, target.long()) + else: + loss = F.smooth_l1_loss(pred, target) + else: + loss = torch.FloatTensor([0]).cuda() + + return loss + +def compute_focal_class_loss(anchor_matches, class_pred_logits, gamma=2.): + """ Focal Loss FL = -(1-q)^g log(q) with q = pred class probability. + + :param anchor_matches: (n_anchors). [-1, 0, class] for negative, neutral, and positive matched anchors. + :param class_pred_logits: (n_anchors, n_classes). logits from classifier sub-network. + :return: loss: torch tensor + :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training. + """ + # Positive and Negative anchors contribute to the loss, + # but neutral anchors (match value = 0) don't. + pos_indices = torch.nonzero(anchor_matches > 0).squeeze(-1) # dim=-1 instead of 1 or 0 to cover empty matches. + neg_indices = torch.nonzero(anchor_matches == -1).squeeze(-1) + target_classes = torch.cat( (anchor_matches[pos_indices].long(), torch.LongTensor([0] * neg_indices.shape[0]).cuda()) ) + + non_neutral_indices = torch.cat( (pos_indices, neg_indices) ) + q = F.softmax(class_pred_logits[non_neutral_indices], dim=1) # q shape: (n_non_neutral_anchors, n_classes) + + # one-hot encoded target classes: keep only the pred probs of the correct class. it will receive incentive to be maximized. + # log(q_i) where i = target class --> FL shape (n_anchors,) + # need to transform to indices into flattened tensor to use torch.take + target_locs_flat = q.shape[1] * torch.arange(q.shape[0]).cuda() + target_classes + q = torch.take(q, target_locs_flat) + + FL = torch.log(q) # element-wise log + FL *= -(1-q)**gamma + + # take mean over all considered anchors + FL = FL.sum() / FL.shape[0] + return FL + + + +def refine_detections(anchors, probs, deltas, regressions, batch_ixs, cf): + """Refine classified proposals, filter overlaps and return final + detections. n_proposals here is typically a very large number: batch_size * n_anchors. + This function is hence optimized on trimming down n_proposals. + :param anchors: (n_anchors, 2 * dim) + :param probs: (n_proposals, n_classes) softmax probabilities for all rois as predicted by classifier head. + :param deltas: (n_proposals, n_classes, 2 * dim) box refinement deltas as predicted by bbox regressor head. + :param regressions: (n_proposals, n_classes, n_rg_feats) + :param batch_ixs: (n_proposals) batch element assignemnt info for re-allocation. + :return: result: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score, pred_regr)) + """ + anchors = anchors.repeat(len(np.unique(batch_ixs)), 1) + + #flatten foreground probabilities, sort and trim down to highest confidences by pre_nms limit. + fg_probs = probs[:, 1:].contiguous() + flat_probs, flat_probs_order = fg_probs.view(-1).sort(descending=True) + keep_ix = flat_probs_order[:cf.pre_nms_limit] + # reshape indices to 2D index array with shape like fg_probs. + keep_arr = torch.cat(((keep_ix / fg_probs.shape[1]).unsqueeze(1), (keep_ix % fg_probs.shape[1]).unsqueeze(1)), 1) + + pre_nms_scores = flat_probs[:cf.pre_nms_limit] + pre_nms_class_ids = keep_arr[:, 1] + 1 # add background again. + pre_nms_batch_ixs = batch_ixs[keep_arr[:, 0]] + pre_nms_anchors = anchors[keep_arr[:, 0]] + pre_nms_deltas = deltas[keep_arr[:, 0]] + pre_nms_regressions = regressions[keep_arr[:, 0]] + keep = torch.arange(pre_nms_scores.size()[0]).long().cuda() + + # apply bounding box deltas. re-scale to image coordinates. + std_dev = torch.from_numpy(np.reshape(cf.rpn_bbox_std_dev, [1, cf.dim * 2])).float().cuda() + scale = torch.from_numpy(cf.scale).float().cuda() + refined_rois = mutils.apply_box_deltas_2D(pre_nms_anchors / scale, pre_nms_deltas * std_dev) * scale \ + if cf.dim == 2 else mutils.apply_box_deltas_3D(pre_nms_anchors / scale, pre_nms_deltas * std_dev) * scale + + # round and cast to int since we're deadling with pixels now + refined_rois = mutils.clip_to_window(cf.window, refined_rois) + pre_nms_rois = torch.round(refined_rois) + for j, b in enumerate(mutils.unique1d(pre_nms_batch_ixs)): + + bixs = torch.nonzero(pre_nms_batch_ixs == b)[:, 0] + bix_class_ids = pre_nms_class_ids[bixs] + bix_rois = pre_nms_rois[bixs] + bix_scores = pre_nms_scores[bixs] + + for i, class_id in enumerate(mutils.unique1d(bix_class_ids)): + + ixs = torch.nonzero(bix_class_ids == class_id)[:, 0] + # nms expects boxes sorted by score. + ix_rois = bix_rois[ixs] + ix_scores = bix_scores[ixs] + ix_scores, order = ix_scores.sort(descending=True) + ix_rois = ix_rois[order, :] + ix_scores = ix_scores + + if cf.dim == 2: + class_keep = nms_2D(torch.cat((ix_rois, ix_scores.unsqueeze(1)), dim=1), cf.detection_nms_threshold) + else: + class_keep = nms_3D(torch.cat((ix_rois, ix_scores.unsqueeze(1)), dim=1), cf.detection_nms_threshold) + + # map indices back. + class_keep = keep[bixs[ixs[order[class_keep]]]] + # merge indices over classes for current batch element + b_keep = class_keep if i == 0 else mutils.unique1d(torch.cat((b_keep, class_keep))) + + # only keep top-k boxes of current batch-element. + top_ids = pre_nms_scores[b_keep].sort(descending=True)[1][:cf.model_max_instances_per_batch_element] + b_keep = b_keep[top_ids] + # merge indices over batch elements. + batch_keep = b_keep if j == 0 else mutils.unique1d(torch.cat((batch_keep, b_keep))) + + keep = batch_keep + + # arrange output. + result = torch.cat((pre_nms_rois[keep], + pre_nms_batch_ixs[keep].unsqueeze(1).float(), + pre_nms_class_ids[keep].unsqueeze(1).float(), + pre_nms_scores[keep].unsqueeze(1), + pre_nms_regressions[keep]), dim=1) + + return result + + + +def gt_anchor_matching(cf, anchors, gt_boxes, gt_class_ids=None, gt_regressions=None): + """Given the anchors and GT boxes, compute overlaps and identify positive + anchors and deltas to refine them to match their corresponding GT boxes. + + anchors: [num_anchors, (y1, x1, y2, x2, (z1), (z2))] + gt_boxes: [num_gt_boxes, (y1, x1, y2, x2, (z1), (z2))] + gt_class_ids (optional): [num_gt_boxes] Integer class IDs for one stage detectors. in RPN case of Mask R-CNN, + set all positive matches to 1 (foreground) + gt_regressions: [num_gt_rgs, n_rg_feats], if None empty rg_targets are returned + + Returns: + anchor_class_matches: [N] (int32) matches between anchors and GT boxes. class_id = positive anchor, + -1 = negative anchor, 0 = neutral. i.e., positively matched anchors are marked by class_id (which is >0). + anchor_delta_targets: [N, (dy, dx, (dz), log(dh), log(dw), (log(dd)))] Anchor bbox deltas. + anchor_rg_targets: [n_anchors, n_rg_feats] + """ + + anchor_class_matches = np.zeros([anchors.shape[0]], dtype=np.int32) + anchor_delta_targets = np.zeros((cf.rpn_train_anchors_per_image, 2*cf.dim)) + if gt_regressions is not None: + if 'regression_bin' in cf.prediction_tasks: + anchor_rg_targets = np.zeros((cf.rpn_train_anchors_per_image,)) + else: + anchor_rg_targets = np.zeros((cf.rpn_train_anchors_per_image, cf.regression_n_features)) + else: + anchor_rg_targets = np.array([]) + + anchor_matching_iou = cf.anchor_matching_iou + + if gt_boxes is None: + anchor_class_matches = np.full(anchor_class_matches.shape, fill_value=-1) + return anchor_class_matches, anchor_delta_targets, anchor_rg_targets + + # for mrcnn: anchor matching is done for RPN loss, so positive labels are all 1 (foreground) + if gt_class_ids is None: + gt_class_ids = np.array([1] * len(gt_boxes)) + + # Compute overlaps [num_anchors, num_gt_boxes] + overlaps = mutils.compute_overlaps(anchors, gt_boxes) + + # Match anchors to GT Boxes + # If an anchor overlaps a GT box with IoU >= anchor_matching_iou then it's positive. + # If an anchor overlaps a GT box with IoU < 0.1 then it's negative. + # Neutral anchors are those that don't match the conditions above, + # and they don't influence the loss function. + # However, don't keep any GT box unmatched (rare, but happens). Instead, + # match it to the closest anchor (even if its max IoU is < 0.1). + + # 1. Set negative anchors first. They get overwritten below if a GT box is + # matched to them. Skip boxes in crowd areas. + anchor_iou_argmax = np.argmax(overlaps, axis=1) + anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax] + if anchors.shape[1] == 4: + anchor_class_matches[(anchor_iou_max < 0.1)] = -1 + elif anchors.shape[1] == 6: + anchor_class_matches[(anchor_iou_max < 0.01)] = -1 + else: + raise ValueError('anchor shape wrong {}'.format(anchors.shape)) + + # 2. Set an anchor for each GT box (regardless of IoU value). + gt_iou_argmax = np.argmax(overlaps, axis=0) + for ix, ii in enumerate(gt_iou_argmax): + anchor_class_matches[ii] = gt_class_ids[ix] + + # 3. Set anchors with high overlap as positive. + above_thresh_ixs = np.argwhere(anchor_iou_max >= anchor_matching_iou) + anchor_class_matches[above_thresh_ixs] = gt_class_ids[anchor_iou_argmax[above_thresh_ixs]] + + # Subsample to balance positive anchors. + ids = np.where(anchor_class_matches > 0)[0] + extra = len(ids) - (cf.rpn_train_anchors_per_image // 2) + if extra > 0: + # Reset the extra ones to neutral + ids = np.random.choice(ids, extra, replace=False) + anchor_class_matches[ids] = 0 + + # Leave all negative proposals negative for now and sample from them later in online hard example mining. + # For positive anchors, compute shift and scale needed to transform them to match the corresponding GT boxes. + ids = np.where(anchor_class_matches > 0)[0] + ix = 0 # index into anchor_delta_targets + for i, a in zip(ids, anchors[ids]): + # closest gt box (it might have IoU < anchor_matching_iou) + gt = gt_boxes[anchor_iou_argmax[i]] + + # convert coordinates to center plus width/height. + gt_h = gt[2] - gt[0] + gt_w = gt[3] - gt[1] + gt_center_y = gt[0] + 0.5 * gt_h + gt_center_x = gt[1] + 0.5 * gt_w + # Anchor + a_h = a[2] - a[0] + a_w = a[3] - a[1] + a_center_y = a[0] + 0.5 * a_h + a_center_x = a[1] + 0.5 * a_w + + if cf.dim == 2: + anchor_delta_targets[ix] = [ + (gt_center_y - a_center_y) / a_h, + (gt_center_x - a_center_x) / a_w, + np.log(gt_h / a_h), + np.log(gt_w / a_w)] + else: + gt_d = gt[5] - gt[4] + gt_center_z = gt[4] + 0.5 * gt_d + a_d = a[5] - a[4] + a_center_z = a[4] + 0.5 * a_d + anchor_delta_targets[ix] = [ + (gt_center_y - a_center_y) / a_h, + (gt_center_x - a_center_x) / a_w, + (gt_center_z - a_center_z) / a_d, + np.log(gt_h / a_h), + np.log(gt_w / a_w), + np.log(gt_d / a_d)] + + # normalize. + anchor_delta_targets[ix] /= cf.rpn_bbox_std_dev + if gt_regressions is not None: + anchor_rg_targets[ix] = gt_regressions[anchor_iou_argmax[i]] + + ix += 1 + + return anchor_class_matches, anchor_delta_targets, anchor_rg_targets + +############################################################ +# RetinaNet Class +############################################################ + + +class net(nn.Module): + """Encapsulates the RetinaNet model functionality. + """ + + def __init__(self, cf, logger): + """ + cf: A Sub-class of the cf class + model_dir: Directory to save training logs and trained weights + """ + super(net, self).__init__() + self.cf = cf + self.logger = logger + self.build() + if self.cf.weight_init is not None: + logger.info("using pytorch weight init of type {}".format(self.cf.weight_init)) + mutils.initialize_weights(self) + else: + logger.info("using default pytorch weight init") + + self.debug_acm = [] + + def build(self): + """Build Retina Net architecture.""" + + # Image size must be dividable by 2 multiple times. + h, w = self.cf.patch_size[:2] + if h / 2 ** 5 != int(h / 2 ** 5) or w / 2 ** 5 != int(w / 2 ** 5): + raise Exception("Image size must be divisible by 2 at least 5 times " + "to avoid fractions when downscaling and upscaling." + "For example, use 256, 320, 384, 448, 512, ... etc. ") + + backbone = utils.import_module('bbone', self.cf.backbone_path) + self.logger.info("loaded backbone from {}".format(self.cf.backbone_path)) + conv = backbone.ConvGenerator(self.cf.dim) + + + # build Anchors, FPN, Classifier / Bbox-Regressor -head + self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf) + self.anchors = torch.from_numpy(self.np_anchors).float().cuda() + self.fpn = backbone.FPN(self.cf, conv, operate_stride1=self.cf.operate_stride1).cuda() + self.classifier = Classifier(self.cf, conv).cuda() + self.bb_regressor = BBRegressor(self.cf, conv).cuda() + + if 'regression' in self.cf.prediction_tasks: + self.roi_regressor = RoIRegressor(self.cf, conv, self.cf.regression_n_features).cuda() + elif 'regression_bin' in self.cf.prediction_tasks: + # classify into bins of regression values + self.roi_regressor = RoIRegressor(self.cf, conv, len(self.cf.bin_labels)).cuda() + else: + self.roi_regressor = lambda x: [torch.tensor([]).cuda()] + + if self.cf.model == 'retina_unet': + self.final_conv = conv(self.cf.end_filts, self.cf.num_seg_classes, ks=1, pad=0, norm=self.cf.norm, relu=None) + + def forward(self, img): + """ + :param img: input img (b, c, y, x, (z)). + """ + # Feature extraction + fpn_outs = self.fpn(img) + if self.cf.model == 'retina_unet': + seg_logits = self.final_conv(fpn_outs[0]) + selected_fmaps = [fpn_outs[i + 1] for i in self.cf.pyramid_levels] + else: + seg_logits = None + selected_fmaps = [fpn_outs[i] for i in self.cf.pyramid_levels] + + # Loop through pyramid layers + class_layer_outputs, bb_reg_layer_outputs, roi_reg_layer_outputs = [], [], [] # list of lists + for p in selected_fmaps: + class_layer_outputs.append(self.classifier(p)) + bb_reg_layer_outputs.append(self.bb_regressor(p)) + roi_reg_layer_outputs.append(self.roi_regressor(p)) + + # Concatenate layer outputs + # Convert from list of lists of level outputs to list of lists + # of outputs across levels. + # e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]] + class_logits = list(zip(*class_layer_outputs)) + class_logits = [torch.cat(list(o), dim=1) for o in class_logits][0] + bb_outputs = list(zip(*bb_reg_layer_outputs)) + bb_outputs = [torch.cat(list(o), dim=1) for o in bb_outputs][0] + if not 0 == roi_reg_layer_outputs[0][0].shape[0]: + rg_outputs = list(zip(*roi_reg_layer_outputs)) + rg_outputs = [torch.cat(list(o), dim=1) for o in rg_outputs][0] + else: + if self.cf.dim == 2: + n_feats = np.array([p.shape[-2] * p.shape[-1] * self.cf.n_anchors_per_pos for p in selected_fmaps]).sum() + else: + n_feats = np.array([p.shape[-3]*p.shape[-2]*p.shape[-1]*self.cf.n_anchors_per_pos for p in selected_fmaps]).sum() + rg_outputs = torch.zeros((selected_fmaps[0].shape[0], n_feats, self.cf.regression_n_features), + dtype=torch.float32).fill_(float('NaN')).cuda() + + # merge batch_dimension and store info in batch_ixs for re-allocation. + batch_ixs = torch.arange(class_logits.shape[0]).unsqueeze(1).repeat(1, class_logits.shape[1]).view(-1).cuda() + flat_class_softmax = F.softmax(class_logits.view(-1, class_logits.shape[-1]), 1) + flat_bb_outputs = bb_outputs.view(-1, bb_outputs.shape[-1]) + flat_rg_outputs = rg_outputs.view(-1, rg_outputs.shape[-1]) + + detections = refine_detections(self.anchors, flat_class_softmax, flat_bb_outputs, flat_rg_outputs, batch_ixs, + self.cf) + + return detections, class_logits, bb_outputs, rg_outputs, seg_logits + + + def get_results(self, img_shape, detections, seg_logits, box_results_list=None): + """ + Restores batch dimension of merged detections, unmolds detections, creates and fills results dict. + :param img_shape: + :param detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score, + pred_regression) + :param box_results_list: None or list of output boxes for monitoring/plotting. + each element is a list of boxes per batch element. + :return: results_dict: dictionary with keys: + 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: + [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] + 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, 1] only fg. vs. bg for now. + class-specific return of masks will come with implementation of instance segmentation evaluation. + """ + detections = detections.cpu().data.numpy() + batch_ixs = detections[:, self.cf.dim*2] + detections = [detections[batch_ixs == ix] for ix in range(img_shape[0])] + + if box_results_list == None: # for test_forward, where no previous list exists. + box_results_list = [[] for _ in range(img_shape[0])] + + for ix in range(img_shape[0]): + + if not 0 in detections[ix].shape: + + boxes = detections[ix][:, :2 * self.cf.dim].astype(np.int32) + class_ids = detections[ix][:, 2 * self.cf.dim + 1].astype(np.int32) + scores = detections[ix][:, 2 * self.cf.dim + 2] + regressions = detections[ix][:, 2 * self.cf.dim + 3:] + + # Filter out detections with zero area. Often only happens in early + # stages of training when the network weights are still a bit random. + if self.cf.dim == 2: + exclude_ix = np.where((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0] + else: + exclude_ix = np.where( + (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4]) <= 0)[0] + + if exclude_ix.shape[0] > 0: + boxes = np.delete(boxes, exclude_ix, axis=0) + class_ids = np.delete(class_ids, exclude_ix, axis=0) + scores = np.delete(scores, exclude_ix, axis=0) + regressions = np.delete(regressions, exclude_ix, axis=0) + + if not 0 in boxes.shape: + for ix2, score in enumerate(scores): + if score >= self.cf.model_min_confidence: + box = {'box_type': 'det', 'box_coords': boxes[ix2], 'box_score': score, + 'box_pred_class_id': class_ids[ix2]} + if "regression_bin" in self.cf.prediction_tasks: + # in this case, regression preds are actually the rg_bin_ids --> map to rg value the bin stands for + box['rg_bin'] = regressions[ix2].argmax() + box['regression'] = self.cf.bin_id2rg_val[box['rg_bin']] + else: + box['regression'] = regressions[ix2] + if hasattr(self.cf, "rg_val_to_bin_id") and \ + any(['regression' in task for task in self.cf.prediction_tasks]): + box['rg_bin'] = self.cf.rg_val_to_bin_id(regressions[ix2]) + box_results_list[ix].append(box) + + + results_dict = {} + results_dict['boxes'] = box_results_list + if seg_logits is None: + # output dummy segmentation for retina_net. + out_logits_shape = list(img_shape) + out_logits_shape[1] = self.cf.num_seg_classes + results_dict['seg_preds'] = np.zeros(out_logits_shape, dtype=np.float16) + #todo: try with seg_preds=None? as to not carry heavy dummy preds. + else: + # output label maps for retina_unet. + results_dict['seg_preds'] = F.softmax(seg_logits, 1).cpu().data.numpy() + + return results_dict + + + def train_forward(self, batch, is_validation=False): + """ + train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data + for processing, computes losses, and stores outputs in a dictionary. + :param batch: dictionary containing 'data', 'seg', etc. + :return: results_dict: dictionary with keys: + 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: + [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] + 'seg_preds': pixelwise segmentation output (b, c, y, x, (z)) with values [0, .., n_classes]. + 'torch_loss': 1D torch tensor for backprop. + 'class_loss': classification loss for monitoring. + """ + img = batch['data'] + gt_class_ids = batch['class_targets'] + gt_boxes = batch['bb_target'] + if 'regression' in self.cf.prediction_tasks: + gt_regressions = batch["regression_targets"] + elif 'regression_bin' in self.cf.prediction_tasks: + gt_regressions = batch["rg_bin_targets"] + else: + gt_regressions = None + + var_seg_ohe = torch.FloatTensor(mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).cuda() + var_seg = torch.LongTensor(batch['seg']).cuda() + + img = torch.from_numpy(img).float().cuda() + torch_loss = torch.FloatTensor([0]).cuda() + + # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. + box_results_list = [[] for _ in range(img.shape[0])] + detections, class_logits, pred_deltas, pred_rgs, seg_logits = self.forward(img) + # loop over batch + for b in range(img.shape[0]): + # add gt boxes to results dict for monitoring. + if len(gt_boxes[b]) > 0: + for tix in range(len(gt_boxes[b])): + gt_box = {'box_type': 'gt', 'box_coords': batch['bb_target'][b][tix]} + for name in self.cf.roi_items: + gt_box.update({name: batch[name][b][tix]}) + box_results_list[b].append(gt_box) + + # match gt boxes with anchors to generate targets. + anchor_class_match, anchor_target_deltas, anchor_target_rgs = gt_anchor_matching( + self.cf, self.np_anchors, gt_boxes[b], gt_class_ids[b], gt_regressions[b] if gt_regressions is not None else None) + + # add positive anchors used for loss to results_dict for monitoring. + pos_anchors = mutils.clip_boxes_numpy( + self.np_anchors[np.argwhere(anchor_class_match > 0)][:, 0], img.shape[2:]) + for p in pos_anchors: + box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'}) + + else: + anchor_class_match = np.array([-1]*self.np_anchors.shape[0]) + anchor_target_deltas = np.array([]) + anchor_target_rgs = np.array([]) + + anchor_class_match = torch.from_numpy(anchor_class_match).cuda() + anchor_target_deltas = torch.from_numpy(anchor_target_deltas).float().cuda() + anchor_target_rgs = torch.from_numpy(anchor_target_rgs).float().cuda() + + if self.cf.focal_loss: + # compute class loss as focal loss as suggested in original publication, but multi-class. + class_loss = compute_focal_class_loss(anchor_class_match, class_logits[b], gamma=self.cf.focal_loss_gamma) + # sparing appendix of negative anchors for monitoring as not really relevant + else: + # compute class loss with SHEM. + class_loss, neg_anchor_ix = compute_class_loss(anchor_class_match, class_logits[b]) + # add negative anchors used for loss to results_dict for monitoring. + neg_anchors = mutils.clip_boxes_numpy( + self.np_anchors[np.argwhere(anchor_class_match == -1)][0, neg_anchor_ix], img.shape[2:]) + for n in neg_anchors: + box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'}) + rg_loss = compute_rg_loss(self.cf.prediction_tasks, anchor_target_rgs, pred_rgs[b], anchor_class_match) + bbox_loss = compute_bbox_loss(anchor_target_deltas, pred_deltas[b], anchor_class_match) + torch_loss += (class_loss + bbox_loss + rg_loss) / img.shape[0] + + + results_dict = self.get_results(img.shape, detections, seg_logits, box_results_list) + results_dict['seg_preds'] = results_dict['seg_preds'].argmax(axis=1).astype('uint8')[:, np.newaxis] + + if self.cf.model == 'retina_unet': + seg_loss_dice = 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1),var_seg_ohe) + seg_loss_ce = F.cross_entropy(seg_logits, var_seg[:, 0]) + torch_loss += (seg_loss_dice + seg_loss_ce) / 2 + #self.logger.info("loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}, seg dice: {3:.3f}, seg ce: {4:.3f}, " + # "mean pixel preds: {5:.5f}".format(torch_loss.item(), batch_class_loss.item(), batch_bbox_loss.item(), + # seg_loss_dice.item(), seg_loss_ce.item(), np.mean(results_dict['seg_preds']))) + if 'dice' in self.cf.metrics: + results_dict['batch_dices'] = mutils.dice_per_batch_and_class( + results_dict['seg_preds'], batch["seg"], self.cf.num_seg_classes, convert_to_ohe=True) + #else: + #self.logger.info("loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}".format( + # torch_loss.item(), class_loss.item(), bbox_loss.item())) + + + results_dict['torch_loss'] = torch_loss + results_dict['class_loss'] = class_loss.item() + + return results_dict + + def test_forward(self, batch, **kwargs): + """ + test method. wrapper around forward pass of network without usage of any ground truth information. + prepares input data for processing and stores outputs in a dictionary. + :param batch: dictionary containing 'data' + :return: results_dict: dictionary with keys: + 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: + [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] + 'seg_preds': actually contain seg probabilities since evaluated to seg_preds (via argmax) in predictor. + or dummy seg logits for real retina net (detection only) + """ + img = torch.from_numpy(batch['data']).float().cuda() + detections, _, _, _, seg_logits = self.forward(img) + results_dict = self.get_results(img.shape, detections, seg_logits) + return results_dict \ No newline at end of file diff --git a/plotting.py b/plotting.py new file mode 100644 index 0000000..d53d3e5 --- /dev/null +++ b/plotting.py @@ -0,0 +1,2135 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import matplotlib +# matplotlib.rcParams['font.family'] = ['serif'] +# matplotlib.rcParams['font.serif'] = ['Times New Roman'] +matplotlib.rcParams['mathtext.fontset'] = 'cm' +matplotlib.rcParams['font.family'] = 'STIXGeneral' +matplotlib.use('Agg') #complains with spyder editor, bc spyder imports mpl at startup +from matplotlib.ticker import FormatStrFormatter +import matplotlib.colors as mcolors + +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +import matplotlib.patches as mpatches +from matplotlib.ticker import StrMethodFormatter, ScalarFormatter +import SimpleITK as sitk +from tensorboard.backend.event_processing.event_multiplexer import EventMultiplexer + +import sys +import os +import warnings + +from copy import deepcopy + +import numpy as np +import pandas as pd +import scipy.interpolate as interpol + +from utils.exp_utils import IO_safe + +warnings.filterwarnings("ignore", module="matplotlib.image") + + +def make_colormap(seq): + """ Return a LinearSegmentedColormap + seq: a sequence of floats and RGB-tuples. The floats should be increasing + and in the interval (0,1). + """ + seq = [(None,) * 3, 0.0] + list(seq) + [1.0, (None,) * 3] + cdict = {'red': [], 'green': [], 'blue': []} + for i, item in enumerate(seq): + if isinstance(item, float): + r1, g1, b1 = seq[i - 1] + r2, g2, b2 = seq[i + 1] + cdict['red'].append([item, r1, r2]) + cdict['green'].append([item, g1, g2]) + cdict['blue'].append([item, b1, b2]) + return mcolors.LinearSegmentedColormap('CustomMap', cdict) +bw_cmap = make_colormap([(1.,1.,1.), (0.,0.,0.)]) + +#------------------------------------------------------------------------ +#------------- plotting functions, not all are used --------------------- + + +def shape_small_first(shape): + """sort a tuple so that the smallest entry is swapped to the beginning + """ + if len(shape) <= 2: # no changing dimensions if channel-dim is missing + return shape + smallest_dim = np.argmin(shape) + if smallest_dim != 0: # assume that smallest dim is color channel + new_shape = np.array(shape) # to support mask indexing + new_shape = (new_shape[smallest_dim], + *new_shape[(np.arange(len(shape), dtype=int) != smallest_dim)]) + return new_shape + else: + return shape + +def RGB_to_rgb(RGB): + rgb = np.array(RGB) / 255. + return rgb + +def mod_to_rgb(arr, cmap=None): + """convert a single-channel modality img to 3-color-channel img. + :param arr: input img, expected in shape (b,c,)x,y with c=1 + :return: img of shape (...,c') with c'=3 + """ + if len(arr.shape) == 3: + arr = np.squeeze(arr) + elif len(arr.shape) != 2: + raise Exception("Invalid input arr shape: {}".format(arr.shape)) + + if cmap is None: + cmap = "gray" + norm = matplotlib.colors.Normalize() + norm.autoscale(arr) + arr = norm(arr) + arr = np.stack((arr,) * 3, axis=-1) + + return arr + +def to_rgb(arr, cmap): + """ + Transform an integer-labeled segmentation map using an rgb color-map. + :param arr: img_arr w/o a color-channel + :param cmap: dictionary mapping from integer class labels to rgb values + :return: img of shape (...,c) + """ + new_arr = np.zeros(shape=(arr.shape) + (3,)) + for l in cmap.keys(): + ixs = np.where(arr == l) + new_arr[ixs] = np.array([cmap[l][i] for i in range(3)]) + + return new_arr + +def to_rgba(arr, cmap): + """ + Transform an integer-labeled segmentation map using an rgba color-map. + :param arr: img_arr w/o a color-channel + :param cmap: dictionary mapping from integer class labels to rgba values + :return: new array holding rgba-image + """ + new_arr = np.zeros(shape=(arr.shape) + (4,)) + for lab, val in cmap.items(): + # in case no alpha, complement with 100% alpha + if len(val) == 3: + cmap[lab] = (*val, 1.) + assert len(cmap[lab]) == 4, "cmap has color with {} entries".format(len(val)) + + for lab in cmap.keys(): + ixs = np.where(arr == lab) + rgb = np.array(cmap[lab][:3]) + new_arr[ixs] = np.append(rgb, cmap[lab][3]) + + return new_arr + +def bin_seg_to_rgba(arr, color): + """ + Transform a continuously labelled binary segmentation map using an rgba color-map. + values are expected to be 0-1, will give alpha-value + :param arr: img_arr w/o a color-channel + :param color: color to give img + :return: new array holding rgba-image + """ + new_arr = np.zeros(shape=(arr.shape) + (4,)) + + for i in range(arr.shape[0]): + for j in range(arr.shape[1]): + new_arr[i][j] = (*color, arr[i][j]) + + return new_arr + +def suppress_axes_lines(ax): + """ + :param ax: pyplot axes object + """ + ax.axes.get_xaxis().set_ticks([]) + ax.axes.get_yaxis().set_ticks([]) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.spines['bottom'].set_visible(False) + ax.spines['left'].set_visible(False) + + return + +def label_bar(ax, rects, labels=None, colors=None, fontsize=10): + """Attach a text label above each bar displaying its height + + :param ax: + :param rects: rectangles as returned by plt.bar() + :param labels: + :param colors: + """ + for ix, rect in enumerate(rects): + height = rect.get_height() + if labels is not None and labels[ix] is not None: + label = labels[ix] + else: + label = '{:g}'.format(height) + if colors is not None and colors[ix] is not None and np.any(np.array(colors[ix])<1): + color = colors[ix] + else: + color = 'black' + ax.text(rect.get_x() + rect.get_width() / 2., 1.007 * height, label, color=color, ha='center', va='bottom', + bbox=dict(facecolor=(1., 1., 1.), edgecolor='none', clip_on=True, pad=0, alpha=0.75), fontsize=fontsize) + +def draw_box_into_arr(arr, box_coords, box_color=None, lw=2): + """ + :param arr: imgs shape, (3,y,x) + :param box_coords: (x1,y1,x2,y2), in ascending order + :param box_color: arr of shape (3,) + :param lw: linewidth in pixels + """ + if box_color is None: + box_color = [1., 0.4, 0.] + + (x1, y1, x2, y2) = box_coords[:4] + + arr = np.swapaxes(arr, 0, -1) + arr[..., y1:y2, x1:x1 + lw, :], arr[..., y1:y2 + lw, x2:x2 + lw, :] = box_color, box_color + arr[..., y1:y1 + lw, x1:x2, :], arr[..., y2:y2 + lw, x1:x2, :] = box_color, box_color + arr = np.swapaxes(arr, 0, -1) + + return arr + +def draw_boxes_into_batch(imgs, batch_boxes, type2color=None, cmap=None): + """ + :param imgs: either the actual batch imgs or a tuple with shape of batch imgs, + need to have 3 color channels, need to be rgb; + """ + if isinstance(imgs, tuple): + img_oshp = imgs + imgs = None + else: + img_oshp = imgs[0].shape + + img_shp = shape_small_first(img_oshp) # c,x/y,y/x now + imgs = np.reshape(imgs, (-1, *img_shp)) + box_imgs = np.empty((len(batch_boxes), *(img_shp))) + + for sample, boxes in enumerate(batch_boxes): + # imgs in batch have shape b,c,x,y, swap c to end + sample_img = np.full(img_shp, 1.) if imgs is None else imgs[sample] + for box in boxes: + if len(box["box_coords"]) > 0: + if type2color is not None and "box_type" in box.keys(): + sample_img = draw_box_into_arr(sample_img, box["box_coords"].astype(np.int32), + type2color[box["box_type"]]) + else: + sample_img = draw_box_into_arr(sample_img, box["box_coords"].astype(np.int32)) + box_imgs[sample] = sample_img + + return box_imgs + + +def plot_prediction_hist(cf, spec_df, outfile, title=None, fs=11, ax=None): + + labels = spec_df.class_label.values + preds = spec_df.pred_score.values + type_list = spec_df.det_type.tolist() if hasattr(spec_df, "det_type") else None + if title is None: + title = outfile.split('/')[-1] + ' count:{}'.format(len(labels)) + close=False + if ax is None: + fig = plt.figure(tight_layout=True) + ax = fig.add_subplot(1,1,1) + close=True + ax.set_yscale('log') + + ax.set_xlabel("Prediction Score", fontsize=fs) + ax.set_ylabel("Occurences", fontsize=fs) + + ax.hist(preds[labels == 0], alpha=0.3, color=cf.red, range=(0, 1), bins=50, label="fp") + ax.hist(preds[labels == 1], alpha=0.3, color=cf.blue, range=(0, 1), bins=50, label="fn at score 0 and tp") + ax.axvline(x=cf.min_det_thresh, alpha=1, color=cf.orange, linewidth=1.5, label="min det thresh") + + if type_list is not None: + fp_count = type_list.count('det_fp') + fn_count = type_list.count('det_fn') + tp_count = type_list.count('det_tp') + pos_count = fn_count + tp_count + title += '\ntp:{} fp:{} fn:{} pos:{}'.format(tp_count, fp_count, fn_count, pos_count) + + ax.set_title(title, fontsize=fs) + ax.tick_params(axis='both', which='major', labelsize=fs) + ax.tick_params(axis='both', which='minor', labelsize=fs) + + if close: + ax.legend(loc="best", fontsize=fs) + if cf.server_env: + IO_safe(plt.savefig, fname=outfile, _raise=False) + else: + plt.savefig(outfile) + plt.close() + +def plot_wbc_n_missing(cf, df, outfile, fs=11, ax=None): + """ WBC (weighted box clustering) has parameter n_missing, which shows how many boxes are missing per cluster. + This function plots the average relative amount of missing boxes sorted by cluster score. + :param cf: config. + :param df: dataframe. + :param outfile: path to save image under. + :param fs: fontsize. + :param ax: axes object. + """ + + bins = np.linspace(0., 1., 10) + names = ["{:.1f}".format((bins[i]+(bins[i+1]-bins[i])/2.)*100) for i in range(len(bins)-1)] + classes = df.pred_class.unique() + colors = [cf.class_id2label[cl_id].color for cl_id in classes] + + binned_df = df.copy() + binned_df.loc[:,"pred_score"] = pd.cut(binned_df["pred_score"], bins) + + close=False + if ax is None: + ax = plt.subplot() + close=True + width = 1 / (len(classes) + 1) + group_positions = np.arange(len(names)) + legend_handles = [] + + for ix, cl_id in enumerate(classes): + cl_df = binned_df[binned_df.pred_class==cl_id].groupby("pred_score").agg({"cluster_n_missing": 'mean'}) + ax.bar(group_positions + ix * width, cl_df.cluster_n_missing.values, width=width, color=colors[ix], + alpha=0.4 + ix / 2 / len(classes), edgecolor=colors[ix]) + legend_handles.append(mpatches.Patch(color=colors[ix], label=cf.class_dict[cl_id])) + + title = "Fold {} WBC Missing Preds\nAverage over scores and classes: {:.1f}%".format(cf.fold, df.cluster_n_missing.mean()) + ax.set_title(title, fontsize=fs) + ax.legend(handles=legend_handles, title="Class", loc="best", fontsize=fs, title_fontsize=fs) + ax.set_xticks(group_positions + (len(classes) - 1) * width / 2) + # ax.xaxis.set_major_formatter(StrMethodFormatter('{x:.1f}')) THIS WONT WORK... no clue! + ax.set_xticklabels(names) + ax.tick_params(axis='both', which='major', labelsize=fs) + ax.tick_params(axis='both', which='minor', labelsize=fs) + + ax.set_axisbelow(True) + ax.grid() + ax.set_ylabel(r"Average Missing Preds per Cluster (%)", fontsize=fs) + ax.set_xlabel("Prediction Score", fontsize=fs) + + if close: + if cf.server_env: + IO_safe(plt.savefig, fname=outfile, _raise=False) + else: + plt.savefig(outfile) + plt.close() + +def plot_stat_curves(cf, stats, outfile, fill=False): + """ Plot precision-recall and/or receiver-operating-characteristic curve(s). + :param cf: config. + :param stats: statistics as supplied by Evaluator. + :param outfile: path to save plot under. + :param fill: whether to colorize space between plot and x-axis. + :return: + """ + + for c in ['roc', 'prc']: + plt.figure() + empty_plot = True + for ix, s in enumerate(stats): + if s[c] is not np.nan: + plt.plot(s[c][1], s[c][0], label=s['name'] + '_' + c, marker=None, + color=cf.color_palette[ix%len(cf.color_palette)]) + empty_plot = False + if fill: + plt.fill_between(s[c][1], s[c][0], alpha=0.33, color=cf.color_palette[ix%len(cf.color_palette)]) + if not empty_plot: + plt.title(outfile.split('/')[-1] + '_' + c) + plt.legend(loc=3 if c == 'prc' else 4) + plt.ylabel('precision' if c == 'prc' else '1-spec.') + plt.ylim((0.,1)) + plt.xlabel('recall') + + plt.savefig(outfile + '_' + c) + plt.close() + + +def plot_grouped_bar_chart(cf, bar_values, groups, splits, colors=None, alphas=None, errors=None, ylabel='', xlabel='', + xticklabels=None, yticks=None, yticklabels=None, ylim=None, label_format="{:.3f}", + title=None, ax=None, out_file=None, legend=False, fs=11): + """ Plot a categorically grouped bar chart. + :param cf: config. + :param bar_values: values of the bars. + :param groups: groups/categories that bars belong to. + :param splits: splits within groups, i.e., names of bars. + :param colors: colors. + :param alphas: 1-opacity. + :param errors: values for errorbars. + :param ylabel: label of y-axis. + :param xlabel: label of x-axis. + :param title: plot title. + :param ax: axes object to draw into. if None, new is created. + :param out_file: path to save plot. + :param legend: whether to show a legend. + :param fs: fontsize. + :return: legend handles. + """ + bar_values = np.array(bar_values) + if alphas is None: + alphas = [1.,] * len(splits) + if colors is None: + colors = [cf.color_palette[ix%len(cf.color_palette)] for ix in range(len(splits))] + if errors is None: + errors = np.zeros_like(bar_values) + # patterns = ('/', '\\', '*', 'O', '.', '-', '+', 'x', 'o') + # patterns = tuple([patterns[ix%len(patterns)] for ix in range(len(splits))]) + close=False + if ax is None: + ax = plt.subplot() + close=True + width = 1 / (len(splits) +0.25) + group_positions = np.arange(len(groups)) + + for ix, split in enumerate(splits): + rects = ax.bar(group_positions + ix * width, bar_values[ix], width=width, color=(*colors[ix], 0.8), + edgecolor=colors[ix], yerr=errors[ix], ecolor=(*np.array(colors[ix])*0.8, 1.), capsize=5) + # for ix, bar in enumerate(rects): + # bar.set_hatch(patterns[ix]) + labels = [label_format.format(val) for val in bar_values[ix]] + label_bar(ax, rects, labels, [colors[ix]]*len(labels), fontsize=fs) + + legend_handles = [mpatches.Patch(color=colors[ix], alpha=alphas[ix], label=split) for ix, split in + enumerate(splits)] + if legend: + ax.legend(handles=legend_handles, fancybox=True, framealpha=1., loc="lower center") + legend_handles = [(colors[ix], alphas[ix], split) for ix, split in enumerate(splits)] + + if title is not None: + ax.set_title(title, fontsize=fs) + + ax.set_xticks(group_positions + (len(splits) - 1) * width / 2) + if xticklabels is None: + ax.set_xticklabels(groups, fontsize=fs) + else: + ax.set_xticklabels(xticklabels, fontsize=fs) + ax.set_axisbelow(True) + ax.set_xlabel(xlabel, fontsize=fs) + ax.tick_params(labelsize=fs) + + ax.grid(axis='y') + ax.set_ylabel(ylabel, fontsize=fs) + if yticks is not None: + ax.set_yticks(yticks) + if yticklabels is not None: + ax.set_yticklabels(yticklabels, fontsize=fs) + if ylim is not None: + ax.set_ylim(ylim) + + if out_file is not None: + plt.savefig(out_file, dpi=600) + if close: + plt.close() + + return legend_handles + +def plot_binned_rater_dissent(cf, binned_stats, out_file=None, ax=None, legend=True, fs=11): + """ LIDC-specific plot: rater disagreement as standard deviations within each bin. + :param cf: config. + :param binned_stats: list, ix==bin_id, item: [(roi_mean, roi_std, roi_max, roi_bin_id-roi_max_bin_id) for roi in bin] + :return: + """ + + dissent = [np.array([roi[1] for roi in bin]) for bin in binned_stats] + avg_dissent_first_degree = [np.mean(bin) for bin in dissent] + + groups = list(cf.bin_id2label.keys()) + splits = [r"$1^{st}$ std. dev.",] + colors = [cf.bin_id2label[bin_id].color[:3] for bin_id in groups] + #colors = [cf.blue for bin_id in groups] + alphas = [0.9,] + #patterns = ('/', '\\', '*', 'O', '.', '-', '+', 'x', 'o') + #patterns = tuple([patterns[ix%len(patterns)] for ix in range(len(splits))]) + + close=False + if ax is None: + ax = plt.subplot() + close=True + width = 1/(len(splits)+1) + group_positions = np.arange(len(groups)) + + #total_counts = [df.loc[split].sum() for split in splits] + dissent = np.array(avg_dissent_first_degree) + ix=0 + rects = ax.bar(group_positions+ix*width, dissent, color=colors, alpha=alphas[ix], + edgecolor=colors) + #for ix, bar in enumerate(rects): + #bar.set_hatch(patterns[ix]) + labels = ["{:.2f}".format(diss) for diss in dissent] + label_bar(ax, rects, labels, colors, fontsize=fs) + bin_edge_color = cf.blue + ax.axhline(y=0.5, color=bin_edge_color) + ax.text(2.5, 0.38, "bin edge", color=cf.white, fontsize=fs, horizontalalignment="center", + bbox=dict(boxstyle='round', facecolor=(*bin_edge_color, 0.85), edgecolor='none', clip_on=True, pad=0)) + + if legend: + legend_handles = [mpatches.Patch(color=cf.blue ,alpha=alphas[ix], label=split) for ix, split in enumerate(splits)] + ax.legend(handles=legend_handles, loc='lower center', fontsize=fs) + + title = "LIDC-IDRI: Average Std Deviation per Lesion" + plt.title(title) + + ax.set_xticks(group_positions + (len(splits)-1)*width/2) + ax.set_xticklabels(groups, fontsize=fs) + ax.set_axisbelow(True) + #ax.tick_params(axis='both', which='major', labelsize=fs) + #ax.tick_params(axis='both', which='minor', labelsize=fs) + ax.grid() + ax.set_ylabel(r"Average Dissent (MS)", fontsize=fs) + ax.set_xlabel("binned malignancy-score value (ms)", fontsize=fs) + ax.tick_params(labelsize=fs) + if out_file is not None: + plt.savefig(out_file, dpi=600) + + if close: + plt.close() + + return + +def plot_confusion_matrix(cf, cm, out_file=None, ax=None, fs=11, cmap=plt.cm.Blues, color_bar=True): + """ Plot a confusion matrix. + :param cf: config. + :param cm: confusion matrix, e.g., as supplied by metrics.confusion_matrix from scikit-learn. + :return: + """ + + close=False + if ax is None: + ax = plt.subplot() + close=True + + im = ax.imshow(cm, interpolation='nearest', cmap=cmap) + if color_bar: + ax.figure.colorbar(im, ax=ax) + + # Rotate the tick labels and set their alignment. + #plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") + + # Loop over data dimensions and create text annotations. + fmt = '.0%' if np.mod(cm, 1).any() else 'd' + thresh = cm.max() / 2. + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text(j, i, format(cm[i, j], fmt), + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + + ax.set_ylabel(r"Binned Mean MS", fontsize=fs) + ax.set_xlabel("Single-Annotator MS", fontsize=fs) + #ax.tick_params(labelsize=fs) + if close and out_file is not None: + plt.savefig(out_file, dpi=600) + + if close: + plt.close() + else: + return ax + +def plot_data_stats(cf, df, labels=None, out_file=None, ax=None, fs=11): + """ Plot data-set statistics. Shows target counts. Mainly used by Dataset Class in dataloader.py. + :param cf: configs obj + :param df: pandas dataframe + :param out_file: path to save fig in + """ + names = df.columns + if labels is not None: + colors = [label.color for name in names for label in labels if label.name==name] + else: + colors = [cf.color_palette[ix%len(cf.color_palette)] for ix in range(len(names))] + #patterns = ('/', '\\', '*', 'O', '.', '-', '+', 'x', 'o') + #patterns = tuple([patterns[ix%len(patterns)] for ix in range(len(splits))]) + if ax is None: + fig, ax = plt.subplots(figsize=(14,6), dpi=300) + return_ax = False + else: + return_ax = True + + plt.margins(x=0.01) + plt.subplots_adjust(bottom=0.15) + bar_positions = np.arange(len(names)) + name_counts = df.sum() + total_count = name_counts.sum() + + rects = ax.bar(bar_positions, name_counts, color=colors, alpha=0.9, edgecolor=colors) + labels = ["{:.0f}%".format(count/ total_count*100) for count in name_counts] + label_bar(ax, rects, labels, colors, fontsize=fs) + + title= "Data Set RoI-Target Balance\nTotal #RoIs: {}".format(int(total_count)) + ax.set_title(title, fontsize=fs) + ax.set_xticks(bar_positions) + rotation = "vertical" if np.any([len(str(name)) > 3 for name in names]) else None + if all([isinstance(name, (float, int)) for name in names]): + ax.set_xticklabels(["{:.2f}".format(name) for name in names], rotation=rotation, fontsize=fs) + else: + ax.set_xticklabels(names, rotation=rotation, fontsize=fs) + + ax.set_axisbelow(True) + ax.grid() + ax.set_ylabel(r"#RoIs", fontsize=fs) + ax.set_xlabel(str(df._metadata[0]), fontsize=fs) + ax.tick_params(axis='both', which='major', labelsize=fs) + ax.tick_params(axis='both', which='minor', labelsize=fs) + + if out_file is not None: + plt.savefig(out_file) + + if return_ax: + return ax + else: + plt.close() + +def plot_fold_stats(cf, df, labels=None, out_file=None, ax=None): + """ Similar as plot_data_stats but per single cross-val fold. + :param cf: configs obj + :param df: pandas dataframe + :param out_file: path to save fig in + """ + names = df.columns + splits = df.index + if labels is not None: + colors = [label.color for name in names for label in labels if label.name==name] + else: + colors = [cf.color_palette[ix%len(cf.color_palette)] for ix in range(len(names))] + #patterns = ('/', '\\', '*', 'O', '.', '-', '+', 'x', 'o') + #patterns = tuple([patterns[ix%len(patterns)] for ix in range(len(splits))]) + if ax is None: + ax = plt.subplot() + return_ax = False + else: + return_ax = True + width = 1/(len(names)+1) + group_positions = np.arange(len(splits)) + legend_handles = [] + + total_counts = [df.loc[split].sum() for split in splits] + + for ix, name in enumerate(names): + rects = ax.bar(group_positions+ix*width, df.loc[:,name], width=width, color=colors[ix], alpha=0.9, + edgecolor=colors[ix]) + #for ix, bar in enumerate(rects): + #bar.set_hatch(patterns[ix]) + labels = ["{:.0f}%".format(df.loc[split, name]/ total_counts[ii]*100) for ii, split in enumerate(splits)] + label_bar(ax, rects, labels, [colors[ix]]*len(group_positions)) + + legend_handles.append(mpatches.Patch(color=colors[ix] ,alpha=0.9, label=name)) + + title= "Fold {} RoI-Target Balances\nTotal #RoIs: {}".format(cf.fold, + int(df.values.sum())) + plt.title(title) + ax.legend(handles=legend_handles) + ax.set_xticks(group_positions + (len(names)-1)*width/2) + ax.set_xticklabels(splits, rotation="vertical" if len(splits)>2 else None, size=12) + ax.set_axisbelow(True) + ax.grid() + ax.set_ylabel(r"#RoIs") + ax.set_xlabel("Set split") + + if out_file is not None: + plt.savefig(out_file) + if return_ax: + return ax + plt.close() + +def plot_batchgen_distribution(cf, pids, p_probs, balance_target, out_file=None): + """plot top n_pids probabilities for drawing a pid into a batch. + :param cf: experiment config object + :param pids: sorted iterable of patient ids + :param p_probs: pid's drawing likelihood, order needs to match the one of pids. + :param out_file: + :return: + """ + n_pids = len(pids) + zip_sorted = np.array(sorted(list(zip(p_probs, pids)), reverse=True)) + names, probs = zip_sorted[:n_pids,1], zip_sorted[:n_pids,0].astype('float32') * 100 + try: + names = [str(int(n)) for n in names] + except ValueError: + names = [str(n) for n in names] + lowest_p = min(p_probs)*100 + fig, ax = plt.subplots(1,1,figsize=(17,5), dpi=200) + rects = ax.bar(names, probs, color=cf.blue, alpha=0.9, edgecolor=cf.blue) + ax = plt.gca() + ax.text(0.8, 0.92, "Lowest prob.: {:.5f}%".format(lowest_p), transform=ax.transAxes, color=cf.white, + bbox=dict(boxstyle='round', facecolor=cf.blue, edgecolor='none', alpha=0.9)) + ax.yaxis.set_major_formatter(StrMethodFormatter('{x:g}')) + ax.set_xticklabels(names, rotation="vertical", fontsize=7) + plt.margins(x=0.01) + plt.subplots_adjust(bottom=0.15) + if balance_target=="class_targets": + balance_target = "Class" + elif balance_target=="lesion_gleasons": + balance_target = "GS" + ax.set_title(str(balance_target)+"-Balanced Train Generator: Sampling Likelihood per PID") + ax.set_axisbelow(True) + ax.grid(axis='y') + ax.set_ylabel("Sampling Likelihood (%)") + ax.set_xlabel("PID") + plt.tight_layout() + + if out_file is not None: + plt.savefig(out_file) + + plt.close() + +def plot_batchgen_stats(cf, stats, target_name, unique_ts, out_file=None): + """Plot bar chart showing RoI frequencies and empty-sample count of batch stats recorded by BatchGenerator. + :param cf: config. + :param stats: statistics as supplied by BatchGenerator class. + :param out_file: path to save plot. + """ + + total_samples = cf.num_epochs*cf.num_train_batches*cf.batch_size + if target_name=="class_targets": + target_name = "Class" + label_dict = {cl_id: label for (cl_id, label) in cf.class_id2label.items()} + elif target_name=="lesion_gleasons": + target_name = "Lesion's Gleason Score" + label_dict = cf.gs2label + elif target_name=="rg_bin_targets": + target_name = "Regression-Bin ID" + label_dict = cf.bin_id2label + else: + raise NotImplementedError + names = [label_dict[t_id].name for t_id in unique_ts] + colors = [label_dict[t_id].color for t_id in unique_ts] + + title = "Training Target Frequencies" + title += "\nempty samples: {} ({:.1f}%)".format(stats['empty_samples_count'], stats['empty_samples_count']/total_samples*100) + rects = plt.bar(names, stats['roi_counts'], color=colors, alpha=0.9, edgecolor=colors) + ax = plt.gca() + + ax.yaxis.set_major_formatter(StrMethodFormatter('{x:g}')) + ax.set_title(title) + ax.set_axisbelow(True) + ax.grid() + ax.set_ylabel(r"#RoIs") + ax.set_xlabel(target_name) + + total_count = np.sum(stats["roi_counts"]) + labels = ["{:.0f}%".format(count/total_count*100) for count in stats["roi_counts"]] + label_bar(ax, rects, labels, colors) + + if out_file is not None: + plt.savefig(out_file) + + plt.close() + + +def view_3D_array(arr, outfile, elev=30, azim=30): + from mpl_toolkits.mplot3d import Axes3D + + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + + ax.set_aspect("equal") + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_zlabel("z") + ax.voxels(arr) + ax.view_init(elev=elev, azim=azim) + + plt.savefig(outfile) + +def view_batch(cf, batch, res_dict=None, out_file=None, legend=True, show_info=True, has_colorchannels=False, + isRGB=True, show_seg_ids="all", show_seg_pred=True, show_gt_boxes=True, show_gt_labels=False, + roi_items="all", sample_picks=None, vol_slice_picks=None, + box_score_thres=None, plot_mods=True, dpi=200, vmin=None, return_fig=False): + r""" View data and target entries of a batch. + + Batch expected as dic with entries 'data' and 'seg' holding np.arrays of + size :math:`batch\_size \times modalities \times h \times w` for data + and :math:`batch\_size \times classes \times h \times w` or + :math:`batch\_size \times 1 \times h \times w` for segs. + Classes, even if just dummy, are always needed for plotting since they determine colors. + Pyplot expects dimensions in order y,x,chans (height, width, chans) for imshow. + + :param cf: config. + :param batch: batch. + :param res_dict: results dictionary. + :param out_file: path to save plot. + :param legend: whether to show a legend. + :param show_info: whether to show text info about img sizes and type in plot. + :param has_colorchannels: whether image has color channels. + :param isRGB: if image is RGB. + :param show_seg_ids: "all" or None or list with seg classes to show (seg_ids) + :param show_seg_pred: whether to the predicted segmentation. + :param show_gt_boxes: whether to show ground-truth boxes. + :param show_gt_labels: whether to show labels of ground-truth boxes. + :param roi_items: which roi items to show: strings "all" or "targets". --> all roi items in cf.roi_items or only + those which are targets, or list holding keys/names of entries in cf.roi_items to plot additionally on roi boxes. + empty iterator to show none. + :param sample_picks: which indices of the batch to display. None for all. + :param vol_slice_picks: when batch elements are 3D: which slices to display. None for all, or tuples + ("random", int: amt) / (float€[0,1]: fg_prob, int: amt) for random pick / fg_slices pick w probability fg_prob + of amt slices. fg pick requires gt seg. + :param box_score_thres: plot only boxes with pred_score > box_score_thres. None or 0. for no threshold. + :param plot_mods: whether to plot input modality/modalities. + :param dpi: graphics resolution. + :param vmin: min value for gray-scale cmap in imshow, set to a fix value for inter-batch normalization, or None for + intra-batch. + :param return_fig: whether to return created figure. + """ + + # pfix = prefix, ptfix = postfix + patched_patient = 'patch_crop_coords' in list(batch.keys()) + pfix = 'patient_' if patched_patient else '' + ptfix = '_2d' if (patched_patient and cf.dim == 2 and pfix + 'class_targets_2d' in batch.keys()) else '' + # -------------- get data, set flags ----------------- + try: + btype = type(batch[pfix + 'data']) + data = batch[pfix + 'data'].astype("float32") + seg = batch[pfix + 'seg'] + except AttributeError: # in this case: assume it's single-annotator ground truths + btype = type(batch[pfix + 'data']) + data = batch[pfix + 'data'].astype("float32") + seg = batch[pfix + 'seg'][0] + print("Showing only gts of rater 0") + + data_init_shp, seg_init_shp = data.shape, seg.shape + seg = np.copy(seg) if show_seg_ids else None + plot_bg = batch['plot_bg'] if 'plot_bg' in batch.keys() and not isinstance(batch['plot_bg'], (int, float)) else None + plot_bg_chan = batch['plot_bg'] if 'plot_bg' in batch.keys() and isinstance(batch['plot_bg'], (int, float)) else 0 + gt_boxes = batch[pfix+'bb_target'+ptfix] if pfix+'bb_target'+ptfix in batch.keys() and show_gt_boxes else None + class_targets = batch[pfix+'class_targets'+ptfix] if pfix+'class_targets'+ptfix in batch.keys() else None + cf_roi_items = [pfix+it+ptfix for it in cf.roi_items] + if roi_items == "all": + roi_items = [it for it in cf_roi_items] + elif roi_items == "targets": + roi_items = [it for it in cf_roi_items if 'targets' in it] + else: + roi_items = [it for it in cf_roi_items if it in roi_items] + + if res_dict is not None: + seg_preds = res_dict["seg_preds"] if (show_seg_pred is not None and 'seg_preds' in res_dict.keys() + and show_seg_ids) else None + if '2D_boxes' in res_dict.keys(): + assert cf.dim==2 + pr_boxes = res_dict["2D_boxes"] + elif 'boxes' in res_dict.keys(): + pr_boxes = res_dict["boxes"] + else: + pr_boxes = None + else: + seg_preds = None + pr_boxes = None + + # -------------- get shapes, apply sample selection ----------------- + (n_samples, mods, h, w), d = data.shape[:4], 0 + + z_ics = [slice(None)] + if has_colorchannels: #has to be 2D + data = np.transpose(data, axes=(0, 2, 3, 1)) # now b,y,x,c + mods = 1 + else: + if len(data.shape) == 5: # 3dim case + d = data.shape[4] + if vol_slice_picks is None: + z_ics = np.arange(0, d) + elif hasattr(vol_slice_picks, "__iter__") and vol_slice_picks[0]=="random": + z_ics = np.random.choice(np.arange(0, d), size=min(vol_slice_picks[1], d), replace=False) + else: + z_ics = vol_slice_picks + + sample_ics = range(n_samples) + # 8000 approx value of pixels that are displayable in one figure dim (pyplot has a render limit), depends on dpi however + if data.shape[0]*data.shape[2]*len(z_ics)>8000: + n_picks = max(1, int(8000/(data.shape[2]*len(z_ics)))) + if len(z_ics)>1 and vol_slice_picks is None: + z_ics = np.random.choice(np.arange(0, data.shape[4]), + size=min(data.shape[4], max(1,int(8000/(n_picks*data.shape[2])))), replace=False) + if sample_picks is None: + sample_picks = np.random.choice(data.shape[0], n_picks, replace=False) + + if sample_picks is not None: + sample_ics = [s for s in sample_picks if s in sample_ics] + n_samples = len(sample_ics) + + if not plot_mods: + mods = 0 + if show_seg_ids=="all": + show_seg_ids = np.unique(seg) + if seg_preds is not None and not type(show_seg_ids)==str: + seg_preds = np.copy(seg_preds) + seg_preds = np.where(np.isin(seg_preds, show_seg_ids), seg_preds, 0) + if seg is not None: + if not type(show_seg_ids)==str: #to save time + seg = np.where(np.isin(seg, show_seg_ids), seg, 0) + legend_items = {cf.seg_id2label[seg_id] for seg_id in np.unique(seg) if seg_id != 0} # add seg labels + else: + legend_items = set() + + # -------------- setup figure ----------------- + if isRGB: + data = RGB_to_rgb(data) + if plot_bg is not None: + plot_bg = RGB_to_rgb(plot_bg) + n_cols = mods + if seg is not None or gt_boxes is not None: + n_cols += 1 + if seg_preds is not None or pr_boxes is not None: + n_cols += 1 + + n_rows = n_samples*len(z_ics) + grid = gridspec.GridSpec(n_rows, n_cols, wspace=0.01, hspace=0.0) + fig = plt.figure(figsize=((n_cols + 1)*2, n_rows*2), tight_layout=True) + title_fs = 12 # fontsize + + sample_ics, z_ics = sorted(sample_ics), sorted(z_ics) + row = 0 # current row + for s_count, s_ix in enumerate(sample_ics): + for z_ix in z_ics: + col = 0 # current col + # ----visualise input data ------------- + if has_colorchannels: + if plot_mods: + ax = fig.add_subplot(grid[row, col]) + ax.imshow(data[s_ix][...,z_ix]) + ax.axis("off") + if row == 0: + plt.title("Input", fontsize=title_fs) + if col == 0: + specs = batch.get('spec', batch['pid']) + intra_patient_ix = s_ix if type(z_ix) == slice else z_ix + ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] + ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number + col += 1 + bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix][...,z_ix] + else: + for mod in range(mods): + ax = fig.add_subplot(grid[row, col]) + ax.imshow(data[s_ix, mod][...,z_ix], cmap="gray", vmin=vmin) + suppress_axes_lines(ax) + if row == 0: + plt.title("Mod. " + str(mod), fontsize=title_fs) + if col == 0: + specs = batch.get('spec', batch['pid']) + intra_patient_ix = s_ix if type(z_ix)==slice else z_ix + ylabel = str(specs[s_ix])[-5:]+"/"+str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] + ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number + col += 1 + bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix, plot_bg_chan][...,z_ix] + + # ---evtly visualise groundtruths------------------- + if seg is not None or gt_boxes is not None: + # img as bg for gt + ax = fig.add_subplot(grid[row, col]) + ax.imshow(bg_img, cmap="gray", vmin=vmin) + if row == 0: + plt.title("Ground Truth", fontsize=title_fs) + if col == 0: + specs = batch.get('spec', batch['pid']) + intra_patient_ix = s_ix if type(z_ix) == slice else z_ix + ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] + ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number + suppress_axes_lines(ax) + else: + plt.axis('off') + col += 1 + + if seg is not None and seg.shape[1] == 1: + ax.imshow(to_rgba(seg[s_ix][0][...,z_ix], cf.cmap), alpha=0.8) + elif seg is not None: + ax.imshow(to_rgba(np.argmax(seg[s_ix][...,z_ix], axis=0), cf.cmap), alpha=0.8) + + # gt bounding boxes + if gt_boxes is not None and len(gt_boxes[s_ix]) > 0: + for j, box in enumerate(gt_boxes[s_ix]): + if d > 0: + [z1, z2] = box[4:] + if not (z1<=z_ix and z_ix<=z2): + box = [] + if len(box) > 0: + [y1, x1, y2, x2] = box[:4] + width, height = x2 - x1, y2 - y1 + if class_targets is not None: + label = cf.class_id2label[class_targets[s_ix][j]] + legend_items.add(label) + if show_gt_labels: + text_poss, p = [(x1, y1), (x1, (y1+y2)//2)], 0 + text_fs = title_fs // 3 + if roi_items is not None: + for name in roi_items: + if name in cf_roi_items and batch[name][s_ix][j] is not None: + if 'class_targets' in name and cf.plot_class_ids: + text_x = x2 #- 2 * text_fs * (len(str(class_targets[s_ix][j]))) # avoid overlap of scores + text_y = y1 #+ 2 * text_fs + text_str = '{}'.format(class_targets[s_ix][j]) + elif 'regression_targets' in name: + text_x, text_y = (x2, y2) + text_str = "[" + " ".join( + ["{:.1f}".format(x) for x in batch[name][s_ix][j]]) + "]" + elif 'rg_bin_targets' in name: + text_x, text_y = (x1, y2) + text_str = '{}'.format(batch[name][s_ix][j]) + else: + text_pos = text_poss.pop(0) + text_x = text_pos[0] #- 2 * text_fs * len(str(batch[name][s_ix][j])) + text_y = text_pos[1] #+ 2 * text_fs + text_str = '{}'.format(batch[name][s_ix][j]) + + ax.text(text_x, text_y, text_str, color=cf.white, fontsize=text_fs, + bbox=dict(facecolor=label.color, alpha=0.7, edgecolor='none', clip_on=True, + pad=0)) + p+=1 + bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=label.color, + facecolor='none') + ax.add_patch(bbox) + + # -----evtly visualise predictions ------------- + if pr_boxes is not None or seg_preds is not None: + ax = fig.add_subplot(grid[row, col]) + ax.imshow(bg_img, cmap="gray") + ax.axis("off") + col += 1 + if row == 0: + plt.title("Prediction", fontsize=title_fs) + # ---------- pred boxes ------------------------- + if pr_boxes is not None and len(pr_boxes[s_ix]) > 0: + box_score_thres = cf.min_det_thresh if box_score_thres is None else box_score_thres + for j, box in enumerate(pr_boxes[s_ix]): + plot_box = box["box_type"] in ["det", "prop"] # , "pos_anchor", "neg_anchor"] + if box["box_type"] == "det" and (float(box["box_score"]) <= box_score_thres or box["box_pred_class_id"] == 0): + plot_box = False + + if plot_box: + if d > 0: + [z1, z2] = box["box_coords"][4:] + if not (z1<=z_ix and z_ix<=z2): + box = [] + if len(box) > 0: + [y1, x1, y2, x2] = box["box_coords"][:4] + + width, height = x2 - x1, y2 - y1 + + if box["box_type"] == "det": + label = cf.class_id2label[box["box_pred_class_id"]] + legend_items.add(label) + text_x, text_y = x2, y1 + id_text = str(box["box_pred_class_id"]) + "|" if cf.plot_class_ids else "" + text_str = '{}{:.0f}'.format(id_text, box["box_score"] * 100) + text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, + pad=0) + ax.text(text_x, text_y, text_str, color=cf.white, + bbox=text_settings, fontsize=title_fs // 4) + edgecolor = label.color + if 'regression' in box.keys(): + text_x, text_y = x2, y2 + id_text = "["+" ".join(["{:.1f}".format(x) for x in box["regression"]])+"]" #str(box["regression"]) #+ "|" if cf.plot_class_ids else "" + if 'rg_uncertainty' in box.keys() and not np.isnan(box['rg_uncertainty']): + id_text += " | {:.1f}".format(box['rg_uncertainty']) + text_str = '{}'.format(id_text) #, box["box_score"] * 100) + text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', + clip_on=True, pad=0) + ax.text(text_x, text_y, text_str, color=cf.white, + bbox=text_settings, fontsize=title_fs // 4) + if 'rg_bin' in box.keys(): + text_x, text_y = x1, y2 + text_str = '{}'.format(box["rg_bin"]) + text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', + clip_on=True, pad=0) + ax.text(text_x, text_y, text_str, color=cf.white, + bbox=text_settings, fontsize=title_fs // 4) + else: + label = cf.box_type2label[box["box_type"]] + legend_items.add(label) + edgecolor = label.color + + bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=edgecolor, + facecolor='none') + ax.add_patch(bbox) + # ------------ pred segs -------- + if seg_preds is not None: # and seg_preds.shape[1] == 1: + if cf.class_specific_seg: + ax.imshow(to_rgba(seg_preds[s_ix][0][...,z_ix], cf.cmap), alpha=0.8) + else: + ax.imshow(bin_seg_to_rgba(seg_preds[s_ix][0][...,z_ix], cf.orange), alpha=0.8) + + row += 1 + + # -----actions for all batch entries---------- + if legend and len(legend_items) > 0: + patches = [] + for label in legend_items: + if cf.plot_class_ids and type(label) != type(cf.box_labels[0]): + id_text = str(label.id) + ":" + else: + id_text = "" + + patches.append(mpatches.Patch(color=label.color, label="{}{:.10s}".format(id_text, label.name))) + # assumes one image gives enough y-space for 5 legend items + ncols = max(1, len(legend_items) // (5 * n_samples)) + plt.figlegend(handles=patches, loc="upper center", bbox_to_anchor=(0.99, 0.86), + borderaxespad=0., ncol=ncols, bbox_transform=fig.transFigure, + fontsize=int(2/3*title_fs)) + # fig.set_size_inches(mods+3+ncols-1,1.5+1.2*n_samples) + + if show_info: + plt.figtext(0, 0, "Batch content is of type\n{}\nand has shapes\n".format(btype) + \ + "{} for 'data' and {} for 'seg'".format(data_init_shp, seg_init_shp)) + + if out_file is not None: + if cf.server_env: + IO_safe(plt.savefig, fname=out_file, dpi=dpi, pad_inches=0.0, bbox_inches='tight', _raise=False) + else: + plt.savefig(out_file, dpi=dpi, pad_inches=0.0, bbox_inches='tight') + if return_fig: + return plt.gcf() + plt.clf() + plt.close() + +def view_batch_paper(cf, batch, res_dict=None, out_file=None, legend=True, show_info=True, has_colorchannels=False, + isRGB=True, show_seg_ids="all", show_seg_pred=True, show_gt_boxes=True, show_gt_labels=False, + roi_items="all", split_ens_ics=False, server_env=True, sample_picks=None, vol_slice_picks=None, + patient_items=False, box_score_thres=None, plot_mods=True, dpi=400, vmin=None, return_fig=False): + r"""view data and target entries of a batch. + + batch expected as dic with entries 'data' and 'seg' holding tensors or nparrays of + size :math:`batch\_size \times modalities \times h \times w` for data + and :math:`batch\_size \times classes \times h \times w` or + :math:`batch\_size \times 1 \times h \times w` for segs. + Classes, even if just dummy, are always needed for plotting since they determine colors. + + :param cf: + :param batch: + :param res_dict: + :param out_file: + :param legend: + :param show_info: + :param has_colorchannels: + :param isRGB: + :param show_seg_ids: + :param show_seg_pred: + :param show_gt_boxes: + :param show_gt_labels: + :param roi_items: strings "all" or "targets" --> all roi items in cf.roi_items or only those which are targets, or + list holding keys/names of entries in cf.roi_items to plot additionally on roi boxes. empty iterator + to show none. + :param split_ens_ics: + :param server_env: + :param sample_picks: which indices of the batch to display. None for all. + :param vol_slice_picks: when batch elements are 3D: which slices to display. None for all, or tuples + ("random", int: amt) / (float€[0,1]: fg_prob, int: amt) for random pick / fg_slices pick w probability fg_prob + of amt slices. fg pick requires gt seg. + :param patient_items: set to true if patient-wise batch items should be displayed (need to be contained in batch + and marked via 'patient_' prefix. + :param box_score_thres: plot only boxes with pred_score > box_score_thres. None or 0. for no thres. + :param plot_mods: + :param dpi: graphics resolution + :param vmin: min value for gs cmap in imshow, set to fix inter-batch, or None for intra-batch. + + pyplot expects dimensions in order y,x,chans (height, width, chans) for imshow. + show_seg_ids: "all" or None or list with seg classes to show (seg_ids) + + """ + # pfix = prefix, ptfix = postfix + pfix = 'patient_' if patient_items else '' + ptfix = '_2d' if (patient_items and cf.dim==2) else '' + + # -------------- get data, set flags ----------------- + + btype = type(batch[pfix + 'data']) + data = batch[pfix + 'data'].astype("float32") + seg = batch[pfix + 'seg'] + + # seg = np.array(seg).mean(axis=0, keepdims=True) + # seg[seg>0] = 1. + + print("Showing multirater GT") + data_init_shp, seg_init_shp = data.shape, seg.shape + fg_slices = np.where(np.sum(np.sum(np.squeeze(seg), axis=0), axis=0)>0)[0] + + if len(fg_slices)==0: + print("skipping empty patient") + return + if vol_slice_picks is None: + vol_slice_picks = fg_slices + + print("data shp, seg shp", data_init_shp, seg_init_shp) + + plot_bg = batch['plot_bg'] if 'plot_bg' in batch.keys() and not isinstance(batch['plot_bg'], (int, float)) else None + plot_bg_chan = batch['plot_bg'] if 'plot_bg' in batch.keys() and isinstance(batch['plot_bg'], (int, float)) else 0 + gt_boxes = batch[pfix+'bb_target'+ptfix] if pfix+'bb_target'+ptfix in batch.keys() and show_gt_boxes else None + class_targets = batch[pfix+'class_targets'+ptfix] if pfix+'class_targets'+ptfix in batch.keys() else None + cf_roi_items = [pfix+it+ptfix for it in cf.roi_items] + if roi_items == "all": + roi_items = [it for it in cf_roi_items] + elif roi_items == "targets": + roi_items = [it for it in cf_roi_items if 'targets' in it] + else: + roi_items = [it for it in cf_roi_items if it in roi_items] + + if res_dict is not None: + seg_preds = res_dict["seg_preds"] if (show_seg_pred is not None and 'seg_preds' in res_dict.keys() + and show_seg_ids) else None + if '2D_boxes' in res_dict.keys(): + assert cf.dim==2 + pr_boxes = res_dict["2D_boxes"] + elif 'boxes' in res_dict.keys(): + pr_boxes = res_dict["boxes"] + else: + pr_boxes = None + else: + seg_preds = None + pr_boxes = None + + # -------------- get shapes, apply sample selection ----------------- + (n_samples, mods, h, w), d = data.shape[:4], 0 + + z_ics = [slice(None)] + if has_colorchannels: #has to be 2D + data = np.transpose(data, axes=(0, 2, 3, 1)) # now b,y,x,c + mods = 1 + else: + if len(data.shape) == 5: # 3dim case + d = data.shape[4] + if vol_slice_picks is None: + z_ics = np.arange(0, d) + # elif hasattr(vol_slice_picks, "__iter__") and vol_slice_picks[0]=="random": + # z_ics = np.random.choice(np.arange(0, d), size=min(vol_slice_picks[1], d), replace=False) + else: + z_ics = vol_slice_picks + + sample_ics = range(n_samples) + # 8000 approx value of pixels that are displayable in one figure dim (pyplot has a render limit), depends on dpi however + if data.shape[0]*data.shape[2]*len(z_ics)>8000: + n_picks = max(1, int(8000/(data.shape[2]*len(z_ics)))) + if len(z_ics)>1: + if vol_slice_picks is None: + z_ics = np.random.choice(np.arange(0, data.shape[4]), + size=min(data.shape[4], max(1,int(8000/(n_picks*data.shape[2])))), replace=False) + else: + z_ics = np.random.choice(vol_slice_picks, + size=min(len(vol_slice_picks), max(1,int(8000/(n_picks*data.shape[2])))), replace=False) + + if sample_picks is None: + sample_picks = np.random.choice(data.shape[0], n_picks, replace=False) + + if sample_picks is not None: + sample_ics = [s for s in sample_picks if s in sample_ics] + n_samples = len(sample_ics) + + if not plot_mods: + mods = 0 + if show_seg_ids=="all": + show_seg_ids = np.unique(seg) + + legend_items = set() + + # -------------- setup figure ----------------- + if isRGB: + data = RGB_to_rgb(data) + if plot_bg is not None: + plot_bg = RGB_to_rgb(plot_bg) + n_cols = mods + if seg is not None or gt_boxes is not None: + n_cols += 1 + if seg_preds is not None or pr_boxes is not None: + n_cols += 1 + + n_rows = n_samples*len(z_ics) + grid = gridspec.GridSpec(n_rows, n_cols, wspace=0.01, hspace=0.0) + fig = plt.figure(figsize=((n_cols + 1)*2, n_rows*2), tight_layout=True) + title_fs = 12 # fontsize + + sample_ics, z_ics = sorted(sample_ics), sorted(z_ics) + row = 0 # current row + for s_count, s_ix in enumerate(sample_ics): + for z_ix in z_ics: + col = 0 # current col + # ----visualise input data ------------- + if has_colorchannels: + if plot_mods: + ax = fig.add_subplot(grid[row, col]) + ax.imshow(data[s_ix][...,z_ix]) + ax.axis("off") + if row == 0: + plt.title("Input", fontsize=title_fs) + if col == 0: + # key = "spec" if "spec" in batch.keys() else "pid" + specs = batch.get('spec', batch['pid']) + intra_patient_ix = s_ix if type(z_ix) == slice else z_ix + ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] + ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number + col += 1 + bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix][...,z_ix] + else: + for mod in range(mods): + ax = fig.add_subplot(grid[row, col]) + ax.imshow(data[s_ix, mod][...,z_ix], cmap="gray", vmin=vmin) + suppress_axes_lines(ax) + if row == 0: + plt.title("Mod. " + str(mod), fontsize=title_fs) + if col == 0: + # key = "spec" if "spec" in batch.keys() else "pid" + specs = batch.get('spec', batch['pid']) + intra_patient_ix = s_ix if type(z_ix)==slice else z_ix + ylabel = str(specs[s_ix])[-5:]+"/"+str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] + ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number + col += 1 + bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix, plot_bg_chan][...,z_ix] + + # ---evtly visualise groundtruths------------------- + if seg is not None or gt_boxes is not None: + # img as bg for gt + ax = fig.add_subplot(grid[row, col]) + ax.imshow(bg_img, cmap="gray", vmin=vmin) + if row == 0: + plt.title("Ground Truth+ Pred", fontsize=title_fs) + if col == 0: + specs = batch.get('spec', batch['pid']) + intra_patient_ix = s_ix if type(z_ix) == slice else z_ix + ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] + ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number + suppress_axes_lines(ax) + else: + plt.axis('off') + col += 1 + + if seg is not None and seg.shape[1] == 1: + cmap = {1: cf.orange} + ax.imshow(to_rgba(seg[s_ix][0][...,z_ix], cmap), alpha=0.8) + + # gt bounding boxes + if gt_boxes is not None and len(gt_boxes[s_ix]) > 0: + for j, box in enumerate(gt_boxes[s_ix]): + if d > 0: + [z1, z2] = box[4:] + if not (z1<=z_ix and z_ix<=z2): + box = [] + if len(box) > 0: + [y1, x1, y2, x2] = box[:4] + # [x1,y1,x2,y2] = box[:4]#:return: coords (x1, y1, x2, y2) + width, height = x2 - x1, y2 - y1 + if class_targets is not None: + label = cf.class_id2label[class_targets[s_ix][j]] + legend_items.add(label) + if show_gt_labels and cf.plot_class_ids: + text_poss, p = [(x1, y1), (x1, (y1+y2)//2)], 0 + text_fs = title_fs // 3 + if roi_items is not None: + for name in roi_items: + if name in cf_roi_items and batch[name][s_ix][j] is not None: + if 'class_targets' in name: + text_x = x2 #- 2 * text_fs * (len(str(class_targets[s_ix][j]))) # avoid overlap of scores + text_y = y1 #+ 2 * text_fs + text_str = '{}'.format(class_targets[s_ix][j]) + elif 'regression_targets' in name: + text_x, text_y = (x2, y2) + text_str = "[" + " ".join( + ["{:.1f}".format(x) for x in batch[name][s_ix][j]]) + "]" + elif 'rg_bin_targets' in name: + text_x, text_y = (x1, y2) + text_str = '{}'.format(batch[name][s_ix][j]) + else: + text_pos = text_poss.pop(0) + text_x = text_pos[0] #- 2 * text_fs * len(str(batch[name][s_ix][j])) + text_y = text_pos[1] #+ 2 * text_fs + text_str = '{}'.format(batch[name][s_ix][j]) + + ax.text(text_x, text_y, text_str, color=cf.black if label.color==cf.yellow else cf.white, fontsize=text_fs, + bbox=dict(facecolor=label.color, alpha=0.7, edgecolor='none', clip_on=True, + pad=0)) + p+=1 + bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=label.color, + facecolor='none') + ax.add_patch(bbox) + + # # -----evtly visualise predictions ------------- + # if pr_boxes is not None or seg_preds is not None: + # ax = fig.add_subplot(grid[row, col]) + # ax.imshow(bg_img, cmap="gray") + # ax.axis("off") + # col += 1 + # if row == 0: + # plt.title("Prediction", fontsize=title_fs) + + + + # ---------- pred boxes ------------------------- + if pr_boxes is not None and len(pr_boxes[s_ix]) > 0: + box_score_thres = cf.min_det_thresh if box_score_thres is None else box_score_thres + for j, box in enumerate(pr_boxes[s_ix]): + plot_box = box["box_type"] in ["det", "prop"] # , "pos_anchor", "neg_anchor"] + if box["box_type"] == "det" and (float(box["box_score"]) <= box_score_thres or box["box_pred_class_id"] == 0): + plot_box = False + + if plot_box: + if d > 0: + [z1, z2] = box["box_coords"][4:] + if not (z1<=z_ix and z_ix<=z2): + box = [] + if len(box) > 0: + [y1, x1, y2, x2] = box["box_coords"][:4] + + width, height = x2 - x1, y2 - y1 + + if box["box_type"] == "det": + label = cf.bin_id2label[box["rg_bin"]] + color = cf.aubergine + legend_items.add(label) + text_x, text_y = x2, y1 + #id_text = str(box["box_pred_class_id"]) + "|" if cf.plot_class_ids else "" + id_text = "fg: " + text_str = '{}{:.0f}'.format(id_text, box["box_score"] * 100) + text_settings = dict(facecolor=color, alpha=0.5, edgecolor='none', clip_on=True, + pad=0.2) + ax.text(text_x, text_y, text_str, color=cf.black if label.color==cf.yellow else cf.white, + bbox=text_settings, fontsize=title_fs // 2) + edgecolor = color #label.color + if 'regression' in box.keys(): + text_x, text_y = x2, y2 + id_text = "ms: "+" ".join(["{:.1f}".format(x) for x in box["regression"]])+"" + text_str = '{}'.format(id_text) #, box["box_score"] * 100) + text_settings = dict(facecolor=color, alpha=0.5, edgecolor='none', + clip_on=True, pad=0.2) + ax.text(text_x, text_y, text_str, color=cf.black if label.color==cf.yellow else cf.white, + bbox=text_settings, fontsize=title_fs // 2) + if 'rg_bin' in box.keys(): + text_x, text_y = x1, y2 + text_str = '{}'.format(box["rg_bin"]) + text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', + clip_on=True, pad=0) + # ax.text(text_x, text_y, text_str, color=cf.white, + # bbox=text_settings, fontsize=title_fs // 4) + if split_ens_ics and "ens_ix" in box.keys(): + n_aug = box["ens_ix"].split("_")[1] + edgecolor = [c for c in cf.color_palette if not c == cf.green][ + int(n_aug) % (len(cf.color_palette) - 1)] + text_x, text_y = x1, y2 + text_str = "{}".format(box["ens_ix"][2:]) + ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, + fontsize=title_fs // 6) + else: + label = cf.box_type2label[box["box_type"]] + legend_items.add(label) + edgecolor = label.color + + bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=edgecolor, + facecolor='none') + ax.add_patch(bbox) + row += 1 + + # -----actions for all batch entries---------- + if legend and len(legend_items) > 0: + patches = [] + for label in legend_items: + if cf.plot_class_ids and type(label) != type(cf.box_labels[0]): + id_text = str(label.id) + ":" + else: + id_text = "" + + patches.append(mpatches.Patch(color=label.color, label="{}{:.10s}".format(id_text, label.name))) + # assumes one image gives enough y-space for 5 legend items + ncols = max(1, len(legend_items) // (5 * n_samples)) + plt.figlegend(handles=patches, loc="upper center", bbox_to_anchor=(0.99, 0.86), + borderaxespad=0., ncol=ncols, bbox_transform=fig.transFigure, + fontsize=int(2/3*title_fs)) + # fig.set_size_inches(mods+3+ncols-1,1.5+1.2*n_samples) + + if show_info: + plt.figtext(0, 0, "Batch content is of type\n{}\nand has shapes\n".format(btype) + \ + "{} for 'data' and {} for 'seg'".format(data_init_shp, seg_init_shp)) + + if out_file is not None: + plt.savefig(out_file, dpi=dpi, pad_inches=0.0, bbox_inches='tight', tight_layout=True) + if return_fig: + return plt.gcf() + if not (server_env or cf.server_env): + plt.show() + plt.clf() + plt.close() + +def view_batch_thesis(cf, batch, res_dict=None, out_file=None, legend=True, has_colorchannels=False, + isRGB=True, show_seg_ids="all", show_seg_pred=True, show_gt_boxes=True, show_gt_labels=False, show_cl_ids=True, + roi_items="all", server_env=True, sample_picks=None, vol_slice_picks=None, fontsize=12, seg_cmap="class", + patient_items=False, box_score_thres=None, plot_mods=True, dpi=400, vmin=None, return_fig=False, axes=None): + r"""view data and target entries of a batch. + + batch expected as dic with entries 'data' and 'seg' holding tensors or nparrays of + size :math:`batch\_size \times modalities \times h \times w` for data + and :math:`batch\_size \times classes \times h \times w` or + :math:`batch\_size \times 1 \times h \times w` for segs. + Classes, even if just dummy, are always needed for plotting since they determine colors. + + :param cf: + :param batch: + :param res_dict: + :param out_file: + :param legend: + :param show_info: + :param has_colorchannels: + :param isRGB: + :param show_seg_ids: + :param show_seg_pred: + :param show_gt_boxes: + :param show_gt_labels: + :param roi_items: strings "all" or "targets" --> all roi items in cf.roi_items or only those which are targets, or + list holding keys/names of entries in cf.roi_items to plot additionally on roi boxes. empty iterator + to show none. + :param split_ens_ics: + :param server_env: + :param sample_picks: which indices of the batch to display. None for all. + :param vol_slice_picks: when batch elements are 3D: which slices to display. None for all, or tuples + ("random", int: amt) / (float€[0,1]: fg_prob, int: amt) for random pick / fg_slices pick w probability fg_prob + of amt slices. fg pick requires gt seg. + :param patient_items: set to true if patient-wise batch items should be displayed (need to be contained in batch + and marked via 'patient_' prefix. + :param box_score_thres: plot only boxes with pred_score > box_score_thres. None or 0. for no thres. + :param plot_mods: + :param dpi: graphics resolution + :param vmin: min value for gs cmap in imshow, set to fix inter-batch, or None for intra-batch. + + pyplot expects dimensions in order y,x,chans (height, width, chans) for imshow. + show_seg_ids: "all" or None or list with seg classes to show (seg_ids) + + """ + # pfix = prefix, ptfix = postfix + pfix = 'patient_' if patient_items else '' + ptfix = '_2d' if (patient_items and cf.dim==2) else '' + + # -------------- get data, set flags ----------------- + + btype = type(batch[pfix + 'data']) + data = batch[pfix + 'data'].astype("float32") + seg = batch[pfix + 'seg'] + + data_init_shp, seg_init_shp = data.shape, seg.shape + fg_slices = np.where(np.sum(np.sum(np.squeeze(seg), axis=0), axis=0)>0)[0] + + if len(fg_slices)==0: + print("skipping empty patient") + return + if vol_slice_picks is None: + vol_slice_picks = fg_slices + + #print("data shp, seg shp", data_init_shp, seg_init_shp) + + plot_bg = batch['plot_bg'] if 'plot_bg' in batch.keys() and not isinstance(batch['plot_bg'], (int, float)) else None + plot_bg_chan = batch['plot_bg'] if 'plot_bg' in batch.keys() and isinstance(batch['plot_bg'], (int, float)) else 0 + gt_boxes = batch[pfix+'bb_target'+ptfix] if pfix+'bb_target'+ptfix in batch.keys() and show_gt_boxes else None + class_targets = batch[pfix+'class_targets'+ptfix] if pfix+'class_targets'+ptfix in batch.keys() else None + cl_targets_sa = batch[pfix+'class_targets_sa'+ptfix] if pfix+'class_targets_sa'+ptfix in batch.keys() else None + cf_roi_items = [pfix+it+ptfix for it in cf.roi_items] + if roi_items == "all": + roi_items = [it for it in cf_roi_items] + elif roi_items == "targets": + roi_items = [it for it in cf_roi_items if 'targets' in it] + else: + roi_items = [it for it in cf_roi_items if it in roi_items] + + if res_dict is not None: + seg_preds = res_dict["seg_preds"] if (show_seg_pred is not None and 'seg_preds' in res_dict.keys() + and show_seg_ids) else None + if '2D_boxes' in res_dict.keys(): + assert cf.dim==2 + pr_boxes = res_dict["2D_boxes"] + elif 'boxes' in res_dict.keys(): + pr_boxes = res_dict["boxes"] + else: + pr_boxes = None + else: + seg_preds = None + pr_boxes = None + + # -------------- get shapes, apply sample selection ----------------- + (n_samples, mods, h, w), d = data.shape[:4], 0 + + z_ics = [slice(None)] + if has_colorchannels: #has to be 2D + data = np.transpose(data, axes=(0, 2, 3, 1)) # now b,y,x,c + mods = 1 + else: + if len(data.shape) == 5: # 3dim case + d = data.shape[4] + if vol_slice_picks is None: + z_ics = np.arange(0, d) + else: + z_ics = vol_slice_picks + + sample_ics = range(n_samples) + # 8000 approx value of pixels that are displayable in one figure dim (pyplot has a render limit), depends on dpi however + if data.shape[0]*data.shape[2]*len(z_ics)>8000: + n_picks = max(1, int(8000/(data.shape[2]*len(z_ics)))) + if len(z_ics)>1 and vol_slice_picks is None: + z_ics = np.random.choice(np.arange(0, data.shape[4]), + size=min(data.shape[4], max(1,int(8000/(n_picks*data.shape[2])))), replace=False) + if sample_picks is None: + sample_picks = np.random.choice(data.shape[0], n_picks, replace=False) + + if sample_picks is not None: + sample_ics = [s for s in sample_picks if s in sample_ics] + n_samples = len(sample_ics) + + if not plot_mods: + mods = 0 + if show_seg_ids=="all": + show_seg_ids = np.unique(seg) + + legend_items = set() + + # -------------- setup figure ----------------- + if isRGB: + data = RGB_to_rgb(data) + if plot_bg is not None: + plot_bg = RGB_to_rgb(plot_bg) + n_cols = mods + if seg is not None or gt_boxes is not None: + n_cols += 1 + if seg_preds is not None or pr_boxes is not None: + n_cols += 1 + + n_rows = n_samples*len(z_ics) + grid = gridspec.GridSpec(n_rows, n_cols, wspace=0.01, hspace=0.0) + fig = plt.figure(figsize=((n_cols + 1)*2, n_rows*2), tight_layout=True) + title_fs = fontsize # fontsize + text_fs = title_fs * 2 / 3 + + sample_ics, z_ics = sorted(sample_ics), sorted(z_ics) + row = 0 # current row + for s_count, s_ix in enumerate(sample_ics): + for z_ix in z_ics: + col = 0 # current col + # ----visualise input data ------------- + if has_colorchannels: + if plot_mods: + ax = fig.add_subplot(grid[row, col]) + ax.imshow(data[s_ix][...,z_ix]) + ax.axis("off") + if row == 0: + plt.title("Input", fontsize=title_fs) + if col == 0: + # key = "spec" if "spec" in batch.keys() else "pid" + specs = batch.get('spec', batch['pid']) + intra_patient_ix = s_ix if type(z_ix) == slice else z_ix + ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] + ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number + col += 1 + bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix][...,z_ix] + else: + for mod in range(mods): + ax = fig.add_subplot(grid[row, col]) + ax.imshow(data[s_ix, mod][...,z_ix], cmap="gray", vmin=vmin) + suppress_axes_lines(ax) + if row == 0: + plt.title("Mod. " + str(mod), fontsize=title_fs) + if col == 0: + # key = "spec" if "spec" in batch.keys() else "pid" + specs = batch.get('spec', batch['pid']) + intra_patient_ix = s_ix if type(z_ix)==slice else z_ix + ylabel = str(specs[s_ix])[-5:]+"/"+str(intra_patient_ix) + ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number + col += 1 + bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix, plot_bg_chan][...,z_ix] + + # ---evtly visualise groundtruths------------------- + if seg is not None or gt_boxes is not None: + # img as bg for gt + if axes is not None and 'gt' in axes.keys(): + ax = axes['gt'] + else: + ax = fig.add_subplot(grid[row, col]) + ax.imshow(bg_img, cmap="gray", vmin=vmin) + if row == 0: + ax.set_title("Ground Truth", fontsize=title_fs) + if col == 0: + # key = "spec" if "spec" in batch.keys() else "pid" + specs = batch.get('spec', batch['pid']) + intra_patient_ix = s_ix if type(z_ix) == slice else z_ix + ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) # str(specs[s_ix])[-5:] + ax.set_ylabel("{:s}".format(ylabel), fontsize=text_fs*1.3) # show id-number + suppress_axes_lines(ax) + else: + ax.axis('off') + col += 1 + + # gt bounding boxes + if gt_boxes is not None and len(gt_boxes[s_ix]) > 0: + for j, box in enumerate(gt_boxes[s_ix]): + if d > 0: + [z1, z2] = box[4:] + if not (z1<=z_ix and z_ix<=z2): + box = [] + if len(box) > 0: + [y1, x1, y2, x2] = box[:4] + # [x1,y1,x2,y2] = box[:4]#:return: coords (x1, y1, x2, y2) + width, height = x2 - x1, y2 - y1 + if class_targets is not None: + try: + label = cf.bin_id2label[cf.rg_val_to_bin_id(batch['patient_regression_targets'][s_ix][j])] + except AttributeError: + label = cf.class_id2label[class_targets[s_ix][j]] + legend_items.add(label) + if show_gt_labels and cf.plot_class_ids: + bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=label.color, + facecolor='none') + if height<=text_fs*6: + y1 -= text_fs*1.5 + y2 += text_fs*2 + text_poss, p = [(x1, y1), (x1, (y1+y2)//2)], 0 + if roi_items is not None: + for name in roi_items: + if name in cf_roi_items and batch[name][s_ix][j] is not None: + if 'class_targets' in name: + text_str = '{}'.format(class_targets[s_ix][j]) + text_x, text_y = (x2 + 0 * len(text_str) // 4, y2) + elif 'regression_targets' in name: + text_str = 'agg. MS: {:.2f}'.format(batch[name][s_ix][j][0]) + text_x, text_y = (x2 + 0 * len(text_str) // 4, y2) + elif 'rg_bin_targets_sa' in name: + text_str = 'sa. MS: {}'.format(batch[name][s_ix][j]) + text_x, text_y = (x2-0*len(text_str)*text_fs//4, y1) + # elif 'rg_bin_targets' in name: + # text_str = 'agg. ms:{}'.format(batch[name][s_ix][j]) + # text_x, text_y = (x2+0*len(text_str)//4, y1) + + + ax.text(text_x, text_y, text_str, color=cf.black if + (label.color[:3]==cf.yellow or label.color[:3]==cf.green) else cf.white, + fontsize=text_fs, + bbox=dict(facecolor=label.color, alpha=0.7, edgecolor='none', clip_on=True, pad=0)) + p+=1 + ax.add_patch(bbox) + if seg is not None and seg.shape[1] == 1: + #cmap = {1: cf.orange} + # cmap = {label_id: label.color for label_id, label in cf.bin_id2label.items()} + # this whole function is totally only hacked together for a quick very specific case + if seg_cmap == "rg" or seg_cmap=="regression": + cmap = {1: cf.bin_id2label[cf.rg_val_to_bin_id(batch['patient_regression_targets'][s_ix][0])].color} + else: + cmap = cf.class_cmap + ax.imshow(to_rgba(seg[s_ix][0][...,z_ix], cmap), alpha=0.8) + + + # # -----evtly visualise predictions ------------- + if pr_boxes is not None or seg_preds is not None: + if axes is not None and 'pred' in axes.keys(): + ax = axes['pred'] + else: + ax = fig.add_subplot(grid[row, col]) + ax.imshow(bg_img, cmap="gray") + ax.axis("off") + col += 1 + if row == 0: + ax.set_title("Prediction", fontsize=title_fs) + + # ---------- pred boxes ------------------------- + if pr_boxes is not None and len(pr_boxes[s_ix]) > 0: + alpha = 0.7 + box_score_thres = cf.min_det_thresh if box_score_thres is None else box_score_thres + for j, box in enumerate(pr_boxes[s_ix]): + plot_box = box["box_type"] in ["det", "prop"] # , "pos_anchor", "neg_anchor"] + if box["box_type"] == "det" and (float(box["box_score"]) <= box_score_thres or box["box_pred_class_id"] == 0): + plot_box = False + + if plot_box: + if d > 0: + [z1, z2] = box["box_coords"][4:] + if not (z1<=z_ix and z_ix<=z2): + box = [] + if len(box) > 0: + [y1, x1, y2, x2] = box["box_coords"][:4] + + width, height = x2 - x1, y2 - y1 + + if box["box_type"] == "det": + try: + label = cf.bin_id2label[cf.rg_val_to_bin_id(box['regression'])] + except AttributeError: + label = cf.class_id2label[box['box_pred_class_id']] + # assert box["rg_bin"] == cf.rg_val_to_bin_id(box['regression']), \ + # "box bin: {}, rg-bin {}".format(box["rg_bin"], cf.rg_val_to_bin_id(box['regression'])) + color = label.color#cf.aubergine + edgecolor = color # label.color + text_color = cf.black if (color[:3]==cf.yellow or color[:3]==cf.green) else cf.white + legend_items.add(label) + bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=edgecolor, + facecolor='none') + if height<=text_fs*6: + y1 -= text_fs*1.5 + y2 += text_fs*2 + text_x, text_y = x2, y1 + #id_text = str(box["box_pred_class_id"]) + "|" if cf.plot_class_ids else "" + id_text = "FG: " + text_str = r'{}{:.0f}%'.format(id_text, box["box_score"] * 100) + text_settings = dict(facecolor=color, alpha=alpha, edgecolor='none', clip_on=True, + pad=0.2) + ax.text(text_x, text_y, text_str, color=text_color, + bbox=text_settings, fontsize=text_fs ) + + if 'regression' in box.keys(): + text_x, text_y = x2, y2 + id_text = "MS: "+" ".join(["{:.2f}".format(x) for x in box["regression"]])+"" + text_str = '{}'.format(id_text) + text_settings = dict(facecolor=color, alpha=alpha, edgecolor='none', + clip_on=True, pad=0.2) + ax.text(text_x, text_y, text_str, color=text_color, + bbox=text_settings, fontsize=text_fs) + if 'rg_bin' in box.keys(): + text_x, text_y = x1, y2 + text_str = '{}'.format(box["rg_bin"]) + text_settings = dict(facecolor=color, alpha=alpha, edgecolor='none', + clip_on=True, pad=0) + # ax.text(text_x, text_y, text_str, color=cf.white, + # bbox=text_settings, fontsize=title_fs // 4) + if 'box_pred_class_id' in box.keys() and show_cl_ids: + text_x, text_y = x2, y2 + id_text = box["box_pred_class_id"] + text_str = '{}'.format(id_text) + text_settings = dict(facecolor=color, alpha=alpha, edgecolor='none', + clip_on=True, pad=0.2) + ax.text(text_x, text_y, text_str, color=text_color, + bbox=text_settings, fontsize=text_fs) + else: + label = cf.box_type2label[box["box_type"]] + legend_items.add(label) + edgecolor = label.color + + ax.add_patch(bbox) + row += 1 + + # -----actions for all batch entries---------- + if legend and len(legend_items) > 0: + patches = [] + for label in legend_items: + if cf.plot_class_ids and type(label) != type(cf.box_labels[0]): + id_text = str(label.id) + ":" + else: + id_text = "" + + patches.append(mpatches.Patch(color=label.color, label="{}{:.10s}".format(id_text, label.name))) + # assumes one image gives enough y-space for 5 legend items + ncols = max(1, len(legend_items) // (5 * n_samples)) + plt.figlegend(handles=patches, loc="upper center", bbox_to_anchor=(0.99, 0.86), + borderaxespad=0., ncol=ncols, bbox_transform=fig.transFigure, + fontsize=int(2/3*title_fs)) + # fig.set_size_inches(mods+3+ncols-1,1.5+1.2*n_samples) + + if out_file is not None: + plt.savefig(out_file, dpi=dpi, pad_inches=0.0, bbox_inches='tight', tight_layout=True) + if return_fig: + return plt.gcf() + if not (server_env or cf.server_env): + plt.show() + plt.clf() + plt.close() + + +def view_slices(cf, img, seg=None, ids=None, title="", out_dir=None, legend=True, + cmap=None, label_remap=None, instance_labels=False): + """View slices of a 3D image overlayed with corresponding segmentations. + + :params img, seg: expected as 3D-arrays + """ + if isinstance(img, sitk.SimpleITK.Image): + img = sitk.GetArrayViewFromImage(img) + elif isinstance(img, np.ndarray): + #assume channels dim is smallest and in either first or last place + if np.argmin(img.shape)==2: + img = np.moveaxis(img, 2,0) + else: + raise Exception("view_slices got unexpected img type.") + + if seg is not None: + if isinstance(seg, sitk.SimpleITK.Image): + seg = sitk.GetArrayViewFromImage(seg) + elif isinstance(img, np.ndarray): + if np.argmin(seg.shape)==2: + seg = np.moveaxis(seg, 2,0) + else: + raise Exception("view_slices got unexpected seg type.") + + if label_remap is not None: + for (key, val) in label_remap.items(): + seg[seg==key] = val + + if instance_labels: + class Label(): + def __init__(self, id, name, color): + self.id = id + self.name = name + self.color = color + + legend_items = {Label(seg_id, "instance_{}".format(seg_id), + cf.color_palette[seg_id%len(cf.color_palette)]) for + seg_id in np.unique(seg)} + if cmap is None: + cmap = {label.id : label.color for label in legend_items} + else: + legend_items = {cf.seg_id2label[seg_id] for seg_id in np.unique(seg)} + if cmap is None: + cmap = {label.id : label.color for label in legend_items} + + + slices = img.shape[0] + if seg is not None: + assert slices==seg.shape[0], "Img and seg have different amt of slices." + grid = gridspec.GridSpec(int(np.ceil(slices/4)),4) + fig = plt.figure(figsize=(10, slices/4*2.5)) + rng = np.arange(slices, dtype='uint8') + if not ids is None: + rng = rng[ids] + for s in rng: + ax = fig.add_subplot(grid[int(s/4),int(s%4)]) + ax.imshow(img[s], cmap="gray") + if not seg is None: + ax.imshow(to_rgba(seg[s], cmap), alpha=0.9) + if legend and int(s/4)==0 and int(s%4)==3: + patches = [mpatches.Patch(color=label.color, + label="{}".format(label.name)) for label in legend_items] + ncols = 1 + plt.legend(handles=patches,bbox_to_anchor=(1.05, 1), loc=2, + borderaxespad=0., ncol=ncols) + plt.title("slice {}, {}".format(s, img[s].shape)) + plt.axis('off') + + plt.suptitle(title) + if out_dir is not None: + plt.savefig(out_dir, dpi=300, pad_inches=0.0, bbox_inches='tight') + if not cf.server_env: + plt.show() + plt.close() + + +def plot_txt(cf, txts, labels=None, title="", x_label="", y_labels=["",""], y_ranges=(None,None), + twin_axes=(), smooth=None, out_dir=None): + """Read and plot txt data, either from file (txts is paths) or directly (txts is arrays). + + :param twin_axes: plot two y-axis over same x-axis. twin_axes expected as + tuple defining which txt files (determined via indices) share the second y-axis. + """ + if isinstance(txts, str) or not hasattr(txts, '__iter__'): + txts = [txts] + + fig = plt.figure() + ax1 = fig.add_subplot(1,1,1) + if len(twin_axes)>0: + ax2 = ax1.twinx() + for i, txt in enumerate(txts): + if isinstance(txt, str): + arr = np.genfromtxt(txt, delimiter=',',skip_header=1, usecols=(1,2)) + else: + arr = txt + if i in twin_axes: + ax = ax2 + else: + ax = ax1 + if smooth is not None: + spline_graph = interpol.UnivariateSpline(arr[:,0], arr[:,1], k=5, s=float(smooth)) + ax.plot(arr[:, 0], spline_graph(arr[:,0]), color=cf.color_palette[i % len(cf.color_palette)], + marker='', markersize=2, linestyle='solid') + ax.plot(arr[:,0], arr[:,1], color=cf.color_palette[i%len(cf.color_palette)], + marker='', markersize=2, linestyle='solid', label=labels[i], alpha=0.5 if smooth else 1.) + plt.title(title) + + ax1.set_xlabel(x_label) + ax1.set_ylabel(y_labels[0]) + if y_ranges[0] is not None: + ax1.set_ylim(y_ranges[0]) + if len(twin_axes)>0: + ax2.set_ylabel(y_labels[1]) + if y_ranges[1] is not None: + ax2.set_ylim(y_ranges[1]) + + plt.grid() + + if labels is not None: + ax1.legend(loc="upper center") + if len(twin_axes)>0: + ax2.legend(loc=4) + + if out_dir is not None: + plt.savefig(out_dir, dpi=200) + return fig + +def plot_tboard_logs(cf, log_dir, tag_filters=[""], inclusive_filters=True, out_dir=None, x_label="", + y_labels=["",""], y_ranges=(None,None), twin_axes=(), smooth=None): + """Plot (only) tboard scalar logs from given log_dir for multiple runs sorted by tag. + """ + print("log dir", log_dir) + mpl = EventMultiplexer().AddRunsFromDirectory(log_dir) #EventAccumulator(log_dir) + mpl.Reload() + + # Print tags of contained entities, use these names to retrieve entities as below + #print(mpl.Runs()) + scalars = {runName : data['scalars'] for (runName, data) in mpl.Runs().items() if len(data['scalars'])>0} + print("scalars", scalars) + tags = {} + tag_filters = [tag_filter.lower() for tag_filter in tag_filters] + for (runName, runtags) in scalars.items(): + print("rn", runName.lower()) + check = np.any if inclusive_filters else np.all + if np.any([tag_filter in runName.lower() for tag_filter in tag_filters]): + for runtag in runtags: + #if tag_filter in runtag.lower(): + if runtag not in tags: + tags[runtag] = [runName] + else: + tags[runtag].append(runName) + print("tags ", tags) + for (tag, runNames) in tags.items(): + print("runnames ", runNames) + print("tag", tag) + tag_scalars = [] + labels = [] + for run in runNames: + #mpl.Scalars returns ScalarEvents array holding wall_time, step, value per time step (shape series_length x 3) + #print(mpl.Scalars(runName, tag)[0]) + run_scalars = [(s.step, s.value) for s in mpl.Scalars(run, tag)] + print(np.array(run_scalars).shape) + tag_scalars.append(np.array(run_scalars)) + print("run", run) + labels.append("/".join(run.split("/")[-2:])) + #print("tag scalars ", tag_scalars) + if out_dir is not None: + out_path = os.path.join(out_dir,tag.replace("/","_")) + else: + out_path = None + plot_txt(txts=tag_scalars, labels=labels, title=tag, out_dir=out_path, cf=cf, + x_label=x_label, y_labels=y_labels, y_ranges=y_ranges, twin_axes=twin_axes, smooth=smooth) + + +def plot_box_legend(cf, box_coords=None, class_id=None, out_dir=None): + """plot a blank box explaining box annotations. + :param cf: + :return: + """ + if class_id is None: + class_id = 1 + + img = np.ones(cf.patch_size[:2]) + dim_max = max(cf.patch_size[:2]) + width, height = cf.patch_size[0] // 2, cf.patch_size[1] // 2 + if box_coords is None: + # lower left corner + x1, y1 = width // 2, height // 2 + x2, y2 = x1 + width, y1 + height + else: + y1, x1, y2, x2 = box_coords + + fig = plt.figure(tight_layout=True, dpi=300) + ax = fig.add_subplot(111) + title_fs = 36 + label = cf.class_id2label[class_id] + # legend_items.add(label) + ax.set_facecolor(cf.beige) + ax.imshow(img, cmap='gray', vmin=0., vmax=1., alpha=0) + # ax.axis('off') + # suppress_axes_lines(ax) + ax.set_xticks([]) + ax.set_yticks([]) + + text_x, text_y = x2 * 0.85, y1 + id_text = "class id" + " | " if cf.plot_class_ids else "" + text_str = '{}{}'.format(id_text, "confidence") + text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, + pad=0) + ax.text(text_x, text_y, text_str, color=cf.white, + bbox=text_settings, fontsize=title_fs // 4) + edgecolor = label.color + if any(['regression' in task for task in cf.prediction_tasks]): + text_x, text_y = x2 * 0.85, y2 + id_text = "regression" + if any(['ken_gal' in task or 'feindt' in task for task in cf.prediction_tasks]): + id_text += " | uncertainty" + text_str = '{}'.format(id_text) + ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) + if 'regression_bin' in cf.prediction_tasks or hasattr(cf, "rg_val_to_bin_id"): + text_x, text_y = x1, y2 + text_str = 'Rg. Bin' + ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) + + if 'lesion_gleasons' in cf.observables_rois: + text_x, text_y = x1, y1 + text_str = 'Gleason Score' + ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) + + bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=1., edgecolor=edgecolor, facecolor='none') + ax.add_patch(bbox) + if out_dir is not None: + plt.savefig(os.path.join(out_dir, "box_legend.png")) + +def plot_boxes(cf, box_coords, patch_size=None, scores=None, class_ids=None, out_file=None, ax=None): + + if patch_size is None: + patch_size = cf.patch_size[:2] + if class_ids is None: + class_ids = np.ones((len(box_coords),), dtype='uint8') + if scores is None: + scores = np.ones((len(box_coords),), dtype='uint8') + + img = np.ones(patch_size) + + y1, x1, y2, x2 = box_coords[:,0], box_coords[:,1], box_coords[:,2], box_coords[:,3] + width, height = x2-x1, y2-y1 + + close = False + if ax is None: + fig = plt.figure(tight_layout=True, dpi=300) + ax = fig.add_subplot(111) + close = True + title_fs = 56 + + ax.set_facecolor((*cf.gray,0.15)) + ax.imshow(img, cmap='gray', vmin=0., vmax=1., alpha=0) + #ax.axis('off') + #suppress_axes_lines(ax) + ax.set_xticks([]) + ax.set_yticks([]) + + for bix, cl_id in enumerate(class_ids): + label = cf.class_id2label[cl_id] + text_x, text_y = x2[bix] -20, y1[bix] +5 + id_text = class_ids[bix] if cf.plot_class_ids else "" + text_str = '{}{}{:.0f}'.format(id_text, " | ", scores[bix] * 100) + text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) + ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) + edgecolor = label.color + + bbox = mpatches.Rectangle((x1[bix], y1[bix]), width[bix], height[bix], linewidth=1., edgecolor=edgecolor, facecolor='none') + ax.add_patch(bbox) + + if out_file is not None: + plt.savefig(out_file) + if close: + plt.close() + + + +if __name__=="__main__": + cluster_exp_root = "/mnt/E132-Cluster-Projects" + #dataset="prostate/" + dataset = "lidc/" + exp_name = "ms13_mrcnnal3d_rg_bs8_480k" + #exp_dir = os.path.join("datasets", dataset, "experiments", exp_name) + # exp_dir = os.path.join(cluster_exp_root, dataset, "experiments", exp_name) + # log_dir = os.path.join(exp_dir, "logs") + # sys.path.append(exp_dir) + # from configs import Configs + # cf = configs() + # + # #print("logdir", log_dir) + # #out_dir = os.path.join(cf.source_dir, log_dir.replace("/", "_")) + # #print("outdir", out_dir) + # log_dir = os.path.join(cf.source_dir, log_dir) + # plot_tboard_logs(cf, log_dir, tag_filters=["train/lesion_avp", "val/lesion_ap", "val/lesion_avp", "val/patient_lesion_avp"], smooth=2.2, out_dir=log_dir, # y_ranges=([0,900], [0,0.8]), + # twin_axes=[1], y_labels=["counts",""], x_label="epoch") + + #plot_box_legend(cf, out_dir=exp_dir) + + + diff --git a/predictor.py b/predictor.py new file mode 100644 index 0000000..6b92782 --- /dev/null +++ b/predictor.py @@ -0,0 +1,1005 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +from multiprocessing import Pool +import pickle +import time +import copy + +import numpy as np +import torch +from scipy.stats import norm +from collections import OrderedDict +import pandas as pd + +import plotting as plg +import utils.model_utils as mutils +import utils.exp_utils as utils + + +def get_mirrored_patch_crops(patch_crops, org_img_shape): + mirrored_patch_crops = [] + mirrored_patch_crops.append([[org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2], ii[3]] + if len(ii) == 4 else [org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2], + ii[3], ii[4], ii[5]] + for ii in patch_crops]) + + mirrored_patch_crops.append([[ii[0], ii[1], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2]] + if len(ii) == 4 else [ii[0], ii[1], org_img_shape[3] - ii[3], + org_img_shape[3] - ii[2], ii[4], ii[5]] + for ii in patch_crops]) + + mirrored_patch_crops.append([[org_img_shape[2] - ii[1], + org_img_shape[2] - ii[0], + org_img_shape[3] - ii[3], + org_img_shape[3] - ii[2]] + if len(ii) == 4 else + [org_img_shape[2] - ii[1], + org_img_shape[2] - ii[0], + org_img_shape[3] - ii[3], + org_img_shape[3] - ii[2], ii[4], ii[5]] + for ii in patch_crops]) + + return mirrored_patch_crops + +def get_mirrored_patch_crops_ax_dep(patch_crops, org_img_shape, mirror_axes): + mirrored_patch_crops = [] + for ax_ix, axes in enumerate(mirror_axes): + if isinstance(axes, (int, float)) and int(axes) == 0: + mirrored_patch_crops.append([[org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2], ii[3]] + if len(ii) == 4 else [org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], + ii[2], ii[3], ii[4], ii[5]] + for ii in patch_crops]) + elif isinstance(axes, (int, float)) and int(axes) == 1: + mirrored_patch_crops.append([[ii[0], ii[1], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2]] + if len(ii) == 4 else [ii[0], ii[1], org_img_shape[3] - ii[3], + org_img_shape[3] - ii[2], ii[4], ii[5]] + for ii in patch_crops]) + elif hasattr(axes, "__iter__") and (tuple(axes) == (0, 1) or tuple(axes) == (1, 0)): + mirrored_patch_crops.append([[org_img_shape[2] - ii[1], + org_img_shape[2] - ii[0], + org_img_shape[3] - ii[3], + org_img_shape[3] - ii[2]] + if len(ii) == 4 else + [org_img_shape[2] - ii[1], + org_img_shape[2] - ii[0], + org_img_shape[3] - ii[3], + org_img_shape[3] - ii[2], ii[4], ii[5]] + for ii in patch_crops]) + else: + raise Exception("invalid mirror axes {} in get mirrored patch crops".format(axes)) + + return mirrored_patch_crops + +def apply_wbc_to_patient(inputs): + """ + wrapper around prediction box consolidation: weighted box clustering (wbc). processes a single patient. + loops over batch elements in patient results (1 in 3D, slices in 2D) and foreground classes, + aggregates and stores results in new list. + :return. patient_results_list: list over batch elements. each element is a list over boxes, where each box is + one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D + predictions, and a dummy batch dimension of 1 for 3D predictions. + :return. pid: string. patient id. + """ + regress_flag, in_patient_results_list, pid, class_dict, clustering_iou, n_ens = inputs + out_patient_results_list = [[] for _ in range(len(in_patient_results_list))] + + for bix, b in enumerate(in_patient_results_list): + + for cl in list(class_dict.keys()): + + boxes = [(ix, box) for ix, box in enumerate(b) if + (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)] + box_coords = np.array([b[1]['box_coords'] for b in boxes]) + box_scores = np.array([b[1]['box_score'] for b in boxes]) + box_center_factor = np.array([b[1]['box_patch_center_factor'] for b in boxes]) + box_n_overlaps = np.array([b[1]['box_n_overlaps'] for b in boxes]) + try: + box_patch_id = np.array([b[1]['patch_id'] for b in boxes]) + except KeyError: #backward compatibility for already saved pred results ... omg + box_patch_id = np.array([b[1]['ens_ix'] for b in boxes]) + box_regressions = np.array([b[1]['regression'] for b in boxes]) if regress_flag else None + box_rg_bins = np.array([b[1]['rg_bin'] if 'rg_bin' in b[1].keys() else float('NaN') for b in boxes]) + box_rg_uncs = np.array([b[1]['rg_uncertainty'] if 'rg_uncertainty' in b[1].keys() else float('NaN') for b in boxes]) + + if 0 not in box_scores.shape: + keep_scores, keep_coords, keep_n_missing, keep_regressions, keep_rg_bins, keep_rg_uncs = \ + weighted_box_clustering(box_coords, box_scores, box_center_factor, box_n_overlaps, box_rg_bins, box_rg_uncs, + box_regressions, box_patch_id, clustering_iou, n_ens) + + + for boxix in range(len(keep_scores)): + clustered_box = {'box_type': 'det', 'box_coords': keep_coords[boxix], + 'box_score': keep_scores[boxix], 'cluster_n_missing': keep_n_missing[boxix], + 'box_pred_class_id': cl} + if regress_flag: + clustered_box.update({'regression': keep_regressions[boxix], + 'rg_uncertainty': keep_rg_uncs[boxix], + 'rg_bin': keep_rg_bins[boxix]}) + + out_patient_results_list[bix].append(clustered_box) + + # add gt boxes back to new output list. + out_patient_results_list[bix].extend([box for box in b if box['box_type'] == 'gt']) + + return [out_patient_results_list, pid] + + +def weighted_box_clustering(box_coords, scores, box_pc_facts, box_n_ovs, box_rg_bins, box_rg_uncs, + box_regress, box_patch_id, thresh, n_ens): + """Consolidates overlapping predictions resulting from patch overlaps, test data augmentations and temporal ensembling. + clusters predictions together with iou > thresh (like in NMS). Output score and coordinate for one cluster are the + average weighted by individual patch center factors (how trustworthy is this candidate measured by how centered + its position within the patch is) and the size of the corresponding box. + The number of expected predictions at a position is n_data_aug * n_temp_ens * n_overlaps_at_position + (1 prediction per unique patch). Missing predictions at a cluster position are defined as the number of unique + patches in the cluster, which did not contribute any predict any boxes. + :param dets: (n_dets, (y1, x1, y2, x2, (z1), (z2), scores, box_pc_facts, box_n_ovs). + :param box_coords: y1, x1, y2, x2, (z1), (z2). + :param scores: confidence scores. + :param box_pc_facts: patch-center factors from position on patch tiles. + :param box_n_ovs: number of patch overlaps at box position. + :param box_rg_bins: regression bin predictions. + :param box_rg_uncs: (n_dets,) regression uncertainties (from model mrcnn_aleatoric). + :param box_regress: (n_dets, n_regression_features). + :param box_patch_id: ensemble index. + :param thresh: threshold for iou_matching. + :param n_ens: number of models, that are ensembled. (-> number of expected predictions per position). + :return: keep_scores: (n_keep) new scores of boxes to be kept. + :return: keep_coords: (n_keep, (y1, x1, y2, x2, (z1), (z2)) new coordinates of boxes to be kept. + """ + + dim = 2 if box_coords.shape[1] == 4 else 3 + y1 = box_coords[:,0] + x1 = box_coords[:,1] + y2 = box_coords[:,2] + x2 = box_coords[:,3] + + areas = (y2 - y1 + 1) * (x2 - x1 + 1) + if dim == 3: + z1 = box_coords[:, 4] + z2 = box_coords[:, 5] + areas *= (z2 - z1 + 1) + + # order is the sorted index. maps order to index o[1] = 24 (rank1, ix 24) + order = scores.argsort()[::-1] + + keep_scores = [] + keep_coords = [] + keep_n_missing = [] + keep_regress = [] + keep_rg_bins = [] + keep_rg_uncs = [] + + while order.size > 0: + i = order[0] # highest scoring element + yy1 = np.maximum(y1[i], y1[order]) + xx1 = np.maximum(x1[i], x1[order]) + yy2 = np.minimum(y2[i], y2[order]) + xx2 = np.minimum(x2[i], x2[order]) + + w = np.maximum(0, xx2 - xx1 + 1) + h = np.maximum(0, yy2 - yy1 + 1) + inter = w * h + + if dim == 3: + zz1 = np.maximum(z1[i], z1[order]) + zz2 = np.minimum(z2[i], z2[order]) + d = np.maximum(0, zz2 - zz1 + 1) + inter *= d + + # overlap between currently highest scoring box and all boxes. + ovr = inter / (areas[i] + areas[order] - inter) + ovr_fl = inter.astype('float64') / (areas[i] + areas[order] - inter.astype('float64')) + assert np.all(ovr==ovr_fl), "ovr {}\n ovr_float {}".format(ovr, ovr_fl) + # get all the predictions that match the current box to build one cluster. + matches = np.nonzero(ovr > thresh)[0] + + match_n_ovs = box_n_ovs[order[matches]] + match_pc_facts = box_pc_facts[order[matches]] + match_patch_id = box_patch_id[order[matches]] + match_ov_facts = ovr[matches] + match_areas = areas[order[matches]] + match_scores = scores[order[matches]] + + # weight all scores in cluster by patch factors, and size. + match_score_weights = match_ov_facts * match_areas * match_pc_facts + match_scores *= match_score_weights + + # for the weighted average, scores have to be divided by the number of total expected preds at the position + # of the current cluster. 1 Prediction per patch is expected. therefore, the number of ensembled models is + # multiplied by the mean overlaps of patches at this position (boxes of the cluster might partly be + # in areas of different overlaps). + n_expected_preds = n_ens * np.mean(match_n_ovs) + # the number of missing predictions is obtained as the number of patches, + # which did not contribute any prediction to the current cluster. + n_missing_preds = np.max((0, n_expected_preds - np.unique(match_patch_id).shape[0])) + + # missing preds are given the mean weighting + # (expected prediction is the mean over all predictions in cluster). + denom = np.sum(match_score_weights) + n_missing_preds * np.mean(match_score_weights) + + # compute weighted average score for the cluster + avg_score = np.sum(match_scores) / denom + + # compute weighted average of coordinates for the cluster. now only take existing + # predictions into account. + avg_coords = [np.sum(y1[order[matches]] * match_scores) / np.sum(match_scores), + np.sum(x1[order[matches]] * match_scores) / np.sum(match_scores), + np.sum(y2[order[matches]] * match_scores) / np.sum(match_scores), + np.sum(x2[order[matches]] * match_scores) / np.sum(match_scores)] + + if dim == 3: + avg_coords.append(np.sum(z1[order[matches]] * match_scores) / np.sum(match_scores)) + avg_coords.append(np.sum(z2[order[matches]] * match_scores) / np.sum(match_scores)) + + if box_regress is not None: + # compute wt. avg. of regression vectors (component-wise average) + avg_regress = np.sum(box_regress[order[matches]] * match_scores[:, np.newaxis], axis=0) / np.sum( + match_scores) + avg_rg_bins = np.round(np.sum(box_rg_bins[order[matches]] * match_scores) / np.sum(match_scores)) + avg_rg_uncs = np.sum(box_rg_uncs[order[matches]] * match_scores) / np.sum(match_scores) + else: + avg_regress = np.array(float('NaN')) + avg_rg_bins = np.array(float('NaN')) + avg_rg_uncs = np.array(float('NaN')) + + # some clusters might have very low scores due to high amounts of missing predictions. + # filter out the with a conservative threshold, to speed up evaluation. + if avg_score > 0.01: + keep_scores.append(avg_score) + keep_coords.append(avg_coords) + keep_n_missing.append((n_missing_preds / n_expected_preds * 100)) # relative + keep_regress.append(avg_regress) + keep_rg_uncs.append(avg_rg_uncs) + keep_rg_bins.append(avg_rg_bins) + + # get index of all elements that were not matched and discard all others. + inds = np.nonzero(ovr <= thresh)[0] + inds_where = np.where(ovr<=thresh)[0] + assert np.all(inds == inds_where), "inds_nonzero {} \ninds_where {}".format(inds, inds_where) + order = order[inds] + + return keep_scores, keep_coords, keep_n_missing, keep_regress, keep_rg_bins, keep_rg_uncs + + +def apply_nms_to_patient(inputs): + + in_patient_results_list, pid, class_dict, iou_thresh = inputs + out_patient_results_list = [] + + + # collect box predictions over batch dimension (slices) and store slice info as slice_ids. + for batch in in_patient_results_list: + batch_el_boxes = [] + for cl in list(class_dict.keys()): + det_boxes = [box for box in batch if (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)] + + box_coords = np.array([box['box_coords'] for box in det_boxes]) + box_scores = np.array([box['box_score'] for box in det_boxes]) + if 0 not in box_scores.shape: + keep_ix = mutils.nms_numpy(box_coords, box_scores, iou_thresh) + else: + keep_ix = [] + + batch_el_boxes += [det_boxes[ix] for ix in keep_ix] + + batch_el_boxes += [box for box in batch if box['box_type'] == 'gt'] + out_patient_results_list.append(batch_el_boxes) + + assert len(in_patient_results_list) == len(out_patient_results_list), "batch dim needs to be maintained, in: {}, out {}".format(len(in_patient_results_list), len(out_patient_results_list)) + + return [out_patient_results_list, pid] + +def nms_2to3D(dets, thresh): + """ + Merges 2D boxes to 3D cubes. For this purpose, boxes of all slices are regarded as lying in one slice. + An adaptation of Non-maximum suppression is applied where clusters are found (like in NMS) with the extra constraint + that suppressed boxes have to have 'connected' z coordinates w.r.t the core slice (cluster center, highest + scoring box, the prevailing box). 'connected' z-coordinates are determined + as the z-coordinates with predictions until the first coordinate for which no prediction is found. + + example: a cluster of predictions was found overlap > iou thresh in xy (like NMS). The z-coordinate of the highest + scoring box is 50. Other predictions have 23, 46, 48, 49, 51, 52, 53, 56, 57. + Only the coordinates connected with 50 are clustered to one cube: 48, 49, 51, 52, 53. (46 not because nothing was + found in 47, so 47 is a 'hole', which interrupts the connection). Only the boxes corresponding to these coordinates + are suppressed. All others are kept for building of further clusters. + + This algorithm works better with a certain min_confidence of predictions, because low confidence (e.g. noisy/cluttery) + predictions can break the relatively strong assumption of defining cubes' z-boundaries at the first 'hole' in the cluster. + + :param dets: (n_detections, (y1, x1, y2, x2, scores, slice_id) + :param thresh: iou matchin threshold (like in NMS). + :return: keep: (n_keep,) 1D tensor of indices to be kept. + :return: keep_z: (n_keep, [z1, z2]) z-coordinates to be added to boxes, which are kept in order to form cubes. + """ + + y1 = dets[:, 0] + x1 = dets[:, 1] + y2 = dets[:, 2] + x2 = dets[:, 3] + assert np.all(y1 <= y2) and np.all(x1 <= x2), """"the definition of the coordinates is crucially important here: + where maximum is taken needs to be the lower coordinate""" + scores = dets[:, -2] + slice_id = dets[:, -1] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + keep_z = [] + + while order.size > 0: # order is the sorted index. maps order to index: order[1] = 24 means (rank1, ix 24) + i = order[0] # highest scoring element + yy1 = np.maximum(y1[i], y1[order]) # highest scoring element still in >order<, is compared to itself: okay? + xx1 = np.maximum(x1[i], x1[order]) + yy2 = np.minimum(y2[i], y2[order]) + xx2 = np.minimum(x2[i], x2[order]) + + h = np.maximum(0.0, yy2 - yy1 + 1) + w = np.maximum(0.0, xx2 - xx1 + 1) + inter = h * w + + iou = inter / (areas[i] + areas[order] - inter) + matches = np.argwhere( + iou > thresh) # get all the elements that match the current box and have a lower score + + slice_ids = slice_id[order[matches]] + core_slice = slice_id[int(i)] + upper_holes = [ii for ii in np.arange(core_slice, np.max(slice_ids)) if ii not in slice_ids] + lower_holes = [ii for ii in np.arange(np.min(slice_ids), core_slice) if ii not in slice_ids] + max_valid_slice_id = np.min(upper_holes) if len(upper_holes) > 0 else np.max(slice_ids) + min_valid_slice_id = np.max(lower_holes) if len(lower_holes) > 0 else np.min(slice_ids) + z_matches = matches[(slice_ids <= max_valid_slice_id) & (slice_ids >= min_valid_slice_id)] + + # expand by one z voxel since box content is surrounded w/o overlap, i.e., z-content computed as z2-z1 + z1 = np.min(slice_id[order[z_matches]]) - 1 + z2 = np.max(slice_id[order[z_matches]]) + 1 + + keep.append(i) + keep_z.append([z1, z2]) + order = np.delete(order, z_matches, axis=0) + + return keep, keep_z + +def apply_2d_3d_merging_to_patient(inputs): + """ + wrapper around 2Dto3D merging operation. Processes a single patient. Takes 2D patient results (slices in batch dimension) + and returns 3D patient results (dummy batch dimension of 1). Applies an adaption of Non-Maximum Surpression + (Detailed methodology is described in nms_2to3D). + :return. results_dict_boxes: list over batch elements (1 in 3D). each element is a list over boxes, where each box is + one dictionary: [[box_0, ...], [box_n,...]]. + :return. pid: string. patient id. + """ + + in_patient_results_list, pid, class_dict, merge_3D_iou = inputs + out_patient_results_list = [] + + for cl in list(class_dict.keys()): + det_boxes, slice_ids = [], [] + # collect box predictions over batch dimension (slices) and store slice info as slice_ids. + for batch_ix, batch in enumerate(in_patient_results_list): + batch_element_det_boxes = [(ix, box) for ix, box in enumerate(batch) if + (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)] + det_boxes += batch_element_det_boxes + slice_ids += [batch_ix] * len(batch_element_det_boxes) + + box_coords = np.array([batch[1]['box_coords'] for batch in det_boxes]) + box_scores = np.array([batch[1]['box_score'] for batch in det_boxes]) + slice_ids = np.array(slice_ids) + + if 0 not in box_scores.shape: + keep_ix, keep_z = nms_2to3D( + np.concatenate((box_coords, box_scores[:, None], slice_ids[:, None]), axis=1), merge_3D_iou) + else: + keep_ix, keep_z = [], [] + + # store kept predictions in new results list and add corresponding z-dimension info to coordinates. + # for kix, kz in zip(keep_ix, keep_z): + # out_patient_results_list.append({'box_type': 'det', 'box_coords': list(box_coords[kix]) + kz, + # 'box_score': box_scores[kix], 'box_pred_class_id': cl}) + for kix, kz in zip(keep_ix, keep_z): + keep_box = det_boxes[kix][1] + keep_box['box_coords'] = list(keep_box['box_coords']) + kz + out_patient_results_list.append(keep_box) + + gt_boxes = [box for b in in_patient_results_list for box in b if box['box_type'] == 'gt'] + if len(gt_boxes) > 0: + assert np.all([len(box["box_coords"]) == 6 for box in gt_boxes]), "expanded preds to 3D but GT is 2D." + out_patient_results_list += gt_boxes + + return [[out_patient_results_list], pid] # additional list wrapping is extra batch dim. + + +class Predictor: + """ + Prediction pipeline: + - receives a patched patient image (n_patches, c, y, x, (z)) from patient data loader. + - forwards patches through model in chunks of batch_size. (method: batch_tiling_forward) + - unmolds predictions (boxes and segmentations) to original patient coordinates. (method: spatial_tiling_forward) + + Ensembling (mode == 'test'): + - for inference, forwards 4 mirrored versions of image to through model and unmolds predictions afterwards + accordingly (method: data_aug_forward) + - for inference, loads multiple parameter-sets of the trained model corresponding to different epochs. for each + parameter-set loops over entire test set, runs prediction pipeline for each patient. (method: predict_test_set) + + Consolidation of predictions: + - consolidates a patient's predictions (boxes, segmentations) collected over patches, data_aug- and temporal ensembling, + performs clustering and weighted averaging (external function: apply_wbc_to_patient) to obtain consistent outptus. + - for 2D networks, consolidates box predictions to 3D cubes via clustering (adaption of non-maximum surpression). + (external function: apply_2d_3d_merging_to_patient) + + Ground truth handling: + - dissmisses any ground truth boxes returned by the model (happens in validation mode, patch-based groundtruth) + - if provided by data loader, adds patient-wise ground truth to the final predictions to be passed to the evaluator. + """ + def __init__(self, cf, net, logger, mode): + + self.cf = cf + self.batch_size = cf.batch_size + self.logger = logger + self.mode = mode + self.net = net + self.n_ens = 1 + self.rank_ix = '0' + self.regress_flag = any(['regression' in task for task in self.cf.prediction_tasks]) + + if self.cf.merge_2D_to_3D_preds: + assert self.cf.dim == 2, "Merge 2Dto3D only valid for 2D preds, but current dim is {}.".format(self.cf.dim) + + if self.mode == 'test': + try: + self.epoch_ranking = np.load(os.path.join(self.cf.fold_dir, 'epoch_ranking.npy'))[:cf.test_n_epochs] + except: + raise RuntimeError('no epoch ranking file in fold directory. ' + 'seems like you are trying to run testing without prior training...') + self.n_ens = cf.test_n_epochs + if self.cf.test_aug_axes is not None: + self.n_ens *= (len(self.cf.test_aug_axes)+1) + self.example_plot_dir = os.path.join(cf.test_dir, "example_plots") + os.makedirs(self.example_plot_dir, exist_ok=True) + + def batch_tiling_forward(self, batch): + """ + calls the actual network forward method. in patch-based prediction, the batch dimension might be overladed + with n_patches >> batch_size, which would exceed gpu memory. In this case, batches are processed in chunks of + batch_size. validation mode calls the train method to monitor losses (returned ground truth objects are discarded). + test mode calls the test forward method, no ground truth required / involved. + :return. results_dict: stores the results for one patient. dictionary with keys: + - 'boxes': list over batch elements. each element is a list over boxes, where each box is + one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, + and a dummy batch dimension of 1 for 3D predictions. + - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) + - loss / class_loss (only in validation mode) + """ + #self.logger.info('forwarding (patched) patient with shape: {}'.format(batch['data'].shape)) + + img = batch['data'] + + if img.shape[0] <= self.batch_size: + + if self.mode == 'val': + # call training method to monitor losses + results_dict = self.net.train_forward(batch, is_validation=True) + # discard returned ground-truth boxes (also training info boxes). + results_dict['boxes'] = [[box for box in b if box['box_type'] == 'det'] for b in results_dict['boxes']] + elif self.mode == 'test': + results_dict = self.net.test_forward(batch, return_masks=self.cf.return_masks_in_test) + + else: #needs batch tiling + split_ixs = np.split(np.arange(img.shape[0]), np.arange(img.shape[0])[::self.batch_size]) + chunk_dicts = [] + for chunk_ixs in split_ixs[1:]: # first split is elements before 0, so empty + b = {k: batch[k][chunk_ixs] for k in batch.keys() + if (isinstance(batch[k], np.ndarray) and batch[k].shape[0] == img.shape[0])} + if self.mode == 'val': + chunk_dicts += [self.net.train_forward(b, is_validation=True)] + else: + chunk_dicts += [self.net.test_forward(b, return_masks=self.cf.return_masks_in_test)] + + results_dict = {} + # flatten out batch elements from chunks ([chunk, chunk] -> [b, b, b, b, ...]) + results_dict['boxes'] = [item for d in chunk_dicts for item in d['boxes']] + results_dict['seg_preds'] = np.array([item for d in chunk_dicts for item in d['seg_preds']]) + + if self.mode == 'val': + # if hasattr(self.cf, "losses_to_monitor"): + # loss_names = self.cf.losses_to_monitor + # else: + # loss_names = {name for dic in chunk_dicts for name in dic if 'loss' in name} + # estimate patient loss by mean over batch_chunks. Most similar to training loss. + results_dict['torch_loss'] = torch.mean(torch.cat([d['torch_loss'] for d in chunk_dicts])) + results_dict['class_loss'] = np.mean([d['class_loss'] for d in chunk_dicts]) + # discard returned ground-truth boxes (also training info boxes). + results_dict['boxes'] = [[box for box in b if box['box_type'] == 'det'] for b in results_dict['boxes']] + + return results_dict + + def spatial_tiling_forward(self, batch, patch_crops = None, n_aug='0'): + """ + forwards batch to batch_tiling_forward method and receives and returns a dictionary with results. + if patch-based prediction, the results received from batch_tiling_forward will be on a per-patch-basis. + this method uses the provided patch_crops to re-transform all predictions to whole-image coordinates. + Patch-origin information of all box-predictions will be needed for consolidation, hence it is stored as + 'patch_id', which is a unique string for each patch (also takes current data aug and temporal epoch instances + into account). all box predictions get additional information about the amount overlapping patches at the + respective position (used for consolidation). + :return. results_dict: stores the results for one patient. dictionary with keys: + - 'boxes': list over batch elements. each element is a list over boxes, where each box is + one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, + and a dummy batch dimension of 1 for 3D predictions. + - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) + - monitor_values (only in validation mode) + returned dict is a flattened version with 1 batch instance (3D) or slices (2D) + """ + + if patch_crops is not None: + #print("patch_crops not None, applying patch center factor") + + patches_dict = self.batch_tiling_forward(batch) + results_dict = {'boxes': [[] for _ in range(batch['original_img_shape'][0])]} + #bc of ohe--> channel dim of seg has size num_classes + out_seg_shape = list(batch['original_img_shape']) + out_seg_shape[1] = patches_dict["seg_preds"].shape[1] + out_seg_preds = np.zeros(out_seg_shape, dtype=np.float16) + patch_overlap_map = np.zeros_like(out_seg_preds, dtype='uint8') + for pix, pc in enumerate(patch_crops): + if self.cf.dim == 3: + out_seg_preds[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += patches_dict['seg_preds'][pix] + patch_overlap_map[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += 1 + elif self.cf.dim == 2: + out_seg_preds[pc[4]:pc[5], :, pc[0]:pc[1], pc[2]:pc[3], ] += patches_dict['seg_preds'][pix] + patch_overlap_map[pc[4]:pc[5], :, pc[0]:pc[1], pc[2]:pc[3], ] += 1 + + out_seg_preds[patch_overlap_map > 0] /= patch_overlap_map[patch_overlap_map > 0] + results_dict['seg_preds'] = out_seg_preds + + for pix, pc in enumerate(patch_crops): + patch_boxes = patches_dict['boxes'][pix] + for box in patch_boxes: + + # add unique patch id for consolidation of predictions. + box['patch_id'] = self.rank_ix + '_' + n_aug + '_' + str(pix) + # boxes from the edges of a patch have a lower prediction quality, than the ones at patch-centers. + # hence they will be down-weighted for consolidation, using the 'box_patch_center_factor', which is + # obtained by a gaussian distribution over positions in the patch and average over spatial dimensions. + # Also the info 'box_n_overlaps' is stored for consolidation, which represents the amount of + # overlapping patches at the box's position. + + c = box['box_coords'] + #box_centers = np.array([(c[ii] + c[ii+2])/2 for ii in range(len(c)//2)]) + box_centers = [(c[ii] + c[ii + 2]) / 2 for ii in range(2)] + if self.cf.dim == 3: + box_centers.append((c[4] + c[5]) / 2) + box['box_patch_center_factor'] = np.mean( + [norm.pdf(bc, loc=pc, scale=pc * 0.8) * np.sqrt(2 * np.pi) * pc * 0.8 for bc, pc in + zip(box_centers, np.array(self.cf.patch_size) / 2)]) + if self.cf.dim == 3: + c += np.array([pc[0], pc[2], pc[0], pc[2], pc[4], pc[4]]) + int_c = [int(np.floor(ii)) if ix%2 == 0 else int(np.ceil(ii)) for ix, ii in enumerate(c)] + box['box_n_overlaps'] = np.mean(patch_overlap_map[:, :, int_c[1]:int_c[3], int_c[0]:int_c[2], int_c[4]:int_c[5]]) + results_dict['boxes'][0].append(box) + else: + c += np.array([pc[0], pc[2], pc[0], pc[2]]) + int_c = [int(np.floor(ii)) if ix % 2 == 0 else int(np.ceil(ii)) for ix, ii in enumerate(c)] + box['box_n_overlaps'] = np.mean( + patch_overlap_map[pc[4], :, int_c[1]:int_c[3], int_c[0]:int_c[2]]) + results_dict['boxes'][pc[4]].append(box) + + if self.mode == 'val': + results_dict['torch_loss'] = patches_dict['torch_loss'] + results_dict['class_loss'] = patches_dict['class_loss'] + + else: + results_dict = self.batch_tiling_forward(batch) + for b in results_dict['boxes']: + for box in b: + box['box_patch_center_factor'] = 1 + box['box_n_overlaps'] = 1 + box['patch_id'] = self.rank_ix + '_' + n_aug + + return results_dict + + def data_aug_forward(self, batch): + """ + in val_mode: passes batch through to spatial_tiling method without data_aug. + in test_mode: if cf.test_aug is set in configs, createst 4 mirrored versions of the input image, + passes all of them to the next processing step (spatial_tiling method) and re-transforms returned predictions + to original image version. + :return. results_dict: stores the results for one patient. dictionary with keys: + - 'boxes': list over batch elements. each element is a list over boxes, where each box is + one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, + and a dummy batch dimension of 1 for 3D predictions. + - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) + - loss / class_loss (only in validation mode) + """ + patch_crops = batch['patch_crop_coords'] if self.patched_patient else None + results_list = [self.spatial_tiling_forward(batch, patch_crops)] + org_img_shape = batch['original_img_shape'] + + if self.mode == 'test' and self.cf.test_aug_axes is not None: + if isinstance(self.cf.test_aug_axes, (int, float)): + self.cf.test_aug_axes = (self.cf.test_aug_axes,) + #assert np.all(np.array(self.cf.test_aug_axes)= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']] + assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']] + chunk_dict['boxes'][ix][boxix]['box_coords'] = coords + # re-transform segmentation predictions. + chunk_dict['seg_preds'] = np.flip(chunk_dict['seg_preds'], axis=axis) + + elif hasattr(sp_axis, "__iter__") and tuple(sp_axis)==(0,1) or tuple(sp_axis)==(1,0): + #NEED: mirrored patch crops are given as [(y-axis), (x-axis), (y-,x-axis)], obey this order! + # mirroring along two axes at same time + batch['data'] = np.flip(np.flip(img, axis=axis[0]), axis=axis[1]).copy() + chunk_dict = self.spatial_tiling_forward(batch, mirrored_patch_crops[n_aug], n_aug=str(n_aug)) + # re-transform coordinates. + for ix in range(len(chunk_dict['boxes'])): + for boxix in range(len(chunk_dict['boxes'][ix])): + coords = chunk_dict['boxes'][ix][boxix]['box_coords'].copy() + coords[sp_axis[0]] = org_img_shape[axis[0]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[0]+2] + coords[sp_axis[0]+2] = org_img_shape[axis[0]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[0]] + coords[sp_axis[1]] = org_img_shape[axis[1]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[1]+2] + coords[sp_axis[1]+2] = org_img_shape[axis[1]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[1]] + assert coords[2] >= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']] + assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']] + chunk_dict['boxes'][ix][boxix]['box_coords'] = coords + # re-transform segmentation predictions. + chunk_dict['seg_preds'] = np.flip(np.flip(chunk_dict['seg_preds'], axis=axis[0]), axis=axis[1]).copy() + + else: + raise Exception("Invalid axis type {} in test augs".format(type(axis))) + results_list.append(chunk_dict) + + batch['data'] = img + + # aggregate all boxes/seg_preds per batch element from data_aug predictions. + results_dict = {} + results_dict['boxes'] = [[item for d in results_list for item in d['boxes'][batch_instance]] + for batch_instance in range(org_img_shape[0])] + # results_dict['seg_preds'] = np.array([[item for d in results_list for item in d['seg_preds'][batch_instance]] + # for batch_instance in range(org_img_shape[0])]) + results_dict['seg_preds'] = np.stack([dic['seg_preds'] for dic in results_list], axis=1) + # needs segs probs in seg_preds entry: + results_dict['seg_preds'] = np.sum(results_dict['seg_preds'], axis=1) #add up seg probs from different augs per class + + if self.mode == 'val': + results_dict['torch_loss'] = results_list[0]['torch_loss'] + results_dict['class_loss'] = results_list[0]['class_loss'] + + return results_dict + + def load_saved_predictions(self): + """loads raw predictions saved by self.predict_test_set. aggregates and/or merges 2D boxes to 3D cubes for + evaluation (if model predicts 2D but evaluation is run in 3D), according to settings config. + :return: list_of_results_per_patient: list over patient results. each entry is a dict with keys: + - 'boxes': list over batch elements. each element is a list over boxes, where each box is + one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions + (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. + - 'batch_dices': dice scores as recorded in raw prediction results. + - 'seg_preds': not implemented yet. could replace dices by seg preds to have raw seg info available, however + would consume critically large memory amount. todo evaluation of instance/semantic segmentation. + """ + + results_file = 'pred_results.pkl' if not self.cf.held_out_test_set else 'pred_results_held_out.pkl' + if not self.cf.held_out_test_set or self.cf.eval_test_fold_wise: + self.logger.info("loading saved predictions of fold {}".format(self.cf.fold)) + with open(os.path.join(self.cf.fold_dir, results_file), 'rb') as handle: + results_list = pickle.load(handle) + box_results_list = [(res_dict["boxes"], pid) for res_dict, pid in results_list] + + da_factor = len(self.cf.test_aug_axes)+1 if self.cf.test_aug_axes is not None else 1 + self.n_ens = self.cf.test_n_epochs * da_factor + self.logger.info('loaded raw test set predictions with n_patients = {} and n_ens = {}'.format( + len(results_list), self.n_ens)) + else: + self.logger.info("loading saved predictions of hold-out test set") + fold_dirs = sorted([os.path.join(self.cf.exp_dir, f) for f in os.listdir(self.cf.exp_dir) if + os.path.isdir(os.path.join(self.cf.exp_dir, f)) and f.startswith("fold")]) + + results_list = [] + folds_loaded = 0 + for fold in range(self.cf.n_cv_splits): + fold_dir = os.path.join(self.cf.exp_dir, 'fold_{}'.format(fold)) + if fold_dir in fold_dirs: + with open(os.path.join(fold_dir, results_file), 'rb') as handle: + fold_list = pickle.load(handle) + results_list += fold_list + folds_loaded += 1 + else: + self.logger.info("Skipping fold {} since no saved predictions found.".format(fold)) + box_results_list = [] + for res_dict, pid in results_list: #without filtering gt out: + box_results_list.append((res_dict['boxes'], pid)) + #it's usually not right to filter out gts here, is it? + + da_factor = len(self.cf.test_aug_axes)+1 if self.cf.test_aug_axes is not None else 1 + self.n_ens = self.cf.test_n_epochs * da_factor * folds_loaded + + # -------------- aggregation of boxes via clustering ----------------- + + if self.cf.clustering == "wbc": + self.logger.info('applying WBC to test-set predictions with iou {} and n_ens {} over {} patients'.format( + self.cf.clustering_iou, self.n_ens, len(box_results_list))) + + mp_inputs = [[self.regress_flag, ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou, self.n_ens] for ii + in box_results_list] + del box_results_list + pool = Pool(processes=self.cf.n_workers) + box_results_list = pool.map(apply_wbc_to_patient, mp_inputs, chunksize=1) + pool.close() + pool.join() + del mp_inputs + elif self.cf.clustering == "nms": + self.logger.info('applying standard NMS to test-set predictions with iou {} over {} patients.'.format( + self.cf.clustering_iou, len(box_results_list))) + pool = Pool(processes=self.cf.n_workers) + mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou] for ii in box_results_list] + del box_results_list + box_results_list = pool.map(apply_nms_to_patient, mp_inputs, chunksize=1) + pool.close() + pool.join() + del mp_inputs + + if self.cf.merge_2D_to_3D_preds: + self.logger.info('applying 2Dto3D merging to test-set predictions with iou = {}.'.format(self.cf.merge_3D_iou)) + pool = Pool(processes=self.cf.n_workers) + mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.merge_3D_iou] for ii in box_results_list] + box_results_list = pool.map(apply_2d_3d_merging_to_patient, mp_inputs, chunksize=1) + pool.close() + pool.join() + del mp_inputs + + for ix in range(len(results_list)): + assert np.all(results_list[ix][1] == box_results_list[ix][1]), "pid mismatch between loaded and aggregated results" + results_list[ix][0]["boxes"] = box_results_list[ix][0] + + return results_list # holds (results_dict, pid) + + def predict_patient(self, batch): + """ + predicts one patient. + called either directly via loop over validation set in exec.py (mode=='val') + or from self.predict_test_set (mode=='test). + in val mode: adds 3D ground truth info to predictions and runs consolidation and 2Dto3D merging of predictions. + in test mode: returns raw predictions (ground truth addition, consolidation, 2D to 3D merging are + done in self.predict_test_set, because patient predictions across several epochs might be needed + to be collected first, in case of temporal ensembling). + :return. results_dict: stores the results for one patient. dictionary with keys: + - 'boxes': list over batch elements. each element is a list over boxes, where each box is + one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions + (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. + - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) + - loss / class_loss (only in validation mode) + """ + if self.mode=="test": + self.logger.info('predicting patient {} for fold {} '.format(np.unique(batch['pid']), self.cf.fold)) + + # True if patient is provided in patches and predictions need to be tiled. + self.patched_patient = 'patch_crop_coords' in list(batch.keys()) + + # forward batch through prediction pipeline. + results_dict = self.data_aug_forward(batch) + #has seg probs in entry 'seg_preds' + + if self.mode == 'val': + for b in range(batch['patient_bb_target'].shape[0]): + for t in range(len(batch['patient_bb_target'][b])): + gt_box = {'box_type': 'gt', 'box_coords': batch['patient_bb_target'][b][t], + 'class_targets': batch['patient_class_targets'][b][t]} + for name in self.cf.roi_items: + gt_box.update({name : batch['patient_'+name][b][t]}) + results_dict['boxes'][b].append(gt_box) + + if 'dice' in self.cf.metrics: + if self.patched_patient: + assert 'patient_seg' in batch.keys(), "Results_dict preds are in original patient shape." + results_dict['batch_dices'] = mutils.dice_per_batch_and_class( + results_dict['seg_preds'], batch["patient_seg"] if self.patched_patient else batch['seg'], + self.cf.num_seg_classes, convert_to_ohe=True) + if self.patched_patient and self.cf.clustering == "wbc": + wbc_input = [self.regress_flag, results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.clustering_iou, self.n_ens] + results_dict['boxes'] = apply_wbc_to_patient(wbc_input)[0] + elif self.patched_patient: + nms_inputs = [results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.clustering_iou] + results_dict['boxes'] = apply_nms_to_patient(nms_inputs)[0] + + if self.cf.merge_2D_to_3D_preds: + results_dict['2D_boxes'] = results_dict['boxes'] + merge_dims_inputs = [results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.merge_3D_iou] + results_dict['boxes'] = apply_2d_3d_merging_to_patient(merge_dims_inputs)[0] + + return results_dict + + def predict_test_set(self, batch_gen, return_results=True): + """ + wrapper around test method, which loads multiple (or one) epoch parameters (temporal ensembling), loops through + the test set and collects predictions per patient. Also flattens the results per patient and epoch + and adds optional ground truth boxes for evaluation. Saves out the raw result list for later analysis and + optionally consolidates and returns predictions immediately. + :return: (optionally) list_of_results_per_patient: list over patient results. each entry is a dict with keys: + - 'boxes': list over batch elements. each element is a list over boxes, where each box is + one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions + (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. + - 'seg_preds': not implemented yet. todo evaluation of instance/semantic segmentation. + """ + + # -------------- raw predicting ----------------- + dict_of_patients_results = OrderedDict() + set_of_result_types = set() + # get paths of all parameter sets to be loaded for temporal ensembling. (or just one for no temp. ensembling). + weight_paths = [os.path.join(self.cf.fold_dir, '{}_best_params.pth'.format(epoch)) for epoch in self.epoch_ranking] + + + for rank_ix, weight_path in enumerate(weight_paths): + self.logger.info(('tmp ensembling over rank_ix:{} epoch:{}'.format(rank_ix, weight_path))) + self.net.load_state_dict(torch.load(weight_path)) + self.net.eval() + self.rank_ix = str(rank_ix) + with torch.no_grad(): + plot_batches = np.random.choice(np.arange(batch_gen['n_test']), size=self.cf.n_test_plots, replace=False) + for i in range(batch_gen['n_test']): + batch = next(batch_gen['test']) + pid = np.unique(batch['pid']) + assert len(pid)==1 + pid = pid[0] + + if not pid in dict_of_patients_results.keys(): # store batch info in patient entry of results dict. + dict_of_patients_results[pid] = {} + dict_of_patients_results[pid]['results_dicts'] = [] + dict_of_patients_results[pid]['patient_bb_target'] = batch['patient_bb_target'] + + for name in self.cf.roi_items: + dict_of_patients_results[pid]["patient_"+name] = batch["patient_"+name] + stime = time.time() + results_dict = self.predict_patient(batch) #only holds "boxes", "seg_preds" + # needs ohe seg probs in seg_preds entry: + results_dict['seg_preds'] = np.argmax(results_dict['seg_preds'], axis=1)[:,np.newaxis] + self.logger.info("predicting patient {} with weight rank {} (progress: {}/{}) took {:.2f}s".format( + str(pid), rank_ix, (rank_ix)*batch_gen['n_test']+(i+1), len(weight_paths)*batch_gen['n_test'], time.time()-stime)) + + if i in plot_batches and (not self.patched_patient or 'patient_data' in batch.keys()): + try: + # view qualitative results of random test case + out_file = os.path.join(self.example_plot_dir, + 'batch_example_test_{}_rank_{}.png'.format(self.cf.fold, rank_ix)) + plg.view_batch(self.cf, batch, res_dict=results_dict, out_file=out_file, + show_seg_ids='dice' in self.cf.metrics, + has_colorchannels=self.cf.has_colorchannels, show_gt_labels=True) + except Exception as e: + self.logger.info("WARNING: error in view_batch: {}".format(e)) + + if 'dice' in self.cf.metrics: + if self.patched_patient: + assert 'patient_seg' in batch.keys(), "Results_dict preds are in original patient shape." + results_dict['batch_dices'] = mutils.dice_per_batch_and_class( results_dict['seg_preds'], + batch["patient_seg"] if self.patched_patient else batch['seg'], + self.cf.num_seg_classes, convert_to_ohe=True) + + dict_of_patients_results[pid]['results_dicts'].append({k:v for k,v in results_dict.items() + if k in ["boxes", "batch_dices"]}) + # collect result types to know which ones to look for when saving + set_of_result_types.update(dict_of_patients_results[pid]['results_dicts'][-1].keys()) + + + + # -------------- re-order, save raw results ----------------- + self.logger.info('finished predicting test set. starting aggregation of predictions.') + results_per_patient = [] + for pid, p_dict in dict_of_patients_results.items(): + # dict_of_patients_results[pid]['results_list'] has length batch['n_test'] + + results_dict = {} + # collect all boxes/seg_preds of same batch_instance over temporal instances. + b_size = len(p_dict['results_dicts'][0]["boxes"]) + for res_type in [rtype for rtype in set_of_result_types if rtype in ["boxes", "batch_dices"]]:#, "seg_preds"]]: + if not 'batch' in res_type: #assume it's results on batch-element basis + results_dict[res_type] = [[item for rank_dict in p_dict['results_dicts'] for item in rank_dict[res_type][batch_instance]] + for batch_instance in range(b_size)] + else: + results_dict[res_type] = [] + for dict in p_dict['results_dicts']: + if 'dice' in res_type: + item = dict[res_type] #dict['batch_dices'] has shape (num_seg_classes,) + assert len(item) == self.cf.num_seg_classes, \ + "{}, {}".format(len(item), self.cf.num_seg_classes) + else: + raise NotImplementedError + results_dict[res_type].append(item) + # rdict[dice] shape (n_rank_epochs (n_saved_ranks), nsegclasses) + # calc mean over test epochs so inline with shape from sampling + results_dict[res_type] = np.mean(results_dict[res_type], axis=0) #maybe error type with other than dice + + if not hasattr(self.cf, "eval_test_separately") or not self.cf.eval_test_separately: + # add unpatched 2D or 3D (if dim==3 or merge_2D_to_3D) ground truth boxes for evaluation. + for b in range(p_dict['patient_bb_target'].shape[0]): + for targ in range(len(p_dict['patient_bb_target'][b])): + gt_box = {'box_type': 'gt', 'box_coords':p_dict['patient_bb_target'][b][targ], + 'class_targets': p_dict['patient_class_targets'][b][targ]} + for name in self.cf.roi_items: + gt_box.update({name: p_dict["patient_"+name][b][targ]}) + results_dict['boxes'][b].append(gt_box) + + results_per_patient.append([results_dict, pid]) + + out_string = 'pred_results_held_out' if self.cf.held_out_test_set else 'pred_results' + with open(os.path.join(self.cf.fold_dir, '{}.pkl'.format(out_string)), 'wb') as handle: + pickle.dump(results_per_patient, handle) + + if return_results: + # -------------- results processing, clustering, etc. ----------------- + final_patient_box_results = [ (res_dict["boxes"], pid) for res_dict,pid in results_per_patient ] + if self.cf.clustering == "wbc": + self.logger.info('applying WBC to test-set predictions with iou = {} and n_ens = {}.'.format( + self.cf.clustering_iou, self.n_ens)) + mp_inputs = [[self.regress_flag, ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou, self.n_ens] for ii in final_patient_box_results] + del final_patient_box_results + pool = Pool(processes=self.cf.n_workers) + final_patient_box_results = pool.map(apply_wbc_to_patient, mp_inputs, chunksize=1) + pool.close() + pool.join() + del mp_inputs + elif self.cf.clustering == "nms": + self.logger.info('applying standard NMS to test-set predictions with iou = {}.'.format(self.cf.clustering_iou)) + pool = Pool(processes=self.cf.n_workers) + mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou] for ii in final_patient_box_results] + del final_patient_box_results + final_patient_box_results = pool.map(apply_nms_to_patient, mp_inputs, chunksize=1) + pool.close() + pool.join() + del mp_inputs + + if self.cf.merge_2D_to_3D_preds: + self.logger.info('applying 2D-to-3D merging to test-set predictions with iou = {}.'.format(self.cf.merge_3D_iou)) + mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.merge_3D_iou] for ii in final_patient_box_results] + del final_patient_box_results + pool = Pool(processes=self.cf.n_workers) + final_patient_box_results = pool.map(apply_2d_3d_merging_to_patient, mp_inputs, chunksize=1) + pool.close() + pool.join() + del mp_inputs + # final_patient_box_results holds [avg_boxes, pid] if wbc + for ix in range(len(results_per_patient)): + assert results_per_patient[ix][1] == final_patient_box_results[ix][1], "should be same pid" + results_per_patient[ix][0]["boxes"] = final_patient_box_results[ix][0] + # results_per_patient = [(res_dict["boxes"] = boxes, pid) for (boxes,pid) in final_patient_box_results] + + return results_per_patient # holds list of (results_dict, pid) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1161169 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,113 @@ +absl-py==0.7.1 +alabaster==0.7.11 +asn1crypto==0.24.0 +astor==0.7.1 +astroid==2.0.1 +Babel==2.6.0 +backcall==0.1.0 +batchgenerators==0.18.2 +bleach==2.1.3 +certifi==2018.4.16 +cffi==1.11.5 +chardet==3.0.4 +cloudpickle==0.5.3 +cryptography==2.3 +cycler==0.10.0 +Cython==0.29.6 +dask==0.18.1 +decorator==4.3.0 +docutils==0.14 +entrypoints==0.2.3 +future==0.16.0 +gast==0.2.0 +grpcio==1.13.0 +h5py==2.9.0 +html5lib==1.0.1 +idna==2.7 +imagesize==1.0.0 +isort==4.3.4 +jedi==0.12.1 +jeepney==0.3.1 +Jinja2==2.10.1 +joblib==0.13.2 +jsonschema==2.6.0 +Keras-Applications==1.0.7 +Keras-Preprocessing==1.0.9 +keyring==13.2.1 +kiwisolver==1.0.1 +lazy-object-proxy==1.3.1 +linecache2==1.0.0 +Markdown==2.6.11 +MarkupSafe==1.0 +matplotlib==3.0.3 +mccabe==0.6.1 +mistune==0.8.3 +mock==2.0.0 +nbconvert==5.3.1 +nbformat==4.4.0 +networkx==2.1 +nibabel==2.3.0 +nilearn==0.4.2 +numpy==1.14.5 +numpydoc==0.8.0 +nvidia-ml-py3==7.352.0 +packaging==17.1 +pandas==0.24.2 +pandocfilters==1.4.2 +parso==0.3.1 +pathlib==1.0.1 +pbr==5.1.3 +pexpect==4.6.0 +pickleshare==0.7.4 +Pillow==6.2.1 +prompt-toolkit==2.0.9 +protobuf==3.7.1 +psutil==5.4.6 +ptyprocess==0.6.0 +pycodestyle==2.4.0 +pycparser==2.18 +pyflakes==2.0.0 +Pygments==2.2.0 +pylint==2.0.1 +PyOpenGL==3.1.0 +pyparsing==2.2.0 +PyQt5==5.9.2 +python-dateutil==2.7.3 +pytz==2018.5 +PyWavelets==0.5.2 +pyzmq==17.1.0 +QtAwesome==0.4.4 +qtconsole==4.3.1 +QtPy==1.4.2 +requests==2.22.0 +rope==0.14.0 +scikit-image==0.14.0 +scikit-learn==0.21.3 +scipy==1.1.0 +SecretStorage==3.0.1 +simplegeneric==0.8.1 +SimpleITK==1.2.2 +sip==4.19.8 +six==1.11.0 +snowballstemmer==1.2.1 +Sphinx==1.7.6 +sphinxcontrib-websupport==1.1.0 +tensorboard==1.13.1 +tensorboardX==1.6 +tensorflow==1.13.1 +tensorflow-estimator==1.13.0 +termcolor==1.1.0 +testpath==0.3.1 +toolz==0.9.0 +torch==0.4.1 +torchvision==0.2.1 +tornado==5.1 +traceback2==1.4.0 +traitlets==4.3.2 +typed-ast==1.1.0 +unittest2==1.1.0 +urllib3==1.25.3 +wcwidth==0.1.7 +webencodings==0.5.1 +Werkzeug==0.15.5 +wrapt==1.10.11 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..0a64bb7 --- /dev/null +++ b/setup.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from distutils.core import setup +from setuptools import find_packages + +req_file = "requirements.txt" + +def parse_requirements(filename): + lineiter = (line.strip() for line in open(filename)) + return [line for line in lineiter if line and not line.startswith("#")] + +install_reqs = parse_requirements(req_file) + +setup(name='model', + version='latest', + packages=find_packages(exclude=['test', 'test.*']), + install_requires=install_reqs, + dependency_links=[], + ) \ No newline at end of file diff --git a/shell_scripts/ana_starter.sh b/shell_scripts/ana_starter.sh new file mode 100644 index 0000000..1eeb63d --- /dev/null +++ b/shell_scripts/ana_starter.sh @@ -0,0 +1,11 @@ +mode=${1} +dataset_name=${2} + +source_dir=/home/gregor/Documents/medicaldetectiontoolkit + +exps_dir=/home/gregor/networkdrives/E132-Cluster-Projects/${dataset_name}/experiments_float_data +exps_dirs=$(ls -d ${exps_dir}/*) +for dir in ${exps_dirs}; do + echo "starting ${mode} in ${dir}" + (python ${source_dir}/exec.py --use_stored_settings --mode ${mode} --dataset_name ${dataset_name} --exp_dir ${dir}) || (echo "FAILED!") +done diff --git a/shell_scripts/cluster_runner_meddec.sh b/shell_scripts/cluster_runner_meddec.sh new file mode 100644 index 0000000..d884226 --- /dev/null +++ b/shell_scripts/cluster_runner_meddec.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +#Usage: +# -->not true?: this script has to be started from the same directory the python files called below lie in (e.g. exec.py lies in meddetectiontkit). +# part of the slurm-job name you pass to sbatch will be the experiment folder's name. +# you need to pass 3 positional arguments to this script (cluster_runner_..sh #1 #2 #3): +# -#1 source directory in which main source code (framework) is located (e.g. medicaldetectiontoolkit/) +# -#2 the exp_dir where job-specific code was copied before by create_exp and exp results are safed by exec.py +# -#3 absolute path to dataset-specific code in source dir +# -#4 mode to run +# -#5 folds to run on + +source_dir=${1} +exp_dir=${2} +dataset_abs_path=${3} +mode=${4} +folds=${5} +resume=$6 + +#known problem: trap somehow does not execute the rm -r tmp_dir command when using scancel on job +#trap clean_up EXIT KILL TERM ABRT QUIT + +job_dir=/ssd/ramien/${LSB_JOBID} + +tmp_dir_data=${job_dir}/data +mkdir $tmp_dir_data + +tmp_dir_cache=${job_dir}/cache +mkdir $tmp_dir_cache +CUDA_CACHE_PATH=$tmp_dir_cache +export CUDA_CACHE_PATH + + +#data must not lie permantly on nodes' ssd, only during training time +#needs to be named with the SLURM_JOB_ID to not be automatically removed +#can permanently lie on /datasets drive --> copy from there before every experiment +#files on datasets are saved as npz (compressed) --> use data_manager.py to copy and unpack into .npy; is done implicitly in exec.py + +#(tensorboard --logdir ${exp_dir}/.. --port 1337 || echo "tboard startup failed")& # || tensorboard --logdir ${exp_dir}/.. --port 1338)& +#tboard_pid=$! + +#clean_up() { +# rm -rf ${job_dir}; +#} + +export OMP_NUM_THREADS=1 # this is a work-around fix for batchgenerators to deal with numpy-inherent multi-threading. + +if [ ! -z "${folds}" ]; then + if [ -z "${resume}" ]; then + resume='None' + else + resume=${exp_dir}"/fold_${folds}/last_state.pth" + echo "Resuming from checkpoint at ${resume}." + fi + python ${source_dir}/exec.py --use_stored_settings --server_env --dataset_name ${dataset_abs_path} --data_dest ${tmp_dir_data} --exp_dir ${exp_dir} --mode ${mode} --folds ${folds} --resume_from_checkpoint ${resume} + +else + python ${source_dir}/exec.py --use_stored_settings --server_env --dataset_name ${dataset_abs_path} --data_dest ${tmp_dir_data} --exp_dir ${exp_dir} --mode ${mode} + +fi + + + + + diff --git a/shell_scripts/job_starter.sh b/shell_scripts/job_starter.sh new file mode 100644 index 0000000..bfcf052 --- /dev/null +++ b/shell_scripts/job_starter.sh @@ -0,0 +1,179 @@ +#!/bin/bash +#wrapper for cluster_runner_....sh which copies job-specific, frequently changing files (e.g. configs.py) before the actual sbatch job +#is submitted since the job might pend in queue before execution --> hazard of job-specific files being unintentionally changed during queue wait time. +#positonal +# -arg #1 identifies the folder name of the dataset-related code (e.g. >prostate< or >cityscapes<) within the code source directory +# -arg #2 is the experiment and first part of slurm job name, +# optional args and flags: +# -c / --create: (flag) whether to create the exp, i.e., if this is a new start of the exp with configs etc from source dir. +# -f / --folds FOLDS: (option) fold(s) to run on (FOLDS needs to be only one int or string of multiple ints separated by space), default None (-->set to all in config) +# -m / --mode MODE: (option) string, one of "train", "train_test", "test", defaults to "train_test" +# -p / --exp_parent_dir: (option) name of parent_dir rel to dataset folder on cluster. exp_dir is exp_parent_dir/exp_name, if not given defaults to "experiments" +# -q / --queue: (option) which queue (-q parameter for bsub) to send job to. default: gputest. others: gputest-short (max 5h jobs). +# -w / --which: (option) same as argument -m to bsub; host or host list (string separated by space) to send the job to. +# use nodenameXX where XX==nr of node or nodenameXX,nodenameYY,... or nodename[XX-YY]. nodename is e.g. e132-comp. +# --gmem: (option) how much gpu memory to request for job (in gigabytes), defaults to 11.9. Currently, the smaller nodes have 11.9G, the larger ones 31.7G. +# --resume: (flag) only with explicit fold argument, if set, resumes from checkpoint in exp_dir/fold_x/last_state.pth. +# --no_parallel: (flag) if set, folds won't start as parallel jobs on cluster, but run sequentially in one job. + +dataset_name="${1}" +exp_name="${2}" + +#arguments not passed, e.g. $7 if no seventh argument, are null. +if [ ! -z "${18}" ]; then #-z checks if is null string + echo "Error: Received too many arguments." + exit +fi + +#make args optional: move up if some args are missing inbetween +while [ ${#} -gt 2 ]; do + case "${3}" in + -c|--create) + create_exp="c" + shift + ;; + -f|--folds) + folds="${4}" + shift; shift + ;; + -m|--mode) + mode="${4}" + shift; shift + ;; + -p|--exp_parent_dir) + exp_parent_dir="${4}" + shift; shift + ;; + -q|--queue) + queue="${4}" + shift; shift + ;; + -w|--which) + which="${4}" + shift; shift + ;; + --gmem) + gmem="${4}" + shift; shift + ;; + --resume) + resume=true + shift + ;; + --no_parallel) + no_parallel=true + shift + ;; + *) + echo "Invalid argument/option passed: ${3}" + exit 1 + ;; + esac +done + +# default values +if [ -z ${exp_parent_dir} ]; then + exp_parent_dir="experiments" +fi + +if [ -z ${mode} ]; then + mode="train_test" +fi + +if [ -z ${queue} ]; then + queue="gputest" +fi + + +if [ -z ${gmem} ]; then + gmem="11" +fi + + +root_dir=/home/ramien #assumes /home/ramien exists +prep_node=ramien@e132-comp07 #node used for prep tasks like create_exp +#medicaldetectiontoolkit +source_dir=${root_dir}/medicaldetectiontoolkit + +dataset_abs_path=${source_dir}/datasets/${dataset_name} #set as second argument passed to this script +exp_parent_dir=/datasets/data_ramien/${dataset_name}/${exp_parent_dir} +#exp_parent_dir=/home/gregor/Documents/medicaldetectiontoolkit/datasets/${dataset_name}/experiments #for testing this script +# /dataset is not mounted on log-in/job submission nodes (would maybe make sense, I feel), only on queue gputest's nodes e132-compXX. +ssh ${prep_node} "mkdir -p ${exp_parent_dir}" +exp_dir=${exp_parent_dir}/${exp_name} + +#activate virtualenv that has all the packages: +source_dl="module load python/3.6.1; source ${root_dir}/.virtualenvs/deeplearning36/bin/activate" + +# TODO as long as no fix available: this script needs to be started directly from the prep node. :/ would be nice if (most importantly +# 'module ...' would also work over ssh, but somehow some commands are not availabe over the ssh-induced shell (even when using it as interactive). +eval ${source_dl} + +#if create_exp, check if would overwrite existing exp_dir +if [ ! -z ${create_exp} ] && [ ${create_exp} = "c" ]; then #-n doesnt work as replacement for !-z + if [ -d ${exp_dir} ]; then + echo "Please confirm to overwrite exp ${exp_name} settings, (Y/n): "; read confirmation + if ([ "${confirmation}" = "y" ] || [ "${confirmation}" = "yes" ] || [ "${confirmation}" = "Y" ] || [ -z "${confirmation}" ]); then + echo "Overwriting ${exp_name}" + else + echo "Exiting due to overwrite denial. Adjust options." + exit + fi + fi + #echo "opts: name ${exp_name}, ${source_dir}/exec.py --server_env --mode create_exp --exp_dir ${exp_dir} --dataset_name ${dataset_abs_path}" + echo "Creating ${exp_name}" + #ssh ${prep_node} "${source_dl}; python ${source_dir}/exec.py --server_env --mode create_exp --exp_dir ${exp_dir} --dataset_name ${dataset_abs_path};" + python ${source_dir}/exec.py --server_env --mode create_exp --exp_dir ${exp_dir} --dataset_name ${dataset_abs_path} +else + if [ ! -d ${exp_dir} ]; then + echo "Experiment directory ${exp_dir} does not exist." + echo "Run create_exp? (Y/n): "; read confirmation + if ([ "${confirmation}" = "y" ] || [ "${confirmation}" = "yes" ] || [ "${confirmation}" = "Y" ] || [ -z "${confirmation}" ]); then + echo "Creating ${exp_name}" + python ${source_dir}/exec.py --server_env --mode create_exp --exp_dir ${exp_dir} --dataset_name ${dataset_abs_path} + fi + fi +fi + +#if not create_exp, check if would overwrite existing folds (possibly valuable trained params!) +if [ -z ${create_exp} ] && ([ ${mode} = "train" ] || [ ${mode} = "train_test" ]) && [ -z "${resume}" ]; then + for f in ${folds}; do #if folds is null this check won't apply and folds will be quietly overwritten. + if [ -d ${exp_dir}/fold_${f} ]; then #-d checks if is dir + echo "please confirm to overwrite fold_${f}, (Y/n):"; read confirmation + if ([ "${confirmation}" = "y" ] || [ "${confirmation}" = "yes" ] || [ "${confirmation}" = "Y" ] || [ -z "${confirmation}" ]); then + echo "Overwriting "${exp_name}/fold_${f} + else + echo "Exiting due to overwrite denial. Adjust options." + exit + fi + fi + done +fi + +if [ ! -z "${folds}" ] && [ -z ${no_parallel} ]; then #WHY do i need to convert to string again? + for f in ${folds}; do + out_file=${exp_dir}/logs/fold_${f}_lsf_output.out + bsub_opts="bsub -N -q '${queue}' -J '${dataset_name} ${exp_name} fold ${f} ${mode}' -gpu num=1:j_exclusive=yes:mode=exclusive_process:gmem=${gmem}G -oo ${out_file}" + if [ ! -z ${which} ]; then + bsub_opts="${bsub_opts} -m ${which}" + fi + #echo ${bsub_opts} #${exp_name}" fold ""${f}"" ""${mode}" #--gres=${gres} --time=${time} -w ${which} + eval "${bsub_opts} 'sh cluster_runner_meddec.sh' ${source_dir} ${exp_dir} ${dataset_abs_path} ${mode} ${f} ${resume}" + done +else + #echo ${exp_name}" fold ""${f}"" ""${mode}" + if [ ! -z ${resume} ]; then + echo "You need to explicitly specify folds if you would like to resume from a checkpoint. Exiting." + exit + fi + out_file=${exp_dir}/lsf_output.out + bsub_opts="bsub -N -q '${queue}' -J '${dataset_name} ${exp_name} folds ${folds} ${mode}' -gpu num=1:j_exclusive=yes:mode=exclusive_process:gmem=${gmem}G -oo '${out_file}'" + if [ ! -z ${which} ]; then + bsub_opts="${bsub_opts} -m ${which}" + fi + eval "${bsub_opts} 'sh cluster_runner_meddec.sh' ${source_dir} ${exp_dir} ${dataset_abs_path} ${mode} ${folds} ${resume}" + echo "Started in no parallel, folds:" ${folds} +fi + + + diff --git a/understanding_metrics.py b/understanding_metrics.py new file mode 100644 index 0000000..6e1532f --- /dev/null +++ b/understanding_metrics.py @@ -0,0 +1,66 @@ + +""" +Created at 06/12/18 13:34 +@author: gregor +""" +import sys +import os +import numpy as np +import pandas as pd +from sklearn.metrics import roc_auc_score, average_precision_score +from sklearn.metrics import roc_curve, precision_recall_curve + +import plotting as plg +import evaluator + +sys.path.append("datasets/prostate/") +from configs import Configs + +""" This is just a supplementary file which you may use to demonstrate or understand detection metrics. +""" + + +def get_det_types(df): + det_types = [] + for ix, score in enumerate(df["pred_score"]): + if score > 0 and df["class_label"][ix] == 1: + det_types.append("det_tp") + elif score > 0 and df["class_label"][ix] == 0: + det_types.append("det_fp") + elif score == 0 and df["class_label"][ix] == 1: + det_types.append("det_fn") + elif score == 0 and df["class_label"][ix] == 0: + det_types.append("det_tn") + return det_types + + +if __name__=="__main__": + cf = Configs() + + working_dir = "/home/gregor/Documents/ramien/Thesis/UnderstandingMetrics" + + df = pd.DataFrame(columns=['pred_score', 'class_label', 'pred_class', 'det_type', 'match_iou']) + + df["pred_score"] = [0.3, 0.] + df["class_label"] = [0, 1] + #df["pred_class"] = [1]*len(df) + det_types = get_det_types(df) + + df["det_type"] = det_types + df["match_iou"] = [0.1]*len(df) + + prc_own = evaluator.compute_prc(df) + all_stats = [{"prc":prc_own, 'roc':np.nan, 'name': "demon"}] + plg.plot_stat_curves(cf, all_stats, os.path.join(working_dir, "understanding_ap_own"), fill=True) + + prc_sk = precision_recall_curve(df.class_label.tolist(), df.pred_score.tolist()) + all_stats = [{"prc":prc_sk, 'roc':np.nan, 'name': "demon"}] + plg.plot_stat_curves(cf, all_stats, os.path.join(working_dir, "understanding_ap"), fill=True) + + ap = evaluator.get_roi_ap_from_df((df, 0.02, False)) + ap_sk = average_precision_score(df.class_label.tolist(), df.pred_score.tolist()) + print("roi_ap_from_df (own implement):",ap) + print("aver_prec_sc (sklearn):",ap_sk) + + plg.plot_prediction_hist(cf, df, os.path.join(working_dir, "understanding_ap.png"), title="AP_own {:.2f}, AP_sklearn {:.2f}".format(ap, ap_sk)) + diff --git a/unittests.py b/unittests.py new file mode 100644 index 0000000..e1b1937 --- /dev/null +++ b/unittests.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import unittest + +import os +import pickle +import time +from multiprocessing import Pool + +import numpy as np +import pandas as pd + +import utils.exp_utils as utils +import utils.model_utils as mutils + +""" Note on unittests: run this file either in the way intended for unittests by starting the script with + python -m unittest unittests.py or start it as a normal python file as python unittests.py. + You can selective run single tests by calling python -m unittest unittests.TestClassOfYourChoice, where + TestClassOfYourChoice is the name of the test defined below, e.g., CompareFoldSplits. +""" + + + +def inspect_info_df(pp_dir): + """ use your debugger to look into the info df of a pp dir. + :param pp_dir: preprocessed-data directory + """ + + info_df = pd.read_pickle(os.path.join(pp_dir, "info_df.pickle")) + + return + +#------- perform integrity checks on data set(s) ----------- +class VerifyLIDCSAIntegrity(unittest.TestCase): + """ Perform integrity checks on preprocessed single-annotator GTs of LIDC data set. + """ + @staticmethod + def check_patient_sa_gt(pid, pp_dir, check_meta_files, check_info_df): + + faulty_cases = pd.DataFrame(columns=['pid', 'rater', 'cl_targets', 'roi_ids']) + + all_segs = np.load(os.path.join(pp_dir, pid + "_rois.npz"), mmap_mode='r') + all_segs = all_segs[list(all_segs.keys())[0]] + all_roi_ids = np.unique(all_segs[all_segs > 0]) + assert len(all_roi_ids) == np.max(all_segs), "roi ids not consecutive" + if check_meta_files: + meta_file = os.path.join(pp_dir, pid + "_meta_info.pickle") + with open(meta_file, "rb") as handle: + info = pickle.load(handle) + assert info["pid"] == pid, "wrong pid in meta_file" + all_cl_targets = info["class_target"] + if check_info_df: + info_df = pd.read_pickle(os.path.join(pp_dir, "info_df.pickle")) + pid_info = info_df[info_df.pid == pid] + assert len(pid_info) == 1, "found {} entries for pid {} in info df, expected exactly 1".format(len(pid_info), + pid) + if check_meta_files: + assert pid_info[ + "class_target"] == all_cl_targets, "meta_info and info_df class targets mismatch:\n{}\n{}".format( + pid_info["class_target"], all_cl_targets) + all_cl_targets = pid_info["class_target"].iloc[0] + assert len(all_roi_ids) == len(all_cl_targets) + for rater in range(4): + seg = all_segs[rater] + roi_ids = np.unique(seg[seg > 0]) + cl_targs = np.array([roi[rater] for roi in all_cl_targets]) + assert np.count_nonzero(cl_targs) == len(roi_ids), "rater {} has targs {} but roi ids {}".format(rater, cl_targs, roi_ids) + assert len(cl_targs) >= len(roi_ids), "not all marked rois have a label" + for zeroix_roi_id, rating in enumerate(cl_targs): + if not ((rating > 0) == (np.any(seg == zeroix_roi_id + 1))): + print("\n\nFAULTY CASE:", end=" ", ) + print("pid {}, rater {}, cl_targs {}, ids {}\n".format(pid, rater, cl_targs, roi_ids)) + faulty_cases = faulty_cases.append( + {'pid': pid, 'rater': rater, 'cl_targets': cl_targs, 'roi_ids': roi_ids}, ignore_index=True) + print("finished checking pid {}, {} faulty cases".format(pid, len(faulty_cases))) + return faulty_cases + + def check_sa_gts(self, pp_dir, pid_subset=None, check_meta_files=False, check_info_df=True, processes=os.cpu_count()): + report_name = "verify_seg_label_pairings.csv" + pids = {file_name.split("_")[0] for file_name in os.listdir(pp_dir) if file_name not in [report_name, "info_df.pickle"]} + if pid_subset is not None: + pids = [pid for pid in pids if pid in pid_subset] + + + faulty_cases = pd.DataFrame(columns=['pid', 'rater', 'cl_targets', 'roi_ids']) + + p = Pool(processes=processes) + mp_args = zip(pids, [pp_dir]*len(pids), [check_meta_files]*len(pids), [check_info_df]*len(pids)) + patient_cases = p.starmap(self.check_patient_sa_gt, mp_args) + p.close(); p.join() + faulty_cases = faulty_cases.append(patient_cases, sort=False) + + + print("\n\nfaulty case count {}".format(len(faulty_cases))) + print(faulty_cases) + findings_file = os.path.join(pp_dir, "verify_seg_label_pairings.csv") + faulty_cases.to_csv(findings_file) + + assert len(faulty_cases)==0, "there was a faulty case in data set {}.\ncheck {}".format(pp_dir, findings_file) + + def test(self): + pp_root = "/mnt/HDD2TB/Documents/data/" + pp_dir = "lidc/pp_20190805" + gt_dir = os.path.join(pp_root, pp_dir, "patient_gts_sa") + self.check_sa_gts(gt_dir, check_meta_files=True, check_info_df=False, pid_subset=None) # ["0811a", "0812a"]) + +#------ compare segmentation gts of preprocessed data sets ------ +class CompareSegGTs(unittest.TestCase): + """ load and compare pre-processed gts by dice scores of segmentations. + + """ + @staticmethod + def group_seg_paths(ref_path, comp_paths): + # not working recursively + ref_files = [fn for fn in os.listdir(ref_path) if + os.path.isfile(os.path.join(ref_path, fn)) and 'seg' in fn and fn.endswith('.npy')] + + comp_files = [[os.path.join(c_path, fn) for c_path in comp_paths] for fn in ref_files] + + ref_files = [os.path.join(ref_path, fn) for fn in ref_files] + + return zip(ref_files, comp_files) + + @staticmethod + def load_calc_dice(paths): + dices = [] + ref_seg = np.load(paths[0])[np.newaxis, np.newaxis] + n_classes = len(np.unique(ref_seg)) + ref_seg = mutils.get_one_hot_encoding(ref_seg, n_classes) + + for c_file in paths[1]: + c_seg = np.load(c_file)[np.newaxis, np.newaxis] + assert n_classes == len(np.unique(c_seg)), "unequal nr of objects/classes betw segs {} {}".format(paths[0], + c_file) + c_seg = mutils.get_one_hot_encoding(c_seg, n_classes) + + dice = mutils.dice_per_batch_inst_and_class(c_seg, ref_seg, n_classes, convert_to_ohe=False) + dices.append(dice) + print("processed ref_path {}".format(paths[0])) + return np.mean(dices), np.std(dices) + + def iterate_files(self, grouped_paths, processes=os.cpu_count()): + p = Pool(processes) + + means_stds = np.array(p.map(self.load_calc_dice, grouped_paths)) + + p.close(); p.join() + min_dice = np.min(means_stds[:, 0]) + print("min mean dice {:.2f}, max std {:.4f}".format(min_dice, np.max(means_stds[:, 1]))) + assert min_dice > 1-1e5, "compared seg gts have insufficient minimum mean dice overlap of {}".format(min_dice) + + def test(self): + ref_path = '/mnt/HDD2TB/Documents/data/prostate/data_t2_250519_ps384_gs6071' + comp_paths = ['/mnt/HDD2TB/Documents/data/prostate/data_t2_190419_ps384_gs6071', ] + paths = self.group_seg_paths(ref_path, comp_paths) + self.iterate_files(paths) + +#------- check if cross-validation fold splits of different experiments are identical ---------- +class CompareFoldSplits(unittest.TestCase): + """ Find evtl. differences in cross-val file splits across different experiments. + """ + @staticmethod + def group_id_paths(ref_exp_dir, comp_exp_dirs): + + f_name = 'fold_ids.pickle' + + ref_paths = os.path.join(ref_exp_dir, f_name) + assert os.path.isfile(ref_paths), "ref file {} does not exist.".format(ref_paths) + + + ref_paths = [ref_paths for comp_ed in comp_exp_dirs] + comp_paths = [os.path.join(comp_ed, f_name) for comp_ed in comp_exp_dirs] + + return zip(ref_paths, comp_paths) + + @staticmethod + def comp_fold_ids(mp_input): + fold_ids1, fold_ids2 = mp_input + with open(fold_ids1, 'rb') as f: + fold_ids1 = pickle.load(f) + try: + with open(fold_ids2, 'rb') as f: + fold_ids2 = pickle.load(f) + except FileNotFoundError: + print("comp file {} does not exist.".format(fold_ids2)) + return + + n_splits = len(fold_ids1) + assert n_splits == len(fold_ids2), "mismatch n splits: ref has {}, comp {}".format(n_splits, len(fold_ids2)) + split_diffs = [np.setdiff1d(fold_ids1[s], fold_ids2[s]) for s in range(n_splits)] + all_equal = np.any(split_diffs) + return (split_diffs, all_equal) + + def iterate_exp_dirs(self, ref_exp, comp_exps, processes=os.cpu_count()): + + grouped_paths = list(self.group_id_paths(ref_exp, comp_exps)) + print("performing {} comparisons of cross-val file splits".format(len(grouped_paths))) + p = Pool(processes) + split_diffs = p.map(self.comp_fold_ids, grouped_paths) + p.close(); p.join() + + df = pd.DataFrame(index=range(0,len(grouped_paths)), columns=["ref", "comp", "all_equal"])#, "diffs"]) + for ix, (ref, comp) in enumerate(grouped_paths): + df.iloc[ix] = [ref, comp, split_diffs[ix][1]]#, split_diffs[ix][0]] + + print("Any splits not equal?", df.all_equal.any()) + assert not df.all_equal.any(), "a split set is different from reference split set, {}".format(df[~df.all_equal]) + + def test(self): + exp_parent_dir = '/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/' + ref_exp = '/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_detfpn2d_cl_bs10' + comp_exps = [os.path.join(exp_parent_dir, p) for p in os.listdir(exp_parent_dir)] + comp_exps = [p for p in comp_exps if os.path.isdir(p) and p != ref_exp] + self.iterate_exp_dirs(ref_exp, comp_exps) + + +#------- check if cross-validation fold splits of a single experiment are actually incongruent (as required) ---------- +class VerifyFoldSplits(unittest.TestCase): + """ Check, for a single fold_ids file, i.e., for a single experiment, if the assigned folds (assignment of data + identifiers) is actually incongruent. No overlaps between folds are required for a correct cross validation. + """ + @staticmethod + def verify_fold_ids(splits): + for i, split1 in enumerate(splits): + for j, split2 in enumerate(splits): + if j > i: + inter = np.intersect1d(split1, split2) + if len(inter) > 0: + raise Exception("Split {} and {} intersect by pids {}".format(i, j, inter)) + def test(self): + exp_dir = "/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/dev" + check_file = os.path.join(exp_dir, 'fold_ids.pickle') + with open(check_file, 'rb') as handle: + splits = pickle.load(handle) + self.verify_fold_ids(splits) + +if __name__=="__main__": + stime = time.time() + + unittest.main() + + mins, secs = divmod((time.time() - stime), 60) + h, mins = divmod(mins, 60) + t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) + print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) \ No newline at end of file diff --git a/utils/dataloader_utils.py b/utils/dataloader_utils.py new file mode 100644 index 0000000..c838ee6 --- /dev/null +++ b/utils/dataloader_utils.py @@ -0,0 +1,655 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import plotting as plg + +import os +from multiprocessing import Pool +import pickle +import warnings + +import numpy as np +import pandas as pd +from batchgenerators.transforms.abstract_transforms import AbstractTransform +from scipy.ndimage.measurements import label as lb +from torch.utils.data import Dataset as torchDataset +from batchgenerators.dataloading.data_loader import SlimDataLoaderBase + +import utils.exp_utils as utils +import data_manager as dmanager + + +for msg in ["This figure includes Axes that are not compatible with tight_layout", + "Data has no positive values, and therefore cannot be log-scaled."]: + warnings.filterwarnings("ignore", msg) + + +class AttributeDict(dict): + __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ + +################################## +# data loading, organisation # +################################## + + +class fold_generator: + """ + generates splits of indices for a given length of a dataset to perform n-fold cross-validation. + splits each fold into 3 subsets for training, validation and testing. + This form of cross validation uses an inner loop test set, which is useful if test scores shall be reported on a + statistically reliable amount of patients, despite limited size of a dataset. + If hold out test set is provided and hence no inner loop test set needed, just add test_idxs to the training data in the dataloader. + This creates straight-forward train-val splits. + :returns names list: list of len n_splits. each element is a list of len 3 for train_ix, val_ix, test_ix. + """ + def __init__(self, seed, n_splits, len_data): + """ + :param seed: Random seed for splits. + :param n_splits: number of splits, e.g. 5 splits for 5-fold cross-validation + :param len_data: number of elements in the dataset. + """ + self.tr_ix = [] + self.val_ix = [] + self.te_ix = [] + self.slicer = None + self.missing = 0 + self.fold = 0 + self.len_data = len_data + self.n_splits = n_splits + self.myseed = seed + self.boost_val = 0 + + def init_indices(self): + + t = list(np.arange(self.l)) + # round up to next splittable data amount. + split_length = int(np.ceil(len(t) / float(self.n_splits))) + self.slicer = split_length + self.mod = len(t) % self.n_splits + if self.mod > 0: + # missing is the number of folds, in which the new splits are reduced to account for missing data. + self.missing = self.n_splits - self.mod + + self.te_ix = t[:self.slicer] + self.tr_ix = t[self.slicer:] + self.val_ix = self.tr_ix[:self.slicer] + self.tr_ix = self.tr_ix[self.slicer:] + + def new_fold(self): + + slicer = self.slicer + if self.fold < self.missing : + slicer = self.slicer - 1 + + temp = self.te_ix + + # catch exception mod == 1: test set collects 1+ data since walk through both roudned up splits. + # account for by reducing last fold split by 1. + if self.fold == self.n_splits-2 and self.mod ==1: + temp += self.val_ix[-1:] + self.val_ix = self.val_ix[:-1] + + self.te_ix = self.val_ix + self.val_ix = self.tr_ix[:slicer] + self.tr_ix = self.tr_ix[slicer:] + temp + + + def get_fold_names(self): + names_list = [] + rgen = np.random.RandomState(self.myseed) + cv_names = np.arange(self.len_data) + + rgen.shuffle(cv_names) + self.l = len(cv_names) + self.init_indices() + + for split in range(self.n_splits): + train_names, val_names, test_names = cv_names[self.tr_ix], cv_names[self.val_ix], cv_names[self.te_ix] + names_list.append([train_names, val_names, test_names, self.fold]) + self.new_fold() + self.fold += 1 + + return names_list + + + +class FoldGenerator(): + r"""takes a set of elements (identifiers) and randomly splits them into the specified amt of subsets. + """ + + def __init__(self, identifiers, seed, n_splits=5): + self.ids = np.array(identifiers) + self.n_splits = n_splits + self.seed = seed + + def generate_splits(self, n_splits=None): + if n_splits is None: + n_splits = self.n_splits + + rgen = np.random.RandomState(self.seed) + rgen.shuffle(self.ids) + self.splits = list(np.array_split(self.ids, n_splits, axis=0)) # already returns list, but to be sure + return self.splits + + +class Dataset(torchDataset): + r"""Parent Class for actual Dataset classes to inherit from! + """ + def __init__(self, cf, data_sourcedir=None): + super(Dataset, self).__init__() + self.cf = cf + + self.data_sourcedir = cf.data_sourcedir if data_sourcedir is None else data_sourcedir + self.data_dir = cf.data_dir if hasattr(cf, 'data_dir') else self.data_sourcedir + + self.data_dest = cf.data_dest if hasattr(cf, "data_dest") else self.data_sourcedir + + self.data = {} + self.set_ids = [] + + def copy_data(self, cf, file_subset, keep_packed=False, del_after_unpack=False): + if os.path.normpath(self.data_sourcedir) != os.path.normpath(self.data_dest): + self.data_sourcedir = os.path.join(self.data_sourcedir, '') + args = AttributeDict({ + "source" : self.data_sourcedir, + "destination" : self.data_dest, + "recursive" : True, + "cp_only_npz" : False, + "keep_packed" : keep_packed, + "del_after_unpack" : del_after_unpack, + "threads" : 16 if self.cf.server_env else os.cpu_count() + }) + dmanager.copy(args, file_subset=file_subset) + self.data_dir = self.data_dest + + + + def __len__(self): + return len(self.data) + def __getitem__(self, id): + """Return a sample of the dataset, i.e.,the dict of the id + """ + return self.data[id] + def __iter__(self): + return self.data.__iter__() + + def init_FoldGenerator(self, seed, n_splits): + self.fg = FoldGenerator(self.set_ids, seed=seed, n_splits=n_splits) + + def generate_splits(self, check_file): + if not os.path.exists(check_file): + self.fg.generate_splits() + with open(check_file, 'wb') as handle: + pickle.dump(self.fg.splits, handle) + else: + with open(check_file, 'rb') as handle: + self.fg.splits = pickle.load(handle) + + def calc_statistics(self, subsets=None, plot_dir=None, overall_stats=True): + + if self.df is None: + self.df = pd.DataFrame() + balance_t = self.cf.balance_target if hasattr(self.cf, "balance_target") else "class_targets" + self.df._metadata.append(balance_t) + if balance_t=="class_targets": + mapper = lambda cl_id: self.cf.class_id2label[cl_id] + labels = self.cf.class_id2label.values() + elif balance_t=="rg_bin_targets": + mapper = lambda rg_bin: self.cf.bin_id2label[rg_bin] + labels = self.cf.bin_id2label.values() + # elif balance_t=="regression_targets": + # # todo this wont work + # mapper = lambda rg_val: AttributeDict({"name":rg_val}) #self.cf.bin_id2label[self.cf.rg_val_to_bin_id(rg_val)] + # labels = self.cf.bin_id2label.values() + elif balance_t=="lesion_gleasons": + mapper = lambda gs: self.cf.gs2label[gs] + labels = self.cf.gs2label.values() + else: + mapper = lambda x: AttributeDict({"name":x}) + labels = None + for pid, subj_data in self.data.items(): + unique_ts, counts = np.unique(subj_data[balance_t], return_counts=True) + self.df = self.df.append(pd.DataFrame({"pid": [pid], + **{mapper(unique_ts[i]).name: [counts[i]] for i in + range(len(unique_ts))}}), ignore_index=True, sort=True) + self.df = self.df.fillna(0) + + if overall_stats: + df = self.df.drop("pid", axis=1) + df = df.reindex(sorted(df.columns), axis=1).astype('uint32') + print("Overall dataset roi counts per target kind:"); print(df.sum()) + if subsets is not None: + self.df["subset"] = np.nan + self.df["display_order"] = np.nan + for ix, (subset, pids) in enumerate(subsets.items()): + self.df.loc[self.df.pid.isin(pids), "subset"] = subset + self.df.loc[self.df.pid.isin(pids), "display_order"] = ix + df = self.df.groupby("subset").agg("sum").drop("pid", axis=1, errors='ignore').astype('int64') + df = df.sort_values(by=['display_order']).drop('display_order', axis=1) + df = df.reindex(sorted(df.columns), axis=1) + + print("Fold {} dataset roi counts per target kind:".format(self.cf.fold)); print(df) + if plot_dir is not None: + os.makedirs(plot_dir, exist_ok=True) + if subsets is not None: + plg.plot_fold_stats(self.cf, df, labels, os.path.join(plot_dir, "data_stats_fold_" + str(self.cf.fold))+".pdf") + if overall_stats: + plg.plot_data_stats(self.cf, df, labels, os.path.join(plot_dir, 'data_stats_overall.pdf')) + + return df, labels + + +def get_class_balanced_patients(all_pids, class_targets, batch_size, num_classes, random_ratio=0): + ''' + samples towards equilibrium of classes (on basis of total RoI counts). for highly imbalanced dataset, this might be a too strong requirement. + :param class_targets: dic holding {patient_specifier : ROI class targets}, list position of ROI target corresponds to respective seg label - 1 + :param batch_size: + :param num_classes: + :return: + ''' + # assert len(all_pids)>=batch_size, "not enough eligible pids {} to form a single batch of size {}".format(len(all_pids), batch_size) + class_counts = {k: 0 for k in range(1,num_classes+1)} + not_picked = np.array(all_pids) + batch_patients = np.empty((batch_size,), dtype=not_picked.dtype) + rarest_class = np.random.randint(1,num_classes+1) + + for ix in range(batch_size): + if len(not_picked) == 0: + warnings.warn("Dataset too small to generate batch with unique samples; => recycling.") + not_picked = np.array(all_pids) + + np.random.shuffle(not_picked) #this could actually go outside(above) the loop. + pick = not_picked[0] + for cand in not_picked: + if np.count_nonzero(class_targets[cand] == rarest_class) > 0: + pick = cand + cand_rarest_class = np.argmin([np.count_nonzero(class_targets[cand] == cl) for cl in + range(1,num_classes+1)])+1 + # if current batch already bigger than the batch random ratio, then + # check that weakest class in this patient is not the weakest in current batch (since needs to be boosted) + # also that at least one roi of this patient belongs to weakest class. If True, keep patient, else keep looking. + if (cand_rarest_class != rarest_class and np.count_nonzero(class_targets[cand] == rarest_class) > 0) \ + or ix < int(batch_size * random_ratio): + break + + for c in range(1,num_classes+1): + class_counts[c] += np.count_nonzero(class_targets[pick] == c) + if not ix < int(batch_size * random_ratio) and class_counts[rarest_class] == 0: # means searched thru whole set without finding rarest class + print("Class {} not represented in current dataset.".format(rarest_class)) + rarest_class = np.argmin(([class_counts[c] for c in range(1,num_classes+1)]))+1 + batch_patients[ix] = pick + not_picked = not_picked[not_picked != pick] # removes pick + + return batch_patients + + +class BatchGenerator(SlimDataLoaderBase): + """ + create the training/validation batch generator. Randomly sample batch_size patients + from the data set, (draw a random slice if 2D), pad-crop them to equal sizes and merge to an array. + :param data: data dictionary as provided by 'load_dataset' + :param img_modalities: list of strings ['adc', 'b1500'] from config + :param batch_size: number of patients to sample for the batch + :param pre_crop_size: equal size for merging the patients to a single array (before the final random-crop in data aug.) + :return dictionary containing the batch data / seg / pids as lists; the augmenter will later concatenate them into an array. + """ + + def __init__(self, cf, data, n_batches=None): + super(BatchGenerator, self).__init__(data, cf.batch_size, n_batches) + self.cf = cf + self.plot_dir = os.path.join(self.cf.plot_dir, 'train_generator') + + self.dataset_length = len(self._data) + self.dataset_pids = list(self._data.keys()) + self.eligible_pids = self.dataset_pids + + self.stats = {"roi_counts": np.zeros((self.cf.num_classes,), dtype='uint32'), "empty_samples_count": 0} + + if hasattr(cf, "balance_target"): + # WARNING: "balance targets are only implemented for 1-d targets (or 1-component vectors)" + self.balance_target = cf.balance_target + else: + self.balance_target = "class_targets" + self.targets = {k:v[self.balance_target] for (k,v) in self._data.items()} + + def balance_target_distribution(self, plot=False): + """ + :param all_pids: + :param self.targets: dic holding {patient_specifier : patient-wise-unique ROI targets} + :return: probability distribution over all pids. draw without replace from this. + """ + # get unique foreground targets per patient, assign -1 to an "empty" patient (has no foreground) + patient_ts = [np.unique(lst) if len([t for t in lst if np.any(t>0)])>0 else [-1] for lst in self.targets.values()] + #bg_mask = np.array([np.all(lst == [-1]) for lst in patient_ts]) + unique_ts, t_counts = np.unique([t for lst in patient_ts for t in lst if t!=-1], return_counts=True) + t_probs = t_counts.sum() / t_counts + t_probs /= t_probs.sum() + t_probs = {t : t_probs[ix] for ix, t in enumerate(unique_ts)} + t_probs[-1] = 0. + # fail if balance target is not a number (i.e., a vector) + self.p_probs = np.array([ max([t_probs[t] for t in lst]) for lst in patient_ts ]) + #normalize + self.p_probs /= self.p_probs.sum() + # rescale probs of empty samples + # if not 0 == self.p_probs[bg_mask].shape[0]: + # #rescale_f = (1 - self.cf.empty_samples_ratio) / self.p_probs[~bg_mask].sum() + # rescale_f = 1 / self.p_probs[~bg_mask].sum() + # self.p_probs *= rescale_f + # self.p_probs[bg_mask] = 0. #self.cf.empty_samples_ratio/self.p_probs[bg_mask].shape[0] + + self.unique_ts = unique_ts + + if plot: + os.makedirs(self.plot_dir, exist_ok=True) + plg.plot_batchgen_distribution(self.cf, self.dataset_pids, self.p_probs, self.balance_target, + out_file=os.path.join(self.plot_dir, + "train_gen_distr_"+str(self.cf.fold)+".png")) + return self.p_probs + + + def generate_train_batch(self): + # to be overriden by child + # everything done in here is per batch + # print statements in here get confusing due to multithreading + + return + + def print_stats(self, logger=None, file=None, plot_file=None, plot=True): + print_f = utils.CombinedPrinter(logger, file) + + print_f('\nFinal Training Stats\n') + print_f('*********************\n') + total_count = np.sum(self.stats['roi_counts']) + for tix, count in enumerate(self.stats['roi_counts']): + #name = self.cf.class_dict[tix] if self.balance_target=="class_targets" else str(self.unique_ts[tix]) + name=str(self.unique_ts[tix]) + print_f('{}: {} rois seen ({:.1f}%).\n'.format(name, count, count / total_count * 100)) + total_samples = self.cf.num_epochs*self.cf.num_train_batches*self.cf.batch_size + print_f('empty samples seen: {} ({:.1f}%).\n'.format(self.stats['empty_samples_count'], + self.stats['empty_samples_count']/total_samples*100)) + if plot: + if plot_file is None: + plot_file = os.path.join(self.plot_dir, "train_gen_stats_{}.png".format(self.cf.fold)) + os.makedirs(self.plot_dir, exist_ok=True) + plg.plot_batchgen_stats(self.cf, self.stats, self.balance_target, self.unique_ts, plot_file) + +class PatientBatchIterator(SlimDataLoaderBase): + """ + creates a val/test generator. Step through the dataset and return dictionaries per patient. + 2D is a special case of 3D patching with patch_size[2] == 1 (slices) + Creates whole Patient batch and targets, and - if necessary - patchwise batch and targets. + Appends patient targets anyway for evaluation. + For Patching, shifts all patches into batch dimension. batch_tiling_forward will take care of exceeding batch dimensions. + + This iterator/these batches are not intended to go through MTaugmenter afterwards + """ + + def __init__(self, cf, data): + super(PatientBatchIterator, self).__init__(data, 0) + self.cf = cf + + self.dataset_length = len(self._data) + self.dataset_pids = list(self._data.keys()) + + def generate_train_batch(self, pid=None): + # to be overriden by child + + return + +################################### +# transforms, image manipulation # +################################### + +def get_patch_crop_coords(img, patch_size, min_overlap=30): + """ + _:param img (y, x, (z)) + _:param patch_size: list of len 2 (2D) or 3 (3D). + _:param min_overlap: minimum required overlap of patches. + If too small, some areas are poorly represented only at edges of single patches. + _:return ndarray: shape (n_patches, 2*dim). crop coordinates for each patch. + """ + crop_coords = [] + for dim in range(len(img.shape)): + n_patches = int(np.ceil(img.shape[dim] / patch_size[dim])) + + # no crops required in this dimension, add image shape as coordinates. + if n_patches == 1: + crop_coords.append([(0, img.shape[dim])]) + continue + + # fix the two outside patches to coords patchsize/2 and interpolate. + center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1) + + if (patch_size[dim] - center_dists) < min_overlap: + n_patches += 1 + center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1) + + patch_centers = np.round([(patch_size[dim] / 2 + (center_dists * ii)) for ii in range(n_patches)]) + dim_crop_coords = [(center - patch_size[dim] / 2, center + patch_size[dim] / 2) for center in patch_centers] + crop_coords.append(dim_crop_coords) + + coords_mesh_grid = [] + for ymin, ymax in crop_coords[0]: + for xmin, xmax in crop_coords[1]: + if len(crop_coords) == 3 and patch_size[2] > 1: + for zmin, zmax in crop_coords[2]: + coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmax]) + elif len(crop_coords) == 3 and patch_size[2] == 1: + for zmin in range(img.shape[2]): + coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmin + 1]) + else: + coords_mesh_grid.append([ymin, ymax, xmin, xmax]) + return np.array(coords_mesh_grid).astype(int) + + + +def pad_nd_image(image, new_shape=None, mode="edge", kwargs=None, return_slicer=False, shape_must_be_divisible_by=None): + """ + one padder to pad them all. Documentation? Well okay. A little bit. by Fabian Isensee + + :param image: nd image. can be anything + :param new_shape: what shape do you want? new_shape does not have to have the same dimensionality as image. If + len(new_shape) < len(image.shape) then the last axes of image will be padded. If new_shape < image.shape in any of + the axes then we will not pad that axis, but also not crop! (interpret new_shape as new_min_shape) + Example: + image.shape = (10, 1, 512, 512); new_shape = (768, 768) -> result: (10, 1, 768, 768). Cool, huh? + image.shape = (10, 1, 512, 512); new_shape = (364, 768) -> result: (10, 1, 512, 768). + + :param mode: see np.pad for documentation + :param return_slicer: if True then this function will also return what coords you will need to use when cropping back + to original shape + :param shape_must_be_divisible_by: for network prediction. After applying new_shape, make sure the new shape is + divisibly by that number (can also be a list with an entry for each axis). Whatever is missing to match that will + be padded (so the result may be larger than new_shape if shape_must_be_divisible_by is not None) + :param kwargs: see np.pad for documentation + """ + if kwargs is None: + kwargs = {} + + if new_shape is not None: + old_shape = np.array(image.shape[-len(new_shape):]) + else: + assert shape_must_be_divisible_by is not None + assert isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)) + new_shape = image.shape[-len(shape_must_be_divisible_by):] + old_shape = new_shape + + num_axes_nopad = len(image.shape) - len(new_shape) + + new_shape = [max(new_shape[i], old_shape[i]) for i in range(len(new_shape))] + + if not isinstance(new_shape, np.ndarray): + new_shape = np.array(new_shape) + + if shape_must_be_divisible_by is not None: + if not isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)): + shape_must_be_divisible_by = [shape_must_be_divisible_by] * len(new_shape) + else: + assert len(shape_must_be_divisible_by) == len(new_shape) + + for i in range(len(new_shape)): + if new_shape[i] % shape_must_be_divisible_by[i] == 0: + new_shape[i] -= shape_must_be_divisible_by[i] + + new_shape = np.array([new_shape[i] + shape_must_be_divisible_by[i] - new_shape[i] % shape_must_be_divisible_by[i] for i in range(len(new_shape))]) + + difference = new_shape - old_shape + pad_below = difference // 2 + pad_above = difference // 2 + difference % 2 + pad_list = [[0, 0]]*num_axes_nopad + list([list(i) for i in zip(pad_below, pad_above)]) + res = np.pad(image, pad_list, mode, **kwargs) + if not return_slicer: + return res + else: + pad_list = np.array(pad_list) + pad_list[:, 1] = np.array(res.shape) - pad_list[:, 1] + slicer = list(slice(*i) for i in pad_list) + return res, slicer + +def convert_seg_to_bounding_box_coordinates(data_dict, dim, roi_item_keys, get_rois_from_seg=False, + class_specific_seg=False): + '''adapted from batchgenerators + + :param data_dict: seg: segmentation with labels indicating roi_count (get_rois_from_seg=False) or classes (get_rois_from_seg=True), + class_targets: list where list index corresponds to roi id (roi_count) + :param dim: + :param roi_item_keys: keys of the roi-wise items in data_dict to process + :param n_rg_feats: nr of regression vector features + :param get_rois_from_seg: + :return: coords (y1,x1,y2,x2 (,z1,z2)) + ''' + + bb_target = [] + roi_masks = [] + roi_items = {name:[] for name in roi_item_keys} + out_seg = np.copy(data_dict['seg']) + for b in range(data_dict['seg'].shape[0]): + + p_coords_list = [] #p for patient? + p_roi_masks_list = [] + p_roi_items_lists = {name:[] for name in roi_item_keys} + + if np.sum(data_dict['seg'][b] != 0) > 0: + if get_rois_from_seg: + clusters, n_cands = lb(data_dict['seg'][b]) + data_dict['class_targets'][b] = [data_dict['class_targets'][b]] * n_cands + else: + n_cands = int(np.max(data_dict['seg'][b])) + + rois = np.array( + [(data_dict['seg'][b] == ii) * 1 for ii in range(1, n_cands + 1)], dtype='uint8') # separate clusters + + for rix, r in enumerate(rois): + if np.sum(r != 0) > 0: # check if the roi survived slicing (3D->2D) and data augmentation (cropping etc.) + seg_ixs = np.argwhere(r != 0) + coord_list = [np.min(seg_ixs[:, 1]) - 1, np.min(seg_ixs[:, 2]) - 1, np.max(seg_ixs[:, 1]) + 1, + np.max(seg_ixs[:, 2]) + 1] + if dim == 3: + coord_list.extend([np.min(seg_ixs[:, 3]) - 1, np.max(seg_ixs[:, 3]) + 1]) + + p_coords_list.append(coord_list) + p_roi_masks_list.append(r) + # add background class = 0. rix is a patient wide index of lesions. since 'class_targets' is + # also patient wide, this assignment is not dependent on patch occurrences. + for name in roi_item_keys: + # if name == "class_targets": + # # add background class = 0. rix is a patient-wide index of lesions. since 'class_targets' is + # # also patient wide, this assignment is not dependent on patch occurrences. + # p_roi_items_lists[name].append(data_dict[name][b][rix]+1) + # else: + p_roi_items_lists[name].append(data_dict[name][b][rix]) + + assert data_dict["class_targets"][b][rix]>=1, "convertsegtobbox produced bg roi w cl targ {} and unique roi seg {}".format(data_dict["class_targets"][b][rix], np.unique(r)) + + + if class_specific_seg: + out_seg[b][data_dict['seg'][b] == rix + 1] = data_dict['class_targets'][b][rix] #+ 1 + + if not class_specific_seg: + out_seg[b][data_dict['seg'][b] > 0] = 1 + + bb_target.append(np.array(p_coords_list)) + roi_masks.append(np.array(p_roi_masks_list)) + for name in roi_item_keys: + roi_items[name].append(np.array(p_roi_items_lists[name])) + + + else: + bb_target.append([]) + roi_masks.append(np.zeros_like(data_dict['seg'][b], dtype='uint8')[None]) + for name in roi_item_keys: + roi_items[name].append(np.array([])) + + if get_rois_from_seg: + data_dict.pop('class_targets', None) + + data_dict['bb_target'] = np.array(bb_target) + data_dict['roi_masks'] = np.array(roi_masks) + data_dict['seg'] = out_seg + for name in roi_item_keys: + data_dict[name] = np.array(roi_items[name]) + + + return data_dict + +class ConvertSegToBoundingBoxCoordinates(AbstractTransform): + """ Converts segmentation masks into bounding box coordinates. + """ + + def __init__(self, dim, roi_item_keys, get_rois_from_seg=False, class_specific_seg=False): + self.dim = dim + self.roi_item_keys = roi_item_keys + self.get_rois_from_seg = get_rois_from_seg + self.class_specific_seg = class_specific_seg + + def __call__(self, **data_dict): + return convert_seg_to_bounding_box_coordinates(data_dict, self.dim, self.roi_item_keys, self.get_rois_from_seg, + self.class_specific_seg) + + + + + +############################# +# data packing / unpacking # not used, data_manager.py used instead +############################# + +def get_case_identifiers(folder): + case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith("npz")] + return case_identifiers + + +def convert_to_npy(npz_file): + if not os.path.isfile(npz_file[:-3] + "npy"): + a = np.load(npz_file)['data'] + np.save(npz_file[:-3] + "npy", a) + + +def unpack_dataset(folder, threads=8): + case_identifiers = get_case_identifiers(folder) + p = Pool(threads) + npz_files = [os.path.join(folder, i + ".npz") for i in case_identifiers] + p.map(convert_to_npy, npz_files) + p.close() + p.join() + + +def delete_npy(folder): + case_identifiers = get_case_identifiers(folder) + npy_files = [os.path.join(folder, i + ".npy") for i in case_identifiers] + npy_files = [i for i in npy_files if os.path.isfile(i)] + for n in npy_files: + os.remove(n) \ No newline at end of file diff --git a/utils/exp_utils.py b/utils/exp_utils.py new file mode 100644 index 0000000..26a3485 --- /dev/null +++ b/utils/exp_utils.py @@ -0,0 +1,630 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +#import plotting as plg + +import sys +import os +import subprocess +import threading +import pickle +import importlib.util +import psutil +from functools import partial +import time + +import logging +from tensorboardX import SummaryWriter + +from collections import OrderedDict +import numpy as np +import pandas as pd +import torch + + +def import_module(name, path): + """ + correct way of importing a module dynamically in python 3. + :param name: name given to module instance. + :param path: path to module. + :return: module: returned module instance. + """ + spec = importlib.util.spec_from_file_location(name, path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + +def save_obj(obj, name): + """Pickle a python object.""" + with open(name + '.pkl', 'wb') as f: + pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) + +def load_obj(file_path): + with open(file_path, 'rb') as handle: + return pickle.load(handle) + +def IO_safe(func, *args, _tries=5, _raise=True, **kwargs): + """ Wrapper calling function func with arguments args and keyword arguments kwargs to catch input/output errors + on cluster. + :param func: function to execute (intended to be read/write operation to a problematic cluster drive, but can be + any function). + :param args: positional args of func. + :param kwargs: kw args of func. + :param _tries: how many attempts to make executing func. + """ + for _try in range(_tries): + try: + return func(*args, **kwargs) + except OSError as e: # to catch cluster issues with network drives + if _raise: + raise e + else: + print("After attempting execution {} time{}, following error occurred:\n{}".format(_try+1,"" if _try==0 else "s", e)) + continue + + +def query_nvidia_gpu(device_id, d_keyword=None, no_units=False): + """ + :param device_id: + :param d_keyword: -d, --display argument (keyword(s) for selective display), all are selected if None + :return: dict of gpu-info items + """ + cmd = ['nvidia-smi', '-i', str(device_id), '-q'] + if d_keyword is not None: + cmd += ['-d', d_keyword] + outp = subprocess.check_output(cmd).strip().decode('utf-8').split("\n") + outp = [x for x in outp if len(x)>0] + headers = [ix for ix, item in enumerate(outp) if len(item.split(":"))==1] + [len(outp)] + + out_dict = {} + for lix, hix in enumerate(headers[:-1]): + head = outp[hix].strip().replace(" ", "_").lower() + out_dict[head] = {} + for lix2 in range(hix, headers[lix+1]): + try: + key, val = [x.strip().lower() for x in outp[lix2].split(":")] + if no_units: + val = val.split()[0] + out_dict[head][key] = val + except: + pass + + return out_dict + +class CombinedPrinter(object): + """combined print function. + prints to logger and/or file if given, to normal print if non given. + + """ + def __init__(self, logger=None, file=None): + + if logger is None and file is None: + self.out = [print] + elif logger is None: + self.out = [file.write] + elif file is None: + self.out = [logger.info] + else: + self.out = [logger.info, file.write] + + def __call__(self, string): + for fct in self.out: + fct(string) + +class Nvidia_GPU_Logger(object): + def __init__(self): + self.count = None + + def get_vals(self): + + cmd = ['nvidia-settings', '-t', '-q', 'GPUUtilization'] + gpu_util = subprocess.check_output(cmd).strip().decode('utf-8').split(",") + gpu_util = dict([f.strip().split("=") for f in gpu_util]) + cmd[-1] = 'UsedDedicatedGPUMemory' + gpu_used_mem = subprocess.check_output(cmd).strip().decode('utf-8') + current_vals = {"gpu_mem_alloc": gpu_used_mem, "gpu_graphics_util": int(gpu_util['graphics']), + "gpu_mem_util": gpu_util['memory'], "time": time.time()} + return current_vals + + def loop(self): + i = 0 + while True: + self.get_vals() + self.log["time"].append(time.time()) + self.log["gpu_util"].append(self.current_vals["gpu_graphics_util"]) + if self.count != None: + i += 1 + if i == count: + exit(0) + time.sleep(self.interval) + + def start(self, interval=1.): + self.interval = interval + self.start_time = time.time() + self.log = {"time": [], "gpu_util": []} + if self.interval is not None: + thread = threading.Thread(target=self.loop) + thread.daemon = True + thread.start() + +class CombinedLogger(object): + """Combine console and tensorboard logger and record system metrics. + """ + def __init__(self, name, log_dir, server_env=True, fold="", sysmetrics_interval=2): + self.pylogger = logging.getLogger(name) + self.tboard = SummaryWriter(log_dir=log_dir) + self.times = {} + self.fold = fold + # monitor system metrics (cpu, mem, ...) + if not server_env: + self.sysmetrics = pd.DataFrame(columns=["global_step", "rel_time", r"CPU (%)", "mem_used (GB)", r"mem_used (%)", + r"swap_used (GB)", r"gpu_utilization (%)"], dtype="float16") + for device in range(torch.cuda.device_count()): + self.sysmetrics["mem_allocd (GB) by torch on {:10s}".format(torch.cuda.get_device_name(device))] = np.nan + self.sysmetrics["mem_cached (GB) by torch on {:10s}".format(torch.cuda.get_device_name(device))] = np.nan + self.sysmetrics_start(sysmetrics_interval) + + def __getattr__(self, attr): + """delegate all undefined method requests to objects of + this class in order pylogger, tboard (first find first serve). + E.g., combinedlogger.add_scalars(...) should trigger self.tboard.add_scalars(...) + """ + for obj in [self.pylogger, self.tboard]: + if attr in dir(obj): + return getattr(obj, attr) + raise AttributeError("CombinedLogger has no attribute {}".format(attr)) + + + def time(self, name, toggle=None): + """record time-spans as with a stopwatch. + :param name: + :param toggle: True^=On: start time recording, False^=Off: halt rec. if None determine from current status. + :return: either start-time or last recorded interval + """ + if toggle is None: + if name in self.times.keys(): + toggle = not self.times[name]["toggle"] + else: + toggle = True + + if toggle: + if not name in self.times.keys(): + self.times[name] = {"total": 0, "last":0} + elif self.times[name]["toggle"] == toggle: + print("restarting running stopwatch") + self.times[name]["last"] = time.time() + self.times[name]["toggle"] = toggle + return time.time() + else: + if toggle == self.times[name]["toggle"]: + self.info("WARNING: tried to stop stopped stop watch: {}.".format(name)) + self.times[name]["last"] = time.time()-self.times[name]["last"] + self.times[name]["total"] += self.times[name]["last"] + self.times[name]["toggle"] = toggle + return self.times[name]["last"] + + def get_time(self, name=None, kind="total", format=None, reset=False): + """ + :param name: + :param kind: 'total' or 'last' + :param format: None for float, "hms"/"ms" for (hours), mins, secs as string + :param reset: reset time after retrieving + :return: + """ + if name is None: + times = self.times + if reset: + self.reset_time() + return times + + else: + time = self.times[name][kind] + if format == "hms": + m, s = divmod(time, 60) + h, m = divmod(m, 60) + time = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(m), int(s)) + elif format == "ms": + m, s = divmod(time, 60) + time = "{:02d}m:{:02d}s".format(int(m), int(s)) + if reset: + self.reset_time(name) + return time + + def reset_time(self, name=None): + if name is None: + self.times = {} + else: + del self.times[name] + + + def sysmetrics_update(self, global_step=None): + if global_step is None: + global_step = time.strftime("%x_%X") + mem = psutil.virtual_memory() + mem_used = (mem.total-mem.available) + gpu_vals = self.gpu_logger.get_vals() + rel_time = time.time()-self.sysmetrics_start_time + self.sysmetrics.loc[len(self.sysmetrics)] = [global_step, rel_time, + psutil.cpu_percent(), mem_used/1024**3, mem_used/mem.total*100, + psutil.swap_memory().used/1024**3, int(gpu_vals['gpu_graphics_util']), + *[torch.cuda.memory_allocated(d)/1024**3 for d in range(torch.cuda.device_count())], + *[torch.cuda.memory_cached(d)/1024**3 for d in range(torch.cuda.device_count())] + ] + return self.sysmetrics.loc[len(self.sysmetrics)-1].to_dict() + + def sysmetrics2tboard(self, metrics=None, global_step=None, suptitle=None): + tag = "per_time" + if metrics is None: + metrics = self.sysmetrics_update(global_step=global_step) + tag = "per_epoch" + + if suptitle is not None: + suptitle = str(suptitle) + elif self.fold!="": + suptitle = "Fold_"+str(self.fold) + if suptitle is not None: + self.tboard.add_scalars(suptitle+"/System_Metrics/"+tag, {k:v for (k,v) in metrics.items() if (k!="global_step" + and k!="rel_time")}, global_step) + + def sysmetrics_loop(self): + try: + os.nice(-19) + except: + print("System-metrics logging has no superior process priority.") + while True: + metrics = self.sysmetrics_update() + self.sysmetrics2tboard(metrics, global_step=metrics["rel_time"]) + #print("thread alive", self.thread.is_alive()) + time.sleep(self.sysmetrics_interval) + + def sysmetrics_start(self, interval): + if interval is not None: + self.sysmetrics_interval = interval + self.gpu_logger = Nvidia_GPU_Logger() + self.sysmetrics_start_time = time.time() + self.thread = threading.Thread(target=self.sysmetrics_loop) + self.thread.daemon = True + self.thread.start() + + def sysmetrics_save(self, out_file): + + self.sysmetrics.to_pickle(out_file) + + + def metrics2tboard(self, metrics, global_step=None, suptitle=None): + """ + :param metrics: {'train': dataframe, 'val':df}, df as produced in + evaluator.py.evaluate_predictions + """ + #print("metrics", metrics) + if global_step is None: + global_step = len(metrics['train'][list(metrics['train'].keys())[0]])-1 + if suptitle is not None: + suptitle = str(suptitle) + else: + suptitle = "Fold_"+str(self.fold) + + for key in ['train', 'val']: + #series = {k:np.array(v[-1]) for (k,v) in metrics[key].items() if not np.isnan(v[-1]) and not 'Bin_Stats' in k} + loss_series = {} + unc_series = {} + bin_stat_series = {} + mon_met_series = {} + for tag,val in metrics[key].items(): + val = val[-1] #maybe remove list wrapping, recording in evaluator? + if 'bin_stats' in tag.lower() and not np.isnan(val): + bin_stat_series["{}".format(tag.split("/")[-1])] = val + elif 'uncertainty' in tag.lower() and not np.isnan(val): + unc_series["{}".format(tag)] = val + elif 'loss' in tag.lower() and not np.isnan(val): + loss_series["{}".format(tag)] = val + elif not np.isnan(val): + mon_met_series["{}".format(tag)] = val + + self.tboard.add_scalars(suptitle+"/Binary_Statistics/{}".format(key), bin_stat_series, global_step) + self.tboard.add_scalars(suptitle + "/Uncertainties/{}".format(key), unc_series, global_step) + self.tboard.add_scalars(suptitle + "/Losses/{}".format(key), loss_series, global_step) + self.tboard.add_scalars(suptitle+"/Monitor_Metrics/{}".format(key), mon_met_series, global_step) + self.tboard.add_scalars(suptitle + "/Learning_Rate", metrics["lr"], global_step) + return + + def batchImgs2tboard(self, batch, results_dict, cmap, boxtype2color, img_bg=False, global_step=None): + raise NotImplementedError("not up-to-date, problem with importing plotting-file, torchvision dependency.") + if len(batch["seg"].shape)==5: #3D imgs + slice_ix = np.random.randint(batch["seg"].shape[-1]) + seg_gt = plg.to_rgb(batch['seg'][:,0,:,:,slice_ix], cmap) + seg_pred = plg.to_rgb(results_dict['seg_preds'][:,0,:,:,slice_ix], cmap) + + mod_img = plg.mod_to_rgb(batch["data"][:,0,:,:,slice_ix]) if img_bg else None + + elif len(batch["seg"].shape)==4: + seg_gt = plg.to_rgb(batch['seg'][:,0,:,:], cmap) + seg_pred = plg.to_rgb(results_dict['seg_preds'][:,0,:,:], cmap) + mod_img = plg.mod_to_rgb(batch["data"][:,0]) if img_bg else None + else: + raise Exception("batch content has wrong format: {}".format(batch["seg"].shape)) + + #from here on only works in 2D + seg_gt = np.transpose(seg_gt, axes=(0,3,1,2)) #previous shp: b,x,y,c + seg_pred = np.transpose(seg_pred, axes=(0,3,1,2)) + + + seg = np.concatenate((seg_gt, seg_pred), axis=0) + # todo replace torchvision (tv) dependency + seg = tv.utils.make_grid(torch.from_numpy(seg), nrow=2) + self.tboard.add_image("Batch seg, 1st col: gt, 2nd: pred.", seg, global_step=global_step) + + if img_bg: + bg_img = np.transpose(mod_img, axes=(0,3,1,2)) + else: + bg_img = seg_gt + box_imgs = plg.draw_boxes_into_batch(bg_img, results_dict["boxes"], boxtype2color) + box_imgs = tv.utils.make_grid(torch.from_numpy(box_imgs), nrow=4) + self.tboard.add_image("Batch bboxes", box_imgs, global_step=global_step) + + return + + def __del__(self): # otherwise might produce multiple prints e.g. in ipython console + for hdlr in self.pylogger.handlers: + hdlr.close() + self.tboard.close() + self.pylogger.handlers = [] + del self.pylogger + +def get_logger(exp_dir, server_env=False, sysmetrics_interval=2): + log_dir = os.path.join(exp_dir, "logs") + logger = CombinedLogger('medical_detection', os.path.join(log_dir, "tboard"), server_env=server_env, + sysmetrics_interval=sysmetrics_interval) + logger.setLevel(logging.DEBUG) + log_file = os.path.join(log_dir, 'exec.log') + + logger.addHandler(logging.FileHandler(log_file)) + if not server_env: + logger.addHandler(ColorHandler()) + else: + logger.addHandler(logging.StreamHandler()) + logger.pylogger.propagate = False + print('Logging to {}'.format(log_file)) + + return logger + +def prep_exp(dataset_path, exp_path, server_env, use_stored_settings=True, is_training=True): + """ + I/O handling, creating of experiment folder structure. Also creates a snapshot of configs/model scripts and copies them to the exp_dir. + This way the exp_dir contains all info needed to conduct an experiment, independent to changes in actual source code. Thus, training/inference of this experiment can be started at anytime. + Therefore, the model script is copied back to the source code dir as tmp_model (tmp_backbone). + Provides robust structure for cloud deployment. + :param dataset_path: path to source code for specific data set. (e.g. medicaldetectiontoolkit/lidc_exp) + :param exp_path: path to experiment directory. + :param server_env: boolean flag. pass to configs script for cloud deployment. + :param use_stored_settings: boolean flag. When starting training: If True, starts training from snapshot in existing + experiment directory, else creates experiment directory on the fly using configs/model scripts from source code. + :param is_training: boolean flag. distinguishes train vs. inference mode. + :return: configs object. + """ + + if is_training: + + if use_stored_settings: + cf_file = import_module('cf', os.path.join(exp_path, 'configs.py')) + cf = cf_file.Configs(server_env) + # in this mode, previously saved model and backbone need to be found in exp dir. + if not os.path.isfile(os.path.join(exp_path, 'model.py')) or \ + not os.path.isfile(os.path.join(exp_path, 'backbone.py')): + raise Exception("Selected use_stored_settings option but no model and/or backbone source files exist in exp dir.") + cf.model_path = os.path.join(exp_path, 'model.py') + cf.backbone_path = os.path.join(exp_path, 'backbone.py') + else: # this case overwrites settings files in exp dir, i.e., default_configs, configs, backbone, model + if not os.path.exists(exp_path): + os.mkdir(exp_path) + # run training with source code info and copy snapshot of model to exp_dir for later testing (overwrite scripts if exp_dir already exists.) + subprocess.call('cp {} {}'.format('default_configs.py', os.path.join(exp_path, 'default_configs.py')), shell=True) + subprocess.call('cp {} {}'.format(os.path.join(dataset_path, 'configs.py'), os.path.join(exp_path, 'configs.py')), shell=True) + cf_file = import_module('cf_file', os.path.join(dataset_path, 'configs.py')) + cf = cf_file.Configs(server_env) + subprocess.call('cp {} {}'.format(cf.model_path, os.path.join(exp_path, 'model.py')), shell=True) + subprocess.call('cp {} {}'.format(cf.backbone_path, os.path.join(exp_path, 'backbone.py')), shell=True) + if os.path.isfile(os.path.join(exp_path, "fold_ids.pickle")): + subprocess.call('rm {}'.format(os.path.join(exp_path, "fold_ids.pickle")), shell=True) + + else: # testing, use model and backbone stored in exp dir. + cf_file = import_module('cf', os.path.join(exp_path, 'configs.py')) + cf = cf_file.Configs(server_env) + cf.model_path = os.path.join(exp_path, 'model.py') + cf.backbone_path = os.path.join(exp_path, 'backbone.py') + + cf.exp_dir = exp_path + cf.test_dir = os.path.join(cf.exp_dir, 'test') + cf.plot_dir = os.path.join(cf.exp_dir, 'plots') + if not os.path.exists(cf.test_dir): + os.mkdir(cf.test_dir) + if not os.path.exists(cf.plot_dir): + os.mkdir(cf.plot_dir) + cf.experiment_name = exp_path.split("/")[-1] + cf.dataset_name = dataset_path + cf.server_env = server_env + cf.created_fold_id_pickle = False + + return cf + +class ModelSelector: + ''' + saves a checkpoint after each epoch as 'last_state' (can be loaded to continue interrupted training). + saves the top-k (k=cf.save_n_models) ranked epochs. In inference, predictions of multiple epochs can be ensembled + to improve performance. + ''' + + def __init__(self, cf, logger): + + self.cf = cf + self.saved_epochs = [-1] * cf.save_n_models + self.logger = logger + + + def run_model_selection(self, net, optimizer, monitor_metrics, epoch): + """rank epoch via weighted mean from self.cf.model_selection_criteria: {criterion : weight} + :param net: + :param optimizer: + :param monitor_metrics: + :param epoch: + :return: + """ + crita = self.cf.model_selection_criteria #shorter alias + + non_nan_scores = {} + for criterion in crita.keys(): + #exclude first entry bc its dummy None entry + non_nan_scores[criterion] = [0 if (ii is None or np.isnan(ii)) else ii for ii in monitor_metrics['val'][criterion]][1:] + n_epochs = len(non_nan_scores[criterion]) + epochs_scores = [] + for e_ix in range(n_epochs): + epochs_scores.append(np.sum([weight * non_nan_scores[criterion][e_ix] for + criterion,weight in crita.items()])/len(crita.keys())) + + # ranking of epochs according to model_selection_criterion + epoch_ranking = np.argsort(epochs_scores)[::-1] + 1 #epochs start at 1 + + # if set in configs, epochs < min_save_thresh are discarded from saving process. + epoch_ranking = epoch_ranking[epoch_ranking >= self.cf.min_save_thresh] + + # check if current epoch is among the top-k epchs. + if epoch in epoch_ranking[:self.cf.save_n_models]: + if self.cf.server_env: + IO_safe(torch.save, net.state_dict(), os.path.join(self.cf.fold_dir, '{}_best_params.pth'.format(epoch))) + # save epoch_ranking to keep info for inference. + IO_safe(np.save, os.path.join(self.cf.fold_dir, 'epoch_ranking'), epoch_ranking[:self.cf.save_n_models]) + else: + torch.save(net.state_dict(), os.path.join(self.cf.fold_dir, '{}_best_params.pth'.format(epoch))) + np.save(os.path.join(self.cf.fold_dir, 'epoch_ranking'), epoch_ranking[:self.cf.save_n_models]) + self.logger.info( + "saving current epoch {} at rank {}".format(epoch, np.argwhere(epoch_ranking == epoch))) + # delete params of the epoch that just fell out of the top-k epochs. + for se in [int(ii.split('_')[0]) for ii in os.listdir(self.cf.fold_dir) if 'best_params' in ii]: + if se in epoch_ranking[self.cf.save_n_models:]: + subprocess.call('rm {}'.format(os.path.join(self.cf.fold_dir, '{}_best_params.pth'.format(se))), + shell=True) + self.logger.info('deleting epoch {} at rank {}'.format(se, np.argwhere(epoch_ranking == se))) + + state = { + 'epoch': epoch, + 'state_dict': net.state_dict(), + 'optimizer': optimizer.state_dict(), + } + + if self.cf.server_env: + IO_safe(torch.save, state, os.path.join(self.cf.fold_dir, 'last_state.pth')) + else: + torch.save(state, os.path.join(self.cf.fold_dir, 'last_state.pth')) + + +def load_checkpoint(checkpoint_path, net, optimizer): + + checkpoint = torch.load(checkpoint_path) + net.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + return checkpoint['epoch'] + + +def prepare_monitoring(cf): + """ + creates dictionaries, where train/val metrics are stored. + """ + metrics = {} + # first entry for loss dict accounts for epoch starting at 1. + metrics['train'] = OrderedDict()# [(l_name, [np.nan]) for l_name in cf.losses_to_monitor] ) + metrics['val'] = OrderedDict()# [(l_name, [np.nan]) for l_name in cf.losses_to_monitor] ) + metric_classes = [] + if 'rois' in cf.report_score_level: + metric_classes.extend([v for k, v in cf.class_dict.items()]) + if hasattr(cf, "eval_bins_separately") and cf.eval_bins_separately: + metric_classes.extend([v for k, v in cf.bin_dict.items()]) + if 'patient' in cf.report_score_level: + metric_classes.extend(['patient_'+cf.class_dict[cf.patient_class_of_interest]]) + if hasattr(cf, "eval_bins_separately") and cf.eval_bins_separately: + metric_classes.extend(['patient_' + cf.bin_dict[cf.patient_bin_of_interest]]) + for cl in metric_classes: + for m in cf.metrics: + metrics['train'][cl + '_' + m] = [np.nan] + metrics['val'][cl + '_' + m] = [np.nan] + + return metrics + + +class _AnsiColorizer(object): + """ + A colorizer is an object that loosely wraps around a stream, allowing + callers to write text to the stream in a particular color. + + Colorizer classes must implement C{supported()} and C{write(text, color)}. + """ + _colors = dict(black=30, red=31, green=32, yellow=33, + blue=34, magenta=35, cyan=36, white=37, default=39) + + def __init__(self, stream): + self.stream = stream + + @classmethod + def supported(cls, stream=sys.stdout): + """ + A class method that returns True if the current platform supports + coloring terminal output using this method. Returns False otherwise. + """ + if not stream.isatty(): + return False # auto color only on TTYs + try: + import curses + except ImportError: + return False + else: + try: + try: + return curses.tigetnum("colors") > 2 + except curses.error: + curses.setupterm() + return curses.tigetnum("colors") > 2 + except: + raise + # guess false in case of error + return False + + def write(self, text, color): + """ + Write the given text to the stream in the given color. + + @param text: Text to be written to the stream. + + @param color: A string label for a color. e.g. 'red', 'white'. + """ + color = self._colors[color] + self.stream.write('\x1b[%sm%s\x1b[0m' % (color, text)) + +class ColorHandler(logging.StreamHandler): + + + def __init__(self, stream=sys.stdout): + super(ColorHandler, self).__init__(_AnsiColorizer(stream)) + + def emit(self, record): + msg_colors = { + logging.DEBUG: "green", + logging.INFO: "default", + logging.WARNING: "red", + logging.ERROR: "red" + } + color = msg_colors.get(record.levelno, "blue") + self.stream.write(record.msg + "\n", color) + + + diff --git a/utils/model_utils.py b/utils/model_utils.py new file mode 100644 index 0000000..aef2bbb --- /dev/null +++ b/utils/model_utils.py @@ -0,0 +1,1471 @@ +#!/usr/bin/env python +# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +""" +Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn +published under MIT license. +""" +import time +import warnings +warnings.filterwarnings('ignore', '.*From scipy 0.13.0, the output shape of zoom()*') + +import numpy as np +import math +import scipy.misc +import scipy.ndimage +from scipy.ndimage.measurements import label as lb +import torch +from torch.autograd import Variable + +from cuda_functions.nms_2D.pth_nms import nms_gpu as nms_2D +from cuda_functions.nms_3D.pth_nms import nms_gpu as nms_3D +from cuda_functions.roi_align_2D.roi_align.crop_and_resize import CropAndResizeFunction as ra2D +from cuda_functions.roi_align_3D.roi_align.crop_and_resize import CropAndResizeFunction as ra3D + + +############################################################ +# Segmentation Processing +############################################################ + +def sum_tensor(input, axes, keepdim=False): + axes = np.unique(axes) + if keepdim: + for ax in axes: + input = input.sum(ax, keepdim=True) + else: + for ax in sorted(axes, reverse=True): + input = input.sum(int(ax)) + return input + +def get_one_hot_encoding(y, n_classes): + """ + transform a numpy label array to a one-hot array of the same shape. + :param y: array of shape (b, 1, y, x, (z)). + :param n_classes: int, number of classes to unfold in one-hot encoding. + :return y_ohe: array of shape (b, n_classes, y, x, (z)) + """ + + dim = len(y.shape) - 2 + if dim == 2: + y_ohe = np.zeros((y.shape[0], n_classes, y.shape[2], y.shape[3])).astype('int32') + elif dim == 3: + y_ohe = np.zeros((y.shape[0], n_classes, y.shape[2], y.shape[3], y.shape[4])).astype('int32') + else: + raise Exception("invalid dimensions {} encountered".format(y.shape)) + for cl in np.arange(n_classes): + y_ohe[:, cl][y[:, 0] == cl] = 1 + return y_ohe + +def dice_per_batch_inst_and_class(pred, y, n_classes, convert_to_ohe=True, smooth=1e-8): + #actually per batch_instance not batch + ''' + computes dice scores per batch instance and class. + :param pred: prediction array of shape (b, 1, y, x, (z)) (e.g. softmax prediction with argmax over dim 1) + :param y: ground truth array of shape (b, 1, y, x, (z)) (contains int [0, ..., n_classes] + :param n_classes: int + :return: dice scores of shape (b, c) + ''' + if convert_to_ohe: + pred = get_one_hot_encoding(pred, n_classes) + y = get_one_hot_encoding(y, n_classes) + axes = tuple(range(2, len(pred.shape))) + intersect = np.sum(pred*y, axis=axes) + denominator = np.sum(pred, axis=axes)+np.sum(y, axis=axes) + dice = (2.0*intersect + smooth) / (denominator + smooth) + return dice + +def dice_per_batch_and_class(pred, targ, n_classes, convert_to_ohe=True, smooth=1e-8): + ''' + computes dice scores per batch and class. + :param pred: prediction array of shape (b, 1, y, x, (z)) (e.g. softmax prediction with argmax over dim 1) + :param targ: ground truth array of shape (b, 1, y, x, (z)) (contains int [0, ..., n_classes]) + :param n_classes: int + :param smooth: Laplacian smooth, https://en.wikipedia.org/wiki/Additive_smoothing + :return: dice scores of shape (b, c) + ''' + if convert_to_ohe: + pred = get_one_hot_encoding(pred, n_classes) + targ = get_one_hot_encoding(targ, n_classes) + axes = (0, *list(range(2, len(pred.shape)))) #(0,2,3(,4)) + + intersect = np.sum(pred * targ, axis=axes) + + denominator = np.sum(pred, axis=axes) + np.sum(targ, axis=axes) + dice = (2.0 * intersect + smooth) / (denominator + smooth) + + assert dice.shape==(n_classes,), "dice shp {}".format(dice.shape) + return dice + + +def batch_dice(pred, y, false_positive_weight=1.0, eps=1e-6): + ''' + compute soft dice over batch. this is a differentiable score and can be used as a loss function. + only dice scores of foreground classes are returned, since training typically + does not benefit from explicit background optimization. Pixels of the entire batch are considered a pseudo-volume to compute dice scores of. + This way, single patches with missing foreground classes can not produce faulty gradients. + :param pred: (b, c, y, x, (z)), softmax probabilities (network output). + :param y: (b, c, y, x, (z)), one hote encoded segmentation mask. + :param false_positive_weight: float [0,1]. For weighting of imbalanced classes, + reduces the penalty for false-positive pixels. Can be beneficial sometimes in data with heavy fg/bg imbalances. + :return: soft dice score (float).This function discards the background score and returns the mena of foreground scores. + ''' + # todo also use additive smooth here instead of eps? + if len(pred.size()) == 4: + axes = (0, 2, 3) + intersect = sum_tensor(pred * y, axes, keepdim=False) + denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False) + return torch.mean((2 * intersect / (denom + eps))[1:]) #only fg dice here. + + if len(pred.size()) == 5: + axes = (0, 2, 3, 4) + intersect = sum_tensor(pred * y, axes, keepdim=False) + denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False) + return torch.mean((2 * intersect / (denom + eps))[1:]) #only fg dice here. + else: + raise ValueError('wrong input dimension in dice loss') + + +############################################################ +# Bounding Boxes +############################################################ + +def compute_iou_2D(box, boxes, box_area, boxes_area): + """Calculates IoU of the given box with the array of the given boxes. + box: 1D vector [y1, x1, y2, x2] THIS IS THE GT BOX + boxes: [boxes_count, (y1, x1, y2, x2)] + box_area: float. the area of 'box' + boxes_area: array of length boxes_count. + + Note: the areas are passed in rather than calculated here for + efficency. Calculate once in the caller to avoid duplicate work. + """ + # Calculate intersection areas + y1 = np.maximum(box[0], boxes[:, 0]) + y2 = np.minimum(box[2], boxes[:, 2]) + x1 = np.maximum(box[1], boxes[:, 1]) + x2 = np.minimum(box[3], boxes[:, 3]) + intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0) + union = box_area + boxes_area[:] - intersection[:] + iou = intersection / union + + return iou + + +def compute_iou_3D(box, boxes, box_volume, boxes_volume): + """Calculates IoU of the given box with the array of the given boxes. + box: 1D vector [y1, x1, y2, x2, z1, z2] (typically gt box) + boxes: [boxes_count, (y1, x1, y2, x2, z1, z2)] + box_area: float. the area of 'box' + boxes_area: array of length boxes_count. + + Note: the areas are passed in rather than calculated here for + efficency. Calculate once in the caller to avoid duplicate work. + """ + # Calculate intersection areas + y1 = np.maximum(box[0], boxes[:, 0]) + y2 = np.minimum(box[2], boxes[:, 2]) + x1 = np.maximum(box[1], boxes[:, 1]) + x2 = np.minimum(box[3], boxes[:, 3]) + z1 = np.maximum(box[4], boxes[:, 4]) + z2 = np.minimum(box[5], boxes[:, 5]) + intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0) * np.maximum(z2 - z1, 0) + union = box_volume + boxes_volume[:] - intersection[:] + iou = intersection / union + + return iou + + + +def compute_overlaps(boxes1, boxes2): + """Computes IoU overlaps between two sets of boxes. + boxes1, boxes2: [N, (y1, x1, y2, x2)]. / 3D: (z1, z2)) + For better performance, pass the largest set first and the smaller second. + :return: (#boxes1, #boxes2), ious of each box of 1 machted with each of 2 + """ + # Areas of anchors and GT boxes + if boxes1.shape[1] == 4: + area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) + area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) + # Compute overlaps to generate matrix [boxes1 count, boxes2 count] + # Each cell contains the IoU value. + overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0])) + for i in range(overlaps.shape[1]): + box2 = boxes2[i] #this is the gt box + overlaps[:, i] = compute_iou_2D(box2, boxes1, area2[i], area1) + return overlaps + + else: + # Areas of anchors and GT boxes + volume1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) * (boxes1[:, 5] - boxes1[:, 4]) + volume2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) * (boxes2[:, 5] - boxes2[:, 4]) + # Compute overlaps to generate matrix [boxes1 count, boxes2 count] + # Each cell contains the IoU value. + overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0])) + for i in range(boxes2.shape[0]): + box2 = boxes2[i] # this is the gt box + overlaps[:, i] = compute_iou_3D(box2, boxes1, volume2[i], volume1) + return overlaps + + + +def box_refinement(box, gt_box): + """Compute refinement needed to transform box to gt_box. + box and gt_box are [N, (y1, x1, y2, x2)] / 3D: (z1, z2)) + """ + height = box[:, 2] - box[:, 0] + width = box[:, 3] - box[:, 1] + center_y = box[:, 0] + 0.5 * height + center_x = box[:, 1] + 0.5 * width + + gt_height = gt_box[:, 2] - gt_box[:, 0] + gt_width = gt_box[:, 3] - gt_box[:, 1] + gt_center_y = gt_box[:, 0] + 0.5 * gt_height + gt_center_x = gt_box[:, 1] + 0.5 * gt_width + + dy = (gt_center_y - center_y) / height + dx = (gt_center_x - center_x) / width + dh = torch.log(gt_height / height) + dw = torch.log(gt_width / width) + result = torch.stack([dy, dx, dh, dw], dim=1) + + if box.shape[1] > 4: + depth = box[:, 5] - box[:, 4] + center_z = box[:, 4] + 0.5 * depth + gt_depth = gt_box[:, 5] - gt_box[:, 4] + gt_center_z = gt_box[:, 4] + 0.5 * gt_depth + dz = (gt_center_z - center_z) / depth + dd = torch.log(gt_depth / depth) + result = torch.stack([dy, dx, dz, dh, dw, dd], dim=1) + + return result + + + +def unmold_mask_2D(mask, bbox, image_shape): + """Converts a mask generated by the neural network into a format similar + to it's original shape. + mask: [height, width] of type float. A small, typically 28x28 mask. + bbox: [y1, x1, y2, x2]. The box to fit the mask in. + + Returns a binary mask with the same size as the original image. + """ + y1, x1, y2, x2 = bbox + out_zoom = [y2 - y1, x2 - x1] + zoom_factor = [i / j for i, j in zip(out_zoom, mask.shape)] + + mask = scipy.ndimage.zoom(mask, zoom_factor, order=1).astype(np.float32) + + # Put the mask in the right location. + full_mask = np.zeros(image_shape[:2]) #only y,x + full_mask[y1:y2, x1:x2] = mask + return full_mask + + +def unmold_mask_2D_torch(mask, bbox, image_shape): + """Converts a mask generated by the neural network into a format similar + to it's original shape. + mask: [height, width] of type float. A small, typically 28x28 mask. + bbox: [y1, x1, y2, x2]. The box to fit the mask in. + + Returns a binary mask with the same size as the original image. + """ + y1, x1, y2, x2 = bbox + out_zoom = [(y2 - y1).float(), (x2 - x1).float()] + zoom_factor = [i / j for i, j in zip(out_zoom, mask.shape)] + + mask = mask.unsqueeze(0).unsqueeze(0) + mask = torch.nn.functional.interpolate(mask, scale_factor=zoom_factor) + mask = mask[0][0] + #mask = scipy.ndimage.zoom(mask.cpu().numpy(), zoom_factor, order=1).astype(np.float32) + #mask = torch.from_numpy(mask).cuda() + # Put the mask in the right location. + full_mask = torch.zeros(image_shape[:2]) # only y,x + full_mask[y1:y2, x1:x2] = mask + return full_mask + + + +def unmold_mask_3D(mask, bbox, image_shape): + """Converts a mask generated by the neural network into a format similar + to it's original shape. + mask: [height, width] of type float. A small, typically 28x28 mask. + bbox: [y1, x1, y2, x2, z1, z2]. The box to fit the mask in. + + Returns a binary mask with the same size as the original image. + """ + y1, x1, y2, x2, z1, z2 = bbox + out_zoom = [y2 - y1, x2 - x1, z2 - z1] + zoom_factor = [i/j for i,j in zip(out_zoom, mask.shape)] + mask = scipy.ndimage.zoom(mask, zoom_factor, order=1).astype(np.float32) + + # Put the mask in the right location. + full_mask = np.zeros(image_shape[:3]) + full_mask[y1:y2, x1:x2, z1:z2] = mask + return full_mask + +def nms_numpy(box_coords, scores, thresh): + """ non-maximum suppression on 2D or 3D boxes in numpy. + :param box_coords: [y1,x1,y2,x2 (,z1,z2)] with y1<=y2, x1<=x2, z1<=z2. + :param scores: ranking scores (higher score == higher rank) of boxes. + :param thresh: IoU threshold for clustering. + :return: + """ + y1 = box_coords[:, 0] + x1 = box_coords[:, 1] + y2 = box_coords[:, 2] + x2 = box_coords[:, 3] + assert np.all(y1 <= y2) and np.all(x1 <= x2), """"the definition of the coordinates is crucially important here: + coordinates of which maxima are taken need to be the lower coordinates""" + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + + is_3d = box_coords.shape[1] == 6 + if is_3d: # 3-dim case + z1 = box_coords[:, 4] + z2 = box_coords[:, 5] + assert np.all(z1<=z2), """"the definition of the coordinates is crucially important here: + coordinates of which maxima are taken need to be the lower coordinates""" + areas *= (z2 - z1 + 1) + + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: # order is the sorted index. maps order to index: order[1] = 24 means (rank1, ix 24) + i = order[0] # highest scoring element + yy1 = np.maximum(y1[i], y1[order]) # highest scoring element still in >order<, is compared to itself, that is okay. + xx1 = np.maximum(x1[i], x1[order]) + yy2 = np.minimum(y2[i], y2[order]) + xx2 = np.minimum(x2[i], x2[order]) + + h = np.maximum(0.0, yy2 - yy1 + 1) + w = np.maximum(0.0, xx2 - xx1 + 1) + inter = h * w + + if is_3d: + zz1 = np.maximum(z1[i], z1[order]) + zz2 = np.minimum(z2[i], z2[order]) + d = np.maximum(0.0, zz2 - zz1 + 1) + inter *= d + + iou = inter / (areas[i] + areas[order] - inter) + + non_matches = np.nonzero(iou <= thresh)[0] # get all elements that were not matched and discard all others. + #print("iou keep {}: {}, non_matches {}".format(i, iou, order[non_matches])) + order = order[non_matches] + keep.append(i) + #print("total keep", keep) + return keep + + + +############################################################ +# M-RCNN +############################################################ + +def refine_proposals(rpn_pred_probs, rpn_pred_deltas, proposal_count, batch_anchors, cf): + """ + Receives anchor scores and selects a subset to pass as proposals + to the second stage. Filtering is done based on anchor scores and + non-max suppression to remove overlaps. It also applies bounding + box refinment details to anchors. + :param rpn_pred_probs: (b, n_anchors, 2) + :param rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d)))) + :return: batch_normalized_props: Proposals in normalized coordinates (b, proposal_count, (y1, x1, y2, x2, (z1), (z2), score)) + :return: batch_out_proposals: Box coords + RPN foreground scores + for monitoring/plotting (b, proposal_count, (y1, x1, y2, x2, (z1), (z2), score)) + """ + std_dev = torch.from_numpy(cf.rpn_bbox_std_dev[None]).float().cuda() + norm = torch.from_numpy(cf.scale).float().cuda() + anchors = batch_anchors.clone() + + batch_scores = rpn_pred_probs[:, :, 1] + # norm deltas + batch_deltas = rpn_pred_deltas * std_dev + batch_normalized_props = [] + batch_out_proposals = [] + + # loop over batch dimension. + for ix in range(batch_scores.shape[0]): + + scores = batch_scores[ix] + deltas = batch_deltas[ix] + + # improve performance by trimming to top anchors by score + # and doing the rest on the smaller subset. + pre_nms_limit = min(cf.pre_nms_limit, anchors.size()[0]) + scores, order = scores.sort(descending=True) + order = order[:pre_nms_limit] + scores = scores[:pre_nms_limit] + deltas = deltas[order, :] + + # apply deltas to anchors to get refined anchors and filter with non-maximum suppression. + if batch_deltas.shape[-1] == 4: + boxes = apply_box_deltas_2D(anchors[order, :], deltas) + boxes = clip_boxes_2D(boxes, cf.window) + keep = nms_2D(torch.cat((boxes, scores.unsqueeze(1)), 1), cf.rpn_nms_threshold) + + else: + boxes = apply_box_deltas_3D(anchors[order, :], deltas) + boxes = clip_boxes_3D(boxes, cf.window) + keep = nms_3D(torch.cat((boxes, scores.unsqueeze(1)), 1), cf.rpn_nms_threshold) + + keep = keep[:proposal_count] + boxes = boxes[keep, :] + rpn_scores = scores[keep][:, None] + + # pad missing boxes with 0. + if boxes.shape[0] < proposal_count: + n_pad_boxes = proposal_count - boxes.shape[0] + zeros = torch.zeros([n_pad_boxes, boxes.shape[1]]).cuda() + boxes = torch.cat([boxes, zeros], dim=0) + zeros = torch.zeros([n_pad_boxes, rpn_scores.shape[1]]).cuda() + rpn_scores = torch.cat([rpn_scores, zeros], dim=0) + + # concat box and score info for monitoring/plotting. + batch_out_proposals.append(torch.cat((boxes, rpn_scores), 1).cpu().data.numpy()) + # normalize dimensions to range of 0 to 1. + normalized_boxes = boxes / norm + assert torch.all(normalized_boxes <= 1), "normalized box coords >1 found" + + # add again batch dimension + batch_normalized_props.append(torch.cat((normalized_boxes, rpn_scores), 1).unsqueeze(0)) + + batch_normalized_props = torch.cat(batch_normalized_props) + batch_out_proposals = np.array(batch_out_proposals) + + return batch_normalized_props, batch_out_proposals + +def pyramid_roi_align(feature_maps, rois, pool_size, pyramid_levels, dim): + """ + Implements ROI Pooling on multiple levels of the feature pyramid. + :param feature_maps: list of feature maps, each of shape (b, c, y, x , (z)) + :param rois: proposals (normalized coords.) as returned by RPN. contain info about original batch element allocation. + (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ixs) + :param pool_size: list of poolsizes in dims: [x, y, (z)] + :param pyramid_levels: list. [0, 1, 2, ...] + :return: pooled: pooled feature map rois (n_proposals, c, poolsize_y, poolsize_x, (poolsize_z)) + + Output: + Pooled regions in the shape: [num_boxes, height, width, channels]. + The width and height are those specific in the pool_shape in the layer + constructor. + """ + boxes = rois[:, :dim*2] + batch_ixs = rois[:, dim*2] + + # Assign each ROI to a level in the pyramid based on the ROI area. + if dim == 2: + y1, x1, y2, x2 = boxes.chunk(4, dim=1) + else: + y1, x1, y2, x2, z1, z2 = boxes.chunk(6, dim=1) + + h = y2 - y1 + w = x2 - x1 + + # Equation 1 in https://arxiv.org/abs/1612.03144. Account for + # the fact that our coordinates are normalized here. + # divide sqrt(h*w) by 1 instead image_area. + roi_level = (4 + log2(torch.sqrt(h*w))).round().int().clamp(pyramid_levels[0], pyramid_levels[-1]) + # if Pyramid contains additional level P6, adapt the roi_level assignment accordingly. + if len(pyramid_levels) == 5: + roi_level[h*w > 0.65] = 5 + + # Loop through levels and apply ROI pooling to each. + pooled = [] + box_to_level = [] + for level_ix, level in enumerate(pyramid_levels): + ix = roi_level == level + if not ix.any(): + continue + ix = torch.nonzero(ix)[:, 0] + level_boxes = boxes[ix, :] + # re-assign rois to feature map of original batch element. + ind = batch_ixs[ix].int() + + # Keep track of which box is mapped to which level + box_to_level.append(ix) + + # Stop gradient propogation to ROI proposals + level_boxes = level_boxes.detach() + + # Crop and Resize + # From Mask R-CNN paper: "We sample four regular locations, so + # that we can evaluate either max or average pooling. In fact, + # interpolating only a single value at each bin center (without + # pooling) is nearly as effective." + # + # Here we use the simplified approach of a single value per bin, + # which is how is done in tf.crop_and_resize() + # + # Also fixed a bug from original implementation, reported in: + # https://hackernoon.com/how-tensorflows-tf-image-resize-stole-60-days-of-my-life-aba5eb093f35 + + if len(pool_size) == 2: + pooled_features = ra2D(pool_size[0], pool_size[1], 0)(feature_maps[level_ix], level_boxes, ind) + else: + pooled_features = ra3D(pool_size[0], pool_size[1], pool_size[2], 0)(feature_maps[level_ix], level_boxes, ind) + + pooled.append(pooled_features) + + + # Pack pooled features into one tensor + pooled = torch.cat(pooled, dim=0) + + # Pack box_to_level mapping into one array and add another + # column representing the order of pooled boxes + box_to_level = torch.cat(box_to_level, dim=0) + + # Rearrange pooled features to match the order of the original boxes + _, box_to_level = torch.sort(box_to_level) + pooled = pooled[box_to_level, :, :] + + return pooled + +def refine_detections(cf, batch_ixs, rois, deltas, scores, regressions): + """ + Refine classified proposals (apply deltas to rpn rois), filter overlaps (nms) and return final detections. + + :param rois: (n_proposals, 2 * dim) normalized boxes as proposed by RPN. n_proposals = batch_size * POST_NMS_ROIS + :param deltas: (n_proposals, n_classes, 2 * dim) box refinement deltas as predicted by mrcnn bbox regressor. + :param batch_ixs: (n_proposals) batch element assignment info for re-allocation. + :param scores: (n_proposals, n_classes) probabilities for all classes per roi as predicted by mrcnn classifier. + :param regressions: (n_proposals, n_classes, regression_features (+1 for uncertainty if predicted) regression vector + :return: result: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score, *regression vector features)) + """ + # class IDs per ROI. Since scores of all classes are of interest (not just max class), all are kept at this point. + class_ids = [] + fg_classes = cf.head_classes - 1 + # repeat vectors to fill in predictions for all foreground classes. + for ii in range(1, fg_classes + 1): + class_ids += [ii] * rois.shape[0] + class_ids = torch.from_numpy(np.array(class_ids)).cuda() + + batch_ixs = batch_ixs.repeat(fg_classes) + rois = rois.repeat(fg_classes, 1) + deltas = deltas.repeat(fg_classes, 1, 1) + scores = scores.repeat(fg_classes, 1) + regressions = regressions.repeat(fg_classes, 1, 1) + + # get class-specific scores and bounding box deltas + idx = torch.arange(class_ids.size()[0]).long().cuda() + # using idx instead of slice [:,] squashes first dimension. + #len(class_ids)>scores.shape[1] --> probs is broadcasted by expansion from fg_classes-->len(class_ids) + batch_ixs = batch_ixs[idx] + deltas_specific = deltas[idx, class_ids] + class_scores = scores[idx, class_ids] + regressions = regressions[idx, class_ids] + + # apply bounding box deltas. re-scale to image coordinates. + std_dev = torch.from_numpy(np.reshape(cf.rpn_bbox_std_dev, [1, cf.dim * 2])).float().cuda() + scale = torch.from_numpy(cf.scale).float().cuda() + refined_rois = apply_box_deltas_2D(rois, deltas_specific * std_dev) * scale if cf.dim == 2 else \ + apply_box_deltas_3D(rois, deltas_specific * std_dev) * scale + + # round and cast to int since we're dealing with pixels now + refined_rois = clip_to_window(cf.window, refined_rois) + refined_rois = torch.round(refined_rois) + + # filter out low confidence boxes + keep = idx + keep_bool = (class_scores >= cf.model_min_confidence) + if not 0 in torch.nonzero(keep_bool).size(): + + score_keep = torch.nonzero(keep_bool)[:, 0] + pre_nms_class_ids = class_ids[score_keep] + pre_nms_rois = refined_rois[score_keep] + pre_nms_scores = class_scores[score_keep] + pre_nms_batch_ixs = batch_ixs[score_keep] + + for j, b in enumerate(unique1d(pre_nms_batch_ixs)): + + bixs = torch.nonzero(pre_nms_batch_ixs == b)[:, 0] + bix_class_ids = pre_nms_class_ids[bixs] + bix_rois = pre_nms_rois[bixs] + bix_scores = pre_nms_scores[bixs] + + for i, class_id in enumerate(unique1d(bix_class_ids)): + + ixs = torch.nonzero(bix_class_ids == class_id)[:, 0] + # nms expects boxes sorted by score. + ix_rois = bix_rois[ixs] + ix_scores = bix_scores[ixs] + ix_scores, order = ix_scores.sort(descending=True) + ix_rois = ix_rois[order, :] + + if cf.dim == 2: + class_keep = nms_2D(torch.cat((ix_rois, ix_scores.unsqueeze(1)), dim=1), cf.detection_nms_threshold) + else: + class_keep = nms_3D(torch.cat((ix_rois, ix_scores.unsqueeze(1)), dim=1), cf.detection_nms_threshold) + + # map indices back. + class_keep = keep[score_keep[bixs[ixs[order[class_keep]]]]] + # merge indices over classes for current batch element + b_keep = class_keep if i == 0 else unique1d(torch.cat((b_keep, class_keep))) + + # only keep top-k boxes of current batch-element + top_ids = class_scores[b_keep].sort(descending=True)[1][:cf.model_max_instances_per_batch_element] + b_keep = b_keep[top_ids] + + # merge indices over batch elements. + batch_keep = b_keep if j == 0 else unique1d(torch.cat((batch_keep, b_keep))) + + keep = batch_keep + + else: + keep = torch.tensor([0]).long().cuda() + + # arrange output + output = [refined_rois[keep], batch_ixs[keep].unsqueeze(1)] + output += [class_ids[keep].unsqueeze(1).float(), class_scores[keep].unsqueeze(1)] + output += [regressions[keep]] + + result = torch.cat(output, dim=1) + # shape: (n_keeps, catted feats), catted feats: [0:dim*2] are box_coords, [dim*2] are batch_ics, + # [dim*2+1] are class_ids, [dim*2+2] are scores, [dim*2+3:] are regression vector features (incl uncertainty) + return result + + +def loss_example_mining(cf, batch_proposals, batch_gt_boxes, batch_gt_masks, batch_roi_scores, + batch_gt_class_ids, batch_gt_regressions): + """ + Subsamples proposals for mrcnn losses and generates targets. Sampling is done per batch element, seems to have positive + effects on training, as opposed to sampling over entire batch. Negatives are sampled via stochastic hard-example mining + (SHEM), where a number of negative proposals is drawn from larger pool of highest scoring proposals for stochasticity. + Scoring is obtained here as the max over all foreground probabilities as returned by mrcnn_classifier (worked better than + loss-based class-balancing methods like "online hard-example mining" or "focal loss".) + + Classification-regression duality: regressions can be given along with classes (at least fg/bg, only class scores + are used for ranking). + + :param batch_proposals: (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ixs). + boxes as proposed by RPN. n_proposals here is determined by batch_size * POST_NMS_ROIS. + :param mrcnn_class_logits: (n_proposals, n_classes) + :param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates. + :param batch_gt_masks: list over batch elements. Each element is binary mask of shape (n_gt_rois, y, x, (z), c) + :param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels. + if no classes predicted (only fg/bg from RPN): expected as pseudo classes [0, 1] for bg, fg. + :param batch_gt_regressions: list over b elements. Each element is a regression target vector. if None--> pseudo + :return: sample_indices: (n_sampled_rois) indices of sampled proposals to be used for loss functions. + :return: target_class_ids: (n_sampled_rois)containing target class labels of sampled proposals. + :return: target_deltas: (n_sampled_rois, 2 * dim) containing target deltas of sampled proposals for box refinement. + :return: target_masks: (n_sampled_rois, y, x, (z)) containing target masks of sampled proposals. + """ + # normalization of target coordinates + #global sample_regressions + if cf.dim == 2: + h, w = cf.patch_size + scale = torch.from_numpy(np.array([h, w, h, w])).float().cuda() + else: + h, w, z = cf.patch_size + scale = torch.from_numpy(np.array([h, w, h, w, z, z])).float().cuda() + + + positive_count = 0 + negative_count = 0 + sample_positive_indices = [] + sample_negative_indices = [] + sample_deltas = [] + sample_masks = [] + sample_class_ids = [] + if batch_gt_regressions is not None: + sample_regressions = [] + else: + target_regressions = torch.FloatTensor().cuda() + + # loop over batch and get positive and negative sample rois. + for b in range(len(batch_gt_boxes)): + + gt_masks = torch.from_numpy(batch_gt_masks[b]).float().cuda() + gt_class_ids = torch.from_numpy(batch_gt_class_ids[b]).int().cuda() + if batch_gt_regressions is not None: + gt_regressions = torch.from_numpy(batch_gt_regressions[b]).float().cuda() + + #if np.any(batch_gt_class_ids[b] > 0): # skip roi selection for no gt images. + if np.any([len(coords)>0 for coords in batch_gt_boxes[b]]): + gt_boxes = torch.from_numpy(batch_gt_boxes[b]).float().cuda() / scale + else: + gt_boxes = torch.FloatTensor().cuda() + + # get proposals and indices of current batch element. + proposals = batch_proposals[batch_proposals[:, -1] == b][:, :-1] + batch_element_indices = torch.nonzero(batch_proposals[:, -1] == b).squeeze(1) + + # Compute overlaps matrix [proposals, gt_boxes] + if not 0 in gt_boxes.size(): + if gt_boxes.shape[1] == 4: + assert cf.dim == 2, "gt_boxes shape {} doesnt match cf.dim{}".format(gt_boxes.shape, cf.dim) + overlaps = bbox_overlaps_2D(proposals, gt_boxes) + else: + assert cf.dim == 3, "gt_boxes shape {} doesnt match cf.dim{}".format(gt_boxes.shape, cf.dim) + overlaps = bbox_overlaps_3D(proposals, gt_boxes) + + # Determine positive and negative ROIs + roi_iou_max = torch.max(overlaps, dim=1)[0] + # 1. Positive ROIs are those with >= 0.5 IoU with a GT box + positive_roi_bool = roi_iou_max >= (0.5 if cf.dim == 2 else 0.3) + # 2. Negative ROIs are those with < 0.1 with every GT box. + negative_roi_bool = roi_iou_max < (0.1 if cf.dim == 2 else 0.01) + else: + positive_roi_bool = torch.FloatTensor().cuda() + negative_roi_bool = torch.from_numpy(np.array([1]*proposals.shape[0])).cuda() + + # Sample Positive ROIs + if not 0 in torch.nonzero(positive_roi_bool).size(): + positive_indices = torch.nonzero(positive_roi_bool).squeeze(1) + positive_samples = int(cf.train_rois_per_image * cf.roi_positive_ratio) + rand_idx = torch.randperm(positive_indices.size()[0]) + rand_idx = rand_idx[:positive_samples].cuda() + positive_indices = positive_indices[rand_idx] + positive_samples = positive_indices.size()[0] + positive_rois = proposals[positive_indices, :] + # Assign positive ROIs to GT boxes. + positive_overlaps = overlaps[positive_indices, :] + roi_gt_box_assignment = torch.max(positive_overlaps, dim=1)[1] + roi_gt_boxes = gt_boxes[roi_gt_box_assignment, :] + roi_gt_class_ids = gt_class_ids[roi_gt_box_assignment] + if batch_gt_regressions is not None: + roi_gt_regressions = gt_regressions[roi_gt_box_assignment] + + # Compute bbox refinement targets for positive ROIs + deltas = box_refinement(positive_rois, roi_gt_boxes) + std_dev = torch.from_numpy(cf.bbox_std_dev).float().cuda() + deltas /= std_dev + + # Assign positive ROIs to GT masks + roi_masks = gt_masks[roi_gt_box_assignment,:,:] + + # Compute mask targets + boxes = positive_rois + box_ids = torch.arange(roi_masks.size()[0]).int().cuda() + + if len(cf.mask_shape) == 2: + masks = ra2D(cf.mask_shape[0], cf.mask_shape[1], 0)(roi_masks.unsqueeze(1), boxes, box_ids) + else: + masks = ra3D(cf.mask_shape[0], cf.mask_shape[1], cf.mask_shape[2], 0)(roi_masks.unsqueeze(1), boxes, box_ids) + + masks = masks.squeeze(1) + # Threshold mask pixels at 0.5 to have GT masks be 0 or 1 to use with + # binary cross entropy loss. + masks = torch.round(masks) + + sample_positive_indices.append(batch_element_indices[positive_indices]) + sample_deltas.append(deltas) + sample_masks.append(masks) + sample_class_ids.append(roi_gt_class_ids) + if batch_gt_regressions is not None: + sample_regressions.append(roi_gt_regressions) + positive_count += positive_samples + else: + positive_samples = 0 + + # Sample negative ROIs. Add enough to maintain positive:negative ratio, but at least 1. Sample via SHEM. + if not 0 in torch.nonzero(negative_roi_bool).size(): + negative_indices = torch.nonzero(negative_roi_bool).squeeze(1) + r = 1.0 / cf.roi_positive_ratio + b_neg_count = np.max((int(r * positive_samples - positive_samples), 1)) + roi_scores_neg = batch_roi_scores[batch_element_indices[negative_indices]] + raw_sampled_indices = shem(roi_scores_neg, b_neg_count, cf.shem_poolsize) + sample_negative_indices.append(batch_element_indices[negative_indices[raw_sampled_indices]]) + negative_count += raw_sampled_indices.size()[0] + + if len(sample_positive_indices) > 0: + target_deltas = torch.cat(sample_deltas) + target_masks = torch.cat(sample_masks) + target_class_ids = torch.cat(sample_class_ids) + if batch_gt_regressions is not None: + target_regressions = torch.cat(sample_regressions) + + # Pad target information with zeros for negative ROIs. + if positive_count > 0 and negative_count > 0: + sample_indices = torch.cat((torch.cat(sample_positive_indices), torch.cat(sample_negative_indices)), dim=0) + zeros = torch.zeros(negative_count, cf.dim * 2).cuda() + target_deltas = torch.cat([target_deltas, zeros], dim=0) + zeros = torch.zeros(negative_count, *cf.mask_shape).cuda() + target_masks = torch.cat([target_masks, zeros], dim=0) + zeros = torch.zeros(negative_count).int().cuda() + target_class_ids = torch.cat([target_class_ids, zeros], dim=0) + if batch_gt_regressions is not None: + # regression targets need to have 0 as background/negative with below practice + if 'regression_bin' in cf.prediction_tasks: + zeros = torch.zeros(negative_count, dtype=torch.float).cuda() + else: + zeros = torch.zeros(negative_count, cf.regression_n_features, dtype=torch.float).cuda() + target_regressions = torch.cat([target_regressions, zeros], dim=0) + + elif positive_count > 0: + sample_indices = torch.cat(sample_positive_indices) + elif negative_count > 0: + sample_indices = torch.cat(sample_negative_indices) + target_deltas = torch.zeros(negative_count, cf.dim * 2).cuda() + target_masks = torch.zeros(negative_count, *cf.mask_shape).cuda() + target_class_ids = torch.zeros(negative_count).int().cuda() + if batch_gt_regressions is not None: + if 'regression_bin' in cf.prediction_tasks: + target_regressions = torch.zeros(negative_count, dtype=torch.float).cuda() + else: + target_regressions = torch.zeros(negative_count, cf.regression_n_features, dtype=torch.float).cuda() + else: + sample_indices = torch.LongTensor().cuda() + target_class_ids = torch.IntTensor().cuda() + target_deltas = torch.FloatTensor().cuda() + target_masks = torch.FloatTensor().cuda() + target_regressions = torch.FloatTensor().cuda() + + return sample_indices, target_deltas, target_masks, target_class_ids, target_regressions + +############################################################ +# Anchors +############################################################ + +def generate_anchors(scales, ratios, shape, feature_stride, anchor_stride): + """ + scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128] + ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2] + shape: [height, width] spatial shape of the feature map over which + to generate anchors. + feature_stride: Stride of the feature map relative to the image in pixels. + anchor_stride: Stride of anchors on the feature map. For example, if the + value is 2 then generate anchors for every other feature map pixel. + """ + # Get all combinations of scales and ratios + scales, ratios = np.meshgrid(np.array(scales), np.array(ratios)) + scales = scales.flatten() + ratios = ratios.flatten() + + # Enumerate heights and widths from scales and ratios + heights = scales / np.sqrt(ratios) + widths = scales * np.sqrt(ratios) + + # Enumerate shifts in feature space + shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride + shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride + shifts_x, shifts_y = np.meshgrid(shifts_x, shifts_y) + + # Enumerate combinations of shifts, widths, and heights + box_widths, box_centers_x = np.meshgrid(widths, shifts_x) + box_heights, box_centers_y = np.meshgrid(heights, shifts_y) + + # Reshape to get a list of (y, x) and a list of (h, w) + box_centers = np.stack([box_centers_y, box_centers_x], axis=2).reshape([-1, 2]) + box_sizes = np.stack([box_heights, box_widths], axis=2).reshape([-1, 2]) + + # Convert to corner coordinates (y1, x1, y2, x2) + boxes = np.concatenate([box_centers - 0.5 * box_sizes, box_centers + 0.5 * box_sizes], axis=1) + return boxes + + + +def generate_anchors_3D(scales_xy, scales_z, ratios, shape, feature_stride_xy, feature_stride_z, anchor_stride): + """ + scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128] + ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2] + shape: [height, width] spatial shape of the feature map over which + to generate anchors. + feature_stride: Stride of the feature map relative to the image in pixels. + anchor_stride: Stride of anchors on the feature map. For example, if the + value is 2 then generate anchors for every other feature map pixel. + """ + # Get all combinations of scales and ratios + + scales_xy, ratios_meshed = np.meshgrid(np.array(scales_xy), np.array(ratios)) + scales_xy = scales_xy.flatten() + ratios_meshed = ratios_meshed.flatten() + + # Enumerate heights and widths from scales and ratios + heights = scales_xy / np.sqrt(ratios_meshed) + widths = scales_xy * np.sqrt(ratios_meshed) + depths = np.tile(np.array(scales_z), len(ratios_meshed)//np.array(scales_z)[..., None].shape[0]) + + # Enumerate shifts in feature space + shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride_xy #translate from fm positions to input coords. + shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride_xy + shifts_z = np.arange(0, shape[2], anchor_stride) * (feature_stride_z) + shifts_x, shifts_y, shifts_z = np.meshgrid(shifts_x, shifts_y, shifts_z) + + # Enumerate combinations of shifts, widths, and heights + box_widths, box_centers_x = np.meshgrid(widths, shifts_x) + box_heights, box_centers_y = np.meshgrid(heights, shifts_y) + box_depths, box_centers_z = np.meshgrid(depths, shifts_z) + + # Reshape to get a list of (y, x, z) and a list of (h, w, d) + box_centers = np.stack( + [box_centers_y, box_centers_x, box_centers_z], axis=2).reshape([-1, 3]) + box_sizes = np.stack([box_heights, box_widths, box_depths], axis=2).reshape([-1, 3]) + + # Convert to corner coordinates (y1, x1, y2, x2, z1, z2) + boxes = np.concatenate([box_centers - 0.5 * box_sizes, + box_centers + 0.5 * box_sizes], axis=1) + + boxes = np.transpose(np.array([boxes[:, 0], boxes[:, 1], boxes[:, 3], boxes[:, 4], boxes[:, 2], boxes[:, 5]]), axes=(1, 0)) + return boxes + + +def generate_pyramid_anchors(logger, cf): + """Generate anchors at different levels of a feature pyramid. Each scale + is associated with a level of the pyramid, but each ratio is used in + all levels of the pyramid. + + from configs: + :param scales: cf.RPN_ANCHOR_SCALES , for conformity with retina nets: scale entries need to be list, e.g. [[4], [8], [16], [32]] + :param ratios: cf.RPN_ANCHOR_RATIOS , e.g. [0.5, 1, 2] + :param feature_shapes: cf.BACKBONE_SHAPES , e.g. [array of shapes per feature map] [80, 40, 20, 10, 5] + :param feature_strides: cf.BACKBONE_STRIDES , e.g. [2, 4, 8, 16, 32, 64] + :param anchors_stride: cf.RPN_ANCHOR_STRIDE , e.g. 1 + :return anchors: (N, (y1, x1, y2, x2, (z1), (z2)). All generated anchors in one array. Sorted + with the same order of the given scales. So, anchors of scale[0] come first, then anchors of scale[1], and so on. + """ + scales = cf.rpn_anchor_scales + ratios = cf.rpn_anchor_ratios + feature_shapes = cf.backbone_shapes + anchor_stride = cf.rpn_anchor_stride + pyramid_levels = cf.pyramid_levels + feature_strides = cf.backbone_strides + + logger.info("anchor scales {} and feature map shapes {}".format(scales, feature_shapes)) + expected_anchors = [np.prod(feature_shapes[level]) * len(ratios) * len(scales['xy'][level]) for level in pyramid_levels] + + anchors = [] + for lix, level in enumerate(pyramid_levels): + if len(feature_shapes[level]) == 2: + anchors.append(generate_anchors(scales['xy'][level], ratios, feature_shapes[level], + feature_strides['xy'][level], anchor_stride)) + elif len(feature_shapes[level]) == 3: + anchors.append(generate_anchors_3D(scales['xy'][level], scales['z'][level], ratios, feature_shapes[level], + feature_strides['xy'][level], feature_strides['z'][level], anchor_stride)) + else: + raise Exception("invalid feature_shapes[{}] size {}".format(level, feature_shapes[level])) + logger.info("level {}: expected anchors {}, built anchors {}.".format(level, expected_anchors[lix], anchors[-1].shape)) + + out_anchors = np.concatenate(anchors, axis=0) + logger.info("Total: expected anchors {}, built anchors {}.".format(np.sum(expected_anchors), out_anchors.shape)) + + return out_anchors + + + +def apply_box_deltas_2D(boxes, deltas): + """Applies the given deltas to the given boxes. + boxes: [N, 4] where each row is y1, x1, y2, x2 + deltas: [N, 4] where each row is [dy, dx, log(dh), log(dw)] + """ + # Convert to y, x, h, w + height = boxes[:, 2] - boxes[:, 0] + width = boxes[:, 3] - boxes[:, 1] + center_y = boxes[:, 0] + 0.5 * height + center_x = boxes[:, 1] + 0.5 * width + # Apply deltas + center_y += deltas[:, 0] * height + center_x += deltas[:, 1] * width + height *= torch.exp(deltas[:, 2]) + width *= torch.exp(deltas[:, 3]) + # Convert back to y1, x1, y2, x2 + y1 = center_y - 0.5 * height + x1 = center_x - 0.5 * width + y2 = y1 + height + x2 = x1 + width + result = torch.stack([y1, x1, y2, x2], dim=1) + return result + + + +def apply_box_deltas_3D(boxes, deltas): + """Applies the given deltas to the given boxes. + boxes: [N, 6] where each row is y1, x1, y2, x2, z1, z2 + deltas: [N, 6] where each row is [dy, dx, dz, log(dh), log(dw), log(dd)] + """ + # Convert to y, x, h, w + height = boxes[:, 2] - boxes[:, 0] + width = boxes[:, 3] - boxes[:, 1] + depth = boxes[:, 5] - boxes[:, 4] + center_y = boxes[:, 0] + 0.5 * height + center_x = boxes[:, 1] + 0.5 * width + center_z = boxes[:, 4] + 0.5 * depth + # Apply deltas + center_y += deltas[:, 0] * height + center_x += deltas[:, 1] * width + center_z += deltas[:, 2] * depth + height *= torch.exp(deltas[:, 3]) + width *= torch.exp(deltas[:, 4]) + depth *= torch.exp(deltas[:, 5]) + # Convert back to y1, x1, y2, x2 + y1 = center_y - 0.5 * height + x1 = center_x - 0.5 * width + z1 = center_z - 0.5 * depth + y2 = y1 + height + x2 = x1 + width + z2 = z1 + depth + result = torch.stack([y1, x1, y2, x2, z1, z2], dim=1) + return result + + + +def clip_boxes_2D(boxes, window): + """ + boxes: [N, 4] each col is y1, x1, y2, x2 + window: [4] in the form y1, x1, y2, x2 + """ + boxes = torch.stack( \ + [boxes[:, 0].clamp(float(window[0]), float(window[2])), + boxes[:, 1].clamp(float(window[1]), float(window[3])), + boxes[:, 2].clamp(float(window[0]), float(window[2])), + boxes[:, 3].clamp(float(window[1]), float(window[3]))], 1) + return boxes + +def clip_boxes_3D(boxes, window): + """ + boxes: [N, 6] each col is y1, x1, y2, x2, z1, z2 + window: [6] in the form y1, x1, y2, x2, z1, z2 + """ + boxes = torch.stack( \ + [boxes[:, 0].clamp(float(window[0]), float(window[2])), + boxes[:, 1].clamp(float(window[1]), float(window[3])), + boxes[:, 2].clamp(float(window[0]), float(window[2])), + boxes[:, 3].clamp(float(window[1]), float(window[3])), + boxes[:, 4].clamp(float(window[4]), float(window[5])), + boxes[:, 5].clamp(float(window[4]), float(window[5]))], 1) + return boxes + +from matplotlib import pyplot as plt + + +def clip_boxes_numpy(boxes, window): + """ + boxes: [N, 4] each col is y1, x1, y2, x2 / [N, 6] in 3D. + window: iamge shape (y, x, (z)) + """ + if boxes.shape[1] == 4: + boxes = np.concatenate( + (np.clip(boxes[:, 0], 0, window[0])[:, None], + np.clip(boxes[:, 1], 0, window[0])[:, None], + np.clip(boxes[:, 2], 0, window[1])[:, None], + np.clip(boxes[:, 3], 0, window[1])[:, None]), 1 + ) + + else: + boxes = np.concatenate( + (np.clip(boxes[:, 0], 0, window[0])[:, None], + np.clip(boxes[:, 1], 0, window[0])[:, None], + np.clip(boxes[:, 2], 0, window[1])[:, None], + np.clip(boxes[:, 3], 0, window[1])[:, None], + np.clip(boxes[:, 4], 0, window[2])[:, None], + np.clip(boxes[:, 5], 0, window[2])[:, None]), 1 + ) + + return boxes + + + +def bbox_overlaps_2D(boxes1, boxes2): + """Computes IoU overlaps between two sets of boxes. + boxes1, boxes2: [N, (y1, x1, y2, x2)]. + """ + # 1. Tile boxes2 and repeate boxes1. This allows us to compare + # every boxes1 against every boxes2 without loops. + # TF doesn't have an equivalent to np.repeate() so simulate it + # using tf.tile() and tf.reshape. + + boxes1_repeat = boxes2.size()[0] + boxes2_repeat = boxes1.size()[0] + + boxes1 = boxes1.repeat(1,boxes1_repeat).view(-1,4) + boxes2 = boxes2.repeat(boxes2_repeat,1) + + # 2. Compute intersections + b1_y1, b1_x1, b1_y2, b1_x2 = boxes1.chunk(4, dim=1) + b2_y1, b2_x1, b2_y2, b2_x2 = boxes2.chunk(4, dim=1) + y1 = torch.max(b1_y1, b2_y1)[:, 0] + x1 = torch.max(b1_x1, b2_x1)[:, 0] + y2 = torch.min(b1_y2, b2_y2)[:, 0] + x2 = torch.min(b1_x2, b2_x2)[:, 0] + #--> expects x11 produced in bbox_overlaps_2D" + overlaps = iou.view(boxes2_repeat, boxes1_repeat) #--> per gt box: ious of all proposal boxes with that gt box + + return overlaps + +def bbox_overlaps_3D(boxes1, boxes2): + """Computes IoU overlaps between two sets of boxes. + boxes1, boxes2: [N, (y1, x1, y2, x2, z1, z2)]. + """ + # 1. Tile boxes2 and repeate boxes1. This allows us to compare + # every boxes1 against every boxes2 without loops. + # TF doesn't have an equivalent to np.repeate() so simulate it + # using tf.tile() and tf.reshape. + boxes1_repeat = boxes2.size()[0] + boxes2_repeat = boxes1.size()[0] + boxes1 = boxes1.repeat(1,boxes1_repeat).view(-1,6) + boxes2 = boxes2.repeat(boxes2_repeat,1) + + # 2. Compute intersections + b1_y1, b1_x1, b1_y2, b1_x2, b1_z1, b1_z2 = boxes1.chunk(6, dim=1) + b2_y1, b2_x1, b2_y2, b2_x2, b2_z1, b2_z2 = boxes2.chunk(6, dim=1) + y1 = torch.max(b1_y1, b2_y1)[:, 0] + x1 = torch.max(b1_x1, b2_x1)[:, 0] + y2 = torch.min(b1_y2, b2_y2)[:, 0] + x2 = torch.min(b1_x2, b2_x2)[:, 0] + z1 = torch.max(b1_z1, b2_z1)[:, 0] + z2 = torch.min(b1_z2, b2_z2)[:, 0] + zeros = Variable(torch.zeros(y1.size()[0]), requires_grad=False) + if y1.is_cuda: + zeros = zeros.cuda() + intersection = torch.max(x2 - x1, zeros) * torch.max(y2 - y1, zeros) * torch.max(z2 - z1, zeros) + + # 3. Compute unions + b1_volume = (b1_y2 - b1_y1) * (b1_x2 - b1_x1) * (b1_z2 - b1_z1) + b2_volume = (b2_y2 - b2_y1) * (b2_x2 - b2_x1) * (b2_z2 - b2_z1) + union = b1_volume[:,0] + b2_volume[:,0] - intersection + + # 4. Compute IoU and reshape to [boxes1, boxes2] + iou = intersection / union + overlaps = iou.view(boxes2_repeat, boxes1_repeat) + return overlaps + +def gt_anchor_matching(cf, anchors, gt_boxes, gt_class_ids=None): + """Given the anchors and GT boxes, compute overlaps and identify positive + anchors and deltas to refine them to match their corresponding GT boxes. + + anchors: [num_anchors, (y1, x1, y2, x2, (z1), (z2))] + gt_boxes: [num_gt_boxes, (y1, x1, y2, x2, (z1), (z2))] + gt_class_ids (optional): [num_gt_boxes] Integer class IDs for one stage detectors. in RPN case of Mask R-CNN, + set all positive matches to 1 (foreground) + + Returns: + anchor_class_matches: [N] (int32) matches between anchors and GT boxes. + 1 = positive anchor, -1 = negative anchor, 0 = neutral + anchor_delta_targets: [N, (dy, dx, (dz), log(dh), log(dw), (log(dd)))] Anchor bbox deltas. + """ + + anchor_class_matches = np.zeros([anchors.shape[0]], dtype=np.int32) + anchor_delta_targets = np.zeros((cf.rpn_train_anchors_per_image, 2*cf.dim)) + anchor_matching_iou = cf.anchor_matching_iou + + if gt_boxes is None: + anchor_class_matches = np.full(anchor_class_matches.shape, fill_value=-1) + return anchor_class_matches, anchor_delta_targets + + # for mrcnn: anchor matching is done for RPN loss, so positive labels are all 1 (foreground) + if gt_class_ids is None: + gt_class_ids = np.array([1] * len(gt_boxes)) + + # Compute overlaps [num_anchors, num_gt_boxes] + overlaps = compute_overlaps(anchors, gt_boxes) + + # Match anchors to GT Boxes + # If an anchor overlaps a GT box with IoU >= anchor_matching_iou then it's positive. + # If an anchor overlaps a GT box with IoU < 0.1 then it's negative. + # Neutral anchors are those that don't match the conditions above, + # and they don't influence the loss function. + # However, don't keep any GT box unmatched (rare, but happens). Instead, + # match it to the closest anchor (even if its max IoU is < 0.1). + + # 1. Set negative anchors first. They get overwritten below if a GT box is + # matched to them. Skip boxes in crowd areas. + anchor_iou_argmax = np.argmax(overlaps, axis=1) + anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax] + if anchors.shape[1] == 4: + anchor_class_matches[(anchor_iou_max < 0.1)] = -1 + elif anchors.shape[1] == 6: + anchor_class_matches[(anchor_iou_max < 0.01)] = -1 + else: + raise ValueError('anchor shape wrong {}'.format(anchors.shape)) + + # 2. Set an anchor for each GT box (regardless of IoU value). + gt_iou_argmax = np.argmax(overlaps, axis=0) + for ix, ii in enumerate(gt_iou_argmax): + anchor_class_matches[ii] = gt_class_ids[ix] + + # 3. Set anchors with high overlap as positive. + above_thresh_ixs = np.argwhere(anchor_iou_max >= anchor_matching_iou) + anchor_class_matches[above_thresh_ixs] = gt_class_ids[anchor_iou_argmax[above_thresh_ixs]] + + # Subsample to balance positive anchors. + ids = np.where(anchor_class_matches > 0)[0] + extra = len(ids) - (cf.rpn_train_anchors_per_image // 2) + if extra > 0: + # Reset the extra ones to neutral + ids = np.random.choice(ids, extra, replace=False) + anchor_class_matches[ids] = 0 + + # Leave all negative proposals negative for now and sample from them later in online hard example mining. + # For positive anchors, compute shift and scale needed to transform them to match the corresponding GT boxes. + ids = np.where(anchor_class_matches > 0)[0] + ix = 0 # index into anchor_delta_targets + for i, a in zip(ids, anchors[ids]): + # closest gt box (it might have IoU < anchor_matching_iou) + gt = gt_boxes[anchor_iou_argmax[i]] + + # convert coordinates to center plus width/height. + gt_h = gt[2] - gt[0] + gt_w = gt[3] - gt[1] + gt_center_y = gt[0] + 0.5 * gt_h + gt_center_x = gt[1] + 0.5 * gt_w + # Anchor + a_h = a[2] - a[0] + a_w = a[3] - a[1] + a_center_y = a[0] + 0.5 * a_h + a_center_x = a[1] + 0.5 * a_w + + if cf.dim == 2: + anchor_delta_targets[ix] = [ + (gt_center_y - a_center_y) / a_h, + (gt_center_x - a_center_x) / a_w, + np.log(gt_h / a_h), + np.log(gt_w / a_w), + ] + + else: + gt_d = gt[5] - gt[4] + gt_center_z = gt[4] + 0.5 * gt_d + a_d = a[5] - a[4] + a_center_z = a[4] + 0.5 * a_d + + anchor_delta_targets[ix] = [ + (gt_center_y - a_center_y) / a_h, + (gt_center_x - a_center_x) / a_w, + (gt_center_z - a_center_z) / a_d, + np.log(gt_h / a_h), + np.log(gt_w / a_w), + np.log(gt_d / a_d) + ] + + # normalize. + anchor_delta_targets[ix] /= cf.rpn_bbox_std_dev + ix += 1 + + return anchor_class_matches, anchor_delta_targets + + + +def clip_to_window(window, boxes): + """ + window: (y1, x1, y2, x2) / 3D: (z1, z2). The window in the image we want to clip to. + boxes: [N, (y1, x1, y2, x2)] / 3D: (z1, z2) + """ + boxes[:, 0] = boxes[:, 0].clamp(float(window[0]), float(window[2])) + boxes[:, 1] = boxes[:, 1].clamp(float(window[1]), float(window[3])) + boxes[:, 2] = boxes[:, 2].clamp(float(window[0]), float(window[2])) + boxes[:, 3] = boxes[:, 3].clamp(float(window[1]), float(window[3])) + + if boxes.shape[1] > 5: + boxes[:, 4] = boxes[:, 4].clamp(float(window[4]), float(window[5])) + boxes[:, 5] = boxes[:, 5].clamp(float(window[4]), float(window[5])) + + return boxes + +############################################################ +# Connected Componenent Analysis +############################################################ + +def get_coords(binary_mask, n_components, dim): + """ + loops over batch to perform connected component analysis on binary input mask. computes box coordinates around + n_components - biggest components (rois). + :param binary_mask: (b, y, x, (z)). binary mask for one specific foreground class. + :param n_components: int. number of components to extract per batch element and class. + :return: coords (b, n, (y1, x1, y2, x2 (,z1, z2)) + :return: batch_components (b, n, (y1, x1, y2, x2, (z1), (z2)) + """ + assert len(binary_mask.shape)==dim+1 + binary_mask = binary_mask.astype('uint8') + batch_coords = [] + batch_components = [] + for ix,b in enumerate(binary_mask): + clusters, n_cands = lb(b) # performs connected component analysis. + uniques, counts = np.unique(clusters, return_counts=True) + keep_uniques = uniques[1:][np.argsort(counts[1:])[::-1]][:n_components] #only keep n_components largest components + p_components = np.array([(clusters == ii) * 1 for ii in keep_uniques]) # separate clusters and concat + p_coords = [] + if p_components.shape[0] > 0: + for roi in p_components: + mask_ixs = np.argwhere(roi != 0) + + # get coordinates around component. + roi_coords = [np.min(mask_ixs[:, 0]) - 1, np.min(mask_ixs[:, 1]) - 1, np.max(mask_ixs[:, 0]) + 1, + np.max(mask_ixs[:, 1]) + 1] + if dim == 3: + roi_coords += [np.min(mask_ixs[:, 2]), np.max(mask_ixs[:, 2])+1] + p_coords.append(roi_coords) + + p_coords = np.array(p_coords) + + #clip coords. + p_coords[p_coords < 0] = 0 + p_coords[:, :4][p_coords[:, :4] > binary_mask.shape[-2]] = binary_mask.shape[-2] + if dim == 3: + p_coords[:, 4:][p_coords[:, 4:] > binary_mask.shape[-1]] = binary_mask.shape[-1] + + batch_coords.append(p_coords) + batch_components.append(p_components) + return batch_coords, batch_components + + +# noinspection PyCallingNonCallable +def get_coords_gpu(binary_mask, n_components, dim): + """ + loops over batch to perform connected component analysis on binary input mask. computes box coordiantes around + n_components - biggest components (rois). + :param binary_mask: (b, y, x, (z)). binary mask for one specific foreground class. + :param n_components: int. number of components to extract per batch element and class. + :return: coords (b, n, (y1, x1, y2, x2 (,z1, z2)) + :return: batch_components (b, n, (y1, x1, y2, x2, (z1), (z2)) + """ + raise Exception("throws floating point exception") + assert len(binary_mask.shape)==dim+1 + binary_mask = binary_mask.type(torch.uint8) + batch_coords = [] + batch_components = [] + for ix,b in enumerate(binary_mask): + clusters, n_cands = lb(b.cpu().data.numpy()) # peforms connected component analysis. + clusters = torch.from_numpy(clusters).cuda() + uniques = torch.unique(clusters) + counts = torch.stack([(clusters==unique).sum() for unique in uniques]) + keep_uniques = uniques[1:][torch.sort(counts[1:])[1].flip(0)][:n_components] #only keep n_components largest components + p_components = torch.cat([(clusters == ii).unsqueeze(0) for ii in keep_uniques]).cuda() # separate clusters and concat + p_coords = [] + if p_components.shape[0] > 0: + for roi in p_components: + mask_ixs = torch.nonzero(roi) + + # get coordinates around component. + roi_coords = [torch.min(mask_ixs[:, 0]) - 1, torch.min(mask_ixs[:, 1]) - 1, + torch.max(mask_ixs[:, 0]) + 1, + torch.max(mask_ixs[:, 1]) + 1] + if dim == 3: + roi_coords += [torch.min(mask_ixs[:, 2]), torch.max(mask_ixs[:, 2])+1] + p_coords.append(roi_coords) + + p_coords = torch.tensor(p_coords) + + #clip coords. + p_coords[p_coords < 0] = 0 + p_coords[:, :4][p_coords[:, :4] > binary_mask.shape[-2]] = binary_mask.shape[-2] + if dim == 3: + p_coords[:, 4:][p_coords[:, 4:] > binary_mask.shape[-1]] = binary_mask.shape[-1] + + batch_coords.append(p_coords) + batch_components.append(p_components) + return batch_coords, batch_components + + +############################################################ +# Pytorch Utility Functions +############################################################ + + +def unique1d(tensor): + """discard all elements of tensor that occur more than once; make tensor unique. + :param tensor: + :return: + """ + if tensor.size()[0] == 0 or tensor.size()[0] == 1: + return tensor + tensor = tensor.sort()[0] + unique_bool = tensor[1:] != tensor [:-1] + first_element = Variable(torch.ByteTensor([True]), requires_grad=False) + if tensor.is_cuda: + first_element = first_element.cuda() + unique_bool = torch.cat((first_element, unique_bool),dim=0) + return tensor[unique_bool.data] + + + +def log2(x): + """Implementatin of Log2. Pytorch doesn't have a native implementation.""" + ln2 = Variable(torch.log(torch.FloatTensor([2.0])), requires_grad=False) + if x.is_cuda: + ln2 = ln2.cuda() + return torch.log(x) / ln2 + + + +def intersect1d(tensor1, tensor2): + aux = torch.cat((tensor1, tensor2), dim=0) + aux = aux.sort(descending=True)[0] + return aux[:-1][(aux[1:] == aux[:-1]).data] + + + +def shem(roi_probs_neg, negative_count, poolsize): + """ + stochastic hard example mining: from a list of indices (referring to non-matched predictions), + determine a pool of highest scoring (worst false positives) of size negative_count*poolsize. + Then, sample n (= negative_count) predictions of this pool as negative examples for loss. + :param roi_probs_neg: tensor of shape (n_predictions, n_classes). + :param negative_count: int. + :param poolsize: int. + :return: (negative_count). indices refer to the positions in roi_probs_neg. If pool smaller than expected due to + limited negative proposals availabel, this function will return sampled indices of number < negative_count without + throwing an error. + """ + # sort according to higehst foreground score. + probs, order = roi_probs_neg[:, 1:].max(1)[0].sort(descending=True) + select = torch.tensor((poolsize * int(negative_count), order.size()[0])).min().int() + + pool_indices = order[:select] + rand_idx = torch.randperm(pool_indices.size()[0]) + return pool_indices[rand_idx[:negative_count].cuda()] + + +############################################################ +# Weight Init +############################################################ + + +def initialize_weights(net): + """Initialize model weights. Current Default in Pytorch (version 0.4.1) is initialization from a uniform distriubtion. + Will expectably be changed to kaiming_uniform in future versions. + """ + init_type = net.cf.weight_init + + for m in [module for module in net.modules() if type(module) in [torch.nn.Conv2d, torch.nn.Conv3d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d, + torch.nn.Linear]]: + if init_type == 'xavier_uniform': + torch.nn.init.xavier_uniform_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + + elif init_type == 'xavier_normal': + torch.nn.init.xavier_normal_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + + elif init_type == "kaiming_uniform": + torch.nn.init.kaiming_uniform_(m.weight.data, mode='fan_out', nonlinearity=net.cf.relu, a=0) + if m.bias is not None: + fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(m.weight.data) + bound = 1 / np.sqrt(fan_out) + torch.nn.init.uniform_(m.bias, -bound, bound) + + elif init_type == "kaiming_normal": + torch.nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity=net.cf.relu, a=0) + if m.bias is not None: + fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(m.weight.data) + bound = 1 / np.sqrt(fan_out) + torch.nn.init.normal_(m.bias, -bound, bound) + net.logger.info("applied {} weight init.".format(init_type)) \ No newline at end of file