diff --git a/datasets/lidc/configs.py b/datasets/lidc/configs.py
index 126300b..d5d687b 100644
--- a/datasets/lidc/configs.py
+++ b/datasets/lidc/configs.py
@@ -1,445 +1,445 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import sys
 import os
 from collections import namedtuple
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 import numpy as np
 sys.path.append(os.path.dirname(os.path.realpath(__file__))+"/../..")
 from default_configs import DefaultConfigs
 
 # legends, nested classes are not handled well in multiprocessing! hence, Label class def in outer scope
 Label = namedtuple("Label", ['id', 'name', 'color', 'm_scores']) # m_scores = malignancy scores
 binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'm_scores', 'bin_vals'])
 
 
 class Configs(DefaultConfigs):
 
     def __init__(self, server_env=None):
         super(Configs, self).__init__(server_env)
 
         #########################
         #    Preprocessing      #
         #########################
 
         self.root_dir = '/home/gregor/networkdrives/E130-Personal/Goetz/Datenkollektive/Lungendaten/Nodules_LIDC_IDRI'
         self.raw_data_dir = '{}/new_nrrd'.format(self.root_dir)
-        self.pp_dir = '/mnt/HDD2TB/Documents/data/lidc/pp_20190805'
+        self.pp_dir = '/mnt/HDD2TB/Documents/data/lidc/pp_20191219_dev'
         # 'merged' for one gt per image, 'single_annotator' for four gts per image.
         self.gts_to_produce = ["single_annotator", "merged"]
 
         self.target_spacing = (0.7, 0.7, 1.25)
 
         #########################
         #         I/O          #
         #########################
 
         # path to preprocessed data.
         #self.pp_name = 'pp_20190318'
-        self.pp_name = 'pp_20190805'
+        self.pp_name = 'pp_20191219_dev'
 
         self.input_df_name = 'info_df.pickle'
         self.data_sourcedir = '/mnt/HDD2TB/Documents/data/lidc/{}/'.format(self.pp_name)
 
         # settings for deployment on cluster.
         if server_env:
             # path to preprocessed data.
             self.data_sourcedir = '/datasets/data_ramien/lidc/{}_npz/'.format(self.pp_name)
 
         # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_fpn'].
-        self.model = 'retina_net'
+        self.model = 'mrcnn'
         self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net')
         self.model_path = os.path.join(self.source_dir, self.model_path)
 
 
         #########################
         #      Architecture     #
         #########################
 
         # dimension the model operates in. one out of [2, 3].
         self.dim = 3
 
         # 'class': standard object classification per roi, pairwise combinable with each of below tasks.
         # if 'class' is omitted from tasks, object classes will be fg/bg (1/0) from RPN.
         # 'regression': regress some vector per each roi
         # 'regression_ken_gal': use kendall-gal uncertainty sigma
         # 'regression_bin': classify each roi into a bin related to a regression scale
         self.prediction_tasks = ['class']
 
         self.start_filts = 48 if self.dim == 2 else 18
         self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2
         self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50'
         self.norm = None # one of None, 'instance_norm', 'batch_norm'
 
         # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform')
         self.weight_init = None
 
         self.regression_n_features = 1
 
         #########################
         #      Data Loader      #
         #########################
 
         # distorted gt experiments: train on single-annotator gts in a random fashion to investigate network's
         # handling of noisy gts.
         # choose 'merged' for single, merged gt per image, or 'single_annotator' for four gts per image.
         # validation is always performed on same gt kind as training, testing always on merged gt.
         self.training_gts = "merged"
 
         # select modalities from preprocessed data
         self.channels = [0]
         self.n_channels = len(self.channels)
 
         # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation.
         self.pre_crop_size_2D = [320, 320]
         self.patch_size_2D = [320, 320]
         self.pre_crop_size_3D = [160, 160, 96]
         self.patch_size_3D = [160, 160, 96]
 
         self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D
         self.pre_crop_size = self.pre_crop_size_2D if self.dim == 2 else self.pre_crop_size_3D
 
         # ratio of free sampled batch elements before class balancing is triggered
         # (>0 to include "empty"/background patches.)
         self.batch_random_ratio = 0.3
         self.balance_target =  "class_targets" if 'class' in self.prediction_tasks else 'rg_bin_targets'
 
         # set 2D network to match 3D gt boxes.
         self.merge_2D_to_3D_preds = self.dim==2
 
         self.observables_rois = []
 
         #self.rg_map = {1:1, 2:2, 3:3, 4:4, 5:5}
 
         #########################
         #   Colors and Legends  #
         #########################
         self.plot_frequency = 5
 
         binary_cl_labels = [Label(1, 'benign',  (*self.dark_green, 1.),  (1, 2)),
                             Label(2, 'malignant', (*self.red, 1.),  (3, 4, 5))]
         quintuple_cl_labels = [Label(1, 'MS1',  (*self.dark_green, 1.),      (1,)),
                                Label(2, 'MS2',  (*self.dark_yellow, 1.),     (2,)),
                                Label(3, 'MS3',  (*self.orange, 1.),     (3,)),
                                Label(4, 'MS4',  (*self.bright_red, 1.), (4,)),
                                Label(5, 'MS5',  (*self.red, 1.),        (5,))]
         # choose here if to do 2-way or 5-way regression-bin classification
         task_spec_cl_labels = quintuple_cl_labels
 
         self.class_labels = [
             #       #id #name     #color              #malignancy score
             Label(  0,  'bg',     (*self.gray, 0.),  (0,))]
         if "class" in self.prediction_tasks:
             self.class_labels += task_spec_cl_labels
 
         else:
             self.class_labels += [Label(1, 'lesion', (*self.orange, 1.), (1,2,3,4,5))]
 
         if any(['regression' in task for task in self.prediction_tasks]):
             self.bin_labels = [binLabel(0, 'MS0', (*self.gray, 1.), (0,), (0,))]
             self.bin_labels += [binLabel(cll.id, cll.name, cll.color, cll.m_scores,
                                          tuple([ms for ms in cll.m_scores])) for cll in task_spec_cl_labels]
             self.bin_id2label = {label.id: label for label in self.bin_labels}
             self.ms2bin_label = {ms: label for label in self.bin_labels for ms in label.m_scores}
             bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels]
             self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)}
             self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)]
 
         if self.class_specific_seg:
             self.seg_labels = self.class_labels
         else:
             self.seg_labels = [  # id      #name           #color
                 Label(0, 'bg', (*self.gray, 0.)),
                 Label(1, 'fg', (*self.orange, 1.))
             ]
 
         self.class_id2label = {label.id: label for label in self.class_labels}
         self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0}
         # class_dict is used in evaluator / ap, auc, etc. statistics, and class 0 (bg) only needs to be
         # evaluated in debugging
         self.class_cmap = {label.id: label.color for label in self.class_labels}
 
         self.seg_id2label = {label.id: label for label in self.seg_labels}
         self.cmap = {label.id: label.color for label in self.seg_labels}
 
         self.plot_prediction_histograms = True
         self.plot_stat_curves = False
         self.has_colorchannels = False
         self.plot_class_ids = True
 
         self.num_classes = len(self.class_dict)  # for instance classification (excl background)
         self.num_seg_classes = len(self.seg_labels)  # incl background
 
 
         #########################
         #   Data Augmentation   #
         #########################
 
         self.da_kwargs={
             'mirror': True,
             'mirror_axes': tuple(np.arange(0, self.dim, 1)),
             'do_elastic_deform': True,
             'alpha':(0., 1500.),
             'sigma':(30., 50.),
             'do_rotation':True,
             'angle_x': (0., 2 * np.pi),
             'angle_y': (0., 0),
             'angle_z': (0., 0),
             'do_scale': True,
             'scale':(0.8, 1.1),
             'random_crop':False,
             'rand_crop_dist':  (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3),
             'border_mode_data': 'constant',
             'border_cval_data': 0,
             'order_data': 1}
 
         if self.dim == 3:
             self.da_kwargs['do_elastic_deform'] = False
             self.da_kwargs['angle_x'] = (0, 0.0)
             self.da_kwargs['angle_y'] = (0, 0.0) #must be 0!!
             self.da_kwargs['angle_z'] = (0., 2 * np.pi)
 
         #################################
         #  Schedule / Selection / Optim #
         #################################
 
         self.num_epochs = 130 if self.dim == 2 else 150
         self.num_train_batches = 200 if self.dim == 2 else 200
         self.batch_size = 20 if self.dim == 2 else 8
 
         # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training)
         # the former is morge accurate, while the latter is faster (depending on volume size)
         self.val_mode = 'val_sampling' # only 'val_sampling', 'val_patient' not implemented
         if self.val_mode == 'val_patient':
             raise NotImplementedError
         if self.val_mode == 'val_sampling':
             self.num_val_batches = 70
 
         self.save_n_models = 4
         # set a minimum epoch number for saving in case of instabilities in the first phase of training.
         self.min_save_thresh = 0 if self.dim == 2 else 0
         # criteria to average over for saving epochs, 'criterion':weight.
         if "class" in self.prediction_tasks:
             # 'criterion': weight
             if len(self.class_labels)==3:
                 self.model_selection_criteria = {"benign_ap": 0.5, "malignant_ap": 0.5}
             elif len(self.class_labels)==6:
                 self.model_selection_criteria = {str(label.name)+"_ap": 1./5 for label in self.class_labels if label.id!=0}
         elif any("regression" in task for task in self.prediction_tasks):
             self.model_selection_criteria = {"lesion_ap": 0.2, "lesion_avp": 0.8}
 
         self.weight_decay = 0
         self.clip_norm = 200 if 'regression_ken_gal' in self.prediction_tasks else None  # number or None
 
         # int in [0, dataset_size]. select n patients from dataset for prototyping. If None, all data is used.
         self.select_prototype_subset = None #self.batch_size
 
         #########################
         #        Testing        #
         #########################
 
         # set the top-n-epochs to be saved for temporal averaging in testing.
         self.test_n_epochs = self.save_n_models
 
         self.test_aug_axes = (0,1,(0,1))  # None or list: choices are 0,1,(0,1) (0==spatial y, 1== spatial x).
         self.held_out_test_set = False
         self.max_test_patients = "all"  # "all" or number
 
         self.report_score_level = ['rois', 'patient']  # choose list from 'patient', 'rois'
         self.patient_class_of_interest = 2 if 'class' in self.prediction_tasks else 1
 
         self.metrics = ['ap', 'auc']
         if any(['regression' in task for task in self.prediction_tasks]):
             self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp',
                              'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp']
         if 'aleatoric' in self.model:
             self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted']
         self.evaluate_fold_means = True
 
         self.ap_match_ious = [0.1]  # list of ious to be evaluated for ap-scoring.
         self.min_det_thresh = 0.1  # minimum confidence value to select predictions for evaluation.
 
         # aggregation method for test and val_patient predictions.
         # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf,
         # nms = standard non-maximum suppression, or None = no clustering
         self.clustering = 'wbc'
         # iou thresh (exclusive!) for regarding two preds as concerning the same ROI
         self.clustering_iou = 0.1  # has to be larger than desired possible overlap iou of model predictions
 
         self.plot_prediction_histograms = True
         self.plot_stat_curves = False
         self.n_test_plots = 1
 
         #########################
         #   Assertions          #
         #########################
         if not 'class' in self.prediction_tasks:
             assert self.num_classes == 1
 
         #########################
         #   Add model specifics #
         #########################
 
         {'detection_fpn': self.add_det_fpn_configs,
          'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs,
          'retina_net': self.add_mrcnn_configs,
          'retina_unet': self.add_mrcnn_configs,
         }[self.model]()
 
     def rg_val_to_bin_id(self, rg_val):
         return float(np.digitize(np.mean(rg_val), self.bin_edges))
 
     def add_det_fpn_configs(self):
 
         self.learning_rate = [1e-4] * self.num_epochs
         self.dynamic_lr_scheduling = False
 
         # RoI score assigned to aggregation from pixel prediction (connected component). One of ['max', 'median'].
         self.score_det = 'max'
 
         # max number of roi candidates to identify per batch element and class.
         self.n_roi_candidates = 10 if self.dim == 2 else 30
 
         # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
         self.seg_loss_mode = 'wce'
 
         # if <1, false positive predictions in foreground are penalized less.
         self.fp_dice_weight = 1 if self.dim == 2 else 1
         if len(self.class_labels)==3:
             self.wce_weights = [1., 1., 1.] if self.seg_loss_mode=="dice_wce" else [0.1, 1., 1.]
         elif len(self.class_labels)==6:
             self.wce_weights = [1., 1., 1., 1., 1., 1.] if self.seg_loss_mode == "dice_wce" else [0.1, 1., 1., 1., 1., 1.]
         else:
             raise Exception("mismatch loss weights & nr of classes")
         self.detection_min_confidence = self.min_det_thresh
 
         self.head_classes = self.num_seg_classes
 
     def add_mrcnn_configs(self):
 
         # learning rate is a list with one entry per epoch.
         self.learning_rate = [1e-4] * self.num_epochs
         self.dynamic_lr_scheduling = False
         # disable the re-sampling of mask proposals to original size for speed-up.
         # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching),
         # mask-outputs are optional.
         self.return_masks_in_train = False
         self.return_masks_in_val = True
         self.return_masks_in_test = False
 
         # set number of proposal boxes to plot after each epoch.
         self.n_plot_rpn_props = 5 if self.dim == 2 else 30
 
         # number of classes for network heads: n_foreground_classes + 1 (background)
         self.head_classes = self.num_classes + 1
 
         self.frcnn_mode = False
 
         # feature map strides per pyramid level are inferred from architecture.
         self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]}
 
         # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale
         # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.)
         self.rpn_anchor_scales = {'xy': [[8], [16], [32], [64]], 'z': [[2], [4], [8], [16]]}
 
         # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3.
         self.pyramid_levels = [0, 1, 2, 3]
 
         # number of feature maps in rpn. typically lowered in 3D to save gpu-memory.
         self.n_rpn_features = 512 if self.dim == 2 else 128
 
         # anchor ratios and strides per position in feature maps.
         self.rpn_anchor_ratios = [0.5, 1, 2]
         self.rpn_anchor_stride = 1
 
         # Threshold for first stage (RPN) non-maximum suppression (NMS):  LOWER == HARDER SELECTION
         self.rpn_nms_threshold = 0.7 if self.dim == 2 else 0.7
 
         # loss sampling settings.
         self.rpn_train_anchors_per_image = 6  #per batch element
         self.train_rois_per_image = 6 #per batch element
         self.roi_positive_ratio = 0.5
         self.anchor_matching_iou = 0.7
 
         # factor of top-k candidates to draw from  per negative sample (stochastic-hard-example-mining).
         # poolsize to draw top-k candidates from will be shem_poolsize * n_negative_samples.
         self.shem_poolsize = 10
 
         self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3)
         self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5)
         self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10)
 
         self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
         self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
         self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]])
         self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1],
                                self.patch_size_3D[2], self.patch_size_3D[2]])
         if self.dim == 2:
             self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4]
             self.bbox_std_dev = self.bbox_std_dev[:4]
             self.window = self.window[:4]
             self.scale = self.scale[:4]
 
         # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element.
         self.pre_nms_limit = 3000 if self.dim == 2 else 6000
 
         # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True,
         # since proposals of the entire batch are forwarded through second stage in as one "batch".
         self.roi_chunk_size = 2500 if self.dim == 2 else 600
         self.post_nms_rois_training = 500 if self.dim == 2 else 75
         self.post_nms_rois_inference = 500
 
         # Final selection of detections (refine_detections)
         self.model_max_instances_per_batch_element = 10 if self.dim == 2 else 30  # per batch element and class.
         self.detection_nms_threshold = 1e-5  # needs to be > 0, otherwise all predictions are one cluster.
         self.model_min_confidence = 0.1
 
         if self.dim == 2:
             self.backbone_shapes = np.array(
                 [[int(np.ceil(self.patch_size[0] / stride)),
                   int(np.ceil(self.patch_size[1] / stride))]
                  for stride in self.backbone_strides['xy']])
         else:
             self.backbone_shapes = np.array(
                 [[int(np.ceil(self.patch_size[0] / stride)),
                   int(np.ceil(self.patch_size[1] / stride)),
                   int(np.ceil(self.patch_size[2] / stride_z))]
                  for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z']
                                              )])
 
         if self.model == 'retina_net' or self.model == 'retina_unet':
 
             self.focal_loss = True
 
             # implement extra anchor-scales according to retina-net publication.
             self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                             self.rpn_anchor_scales['xy']]
             self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                            self.rpn_anchor_scales['z']]
             self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3
 
             self.n_rpn_features = 256 if self.dim == 2 else 128
 
             # pre-selection of detections for NMS-speedup. per entire batch.
             self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size
 
             # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002
             self.anchor_matching_iou = 0.5
 
             if self.model == 'retina_unet':
                 self.operate_stride1 = True
 
