diff --git a/README.md b/README.md index 6a62c4c..f961ce5 100644 --- a/README.md +++ b/README.md @@ -1,180 +1,163 @@ +Copyright © German Cancer Research Center (DKFZ), Division of Medical Image Computing (MIC). +Please make sure that your usage of this code is in compliance with the code license. + ## Introduction This repository holds the code framework used in the paper Reg R-CNN: Lesion Detection and Grading under Noisy Labels [1]. The framework is a fork of MIC's [medicaldetectiontoolkit](https://github.com/MIC-DKFZ/medicaldetectiontoolkit) with added regression capabilities. As below figure shows, the regression capability allows for the preservation of ordinal relations in the training signal as opposed to a standard categorical classification loss like the cross entropy loss (see publication for details).


Network Reg R-CNN is a version of Mask R-CNN [2] but with a regressor in place of the object-class head (see figure below). In this scenario, the first stage makes foreground (fg) vs background (bg) detections, then the regression head determines the class on an ordinal scale. Consequently, prediction confidence scores are taken from the first stage as opposed to the head in the original Mask R-CNN.


In the configs file of a data set in the framework, you may set attribute self.prediction_tasks = ["task"] to a value "task" from ["class", "regression_bin", "regression"]. "class" produces the same behavior as the original framework, i.e., standard object-detection behavior. "regression" on the other hand, swaps the class head of network Mask R-CNN [2] for a regression head. Consequently, objects are identified as fg/bg and then the class is decided by the regressor. For the sake of comparability, "regression_bin" produces a similar behavior but with a classification head. Both methods should be evaluated with the (implemented) Average Viewpoint Precision instead of only Average Precision. Below you will found a description of the general framework operations and handling. Basic framework functionality and description are for the most part identical to the original [medicaldetectiontoolkit](https://github.com/MIC-DKFZ/medicaldetectiontoolkit).
[1] Ramien, Gregor et al., "Reg R-CNN: Lesion Detection and Grading under Noisy Labels". In: UNSURE Workshop at MICCAI, 2019.
[2] He, Kaiming, et al. "Mask R-CNN" ICCV, 2017

## Overview This is a comprehensive framework for object detection featuring: - 2D + 3D implementations of common object detectors: e.g., Mask R-CNN [2], Retina Net [3], Retina U-Net [4]. - Modular and light-weight structure ensuring sharing of all processing steps (incl. backbone architecture) for comparability of models. - training with bounding box and/or pixel-wise annotations. - dynamic patching and tiling of 2D + 3D images (for training and inference). - weighted consolidation of box predictions across patch-overlaps, ensembles, and dimensions [4] or standard non-maximum suppression. - monitoring + evaluation simultaneously on object and patient level. - 2D + 3D output visualizations. - integration of COCO mean average precision metric [5]. - integration of MIC-DKFZ batch generators for extensive data augmentation [6]. - possible evaluation of instance segmentation and/or semantic segmentation by dice scores.
[3] Lin, Tsung-Yi, et al. "Focal Loss for Dense Object Detection" TPAMI, 2018.
[4] Jaeger, Paul et al. "Retina U-Net: Embarrassingly Simple Exploitation of Segmentation Supervision for Medical Object Detection" , 2018 [5] https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py
[6] https://github.com/MIC-DKFZ/batchgenerators

## How to cite this code Please cite the Reg R-CNN publication [1] or the original publication [4] depending on what features you use. ## Installation Setup package in virtual environment ``` git clone https://github.com/MIC-DKFZ/RegRCNN.git. cd RegRCNN -virtualenv -p python3 regrcnn_env +virtualenv -p python3.7 regrcnn_env source regrcnn_env/bin/activate -pip install -e . -``` -We use two cuda functions: Non-Maximum Suppression (taken from [pytorch-faster-rcnn](https://github.com/ruotianluo/pytorch-faster-rcnn) and added adaption for 3D) and RoiAlign (taken from [RoiAlign](https://github.com/longcw/RoIAlign.pytorch), fixed according to [this bug report](https://hackernoon.com/how-tensorflows-tf-image-resize-stole-60-days-of-my-life-aba5eb093f35), and added adaption for 3D). In this framework, they come pre-compile for TitanX. If you have a different GPU you need to re-compile these functions: - - -| GPU | arch | -| --- | --- | -| TitanX | sm_52 | -| GTX 960M | sm_50 | -| GTX 1070 | sm_61 | -| GTX 1080 (Ti) | sm_61 | - -``` -cd cuda_functions/nms_xD/src/cuda/ -nvcc -c -o nms_kernel.cu.o nms_kernel.cu -x cu -Xcompiler -fPIC -arch=[arch] -cd ../../ -python build.py -cd ../ - -cd cuda_functions/roi_align_xD/roi_align/src/cuda/ -nvcc -c -o crop_and_resize_kernel.cu.o crop_and_resize_kernel.cu -x cu -Xcompiler -fPIC -arch=[arch] -cd ../../ -python build.py -cd ../../ +python setup.py install ``` +This framework uses two custom mixed C++/CUDA extensions: Non-maximum suppression (NMS) and RoIAlign. Both are adapted from the original pytorch extensions (under torchvision.ops.boxes and ops.roialign). +The extensions are automatically compiled from the provided source files under RegRCNN/custom_extensions with above setup.py. +Note: If you'd like to import the raw extensions (not the wrapper modules), be sure to import torch first. ## Prepare the Data This framework is meant for you to be able to train models on your own data sets. In order to include a data set in the framework, create a new folder in RegRCNN/datasets, for instance "example_data". Your data set needs to have a config file in the style of the provided example data sets "lidc" and "toy". It also needs a data loader meeting the same requirements as the provided examples. Likely, you will also need a preprocessing script that transforms your data (once per data set creation, i.e., not a repetitive operation) into a suitable and easily processable format. Important requirements: - The framework expects numpy arrays as data and segmentation ground truth input. - Segmentations need to be suited for object detection, i.e., Regions of Interest (RoIs) need to be marked by integers (RoI-ID) in the segmentation volume (0 is background). Corresponding properties of a RoI, e.g., the "class_targets" need to be provided in a separate array or list with (RoI-ID - 1) corresponding to the index of the property in the list (-1 due to zero-indexing). Example: A data volume contains two RoIs. The second RoI is marked in the segmentation by number 2. The "class_targets" info associated with the data volume holds the list [2, 3]. Hence, RoI-ID 2 is assigned class 3. - This framework uses a modified version of MIC's batchgenerators' segmentation-to-bounding-box conversion tool. In this version, "class_targets", i.e., object classes start at 1, 0 is reserved for background. Thus, if you use "ConvertSegToBoundingBoxCoordinates" classes in your preprocessed data need to start at 1, not 0. Two example data loaders are provided in RegRCNN/datasets. The way I load data is to have a preprocessing script, which after preprocessing saves the data of whatever data type into numpy arrays (this is just run once). During training / testing, the data loader then loads these numpy arrays dynamically. Please note the data input side is meant to be customized by you according to your own needs and the provided data loaders are merely examples: LIDC has a powerful data loader that handles 2D/3D inputs and is optimized for patch-based training and inference. Due to the large data volumes of LIDC, this loader is slow. The provided toy data set, however, is light weight and a good starting point to get familiar with the framework. It is fully creatable from scratch within a few minutes with RegRCNN/datasets/toy/generate_toys.py. ## Execute 1. Set I/O paths, model and training specifics in the configs file: RegRCNN/datasets/_your_dataset_/configs.py 2. i) Train the model: ``` python exec.py --mode train --dataset_name your_dataset --exp_dir path/to/experiment/directory ``` This copies snapshots of configs and model to the specified exp_dir, where all outputs will be saved. By default, the data is split into 60% training and 20% validation and 20% testing data to perform a 5-fold cross validation (can be changed to hold-out test set in configs) and all folds will be trained iteratively. In order to train a single fold, specify it using the folds arg: ``` python exec.py --folds 0 1 2 .... # specify any combination of folds [0-configs.n_cv_splits] ``` ii) Alternatively, train and test consecutively: ``` python exec.py --mode train_test --dataset_name your_dataset --exp_dir path/to/experiment/directory ``` 3. Run inference: ``` python exec.py --mode test --exp_dir path/to/experiment/directory ``` This runs the prediction pipeline and saves all results to exp_dir. 4. Additional settings: - Check the args parser in exec.py to see which arguments and modes are available. - E.g., you may pass ```-d``` or ```--dev``` to enable a short development run of the whole train_test procedure (small batch size, only one epoch, two folds, one test patient, etc.). ## Models This framework features models explored in [4] (implemented in 2D + 3D): The proposed Retina U-Net, a simple but effective Architecture fusing state-of-the-art semantic segmentation with object detection,


also implementations of prevalent object detectors, such as Mask R-CNN, Faster R-CNN+ (Faster R-CNN w\ RoIAlign), Retina Net, Detection U-Net (a U-Net like segmentation architecture with heuristics for object detection.)



## Training annotations This framework features training with pixelwise and/or bounding box annotations. To overcome the issue of box coordinates in data augmentation, we feed the annotation masks through data augmentation (create a pseudo mask, if only bounding box annotations provided) and draw the boxes afterwards.


The framework further handles two types of pixel-wise annotations: 1. A label map with individual ROIs identified by increasing label values, accompanied by a vector containing in each position the class target for the lesion with the corresponding label (for this mode set get_rois_from_seg_flag = False when calling ConvertSegToBoundingBoxCoordinates in your Data Loader). This is usual use case as explained in section "Prepare the data". 2. A binary label map. There is only one foreground class and single lesions are not identified. All lesions have the same class target (foreground). In this case the data loader runs a Connected Component Labelling algorithm to create processable lesion - class target pairs on the fly (for this mode set get_rois_from_seg_flag = True when calling ConvertSegToBoundingBoxCoordinates in your data loader). ## Prediction pipeline This framework provides an inference module, which automatically handles patching of inputs, and tiling, ensembling, and weighted consolidation of output predictions:




## Consolidation of predictions ### Weighted Box Clustering Multiple predictions of the same image (from test time augmentations, tested epochs and overlapping patches), result in a high amount of boxes (or cubes), which need to be consolidated. In semantic segmentation, the final output would typically be obtained by averaging every pixel over all predictions. As described in [4], **weighted box clustering** (WBC) does this for box predictions:





To enable WBC, set self.clustering = "wbc" in your configs file. ### Non-Maximum Suppression Test-time predictions can alternatively be aggregated with standard non-maximum suppression. In your configs file, simply set self.clustering = "nms" instead of "wbc". As a further alternative you may also choose no test-time aggregation by setting self.clustering = None. ## Visualization / Monitoring In opposition to the original framework, this fork uses tensorboard for monitoring training and validation progress. Since, for now, the framework cannot easily be updated to pytorch >= 1.x, we need third-party package [tensorboardX](https://github.com/lanpa/tensorboardX) to use tensorboard with pytorch. You can set an applicable choice of implemented metrics like "ap" for Average Precision or "auc" for patient-level ROC-AUC in the configs under self.metrics = [...]. Metrics are then evaluated by evaluator.py and recorded in monitor_metrics. logger.metrics2tboard sends monitor_metrics to your tensorboard logfiles at the end of each epoch. You need to separately start a virtual tensorboard server, pass it your experiment directory (or directories, but it crashes if its more than ~5 experiments) and navigate to the server address. (You can also read up on tensoardboard usage in the original documentation). ### Example: 1. Activate your virtualenv where tensorboard is installed. 2. Start tensorboard server. For instance, your experiment directory is _yourexpdir_:
```tensorboard --port 6007 --logdir yourexpdir``` 3. Navigate to ```localhost:6007``` in your browser. ### Output monitoring For qualitative monitoring, example plots are saved to _yourexpdir_/plots for training and validation and _yourexpdir_/test/example_plots for testing. Note, that test-time example plots may contain unconsolidated predictions over test-time augmentations, thereby possibly showing many overlapping and/or noisy predictions. You may adapt/use separate file RegRCNN/inference_analysis.py to create clean and nice plots of (consolidated) test-time predictions. ## Balancing Mechanism of Example Data Loader The data loaders of the provided example data sets employ a custom mechanism with the goal of assembling target-balanced batches or training sequences. I.e., the amount of examples shown per target class should be near balance. The mechanism creates a sampling-likelihood distribution, as shown below, over all available patients (PIDs). At batch generation, some patients are drawn according to this distribution, others are drawn completely randomly (according to a uniform distribution across all patients). The ratio of uniformly and target-dependently drawn patients is set in your configs file by configs.batch_random_ratio. configs.balance_target determines which targets are considered for the balancing distribution. While the balancing distribution assigns probability 0 to empty patients (contains no object of desired target kind), the random ratio allows for inclusion of those empty patients in the training exposure. Experience has shown, that showing at least one foreground example in each batch is most critical, other properties have less impact.



## Unittests unittests.py contains some verification and testing procedures, which, however, need you to adjust paths in the TestCase classes before execution. Tests can be used, for instance, to verify if your cross-validation folds have been created correctly, or if separate experiments have the same fold splits. # License This framework is published under the [APACHE 2.0 License](https://github.com/MIC-DKFZ/RegRCNN/blob/master/LICENSE) \ No newline at end of file diff --git a/models/retina_net.py b/models/retina_net.py index d618e5a..5a45849 100644 --- a/models/retina_net.py +++ b/models/retina_net.py @@ -1,779 +1,780 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Retina Net. According to https://arxiv.org/abs/1708.02002""" import utils.model_utils as mutils import utils.exp_utils as utils import sys sys.path.append('../') from custom_extensions.nms import nms import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.utils class Classifier(nn.Module): def __init__(self, cf, conv): """ Builds the classifier sub-network. """ super(Classifier, self).__init__() self.dim = conv.dim self.n_classes = cf.head_classes n_input_channels = cf.end_filts n_features = cf.n_rpn_features n_output_channels = cf.n_anchors_per_pos * cf.head_classes anchor_stride = cf.rpn_anchor_stride self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, pad=1, relu=None) def forward(self, x): """ :param x: input feature map (b, in_c, y, x, (z)) :return: class_logits (b, n_anchors, n_classes) """ x = self.conv_1(x) x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) class_logits = self.conv_final(x) axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) class_logits = class_logits.permute(*axes) class_logits = class_logits.contiguous() class_logits = class_logits.view(x.shape[0], -1, self.n_classes) return [class_logits] class BBRegressor(nn.Module): def __init__(self, cf, conv): """ Builds the bb-regression sub-network. """ super(BBRegressor, self).__init__() self.dim = conv.dim n_input_channels = cf.end_filts n_features = cf.n_rpn_features n_output_channels = cf.n_anchors_per_pos * self.dim * 2 anchor_stride = cf.rpn_anchor_stride self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, pad=1, relu=None) def forward(self, x): """ :param x: input feature map (b, in_c, y, x, (z)) :return: bb_logits (b, n_anchors, dim * 2) """ x = self.conv_1(x) x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) bb_logits = self.conv_final(x) axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) bb_logits = bb_logits.permute(*axes) bb_logits = bb_logits.contiguous() bb_logits = bb_logits.view(x.shape[0], -1, self.dim * 2) return [bb_logits] class RoIRegressor(nn.Module): def __init__(self, cf, conv, rg_feats): """ Builds the RoI-item-regression sub-network. Regression items can be, e.g., malignancy scores of tumors. """ super(RoIRegressor, self).__init__() self.dim = conv.dim n_input_channels = cf.end_filts n_features = cf.n_rpn_features self.rg_feats = rg_feats n_output_channels = cf.n_anchors_per_pos * self.rg_feats anchor_stride = cf.rpn_anchor_stride self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, pad=1, relu=None) def forward(self, x): """ :param x: input feature map (b, in_c, y, x, (z)) :return: bb_logits (b, n_anchors, dim * 2) """ x = self.conv_1(x) x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) x = self.conv_final(x) axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) x = x.permute(*axes) x = x.contiguous() x = x.view(x.shape[0], -1, self.rg_feats) return [x] ############################################################ # Loss Functions ############################################################ # def compute_class_loss(anchor_matches, class_pred_logits, shem_poolsize=20): """ :param anchor_matches: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. :param class_pred_logits: (n_anchors, n_classes). logits from classifier sub-network. :param shem_poolsize: int. factor of top-k candidates to draw from per negative sample (online-hard-example-mining). :return: loss: torch tensor :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training. """ # Positive and Negative anchors contribute to the loss, # but neutral anchors (match value = 0) don't. pos_indices = torch.nonzero(anchor_matches > 0) neg_indices = torch.nonzero(anchor_matches == -1) # get positive samples and calucalte loss. if not 0 in pos_indices.size(): pos_indices = pos_indices.squeeze(1) roi_logits_pos = class_pred_logits[pos_indices] targets_pos = anchor_matches[pos_indices].detach() pos_loss = F.cross_entropy(roi_logits_pos, targets_pos.long()) else: pos_loss = torch.FloatTensor([0]).cuda() # get negative samples, such that the amount matches the number of positive samples, but at least 1. # get high scoring negatives by applying online-hard-example-mining. if not 0 in neg_indices.size(): neg_indices = neg_indices.squeeze(1) roi_logits_neg = class_pred_logits[neg_indices] negative_count = np.max((1, pos_indices.cpu().data.numpy().size)) roi_probs_neg = F.softmax(roi_logits_neg, dim=1) neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize) neg_loss = F.cross_entropy(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda()) # return the indices of negative samples, who contributed to the loss for monitoring plots. np_neg_ix = neg_ix.cpu().data.numpy() else: neg_loss = torch.FloatTensor([0]).cuda() np_neg_ix = np.array([]).astype('int32') loss = (pos_loss + neg_loss) / 2 return loss, np_neg_ix def compute_bbox_loss(target_deltas, pred_deltas, anchor_matches): """ :param target_deltas: (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))). Uses 0 padding to fill in unused bbox deltas. :param pred_deltas: predicted deltas from bbox regression head. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))) :param anchor_matches: tensor (n_anchors). value in [-1, 0, class_ids] for negative, neutral, and positive matched anchors. i.e., positively matched anchors are marked by class_id >0 :return: loss: torch 1D tensor. """ if not 0 in torch.nonzero(anchor_matches>0).shape: indices = torch.nonzero(anchor_matches>0).squeeze(1) # Pick bbox deltas that contribute to the loss pred_deltas = pred_deltas[indices] # Trim target bounding box deltas to the same length as pred_deltas. target_deltas = target_deltas[:pred_deltas.shape[0], :].detach() # Smooth L1 loss loss = F.smooth_l1_loss(pred_deltas, target_deltas) else: loss = torch.FloatTensor([0]).cuda() return loss def compute_rg_loss(tasks, target, pred, anchor_matches): """ :param target_deltas: (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))). Uses 0 padding to fill in unsed bbox deltas. :param pred_deltas: predicted deltas from bbox regression head. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))) :param anchor_matches: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. :return: loss: torch 1D tensor. """ if not 0 in target.shape and not 0 in torch.nonzero(anchor_matches>0).shape: indices = torch.nonzero(anchor_matches>0).squeeze(1) # Pick rgs that contribute to the loss pred = pred[indices] # Trim target target = target[:pred.shape[0]].detach() if 'regression_bin' in tasks: loss = F.cross_entropy(pred, target.long()) else: loss = F.smooth_l1_loss(pred, target) else: loss = torch.FloatTensor([0]).cuda() return loss def compute_focal_class_loss(anchor_matches, class_pred_logits, gamma=2.): """ Focal Loss FL = -(1-q)^g log(q) with q = pred class probability. :param anchor_matches: (n_anchors). [-1, 0, class] for negative, neutral, and positive matched anchors. :param class_pred_logits: (n_anchors, n_classes). logits from classifier sub-network. :param gamma: g in above formula, good results with g=2 in original paper. :return: loss: torch tensor :return: focal loss """ # Positive and Negative anchors contribute to the loss, # but neutral anchors (match value = 0) don't. pos_indices = torch.nonzero(anchor_matches > 0).squeeze(-1) # dim=-1 instead of 1 or 0 to cover empty matches. neg_indices = torch.nonzero(anchor_matches == -1).squeeze(-1) target_classes = torch.cat( (anchor_matches[pos_indices].long(), torch.LongTensor([0] * neg_indices.shape[0]).cuda()) ) non_neutral_indices = torch.cat( (pos_indices, neg_indices) ) q = F.softmax(class_pred_logits[non_neutral_indices], dim=1) # q shape: (n_non_neutral_anchors, n_classes) # one-hot encoded target classes: keep only the pred probs of the correct class. it will receive incentive to be maximized. # log(q_i) where i = target class --> FL shape (n_anchors,) # need to transform to indices into flattened tensor to use torch.take target_locs_flat = q.shape[1] * torch.arange(q.shape[0]).cuda() + target_classes q = torch.take(q, target_locs_flat) FL = torch.log(q) # element-wise log FL *= -(1-q)**gamma # take mean over all considered anchors FL = FL.sum() / FL.shape[0] return FL def refine_detections(anchors, probs, deltas, regressions, batch_ixs, cf): """Refine classified proposals, filter overlaps and return final detections. n_proposals here is typically a very large number: batch_size * n_anchors. This function is hence optimized on trimming down n_proposals. :param anchors: (n_anchors, 2 * dim) :param probs: (n_proposals, n_classes) softmax probabilities for all rois as predicted by classifier head. :param deltas: (n_proposals, n_classes, 2 * dim) box refinement deltas as predicted by bbox regressor head. :param regressions: (n_proposals, n_classes, n_rg_feats) :param batch_ixs: (n_proposals) batch element assignemnt info for re-allocation. :return: result: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score, pred_regr)) """ anchors = anchors.repeat(batch_ixs.unique().shape[0], 1) #flatten foreground probabilities, sort and trim down to highest confidences by pre_nms limit. fg_probs = probs[:, 1:].contiguous() flat_probs, flat_probs_order = fg_probs.view(-1).sort(descending=True) keep_ix = flat_probs_order[:cf.pre_nms_limit] # reshape indices to 2D index array with shape like fg_probs. keep_arr = torch.cat(((keep_ix / fg_probs.shape[1]).unsqueeze(1), (keep_ix % fg_probs.shape[1]).unsqueeze(1)), 1) pre_nms_scores = flat_probs[:cf.pre_nms_limit] pre_nms_class_ids = keep_arr[:, 1] + 1 # add background again. pre_nms_batch_ixs = batch_ixs[keep_arr[:, 0]] pre_nms_anchors = anchors[keep_arr[:, 0]] pre_nms_deltas = deltas[keep_arr[:, 0]] pre_nms_regressions = regressions[keep_arr[:, 0]] keep = torch.arange(pre_nms_scores.size()[0]).long().cuda() # apply bounding box deltas. re-scale to image coordinates. std_dev = torch.from_numpy(np.reshape(cf.rpn_bbox_std_dev, [1, cf.dim * 2])).float().cuda() scale = torch.from_numpy(cf.scale).float().cuda() refined_rois = mutils.apply_box_deltas_2D(pre_nms_anchors / scale, pre_nms_deltas * std_dev) * scale \ if cf.dim == 2 else mutils.apply_box_deltas_3D(pre_nms_anchors / scale, pre_nms_deltas * std_dev) * scale # round and cast to int since we're deadling with pixels now refined_rois = mutils.clip_to_window(cf.window, refined_rois) pre_nms_rois = torch.round(refined_rois) for j, b in enumerate(mutils.unique1d(pre_nms_batch_ixs)): bixs = torch.nonzero(pre_nms_batch_ixs == b)[:, 0] bix_class_ids = pre_nms_class_ids[bixs] bix_rois = pre_nms_rois[bixs] bix_scores = pre_nms_scores[bixs] for i, class_id in enumerate(mutils.unique1d(bix_class_ids)): ixs = torch.nonzero(bix_class_ids == class_id)[:, 0] # nms expects boxes sorted by score. ix_rois = bix_rois[ixs] ix_scores = bix_scores[ixs] ix_scores, order = ix_scores.sort(descending=True) ix_rois = ix_rois[order, :] ix_scores = ix_scores class_keep = nms.nms(ix_rois, ix_scores, cf.detection_nms_threshold) # map indices back. class_keep = keep[bixs[ixs[order[class_keep]]]] # merge indices over classes for current batch element b_keep = class_keep if i == 0 else mutils.unique1d(torch.cat((b_keep, class_keep))) # only keep top-k boxes of current batch-element. top_ids = pre_nms_scores[b_keep].sort(descending=True)[1][:cf.model_max_instances_per_batch_element] b_keep = b_keep[top_ids] # merge indices over batch elements. batch_keep = b_keep if j == 0 else mutils.unique1d(torch.cat((batch_keep, b_keep))) keep = batch_keep # arrange output. result = torch.cat((pre_nms_rois[keep], pre_nms_batch_ixs[keep].unsqueeze(1).float(), pre_nms_class_ids[keep].unsqueeze(1).float(), pre_nms_scores[keep].unsqueeze(1), pre_nms_regressions[keep]), dim=1) return result def gt_anchor_matching(cf, anchors, gt_boxes, gt_class_ids=None, gt_regressions=None): """Given the anchors and GT boxes, compute overlaps and identify positive anchors and deltas to refine them to match their corresponding GT boxes. anchors: [num_anchors, (y1, x1, y2, x2, (z1), (z2))] gt_boxes: [num_gt_boxes, (y1, x1, y2, x2, (z1), (z2))] gt_class_ids (optional): [num_gt_boxes] Integer class IDs for one stage detectors. in RPN case of Mask R-CNN, set all positive matches to 1 (foreground) gt_regressions: [num_gt_rgs, n_rg_feats], if None empty rg_targets are returned Returns: anchor_class_matches: [N] (int32) matches between anchors and GT boxes. class_id = positive anchor, -1 = negative anchor, 0 = neutral. i.e., positively matched anchors are marked by class_id (which is >0). anchor_delta_targets: [N, (dy, dx, (dz), log(dh), log(dw), (log(dd)))] Anchor bbox deltas. anchor_rg_targets: [n_anchors, n_rg_feats] """ anchor_class_matches = np.zeros([anchors.shape[0]], dtype=np.int32) anchor_delta_targets = np.zeros((cf.rpn_train_anchors_per_image, 2*cf.dim)) if gt_regressions is not None: if 'regression_bin' in cf.prediction_tasks: anchor_rg_targets = np.zeros((cf.rpn_train_anchors_per_image,)) else: anchor_rg_targets = np.zeros((cf.rpn_train_anchors_per_image, cf.regression_n_features)) else: anchor_rg_targets = np.array([]) anchor_matching_iou = cf.anchor_matching_iou if gt_boxes is None: anchor_class_matches = np.full(anchor_class_matches.shape, fill_value=-1) return anchor_class_matches, anchor_delta_targets, anchor_rg_targets # for mrcnn: anchor matching is done for RPN loss, so positive labels are all 1 (foreground) if gt_class_ids is None: gt_class_ids = np.array([1] * len(gt_boxes)) # Compute overlaps [num_anchors, num_gt_boxes] overlaps = mutils.compute_overlaps(anchors, gt_boxes) # Match anchors to GT Boxes # If an anchor overlaps a GT box with IoU >= anchor_matching_iou then it's positive. # If an anchor overlaps a GT box with IoU < 0.1 then it's negative. # Neutral anchors are those that don't match the conditions above, # and they don't influence the loss function. # However, don't keep any GT box unmatched (rare, but happens). Instead, # match it to the closest anchor (even if its max IoU is < 0.1). # 1. Set negative anchors first. They get overwritten below if a GT box is # matched to them. Skip boxes in crowd areas. anchor_iou_argmax = np.argmax(overlaps, axis=1) anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax] if anchors.shape[1] == 4: anchor_class_matches[(anchor_iou_max < 0.1)] = -1 elif anchors.shape[1] == 6: anchor_class_matches[(anchor_iou_max < 0.01)] = -1 else: raise ValueError('anchor shape wrong {}'.format(anchors.shape)) # 2. Set an anchor for each GT box (regardless of IoU value). gt_iou_argmax = np.argmax(overlaps, axis=0) for ix, ii in enumerate(gt_iou_argmax): anchor_class_matches[ii] = gt_class_ids[ix] # 3. Set anchors with high overlap as positive. above_thresh_ixs = np.argwhere(anchor_iou_max >= anchor_matching_iou) anchor_class_matches[above_thresh_ixs] = gt_class_ids[anchor_iou_argmax[above_thresh_ixs]] # Subsample to balance positive anchors. ids = np.where(anchor_class_matches > 0)[0] extra = len(ids) - (cf.rpn_train_anchors_per_image // 2) if extra > 0: # Reset the extra ones to neutral ids = np.random.choice(ids, extra, replace=False) anchor_class_matches[ids] = 0 # Leave all negative proposals negative for now and sample from them later in online hard example mining. # For positive anchors, compute shift and scale needed to transform them to match the corresponding GT boxes. ids = np.where(anchor_class_matches > 0)[0] ix = 0 # index into anchor_delta_targets for i, a in zip(ids, anchors[ids]): # closest gt box (it might have IoU < anchor_matching_iou) gt = gt_boxes[anchor_iou_argmax[i]] # convert coordinates to center plus width/height. gt_h = gt[2] - gt[0] gt_w = gt[3] - gt[1] gt_center_y = gt[0] + 0.5 * gt_h gt_center_x = gt[1] + 0.5 * gt_w # Anchor a_h = a[2] - a[0] a_w = a[3] - a[1] a_center_y = a[0] + 0.5 * a_h a_center_x = a[1] + 0.5 * a_w if cf.dim == 2: anchor_delta_targets[ix] = [ (gt_center_y - a_center_y) / a_h, (gt_center_x - a_center_x) / a_w, np.log(gt_h / a_h), np.log(gt_w / a_w)] else: gt_d = gt[5] - gt[4] gt_center_z = gt[4] + 0.5 * gt_d a_d = a[5] - a[4] a_center_z = a[4] + 0.5 * a_d anchor_delta_targets[ix] = [ (gt_center_y - a_center_y) / a_h, (gt_center_x - a_center_x) / a_w, (gt_center_z - a_center_z) / a_d, np.log(gt_h / a_h), np.log(gt_w / a_w), np.log(gt_d / a_d)] # normalize. anchor_delta_targets[ix] /= cf.rpn_bbox_std_dev if gt_regressions is not None: anchor_rg_targets[ix] = gt_regressions[anchor_iou_argmax[i]] ix += 1 return anchor_class_matches, anchor_delta_targets, anchor_rg_targets ############################################################ # RetinaNet Class ############################################################ class net(nn.Module): """Encapsulates the RetinaNet model functionality. """ def __init__(self, cf, logger): """ cf: A Sub-class of the cf class model_dir: Directory to save training logs and trained weights """ super(net, self).__init__() self.cf = cf self.logger = logger self.build() if self.cf.weight_init is not None: logger.info("using pytorch weight init of type {}".format(self.cf.weight_init)) mutils.initialize_weights(self) else: logger.info("using default pytorch weight init") self.debug_acm = [] def build(self): """Build Retina Net architecture.""" # Image size must be dividable by 2 multiple times. h, w = self.cf.patch_size[:2] if h / 2 ** 5 != int(h / 2 ** 5) or w / 2 ** 5 != int(w / 2 ** 5): raise Exception("Image size must be divisible by 2 at least 5 times " "to avoid fractions when downscaling and upscaling." "For example, use 256, 320, 384, 448, 512, ... etc. ") backbone = utils.import_module('bbone', self.cf.backbone_path) self.logger.info("loaded backbone from {}".format(self.cf.backbone_path)) conv = backbone.ConvGenerator(self.cf.dim) # build Anchors, FPN, Classifier / Bbox-Regressor -head self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf) self.anchors = torch.from_numpy(self.np_anchors).float().cuda() self.fpn = backbone.FPN(self.cf, conv, operate_stride1=self.cf.operate_stride1).cuda() self.classifier = Classifier(self.cf, conv).cuda() self.bb_regressor = BBRegressor(self.cf, conv).cuda() if 'regression' in self.cf.prediction_tasks: self.roi_regressor = RoIRegressor(self.cf, conv, self.cf.regression_n_features).cuda() elif 'regression_bin' in self.cf.prediction_tasks: # classify into bins of regression values self.roi_regressor = RoIRegressor(self.cf, conv, len(self.cf.bin_labels)).cuda() else: self.roi_regressor = lambda x: [torch.tensor([]).cuda()] if self.cf.model == 'retina_unet': self.final_conv = conv(self.cf.end_filts, self.cf.num_seg_classes, ks=1, pad=0, norm=self.cf.norm, relu=None) def forward(self, img): """ :param img: input img (b, c, y, x, (z)). """ # Feature extraction fpn_outs = self.fpn(img) if self.cf.model == 'retina_unet': seg_logits = self.final_conv(fpn_outs[0]) selected_fmaps = [fpn_outs[i + 1] for i in self.cf.pyramid_levels] else: seg_logits = None selected_fmaps = [fpn_outs[i] for i in self.cf.pyramid_levels] # Loop through pyramid layers class_layer_outputs, bb_reg_layer_outputs, roi_reg_layer_outputs = [], [], [] # list of lists for p in selected_fmaps: class_layer_outputs.append(self.classifier(p)) bb_reg_layer_outputs.append(self.bb_regressor(p)) roi_reg_layer_outputs.append(self.roi_regressor(p)) # Concatenate layer outputs # Convert from list of lists of level outputs to list of lists # of outputs across levels. # e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]] class_logits = list(zip(*class_layer_outputs)) class_logits = [torch.cat(list(o), dim=1) for o in class_logits][0] bb_outputs = list(zip(*bb_reg_layer_outputs)) bb_outputs = [torch.cat(list(o), dim=1) for o in bb_outputs][0] if not 0 == roi_reg_layer_outputs[0][0].shape[0]: rg_outputs = list(zip(*roi_reg_layer_outputs)) rg_outputs = [torch.cat(list(o), dim=1) for o in rg_outputs][0] else: if self.cf.dim == 2: n_feats = np.array([p.shape[-2] * p.shape[-1] * self.cf.n_anchors_per_pos for p in selected_fmaps]).sum() else: n_feats = np.array([p.shape[-3]*p.shape[-2]*p.shape[-1]*self.cf.n_anchors_per_pos for p in selected_fmaps]).sum() rg_outputs = torch.zeros((selected_fmaps[0].shape[0], n_feats, self.cf.regression_n_features), dtype=torch.float32).fill_(float('NaN')).cuda() # merge batch_dimension and store info in batch_ixs for re-allocation. batch_ixs = torch.arange(class_logits.shape[0]).unsqueeze(1).repeat(1, class_logits.shape[1]).view(-1).cuda() flat_class_softmax = F.softmax(class_logits.view(-1, class_logits.shape[-1]), 1) flat_bb_outputs = bb_outputs.view(-1, bb_outputs.shape[-1]) flat_rg_outputs = rg_outputs.view(-1, rg_outputs.shape[-1]) detections = refine_detections(self.anchors, flat_class_softmax, flat_bb_outputs, flat_rg_outputs, batch_ixs, self.cf) return detections, class_logits, bb_outputs, rg_outputs, seg_logits def get_results(self, img_shape, detections, seg_logits, box_results_list=None): """ Restores batch dimension of merged detections, unmolds detections, creates and fills results dict. :param img_shape: :param detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score, pred_regression) :param box_results_list: None or list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, 1] only fg. vs. bg for now. class-specific return of masks will come with implementation of instance segmentation evaluation. """ detections = detections.cpu().data.numpy() batch_ixs = detections[:, self.cf.dim*2] detections = [detections[batch_ixs == ix] for ix in range(img_shape[0])] if box_results_list == None: # for test_forward, where no previous list exists. box_results_list = [[] for _ in range(img_shape[0])] for ix in range(img_shape[0]): if not 0 in detections[ix].shape: boxes = detections[ix][:, :2 * self.cf.dim].astype(np.int32) class_ids = detections[ix][:, 2 * self.cf.dim + 1].astype(np.int32) scores = detections[ix][:, 2 * self.cf.dim + 2] regressions = detections[ix][:, 2 * self.cf.dim + 3:] # Filter out detections with zero area. Often only happens in early # stages of training when the network weights are still a bit random. if self.cf.dim == 2: exclude_ix = np.where((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0] else: exclude_ix = np.where( (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4]) <= 0)[0] if exclude_ix.shape[0] > 0: boxes = np.delete(boxes, exclude_ix, axis=0) class_ids = np.delete(class_ids, exclude_ix, axis=0) scores = np.delete(scores, exclude_ix, axis=0) regressions = np.delete(regressions, exclude_ix, axis=0) if not 0 in boxes.shape: for ix2, score in enumerate(scores): if score >= self.cf.model_min_confidence: box = {'box_type': 'det', 'box_coords': boxes[ix2], 'box_score': score, 'box_pred_class_id': class_ids[ix2]} if "regression_bin" in self.cf.prediction_tasks: # in this case, regression preds are actually the rg_bin_ids --> map to rg value the bin stands for box['rg_bin'] = regressions[ix2].argmax() box['regression'] = self.cf.bin_id2rg_val[box['rg_bin']] else: box['regression'] = regressions[ix2] if hasattr(self.cf, "rg_val_to_bin_id") and \ any(['regression' in task for task in self.cf.prediction_tasks]): box['rg_bin'] = self.cf.rg_val_to_bin_id(regressions[ix2]) box_results_list[ix].append(box) results_dict = {} results_dict['boxes'] = box_results_list if seg_logits is None: # output dummy segmentation for retina_net. out_logits_shape = list(img_shape) out_logits_shape[1] = self.cf.num_seg_classes results_dict['seg_preds'] = np.zeros(out_logits_shape, dtype=np.float16) #todo: try with seg_preds=None? as to not carry heavy dummy preds. else: # output label maps for retina_unet. results_dict['seg_preds'] = F.softmax(seg_logits, 1).cpu().data.numpy() return results_dict def train_forward(self, batch, is_validation=False): """ train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data for processing, computes losses, and stores outputs in a dictionary. :param batch: dictionary containing 'data', 'seg', etc. :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixelwise segmentation output (b, c, y, x, (z)) with values [0, .., n_classes]. 'torch_loss': 1D torch tensor for backprop. 'class_loss': classification loss for monitoring. """ img = batch['data'] gt_class_ids = batch['class_targets'] gt_boxes = batch['bb_target'] if 'regression' in self.cf.prediction_tasks: gt_regressions = batch["regression_targets"] elif 'regression_bin' in self.cf.prediction_tasks: gt_regressions = batch["rg_bin_targets"] else: gt_regressions = None var_seg_ohe = torch.FloatTensor(mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).cuda() var_seg = torch.LongTensor(batch['seg']).cuda() img = torch.from_numpy(img).float().cuda() torch_loss = torch.FloatTensor([0]).cuda() # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. box_results_list = [[] for _ in range(img.shape[0])] detections, class_logits, pred_deltas, pred_rgs, seg_logits = self.forward(img) # loop over batch for b in range(img.shape[0]): # add gt boxes to results dict for monitoring. if len(gt_boxes[b]) > 0: for tix in range(len(gt_boxes[b])): gt_box = {'box_type': 'gt', 'box_coords': batch['bb_target'][b][tix]} for name in self.cf.roi_items: gt_box.update({name: batch[name][b][tix]}) box_results_list[b].append(gt_box) # match gt boxes with anchors to generate targets. anchor_class_match, anchor_target_deltas, anchor_target_rgs = gt_anchor_matching( self.cf, self.np_anchors, gt_boxes[b], gt_class_ids[b], gt_regressions[b] if gt_regressions is not None else None) # add positive anchors used for loss to results_dict for monitoring. pos_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere(anchor_class_match > 0)][:, 0], img.shape[2:]) for p in pos_anchors: box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'}) else: anchor_class_match = np.array([-1]*self.np_anchors.shape[0]) anchor_target_deltas = np.array([]) anchor_target_rgs = np.array([]) anchor_class_match = torch.from_numpy(anchor_class_match).cuda() anchor_target_deltas = torch.from_numpy(anchor_target_deltas).float().cuda() anchor_target_rgs = torch.from_numpy(anchor_target_rgs).float().cuda() if self.cf.focal_loss: # compute class loss as focal loss as suggested in original publication, but multi-class. class_loss = compute_focal_class_loss(anchor_class_match, class_logits[b], gamma=self.cf.focal_loss_gamma) # sparing appendix of negative anchors for monitoring as not really relevant else: # compute class loss with SHEM. class_loss, neg_anchor_ix = compute_class_loss(anchor_class_match, class_logits[b]) # add negative anchors used for loss to results_dict for monitoring. neg_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere(anchor_class_match.cpu().numpy() == -1)][neg_anchor_ix, 0], img.shape[2:]) for n in neg_anchors: box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'}) rg_loss = compute_rg_loss(self.cf.prediction_tasks, anchor_target_rgs, pred_rgs[b], anchor_class_match) bbox_loss = compute_bbox_loss(anchor_target_deltas, pred_deltas[b], anchor_class_match) torch_loss += (class_loss + bbox_loss + rg_loss) / img.shape[0] results_dict = self.get_results(img.shape, detections, seg_logits, box_results_list) results_dict['seg_preds'] = results_dict['seg_preds'].argmax(axis=1).astype('uint8')[:, np.newaxis] - + # todo error + raise Exception("a test error") if self.cf.model == 'retina_unet': seg_loss_dice = 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1),var_seg_ohe) seg_loss_ce = F.cross_entropy(seg_logits, var_seg[:, 0]) torch_loss += (seg_loss_dice + seg_loss_ce) / 2 #self.logger.info("loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}, seg dice: {3:.3f}, seg ce: {4:.3f}, " # "mean pixel preds: {5:.5f}".format(torch_loss.item(), batch_class_loss.item(), batch_bbox_loss.item(), # seg_loss_dice.item(), seg_loss_ce.item(), np.mean(results_dict['seg_preds']))) if 'dice' in self.cf.metrics: results_dict['batch_dices'] = mutils.dice_per_batch_and_class( results_dict['seg_preds'], batch["seg"], self.cf.num_seg_classes, convert_to_ohe=True) #else: #self.logger.info("loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}".format( # torch_loss.item(), class_loss.item(), bbox_loss.item())) results_dict['torch_loss'] = torch_loss results_dict['class_loss'] = class_loss.item() return results_dict def test_forward(self, batch, **kwargs): """ test method. wrapper around forward pass of network without usage of any ground truth information. prepares input data for processing and stores outputs in a dictionary. :param batch: dictionary containing 'data' :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': actually contain seg probabilities since evaluated to seg_preds (via argmax) in predictor. or dummy seg logits for real retina net (detection only) """ img = torch.from_numpy(batch['data']).float().cuda() detections, _, _, _, seg_logits = self.forward(img) results_dict = self.get_results(img.shape, detections, seg_logits) return results_dict \ No newline at end of file diff --git a/unittests.py b/unittests.py index 799e67b..f3613e8 100644 --- a/unittests.py +++ b/unittests.py @@ -1,451 +1,470 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import unittest import os import pickle import time from multiprocessing import Pool +import subprocess import numpy as np import pandas as pd import torch import torchvision as tv import tqdm import utils.exp_utils as utils import utils.model_utils as mutils """ Note on unittests: run this file either in the way intended for unittests by starting the script with python -m unittest unittests.py or start it as a normal python file as python unittests.py. You can selective run single tests by calling python -m unittest unittests.TestClassOfYourChoice, where TestClassOfYourChoice is the name of the test defined below, e.g., CompareFoldSplits. """ def inspect_info_df(pp_dir): """ use your debugger to look into the info df of a pp dir. :param pp_dir: preprocessed-data directory """ info_df = pd.read_pickle(os.path.join(pp_dir, "info_df.pickle")) return def generate_boxes(count, dim=2, h=100, w=100, d=20, normalize=False, on_grid=False, seed=0): """ generate boxes of format [y1, x1, y2, x2, (z1, z2)]. :param count: nr of boxes :param dim: dimension of boxes (2 or 3) :return: boxes in format (n_boxes, 4 or 6), scores """ np.random.seed(seed) if on_grid: lower_y = np.random.randint(0, h // 2, (count,)) lower_x = np.random.randint(0, w // 2, (count,)) upper_y = np.random.randint(h // 2, h, (count,)) upper_x = np.random.randint(w // 2, w, (count,)) if dim == 3: lower_z = np.random.randint(0, d // 2, (count,)) upper_z = np.random.randint(d // 2, d, (count,)) else: lower_y = np.random.rand(count) * h / 2. lower_x = np.random.rand(count) * w / 2. upper_y = (np.random.rand(count) + 1.) * h / 2. upper_x = (np.random.rand(count) + 1.) * w / 2. if dim == 3: lower_z = np.random.rand(count) * d / 2. upper_z = (np.random.rand(count) + 1.) * d / 2. if dim == 3: boxes = np.array(list(zip(lower_y, lower_x, upper_y, upper_x, lower_z, upper_z))) # add an extreme box that tests the boundaries boxes = np.concatenate((boxes, np.array([[0., 0., h, w, 0, d]]))) else: boxes = np.array(list(zip(lower_y, lower_x, upper_y, upper_x))) boxes = np.concatenate((boxes, np.array([[0., 0., h, w]]))) scores = np.random.rand(count + 1) if normalize: divisor = np.array([h, w, h, w, d, d]) if dim == 3 else np.array([h, w, h, w]) boxes = boxes / divisor return boxes, scores #------- perform integrity checks on data set(s) ----------- class VerifyLIDCSAIntegrity(unittest.TestCase): """ Perform integrity checks on preprocessed single-annotator GTs of LIDC data set. """ @staticmethod def check_patient_sa_gt(pid, pp_dir, check_meta_files, check_info_df): faulty_cases = pd.DataFrame(columns=['pid', 'rater', 'cl_targets', 'roi_ids']) all_segs = np.load(os.path.join(pp_dir, pid + "_rois.npz"), mmap_mode='r') all_segs = all_segs[list(all_segs.keys())[0]] all_roi_ids = np.unique(all_segs[all_segs > 0]) assert len(all_roi_ids) == np.max(all_segs), "roi ids not consecutive" if check_meta_files: meta_file = os.path.join(pp_dir, pid + "_meta_info.pickle") with open(meta_file, "rb") as handle: info = pickle.load(handle) assert info["pid"] == pid, "wrong pid in meta_file" all_cl_targets = info["class_target"] if check_info_df: info_df = pd.read_pickle(os.path.join(pp_dir, "info_df.pickle")) pid_info = info_df[info_df.pid == pid] assert len(pid_info) == 1, "found {} entries for pid {} in info df, expected exactly 1".format(len(pid_info), pid) if check_meta_files: assert pid_info[ "class_target"] == all_cl_targets, "meta_info and info_df class targets mismatch:\n{}\n{}".format( pid_info["class_target"], all_cl_targets) all_cl_targets = pid_info["class_target"].iloc[0] assert len(all_roi_ids) == len(all_cl_targets) for rater in range(4): seg = all_segs[rater] roi_ids = np.unique(seg[seg > 0]) cl_targs = np.array([roi[rater] for roi in all_cl_targets]) assert np.count_nonzero(cl_targs) == len(roi_ids), "rater {} has targs {} but roi ids {}".format(rater, cl_targs, roi_ids) assert len(cl_targs) >= len(roi_ids), "not all marked rois have a label" for zeroix_roi_id, rating in enumerate(cl_targs): if not ((rating > 0) == (np.any(seg == zeroix_roi_id + 1))): print("\n\nFAULTY CASE:", end=" ", ) print("pid {}, rater {}, cl_targs {}, ids {}\n".format(pid, rater, cl_targs, roi_ids)) faulty_cases = faulty_cases.append( {'pid': pid, 'rater': rater, 'cl_targets': cl_targs, 'roi_ids': roi_ids}, ignore_index=True) print("finished checking pid {}, {} faulty cases".format(pid, len(faulty_cases))) return faulty_cases - def check_sa_gts(self, pp_dir, pid_subset=None, check_meta_files=False, check_info_df=True, processes=os.cpu_count()): + def check_sa_gts(cf, pp_dir, pid_subset=None, check_meta_files=False, check_info_df=True, processes=os.cpu_count()): report_name = "verify_seg_label_pairings.csv" pids = {file_name.split("_")[0] for file_name in os.listdir(pp_dir) if file_name not in [report_name, "info_df.pickle"]} if pid_subset is not None: pids = [pid for pid in pids if pid in pid_subset] faulty_cases = pd.DataFrame(columns=['pid', 'rater', 'cl_targets', 'roi_ids']) p = Pool(processes=processes) mp_args = zip(pids, [pp_dir]*len(pids), [check_meta_files]*len(pids), [check_info_df]*len(pids)) patient_cases = p.starmap(self.check_patient_sa_gt, mp_args) p.close(); p.join() faulty_cases = faulty_cases.append(patient_cases, sort=False) print("\n\nfaulty case count {}".format(len(faulty_cases))) print(faulty_cases) findings_file = os.path.join(pp_dir, "verify_seg_label_pairings.csv") faulty_cases.to_csv(findings_file) assert len(faulty_cases)==0, "there was a faulty case in data set {}.\ncheck {}".format(pp_dir, findings_file) def test(self): pp_root = "/mnt/HDD2TB/Documents/data/" pp_dir = "lidc/pp_20190805" gt_dir = os.path.join(pp_root, pp_dir, "patient_gts_sa") self.check_sa_gts(gt_dir, check_meta_files=True, check_info_df=False, pid_subset=None) # ["0811a", "0812a"]) #------ compare segmentation gts of preprocessed data sets ------ class CompareSegGTs(unittest.TestCase): """ load and compare pre-processed gts by dice scores of segmentations. """ @staticmethod def group_seg_paths(ref_path, comp_paths): # not working recursively ref_files = [fn for fn in os.listdir(ref_path) if os.path.isfile(os.path.join(ref_path, fn)) and 'seg' in fn and fn.endswith('.npy')] comp_files = [[os.path.join(c_path, fn) for c_path in comp_paths] for fn in ref_files] ref_files = [os.path.join(ref_path, fn) for fn in ref_files] return zip(ref_files, comp_files) @staticmethod def load_calc_dice(paths): dices = [] ref_seg = np.load(paths[0])[np.newaxis, np.newaxis] n_classes = len(np.unique(ref_seg)) ref_seg = mutils.get_one_hot_encoding(ref_seg, n_classes) for c_file in paths[1]: c_seg = np.load(c_file)[np.newaxis, np.newaxis] assert n_classes == len(np.unique(c_seg)), "unequal nr of objects/classes betw segs {} {}".format(paths[0], c_file) c_seg = mutils.get_one_hot_encoding(c_seg, n_classes) dice = mutils.dice_per_batch_inst_and_class(c_seg, ref_seg, n_classes, convert_to_ohe=False) dices.append(dice) print("processed ref_path {}".format(paths[0])) return np.mean(dices), np.std(dices) def iterate_files(self, grouped_paths, processes=os.cpu_count()): p = Pool(processes) means_stds = np.array(p.map(self.load_calc_dice, grouped_paths)) p.close(); p.join() min_dice = np.min(means_stds[:, 0]) print("min mean dice {:.2f}, max std {:.4f}".format(min_dice, np.max(means_stds[:, 1]))) assert min_dice > 1-1e5, "compared seg gts have insufficient minimum mean dice overlap of {}".format(min_dice) def test(self): ref_path = '/mnt/HDD2TB/Documents/data/prostate/data_t2_250519_ps384_gs6071' comp_paths = ['/mnt/HDD2TB/Documents/data/prostate/data_t2_190419_ps384_gs6071', ] paths = self.group_seg_paths(ref_path, comp_paths) self.iterate_files(paths) #------- check if cross-validation fold splits of different experiments are identical ---------- class CompareFoldSplits(unittest.TestCase): """ Find evtl. differences in cross-val file splits across different experiments. """ @staticmethod def group_id_paths(ref_exp_dir, comp_exp_dirs): f_name = 'fold_ids.pickle' ref_paths = os.path.join(ref_exp_dir, f_name) assert os.path.isfile(ref_paths), "ref file {} does not exist.".format(ref_paths) ref_paths = [ref_paths for comp_ed in comp_exp_dirs] comp_paths = [os.path.join(comp_ed, f_name) for comp_ed in comp_exp_dirs] return zip(ref_paths, comp_paths) @staticmethod def comp_fold_ids(mp_input): fold_ids1, fold_ids2 = mp_input with open(fold_ids1, 'rb') as f: fold_ids1 = pickle.load(f) try: with open(fold_ids2, 'rb') as f: fold_ids2 = pickle.load(f) except FileNotFoundError: print("comp file {} does not exist.".format(fold_ids2)) return n_splits = len(fold_ids1) assert n_splits == len(fold_ids2), "mismatch n splits: ref has {}, comp {}".format(n_splits, len(fold_ids2)) split_diffs = [np.setdiff1d(fold_ids1[s], fold_ids2[s]) for s in range(n_splits)] all_equal = np.any(split_diffs) return (split_diffs, all_equal) def iterate_exp_dirs(self, ref_exp, comp_exps, processes=os.cpu_count()): grouped_paths = list(self.group_id_paths(ref_exp, comp_exps)) print("performing {} comparisons of cross-val file splits".format(len(grouped_paths))) p = Pool(processes) split_diffs = p.map(self.comp_fold_ids, grouped_paths) p.close(); p.join() df = pd.DataFrame(index=range(0,len(grouped_paths)), columns=["ref", "comp", "all_equal"])#, "diffs"]) for ix, (ref, comp) in enumerate(grouped_paths): df.iloc[ix] = [ref, comp, split_diffs[ix][1]]#, split_diffs[ix][0]] print("Any splits not equal?", df.all_equal.any()) assert not df.all_equal.any(), "a split set is different from reference split set, {}".format(df[~df.all_equal]) def test(self): exp_parent_dir = '/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/' ref_exp = '/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_detfpn2d_cl_bs10' comp_exps = [os.path.join(exp_parent_dir, p) for p in os.listdir(exp_parent_dir)] comp_exps = [p for p in comp_exps if os.path.isdir(p) and p != ref_exp] self.iterate_exp_dirs(ref_exp, comp_exps) #------- check if cross-validation fold splits of a single experiment are actually incongruent (as required) ---------- class VerifyFoldSplits(unittest.TestCase): """ Check, for a single fold_ids file, i.e., for a single experiment, if the assigned folds (assignment of data identifiers) is actually incongruent. No overlaps between folds are required for a correct cross validation. """ @staticmethod def verify_fold_ids(splits): for i, split1 in enumerate(splits): for j, split2 in enumerate(splits): if j > i: inter = np.intersect1d(split1, split2) if len(inter) > 0: raise Exception("Split {} and {} intersect by pids {}".format(i, j, inter)) def test(self): exp_dir = "/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/dev" check_file = os.path.join(exp_dir, 'fold_ids.pickle') with open(check_file, 'rb') as handle: splits = pickle.load(handle) self.verify_fold_ids(splits) # -------- check own nms CUDA implement against own numpy implement ------ class CheckNMSImplementation(unittest.TestCase): @staticmethod def assert_res_equality(keep_ics1, keep_ics2, boxes, scores, tolerance=0, names=("res1", "res2")): """ :param keep_ics1: keep indices (results), torch.Tensor of shape (n_ics,) :param keep_ics2: :return: """ keep_ics1, keep_ics2 = keep_ics1.cpu().numpy(), keep_ics2.cpu().numpy() discrepancies = np.setdiff1d(keep_ics1, keep_ics2) try: checks = np.array([ len(discrepancies) <= tolerance ]) except: checks = np.zeros((1,)).astype("bool") msgs = np.array([ """{}: {} \n{}: {} \nboxes: {}\n {}\n""".format(names[0], keep_ics1, names[1], keep_ics2, boxes, scores) ]) assert np.all(checks), "NMS: results mismatch: " + "\n".join(msgs[~checks]) def single_case(self, count=20, dim=3, threshold=0.2, seed=0): boxes, scores = generate_boxes(count, dim, seed=seed, h=320, w=280, d=30) keep_numpy = torch.tensor(mutils.nms_numpy(boxes, scores, threshold)) # for some reason torchvision nms requires box coords as floats. boxes = torch.from_numpy(boxes).type(torch.float32) scores = torch.from_numpy(scores).type(torch.float32) if dim == 2: """need to wait until next pytorch release where they fixed nms on cpu (currently they have >= where it needs to be >. """ # keep_ops = tv.ops.nms(boxes, scores, threshold) # self.assert_res_equality(keep_numpy, keep_ops, boxes, scores, tolerance=0, names=["np", "ops"]) pass boxes = boxes.cuda() scores = scores.cuda() keep = self.nms_ext.nms(boxes, scores, threshold) self.assert_res_equality(keep_numpy, keep, boxes, scores, tolerance=0, names=["np", "cuda"]) def test(self, n_cases=200, box_count=30, threshold=0.5): # dynamically import module so that it doesn't affect other tests if import fails self.nms_ext = utils.import_module("nms_ext", 'custom_extensions/nms/nms.py') # change seed to something fix if you want exactly reproducible test seed0 = np.random.randint(50) print("NMS test progress (done/total box configurations) 2D:", end="\n") for i in tqdm.tqdm(range(n_cases)): self.single_case(count=box_count, dim=2, threshold=threshold, seed=seed0+i) print("NMS test progress (done/total box configurations) 3D:", end="\n") for i in tqdm.tqdm(range(n_cases)): self.single_case(count=box_count, dim=3, threshold=threshold, seed=seed0+i) return class CheckRoIAlignImplementation(unittest.TestCase): def prepare(self, dim=2): b, c, h, w = 1, 3, 50, 50 # feature map, (b, c, h, w(, z)) if dim == 2: fmap = torch.rand(b, c, h, w).cuda() # rois = torch.tensor([[ # [0.1, 0.1, 0.3, 0.3], # [0.2, 0.2, 0.4, 0.7], # [0.5, 0.7, 0.7, 0.9], # ]]).cuda() pool_size = (7, 7) rois = generate_boxes(5, dim=dim, h=h, w=w, on_grid=True, seed=np.random.randint(50))[0] elif dim == 3: d = 20 fmap = torch.rand(b, c, h, w, d).cuda() # rois = torch.tensor([[ # [0.1, 0.1, 0.3, 0.3, 0.1, 0.1], # [0.2, 0.2, 0.4, 0.7, 0.2, 0.4], # [0.5, 0.0, 0.7, 1.0, 0.4, 0.5], # [0.0, 0.0, 0.9, 1.0, 0.0, 1.0], # ]]).cuda() pool_size = (7, 7, 3) rois = generate_boxes(5, dim=dim, h=h, w=w, d=d, on_grid=True, seed=np.random.randint(50), normalize=False)[0] else: raise ValueError("dim needs to be 2 or 3") rois = [torch.from_numpy(rois).type(dtype=torch.float32).cuda(), ] fmap.requires_grad_(True) return fmap, rois, pool_size def check_2d(self): fmap, rois, pool_size = self.prepare(dim=2) align_ops = tv.ops.roi_align(fmap, rois, pool_size) loss_ops = align_ops.sum() loss_ops.backward() ra_object = self.ra_ext.RoIAlign(output_size=pool_size, spatial_scale=1., sampling_ratio=-1) align_ext = ra_object(fmap, rois) loss_ext = align_ext.sum() loss_ext.backward() assert (loss_ops == loss_ext), "sum of roialign ops and extension 2D diverges" assert (align_ops == align_ext).all(), "ROIAlign failed 2D test" def check_3d(self): fmap, rois, pool_size = self.prepare(dim=3) ra_object = self.ra_ext.RoIAlign(output_size=pool_size, spatial_scale=1., sampling_ratio=-1) align_ext = ra_object(fmap, rois) loss_ext = align_ext.sum() loss_ext.backward() align_np = mutils.roi_align_3d_numpy(fmap.cpu().detach().numpy(), [roi.cpu().numpy() for roi in rois], pool_size) align_np = np.squeeze(align_np) # remove singleton batch dim align_ext = align_ext.cpu().detach().numpy() assert np.allclose(align_np, align_ext, rtol=1e-5, atol=1e-8), "RoIAlign differences in numpy and CUDA implement" def test(self): # dynamically import module so that it doesn't affect other tests if import fails self.ra_ext = utils.import_module("ra_ext", 'custom_extensions/roi_align/roi_align.py') # 2d test self.check_2d() # 3d test self.check_3d() return class CheckRuntimeErrors(unittest.TestCase): """ Check if minimal examples of the exec.py module finish without runtime errors. This check requires a working path to data in the toy-dataset configs. """ def test(self): cf = utils.import_module("toy_cf", 'datasets/toy/configs.py').Configs() - for model in ["retina_net",]: - cf.model = None - - pass + exp_dir = "./unittesting/" + #checks = {"retina_net": False, "mrcnn": False} + #print("Testing for runtime errors with models {}".format(list(checks.keys()))) + #for model in tqdm.tqdm(list(checks.keys())): + # cf.model = model + # cf.model_path = 'models/{}.py'.format(cf.model if not 'retina' in cf.model else 'retina_net') + # cf.model_path = os.path.join(cf.source_dir, cf.model_path) + # {'mrcnn': cf.add_mrcnn_configs, + # 'retina_net': cf.add_mrcnn_configs, 'retina_unet': cf.add_mrcnn_configs, + # 'detection_unet': cf.add_det_unet_configs, 'detection_fpn': cf.add_det_fpn_configs + # }[model]() + # todo change structure of configs-handling with exec.py so that its dynamically parseable instead of needing to + # todo be changed in the file all the time. + checks = {cf.model:False} + completed_process = subprocess.run("python exec.py --dev --dataset_name toy -m train_test --exp_dir {}".format(exp_dir), + shell=True, capture_output=True, text=True) + if completed_process.returncode!=0: + print("Runtime test of model {} failed due to\n{}".format(cf.model, completed_process.stderr)) + else: + checks[cf.model] = True + subprocess.call("rm -rf {}".format(exp_dir), shell=True) + assert all(checks.values()), "A runtime test crashed." if __name__=="__main__": stime = time.time() unittest.main() mins, secs = divmod((time.time() - stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) \ No newline at end of file