diff --git a/datasets/lidc/preprocessing.py b/datasets/lidc/preprocessing.py
index 2f5efd4..296a520 100644
--- a/datasets/lidc/preprocessing.py
+++ b/datasets/lidc/preprocessing.py
@@ -1,478 +1,478 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 '''
 This preprocessing script loads nrrd files obtained by the data conversion tool: https://github.com/MIC-DKFZ/LIDC-IDRI-processing/tree/v1.0.1
 After applying preprocessing, images are saved as numpy arrays and the meta information for the corresponding patient is stored
 as a line in the dataframe saved as info_df.pickle.
 '''
 
 import os
 import sys
 import shutil
 import subprocess
 import pickle
 import time
 
 import SimpleITK as sitk
 import numpy as np
 from multiprocessing import Pool
 import pandas as pd
 import numpy.testing as npt
 from skimage.transform import resize
 
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 sys.path.append('../..')
 import data_manager as dmanager
 
 class AttributeDict(dict):
     __getattr__ = dict.__getitem__
     __setattr__ = dict.__setitem__
 
 def load_df(path):
     df = pd.read_pickle(path)
     print(df)
 
     return
 
 def resample_array(src_imgs, src_spacing, target_spacing):
     """ Resample a numpy array.
     :param src_imgs: source image.
     :param src_spacing: source image's spacing.
     :param target_spacing: spacing to resample source image to.
     :return:
     """
     src_spacing = np.round(src_spacing, 3)
     target_shape = [int(src_imgs.shape[ix] * src_spacing[::-1][ix] / target_spacing[::-1][ix]) for ix in range(len(src_imgs.shape))]
     for i in range(len(target_shape)):
         try:
             assert target_shape[i] > 0
         except:
             raise AssertionError("AssertionError:", src_imgs.shape, src_spacing, target_spacing)
 
     img = src_imgs.astype('float64')
     resampled_img = resize(img, target_shape, order=1, clip=True, mode='edge').astype('float32')
 
     return resampled_img
 
 class Preprocessor(object):
     """Preprocessor for LIDC raw data. Set in config: which ground truths to produce, choices are
         - "merged" for a single ground truth per input image, created by merging the given four rater annotations
             into one.
         - "single-annotator" for a four-fold ground truth per input image, created by leaving the each rater annotation
             separately.
     :param cf: config.
     :param exclude_inconsistents: bool or tuple, list, np.array, exclude patients that show technical inconsistencies
         in the raw files, likely due to file-naming mistakes. if bool and True: search for patients that have too many
         ratings per lesion or other inconstencies, exclude findings. if param is list of pids: exclude given pids.
     :param overwrite: look for patients that already exist in the pp dir. if overwrite is False, do not redo existing
         patients, otherwise ignore any existing files.
     :param max_count: maximum number of patients to preprocess.
     :param pids_subset: subset of pids to preprocess.
     """
 
     def __init__(self, cf, exclude_inconsistents=True, overwrite=False, max_count=None, pids_subset=None):
 
         self.cf = cf
 
         assert len(self.cf.gts_to_produce)>0, "need to specify which gts to produce, choices: 'merged', 'single_annotator'"
 
         self.paths = [os.path.join(cf.raw_data_dir, ii) for ii in os.listdir(cf.raw_data_dir)]
         if exclude_inconsistents:
             if isinstance(exclude_inconsistents, bool):
                 exclude_paths = self.exclude_too_many_ratings()
                 exclude_paths += self.verify_seg_label_pairings()
             else:
                 assert isinstance(exclude_inconsistents, (tuple,list,np.ndarray))
                 exclude_paths = exclude_inconsistents
             self.paths = [path for path in self.paths if path not in exclude_paths]
 
 
         if 'single_annotator' in self.cf.gts_to_produce or 'sa' in self.cf.gts_to_produce:
             self.pp_dir_sa = os.path.join(cf.pp_dir, "patient_gts_sa")
         if 'merged' in self.cf.gts_to_produce:
             self.pp_dir_merged = os.path.join(cf.pp_dir, "patient_gts_merged")
         orig_count = len(self.paths)
         # check if some patients already have ppd versions in destination dir
         if os.path.exists(cf.pp_dir) and not overwrite:
             fs_in_dir = os.listdir(cf.pp_dir)
             already_done =  [file.split("_")[0] for file in fs_in_dir if file.split("_")[1] == "img.npy"]
             if 'single_annotator' in self.cf.gts_to_produce or 'sa' in self.cf.gts_to_produce:
                 ext = '.npy' if hasattr(self.cf, "save_sa_segs_as") and (
                             self.cf.save_sa_segs_as == "npy" or self.cf.save_sa_segs_as == ".npy") else '.npz'
                 fs_in_dir = os.listdir(self.pp_dir_sa)
                 already_done = [ pid for pid in already_done if pid+"_rois"+ext in fs_in_dir and pid+"_meta_info.pickle" in fs_in_dir]
             if 'merged' in self.cf.gts_to_produce:
                 fs_in_dir = os.listdir(self.pp_dir_merged)
                 already_done = [pid for pid in already_done if
                                 pid + "_rois.npy" in fs_in_dir and pid+"_meta_info.pickle" in fs_in_dir]
 
             self.paths = [p for p in self.paths if not p.split(os.sep)[-1] in already_done]
             if len(self.paths)!=orig_count:
                 print("Due to existing ppd files: Selected a subset of {} patients from originally {}".format(len(self.paths), orig_count))
 
         if pids_subset:
             self.paths = [p for p in self.paths if p.split(os.sep)[-1] in pids_subset]
         if max_count is not None:
             self.paths = self.paths[:max_count]
 
         if not os.path.exists(cf.pp_dir):
             os.mkdir(cf.pp_dir)
         if ('single_annotator' in self.cf.gts_to_produce or 'sa' in self.cf.gts_to_produce) and \
                 not os.path.exists(self.pp_dir_sa):
             os.mkdir(self.pp_dir_sa)
         if 'merged' in self.cf.gts_to_produce and not os.path.exists(self.pp_dir_merged):
             os.mkdir(self.pp_dir_merged)
 
 
     def exclude_too_many_ratings(self):
         """exclude a patient's full path (the patient folder) from further processing if patient has nodules with
             ratings of more than four raters (which is inconsistent with what the raw data is supposed to comprise,
             also rater ids appear multiple times on the same nodule in these cases motivating the assumption that
             the same rater issued more than one rating / mixed up files or annotations for a nodule).
         :return: paths to be excluded.
         """
         exclude_paths = []
         for path in self.paths:
             roi_ids = set([ii.split('.')[0].split('_')[-1] for ii in os.listdir(path) if '.nii.gz' in ii])
             found = False
             for roi_id in roi_ids:
                 n_raters = len([ii for ii in os.listdir(path) if '{}.nii'.format(roi_id) in ii])
                 # assert n_raters<=4, "roi {} in path {} has {} raters".format(roi_id, path, n_raters)
                 if n_raters > 4:
                     print("roi {} in path {} has {} raters".format(roi_id, path, n_raters))
                     found = True
             if found:
                 exclude_paths.append(path)
         print("Patients excluded bc of too many raters:\n")
         for p in exclude_paths:
             print(p)
         print()
 
         return exclude_paths
 
     def analyze_lesion(self, pid, nodule_id):
         """print unique seg and counts of nodule nodule_id of patient pid.
         """
         nodule_id = nodule_id.lstrip("0")
         nodule_id_paths = [ii for ii in os.listdir(os.path.join(self.cf.raw_data_dir, pid)) if '.nii' in ii]
         nodule_id_paths = [ii for ii in nodule_id_paths if ii.split('_')[2].lstrip("0")==nodule_id]
         assert len(nodule_id_paths)==1
         nodule_path = nodule_id_paths[0]
 
         roi = sitk.ReadImage(os.path.join(self.cf.raw_data_dir, pid, nodule_path))
         roi_arr = sitk.GetArrayFromImage(roi).astype(np.uint8)
 
         print("pid {}, nodule {}, unique seg & counts: {}".format(pid, nodule_id, np.unique(roi_arr, return_counts=True)))
         return
 
     def verify_seg_label_pairing(self, path):
         """verifies that a nodule's segmentation has malignancy label > 0 if segmentation has foreground (>0 anywhere),
             and vice-versa that it has only background (==0 everywhere) if no malignancy label (==label 0) assigned.
         :param path: path to the patient folder.
         :return: df containing eventual inconsistency findings.
         """
 
         pid = path.split('/')[-1]
 
         df = pd.read_csv(os.path.join(self.cf.root_dir, 'characteristics.csv'), sep=';')
         df = df[df.PatientID == pid]
 
         findings_df = pd.DataFrame(columns=["problem", "pid", "roi_id", "nodule_id", "rater_ix", "seg_unique", "label"])
 
         print('verifying {}'.format(pid))
 
         roi_ids = set([ii.split('.')[0].split('_')[-1] for ii in os.listdir(path) if '.nii.gz' in ii])
 
         for roi_id in roi_ids:
             roi_id_paths = [ii for ii in os.listdir(path) if '{}.nii'.format(roi_id) in ii]
             nodule_ids = [rp.split('_')[2].lstrip("0") for rp in roi_id_paths]
             rater_ids = [rp.split('_')[1] for rp in roi_id_paths]
             rater_labels = [df[df.NoduleID == int(ii)].Malignancy.values[0] for ii in nodule_ids]
 
             # check double existence of nodule ids
             uniq, counts = np.unique(nodule_ids, return_counts=True)
             if np.any([count>1 for count in counts]):
                 finding = ("same nodule id exists more than once", pid, roi_id, nodule_ids, "N/A", "N/A", "N/A")
                 print("not unique nodule id", finding)
                 findings_df.loc[findings_df.shape[0]] = finding
 
             # check double gradings of single rater for single roi
             uniq, counts = np.unique(rater_ids, return_counts=True)
             if np.any([count>1 for count in counts]):
                 finding = ("same roi_id exists more than once for a single rater", pid, roi_id, nodule_ids, rater_ids, "N/A", rater_labels)
                 print("more than one grading per roi per single rater", finding)
                 findings_df.loc[findings_df.shape[0]] = finding
 
 
             rater_segs = []
             for rp in roi_id_paths:
                 roi = sitk.ReadImage(os.path.join(self.cf.raw_data_dir, pid, rp))
                 roi_arr = sitk.GetArrayFromImage(roi).astype(np.uint8)
 
                 rater_segs.append(roi_arr)
             rater_segs = np.array(rater_segs)
             for r in range(rater_segs.shape[0]):
                 if np.sum(rater_segs[r])>0:
                     if rater_labels[r]<=0:
                         finding =  ("non-empty seg w/ bg label", pid, roi_id, nodule_ids[r], rater_ids[r], np.unique(rater_segs[r]), rater_labels[r])
                         print("{}: pid {}, nodule {}, rater {}, seg unique {}, label {}".format(
                             *finding))
                         findings_df.loc[findings_df.shape[0]] = finding
                 else:
                     if rater_labels[r]>0:
                         finding = ("empty seg w/ fg label", pid, roi_id, nodule_ids[r], rater_ids[r], np.unique(rater_segs[r]), rater_labels[r])
                         print("{}: pid {}, nodule {}, rater {}, seg unique {}, label {}".format(
                             *finding))
                         findings_df.loc[findings_df.shape[0]] = finding
 
         return findings_df
 
     def verify_seg_label_pairings(self, processes=os.cpu_count()):
         """wrapper to multi-process verification of seg-label pairings.
         """
 
         pool = Pool(processes=processes)
         findings_dfs = pool.map(self.verify_seg_label_pairing, self.paths, chunksize=1)
         pool.close()
         pool.join()
 
         findings_df = pd.concat(findings_dfs, axis=0)
         findings_df.to_pickle(os.path.join(self.cf.pp_dir, "verification_seg_label_pairings.pickle"))
         findings_df.to_csv(os.path.join(self.cf.pp_dir, "verification_seg_label_pairings.csv"))
 
         return findings_df.pid.tolist()
 
     def produce_sa_gt(self, path, pid, df, img_spacing, img_arr_shape):
         """ Keep annotations separate, i.e., every processed image has four final GTs.
             Images are always saved as npy. For meeting hard-disk-memory constraints, segmentations can optionally be
             saved as .npz instead of .npy. Dataloader is only implemented for reading .npz segs.
         """
 
         final_rois = np.zeros((4, *img_arr_shape), dtype='uint8')
         patient_mal_labels = []
         roi_ids = list(set([ii.split('.')[0].split('_')[-1] for ii in os.listdir(path) if '.nii.gz' in ii]))
         roi_ids.sort() # just a precaution to have same order of lesions throughout separate runs
 
         rix = 1
         for roi_id in roi_ids:
             roi_id_paths = [ii for ii in os.listdir(path) if '{}.nii'.format(roi_id) in ii]
             assert len(roi_id_paths)>0 and len(roi_id_paths)<=4, "pid {}: should find 0< n_rois <4, but found {}".format(pid, len(roi_id_paths))
 
             """ not strictly necessary precaution: in theory, segmentations of different raters could overlap also for 
                 *different* rois, i.e., a later roi of a rater could (partially) cover up / destroy the roi of another 
                 rater. practically this is unlikely as overlapping lesions of different raters should be regarded as the
                 same lesion, but safety first. hence, the order of raters is maintained across rois, i.e., rater 0 
                 (marked as rater 0 in roi's file name) always has slot 0 in rater_labels and rater_segs, thereby rois
                 are certain to not overlap.
             """
             rater_labels, rater_segs = np.zeros((4,), dtype='uint8'), np.zeros((4,*img_arr_shape), dtype="float32")
             for ix, rp in enumerate(roi_id_paths): # one roi path per rater
                 nodule_id = rp.split('_')[2].lstrip("0")
                 assert not (nodule_id=="5728" or nodule_id=="8840"), "nodule ids {}, {} should be excluded due to seg-mal-label inconsistency.".format(5728, 8840)
                 rater = int(rp.split('_')[1])
                 rater_label = df[df.NoduleID == int(nodule_id)].Malignancy.values[0]
                 rater_labels[rater] = rater_label
 
                 roi = sitk.ReadImage(os.path.join(self.cf.raw_data_dir, pid, rp))
                 for dim in range(len(img_arr_shape)):
                     npt.assert_almost_equal(roi.GetSpacing()[dim], img_spacing[dim])
                 roi_arr = sitk.GetArrayFromImage(roi)
                 roi_arr = resample_array(roi_arr, roi.GetSpacing(), self.cf.target_spacing)
                 assert roi_arr.shape == img_arr_shape, [roi_arr.shape, img_arr_shape, pid, roi.GetSpacing()]
                 assert not np.any(rater_segs[rater]), "overwriting existing rater's seg with roi {}".format(rp)
                 rater_segs[rater] = roi_arr
             rater_segs = np.array(rater_segs)
 
             # rename/remap the malignancy to be positive.
             roi_mal_labels = [ii if ii > -1 else 0 for ii in rater_labels]
             assert rater_segs.shape == final_rois.shape, "rater segs shape {}, final rois shp {}".format(rater_segs.shape, final_rois.shape)
 
             # assert non-zero rating has non-zero seg
             for rater in range(4):
                 if roi_mal_labels[rater]>0:
                     assert np.any(rater_segs[rater]>0), "rater {} mal label {} but uniq seg {}".format(rater, roi_mal_labels[rater], np.unique(rater_segs[rater]))
 
             # add the roi to patient. i.e., write current lesion into final labels and seg of whole patient.
             assert np.any(rater_segs), "empty segmentations for all raters should not exist in single-annotator mode, pid {}, rois: {}".format(pid, roi_id_paths)
             patient_mal_labels.append(roi_mal_labels)
             final_rois[rater_segs > 0] = rix
             rix += 1
 
 
         fg_slices = [[ii for ii in np.unique(np.argwhere(final_rois[r] != 0)[:, 0])] for r in range(4)]
         patient_mal_labels = np.array(patient_mal_labels)
         roi_ids = np.unique(final_rois[final_rois>0])
         assert len(roi_ids) == len(patient_mal_labels), "mismatch {} rois in seg, {} rois in mal labels".format(len(roi_ids), len(patient_mal_labels))
 
         if hasattr(self.cf, "save_sa_segs_as") and (self.cf.save_sa_segs_as=="npy" or self.cf.save_sa_segs_as==".npy"):
             np.save(os.path.join(self.pp_dir_sa, '{}_rois.npy'.format(pid)), final_rois)
         else:
             np.savez_compressed(os.path.join(self.cf.pp_dir, 'patient_gts_sa', '{}_rois.npz'.format(pid)), seg=final_rois)
         with open(os.path.join(self.pp_dir_sa, '{}_meta_info.pickle'.format(pid)), 'wb') as handle:
             meta_info_dict = {'pid': pid, 'class_target': patient_mal_labels, 'spacing': img_spacing,
                               'fg_slices': fg_slices}
             pickle.dump(meta_info_dict, handle)
 
     def produce_merged_gt(self, path, pid, df, img_spacing, img_arr_shape):
         """ process patient with merged annotations, i.e., only one final GT per image. save img and seg to npy, rest to
             metadata.
             annotations merging:
                 - segmentations: only regard a pixel as foreground if at least two raters found it be foreground.
                 - malignancy labels: average over all four rater votes. every rater who did not assign a finding or
                     assigned -1 to the RoI contributes to the average with a vote of 0.
 
         :param path: path to patient folder.
         """
 
         final_rois = np.zeros(img_arr_shape, dtype=np.uint8)
         patient_mal_labels = []
         roi_ids = set([ii.split('.')[0].split('_')[-1] for ii in os.listdir(path) if '.nii.gz' in ii])
 
         rix = 1
         for roi_id in roi_ids:
             roi_id_paths = [ii for ii in os.listdir(path) if '{}.nii'.format(roi_id) in ii]
             nodule_ids = [ii.split('_')[2].lstrip("0") for ii in roi_id_paths]
             rater_labels = [df[df.NoduleID == int(ii)].Malignancy.values[0] for ii in nodule_ids]
             rater_labels.extend([0] * (4 - len(rater_labels)))
             mal_label = np.mean([ii if ii > -1 else 0 for ii in rater_labels])
             rater_segs = []
             for rp in roi_id_paths:
                 roi = sitk.ReadImage(os.path.join(self.cf.raw_data_dir, pid, rp))
                 for dim in range(len(img_arr_shape)):
                     npt.assert_almost_equal(roi.GetSpacing()[dim], img_spacing[dim])
                 roi_arr = sitk.GetArrayFromImage(roi).astype(np.uint8)
                 roi_arr = resample_array(roi_arr, roi.GetSpacing(), self.cf.target_spacing)
                 assert roi_arr.shape == img_arr_shape, [roi_arr.shape, img_arr_shape, pid, roi.GetSpacing()]
                 rater_segs.append(roi_arr)
             rater_segs.extend([np.zeros_like(rater_segs[-1])] * (4 - len(roi_id_paths)))
             rater_segs = np.mean(np.array(rater_segs), axis=0)
             # annotations merging: if less than two raters found fg, set segmentation to bg.
             rater_segs[rater_segs < 0.5] = 0
             if np.sum(rater_segs) > 0:
                 patient_mal_labels.append(mal_label)
                 final_rois[rater_segs > 0] = rix
                 rix += 1
             else:
                 # indicate rois suppressed by majority voting of raters
                 print('suppressed roi!', roi_id_paths)
                 with open(os.path.join(self.pp_dir_merged, 'suppressed_rois.txt'), 'a') as handle:
                     handle.write(" ".join(roi_id_paths))
 
         fg_slices = [ii for ii in np.unique(np.argwhere(final_rois != 0)[:, 0])]
         patient_mal_labels = np.array(patient_mal_labels)
         assert len(patient_mal_labels) + 1 == len(np.unique(final_rois)), [len(patient_mal_labels), np.unique(final_rois), pid]
         assert final_rois.dtype == 'uint8'
         np.save(os.path.join(self.pp_dir_merged, '{}_rois.npy'.format(pid)), final_rois)
 
         with open(os.path.join(self.pp_dir_merged, '{}_meta_info.pickle'.format(pid)), 'wb') as handle:
             meta_info_dict = {'pid': pid, 'class_target': patient_mal_labels, 'spacing': img_spacing,
                               'fg_slices': fg_slices}
             pickle.dump(meta_info_dict, handle)
 
     def pp_patient(self, path):
 
         pid = path.split('/')[-1]
         img = sitk.ReadImage(os.path.join(path, '{}_ct_scan.nrrd'.format(pid)))
         img_arr = sitk.GetArrayFromImage(img)
         print('processing {} with GT(s) {}, spacing {} and img shape {}.'.format(
             pid, " and ".join(self.cf.gts_to_produce), img.GetSpacing(), img_arr.shape))
         img_arr = resample_array(img_arr, img.GetSpacing(), self.cf.target_spacing)
         img_arr = np.clip(img_arr, -1200, 600)
         #img_arr = (1200 + img_arr) / (600 + 1200) * 255  # a+x / (b-a) * (c-d) (c, d = new)
         img_arr = img_arr.astype(np.float32)
         img_arr = (img_arr - np.mean(img_arr)) / np.std(img_arr).astype('float16')
 
         df = pd.read_csv(os.path.join(self.cf.root_dir, 'characteristics.csv'), sep=';')
         df = df[df.PatientID == pid]
 
         np.save(os.path.join(self.cf.pp_dir, '{}_img.npy'.format(pid)), img_arr)
         if 'single_annotator' in self.cf.gts_to_produce or 'sa' in self.cf.gts_to_produce:
             self.produce_sa_gt(path, pid, df, img.GetSpacing(), img_arr.shape)
         if 'merged' in self.cf.gts_to_produce:
             self.produce_merged_gt(path, pid, df, img.GetSpacing(), img_arr.shape)
 
 
     def iterate_patients(self, processes=os.cpu_count()):
         pool = Pool(processes=processes)
         pool.map(self.pp_patient, self.paths, chunksize=1)
         pool.close()
         pool.join()
         print("finished processing raw patient data")
 
 
     def aggregate_meta_info(self):
         self.dfs = {}
         for gt_kind in self.cf.gts_to_produce:
             kind_dir = self.pp_dir_merged if gt_kind == "merged" else self.pp_dir_sa
             files = [os.path.join(kind_dir, f) for f in os.listdir(kind_dir) if 'meta_info.pickle' in f]
             self.dfs[gt_kind] = pd.DataFrame(columns=['pid', 'class_target', 'spacing', 'fg_slices'])
             for f in files:
                 with open(f, 'rb') as handle:
                     self.dfs[gt_kind].loc[len(self.dfs[gt_kind])] = pickle.load(handle)
 
             self.dfs[gt_kind].to_pickle(os.path.join(kind_dir, 'info_df.pickle'))
             print("aggregated meta info to df with length", len(self.dfs[gt_kind]))
 
     def convert_copy_npz(self):
         npz_dir = os.path.join(self.cf.pp_dir+'_npz')
         print("converting to npz dir", npz_dir)
         os.makedirs(npz_dir, exist_ok=True)
 
         dmanager.pack_dataset(self.cf.pp_dir, destination=npz_dir, recursive=True, verbose=False)
         if hasattr(self, 'pp_dir_merged'):
             subprocess.call('rsync -avh --exclude="*.npy" {} {}'.format(self.pp_dir_merged, npz_dir), shell=True)
         if hasattr(self, 'pp_dir_sa'):
             subprocess.call('rsync -avh --exclude="*.npy" {} {}'.format(self.pp_dir_sa, npz_dir), shell=True)
 
 
 if __name__ == "__main__":
     total_stime = time.time()
 
     import configs
-    cf = configs.configs()
+    cf = configs.Configs()
 
     # analysis finding: the following patients have unclear annotations. some raters gave more than one judgement
     # on the same roi.
     patients_to_exclude = ["0137a", "0404a", "0204a", "0252a", "0366a", "0863a", "0815a", "0060a", "0249a", "0436a", "0865a"]
     # further finding: the following patients contain nodules with segmentation-label inconsistencies
     # running Preprocessor.verify_seg_label_pairings() produces a data frame with detailed findings.
     patients_to_exclude += ["0305a", "0447a"]
     exclude_paths = [os.path.join(cf.raw_data_dir, pid) for pid in patients_to_exclude]
     # These pids are automatically found and excluded, when setting exclude_inconsistents=True at Preprocessor
     # initialization instead of passing the pre-compiled list.
 
 
-    pp = Preprocessor(cf, overwrite=True, exclude_inconsistents=exclude_paths, max_count=None, pids_subset=None)#["0998a"])
+    pp = Preprocessor(cf, overwrite=True, exclude_inconsistents=exclude_paths, max_count=80, pids_subset=None)#["0998a"])
     #pp.analyze_lesion("0305a", "5728")
     #pp.analyze_lesion("0305a", "5741")
     #pp.analyze_lesion("0447a", "8840")
 
     #pp.verify_seg_label_pairings()
     #load_df(os.path.join(cf.pp_dir, "verification_seg_label_pairings.pickle"))
     pp.iterate_patients(processes=8)
     # for i in ["/mnt/E130-Personal/Goetz/Datenkollektive/Lungendaten/Nodules_LIDC_IDRI/new_nrrd/0305a",
     #           "/mnt/E130-Personal/Goetz/Datenkollektive/Lungendaten/Nodules_LIDC_IDRI/new_nrrd/0447a"]:  #pp.paths[:1]:
     #      pp.pp_patient(i)
     pp.aggregate_meta_info()
     pp.convert_copy_npz()
 
 
 
     mins, secs = divmod((time.time() - total_stime), 60)
     h, mins = divmod(mins, 60)
     t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
     print("{} total runtime: {}".format(os.path.split(__file__)[1], t))
diff --git a/datasets/toy/configs.py b/datasets/toy/configs.py
index f7b3762..6ae0db0 100644
--- a/datasets/toy/configs.py
+++ b/datasets/toy/configs.py
@@ -1,490 +1,490 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import sys
 import os
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 import numpy as np
 from default_configs import DefaultConfigs
 from collections import namedtuple
 
 boxLabel = namedtuple('boxLabel', ["name", "color"])
 Label = namedtuple("Label", ['id', 'name', 'shape', 'radius', 'color', 'regression', 'ambiguities', 'gt_distortion'])
 binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'bin_vals'])
 
 class Configs(DefaultConfigs):
 
     def __init__(self, server_env=None):
         super(Configs, self).__init__(server_env)
 
         #########################
         #         Prepro        #
         #########################
 
         self.pp_rootdir = os.path.join('/mnt/HDD2TB/Documents/data/toy', "cyl1ps_dev_exact")
         self.pp_npz_dir = self.pp_rootdir+"_npz"
 
         self.pre_crop_size = [320,320,8] #y,x,z; determines pp data shape (2D easily implementable, but only 3D for now)
         self.min_2d_radius = 6 #in pixels
         self.n_train_samples, self.n_test_samples = 80, 80
 
         # not actually real one-hot encoding (ohe) but contains more info: roi-overlap only within classes.
         self.pp_create_ohe_seg = False
         self.pp_empty_samples_ratio = 0.1
 
         self.pp_place_radii_mid_bin = True
         self.pp_only_distort_2d = True
         # outer-most intensity of blurred radii, relative to inner-object intensity. <1 for decreasing, > 1 for increasing.
         # e.g.: setting 0.1 means blurred edge has min intensity 10% as large as inner-object intensity.
         self.pp_blur_min_intensity = 0.2
 
         self.max_instances_per_sample = 3 #how many max instances over all classes per sample (img if 2d, vol if 3d)
         self.max_instances_per_class = self.max_instances_per_sample  # how many max instances per image per class
         self.noise_scale = 0.  # std-dev of gaussian noise
 
         self.ambigs_sampling = "gaussian" #"gaussian" or "uniform"
         """ radius_calib: gt distort for calibrating uncertainty. Range of gt distortion is inferable from
             image by distinguishing it from the rest of the object.
             blurring width around edge will be shifted so that symmetric rel to orig radius.
             blurring scale: if self.ambigs_sampling is uniform, distribution's non-zero range (b-a) will be sqrt(12)*scale
             since uniform dist has variance (b-a)²/12. b,a will be placed symmetrically around unperturbed radius.
             if sampling is gaussian, then scale parameter sets one std dev, i.e., blurring width will be orig_radius * std_dev * 2.
         """
         self.ambiguities = {
              #set which classes to apply which ambs to below in class labels
              #choose out of: 'outer_radius', 'inner_radius', 'radii_relations'.
              #kind              #probability   #scale (gaussian std, relative to unperturbed value)
             #"outer_radius":     (1.,            0.5),
             #"outer_radius_xy":  (1.,            0.5),
             #"inner_radius":     (0.5,            0.1),
             #"radii_relations":  (0.5,            0.1),
             "radius_calib":     (1.,            1./6)
         }
 
         # shape choices: 'cylinder', 'block'
         #                        id,    name,       shape,      radius,                 color,              regression,     ambiguities,    gt_distortion
         self.pp_classes = [Label(1,     'cylinder', 'cylinder', ((6,6,1),(40,40,8)),    (*self.blue, 1.),   "radius_2d",    (),             ()),
                            #Label(2,      'block',      'block',        ((6,6,1),(40,40,8)),  (*self.aubergine,1.),  "radii_2d", (), ('radius_calib',))
             ]
 
 
         #########################
         #         I/O           #
         #########################
 
-        self.data_sourcedir = '/mnt/HDD2TB/Documents/data/toy/cyl1ps_dev_exact'
+        self.data_sourcedir = '/mnt/HDD2TB/Documents/data/toy/cyl1ps_exact'
 
         if server_env:
             self.data_sourcedir = '/datasets/data_ramien/toy/cyl1ps_exact_npz'
 
 
         self.test_data_sourcedir = os.path.join(self.data_sourcedir, 'test')
         self.data_sourcedir = os.path.join(self.data_sourcedir, "train")
 
         self.info_df_name = 'info_df.pickle'
 
         # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn', 'detection_fpn'].
         self.model = 'mrcnn'
         self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net')
         self.model_path = os.path.join(self.source_dir, self.model_path)
 
 
         #########################
         #      Architecture     #
         #########################
 
         # one out of [2, 3]. dimension the model operates in.
-        self.dim = 3
+        self.dim = 2
 
         # 'class', 'regression', 'regression_bin', 'regression_ken_gal'
         # currently only tested mode is a single-task at a time (i.e., only one task in below list)
         # but, in principle, tasks could be combined (e.g., object classes and regression per class)
         self.prediction_tasks = ['class',]
 
         self.start_filts = 48 if self.dim == 2 else 18
         self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2
         self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50'
         self.norm = 'instance_norm' # one of None, 'instance_norm', 'batch_norm'
         self.relu = 'relu'
         # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform')
         self.weight_init = None
 
         self.regression_n_features = 1  # length of regressor target vector
 
 
         #########################
         #      Data Loader      #
         #########################
 
         self.num_epochs = 32
         self.num_train_batches = 120 if self.dim == 2 else 80
         self.batch_size = 16 if self.dim == 2 else 8
 
         self.n_cv_splits = 4
         # select modalities from preprocessed data
         self.channels = [0]
         self.n_channels = len(self.channels)
 
         # which channel (mod) to show as bg in plotting, will be extra added to batch if not in self.channels
         self.plot_bg_chan = 0
         self.crop_margin = [20, 20, 1]  # has to be smaller than respective patch_size//2
         self.patch_size_2D = self.pre_crop_size[:2]
         self.patch_size_3D = self.pre_crop_size[:2]+[8]
 
         # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation.
         self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D
 
         # ratio of free sampled batch elements before class balancing is triggered
         # (>0 to include "empty"/background patches.)
         self.batch_random_ratio = 0.2
         self.balance_target = "class_targets" if 'class' in self.prediction_tasks else "rg_bin_targets"
 
         self.observables_patient = []
         self.observables_rois = []
 
         self.seed = 3 #for generating folds
 
         #############################
         # Colors, Classes, Legends  #
         #############################
         self.plot_frequency = 1
 
         binary_bin_labels = [binLabel(1,  'r<=25',      (*self.green, 1.),      (1,25)),
                              binLabel(2,  'r>25',       (*self.red, 1.),        (25,))]
         quintuple_bin_labels = [binLabel(1,  'r2-10',   (*self.green, 1.),      (2,10)),
                                 binLabel(2,  'r10-20',  (*self.yellow, 1.),     (10,20)),
                                 binLabel(3,  'r20-30',  (*self.orange, 1.),     (20,30)),
                                 binLabel(4,  'r30-40',  (*self.bright_red, 1.), (30,40)),
                                 binLabel(5,  'r>40',    (*self.red, 1.), (40,))]
 
         # choose here if to do 2-way or 5-way regression-bin classification
         task_spec_bin_labels = quintuple_bin_labels
 
         self.class_labels = [
             # regression: regression-task label, either value or "(x,y,z)_radius" or "radii".
             # ambiguities: name of above defined ambig to apply to image data (not gt); need to be iterables!
             # gt_distortion: name of ambig to apply to gt only; needs to be iterable!
             #      #id  #name   #shape  #radius     #color              #regression #ambiguities    #gt_distortion
             Label(  0,  'bg',   None,   (0, 0, 0),  (*self.white, 0.),  (0, 0, 0),  (),             ())]
         if "class" in self.prediction_tasks:
             self.class_labels += self.pp_classes
         else:
             self.class_labels += [Label(1, 'object', 'object', ('various',), (*self.orange, 1.), ('radius_2d',), ("various",), ('various',))]
 
 
         if any(['regression' in task for task in self.prediction_tasks]):
             self.bin_labels = [binLabel(0,  'bg',       (*self.white, 1.),      (0,))]
             self.bin_labels += task_spec_bin_labels
             self.bin_id2label = {label.id: label for label in self.bin_labels}
             bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels]
             self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)}
             self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)]
             self.bin_dict = {label.id: label.name for label in self.bin_labels if label.id != 0}
 
         if self.class_specific_seg:
           self.seg_labels = self.class_labels
 
         self.box_type2label = {label.name: label for label in self.box_labels}
         self.class_id2label = {label.id: label for label in self.class_labels}
         self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0}
 
         self.seg_id2label = {label.id: label for label in self.seg_labels}
         self.cmap = {label.id: label.color for label in self.seg_labels}
 
         self.plot_prediction_histograms = True
         self.plot_stat_curves = False
         self.has_colorchannels = False
         self.plot_class_ids = True
 
         self.num_classes = len(self.class_dict)
         self.num_seg_classes = len(self.seg_labels)
 
         #########################
         #   Data Augmentation   #
         #########################
         self.do_aug = True
         self.da_kwargs = {
             'mirror': True,
             'mirror_axes': tuple(np.arange(0, self.dim, 1)),
             'do_elastic_deform': False,
             'alpha': (500., 1500.),
             'sigma': (40., 45.),
             'do_rotation': False,
             'angle_x': (0., 2 * np.pi),
             'angle_y': (0., 0),
             'angle_z': (0., 0),
             'do_scale': False,
             'scale': (0.8, 1.1),
             'random_crop': False,
             'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3),
             'border_mode_data': 'constant',
             'border_cval_data': 0,
             'order_data': 1
         }
 
         if self.dim == 3:
             self.da_kwargs['do_elastic_deform'] = False
             self.da_kwargs['angle_x'] = (0, 0.0)
             self.da_kwargs['angle_y'] = (0, 0.0)  # must be 0!!
             self.da_kwargs['angle_z'] = (0., 2 * np.pi)
 
         #########################
         #  Schedule / Selection #
         #########################
 
         # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training)
         # the former is morge accurate, while the latter is faster (depending on volume size)
         self.val_mode = 'val_sampling' # one of 'val_sampling' , 'val_patient'
         if self.val_mode == 'val_patient':
             self.max_val_patients = 220  # if 'all' iterates over entire val_set once.
         if self.val_mode == 'val_sampling':
             self.num_val_batches = 25 if self.dim==2 else 15
 
         self.save_n_models = 2
         self.min_save_thresh = 1 if self.dim == 2 else 1  # =wait time in epochs
         if "class" in self.prediction_tasks:
             self.model_selection_criteria = {name + "_ap": 1. for name in self.class_dict.values()}
         elif any("regression" in task for task in self.prediction_tasks):
             self.model_selection_criteria = {name + "_ap": 0.2 for name in self.class_dict.values()}
             self.model_selection_criteria.update({name + "_avp": 0.8 for name in self.class_dict.values()})
 
         self.lr_decay_factor = 0.5
         self.scheduling_patience = int(self.num_epochs / 5)
         self.weight_decay = 1e-5
         self.clip_norm = None  # number or None
 
         #########################
         #   Testing / Plotting  #
         #########################
 
         self.test_aug_axes = (0,1,(0,1)) # None or list: choices are 0,1,(0,1)
         self.held_out_test_set = True
         self.max_test_patients = "all"  # number or "all" for all
 
         self.test_against_exact_gt = not 'exact' in self.data_sourcedir
         self.val_against_exact_gt = False # True is an unrealistic --> irrelevant scenario.
         self.report_score_level = ['rois']  # 'patient' or 'rois' (incl)
         self.patient_class_of_interest = 1
         self.patient_bin_of_interest = 2
 
         self.eval_bins_separately = False#"additionally" if not 'class' in self.prediction_tasks else False
         self.metrics = ['ap', 'auc', 'dice']
         if any(['regression' in task for task in self.prediction_tasks]):
             self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp',
                              'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp']
         if 'aleatoric' in self.model:
             self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted']
         self.evaluate_fold_means = True
 
         self.ap_match_ious = [0.5]  # threshold(s) for considering a prediction as true positive
         self.min_det_thresh = 0.3
 
         self.model_max_iou_resolution = 0.2
 
         # aggregation method for test and val_patient predictions.
         # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf,
         # nms = standard non-maximum suppression, or None = no clustering
         self.clustering = 'wbc'
         # iou thresh (exclusive!) for regarding two preds as concerning the same ROI
         self.clustering_iou = self.model_max_iou_resolution  # has to be larger than desired possible overlap iou of model predictions
 
         self.merge_2D_to_3D_preds = False
         self.merge_3D_iou = self.model_max_iou_resolution
         self.n_test_plots = 1  # per fold and rank
 
         self.test_n_epochs = self.save_n_models  # should be called n_test_ens, since is number of models to ensemble over during testing
         # is multiplied by (1 + nr of test augs)
 
         #########################
         #   Assertions          #
         #########################
         if not 'class' in self.prediction_tasks:
             assert self.num_classes == 1
 
         #########################
         #   Add model specifics #
         #########################
 
         {'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs,
          'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs,
          'detection_unet': self.add_det_unet_configs, 'detection_fpn': self.add_det_fpn_configs
          }[self.model]()
 
     def rg_val_to_bin_id(self, rg_val):
         #only meant for isotropic radii!!
         # only 2D radii (x and y dims) or 1D (x or y) are expected
         return np.round(np.digitize(rg_val, self.bin_edges).mean())
 
 
     def add_det_fpn_configs(self):
 
       self.learning_rate = [5 * 1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True
       self.scheduling_criterion = 'torch_loss'
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       self.n_roi_candidates = 4 if self.dim == 2 else 6
       # max number of roi candidates to identify per image (slice in 2D, volume in 3D)
 
       # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
       self.seg_loss_mode = 'wce'
       self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1]
 
       self.fp_dice_weight = 1 if self.dim == 2 else 1
       # if <1, false positive predictions in foreground are penalized less.
 
       self.detection_min_confidence = 0.05
       # how to determine score of roi: 'max' or 'median'
       self.score_det = 'max'
 
     def add_det_unet_configs(self):
 
       self.learning_rate = [5 * 1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True
       self.scheduling_criterion = "torch_loss"
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       # max number of roi candidates to identify per image (slice in 2D, volume in 3D)
       self.n_roi_candidates = 4 if self.dim == 2 else 6
 
       # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
       self.seg_loss_mode = 'wce'
       self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1]
       # if <1, false positive predictions in foreground are penalized less.
       self.fp_dice_weight = 1 if self.dim == 2 else 1
 
       self.detection_min_confidence = 0.05
       # how to determine score of roi: 'max' or 'median'
       self.score_det = 'max'
 
       self.init_filts = 32
       self.kernel_size = 3  # ks for horizontal, normal convs
       self.kernel_size_m = 2  # ks for max pool
       self.pad = "same"  # "same" or integer, padding of horizontal convs
 
     def add_mrcnn_configs(self):
 
       self.learning_rate = [1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True  # with scheduler set in exec
       self.scheduling_criterion = max(self.model_selection_criteria, key=self.model_selection_criteria.get)
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       # number of classes for network heads: n_foreground_classes + 1 (background)
       self.head_classes = self.num_classes + 1 if 'class' in self.prediction_tasks else 2
 
       # feed +/- n neighbouring slices into channel dimension. set to None for no context.
       self.n_3D_context = None
       if self.n_3D_context is not None and self.dim == 2:
         self.n_channels *= (self.n_3D_context * 2 + 1)
 
       self.detect_while_training = True
       # disable the re-sampling of mask proposals to original size for speed-up.
       # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching),
       # mask outputs are optional.
       self.return_masks_in_train = True
       self.return_masks_in_val = True
       self.return_masks_in_test = True
 
       # feature map strides per pyramid level are inferred from architecture. anchor scales are set accordingly.
       self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]}
       # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale
       # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.)
       self.rpn_anchor_scales = {'xy': [[4], [8], [16], [32]], 'z': [[1], [2], [4], [8]]}
       # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3.
       self.pyramid_levels = [0, 1, 2, 3]
       # number of feature maps in rpn. typically lowered in 3D to save gpu-memory.
       self.n_rpn_features = 512 if self.dim == 2 else 64
 
       # anchor ratios and strides per position in feature maps.
       self.rpn_anchor_ratios = [0.5, 1., 2.]
       self.rpn_anchor_stride = 1
       # Threshold for first stage (RPN) non-maximum suppression (NMS):  LOWER == HARDER SELECTION
       self.rpn_nms_threshold = max(0.8, self.model_max_iou_resolution)
 
       # loss sampling settings.
       self.rpn_train_anchors_per_image = 4
       self.train_rois_per_image = 6 # per batch_instance
       self.roi_positive_ratio = 0.5
       self.anchor_matching_iou = 0.8
 
       # k negative example candidates are drawn from a pool of size k*shem_poolsize (stochastic hard-example mining),
       # where k<=#positive examples.
       self.shem_poolsize = 2
 
       self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3)
       self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5)
       self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10)
 
       self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
       self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
       self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]])
       self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1],
                              self.patch_size_3D[2], self.patch_size_3D[2]])  # y1,x1,y2,x2,z1,z2
 
       if self.dim == 2:
         self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4]
         self.bbox_std_dev = self.bbox_std_dev[:4]
         self.window = self.window[:4]
         self.scale = self.scale[:4]
 
       self.plot_y_max = 1.5
       self.n_plot_rpn_props = 5 if self.dim == 2 else 30  # per batch_instance (slice in 2D / patient in 3D)
 
       # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element.
       self.pre_nms_limit = 2000 if self.dim == 2 else 4000
 
       # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True,
       # since proposals of the entire batch are forwarded through second stage as one "batch".
       self.roi_chunk_size = 1300 if self.dim == 2 else 500
       self.post_nms_rois_training = 200 * (self.head_classes-1) if self.dim == 2 else 400
       self.post_nms_rois_inference = 200 * (self.head_classes-1)
 
       # Final selection of detections (refine_detections)
       self.model_max_instances_per_batch_element = 9 if self.dim == 2 else 18 # per batch element and class.
       self.detection_nms_threshold = self.model_max_iou_resolution  # needs to be > 0, otherwise all predictions are one cluster.
       self.model_min_confidence = 0.2  # iou for nms in box refining (directly after heads), should be >0 since ths>=x in mrcnn.py
 
       if self.dim == 2:
         self.backbone_shapes = np.array(
           [[int(np.ceil(self.patch_size[0] / stride)),
             int(np.ceil(self.patch_size[1] / stride))]
            for stride in self.backbone_strides['xy']])
       else:
         self.backbone_shapes = np.array(
           [[int(np.ceil(self.patch_size[0] / stride)),
             int(np.ceil(self.patch_size[1] / stride)),
             int(np.ceil(self.patch_size[2] / stride_z))]
            for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z']
                                        )])
 
       if self.model == 'retina_net' or self.model == 'retina_unet':
         # whether to use focal loss or SHEM for loss-sample selection
         self.focal_loss = False
         # implement extra anchor-scales according to https://arxiv.org/abs/1708.02002
         self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                         self.rpn_anchor_scales['xy']]
         self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                        self.rpn_anchor_scales['z']]
         self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3
 
         # pre-selection of detections for NMS-speedup. per entire batch.
         self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size
 
         # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002
         self.anchor_matching_iou = 0.7
 
         if self.model == 'retina_unet':
           self.operate_stride1 = True
diff --git a/models/mrcnn.py b/models/mrcnn.py
index 8e85ba3..a9535fe 100644
--- a/models/mrcnn.py
+++ b/models/mrcnn.py
@@ -1,755 +1,751 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 """
 Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn
 published under MIT license.
 """
 import os
 from multiprocessing import  Pool
 import time
 
 import numpy as np
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.utils
 
 import utils.model_utils as mutils
 import utils.exp_utils as utils
 
 
 
 class RPN(nn.Module):
     """
     Region Proposal Network.
     """
 
     def __init__(self, cf, conv):
 
         super(RPN, self).__init__()
         self.dim = conv.dim
 
         self.conv_shared = conv(cf.end_filts, cf.n_rpn_features, ks=3, stride=cf.rpn_anchor_stride, pad=1, relu=cf.relu)
         self.conv_class = conv(cf.n_rpn_features, 2 * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None)
         self.conv_bbox = conv(cf.n_rpn_features, 2 * self.dim * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None)
 
 
     def forward(self, x):
         """
         :param x: input feature maps (b, in_channels, y, x, (z))
         :return: rpn_class_logits (b, 2, n_anchors)
         :return: rpn_probs_logits (b, 2, n_anchors)
         :return: rpn_bbox (b, 2 * dim, n_anchors)
         """
 
         # Shared convolutional base of the RPN.
         x = self.conv_shared(x)
 
         # Anchor Score. (batch, anchors per location * 2, y, x, (z)).
         rpn_class_logits = self.conv_class(x)
         # Reshape to (batch, 2, anchors)
         axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1)
         rpn_class_logits = rpn_class_logits.permute(*axes)
         rpn_class_logits = rpn_class_logits.contiguous()
         rpn_class_logits = rpn_class_logits.view(x.size()[0], -1, 2)
 
         # Softmax on last dimension (fg vs. bg).
         rpn_probs = F.softmax(rpn_class_logits, dim=2)
 
         # Bounding box refinement. (batch, anchors_per_location * (y, x, (z), log(h), log(w), (log(d)), y, x, (z))
         rpn_bbox = self.conv_bbox(x)
 
         # Reshape to (batch, 2*dim, anchors)
         rpn_bbox = rpn_bbox.permute(*axes)
         rpn_bbox = rpn_bbox.contiguous()
         rpn_bbox = rpn_bbox.view(x.size()[0], -1, self.dim * 2)
 
         return [rpn_class_logits, rpn_probs, rpn_bbox]
 
 
 
 class Classifier(nn.Module):
     """
     Head network for classification and bounding box refinement. Performs RoiAlign, processes resulting features through a
     shared convolutional base and finally branches off the classifier- and regression head.
     """
     def __init__(self, cf, conv):
         super(Classifier, self).__init__()
 
         self.cf = cf
         self.dim = conv.dim
         self.in_channels = cf.end_filts
         self.pool_size = cf.pool_size
         self.pyramid_levels = cf.pyramid_levels
         # instance_norm does not work with spatial dims (1, 1, (1))
         norm = cf.norm if cf.norm != 'instance_norm' else None
 
         self.conv1 = conv(cf.end_filts, cf.end_filts * 4, ks=self.pool_size, stride=1, norm=norm, relu=cf.relu)
         self.conv2 = conv(cf.end_filts * 4, cf.end_filts * 4, ks=1, stride=1, norm=norm, relu=cf.relu)
         self.linear_bbox = nn.Linear(cf.end_filts * 4, cf.head_classes * 2 * self.dim)
 
 
         if 'regression' in self.cf.prediction_tasks:
             self.linear_regressor = nn.Linear(cf.end_filts * 4, cf.head_classes * cf.regression_n_features)
             self.rg_n_feats = cf.regression_n_features
         #classify into bins of regression values
         elif 'regression_bin' in self.cf.prediction_tasks:
             self.linear_regressor = nn.Linear(cf.end_filts * 4, cf.head_classes * len(cf.bin_labels))
             self.rg_n_feats = len(cf.bin_labels)
         else:
             self.linear_regressor = lambda x: torch.zeros((x.shape[0], cf.head_classes * 1), dtype=torch.float32).fill_(float('NaN')).cuda()
             self.rg_n_feats = 1 #cf.regression_n_features
         if 'class' in self.cf.prediction_tasks:
             self.linear_class = nn.Linear(cf.end_filts * 4, cf.head_classes)
         else:
             assert cf.head_classes == 2, "#head classes {} needs to be 2 (bg/fg) when not predicting classes".format(cf.head_classes)
             self.linear_class = lambda x: torch.zeros((x.shape[0], cf.head_classes), dtype=torch.float64).cuda()
 
 
     def forward(self, x, rois):
         """
         :param x: input feature maps (b, in_channels, y, x, (z))
         :param rois: normalized box coordinates as proposed by the RPN to be forwarded through
         the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements
         have been merged to one vector, while the origin info has been stored for re-allocation.
         :return: mrcnn_class_logits (n_proposals, n_head_classes)
         :return: mrcnn_bbox (n_proposals, n_head_classes, 2 * dim) predicted corrections to be applied to proposals for refinement.
         """
         x = mutils.pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim)
         x = self.conv1(x)
         x = self.conv2(x)
         x = x.view(-1, self.in_channels * 4)
 
         mrcnn_bbox = self.linear_bbox(x)
         mrcnn_bbox = mrcnn_bbox.view(mrcnn_bbox.size()[0], -1, self.dim * 2)
         mrcnn_class_logits = self.linear_class(x)
         mrcnn_regress = self.linear_regressor(x)
         mrcnn_regress = mrcnn_regress.view(mrcnn_regress.size()[0], -1, self.rg_n_feats)
 
         return [mrcnn_bbox, mrcnn_class_logits, mrcnn_regress]
 
 
 class Mask(nn.Module):
     """
     Head network for proposal-based mask segmentation. Performs RoiAlign, some convolutions and applies sigmoid on the
     output logits to allow for overlapping classes.
     """
     def __init__(self, cf, conv):
         super(Mask, self).__init__()
         self.pool_size = cf.mask_pool_size
         self.pyramid_levels = cf.pyramid_levels
         self.dim = conv.dim
         self.conv1 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
         self.conv2 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
         self.conv3 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
         self.conv4 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
         if conv.dim == 2:
             self.deconv = nn.ConvTranspose2d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2)
         else:
             self.deconv = nn.ConvTranspose3d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2)
 
         self.relu = nn.ReLU(inplace=True) if cf.relu == 'relu' else nn.LeakyReLU(inplace=True)
         self.conv5 = conv(cf.end_filts, cf.head_classes, ks=1, stride=1, relu=None)
         self.sigmoid = nn.Sigmoid()
 
     def forward(self, x, rois):
         """
         :param x: input feature maps (b, in_channels, y, x, (z))
         :param rois: normalized box coordinates as proposed by the RPN to be forwarded through
         the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements
         have been merged to one vector, while the origin info has been stored for re-allocation.
         :return: x: masks (n_sampled_proposals (n_detections in inference), n_classes, y, x, (z))
         """
         x = mutils.pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim)
         x = self.conv1(x)
         x = self.conv2(x)
         x = self.conv3(x)
         x = self.conv4(x)
         x = self.relu(self.deconv(x))
         x = self.conv5(x)
         x = self.sigmoid(x)
         return x
 
 
 ############################################################
 #  Loss Functions
 ############################################################
 
 def compute_rpn_class_loss(rpn_class_logits, rpn_match, shem_poolsize):
     """
     :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors.
     :param rpn_class_logits: (n_anchors, 2). logits from RPN classifier.
     :param SHEM_poolsize: int. factor of top-k candidates to draw from per negative sample (stochastic-hard-example-mining).
     :return: loss: torch tensor
     :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training.
     """
 
     # Filter out netural anchors
     pos_indices = torch.nonzero(rpn_match == 1)
     neg_indices = torch.nonzero(rpn_match == -1)
 
     # loss for positive samples
     if not 0 in pos_indices.size():
         pos_indices = pos_indices.squeeze(1)
         roi_logits_pos = rpn_class_logits[pos_indices]
         pos_loss = F.cross_entropy(roi_logits_pos, torch.LongTensor([1] * pos_indices.shape[0]).cuda())
     else:
         pos_loss = torch.FloatTensor([0]).cuda()
 
     # loss for negative samples: draw hard negative examples (SHEM)
     # that match the number of positive samples, but at least 1.
     if not 0 in neg_indices.size():
         neg_indices = neg_indices.squeeze(1)
         roi_logits_neg = rpn_class_logits[neg_indices]
         negative_count = np.max((1, pos_indices.cpu().data.numpy().size))
         roi_probs_neg = F.softmax(roi_logits_neg, dim=1)
         neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize)
         neg_loss = F.cross_entropy(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda())
         np_neg_ix = neg_ix.cpu().data.numpy()
         #print("pos, neg count", pos_indices.cpu().data.numpy().size, negative_count)
     else:
         neg_loss = torch.FloatTensor([0]).cuda()
         np_neg_ix = np.array([]).astype('int32')
 
     loss = (pos_loss + neg_loss) / 2
     return loss, np_neg_ix
 
 
 def compute_rpn_bbox_loss(rpn_pred_deltas, rpn_target_deltas, rpn_match):
     """
     :param rpn_target_deltas:   (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))).
     Uses 0 padding to fill in unsed bbox deltas.
     :param rpn_pred_deltas: predicted deltas from RPN. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd))))
     :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors.
     :return: loss: torch 1D tensor.
     """
     if not 0 in torch.nonzero(rpn_match == 1).size():
 
         indices = torch.nonzero(rpn_match == 1).squeeze(1)
         # Pick bbox deltas that contribute to the loss
         rpn_pred_deltas = rpn_pred_deltas[indices]
         # Trim target bounding box deltas to the same length as rpn_bbox.
         target_deltas = rpn_target_deltas[:rpn_pred_deltas.size()[0], :]
         # Smooth L1 loss
         loss = F.smooth_l1_loss(rpn_pred_deltas, target_deltas)
     else:
         loss = torch.FloatTensor([0]).cuda()
 
     return loss
 
 def compute_mrcnn_bbox_loss(mrcnn_pred_deltas, mrcnn_target_deltas, target_class_ids):
     """
     :param mrcnn_target_deltas: (n_sampled_rois, (dy, dx, (dz), log(dh), log(dw), (log(dh)))
     :param mrcnn_pred_deltas: (n_sampled_rois, n_classes, (dy, dx, (dz), log(dh), log(dw), (log(dh)))
     :param target_class_ids: (n_sampled_rois)
     :return: loss: torch 1D tensor.
     """
     if not 0 in torch.nonzero(target_class_ids > 0).size():
         positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0]
         positive_roi_class_ids = target_class_ids[positive_roi_ix].long()
         target_bbox = mrcnn_target_deltas[positive_roi_ix, :].detach()
         pred_bbox = mrcnn_pred_deltas[positive_roi_ix, positive_roi_class_ids, :]
         loss = F.smooth_l1_loss(pred_bbox, target_bbox)
     else:
         loss = torch.FloatTensor([0]).cuda()
 
     return loss
 
 def compute_mrcnn_mask_loss(pred_masks, target_masks, target_class_ids):
     """
     :param target_masks: (n_sampled_rois, y, x, (z)) A float32 tensor of values 0 or 1. Uses zero padding to fill array.
     :param pred_masks: (n_sampled_rois, n_classes, y, x, (z)) float32 tensor with values between [0, 1].
     :param target_class_ids: (n_sampled_rois)
     :return: loss: torch 1D tensor.
     """
+    #print("targ masks", target_masks.unique(return_counts=True))
     if not 0 in torch.nonzero(target_class_ids > 0).size():
         # Only positive ROIs contribute to the loss. And only
         # the class-specific mask of each ROI.
         positive_ix = torch.nonzero(target_class_ids > 0)[:, 0]
         positive_class_ids = target_class_ids[positive_ix].long()
         y_true = target_masks[positive_ix, :, :].detach()
         y_pred = pred_masks[positive_ix, positive_class_ids, :, :]
         loss = F.binary_cross_entropy(y_pred, y_true)
     else:
         loss = torch.FloatTensor([0]).cuda()
 
     return loss
 
 def compute_mrcnn_class_loss(tasks, pred_class_logits, target_class_ids):
     """
     :param pred_class_logits: (n_sampled_rois, n_classes)
     :param target_class_ids: (n_sampled_rois) batch dimension was merged into roi dimension.
     :return: loss: torch 1D tensor.
     """
     if 'class' in tasks and not 0 in target_class_ids.size():
         loss = F.cross_entropy(pred_class_logits, target_class_ids.long())
     else:
         loss = torch.FloatTensor([0.]).cuda()
 
     return loss
 
 def compute_mrcnn_regression_loss(tasks, pred, target, target_class_ids):
     """regression loss is a distance metric between target vector and predicted regression vector.
     :param pred: (n_sampled_rois, n_classes, [n_rg_feats if real regression or 1 if rg_bin task)
     :param target: (n_sampled_rois, [n_rg_feats or n_rg_bins])
     :return: differentiable loss, torch 1D tensor on cuda
     """
 
     if not 0 in target.shape and not 0 in torch.nonzero(target_class_ids > 0).shape:
         positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0]
         positive_roi_class_ids = target_class_ids[positive_roi_ix].long()
         target = target[positive_roi_ix].detach()
         pred = pred[positive_roi_ix, positive_roi_class_ids]
         if "regression_bin" in tasks:
             loss = F.cross_entropy(pred, target.long())
         else:
             loss = F.smooth_l1_loss(pred, target)
             #loss = F.mse_loss(pred, target)
     else:
         loss = torch.FloatTensor([0.]).cuda()
 
     return loss
 
 ############################################################
 #  Detection Layer
 ############################################################
 
 def compute_roi_scores(tasks, batch_rpn_proposals, mrcnn_cl_logits):
     """ Depending on the predicition tasks: if no class prediction beyong fg/bg (--> means no additional class
         head was applied) use RPN objectness scores as roi scores, otherwise class head scores.
     :param cf:
     :param batch_rpn_proposals:
     :param mrcnn_cl_logits:
     :return:
     """
     if not 'class' in tasks:
         scores = batch_rpn_proposals[:, :, -1].view(-1, 1)
         scores = torch.cat((1 - scores, scores), dim=1)
     else:
         scores = F.softmax(mrcnn_cl_logits, dim=1)
 
     return scores
 
 ############################################################
 #  MaskRCNN Class
 ############################################################
 
 class net(nn.Module):
 
 
     def __init__(self, cf, logger):
 
         super(net, self).__init__()
         self.cf = cf
         self.logger = logger
         self.build()
 
         loss_order = ['rpn_class', 'rpn_bbox', 'mrcnn_bbox', 'mrcnn_mask', 'mrcnn_class', 'mrcnn_rg']
         if hasattr(cf, "mrcnn_loss_weights"):
             # bring into right order
             self.loss_weights = np.array([cf.mrcnn_loss_weights[k] for k in loss_order])
         else:
             self.loss_weights = np.array([1.]*len(loss_order))
 
         if self.cf.weight_init=="custom":
             logger.info("Tried to use custom weight init which is not defined. Using pytorch default.")
         elif self.cf.weight_init:
             mutils.initialize_weights(self)
         else:
             logger.info("using default pytorch weight init")
 
     def build(self):
         """Build Mask R-CNN architecture."""
 
         # Image size must be dividable by 2 multiple times.
         h, w = self.cf.patch_size[:2]
         if h / 2**5 != int(h / 2**5) or w / 2**5 != int(w / 2**5):
             raise Exception("Image size must be divisible by 2 at least 5 times "
                             "to avoid fractions when downscaling and upscaling."
                             "For example, use 256, 288, 320, 384, 448, 512, ... etc.,i.e.,"
                             "any number x*32 will do!")
 
         # instantiate abstract multi-dimensional conv generator and load backbone module.
         backbone = utils.import_module('bbone', self.cf.backbone_path)
         self.logger.info("loaded backbone from {}".format(self.cf.backbone_path))
         conv = backbone.ConvGenerator(self.cf.dim)
 
         # build Anchors, FPN, RPN, Classifier / Bbox-Regressor -head, Mask-head
         self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf)
         self.anchors = torch.from_numpy(self.np_anchors).float().cuda()
         self.fpn = backbone.FPN(self.cf, conv, relu_enc=self.cf.relu, operate_stride1=False).cuda()
         self.rpn = RPN(self.cf, conv)
         self.classifier = Classifier(self.cf, conv)
         self.mask = Mask(self.cf, conv)
 
     def forward(self, img, is_training=True):
         """
         :param img: input images (b, c, y, x, (z)).
         :return: rpn_pred_logits: (b, n_anchors, 2)
         :return: rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d))))
         :return: batch_proposal_boxes: (b, n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix)) only for monitoring/plotting.
         :return: detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score)
         :return: detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head.
         """
         # extract features.
         fpn_outs = self.fpn(img)
         rpn_feature_maps = [fpn_outs[i] for i in self.cf.pyramid_levels]
         self.mrcnn_feature_maps = rpn_feature_maps
 
         # loop through pyramid layers and apply RPN.
         layer_outputs = [ self.rpn(p_feats) for p_feats in rpn_feature_maps ]
 
         # concatenate layer outputs.
         # convert from list of lists of level outputs to list of lists of outputs across levels.
         # e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]]
         outputs = list(zip(*layer_outputs))
         outputs = [torch.cat(list(o), dim=1) for o in outputs]
         rpn_pred_logits, rpn_pred_probs, rpn_pred_deltas = outputs
         #
         # # generate proposals: apply predicted deltas to anchors and filter by foreground scores from RPN classifier.
         proposal_count = self.cf.post_nms_rois_training if is_training else self.cf.post_nms_rois_inference
         batch_normed_props, batch_unnormed_props = mutils.refine_proposals(rpn_pred_probs, rpn_pred_deltas,
                                                                             proposal_count, self.anchors, self.cf)
 
         # merge batch dimension of proposals while storing allocation info in coordinate dimension.
         batch_ixs = torch.arange(
             batch_normed_props.shape[0]).cuda().unsqueeze(1).repeat(1,batch_normed_props.shape[1]).view(-1).float()
         rpn_rois = batch_normed_props[:, :, :-1].view(-1, batch_normed_props[:, :, :-1].shape[2])
         self.rpn_rois_batch_info = torch.cat((rpn_rois, batch_ixs.unsqueeze(1)), dim=1)
 
         # this is the first of two forward passes in the second stage, where no activations are stored for backprop.
         # here, all proposals are forwarded (with virtual_batch_size = batch_size * post_nms_rois.)
         # for inference/monitoring as well as sampling of rois for the loss functions.
         # processed in chunks of roi_chunk_size to re-adjust to gpu-memory.
         chunked_rpn_rois = self.rpn_rois_batch_info.split(self.cf.roi_chunk_size)
         bboxes_list, class_logits_list, regressions_list = [], [], []
         with torch.no_grad():
             for chunk in chunked_rpn_rois:
                 chunk_bboxes, chunk_class_logits, chunk_regressions = self.classifier(self.mrcnn_feature_maps, chunk)
                 bboxes_list.append(chunk_bboxes)
                 class_logits_list.append(chunk_class_logits)
                 regressions_list.append(chunk_regressions)
         mrcnn_bbox = torch.cat(bboxes_list, 0)
         mrcnn_class_logits = torch.cat(class_logits_list, 0)
         mrcnn_regressions = torch.cat(regressions_list, 0)
         self.mrcnn_roi_scores = compute_roi_scores(self.cf.prediction_tasks, batch_normed_props, mrcnn_class_logits)
 
         # refine classified proposals, filter and return final detections.
         # returns (cf.max_inst_per_batch_element, n_coords+1+...)
         detections = mutils.refine_detections(self.cf, batch_ixs, rpn_rois, mrcnn_bbox, self.mrcnn_roi_scores,
                                        mrcnn_regressions)
 
         # forward remaining detections through mask-head to generate corresponding masks.
         scale = [img.shape[2]] * 4 + [img.shape[-1]] * 2
         scale = torch.from_numpy(np.array(scale[:self.cf.dim * 2] + [1])[None]).float().cuda()
 
         # first self.cf.dim * 2 entries on axis 1 are always the box coords, +1 is batch_ix
         detection_boxes = detections[:, :self.cf.dim * 2 + 1] / scale
         with torch.no_grad():
             detection_masks = self.mask(self.mrcnn_feature_maps, detection_boxes)
 
         return [rpn_pred_logits, rpn_pred_deltas, batch_unnormed_props, detections, detection_masks]
 
 
     def loss_samples_forward(self, batch_gt_boxes, batch_gt_masks, batch_gt_class_ids, batch_gt_regressions=None):
         """
         this is the second forward pass through the second stage (features from stage one are re-used).
         samples few rois in loss_example_mining and forwards only those for loss computation.
         :param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels.
         :param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates.
-        :param batch_gt_masks: (b,n(b),y,x (,z), c) list over batch elements. Each element holds n_gt_rois(b)
-                (i.e., dependent on the batch element) binary masks of shape (y, x, (z), c).
+        :param batch_gt_masks: (b, n(b), c, y, x (,z)) list over batch elements. Each element holds n_gt_rois(b)
+                (i.e., dependent on the batch element) binary masks of shape (c, y, x, (z)).
         :return: sample_logits: (n_sampled_rois, n_classes) predicted class scores.
         :return: sample_deltas: (n_sampled_rois, n_classes, 2 * dim) predicted corrections to be applied to proposals for refinement.
         :return: sample_mask: (n_sampled_rois, n_classes, y, x, (z)) predicted masks per class and proposal.
         :return: sample_target_class_ids: (n_sampled_rois) target class labels of sampled proposals.
         :return: sample_target_deltas: (n_sampled_rois, 2 * dim) target deltas of sampled proposals for box refinement.
         :return: sample_target_masks: (n_sampled_rois, y, x, (z)) target masks of sampled proposals.
         :return: sample_proposals: (n_sampled_rois, 2 * dim) RPN output for sampled proposals. only for monitoring/plotting.
         """
         # sample rois for loss and get corresponding targets for all Mask R-CNN head network losses.
         sample_ics, sample_target_deltas, sample_target_mask, sample_target_class_ids, sample_target_regressions = \
             mutils.loss_example_mining(self.cf, self.rpn_rois_batch_info, batch_gt_boxes, batch_gt_masks,
                                        self.mrcnn_roi_scores, batch_gt_class_ids, batch_gt_regressions)
 
         # re-use feature maps and RPN output from first forward pass.
         sample_proposals = self.rpn_rois_batch_info[sample_ics]
         if not 0 in sample_proposals.size():
             sample_deltas, sample_logits, sample_regressions = self.classifier(self.mrcnn_feature_maps, sample_proposals)
             sample_mask = self.mask(self.mrcnn_feature_maps, sample_proposals)
         else:
             sample_logits = torch.FloatTensor().cuda()
             sample_deltas = torch.FloatTensor().cuda()
             sample_regressions = torch.FloatTensor().cuda()
             sample_mask = torch.FloatTensor().cuda()
 
         return [sample_deltas, sample_mask, sample_logits, sample_regressions, sample_proposals,
                 sample_target_deltas, sample_target_mask, sample_target_class_ids, sample_target_regressions]
 
     def get_results(self, img_shape, detections, detection_masks, box_results_list=None, return_masks=True):
         """
         Restores batch dimension of merged detections, unmolds detections, creates and fills results dict.
         :param img_shape:
         :param detections: shape (n_final_detections, len(info)), where
             info=( y1, x1, y2, x2, (z1,z2), batch_ix, pred_class_id, pred_score )
         :param detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head.
         :param box_results_list: None or list of output boxes for monitoring/plotting.
         each element is a list of boxes per batch element.
         :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off).
         :return: results_dict: dictionary with keys:
                  'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                           [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                  'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, 1] only fg. vs. bg for now.
                  class-specific return of masks will come with implementation of instance segmentation evaluation.
         """
 
         detections = detections.cpu().data.numpy()
         if self.cf.dim == 2:
             detection_masks = detection_masks.permute(0, 2, 3, 1).cpu().data.numpy()
         else:
             detection_masks = detection_masks.permute(0, 2, 3, 4, 1).cpu().data.numpy()
         # det masks shape now (n_dets, y,x(,z), n_classes)
         # restore batch dimension of merged detections using the batch_ix info.
         batch_ixs = detections[:, self.cf.dim*2]
         detections = [detections[batch_ixs == ix] for ix in range(img_shape[0])]
         mrcnn_mask = [detection_masks[batch_ixs == ix] for ix in range(img_shape[0])]
         # mrcnn_mask: shape (b_size, variable, variable, n_classes), variable bc depends on single instance mask size
 
         if box_results_list == None: # for test_forward, where no previous list exists.
             box_results_list =  [[] for _ in range(img_shape[0])]
         # seg_logits == seg_probs in mrcnn since mask head finishes with sigmoid (--> image space = [0,1])
         seg_probs = []
         # loop over batch and unmold detections.
         for ix in range(img_shape[0]):
 
             # final masks are one-hot encoded (b, n_classes, y, x, (z))
             final_masks = np.zeros((self.cf.num_classes + 1, *img_shape[2:]))
             #+1 for bg, 0.5 bc mask head classifies only bg/fg with logits between 0,1--> bg is <0.5
             if self.cf.num_classes + 1 != self.cf.num_seg_classes:
                 self.logger.warning("n of roi-classifier head classes {} doesnt match cf.num_seg_classes {}".format(
                     self.cf.num_classes + 1, self.cf.num_seg_classes))
 
             if not 0 in detections[ix].shape:
                 boxes = detections[ix][:, :self.cf.dim*2].astype(np.int32)
                 class_ids = detections[ix][:, self.cf.dim*2 + 1].astype(np.int32)
                 scores = detections[ix][:, self.cf.dim*2 + 2]
                 masks = mrcnn_mask[ix][np.arange(boxes.shape[0]), ..., class_ids]
                 regressions = detections[ix][:,self.cf.dim*2+3:]
 
                 # Filter out detections with zero area. Often only happens in early
                 # stages of training when the network weights are still a bit random.
                 if self.cf.dim == 2:
                     exclude_ix = np.where((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0]
                 else:
                     exclude_ix = np.where(
                         (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4]) <= 0)[0]
 
                 if exclude_ix.shape[0] > 0:
                     boxes = np.delete(boxes, exclude_ix, axis=0)
                     masks = np.delete(masks, exclude_ix, axis=0)
                     class_ids = np.delete(class_ids, exclude_ix, axis=0)
                     scores = np.delete(scores, exclude_ix, axis=0)
                     regressions = np.delete(regressions, exclude_ix, axis=0)
 
                 # Resize masks to original image size and set boundary threshold.
                 if return_masks:
                     for i in range(masks.shape[0]): #masks per this batch instance/element/image
                         # Convert neural network mask to full size mask
                         if self.cf.dim == 2:
                             full_mask = mutils.unmold_mask_2D(masks[i], boxes[i], img_shape[2:])
                         else:
                             full_mask = mutils.unmold_mask_3D(masks[i], boxes[i], img_shape[2:])
                         # take the maximum seg_logits per class of instances in that class, i.e., a pixel in a class
                         # has the max seg_logit value over all instances of that class in one sample
                         final_masks[class_ids[i]] = np.max((final_masks[class_ids[i]], full_mask), axis=0)
                     final_masks[0] = np.full(final_masks[0].shape, 0.49999999) #effectively min_det_thres at 0.5 per pixel
 
                 # add final predictions to results.
                 if not 0 in boxes.shape:
                     for ix2, coords in enumerate(boxes):
                         box = {'box_coords': coords, 'box_type': 'det', 'box_score': scores[ix2],
                                'box_pred_class_id': class_ids[ix2]}
                         #if (hasattr(self.cf, "convert_cl_to_rg") and self.cf.convert_cl_to_rg):
                         if "regression_bin" in self.cf.prediction_tasks:
                             # in this case, regression preds are actually the rg_bin_ids --> map to rg value the bin represents
                             box['rg_bin'] = regressions[ix2].argmax()
                             box['regression'] = self.cf.bin_id2rg_val[box['rg_bin']]
                         else:
                             box['regression'] = regressions[ix2]
                             if hasattr(self.cf, "rg_val_to_bin_id") and \
                                     any(['regression' in task for task in self.cf.prediction_tasks]):
                                 box.update({'rg_bin': self.cf.rg_val_to_bin_id(regressions[ix2])})
 
                         box_results_list[ix].append(box)
 
             # if no detections were made--> keep full bg mask (zeros).
             seg_probs.append(final_masks)
 
         # create and fill results dictionary.
         results_dict = {}
         results_dict['boxes'] = box_results_list
         results_dict['seg_preds'] = np.array(seg_probs)
 
         return results_dict
 
     def train_forward(self, batch, is_validation=False):
         """
         train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
         for processing, computes losses, and stores outputs in a dictionary.
         :param batch: dictionary containing 'data', 'seg', etc.
+            batch['roi_masks']: (b, n(b), c, h(n), w(n) (z(n))) list like roi_labels but with arrays (masks) inplace of
+        integers. c==channels of the raw segmentation.
         :return: results_dict: dictionary with keys:
                 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                         [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes].
                 'torch_loss': 1D torch tensor for backprop.
                 'class_loss': classification loss for monitoring.
         """
         img = batch['data']
         gt_boxes = batch['bb_target']
-        axes = (0, 2, 3, 1) if self.cf.dim == 2 else (0, 2, 3, 4, 1)
-        gt_masks = [np.transpose(batch['roi_masks'][ii], axes=axes) for ii in range(len(batch['roi_masks']))]
+        #axes = (0, 2, 3, 1) if self.cf.dim == 2 else (0, 2, 3, 4, 1)
+        #gt_masks = [np.transpose(batch['roi_masks'][ii], axes=axes) for ii in range(len(batch['roi_masks']))]
+        gt_masks = batch['roi_masks']
         gt_class_ids = batch['class_targets']
         if 'regression' in self.cf.prediction_tasks:
             gt_regressions = batch["regression_targets"]
         elif 'regression_bin' in self.cf.prediction_tasks:
             gt_regressions = batch["rg_bin_targets"]
         else:
             gt_regressions = None
 
         img = torch.from_numpy(img).cuda().float()
         batch_rpn_class_loss = torch.FloatTensor([0]).cuda()
         batch_rpn_bbox_loss = torch.FloatTensor([0]).cuda()
 
         # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element.
         box_results_list = [[] for _ in range(img.shape[0])]
 
         #forward passes. 1. general forward pass, where no activations are saved in second stage (for performance
         # monitoring and loss sampling). 2. second stage forward pass of sampled rois with stored activations for backprop.
         rpn_class_logits, rpn_pred_deltas, proposal_boxes, detections, detection_masks = self.forward(img)
 
         mrcnn_pred_deltas, mrcnn_pred_mask, mrcnn_class_logits, mrcnn_regressions, sample_proposals, \
         mrcnn_target_deltas, target_mask, target_class_ids, target_regressions = \
             self.loss_samples_forward(gt_boxes, gt_masks, gt_class_ids, gt_regressions)
-
         # loop over batch
         for b in range(img.shape[0]):
             if len(gt_boxes[b]) > 0:
                 # add gt boxes to output list
                 for tix in range(len(gt_boxes[b])):
                     gt_box = {'box_type': 'gt', 'box_coords': batch['bb_target'][b][tix]}
                     for name in self.cf.roi_items:
                         gt_box.update({name: batch[name][b][tix]})
                     box_results_list[b].append(gt_box)
 
                 # match gt boxes with anchors to generate targets for RPN losses.
                 rpn_match, rpn_target_deltas = mutils.gt_anchor_matching(self.cf, self.np_anchors, gt_boxes[b])
 
                 # add positive anchors used for loss to output list for monitoring.
                 pos_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == 1)][:, 0], img.shape[2:])
                 for p in pos_anchors:
                     box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'})
 
             else:
                 rpn_match = np.array([-1]*self.np_anchors.shape[0])
                 rpn_target_deltas = np.array([0])
 
             rpn_match_gpu = torch.from_numpy(rpn_match).cuda()
             rpn_target_deltas = torch.from_numpy(rpn_target_deltas).float().cuda()
 
             # compute RPN losses.
             rpn_class_loss, neg_anchor_ix = compute_rpn_class_loss(rpn_class_logits[b], rpn_match_gpu, self.cf.shem_poolsize)
             rpn_bbox_loss = compute_rpn_bbox_loss(rpn_pred_deltas[b], rpn_target_deltas, rpn_match_gpu)
             batch_rpn_class_loss += rpn_class_loss /img.shape[0]
             batch_rpn_bbox_loss += rpn_bbox_loss /img.shape[0]
 
             # add negative anchors used for loss to output list for monitoring.
             # neg_anchor_ix = neg_ix come from shem and mark positions in roi_probs_neg = rpn_class_logits[neg_indices]
             # with neg_indices = rpn_match == -1
             neg_anchors = mutils.clip_boxes_numpy(self.np_anchors[rpn_match == -1][neg_anchor_ix], img.shape[2:])
             for n in neg_anchors:
                 box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'})
 
             # add highest scoring proposals to output list for monitoring.
             rpn_proposals = proposal_boxes[b][proposal_boxes[b, :, -1].argsort()][::-1]
             for r in rpn_proposals[:self.cf.n_plot_rpn_props, :-1]:
                 box_results_list[b].append({'box_coords': r, 'box_type': 'prop'})
 
         # add positive and negative roi samples used for mrcnn losses to output list for monitoring.
         if not 0 in sample_proposals.shape:
             rois = mutils.clip_to_window(self.cf.window, sample_proposals).cpu().data.numpy()
             for ix, r in enumerate(rois):
                 box_results_list[int(r[-1])].append({'box_coords': r[:-1] * self.cf.scale,
                                             'box_type': 'pos_class' if target_class_ids[ix] > 0 else 'neg_class'})
 
         # compute mrcnn losses.
         mrcnn_class_loss = compute_mrcnn_class_loss(self.cf.prediction_tasks, mrcnn_class_logits, target_class_ids)
         mrcnn_bbox_loss = compute_mrcnn_bbox_loss(mrcnn_pred_deltas, mrcnn_target_deltas, target_class_ids)
         mrcnn_regressions_loss = compute_mrcnn_regression_loss(self.cf.prediction_tasks, mrcnn_regressions, target_regressions, target_class_ids)
         # mrcnn can be run without pixelwise annotations available (Faster R-CNN mode).
         # In this case, the mask_loss is taken out of training.
-        if not self.cf.frcnn_mode:
-            mrcnn_mask_loss = compute_mrcnn_mask_loss(mrcnn_pred_mask, target_mask, target_class_ids)
-        else:
+        if self.cf.frcnn_mode:
             mrcnn_mask_loss = torch.FloatTensor([0]).cuda()
+        else:
+            mrcnn_mask_loss = compute_mrcnn_mask_loss(mrcnn_pred_mask, target_mask, target_class_ids)
 
         loss = batch_rpn_class_loss + batch_rpn_bbox_loss +\
                mrcnn_bbox_loss + mrcnn_mask_loss +  mrcnn_class_loss + mrcnn_regressions_loss
 
-
-        # monitor RPN performance: detection count = the number of correctly matched proposals per fg-class.
-        #dcount = [list(target_class_ids.cpu().data.numpy()).count(c) for c in np.arange(self.cf.head_classes)[1:]]
-        #self.logger.info("regression loss {:.3f}".format(mrcnn_regressions_loss.item()))
-        #self.logger.info("loss: {0:.2f}, rpn_class: {1:.2f}, rpn_bbox: {2:.2f}, mrcnn_class: {3:.2f}, mrcnn_bbox: {4:.2f}, "
-        #      "mrcnn_mask: {5:.2f}, dcount {6}".format(loss.item(), batch_rpn_class_loss.item(),
-        #      batch_rpn_bbox_loss.item(), mrcnn_class_loss.item(), mrcnn_bbox_loss.item(), mrcnn_mask_loss.item(), dcount))
-
         # run unmolding of predictions for monitoring and merge all results to one dictionary.
         return_masks = self.cf.return_masks_in_val if is_validation else self.cf.return_masks_in_train
         results_dict = self.get_results(img.shape, detections, detection_masks, box_results_list,
                                         return_masks=return_masks)
+
         results_dict['seg_preds'] = results_dict['seg_preds'].argmax(axis=1).astype('uint8')[:,np.newaxis]
         if 'dice' in self.cf.metrics:
             results_dict['batch_dices'] = mutils.dice_per_batch_and_class(
                 results_dict['seg_preds'], batch["seg"], self.cf.num_seg_classes, convert_to_ohe=True)
 
         results_dict['torch_loss'] = loss
         results_dict['class_loss'] = mrcnn_class_loss.item()
         results_dict['bbox_loss'] = mrcnn_bbox_loss.item()
         results_dict['rg_loss'] = mrcnn_regressions_loss.item()
         results_dict['rpn_class_loss'] = rpn_class_loss.item()
         results_dict['rpn_bbox_loss'] = rpn_bbox_loss.item()
 
         return results_dict
 
 
     def test_forward(self, batch, return_masks=True):
         """
         test method. wrapper around forward pass of network without usage of any ground truth information.
         prepares input data for processing and stores outputs in a dictionary.
         :param batch: dictionary containing 'data'
         :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off).
         :return: results_dict: dictionary with keys:
                'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                        [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]
         """
         img = batch['data']
         img = torch.from_numpy(img).float().cuda()
         _, _, _, detections, detection_masks = self.forward(img)
         results_dict = self.get_results(img.shape, detections, detection_masks, return_masks=return_masks)
 
         return results_dict
\ No newline at end of file
diff --git a/utils/model_utils.py b/utils/model_utils.py
index 40286b7..8a2346b 100644
--- a/utils/model_utils.py
+++ b/utils/model_utils.py
@@ -1,1526 +1,1529 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 """
 Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn
 published under MIT license.
 """
 import warnings
 warnings.filterwarnings('ignore', '.*From scipy 0.13.0, the output shape of zoom()*')
 
 import numpy as np
 import scipy.misc
 import scipy.ndimage
 import scipy.interpolate
 from scipy.ndimage.measurements import label as lb
 import torch
 
 import tqdm
 
 from custom_extensions.nms import nms
 from custom_extensions.roi_align import roi_align
 
+import torchvision as tv
 
 ############################################################
 #  Segmentation Processing
 ############################################################
 
 def sum_tensor(input, axes, keepdim=False):
     axes = np.unique(axes)
     if keepdim:
         for ax in axes:
             input = input.sum(ax, keepdim=True)
     else:
         for ax in sorted(axes, reverse=True):
             input = input.sum(int(ax))
     return input
 
 def get_one_hot_encoding(y, n_classes):
     """
     transform a numpy label array to a one-hot array of the same shape.
     :param y: array of shape (b, 1, y, x, (z)).
     :param n_classes: int, number of classes to unfold in one-hot encoding.
     :return y_ohe: array of shape (b, n_classes, y, x, (z))
     """
 
     dim = len(y.shape) - 2
     if dim == 2:
         y_ohe = np.zeros((y.shape[0], n_classes, y.shape[2], y.shape[3])).astype('int32')
     elif dim == 3:
         y_ohe = np.zeros((y.shape[0], n_classes, y.shape[2], y.shape[3], y.shape[4])).astype('int32')
     else:
         raise Exception("invalid dimensions {} encountered".format(y.shape))
     for cl in np.arange(n_classes):
         y_ohe[:, cl][y[:, 0] == cl] = 1
     return y_ohe
 
 def dice_per_batch_inst_and_class(pred, y, n_classes, convert_to_ohe=True, smooth=1e-8):
     '''
     computes dice scores per batch instance and class.
     :param pred: prediction array of shape (b, 1, y, x, (z)) (e.g. softmax prediction with argmax over dim 1)
     :param y: ground truth array of shape (b, 1, y, x, (z)) (contains int [0, ..., n_classes]
     :param n_classes: int
     :return: dice scores of shape (b, c)
     '''
     if convert_to_ohe:
         pred = get_one_hot_encoding(pred, n_classes)
         y = get_one_hot_encoding(y, n_classes)
     axes = tuple(range(2, len(pred.shape)))
     intersect = np.sum(pred*y, axis=axes)
     denominator = np.sum(pred, axis=axes)+np.sum(y, axis=axes)
     dice = (2.0*intersect + smooth) / (denominator + smooth)
     return dice
 
 def dice_per_batch_and_class(pred, targ, n_classes, convert_to_ohe=True, smooth=1e-8):
     '''
     computes dice scores per batch and class.
     :param pred: prediction array of shape (b, 1, y, x, (z)) (e.g. softmax prediction with argmax over dim 1)
     :param targ: ground truth array of shape (b, 1, y, x, (z)) (contains int [0, ..., n_classes])
     :param n_classes: int
     :param smooth: Laplacian smooth, https://en.wikipedia.org/wiki/Additive_smoothing
     :return: dice scores of shape (b, c)
     '''
     if convert_to_ohe:
         pred = get_one_hot_encoding(pred, n_classes)
         targ = get_one_hot_encoding(targ, n_classes)
     axes = (0, *list(range(2, len(pred.shape)))) #(0,2,3(,4))
 
     intersect = np.sum(pred * targ, axis=axes)
 
     denominator = np.sum(pred, axis=axes) + np.sum(targ, axis=axes)
     dice = (2.0 * intersect + smooth) / (denominator + smooth)
 
     assert dice.shape==(n_classes,), "dice shp {}".format(dice.shape)
     return dice
 
 
 def batch_dice(pred, y, false_positive_weight=1.0, smooth=1e-6):
     '''
     compute soft dice over batch. this is a differentiable score and can be used as a loss function.
     only dice scores of foreground classes are returned, since training typically
     does not benefit from explicit background optimization. Pixels of the entire batch are considered a pseudo-volume to compute dice scores of.
     This way, single patches with missing foreground classes can not produce faulty gradients.
     :param pred: (b, c, y, x, (z)), softmax probabilities (network output).
     :param y: (b, c, y, x, (z)), one hote encoded segmentation mask.
     :param false_positive_weight: float [0,1]. For weighting of imbalanced classes,
     reduces the penalty for false-positive pixels. Can be beneficial sometimes in data with heavy fg/bg imbalances.
     :return: soft dice score (float).This function discards the background score and returns the mena of foreground scores.
     '''
 
     if len(pred.size()) == 4:
         axes = (0, 2, 3)
         intersect = sum_tensor(pred * y, axes, keepdim=False)
         denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False)
         return torch.mean(( (2*intersect + smooth) / (denom + smooth))[1:]) #only fg dice here.
 
     elif len(pred.size()) == 5:
         axes = (0, 2, 3, 4)
         intersect = sum_tensor(pred * y, axes, keepdim=False)
         denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False)
         return torch.mean(( (2*intersect + smooth) / (denom + smooth))[1:]) #only fg dice here.
     else:
         raise ValueError('wrong input dimension in dice loss')
 
 
 ############################################################
 #  Bounding Boxes
 ############################################################
 
 def compute_iou_2D(box, boxes, box_area, boxes_area):
     """Calculates IoU of the given box with the array of the given boxes.
     box: 1D vector [y1, x1, y2, x2] THIS IS THE GT BOX
     boxes: [boxes_count, (y1, x1, y2, x2)]
     box_area: float. the area of 'box'
     boxes_area: array of length boxes_count.
 
     Note: the areas are passed in rather than calculated here for
           efficency. Calculate once in the caller to avoid duplicate work.
     """
     # Calculate intersection areas
     y1 = np.maximum(box[0], boxes[:, 0])
     y2 = np.minimum(box[2], boxes[:, 2])
     x1 = np.maximum(box[1], boxes[:, 1])
     x2 = np.minimum(box[3], boxes[:, 3])
     intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0)
     union = box_area + boxes_area[:] - intersection[:]
     iou = intersection / union
 
     return iou
 
 
 def compute_iou_3D(box, boxes, box_volume, boxes_volume):
     """Calculates IoU of the given box with the array of the given boxes.
     box: 1D vector [y1, x1, y2, x2, z1, z2] (typically gt box)
     boxes: [boxes_count, (y1, x1, y2, x2, z1, z2)]
     box_area: float. the area of 'box'
     boxes_area: array of length boxes_count.
 
     Note: the areas are passed in rather than calculated here for
           efficency. Calculate once in the caller to avoid duplicate work.
     """
     # Calculate intersection areas
     y1 = np.maximum(box[0], boxes[:, 0])
     y2 = np.minimum(box[2], boxes[:, 2])
     x1 = np.maximum(box[1], boxes[:, 1])
     x2 = np.minimum(box[3], boxes[:, 3])
     z1 = np.maximum(box[4], boxes[:, 4])
     z2 = np.minimum(box[5], boxes[:, 5])
     intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0) * np.maximum(z2 - z1, 0)
     union = box_volume + boxes_volume[:] - intersection[:]
     iou = intersection / union
 
     return iou
 
 
 
 def compute_overlaps(boxes1, boxes2):
     """Computes IoU overlaps between two sets of boxes.
     boxes1, boxes2: [N, (y1, x1, y2, x2)]. / 3D: (z1, z2))
     For better performance, pass the largest set first and the smaller second.
     :return: (#boxes1, #boxes2), ious of each box of 1 machted with each of 2
     """
     # Areas of anchors and GT boxes
     if boxes1.shape[1] == 4:
         area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
         area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
         # Compute overlaps to generate matrix [boxes1 count, boxes2 count]
         # Each cell contains the IoU value.
         overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0]))
         for i in range(overlaps.shape[1]):
             box2 = boxes2[i] #this is the gt box
             overlaps[:, i] = compute_iou_2D(box2, boxes1, area2[i], area1)
         return overlaps
 
     else:
         # Areas of anchors and GT boxes
         volume1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) * (boxes1[:, 5] - boxes1[:, 4])
         volume2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) * (boxes2[:, 5] - boxes2[:, 4])
         # Compute overlaps to generate matrix [boxes1 count, boxes2 count]
         # Each cell contains the IoU value.
         overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0]))
         for i in range(boxes2.shape[0]):
             box2 = boxes2[i]  # this is the gt box
             overlaps[:, i] = compute_iou_3D(box2, boxes1, volume2[i], volume1)
         return overlaps
 
 
 
 def box_refinement(box, gt_box):
     """Compute refinement needed to transform box to gt_box.
     box and gt_box are [N, (y1, x1, y2, x2)] / 3D: (z1, z2))
     """
     height = box[:, 2] - box[:, 0]
     width = box[:, 3] - box[:, 1]
     center_y = box[:, 0] + 0.5 * height
     center_x = box[:, 1] + 0.5 * width
 
     gt_height = gt_box[:, 2] - gt_box[:, 0]
     gt_width = gt_box[:, 3] - gt_box[:, 1]
     gt_center_y = gt_box[:, 0] + 0.5 * gt_height
     gt_center_x = gt_box[:, 1] + 0.5 * gt_width
 
     dy = (gt_center_y - center_y) / height
     dx = (gt_center_x - center_x) / width
     dh = torch.log(gt_height / height)
     dw = torch.log(gt_width / width)
     result = torch.stack([dy, dx, dh, dw], dim=1)
 
     if box.shape[1] > 4:
         depth = box[:, 5] - box[:, 4]
         center_z = box[:, 4] + 0.5 * depth
         gt_depth = gt_box[:, 5] - gt_box[:, 4]
         gt_center_z = gt_box[:, 4] + 0.5 * gt_depth
         dz = (gt_center_z - center_z) / depth
         dd = torch.log(gt_depth / depth)
         result = torch.stack([dy, dx, dz, dh, dw, dd], dim=1)
 
     return result
 
 
 
 def unmold_mask_2D(mask, bbox, image_shape):
     """Converts a mask generated by the neural network into a format similar
     to it's original shape.
     mask: [height, width] of type float. A small, typically 28x28 mask.
     bbox: [y1, x1, y2, x2]. The box to fit the mask in.
 
     Returns a binary mask with the same size as the original image.
     """
     y1, x1, y2, x2 = bbox
     out_zoom = [y2 - y1, x2 - x1]
     zoom_factor = [i / j for i, j in zip(out_zoom, mask.shape)]
 
     mask = scipy.ndimage.zoom(mask, zoom_factor, order=1).astype(np.float32)
 
     # Put the mask in the right location.
     full_mask = np.zeros(image_shape[:2]) #only y,x
     full_mask[y1:y2, x1:x2] = mask
     return full_mask
 
 
 def unmold_mask_2D_torch(mask, bbox, image_shape):
     """Converts a mask generated by the neural network into a format similar
     to it's original shape.
     mask: [height, width] of type float. A small, typically 28x28 mask.
     bbox: [y1, x1, y2, x2]. The box to fit the mask in.
 
     Returns a binary mask with the same size as the original image.
     """
     y1, x1, y2, x2 = bbox
     out_zoom = [(y2 - y1).float(), (x2 - x1).float()]
     zoom_factor = [i / j for i, j in zip(out_zoom, mask.shape)]
 
     mask = mask.unsqueeze(0).unsqueeze(0)
     mask = torch.nn.functional.interpolate(mask, scale_factor=zoom_factor)
     mask = mask[0][0]
     #mask = scipy.ndimage.zoom(mask.cpu().numpy(), zoom_factor, order=1).astype(np.float32)
     #mask = torch.from_numpy(mask).cuda()
     # Put the mask in the right location.
     full_mask = torch.zeros(image_shape[:2])  # only y,x
     full_mask[y1:y2, x1:x2] = mask
     return full_mask
 
 
 
 def unmold_mask_3D(mask, bbox, image_shape):
     """Converts a mask generated by the neural network into a format similar
     to it's original shape.
     mask: [height, width] of type float. A small, typically 28x28 mask.
     bbox: [y1, x1, y2, x2, z1, z2]. The box to fit the mask in.
 
     Returns a binary mask with the same size as the original image.
     """
     y1, x1, y2, x2, z1, z2 = bbox
     out_zoom = [y2 - y1, x2 - x1, z2 - z1]
     zoom_factor = [i/j for i,j in zip(out_zoom, mask.shape)]
     mask = scipy.ndimage.zoom(mask, zoom_factor, order=1).astype(np.float32)
 
     # Put the mask in the right location.
     full_mask = np.zeros(image_shape[:3])
     full_mask[y1:y2, x1:x2, z1:z2] = mask
     return full_mask
 
 def nms_numpy(box_coords, scores, thresh):
     """ non-maximum suppression on 2D or 3D boxes in numpy.
     :param box_coords: [y1,x1,y2,x2 (,z1,z2)] with y1<=y2, x1<=x2, z1<=z2.
     :param scores: ranking scores (higher score == higher rank) of boxes.
     :param thresh: IoU threshold for clustering.
     :return:
     """
     y1 = box_coords[:, 0]
     x1 = box_coords[:, 1]
     y2 = box_coords[:, 2]
     x2 = box_coords[:, 3]
     assert np.all(y1 <= y2) and np.all(x1 <= x2), """"the definition of the coordinates is crucially important here: 
             coordinates of which maxima are taken need to be the lower coordinates"""
     areas = (x2 - x1) * (y2 - y1)
 
     is_3d = box_coords.shape[1] == 6
     if is_3d: # 3-dim case
         z1 = box_coords[:, 4]
         z2 = box_coords[:, 5]
         assert np.all(z1<=z2), """"the definition of the coordinates is crucially important here: 
            coordinates of which maxima are taken need to be the lower coordinates"""
         areas *= (z2 - z1)
 
     order = scores.argsort()[::-1]
 
     keep = []
     while order.size > 0:  # order is the sorted index.  maps order to index: order[1] = 24 means (rank1, ix 24)
         i = order[0] # highest scoring element
         yy1 = np.maximum(y1[i], y1[order])  # highest scoring element still in >order<, is compared to itself, that is okay.
         xx1 = np.maximum(x1[i], x1[order])
         yy2 = np.minimum(y2[i], y2[order])
         xx2 = np.minimum(x2[i], x2[order])
 
         h = np.maximum(0.0, yy2 - yy1)
         w = np.maximum(0.0, xx2 - xx1)
         inter = h * w
 
         if is_3d:
             zz1 = np.maximum(z1[i], z1[order])
             zz2 = np.minimum(z2[i], z2[order])
             d = np.maximum(0.0, zz2 - zz1)
             inter *= d
 
         iou = inter / (areas[i] + areas[order] - inter)
 
         non_matches = np.nonzero(iou <= thresh)[0]  # get all elements that were not matched and discard all others.
         order = order[non_matches]
         keep.append(i)
 
     return keep
 
 
 
 ############################################################
 #  M-RCNN
 ############################################################
 
 def refine_proposals(rpn_pred_probs, rpn_pred_deltas, proposal_count, batch_anchors, cf):
     """
     Receives anchor scores and selects a subset to pass as proposals
     to the second stage. Filtering is done based on anchor scores and
     non-max suppression to remove overlaps. It also applies bounding
     box refinment details to anchors.
     :param rpn_pred_probs: (b, n_anchors, 2)
     :param rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d))))
     :return: batch_normalized_props: Proposals in normalized coordinates (b, proposal_count, (y1, x1, y2, x2, (z1), (z2), score))
     :return: batch_out_proposals: Box coords + RPN foreground scores
     for monitoring/plotting (b, proposal_count, (y1, x1, y2, x2, (z1), (z2), score))
     """
     std_dev = torch.from_numpy(cf.rpn_bbox_std_dev[None]).float().cuda()
     norm = torch.from_numpy(cf.scale).float().cuda()
     anchors = batch_anchors.clone()
 
 
 
     batch_scores = rpn_pred_probs[:, :, 1]
     # norm deltas
     batch_deltas = rpn_pred_deltas * std_dev
     batch_normalized_props = []
     batch_out_proposals = []
 
     # loop over batch dimension.
     for ix in range(batch_scores.shape[0]):
 
         scores = batch_scores[ix]
         deltas = batch_deltas[ix]
 
         # improve performance by trimming to top anchors by score
         # and doing the rest on the smaller subset.
         pre_nms_limit = min(cf.pre_nms_limit, anchors.size()[0])
         scores, order = scores.sort(descending=True)
         order = order[:pre_nms_limit]
         scores = scores[:pre_nms_limit]
         deltas = deltas[order, :]
 
         # apply deltas to anchors to get refined anchors and filter with non-maximum suppression.
         if batch_deltas.shape[-1] == 4:
             boxes = apply_box_deltas_2D(anchors[order, :], deltas)
             boxes = clip_boxes_2D(boxes, cf.window)
         else:
             boxes = apply_box_deltas_3D(anchors[order, :], deltas)
             boxes = clip_boxes_3D(boxes, cf.window)
         # boxes are y1,x1,y2,x2, torchvision-nms requires x1,y1,x2,y2, but consistent swap x<->y is irrelevant.
         keep = nms.nms(boxes, scores, cf.rpn_nms_threshold)
 
 
         keep = keep[:proposal_count]
         boxes = boxes[keep, :]
         rpn_scores = scores[keep][:, None]
 
         # pad missing boxes with 0.
         if boxes.shape[0] < proposal_count:
             n_pad_boxes = proposal_count - boxes.shape[0]
             zeros = torch.zeros([n_pad_boxes, boxes.shape[1]]).cuda()
             boxes = torch.cat([boxes, zeros], dim=0)
             zeros = torch.zeros([n_pad_boxes, rpn_scores.shape[1]]).cuda()
             rpn_scores = torch.cat([rpn_scores, zeros], dim=0)
 
         # concat box and score info for monitoring/plotting.
         batch_out_proposals.append(torch.cat((boxes, rpn_scores), 1).cpu().data.numpy())
         # normalize dimensions to range of 0 to 1.
         normalized_boxes = boxes / norm
         assert torch.all(normalized_boxes <= 1), "normalized box coords >1 found"
 
         # add again batch dimension
         batch_normalized_props.append(torch.cat((normalized_boxes, rpn_scores), 1).unsqueeze(0))
 
     batch_normalized_props = torch.cat(batch_normalized_props)
     batch_out_proposals = np.array(batch_out_proposals)
 
     return batch_normalized_props, batch_out_proposals
 
 def pyramid_roi_align(feature_maps, rois, pool_size, pyramid_levels, dim):
     """
     Implements ROI Pooling on multiple levels of the feature pyramid.
     :param feature_maps: list of feature maps, each of shape (b, c, y, x , (z))
     :param rois: proposals (normalized coords.) as returned by RPN. contain info about original batch element allocation.
     (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ixs)
     :param pool_size: list of poolsizes in dims: [x, y, (z)]
     :param pyramid_levels: list. [0, 1, 2, ...]
     :return: pooled: pooled feature map rois (n_proposals, c, poolsize_y, poolsize_x, (poolsize_z))
 
     Output:
     Pooled regions in the shape: [num_boxes, height, width, channels].
     The width and height are those specific in the pool_shape in the layer
     constructor.
     """
     boxes = rois[:, :dim*2]
     batch_ixs = rois[:, dim*2]
 
     # Assign each ROI to a level in the pyramid based on the ROI area.
     if dim == 2:
         y1, x1, y2, x2 = boxes.chunk(4, dim=1)
     else:
         y1, x1, y2, x2, z1, z2 = boxes.chunk(6, dim=1)
 
     h = y2 - y1
     w = x2 - x1
 
     # Equation 1 in https://arxiv.org/abs/1612.03144. Account for
     # the fact that our coordinates are normalized here.
     # divide sqrt(h*w) by 1 instead image_area.
     roi_level = (4 + torch.log2(torch.sqrt(h*w))).round().int().clamp(pyramid_levels[0], pyramid_levels[-1])
     # if Pyramid contains additional level P6, adapt the roi_level assignment accordingly.
     if len(pyramid_levels) == 5:
         roi_level[h*w > 0.65] = 5
 
     # Loop through levels and apply ROI pooling to each.
     pooled = []
     box_to_level = []
     fmap_shapes = [f.shape for f in feature_maps]
     for level_ix, level in enumerate(pyramid_levels):
         ix = roi_level == level
         if not ix.any():
             continue
         ix = torch.nonzero(ix)[:, 0]
         level_boxes = boxes[ix, :]
         # re-assign rois to feature map of original batch element.
         ind = batch_ixs[ix].int()
 
         # Keep track of which box is mapped to which level
         box_to_level.append(ix)
 
         # Stop gradient propogation to ROI proposals
         level_boxes = level_boxes.detach()
         if len(pool_size) == 2:
             # remap to feature map coordinate system
             y_exp, x_exp = fmap_shapes[level_ix][2:]  # exp = expansion
             level_boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp], dtype=torch.float32).cuda())
             pooled_features = roi_align.roi_align_2d(feature_maps[level_ix],
                                                      torch.cat((ind.unsqueeze(1).float(), level_boxes), dim=1),
                                                      pool_size)
         else:
             y_exp, x_exp, z_exp = fmap_shapes[level_ix][2:]
             level_boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp, z_exp, z_exp], dtype=torch.float32).cuda())
             pooled_features = roi_align.roi_align_3d(feature_maps[level_ix],
                                                      torch.cat((ind.unsqueeze(1).float(), level_boxes), dim=1),
                                                      pool_size)
         pooled.append(pooled_features)
 
 
     # Pack pooled features into one tensor
     pooled = torch.cat(pooled, dim=0)
 
     # Pack box_to_level mapping into one array and add another
     # column representing the order of pooled boxes
     box_to_level = torch.cat(box_to_level, dim=0)
 
     # Rearrange pooled features to match the order of the original boxes
     _, box_to_level = torch.sort(box_to_level)
     pooled = pooled[box_to_level, :, :]
 
     return pooled
 
 
 def roi_align_3d_numpy(input: np.ndarray, rois, output_size: tuple,
                        spatial_scale: float = 1., sampling_ratio: int = -1) -> np.ndarray:
     """ This fct mainly serves as a verification method for 3D CUDA implementation of RoIAlign, it's highly
         inefficient due to the nested loops.
     :param input:  (ndarray[N, C, H, W, D]): input feature map
     :param rois: list (N,K(n), 6), K(n) = nr of rois in batch-element n, single roi of format (y1,x1,y2,x2,z1,z2)
     :param output_size:
     :param spatial_scale:
     :param sampling_ratio:
     :return: (List[N, K(n), C, output_size[0], output_size[1], output_size[2]])
     """
 
     out_height, out_width, out_depth = output_size
 
     coord_grid = tuple([np.linspace(0, input.shape[dim] - 1, num=input.shape[dim]) for dim in range(2, 5)])
     pooled_rois = [[]] * len(rois)
     assert len(rois) == input.shape[0], "batch dim mismatch, rois: {}, input: {}".format(len(rois), input.shape[0])
     print("Numpy 3D RoIAlign progress:", end="\n")
     for b in range(input.shape[0]):
         for roi in tqdm.tqdm(rois[b]):
             y1, x1, y2, x2, z1, z2 = np.array(roi) * spatial_scale
             roi_height = max(float(y2 - y1), 1.)
             roi_width = max(float(x2 - x1), 1.)
             roi_depth = max(float(z2 - z1), 1.)
 
             if sampling_ratio <= 0:
                 sampling_ratio_h = int(np.ceil(roi_height / out_height))
                 sampling_ratio_w = int(np.ceil(roi_width / out_width))
                 sampling_ratio_d = int(np.ceil(roi_depth / out_depth))
             else:
                 sampling_ratio_h = sampling_ratio_w = sampling_ratio_d = sampling_ratio  # == n points per bin
 
             bin_height = roi_height / out_height
             bin_width = roi_width / out_width
             bin_depth = roi_depth / out_depth
 
             n_points = sampling_ratio_h * sampling_ratio_w * sampling_ratio_d
             pooled_roi = np.empty((input.shape[1], out_height, out_width, out_depth), dtype="float32")
             for chan in range(input.shape[1]):
                 lin_interpolator = scipy.interpolate.RegularGridInterpolator(coord_grid, input[b, chan],
                                                                              method="linear")
                 for bin_iy in range(out_height):
                     for bin_ix in range(out_width):
                         for bin_iz in range(out_depth):
 
                             bin_val = 0.
                             for i in range(sampling_ratio_h):
                                 for j in range(sampling_ratio_w):
                                     for k in range(sampling_ratio_d):
                                         loc_ijk = [
                                             y1 + bin_iy * bin_height + (i + 0.5) * (bin_height / sampling_ratio_h),
                                             x1 + bin_ix * bin_width + (j + 0.5) * (bin_width / sampling_ratio_w),
                                             z1 + bin_iz * bin_depth + (k + 0.5) * (bin_depth / sampling_ratio_d)]
                                         # print("loc_ijk", loc_ijk)
                                         if not (np.any([c < -1.0 for c in loc_ijk]) or loc_ijk[0] > input.shape[2] or
                                                 loc_ijk[1] > input.shape[3] or loc_ijk[2] > input.shape[4]):
                                             for catch_case in range(3):
                                                 # catch on-border cases
                                                 if int(loc_ijk[catch_case]) == input.shape[catch_case + 2] - 1:
                                                     loc_ijk[catch_case] = input.shape[catch_case + 2] - 1
                                             bin_val += lin_interpolator(loc_ijk)
                             pooled_roi[chan, bin_iy, bin_ix, bin_iz] = bin_val / n_points
 
             pooled_rois[b].append(pooled_roi)
 
     return np.array(pooled_rois)
 
 def refine_detections(cf, batch_ixs, rois, deltas, scores, regressions):
     """
     Refine classified proposals (apply deltas to rpn rois), filter overlaps (nms) and return final detections.
 
     :param rois: (n_proposals, 2 * dim) normalized boxes as proposed by RPN. n_proposals = batch_size * POST_NMS_ROIS
     :param deltas: (n_proposals, n_classes, 2 * dim) box refinement deltas as predicted by mrcnn bbox regressor.
     :param batch_ixs: (n_proposals) batch element assignment info for re-allocation.
     :param scores: (n_proposals, n_classes) probabilities for all classes per roi as predicted by mrcnn classifier.
     :param regressions: (n_proposals, n_classes, regression_features (+1 for uncertainty if predicted) regression vector
     :return: result: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score, *regression vector features))
     """
     # class IDs per ROI. Since scores of all classes are of interest (not just max class), all are kept at this point.
     class_ids = []
     fg_classes = cf.head_classes - 1
     # repeat vectors to fill in predictions for all foreground classes.
     for ii in range(1, fg_classes + 1):
         class_ids += [ii] * rois.shape[0]
     class_ids = torch.from_numpy(np.array(class_ids)).cuda()
 
     batch_ixs = batch_ixs.repeat(fg_classes)
     rois = rois.repeat(fg_classes, 1)
     deltas = deltas.repeat(fg_classes, 1, 1)
     scores = scores.repeat(fg_classes, 1)
     regressions = regressions.repeat(fg_classes, 1, 1)
 
     # get class-specific scores and  bounding box deltas
     idx = torch.arange(class_ids.size()[0]).long().cuda()
     # using idx instead of slice [:,] squashes first dimension.
     #len(class_ids)>scores.shape[1] --> probs is broadcasted by expansion from fg_classes-->len(class_ids)
     batch_ixs = batch_ixs[idx]
     deltas_specific = deltas[idx, class_ids]
     class_scores = scores[idx, class_ids]
     regressions = regressions[idx, class_ids]
 
     # apply bounding box deltas. re-scale to image coordinates.
     std_dev = torch.from_numpy(np.reshape(cf.rpn_bbox_std_dev, [1, cf.dim * 2])).float().cuda()
     scale = torch.from_numpy(cf.scale).float().cuda()
     refined_rois = apply_box_deltas_2D(rois, deltas_specific * std_dev) * scale if cf.dim == 2 else \
         apply_box_deltas_3D(rois, deltas_specific * std_dev) * scale
 
     # round and cast to int since we're dealing with pixels now
     refined_rois = clip_to_window(cf.window, refined_rois)
     refined_rois = torch.round(refined_rois)
 
     # filter out low confidence boxes
     keep = idx
     keep_bool = (class_scores >= cf.model_min_confidence)
     if not 0 in torch.nonzero(keep_bool).size():
 
         score_keep = torch.nonzero(keep_bool)[:, 0]
         pre_nms_class_ids = class_ids[score_keep]
         pre_nms_rois = refined_rois[score_keep]
         pre_nms_scores = class_scores[score_keep]
         pre_nms_batch_ixs = batch_ixs[score_keep]
 
         for j, b in enumerate(unique1d(pre_nms_batch_ixs)):
 
             bixs = torch.nonzero(pre_nms_batch_ixs == b)[:, 0]
             bix_class_ids = pre_nms_class_ids[bixs]
             bix_rois = pre_nms_rois[bixs]
             bix_scores = pre_nms_scores[bixs]
 
             for i, class_id in enumerate(unique1d(bix_class_ids)):
 
                 ixs = torch.nonzero(bix_class_ids == class_id)[:, 0]
                 # nms expects boxes sorted by score.
                 ix_rois = bix_rois[ixs]
                 ix_scores = bix_scores[ixs]
                 ix_scores, order = ix_scores.sort(descending=True)
                 ix_rois = ix_rois[order, :]
 
                 class_keep = nms.nms(ix_rois, ix_scores, cf.detection_nms_threshold)
 
                 # map indices back.
                 class_keep = keep[score_keep[bixs[ixs[order[class_keep]]]]]
                 # merge indices over classes for current batch element
                 b_keep = class_keep if i == 0 else unique1d(torch.cat((b_keep, class_keep)))
 
             # only keep top-k boxes of current batch-element
             top_ids = class_scores[b_keep].sort(descending=True)[1][:cf.model_max_instances_per_batch_element]
             b_keep = b_keep[top_ids]
 
             # merge indices over batch elements.
             batch_keep = b_keep  if j == 0 else unique1d(torch.cat((batch_keep, b_keep)))
 
         keep = batch_keep
 
     else:
         keep = torch.tensor([0]).long().cuda()
 
     # arrange output
     output = [refined_rois[keep], batch_ixs[keep].unsqueeze(1)]
     output += [class_ids[keep].unsqueeze(1).float(), class_scores[keep].unsqueeze(1)]
     output += [regressions[keep]]
 
     result = torch.cat(output, dim=1)
     # shape: (n_keeps, catted feats), catted feats: [0:dim*2] are box_coords, [dim*2] are batch_ics,
     # [dim*2+1] are class_ids, [dim*2+2] are scores, [dim*2+3:] are regression vector features (incl uncertainty)
     return result
 
 
 def loss_example_mining(cf, batch_proposals, batch_gt_boxes, batch_gt_masks, batch_roi_scores,
                            batch_gt_class_ids, batch_gt_regressions):
     """
     Subsamples proposals for mrcnn losses and generates targets. Sampling is done per batch element, seems to have positive
     effects on training, as opposed to sampling over entire batch. Negatives are sampled via stochastic hard-example mining
     (SHEM), where a number of negative proposals is drawn from larger pool of highest scoring proposals for stochasticity.
     Scoring is obtained here as the max over all foreground probabilities as returned by mrcnn_classifier (worked better than
     loss-based class-balancing methods like "online hard-example mining" or "focal loss".)
 
     Classification-regression duality: regressions can be given along with classes (at least fg/bg, only class scores
     are used for ranking).
 
     :param batch_proposals: (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ixs).
     boxes as proposed by RPN. n_proposals here is determined by batch_size * POST_NMS_ROIS.
     :param mrcnn_class_logits: (n_proposals, n_classes)
     :param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates.
-    :param batch_gt_masks: list over batch elements. Each element is binary mask of shape (n_gt_rois, y, x, (z), c)
+    :param batch_gt_masks: list over batch elements. Each element is binary mask of shape (n_gt_rois, c, y, x, (z))
     :param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels.
         if no classes predicted (only fg/bg from RPN): expected as pseudo classes [0, 1] for bg, fg.
     :param batch_gt_regressions: list over b elements. Each element is a regression target vector. if None--> pseudo
     :return: sample_indices: (n_sampled_rois) indices of sampled proposals to be used for loss functions.
     :return: target_class_ids: (n_sampled_rois)containing target class labels of sampled proposals.
     :return: target_deltas: (n_sampled_rois, 2 * dim) containing target deltas of sampled proposals for box refinement.
     :return: target_masks: (n_sampled_rois, y, x, (z)) containing target masks of sampled proposals.
     """
     # normalization of target coordinates
     #global sample_regressions
     if cf.dim == 2:
         h, w = cf.patch_size
         scale = torch.from_numpy(np.array([h, w, h, w])).float().cuda()
     else:
         h, w, z = cf.patch_size
         scale = torch.from_numpy(np.array([h, w, h, w, z, z])).float().cuda()
 
-
     positive_count = 0
     negative_count = 0
     sample_positive_indices = []
     sample_negative_indices = []
     sample_deltas = []
     sample_masks = []
     sample_class_ids = []
     if batch_gt_regressions is not None:
         sample_regressions = []
     else:
         target_regressions = torch.FloatTensor().cuda()
 
     std_dev = torch.from_numpy(cf.bbox_std_dev).float().cuda()
 
     # loop over batch and get positive and negative sample rois.
     for b in range(len(batch_gt_boxes)):
 
         gt_masks = torch.from_numpy(batch_gt_masks[b]).float().cuda()
         gt_class_ids = torch.from_numpy(batch_gt_class_ids[b]).int().cuda()
         if batch_gt_regressions is not None:
             gt_regressions = torch.from_numpy(batch_gt_regressions[b]).float().cuda()
 
         #if np.any(batch_gt_class_ids[b] > 0):  # skip roi selection for no gt images.
         if np.any([len(coords)>0 for coords in batch_gt_boxes[b]]):
             gt_boxes = torch.from_numpy(batch_gt_boxes[b]).float().cuda() / scale
         else:
             gt_boxes = torch.FloatTensor().cuda()
 
         # get proposals and indices of current batch element.
         proposals = batch_proposals[batch_proposals[:, -1] == b][:, :-1]
         batch_element_indices = torch.nonzero(batch_proposals[:, -1] == b).squeeze(1)
 
         # Compute overlaps matrix [proposals, gt_boxes]
         if not 0 in gt_boxes.size():
             if gt_boxes.shape[1] == 4:
                 assert cf.dim == 2, "gt_boxes shape {} doesnt match cf.dim{}".format(gt_boxes.shape, cf.dim)
                 overlaps = bbox_overlaps_2D(proposals, gt_boxes)
             else:
                 assert cf.dim == 3, "gt_boxes shape {} doesnt match cf.dim{}".format(gt_boxes.shape, cf.dim)
                 overlaps = bbox_overlaps_3D(proposals, gt_boxes)
 
             # Determine positive and negative ROIs
             roi_iou_max = torch.max(overlaps, dim=1)[0]
             # 1. Positive ROIs are those with >= 0.5 IoU with a GT box
             positive_roi_bool = roi_iou_max >= (0.5 if cf.dim == 2 else 0.3)
             # 2. Negative ROIs are those with < 0.1 with every GT box.
             negative_roi_bool = roi_iou_max < (0.1 if cf.dim == 2 else 0.01)
         else:
             positive_roi_bool = torch.FloatTensor().cuda()
             negative_roi_bool = torch.from_numpy(np.array([1]*proposals.shape[0])).cuda()
 
         # Sample Positive ROIs
         if not 0 in torch.nonzero(positive_roi_bool).size():
             positive_indices = torch.nonzero(positive_roi_bool).squeeze(1)
             positive_samples = int(cf.train_rois_per_image * cf.roi_positive_ratio)
             rand_idx = torch.randperm(positive_indices.size()[0])
             rand_idx = rand_idx[:positive_samples].cuda()
             positive_indices = positive_indices[rand_idx]
             positive_samples = positive_indices.size()[0]
             positive_rois = proposals[positive_indices, :]
             # Assign positive ROIs to GT boxes.
             positive_overlaps = overlaps[positive_indices, :]
             roi_gt_box_assignment = torch.max(positive_overlaps, dim=1)[1]
             roi_gt_boxes = gt_boxes[roi_gt_box_assignment, :]
             roi_gt_class_ids = gt_class_ids[roi_gt_box_assignment]
             if batch_gt_regressions is not None:
                 roi_gt_regressions = gt_regressions[roi_gt_box_assignment]
 
             # Compute bbox refinement targets for positive ROIs
             deltas = box_refinement(positive_rois, roi_gt_boxes)
             deltas /= std_dev
 
-            assert gt_masks[roi_gt_box_assignment].shape[-1] == 1
-            roi_masks = gt_masks[roi_gt_box_assignment].unsqueeze(1).squeeze(-1)
-
+            roi_masks = gt_masks[roi_gt_box_assignment]
+            #print("roi_masks[b] in ex mining pre align", roi_masks.unique(return_counts=True))
+            assert roi_masks.shape[1] == 1, "gt masks have more than one channel --> is this desired?"
             # Compute mask targets
             boxes = positive_rois
             box_ids = torch.arange(roi_masks.shape[0]).cuda().unsqueeze(1).float()
 
             if len(cf.mask_shape) == 2:
-                # todo what are the dims of roi_masks? (n_matched_boxes_with_gts, 1 (dummy channel dim), y,x, c (WHY?))
+                y_exp, x_exp = roi_masks.shape[2:]  # exp = expansion
+                boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp], dtype=torch.float32).cuda())
                 masks = roi_align.roi_align_2d(roi_masks,
                                                torch.cat((box_ids, boxes), dim=1),
                                                cf.mask_shape)
             else:
+                y_exp, x_exp, z_exp = roi_masks.shape[2:]  # exp = expansion
+                boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp, z_exp, z_exp], dtype=torch.float32).cuda())
                 masks = roi_align.roi_align_3d(roi_masks,
                                                torch.cat((box_ids, boxes), dim=1),
                                                cf.mask_shape)
-
+            #print("roi_masks[b] in ex mining POST align", masks.unique(return_counts=True))
 
             masks = masks.squeeze(1)
             # Threshold mask pixels at 0.5 to have GT masks be 0 or 1 to use with
             # binary cross entropy loss.
             masks = torch.round(masks)
 
             sample_positive_indices.append(batch_element_indices[positive_indices])
             sample_deltas.append(deltas)
             sample_masks.append(masks)
             sample_class_ids.append(roi_gt_class_ids)
             if batch_gt_regressions is not None:
                 sample_regressions.append(roi_gt_regressions)
             positive_count += positive_samples
         else:
             positive_samples = 0
 
         # Sample negative ROIs. Add enough to maintain positive:negative ratio, but at least 1. Sample via SHEM.
         if not 0 in torch.nonzero(negative_roi_bool).size():
             negative_indices = torch.nonzero(negative_roi_bool).squeeze(1)
             r = 1.0 / cf.roi_positive_ratio
             b_neg_count = np.max((int(r * positive_samples - positive_samples), 1))
             roi_scores_neg = batch_roi_scores[batch_element_indices[negative_indices]]
             raw_sampled_indices = shem(roi_scores_neg, b_neg_count, cf.shem_poolsize)
             sample_negative_indices.append(batch_element_indices[negative_indices[raw_sampled_indices]])
             negative_count  += raw_sampled_indices.size()[0]
 
     if len(sample_positive_indices) > 0:
         target_deltas = torch.cat(sample_deltas)
         target_masks = torch.cat(sample_masks)
         target_class_ids = torch.cat(sample_class_ids)
         if batch_gt_regressions is not None:
             target_regressions = torch.cat(sample_regressions)
 
     # Pad target information with zeros for negative ROIs.
     if positive_count > 0 and negative_count > 0:
         sample_indices = torch.cat((torch.cat(sample_positive_indices), torch.cat(sample_negative_indices)), dim=0)
         zeros = torch.zeros(negative_count, cf.dim * 2).cuda()
         target_deltas = torch.cat([target_deltas, zeros], dim=0)
         zeros = torch.zeros(negative_count, *cf.mask_shape).cuda()
         target_masks = torch.cat([target_masks, zeros], dim=0)
         zeros = torch.zeros(negative_count).int().cuda()
         target_class_ids = torch.cat([target_class_ids, zeros], dim=0)
         if batch_gt_regressions is not None:
             # regression targets need to have 0 as background/negative with below practice
             if 'regression_bin' in cf.prediction_tasks:
                 zeros = torch.zeros(negative_count, dtype=torch.float).cuda()
             else:
                 zeros = torch.zeros(negative_count, cf.regression_n_features, dtype=torch.float).cuda()
             target_regressions = torch.cat([target_regressions, zeros], dim=0)
 
     elif positive_count > 0:
         sample_indices = torch.cat(sample_positive_indices)
     elif negative_count > 0:
         sample_indices = torch.cat(sample_negative_indices)
         target_deltas = torch.zeros(negative_count, cf.dim * 2).cuda()
         target_masks = torch.zeros(negative_count, *cf.mask_shape).cuda()
         target_class_ids = torch.zeros(negative_count).int().cuda()
         if batch_gt_regressions is not None:
             if 'regression_bin' in cf.prediction_tasks:
                 target_regressions = torch.zeros(negative_count, dtype=torch.float).cuda()
             else:
                 target_regressions = torch.zeros(negative_count, cf.regression_n_features, dtype=torch.float).cuda()
     else:
         sample_indices = torch.LongTensor().cuda()
         target_class_ids = torch.IntTensor().cuda()
         target_deltas = torch.FloatTensor().cuda()
         target_masks = torch.FloatTensor().cuda()
         target_regressions = torch.FloatTensor().cuda()
 
     return sample_indices, target_deltas, target_masks, target_class_ids, target_regressions
 
 ############################################################
 #  Anchors
 ############################################################
 
 def generate_anchors(scales, ratios, shape, feature_stride, anchor_stride):
     """
     scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128]
     ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2]
     shape: [height, width] spatial shape of the feature map over which
             to generate anchors.
     feature_stride: Stride of the feature map relative to the image in pixels.
     anchor_stride: Stride of anchors on the feature map. For example, if the
         value is 2 then generate anchors for every other feature map pixel.
     """
     # Get all combinations of scales and ratios
     scales, ratios = np.meshgrid(np.array(scales), np.array(ratios))
     scales = scales.flatten()
     ratios = ratios.flatten()
 
     # Enumerate heights and widths from scales and ratios
     heights = scales / np.sqrt(ratios)
     widths = scales * np.sqrt(ratios)
 
     # Enumerate shifts in feature space
     shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride
     shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride
     shifts_x, shifts_y = np.meshgrid(shifts_x, shifts_y)
 
     # Enumerate combinations of shifts, widths, and heights
     box_widths, box_centers_x = np.meshgrid(widths, shifts_x)
     box_heights, box_centers_y = np.meshgrid(heights, shifts_y)
 
     # Reshape to get a list of (y, x) and a list of (h, w)
     box_centers = np.stack([box_centers_y, box_centers_x], axis=2).reshape([-1, 2])
     box_sizes = np.stack([box_heights, box_widths], axis=2).reshape([-1, 2])
 
     # Convert to corner coordinates (y1, x1, y2, x2)
     boxes = np.concatenate([box_centers - 0.5 * box_sizes, box_centers + 0.5 * box_sizes], axis=1)
     return boxes
 
 
 
 def generate_anchors_3D(scales_xy, scales_z, ratios, shape, feature_stride_xy, feature_stride_z, anchor_stride):
     """
     scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128]
     ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2]
     shape: [height, width] spatial shape of the feature map over which
             to generate anchors.
     feature_stride: Stride of the feature map relative to the image in pixels.
     anchor_stride: Stride of anchors on the feature map. For example, if the
         value is 2 then generate anchors for every other feature map pixel.
     """
     # Get all combinations of scales and ratios
 
     scales_xy, ratios_meshed = np.meshgrid(np.array(scales_xy), np.array(ratios))
     scales_xy = scales_xy.flatten()
     ratios_meshed = ratios_meshed.flatten()
 
     # Enumerate heights and widths from scales and ratios
     heights = scales_xy / np.sqrt(ratios_meshed)
     widths = scales_xy * np.sqrt(ratios_meshed)
     depths = np.tile(np.array(scales_z), len(ratios_meshed)//np.array(scales_z)[..., None].shape[0])
 
     # Enumerate shifts in feature space
     shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride_xy #translate from fm positions to input coords.
     shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride_xy
     shifts_z = np.arange(0, shape[2], anchor_stride) * (feature_stride_z)
     shifts_x, shifts_y, shifts_z = np.meshgrid(shifts_x, shifts_y, shifts_z)
 
     # Enumerate combinations of shifts, widths, and heights
     box_widths, box_centers_x = np.meshgrid(widths, shifts_x)
     box_heights, box_centers_y = np.meshgrid(heights, shifts_y)
     box_depths, box_centers_z = np.meshgrid(depths, shifts_z)
 
     # Reshape to get a list of (y, x, z) and a list of (h, w, d)
     box_centers = np.stack(
         [box_centers_y, box_centers_x, box_centers_z], axis=2).reshape([-1, 3])
     box_sizes = np.stack([box_heights, box_widths, box_depths], axis=2).reshape([-1, 3])
 
     # Convert to corner coordinates (y1, x1, y2, x2, z1, z2)
     boxes = np.concatenate([box_centers - 0.5 * box_sizes,
                             box_centers + 0.5 * box_sizes], axis=1)
 
     boxes = np.transpose(np.array([boxes[:, 0], boxes[:, 1], boxes[:, 3], boxes[:, 4], boxes[:, 2], boxes[:, 5]]), axes=(1, 0))
     return boxes
 
 
 def generate_pyramid_anchors(logger, cf):
     """Generate anchors at different levels of a feature pyramid. Each scale
     is associated with a level of the pyramid, but each ratio is used in
     all levels of the pyramid.
 
     from configs:
     :param scales: cf.RPN_ANCHOR_SCALES , for conformity with retina nets: scale entries need to be list, e.g. [[4], [8], [16], [32]]
     :param ratios: cf.RPN_ANCHOR_RATIOS , e.g. [0.5, 1, 2]
     :param feature_shapes: cf.BACKBONE_SHAPES , e.g.  [array of shapes per feature map] [80, 40, 20, 10, 5]
     :param feature_strides: cf.BACKBONE_STRIDES , e.g. [2, 4, 8, 16, 32, 64]
     :param anchors_stride: cf.RPN_ANCHOR_STRIDE , e.g. 1
     :return anchors: (N, (y1, x1, y2, x2, (z1), (z2)). All generated anchors in one array. Sorted
     with the same order of the given scales. So, anchors of scale[0] come first, then anchors of scale[1], and so on.
     """
     scales = cf.rpn_anchor_scales
     ratios = cf.rpn_anchor_ratios
     feature_shapes = cf.backbone_shapes
     anchor_stride = cf.rpn_anchor_stride
     pyramid_levels = cf.pyramid_levels
     feature_strides = cf.backbone_strides
 
     logger.info("anchor scales {} and feature map shapes {}".format(scales, feature_shapes))
     expected_anchors = [np.prod(feature_shapes[level]) * len(ratios) * len(scales['xy'][level]) for level in pyramid_levels]
 
     anchors = []
     for lix, level in enumerate(pyramid_levels):
         if len(feature_shapes[level]) == 2:
             anchors.append(generate_anchors(scales['xy'][level], ratios, feature_shapes[level],
                                             feature_strides['xy'][level], anchor_stride))
         elif len(feature_shapes[level]) == 3:
             anchors.append(generate_anchors_3D(scales['xy'][level], scales['z'][level], ratios, feature_shapes[level],
                                             feature_strides['xy'][level], feature_strides['z'][level], anchor_stride))
         else:
             raise Exception("invalid feature_shapes[{}] size {}".format(level, feature_shapes[level]))
         logger.info("level {}: expected anchors {}, built anchors {}.".format(level, expected_anchors[lix], anchors[-1].shape))
 
     out_anchors = np.concatenate(anchors, axis=0)
     logger.info("Total: expected anchors {}, built anchors {}.".format(np.sum(expected_anchors), out_anchors.shape))
 
     return out_anchors
 
 
 
 def apply_box_deltas_2D(boxes, deltas):
     """Applies the given deltas to the given boxes.
     boxes: [N, 4] where each row is y1, x1, y2, x2
     deltas: [N, 4] where each row is [dy, dx, log(dh), log(dw)]
     """
     # Convert to y, x, h, w
     height = boxes[:, 2] - boxes[:, 0]
     width = boxes[:, 3] - boxes[:, 1]
     center_y = boxes[:, 0] + 0.5 * height
     center_x = boxes[:, 1] + 0.5 * width
     # Apply deltas
     center_y += deltas[:, 0] * height
     center_x += deltas[:, 1] * width
     height *= torch.exp(deltas[:, 2])
     width *= torch.exp(deltas[:, 3])
     # Convert back to y1, x1, y2, x2
     y1 = center_y - 0.5 * height
     x1 = center_x - 0.5 * width
     y2 = y1 + height
     x2 = x1 + width
     result = torch.stack([y1, x1, y2, x2], dim=1)
     return result
 
 
 
 def apply_box_deltas_3D(boxes, deltas):
     """Applies the given deltas to the given boxes.
     boxes: [N, 6] where each row is y1, x1, y2, x2, z1, z2
     deltas: [N, 6] where each row is [dy, dx, dz, log(dh), log(dw), log(dd)]
     """
     # Convert to y, x, h, w
     height = boxes[:, 2] - boxes[:, 0]
     width = boxes[:, 3] - boxes[:, 1]
     depth = boxes[:, 5] - boxes[:, 4]
     center_y = boxes[:, 0] + 0.5 * height
     center_x = boxes[:, 1] + 0.5 * width
     center_z = boxes[:, 4] + 0.5 * depth
     # Apply deltas
     center_y += deltas[:, 0] * height
     center_x += deltas[:, 1] * width
     center_z += deltas[:, 2] * depth
     height *= torch.exp(deltas[:, 3])
     width *= torch.exp(deltas[:, 4])
     depth *= torch.exp(deltas[:, 5])
     # Convert back to y1, x1, y2, x2
     y1 = center_y - 0.5 * height
     x1 = center_x - 0.5 * width
     z1 = center_z - 0.5 * depth
     y2 = y1 + height
     x2 = x1 + width
     z2 = z1 + depth
     result = torch.stack([y1, x1, y2, x2, z1, z2], dim=1)
     return result
 
 
 
 def clip_boxes_2D(boxes, window):
     """
     boxes: [N, 4] each col is y1, x1, y2, x2
     window: [4] in the form y1, x1, y2, x2
     """
     boxes = torch.stack( \
         [boxes[:, 0].clamp(float(window[0]), float(window[2])),
          boxes[:, 1].clamp(float(window[1]), float(window[3])),
          boxes[:, 2].clamp(float(window[0]), float(window[2])),
          boxes[:, 3].clamp(float(window[1]), float(window[3]))], 1)
     return boxes
 
 def clip_boxes_3D(boxes, window):
     """
     boxes: [N, 6] each col is y1, x1, y2, x2, z1, z2
     window: [6] in the form y1, x1, y2, x2, z1, z2
     """
     boxes = torch.stack( \
         [boxes[:, 0].clamp(float(window[0]), float(window[2])),
          boxes[:, 1].clamp(float(window[1]), float(window[3])),
          boxes[:, 2].clamp(float(window[0]), float(window[2])),
          boxes[:, 3].clamp(float(window[1]), float(window[3])),
          boxes[:, 4].clamp(float(window[4]), float(window[5])),
          boxes[:, 5].clamp(float(window[4]), float(window[5]))], 1)
     return boxes
 
 from matplotlib import pyplot as plt
 
 
 def clip_boxes_numpy(boxes, window):
     """
     boxes: [N, 4] each col is y1, x1, y2, x2 / [N, 6] in 3D.
     window: iamge shape (y, x, (z))
     """
     if boxes.shape[1] == 4:
         boxes = np.concatenate(
             (np.clip(boxes[:, 0], 0, window[0])[:, None],
             np.clip(boxes[:, 1], 0, window[0])[:, None],
             np.clip(boxes[:, 2], 0, window[1])[:, None],
             np.clip(boxes[:, 3], 0, window[1])[:, None]), 1
         )
 
     else:
         boxes = np.concatenate(
             (np.clip(boxes[:, 0], 0, window[0])[:, None],
              np.clip(boxes[:, 1], 0, window[0])[:, None],
              np.clip(boxes[:, 2], 0, window[1])[:, None],
              np.clip(boxes[:, 3], 0, window[1])[:, None],
              np.clip(boxes[:, 4], 0, window[2])[:, None],
              np.clip(boxes[:, 5], 0, window[2])[:, None]), 1
         )
 
     return boxes
 
 
 
 def bbox_overlaps_2D(boxes1, boxes2):
     """Computes IoU overlaps between two sets of boxes.
     boxes1, boxes2: [N, (y1, x1, y2, x2)].
     """
     # 1. Tile boxes2 and repeate boxes1. This allows us to compare
     # every boxes1 against every boxes2 without loops.
     # TF doesn't have an equivalent to np.repeate() so simulate it
     # using tf.tile() and tf.reshape.
 
     boxes1_repeat = boxes2.size()[0]
     boxes2_repeat = boxes1.size()[0]
 
     boxes1 = boxes1.repeat(1,boxes1_repeat).view(-1,4)
     boxes2 = boxes2.repeat(boxes2_repeat,1)
 
     # 2. Compute intersections
     b1_y1, b1_x1, b1_y2, b1_x2 = boxes1.chunk(4, dim=1)
     b2_y1, b2_x1, b2_y2, b2_x2 = boxes2.chunk(4, dim=1)
     y1 = torch.max(b1_y1, b2_y1)[:, 0]
     x1 = torch.max(b1_x1, b2_x1)[:, 0]
     y2 = torch.min(b1_y2, b2_y2)[:, 0]
     x2 = torch.min(b1_x2, b2_x2)[:, 0]
     #--> expects x1<x2 & y1<y2
     zeros = torch.zeros(y1.size()[0], requires_grad=False)
     if y1.is_cuda:
         zeros = zeros.cuda()
     intersection = torch.max(x2 - x1, zeros) * torch.max(y2 - y1, zeros)
 
     # 3. Compute unions
     b1_area = (b1_y2 - b1_y1) * (b1_x2 - b1_x1)
     b2_area = (b2_y2 - b2_y1) * (b2_x2 - b2_x1)
     union = b1_area[:,0] + b2_area[:,0] - intersection
 
     # 4. Compute IoU and reshape to [boxes1, boxes2]
     iou = intersection / union
     assert torch.all(iou<=1), "iou score>1 produced in bbox_overlaps_2D"
     overlaps = iou.view(boxes2_repeat, boxes1_repeat) #--> per gt box: ious of all proposal boxes with that gt box
 
     return overlaps
 
 def bbox_overlaps_3D(boxes1, boxes2):
     """Computes IoU overlaps between two sets of boxes.
     boxes1, boxes2: [N, (y1, x1, y2, x2, z1, z2)].
     """
     # 1. Tile boxes2 and repeate boxes1. This allows us to compare
     # every boxes1 against every boxes2 without loops.
     # TF doesn't have an equivalent to np.repeate() so simulate it
     # using tf.tile() and tf.reshape.
     boxes1_repeat = boxes2.size()[0]
     boxes2_repeat = boxes1.size()[0]
     boxes1 = boxes1.repeat(1,boxes1_repeat).view(-1,6)
     boxes2 = boxes2.repeat(boxes2_repeat,1)
 
     # 2. Compute intersections
     b1_y1, b1_x1, b1_y2, b1_x2, b1_z1, b1_z2 = boxes1.chunk(6, dim=1)
     b2_y1, b2_x1, b2_y2, b2_x2, b2_z1, b2_z2 = boxes2.chunk(6, dim=1)
     y1 = torch.max(b1_y1, b2_y1)[:, 0]
     x1 = torch.max(b1_x1, b2_x1)[:, 0]
     y2 = torch.min(b1_y2, b2_y2)[:, 0]
     x2 = torch.min(b1_x2, b2_x2)[:, 0]
     z1 = torch.max(b1_z1, b2_z1)[:, 0]
     z2 = torch.min(b1_z2, b2_z2)[:, 0]
     zeros = torch.zeros(y1.size()[0], requires_grad=False)
     if y1.is_cuda:
         zeros = zeros.cuda()
     intersection = torch.max(x2 - x1, zeros) * torch.max(y2 - y1, zeros) * torch.max(z2 - z1, zeros)
 
     # 3. Compute unions
     b1_volume = (b1_y2 - b1_y1) * (b1_x2 - b1_x1)  * (b1_z2 - b1_z1)
     b2_volume = (b2_y2 - b2_y1) * (b2_x2 - b2_x1)  * (b2_z2 - b2_z1)
     union = b1_volume[:,0] + b2_volume[:,0] - intersection
 
     # 4. Compute IoU and reshape to [boxes1, boxes2]
     iou = intersection / union
     overlaps = iou.view(boxes2_repeat, boxes1_repeat)
     return overlaps
 
 def gt_anchor_matching(cf, anchors, gt_boxes, gt_class_ids=None):
     """Given the anchors and GT boxes, compute overlaps and identify positive
     anchors and deltas to refine them to match their corresponding GT boxes.
 
     anchors: [num_anchors, (y1, x1, y2, x2, (z1), (z2))]
     gt_boxes: [num_gt_boxes, (y1, x1, y2, x2, (z1), (z2))]
     gt_class_ids (optional): [num_gt_boxes] Integer class IDs for one stage detectors. in RPN case of Mask R-CNN,
     set all positive matches to 1 (foreground)
 
     Returns:
     anchor_class_matches: [N] (int32) matches between anchors and GT boxes.
                1 = positive anchor, -1 = negative anchor, 0 = neutral
     anchor_delta_targets: [N, (dy, dx, (dz), log(dh), log(dw), (log(dd)))] Anchor bbox deltas.
     """
 
     anchor_class_matches = np.zeros([anchors.shape[0]], dtype=np.int32)
     anchor_delta_targets = np.zeros((cf.rpn_train_anchors_per_image, 2*cf.dim))
     anchor_matching_iou = cf.anchor_matching_iou
 
     if gt_boxes is None:
         anchor_class_matches = np.full(anchor_class_matches.shape, fill_value=-1)
         return anchor_class_matches, anchor_delta_targets
 
     # for mrcnn: anchor matching is done for RPN loss, so positive labels are all 1 (foreground)
     if gt_class_ids is None:
         gt_class_ids = np.array([1] * len(gt_boxes))
 
     # Compute overlaps [num_anchors, num_gt_boxes]
     overlaps = compute_overlaps(anchors, gt_boxes)
 
     # Match anchors to GT Boxes
     # If an anchor overlaps a GT box with IoU >= anchor_matching_iou then it's positive.
     # If an anchor overlaps a GT box with IoU < 0.1 then it's negative.
     # Neutral anchors are those that don't match the conditions above,
     # and they don't influence the loss function.
     # However, don't keep any GT box unmatched (rare, but happens). Instead,
     # match it to the closest anchor (even if its max IoU is < 0.1).
 
     # 1. Set negative anchors first. They get overwritten below if a GT box is
     # matched to them. Skip boxes in crowd areas.
     anchor_iou_argmax = np.argmax(overlaps, axis=1)
     anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax]
     if anchors.shape[1] == 4:
         anchor_class_matches[(anchor_iou_max < 0.1)] = -1
     elif anchors.shape[1] == 6:
         anchor_class_matches[(anchor_iou_max < 0.01)] = -1
     else:
         raise ValueError('anchor shape wrong {}'.format(anchors.shape))
 
     # 2. Set an anchor for each GT box (regardless of IoU value).
     gt_iou_argmax = np.argmax(overlaps, axis=0)
     for ix, ii in enumerate(gt_iou_argmax):
         anchor_class_matches[ii] = gt_class_ids[ix]
 
     # 3. Set anchors with high overlap as positive.
     above_thresh_ixs = np.argwhere(anchor_iou_max >= anchor_matching_iou)
     anchor_class_matches[above_thresh_ixs] = gt_class_ids[anchor_iou_argmax[above_thresh_ixs]]
 
     # Subsample to balance positive anchors.
     ids = np.where(anchor_class_matches > 0)[0]
     extra = len(ids) - (cf.rpn_train_anchors_per_image // 2)
     if extra > 0:
         # Reset the extra ones to neutral
         ids = np.random.choice(ids, extra, replace=False)
         anchor_class_matches[ids] = 0
 
     # Leave all negative proposals negative for now and sample from them later in online hard example mining.
     # For positive anchors, compute shift and scale needed to transform them to match the corresponding GT boxes.
     ids = np.where(anchor_class_matches > 0)[0]
     ix = 0  # index into anchor_delta_targets
     for i, a in zip(ids, anchors[ids]):
         # closest gt box (it might have IoU < anchor_matching_iou)
         gt = gt_boxes[anchor_iou_argmax[i]]
 
         # convert coordinates to center plus width/height.
         gt_h = gt[2] - gt[0]
         gt_w = gt[3] - gt[1]
         gt_center_y = gt[0] + 0.5 * gt_h
         gt_center_x = gt[1] + 0.5 * gt_w
         # Anchor
         a_h = a[2] - a[0]
         a_w = a[3] - a[1]
         a_center_y = a[0] + 0.5 * a_h
         a_center_x = a[1] + 0.5 * a_w
 
         if cf.dim == 2:
             anchor_delta_targets[ix] = [
                 (gt_center_y - a_center_y) / a_h,
                 (gt_center_x - a_center_x) / a_w,
                 np.log(gt_h / a_h),
                 np.log(gt_w / a_w),
             ]
 
         else:
             gt_d = gt[5] - gt[4]
             gt_center_z = gt[4] + 0.5 * gt_d
             a_d = a[5] - a[4]
             a_center_z = a[4] + 0.5 * a_d
 
             anchor_delta_targets[ix] = [
                 (gt_center_y - a_center_y) / a_h,
                 (gt_center_x - a_center_x) / a_w,
                 (gt_center_z - a_center_z) / a_d,
                 np.log(gt_h / a_h),
                 np.log(gt_w / a_w),
                 np.log(gt_d / a_d)
             ]
 
         # normalize.
         anchor_delta_targets[ix] /= cf.rpn_bbox_std_dev
         ix += 1
 
     return anchor_class_matches, anchor_delta_targets
 
 
 
 def clip_to_window(window, boxes):
     """
         window: (y1, x1, y2, x2) / 3D: (z1, z2). The window in the image we want to clip to.
         boxes: [N, (y1, x1, y2, x2)]  / 3D: (z1, z2)
     """
     boxes[:, 0] = boxes[:, 0].clamp(float(window[0]), float(window[2]))
     boxes[:, 1] = boxes[:, 1].clamp(float(window[1]), float(window[3]))
     boxes[:, 2] = boxes[:, 2].clamp(float(window[0]), float(window[2]))
     boxes[:, 3] = boxes[:, 3].clamp(float(window[1]), float(window[3]))
 
     if boxes.shape[1] > 5:
         boxes[:, 4] = boxes[:, 4].clamp(float(window[4]), float(window[5]))
         boxes[:, 5] = boxes[:, 5].clamp(float(window[4]), float(window[5]))
 
     return boxes
 
 ############################################################
 #  Connected Componenent Analysis
 ############################################################
 
 def get_coords(binary_mask, n_components, dim):
     """
     loops over batch to perform connected component analysis on binary input mask. computes box coordinates around
     n_components - biggest components (rois).
     :param binary_mask: (b, y, x, (z)). binary mask for one specific foreground class.
     :param n_components: int. number of components to extract per batch element and class.
     :return: coords (b, n, (y1, x1, y2, x2 (,z1, z2))
     :return: batch_components (b, n, (y1, x1, y2, x2, (z1), (z2))
     """
     assert len(binary_mask.shape)==dim+1
     binary_mask = binary_mask.astype('uint8')
     batch_coords = []
     batch_components = []
     for ix,b in enumerate(binary_mask):
         clusters, n_cands = lb(b)  # performs connected component analysis.
         uniques, counts = np.unique(clusters, return_counts=True)
         keep_uniques = uniques[1:][np.argsort(counts[1:])[::-1]][:n_components] #only keep n_components largest components
         p_components = np.array([(clusters == ii) * 1 for ii in keep_uniques])  # separate clusters and concat
         p_coords = []
         if p_components.shape[0] > 0:
             for roi in p_components:
                 mask_ixs = np.argwhere(roi != 0)
 
                 # get coordinates around component.
                 roi_coords = [np.min(mask_ixs[:, 0]) - 1, np.min(mask_ixs[:, 1]) - 1, np.max(mask_ixs[:, 0]) + 1,
                                np.max(mask_ixs[:, 1]) + 1]
                 if dim == 3:
                     roi_coords += [np.min(mask_ixs[:, 2]), np.max(mask_ixs[:, 2])+1]
                 p_coords.append(roi_coords)
 
             p_coords = np.array(p_coords)
 
             #clip coords.
             p_coords[p_coords < 0] = 0
             p_coords[:, :4][p_coords[:, :4] > binary_mask.shape[-2]] = binary_mask.shape[-2]
             if dim == 3:
                 p_coords[:, 4:][p_coords[:, 4:] > binary_mask.shape[-1]] = binary_mask.shape[-1]
 
         batch_coords.append(p_coords)
         batch_components.append(p_components)
     return batch_coords, batch_components
 
 
 # noinspection PyCallingNonCallable
 def get_coords_gpu(binary_mask, n_components, dim):
     """
     loops over batch to perform connected component analysis on binary input mask. computes box coordiantes around
     n_components - biggest components (rois).
     :param binary_mask: (b, y, x, (z)). binary mask for one specific foreground class.
     :param n_components: int. number of components to extract per batch element and class.
     :return: coords (b, n, (y1, x1, y2, x2 (,z1, z2))
     :return: batch_components (b, n, (y1, x1, y2, x2, (z1), (z2))
     """
     raise Exception("throws floating point exception")
     assert len(binary_mask.shape)==dim+1
     binary_mask = binary_mask.type(torch.uint8)
     batch_coords = []
     batch_components = []
     for ix,b in enumerate(binary_mask):
         clusters, n_cands = lb(b.cpu().data.numpy())  # peforms connected component analysis.
         clusters = torch.from_numpy(clusters).cuda()
         uniques = torch.unique(clusters)
         counts = torch.stack([(clusters==unique).sum() for unique in uniques])
         keep_uniques = uniques[1:][torch.sort(counts[1:])[1].flip(0)][:n_components] #only keep n_components largest components
         p_components = torch.cat([(clusters == ii).unsqueeze(0) for ii in keep_uniques]).cuda()  # separate clusters and concat
         p_coords = []
         if p_components.shape[0] > 0:
             for roi in p_components:
                 mask_ixs = torch.nonzero(roi)
 
                 # get coordinates around component.
                 roi_coords = [torch.min(mask_ixs[:, 0]) - 1, torch.min(mask_ixs[:, 1]) - 1,
                               torch.max(mask_ixs[:, 0]) + 1,
                               torch.max(mask_ixs[:, 1]) + 1]
                 if dim == 3:
                     roi_coords += [torch.min(mask_ixs[:, 2]), torch.max(mask_ixs[:, 2])+1]
                 p_coords.append(roi_coords)
 
             p_coords = torch.tensor(p_coords)
 
             #clip coords.
             p_coords[p_coords < 0] = 0
             p_coords[:, :4][p_coords[:, :4] > binary_mask.shape[-2]] = binary_mask.shape[-2]
             if dim == 3:
                 p_coords[:, 4:][p_coords[:, 4:] > binary_mask.shape[-1]] = binary_mask.shape[-1]
 
         batch_coords.append(p_coords)
         batch_components.append(p_components)
     return batch_coords, batch_components
 
 
 ############################################################
 #  Pytorch Utility Functions
 ############################################################
 
 def unique1d(tensor):
     """discard all elements of tensor that occur more than once; make tensor unique.
     :param tensor:
     :return:
     """
     if tensor.size()[0] == 0 or tensor.size()[0] == 1:
         return tensor
     tensor = tensor.sort()[0]
     unique_bool = tensor[1:] != tensor[:-1]
     first_element = torch.tensor([True], dtype=torch.bool, requires_grad=False)
     if tensor.is_cuda:
         first_element = first_element.cuda()
     unique_bool = torch.cat((first_element, unique_bool), dim=0)
     return tensor[unique_bool.data]
 
 
 def intersect1d(tensor1, tensor2):
     aux = torch.cat((tensor1, tensor2), dim=0)
     aux = aux.sort(descending=True)[0]
     return aux[:-1][(aux[1:] == aux[:-1]).data]
 
 
 
 def shem(roi_probs_neg, negative_count, poolsize):
     """
     stochastic hard example mining: from a list of indices (referring to non-matched predictions),
     determine a pool of highest scoring (worst false positives) of size negative_count*poolsize.
     Then, sample n (= negative_count) predictions of this pool as negative examples for loss.
     :param roi_probs_neg: tensor of shape (n_predictions, n_classes).
     :param negative_count: int.
     :param poolsize: int.
     :return: (negative_count).  indices refer to the positions in roi_probs_neg. If pool smaller than expected due to
     limited negative proposals availabel, this function will return sampled indices of number < negative_count without
     throwing an error.
     """
     # sort according to higehst foreground score.
     probs, order = roi_probs_neg[:, 1:].max(1)[0].sort(descending=True)
     select = torch.tensor((poolsize * int(negative_count), order.size()[0])).min().int()
 
     pool_indices = order[:select]
     rand_idx = torch.randperm(pool_indices.size()[0])
     return pool_indices[rand_idx[:negative_count].cuda()]
 
 
 ############################################################
 #  Weight Init
 ############################################################
 
 
 def initialize_weights(net):
     """Initialize model weights. Current Default in Pytorch (version 0.4.1) is initialization from a uniform distriubtion.
     Will expectably be changed to kaiming_uniform in future versions.
     """
     init_type = net.cf.weight_init
 
     for m in [module for module in net.modules() if type(module) in [torch.nn.Conv2d, torch.nn.Conv3d,
                                                                      torch.nn.ConvTranspose2d,
                                                                      torch.nn.ConvTranspose3d,
                                                                      torch.nn.Linear]]:
         if init_type == 'xavier_uniform':
             torch.nn.init.xavier_uniform_(m.weight.data)
             if m.bias is not None:
                 m.bias.data.zero_()
 
         elif init_type == 'xavier_normal':
             torch.nn.init.xavier_normal_(m.weight.data)
             if m.bias is not None:
                 m.bias.data.zero_()
 
         elif init_type == "kaiming_uniform":
             torch.nn.init.kaiming_uniform_(m.weight.data, mode='fan_out', nonlinearity=net.cf.relu, a=0)
             if m.bias is not None:
                 fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                 bound = 1 / np.sqrt(fan_out)
                 torch.nn.init.uniform_(m.bias, -bound, bound)
 
         elif init_type == "kaiming_normal":
             torch.nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity=net.cf.relu, a=0)
             if m.bias is not None:
                 fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                 bound = 1 / np.sqrt(fan_out)
                 torch.nn.init.normal_(m.bias, -bound, bound)
     net.logger.info("applied {} weight init.".format(init_type))
\ No newline at end of file