diff --git a/experiments/lidc_exp/configs.py b/experiments/lidc_exp/configs.py
index aba901b..1bf3237 100644
--- a/experiments/lidc_exp/configs.py
+++ b/experiments/lidc_exp/configs.py
@@ -1,341 +1,341 @@
 #!/usr/bin/env python
 # Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import sys
 import os
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 import numpy as np
 from default_configs import DefaultConfigs
 
 class configs(DefaultConfigs):
 
     def __init__(self, server_env=None):
 
         #########################
         #    Preprocessing      #
         #########################
 
         self.root_dir = '/home/gregor/networkdrives/E130-Personal/Goetz/Datenkollektive/Lungendaten/Nodules_LIDC_IDRI'
         self.raw_data_dir = '{}/new_nrrd'.format(self.root_dir)
         self.pp_dir = '/media/gregor/HDD2TB/data/lidc/lidc_mdt'
         self.target_spacing = (0.7, 0.7, 1.25)
 
         #########################
         #         I/O           #
         #########################
 
 
         # one out of [2, 3]. dimension the model operates in.
         self.dim = 2
 
         # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn'].
-        self.model = 'mrcnn'
+        self.model = 'retina_unet'
 
         DefaultConfigs.__init__(self, self.model, server_env, self.dim)
 
         # int [0 < dataset_size]. select n patients from dataset for prototyping. If None, all data is used.
         self.select_prototype_subset = None
 
         # path to preprocessed data.
         self.pp_name = 'lidc_mdt'
         self.input_df_name = 'info_df.pickle'
         self.pp_data_path = '/media/gregor/HDD2TB/data/lidc/{}'.format(self.pp_name)
         self.pp_test_data_path = self.pp_data_path #change if test_data in separate folder.
 
         # settings for deployment in cloud.
         if server_env:
             # path to preprocessed data.
             self.pp_name = 'lidc_mdt_npz'
             self.crop_name = 'pp_fg_slices_packed'
             self.pp_data_path = '/datasets/datasets_ramien/lidc_exp/data/{}'.format(self.pp_name)
             self.pp_test_data_path = self.pp_data_path
             self.select_prototype_subset = None
 
         #########################
         #      Data Loader      #
         #########################
 
         # select modalities from preprocessed data
         self.channels = [0]
         self.n_channels = len(self.channels)
 
         # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation.
         self.pre_crop_size_2D = [300, 300]
         self.patch_size_2D = [288, 288]
         self.pre_crop_size_3D = [156, 156, 96]
         self.patch_size_3D = [128, 128, 64]
         self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D
         self.pre_crop_size = self.pre_crop_size_2D if self.dim == 2 else self.pre_crop_size_3D
 
         # ratio of free sampled batch elements before class balancing is triggered
         # (>0 to include "empty"/background patches.)
         self.batch_sample_slack = 0.2
 
         # set 2D network to operate in 3D images.
         self.merge_2D_to_3D_preds = self.dim == 2
 
         # feed +/- n neighbouring slices into channel dimension. set to None for no context.
         self.n_3D_context = None
         if self.n_3D_context is not None and self.dim == 2:
             self.n_channels *= (self.n_3D_context * 2 + 1)
 
 
         #########################
         #      Architecture      #
         #########################
 
         self.start_filts = 48 if self.dim == 2 else 18
         self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2
         self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50'
-        self.norm = "instance_norm" # one of None, 'instance_norm', 'batch_norm'
+        self.norm = None # one of None, 'instance_norm', 'batch_norm'
         self.weight_decay = 1e-5
 
         # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform')
         self.weight_init = None
 
         #########################
         #  Schedule / Selection #
         #########################
 
         self.num_epochs = 100
         self.num_train_batches = 200 if self.dim == 2 else 300
         self.batch_size = 20 if self.dim == 2 else 8
 
         self.do_validation = True
         # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training)
         # the former is morge accurate, while the latter is faster (depending on volume size)
         self.val_mode = 'val_sampling' # one of 'val_sampling' , 'val_patient'
         if self.val_mode == 'val_patient':
             self.max_val_patients = 50  # if 'None' iterates over entire val_set once.
         if self.val_mode == 'val_sampling':
             self.num_val_batches = 50
 
         # set dynamic_lr_scheduling to True to apply LR scheduling with below settings.
         self.dynamic_lr_scheduling = True
         self.lr_decay_factor = 0.5
         self.scheduling_patience = np.ceil(6000 / (self.num_train_batches * self.batch_size))
         self.scheduling_criterion = 'malignant_ap'
         self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
         #########################
         #   Testing / Plotting  #
         #########################
 
         # set the top-n-epochs to be saved for temporal averaging in testing.
         self.save_n_models = 5
         self.test_n_epochs = 5
         # set a minimum epoch number for saving in case of instabilities in the first phase of training.
         self.min_save_thresh = 1 if self.dim == 2 else 1
 
         self.report_score_level = ['patient', 'rois']  # choose list from 'patient', 'rois'
         self.class_dict = {1: 'benign', 2: 'malignant'}  # 0 is background.
         self.patient_class_of_interest = 2  # patient metrics are only plotted for one class.
         self.ap_match_ious = [0.1]  # list of ious to be evaluated for ap-scoring.
 
         self.model_selection_criteria = ['malignant_ap', 'benign_ap'] # criteria to average over for saving epochs.
         self.min_det_thresh = 0.1  # minimum confidence value to select predictions for evaluation.
 
         # threshold for clustering predictions together (wcs = weighted cluster scoring).
         # needs to be >= the expected overlap of predictions coming from one model (typically NMS threshold).
         # if too high, preds of the same object are separate clusters.
         self.wcs_iou = 1e-5
 
         self.plot_prediction_histograms = True
         self.plot_stat_curves = False
 
         #########################
         #   Data Augmentation   #
         #########################
 
         self.da_kwargs={
         'do_elastic_deform': True,
         'alpha':(0., 1500.),
         'sigma':(30., 50.),
         'do_rotation':True,
         'angle_x': (0., 2 * np.pi),
         'angle_y': (0., 0),
         'angle_z': (0., 0),
         'do_scale': True,
         'scale':(0.8, 1.1),
         'random_crop':False,
         'rand_crop_dist':  (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3),
         'border_mode_data': 'constant',
         'border_cval_data': 0,
         'order_data': 1
         }
 
         if self.dim == 3:
             self.da_kwargs['do_elastic_deform'] = False
             self.da_kwargs['angle_x'] = (0, 0.0)
             self.da_kwargs['angle_y'] = (0, 0.0) #must be 0!!
             self.da_kwargs['angle_z'] = (0., 2 * np.pi)
 
 
         #########################
         #   Add model specifics #
         #########################
 
         {'detection_unet': self.add_det_unet_configs,
          'mrcnn': self.add_mrcnn_configs,
          'ufrcnn': self.add_mrcnn_configs,
          'retina_net': self.add_mrcnn_configs,
          'retina_unet': self.add_mrcnn_configs,
         }[self.model]()
 
 
     def add_det_unet_configs(self):
 
-        self.learning_rate = [3e-4] * self.num_epochs
+        self.learning_rate = [1e-4] * self.num_epochs
 
         # aggregation from pixel perdiction to object scores (connected component). One of ['max', 'median']
         self.aggregation_operation = 'max'
 
         # max number of roi candidates to identify per batch element and class.
         self.n_roi_candidates = 10 if self.dim == 2 else 30
 
         # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
-        self.seg_loss_mode = 'wce'
+        self.seg_loss_mode = 'dice_wce'
 
         # if <1, false positive predictions in foreground are penalized less.
         self.fp_dice_weight = 1 if self.dim == 2 else 1
 
-        self.wce_weights = [0.1, 1, 1]
+        self.wce_weights = [0.3, 1, 1]
         self.detection_min_confidence = self.min_det_thresh
 
         # if 'True', loss distinguishes all classes, else only foreground vs. background (class agnostic).
         self.class_specific_seg_flag = True
         self.num_seg_classes = 3 if self.class_specific_seg_flag else 2
         self.head_classes = self.num_seg_classes
 
     def add_mrcnn_configs(self):
 
         # learning rate is a list with one entry per epoch.
         self.learning_rate = [3e-4] * self.num_epochs
 
         # disable the re-sampling of mask proposals to original size for speed-up.
         # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching),
         # mask-outputs are optional.
         self.return_masks_in_val = True
         self.return_masks_in_test = False
 
         # set number of proposal boxes to plot after each epoch.
         self.n_plot_rpn_props = 5 if self.dim == 2 else 30
 
         # number of classes for head networks: n_foreground_classes + 1 (background)
         self.head_classes = 3
 
         # seg_classes hier refers to the first stage classifier (RPN)
         self.num_seg_classes = 2  # foreground vs. background
 
         # feature map strides per pyramid level are inferred from architecture.
         self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]}
 
         # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale
         # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.)
         self.rpn_anchor_scales = {'xy': [[8], [16], [32], [64]], 'z': [[2], [4], [8], [16]]}
 
         # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3.
         self.pyramid_levels = [0, 1, 2, 3]
 
         # number of feature maps in rpn. typically lowered in 3D to save gpu-memory.
         self.n_rpn_features = 512 if self.dim == 2 else 128
 
         # anchor ratios and strides per position in feature maps.
         self.rpn_anchor_ratios = [0.5, 1, 2]
         self.rpn_anchor_stride = 1
 
         # Threshold for first stage (RPN) non-maximum suppression (NMS):  LOWER == HARDER SELECTION
         self.rpn_nms_threshold = 0.7 if self.dim == 2 else 0.7
 
         # loss sampling settings.
         self.rpn_train_anchors_per_image = 32  #per batch element
         self.train_rois_per_image = 6 #per batch element
         self.roi_positive_ratio = 0.5
         self.anchor_matching_iou = 0.7
 
         # factor of top-k candidates to draw from  per negative sample (stochastic-hard-example-mining).
         # poolsize to draw top-k candidates from will be shem_poolsize * n_negative_samples.
         self.shem_poolsize = 10
 
         self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3)
         self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5)
         self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10)
 
         self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
         self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
         self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]])
         self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1],
                                self.patch_size_3D[2], self.patch_size_3D[2]])
         if self.dim == 2:
             self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4]
             self.bbox_std_dev = self.bbox_std_dev[:4]
             self.window = self.window[:4]
             self.scale = self.scale[:4]
 
         # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element.
         self.pre_nms_limit = 3000 if self.dim == 2 else 6000
 
         # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True,
         # since proposals of the entire batch are forwarded through second stage in as one "batch".
         self.roi_chunk_size = 2500 if self.dim == 2 else 600
         self.post_nms_rois_training = 500 if self.dim == 2 else 75
         self.post_nms_rois_inference = 500
 
         # Final selection of detections (refine_detections)
         self.model_max_instances_per_batch_element = 10 if self.dim == 2 else 30  # per batch element and class.
         self.detection_nms_threshold = 1e-5  # needs to be > 0, otherwise all predictions are one cluster.
         self.model_min_confidence = 0.1
 
         if self.dim == 2:
             self.backbone_shapes = np.array(
                 [[int(np.ceil(self.patch_size[0] / stride)),
                   int(np.ceil(self.patch_size[1] / stride))]
                  for stride in self.backbone_strides['xy']])
         else:
             self.backbone_shapes = np.array(
                 [[int(np.ceil(self.patch_size[0] / stride)),
                   int(np.ceil(self.patch_size[1] / stride)),
                   int(np.ceil(self.patch_size[2] / stride_z))]
                  for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z']
                                              )])
 
         if self.model == 'ufrcnn':
             self.operate_stride1 = True
             self.class_specific_seg_flag = True
             self.num_seg_classes = 3 if self.class_specific_seg_flag else 2
             self.frcnn_mode = True
 
         if self.model == 'retina_net' or self.model == 'retina_unet' or self.model == 'prob_detector':
             # implement extra anchor-scales according to retina-net publication.
             self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                             self.rpn_anchor_scales['xy']]
             self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                            self.rpn_anchor_scales['z']]
             self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3
 
             self.n_rpn_features = 256 if self.dim == 2 else 64
 
             # pre-selection of detections for NMS-speedup. per entire batch.
             self.pre_nms_limit = 10000 if self.dim == 2 else 50000
 
             # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002
             self.anchor_matching_iou = 0.5
 
             # if 'True', seg loss distinguishes all classes, else only foreground vs. background (class agnostic).
             self.num_seg_classes = 3 if self.class_specific_seg_flag else 2
 
             if self.model == 'retina_unet':
                 self.operate_stride1 = True
diff --git a/experiments/toy_exp/configs.py b/experiments/toy_exp/configs.py
index 1b1870b..807cf1c 100644
--- a/experiments/toy_exp/configs.py
+++ b/experiments/toy_exp/configs.py
@@ -1,351 +1,351 @@
 #!/usr/bin/env python
 # Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import sys
 import os
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 import numpy as np
 from default_configs import DefaultConfigs
 
 class configs(DefaultConfigs):
 
     def __init__(self, server_env=None):
 
         #########################
         #    Preprocessing      #
         #########################
 
         self.root_dir = '/home/gregor/datasets/toy_mdt'
 
         #########################
         #         I/O           #
         #########################
 
 
         # one out of [2, 3]. dimension the model operates in.
         self.dim = 2
 
         # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn'].
         self.model = 'mrcnn'
 
         DefaultConfigs.__init__(self, self.model, server_env, self.dim)
 
         # int [0 < dataset_size]. select n patients from dataset for prototyping.
         self.select_prototype_subset = None
         self.hold_out_test_set = True
         # including val set. will be 3/4 train, 1/4 val.
         self.n_train_val_data = 2500
 
         # choose one of the 3 toy experiments described in https://arxiv.org/pdf/1811.08661.pdf
         # one of ['donuts_shape', 'donuts_pattern', 'circles_scale'].
         toy_mode = 'donuts_shape_noise'
 
         # path to preprocessed data.
         self.input_df_name = 'info_df.pickle'
         self.pp_name = os.path.join(toy_mode, 'train')
         self.pp_data_path = os.path.join(self.root_dir, self.pp_name)
         self.pp_test_name = os.path.join(toy_mode, 'test')
         self.pp_test_data_path = os.path.join(self.root_dir, self.pp_test_name)
 
         # settings for deployment in cloud.
         if server_env:
             # path to preprocessed data.
             pp_root_dir = '/datasets/datasets_ramien/toy_exp/data'
             self.pp_name = os.path.join(toy_mode, 'train')
             self.pp_data_path = os.path.join(pp_root_dir, self.pp_name)
             self.pp_test_name = os.path.join(toy_mode, 'test')
             self.pp_test_data_path = os.path.join(pp_root_dir, self.pp_test_name)
             self.select_prototype_subset = None
 
         #########################
         #      Data Loader      #
         #########################
 
         # select modalities from preprocessed data
         self.channels = [0]
         self.n_channels = len(self.channels)
 
         # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation.
         self.pre_crop_size_2D = [320, 320]
         self.patch_size_2D = [320, 320]
 
         self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D
         self.pre_crop_size = self.pre_crop_size_2D if self.dim == 2 else self.pre_crop_size_3D
 
         # ratio of free sampled batch elements before class balancing is triggered
         # (>0 to include "empty"/background patches.)
         self.batch_sample_slack = 0.2
 
         # set 2D network to operate in 3D images.
         self.merge_2D_to_3D_preds = False
 
         # feed +/- n neighbouring slices into channel dimension. set to None for no context.
         self.n_3D_context = None
         if self.n_3D_context is not None and self.dim == 2:
             self.n_channels *= (self.n_3D_context * 2 + 1)
 
 
         #########################
         #      Architecture      #
         #########################
 
         self.start_filts = 48 if self.dim == 2 else 18
         self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2
         self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50'
-        self.norm = "instance_norm" # one of None, 'instance_norm', 'batch_norm'
+        self.norm = None # one of None, 'instance_norm', 'batch_norm'
         self.weight_decay = 3e-5
 
         # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform')
         self.weight_init = None
 
         #########################
         #  Schedule / Selection #
         #########################
 
         self.num_epochs = 24
         self.num_train_batches = 100 if self.dim == 2 else 200
         self.batch_size = 20 if self.dim == 2 else 8
 
         self.do_validation = True
         # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training)
         # the former is morge accurate, while the latter is faster (depending on volume size)
         self.val_mode = 'val_patient' # one of 'val_sampling' , 'val_patient'
         if self.val_mode == 'val_patient':
             self.max_val_patients = None  # if 'None' iterates over entire val_set once.
         if self.val_mode == 'val_sampling':
             self.num_val_batches = 50
 
         # set dynamic_lr_scheduling to True to apply LR scheduling with below settings.
         self.dynamic_lr_scheduling = True
         self.lr_decay_factor = 0.5
         self.scheduling_patience = np.ceil(3600 / (self.num_train_batches * self.batch_size))
         self.scheduling_criterion = 'malignant_ap'
         self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
         #########################
         #   Testing / Plotting  #
         #########################
 
         # set the top-n-epochs to be saved for temporal averaging in testing.
         self.save_n_models = 5
         self.test_n_epochs = 5
 
         # set a minimum epoch number for saving in case of instabilities in the first phase of training.
         self.min_save_thresh = 0 if self.dim == 2 else 0
 
         self.report_score_level = ['patient', 'rois']  # choose list from 'patient', 'rois'
         self.class_dict = {1: 'benign', 2: 'malignant'}  # 0 is background.
         self.patient_class_of_interest = 2  # patient metrics are only plotted for one class.
         self.ap_match_ious = [0.1]  # list of ious to be evaluated for ap-scoring.
 
         self.model_selection_criteria = ['benign_ap', 'malignant_ap'] # criteria to average over for saving epochs.
         self.min_det_thresh = 0.1  # minimum confidence value to select predictions for evaluation.
 
         # threshold for clustering predictions together (wcs = weighted cluster scoring).
         # needs to be >= the expected overlap of predictions coming from one model (typically NMS threshold).
         # if too high, preds of the same object are separate clusters.
         self.wcs_iou = 1e-5
 
         self.plot_prediction_histograms = True
         self.plot_stat_curves = False
 
         #########################
         #   Data Augmentation   #
         #########################
 
         self.da_kwargs={
         'do_elastic_deform': True,
         'alpha':(0., 1500.),
         'sigma':(30., 50.),
         'do_rotation':True,
         'angle_x': (0., 2 * np.pi),
         'angle_y': (0., 0),
         'angle_z': (0., 0),
         'do_scale': True,
         'scale':(0.8, 1.1),
         'random_crop':False,
         'rand_crop_dist':  (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3),
         'border_mode_data': 'constant',
         'border_cval_data': 0,
         'order_data': 1
         }
 
         if self.dim == 3:
             self.da_kwargs['do_elastic_deform'] = False
             self.da_kwargs['angle_x'] = (0, 0.0)
             self.da_kwargs['angle_y'] = (0, 0.0) #must be 0!!
             self.da_kwargs['angle_z'] = (0., 2 * np.pi)
 
 
         #########################
         #   Add model specifics #
         #########################
 
         {'detection_unet': self.add_det_unet_configs,
          'mrcnn': self.add_mrcnn_configs,
          'ufrcnn': self.add_mrcnn_configs,
          'ufrcnn_surrounding': self.add_mrcnn_configs,
          'retina_net': self.add_mrcnn_configs,
          'retina_unet': self.add_mrcnn_configs,
          'prob_detector': self.add_mrcnn_configs,
         }[self.model]()
 
 
     def add_det_unet_configs(self):
 
-        self.learning_rate = [3e-4] * self.num_epochs
+        self.learning_rate = [1e-4] * self.num_epochs
 
         # aggregation from pixel perdiction to object scores (connected component). One of ['max', 'median']
         self.aggregation_operation = 'max'
 
         # max number of roi candidates to identify per image (slice in 2D, volume in 3D)
         self.n_roi_candidates = 3 if self.dim == 2 else 8
 
         # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
-        self.seg_loss_mode = 'wce'
+        self.seg_loss_mode = 'dice_wce'
 
         # if <1, false positive predictions in foreground are penalized less.
         self.fp_dice_weight = 1 if self.dim == 2 else 1
 
-        self.wce_weights = [0.1, 1, 1]
+        self.wce_weights = [0.3, 1, 1]
         self.detection_min_confidence = self.min_det_thresh
 
         # if 'True', loss distinguishes all classes, else only foreground vs. background (class agnostic).
         self.class_specific_seg_flag = True
         self.num_seg_classes = 3 if self.class_specific_seg_flag else 2
         self.head_classes = self.num_seg_classes
 
     def add_mrcnn_configs(self):
 
         # learning rate is a list with one entry per epoch.
         self.learning_rate = [3e-4] * self.num_epochs
 
         # disable mask head loss. (e.g. if no pixelwise annotations available)
         self.frcnn_mode = False
 
         # disable the re-sampling of mask proposals to original size for speed-up.
         # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching),
         # mask-outputs are optional.
         self.return_masks_in_val = True
         self.return_masks_in_test = False
 
         # set number of proposal boxes to plot after each epoch.
         self.n_plot_rpn_props = 0 if self.dim == 2 else 0
 
         # number of classes for head networks: n_foreground_classes + 1 (background)
         self.head_classes = 3
 
         # seg_classes hier refers to the first stage classifier (RPN)
         self.num_seg_classes = 2  # foreground vs. background
 
         # feature map strides per pyramid level are inferred from architecture.
         self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]}
 
         # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale
         # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.)
         self.rpn_anchor_scales = {'xy': [[8], [16], [32], [64]], 'z': [[2], [4], [8], [16]]}
 
         # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3.
         self.pyramid_levels = [0, 1, 2, 3]
 
         # number of feature maps in rpn. typically lowered in 3D to save gpu-memory.
         self.n_rpn_features = 512 if self.dim == 2 else 128
 
         # anchor ratios and strides per position in feature maps.
         self.rpn_anchor_ratios = [0.5, 1., 2.]
         self.rpn_anchor_stride = 1
 
         # Threshold for first stage (RPN) non-maximum suppression (NMS):  LOWER == HARDER SELECTION
         self.rpn_nms_threshold = 0.7 if self.dim == 2 else 0.7
 
         # loss sampling settings.
         self.rpn_train_anchors_per_image = 64 #per batch element
         self.train_rois_per_image = 2 #per batch element
         self.roi_positive_ratio = 0.5
         self.anchor_matching_iou = 0.7
 
         # factor of top-k candidates to draw from  per negative sample (stochastic-hard-example-mining).
         # poolsize to draw top-k candidates from will be shem_poolsize * n_negative_samples.
         self.shem_poolsize = 10
 
         self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3)
         self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5)
         self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10)
 
         self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
         self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
         self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1]])
         self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1]])
 
         if self.dim == 2:
             self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4]
             self.bbox_std_dev = self.bbox_std_dev[:4]
             self.window = self.window[:4]
             self.scale = self.scale[:4]
 
         # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element.
         self.pre_nms_limit = 3000 if self.dim == 2 else 6000
 
         # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True,
         # since proposals of the entire batch are forwarded through second stage in as one "batch".
         self.roi_chunk_size = 800 if self.dim == 2 else 600
         self.post_nms_rois_training = 500 if self.dim == 2 else 75
         self.post_nms_rois_inference = 500
 
         # Final selection of detections (refine_detections)
         self.model_max_instances_per_batch_element = 10 if self.dim == 2 else 30  # per batch element and class.
         self.detection_nms_threshold = 1e-5  # needs to be > 0, otherwise all predictions are one cluster.
         self.model_min_confidence = 0.1
 
         if self.dim == 2:
             self.backbone_shapes = np.array(
                 [[int(np.ceil(self.patch_size[0] / stride)),
                   int(np.ceil(self.patch_size[1] / stride))]
                  for stride in self.backbone_strides['xy']])
         else:
             self.backbone_shapes = np.array(
                 [[int(np.ceil(self.patch_size[0] / stride)),
                   int(np.ceil(self.patch_size[1] / stride)),
                   int(np.ceil(self.patch_size[2] / stride_z))]
                  for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z']
                                              )])
         if self.model == 'ufrcnn':
             self.operate_stride1 = True
             self.class_specific_seg_flag = True
             self.num_seg_classes = 3 if self.class_specific_seg_flag else 2
             self.frcnn_mode = True
 
         if self.model == 'retina_net' or self.model == 'retina_unet' or self.model == 'prob_detector':
             # implement extra anchor-scales according to retina-net publication.
             self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                             self.rpn_anchor_scales['xy']]
             self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                            self.rpn_anchor_scales['z']]
             self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3
 
             self.n_rpn_features = 256 if self.dim == 2 else 64
 
             # pre-selection of detections for NMS-speedup. per entire batch.
             self.pre_nms_limit = 10000 if self.dim == 2 else 50000
 
             # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002
             self.anchor_matching_iou = 0.5
 
             # if 'True', seg loss distinguishes all classes, else only foreground vs. background (class agnostic).
             self.num_seg_classes = 3 if self.class_specific_seg_flag else 2
 
             if self.model == 'retina_unet':
                 self.operate_stride1 = True
diff --git a/experiments/toy_exp/data_loader.py b/experiments/toy_exp/data_loader.py
index b5b8509..c123011 100644
--- a/experiments/toy_exp/data_loader.py
+++ b/experiments/toy_exp/data_loader.py
@@ -1,312 +1,312 @@
 #!/usr/bin/env python
 # Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import numpy as np
 import os
 from collections import OrderedDict
 import pandas as pd
 import pickle
 import time
 import subprocess
 import utils.dataloader_utils as dutils
 
 # batch generator tools from https://github.com/MIC-DKFZ/batchgenerators
 from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
 from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror
 from batchgenerators.transforms.abstract_transforms import Compose
 from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
 from batchgenerators.dataloading import SingleThreadedAugmenter
 from batchgenerators.transforms.spatial_transforms import SpatialTransform
 from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform
 from batchgenerators.transforms.utility_transforms import ConvertSegToBoundingBoxCoordinates
 
 
 
 def get_train_generators(cf, logger):
     """
     wrapper function for creating the training batch generator pipeline. returns the train/val generators.
     selects patients according to cv folds (generated by first run/fold of experiment):
     splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set)
     If cf.hold_out_test_set is True, adds the test split to the training data.
     """
     all_data = load_dataset(cf, logger)
     all_pids_list = np.unique([v['pid'] for (k, v) in all_data.items()])
 
     assert cf.n_train_val_data <= len(all_pids_list), \
         "requested {} train val samples, but dataset only has {} train val samples.".format(
             cf.n_train_val_data, len(all_pids_list))
     train_pids = all_pids_list[:int(2*cf.n_train_val_data//3)]
     val_pids = all_pids_list[int(np.ceil(2*cf.n_train_val_data//3)):cf.n_train_val_data]
 
     train_data = {k: v for (k, v) in all_data.items() if any(p == v['pid'] for p in train_pids)}
     val_data = {k: v for (k, v) in all_data.items() if any(p == v['pid'] for p in val_pids)}
 
     logger.info("data set loaded with: {} train / {} val patients".format(len(train_pids), len(val_pids)))
     batch_gen = {}
     batch_gen['train'] = create_data_gen_pipeline(train_data, cf=cf, do_aug=False)
     batch_gen['val_sampling'] = create_data_gen_pipeline(val_data, cf=cf, do_aug=False)
     if cf.val_mode == 'val_patient':
         batch_gen['val_patient'] = PatientBatchIterator(val_data, cf=cf)
         batch_gen['n_val'] = len(val_pids) if cf.max_val_patients is None else min(len(val_pids), cf.max_val_patients)
     else:
         batch_gen['n_val'] = cf.num_val_batches
 
     return batch_gen
 
 
 def get_test_generator(cf, logger):
     """
     wrapper function for creating the test batch generator pipeline.
     selects patients according to cv folds (generated by first run/fold of experiment)
     If cf.hold_out_test_set is True, gets the data from an external folder instead.
     """
     if cf.hold_out_test_set:
         pp_name = cf.pp_test_name
         test_ix = None
     else:
         pp_name = None
         with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle:
             fold_list = pickle.load(handle)
         _, _, test_ix, _ = fold_list[cf.fold]
         # warnings.warn('WARNING: using validation set for testing!!!')
 
     test_data = load_dataset(cf, logger, test_ix, pp_data_path=cf.pp_test_data_path, pp_name=pp_name)
     logger.info("data set loaded with: {} test patients from {}".format(len(test_data.keys()), cf.pp_test_data_path))
     batch_gen = {}
     batch_gen['test'] = PatientBatchIterator(test_data, cf=cf)
     batch_gen['n_test'] = len(test_data.keys()) if cf.max_test_patients=="all" else \
         min(cf.max_test_patients, len(test_data.keys()))
 
     return batch_gen
 
 
 
 def load_dataset(cf, logger, subset_ixs=None, pp_data_path=None, pp_name=None):
     """
     loads the dataset. if deployed in cloud also copies and unpacks the data to the working directory.
     :param subset_ixs: subset indices to be loaded from the dataset. used e.g. for testing to only load the test folds.
     :return: data: dictionary with one entry per patient (in this case per patient-breast, since they are treated as
     individual images for training) each entry is a dictionary containing respective meta-info as well as paths to the preprocessed
     numpy arrays to be loaded during batch-generation
     """
     if pp_data_path is None:
         pp_data_path = cf.pp_data_path
     if pp_name is None:
         pp_name = cf.pp_name
     if cf.server_env:
         copy_data = True
         target_dir = os.path.join(cf.data_dest, pp_name)
         if not os.path.exists(target_dir):
             cf.data_source_dir = pp_data_path
             os.makedirs(target_dir)
             subprocess.call('rsync -av {} {}'.format(
                 os.path.join(cf.data_source_dir, cf.input_df_name), os.path.join(target_dir, cf.input_df_name)), shell=True)
             logger.info('created target dir and info df at {}'.format(os.path.join(target_dir, cf.input_df_name)))
 
         elif subset_ixs is None:
             copy_data = False
 
         pp_data_path = target_dir
 
 
     p_df = pd.read_pickle(os.path.join(pp_data_path, cf.input_df_name))
 
 
     if subset_ixs is not None:
         subset_pids = [np.unique(p_df.pid.tolist())[ix] for ix in subset_ixs]
         p_df = p_df[p_df.pid.isin(subset_pids)]
         logger.info('subset: selected {} instances from df'.format(len(p_df)))
 
     if cf.server_env:
         if copy_data:
             copy_and_unpack_data(logger, p_df.pid.tolist(), cf.fold_dir, cf.data_source_dir, target_dir)
 
     class_targets = p_df['class_id'].tolist()
     pids = p_df.pid.tolist()
     imgs = [os.path.join(pp_data_path, '{}.npy'.format(pid)) for pid in pids]
     segs = [os.path.join(pp_data_path,'{}.npy'.format(pid)) for pid in pids]
 
     data = OrderedDict()
     for ix, pid in enumerate(pids):
 
         data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid, 'class_target': [class_targets[ix]]}
 
     return data
 
 
 
 def create_data_gen_pipeline(patient_data, cf, do_aug=True):
     """
     create mutli-threaded train/val/test batch generation and augmentation pipeline.
     :param patient_data: dictionary containing one dictionary per patient in the train/test subset.
     :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing)
     :return: multithreaded_generator
     """
 
     # create instance of batch generator as first element in pipeline.
     data_gen = BatchGenerator(patient_data, batch_size=cf.batch_size, cf=cf)
 
     # add transformations to pipeline.
     my_transforms = []
     if do_aug:
         mirror_transform = Mirror(axes=np.arange(2, cf.dim+2, 1))
         my_transforms.append(mirror_transform)
         spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
                                              patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
                                              do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
                                              alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
                                              do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
                                              angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
                                              do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
                                              random_crop=cf.da_kwargs['random_crop'])
 
         my_transforms.append(spatial_transform)
     else:
         my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))
 
     my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, get_rois_from_seg_flag=False, class_specific_seg_flag=cf.class_specific_seg_flag))
     all_transforms = Compose(my_transforms)
     # multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms)
     multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
     return multithreaded_generator
 
 
 class BatchGenerator(SlimDataLoaderBase):
     """
     creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D)
     from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size.
     Actual patch_size is obtained after data augmentation.
     :param data: data dictionary as provided by 'load_dataset'.
     :param batch_size: number of patients to sample for the batch
     :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target
     """
     def __init__(self, data, batch_size, cf):
         super(BatchGenerator, self).__init__(data, batch_size)
 
         self.cf = cf
 
     def generate_train_batch(self):
 
         batch_data, batch_segs, batch_pids, batch_targets = [], [], [], []
         class_targets_list =  [v['class_target'] for (k, v) in self._data.items()]
 
         #samples patients towards equilibrium of foreground classes on a roi-level (after randomly sampling the ratio "batch_sample_slack).
         batch_ixs = dutils.get_class_balanced_patients(
             class_targets_list, self.batch_size, self.cf.head_classes - 1, slack_factor=self.cf.batch_sample_slack)
         patients = list(self._data.items())
 
         for b in batch_ixs:
 
             patient = patients[b][1]
             all_data = np.load(patient['data'], mmap_mode='r')
             data = all_data[0]
             seg = all_data[1].astype('uint8')
             batch_pids.append(patient['pid'])
             batch_targets.append(patient['class_target'])
             batch_data.append(data[np.newaxis])
             batch_segs.append(seg[np.newaxis])
 
         data = np.array(batch_data)
         seg = np.array(batch_segs).astype(np.uint8)
         class_target = np.array(batch_targets)
         return {'data': data, 'seg': seg, 'pid': batch_pids, 'class_target': class_target}
 
 
 
 class PatientBatchIterator(SlimDataLoaderBase):
     """
     creates a test generator that iterates over entire given dataset returning 1 patient per batch.
     Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actualy evaluation (done in 3D),
     if willing to accept speed-loss during training.
     :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or
     batch_size = n_2D_patches in 2D .
     """
     def __init__(self, data, cf): #threads in augmenter
         super(PatientBatchIterator, self).__init__(data, 0)
         self.cf = cf
         self.patient_ix = 0
         self.dataset_pids = [v['pid'] for (k, v) in data.items()]
         self.patch_size = cf.patch_size
         if len(self.patch_size) == 2:
             self.patch_size = self.patch_size + [1]
 
 
     def generate_train_batch(self):
 
         pid = self.dataset_pids[self.patient_ix]
         patient = self._data[pid]
         all_data = np.load(patient['data'], mmap_mode='r')
         data = all_data[0]
         seg = all_data[1].astype('uint8')
         batch_class_targets = np.array([patient['class_target']])
 
         out_data = data[None, None]
         out_seg = seg[None, None]
 
         #print('check patient data loader', out_data.shape, out_seg.shape)
         batch_2D = {'data': out_data, 'seg': out_seg, 'class_target': batch_class_targets, 'pid': pid}
         converter = ConvertSegToBoundingBoxCoordinates(dim=2, get_rois_from_seg_flag=False, class_specific_seg_flag=self.cf.class_specific_seg_flag)
         batch_2D = converter(**batch_2D)
 
         batch_2D.update({'patient_bb_target': batch_2D['bb_target'],
                          'patient_roi_labels': batch_2D['roi_labels'],
                          'original_img_shape': out_data.shape})
 
         self.patient_ix += 1
         if self.patient_ix == len(self.dataset_pids):
             self.patient_ix = 0
 
         return batch_2D
 
 def copy_and_unpack_data(logger, pids, fold_dir, source_dir, target_dir):
 
 
     start_time = time.time()
     with open(os.path.join(fold_dir, 'file_list.txt'), 'w') as handle:
         for pid in pids:
             handle.write('{}.npy\n'.format(pid))
 
     subprocess.call('rsync -ahv --files-from {} {} {}'.format(os.path.join(fold_dir, 'file_list.txt'),
         source_dir, target_dir), shell=True)
     # dutils.unpack_dataset(target_dir)
     copied_files = os.listdir(target_dir)
-    logger.info("copying and unpacking data set finished : {} files in target dir: {}. took {} sec".format(
+    logger.info("copying data set finished : {} files in target dir: {}. took {} sec".format(
         len(copied_files), target_dir, np.round(time.time() - start_time, 0)))
 
 if __name__=="__main__":
     import utils.exp_utils as utils
 
     total_stime = time.time()
     cf_file = utils.import_module("cf", "configs.py")
     cf = cf_file.configs()
 
     logger = utils.get_logger("dev")
     batch_gen = get_train_generators(cf, logger)
 
     train_batch = next(batch_gen["train"])
     pids = []
     total = 100
     for i in range(total):
         print("\r producing batch {}/{}.".format(i, total), end="", flush=True)
         train_batch = next(batch_gen["train"])
         pids.append(train_batch["pid"])
     print()
 
 
     mins, secs = divmod((time.time() - total_stime), 60)
     h, mins = divmod(mins, 60)
     t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
     print("{} total runtime: {}".format(os.path.split(__file__)[1], t))
\ No newline at end of file
diff --git a/utils/dataloader_utils.py b/utils/dataloader_utils.py
index 062af62..b328985 100644
--- a/utils/dataloader_utils.py
+++ b/utils/dataloader_utils.py
@@ -1,278 +1,280 @@
 #!/usr/bin/env python
 # Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import numpy as np
 import os
 from multiprocessing import Pool
 
 
 
 def get_class_balanced_patients(class_targets, batch_size, num_classes, slack_factor=0.1):
     '''
     samples patients towards equilibrium of classes on a roi-level. For highly imbalanced datasets, this might be a too strong requirement.
     Hence a slack factor determines the ratio of the batch, that is randomly sampled, before class-balance is triggered.
     :param class_targets: list of patient targets. where each patient target is a list of class labels of respective rois.
     :param batch_size:
     :param num_classes:
     :param slack_factor:
     :return: batch_ixs: list of indices referring to a subset in class_targets-list, sampled to build one batch.
     '''
     batch_ixs = []
     class_count = {k: 0 for k in range(num_classes)}
     weakest_class = 0
     for ix in range(batch_size):
 
         keep_looking = True
         while keep_looking:
             #choose a random patient.
             cand = np.random.choice(len(class_targets), 1)[0]
             # check the least occuring class among this patient's rois.
             tmp_weakest_class = np.argmin([class_targets[cand].count(ii) for ii in range(num_classes)])
             # if current batch already bigger than the slack_factor ratio, then
             # check that weakest class in this patient is not the weakest in current batch (since needs to be boosted)
             # also that at least one roi of this patient belongs to weakest class. If True, keep patient, else keep looking.
             if (tmp_weakest_class != weakest_class and class_targets[cand].count(weakest_class) > 0) or ix < int(batch_size * slack_factor):
                 keep_looking = False
 
         for c in range(num_classes):
             class_count[c] += class_targets[cand].count(c)
         weakest_class = np.argmin(([class_count[c] for c in range(num_classes)]))
         batch_ixs.append(cand)
 
     return batch_ixs
 
 
 
 class fold_generator:
     """
     generates splits of indices for a given length of a dataset to perform n-fold cross-validation.
     splits each fold into 3 subsets for training, validation and testing.
     This form of cross validation uses an inner loop test set, which is useful if test scores shall be reported on a
     statistically reliable amount of patients, despite limited size of a dataset.
     If hold out test set is provided and hence no inner loop test set needed, just add test_idxs to the training data in the dataloader.
     This creates straight-forward train-val splits.
     :returns names list: list of len n_splits. each element is a list of len 3 for train_ix, val_ix, test_ix.
     """
     def __init__(self, seed, n_splits, len_data):
         """
         :param seed: Random seed for splits.
         :param n_splits: number of splits, e.g. 5 splits for 5-fold cross-validation
         :param len_data: number of elements in the dataset.
         """
         self.tr_ix = []
         self.val_ix = []
         self.te_ix = []
         self.slicer = None
         self.missing = 0
         self.fold = 0
         self.len_data = len_data
         self.n_splits = n_splits
         self.myseed = seed
         self.boost_val = 0
 
     def init_indices(self):
 
         t = list(np.arange(self.l))
         # round up to next splittable data amount.
         split_length = int(np.ceil(len(t) / float(self.n_splits)))
         self.slicer = split_length
         self.mod = len(t) % self.n_splits
         if self.mod > 0:
             # missing is the number of folds, in which the new splits are reduced to account for missing data.
             self.missing = self.n_splits - self.mod
 
         self.te_ix = t[:self.slicer]
         self.tr_ix = t[self.slicer:]
         self.val_ix = self.tr_ix[:self.slicer]
         self.tr_ix = self.tr_ix[self.slicer:]
 
     def new_fold(self):
 
         slicer = self.slicer
         if self.fold < self.missing :
             slicer = self.slicer - 1
 
         temp = self.te_ix
 
         # catch exception mod == 1: test set collects 1+ data since walk through both roudned up splits.
         # account for by reducing last fold split by 1.
         if self.fold == self.n_splits-2 and self.mod ==1:
             temp += self.val_ix[-1:]
             self.val_ix = self.val_ix[:-1]
 
         self.te_ix = self.val_ix
         self.val_ix = self.tr_ix[:slicer]
         self.tr_ix = self.tr_ix[slicer:] + temp
 
 
     def get_fold_names(self):
         names_list = []
         rgen = np.random.RandomState(self.myseed)
         cv_names = np.arange(self.len_data)
 
         rgen.shuffle(cv_names)
         self.l = len(cv_names)
         self.init_indices()
 
         for split in range(self.n_splits):
             train_names, val_names, test_names = cv_names[self.tr_ix], cv_names[self.val_ix], cv_names[self.te_ix]
             names_list.append([train_names, val_names, test_names, self.fold])
             self.new_fold()
             self.fold += 1
 
         return names_list
 
 
 
 def get_patch_crop_coords(img, patch_size, min_overlap=30):
     """
 
     _:param img (y, x, (z))
     _:param patch_size: list of len 2 (2D) or 3 (3D).
     _:param min_overlap: minimum required overlap of patches.
     If too small, some areas are poorly represented only at edges of single patches.
     _:return ndarray: shape (n_patches, 2*dim). crop coordinates for each patch.
     """
     crop_coords = []
     for dim in range(len(img.shape)):
         n_patches = int(np.ceil(img.shape[dim] / patch_size[dim]))
 
         # no crops required in this dimension, add image shape as coordinates.
         if n_patches == 1:
             crop_coords.append([(0, img.shape[dim])])
             continue
 
         # fix the two outside patches to coords patchsize/2 and interpolate.
         center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1)
 
         if (patch_size[dim] - center_dists) < min_overlap:
             n_patches += 1
             center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1)
 
         patch_centers = np.round([(patch_size[dim] / 2 + (center_dists * ii)) for ii in range(n_patches)])
         dim_crop_coords = [(center - patch_size[dim] / 2, center + patch_size[dim] / 2) for center in patch_centers]
         crop_coords.append(dim_crop_coords)
 
     coords_mesh_grid = []
     for ymin, ymax in crop_coords[0]:
         for xmin, xmax in crop_coords[1]:
             if len(crop_coords) == 3 and patch_size[2] > 1:
                 for zmin, zmax in crop_coords[2]:
                     coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmax])
             elif len(crop_coords) == 3 and patch_size[2] == 1:
                 for zmin in range(img.shape[2]):
                     coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmin + 1])
             else:
                 coords_mesh_grid.append([ymin, ymax, xmin, xmax])
     return np.array(coords_mesh_grid).astype(int)
 
 
 
 def pad_nd_image(image, new_shape=None, mode="edge", kwargs=None, return_slicer=False, shape_must_be_divisible_by=None):
     """
     one padder to pad them all. Documentation? Well okay. A little bit. by Fabian Isensee
 
     :param image: nd image. can be anything
     :param new_shape: what shape do you want? new_shape does not have to have the same dimensionality as image. If
     len(new_shape) < len(image.shape) then the last axes of image will be padded. If new_shape < image.shape in any of
     the axes then we will not pad that axis, but also not crop! (interpret new_shape as new_min_shape)
     Example:
     image.shape = (10, 1, 512, 512); new_shape = (768, 768) -> result: (10, 1, 768, 768). Cool, huh?
     image.shape = (10, 1, 512, 512); new_shape = (364, 768) -> result: (10, 1, 512, 768).
 
     :param mode: see np.pad for documentation
     :param return_slicer: if True then this function will also return what coords you will need to use when cropping back
     to original shape
     :param shape_must_be_divisible_by: for network prediction. After applying new_shape, make sure the new shape is
     divisibly by that number (can also be a list with an entry for each axis). Whatever is missing to match that will
     be padded (so the result may be larger than new_shape if shape_must_be_divisible_by is not None)
     :param kwargs: see np.pad for documentation
     """
     if kwargs is None:
         kwargs = {}
 
     if new_shape is not None:
         old_shape = np.array(image.shape[-len(new_shape):])
     else:
         assert shape_must_be_divisible_by is not None
         assert isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray))
         new_shape = image.shape[-len(shape_must_be_divisible_by):]
         old_shape = new_shape
 
     num_axes_nopad = len(image.shape) - len(new_shape)
 
     new_shape = [max(new_shape[i], old_shape[i]) for i in range(len(new_shape))]
 
     if not isinstance(new_shape, np.ndarray):
         new_shape = np.array(new_shape)
 
     if shape_must_be_divisible_by is not None:
         if not isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)):
             shape_must_be_divisible_by = [shape_must_be_divisible_by] * len(new_shape)
         else:
             assert len(shape_must_be_divisible_by) == len(new_shape)
 
         for i in range(len(new_shape)):
             if new_shape[i] % shape_must_be_divisible_by[i] == 0:
                 new_shape[i] -= shape_must_be_divisible_by[i]
 
         new_shape = np.array([new_shape[i] + shape_must_be_divisible_by[i] - new_shape[i] % shape_must_be_divisible_by[i] for i in range(len(new_shape))])
 
     difference = new_shape - old_shape
     pad_below = difference // 2
     pad_above = difference // 2 + difference % 2
     pad_list = [[0, 0]]*num_axes_nopad + list([list(i) for i in zip(pad_below, pad_above)])
     res = np.pad(image, pad_list, mode, **kwargs)
     if not return_slicer:
         return res
     else:
         pad_list = np.array(pad_list)
         pad_list[:, 1] = np.array(res.shape) - pad_list[:, 1]
         slicer = list(slice(*i) for i in pad_list)
         return res, slicer
 
 
 #############################
 #  data packing / unpacking #
 #############################
 
 def get_case_identifiers(folder):
     case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith("npz")]
     return case_identifiers
 
 
-def convert_to_npy(npz_file):
+def convert_to_npy(npz_file, remove=False):
     identifier = os.path.split(npz_file)[1][:-4]
     if not os.path.isfile(npz_file[:-4] + ".npy"):
         a = np.load(npz_file)[identifier]
         np.save(npz_file[:-4] + ".npy", a)
+    if remove:
+        os.remove(npz_file)
 
 
 def unpack_dataset(folder, threads=8):
     case_identifiers = get_case_identifiers(folder)
     p = Pool(threads)
     npz_files = [os.path.join(folder, i + ".npz") for i in case_identifiers]
-    p.map(convert_to_npy, npz_files)
+    p.starmap(convert_to_npy, [(f, True) for f in npz_files])
     p.close()
     p.join()
 
 
 def delete_npy(folder):
     case_identifiers = get_case_identifiers(folder)
     npy_files = [os.path.join(folder, i + ".npy") for i in case_identifiers]
     npy_files = [i for i in npy_files if os.path.isfile(i)]
     for n in npy_files:
         os.remove(n)
\ No newline at end of file
diff --git a/utils/model_utils.py b/utils/model_utils.py
index 70c1fae..7d74d20 100644
--- a/utils/model_utils.py
+++ b/utils/model_utils.py
@@ -1,1012 +1,1012 @@
 #!/usr/bin/env python
 # Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 """
 Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn
 published under MIT license.
 """
 
 import numpy as np
 import scipy.misc
 import scipy.ndimage
 import scipy.interpolate
 import torch
 from torch.autograd import Variable
 import torch.nn as nn
 
 import tqdm
 ############################################################
 #  Bounding Boxes
 ############################################################
 
 
 def compute_iou_2D(box, boxes, box_area, boxes_area):
     """Calculates IoU of the given box with the array of the given boxes.
     box: 1D vector [y1, x1, y2, x2] THIS IS THE GT BOX
     boxes: [boxes_count, (y1, x1, y2, x2)]
     box_area: float. the area of 'box'
     boxes_area: array of length boxes_count.
 
     Note: the areas are passed in rather than calculated here for
           efficency. Calculate once in the caller to avoid duplicate work.
     """
     # Calculate intersection areas
     y1 = np.maximum(box[0], boxes[:, 0])
     y2 = np.minimum(box[2], boxes[:, 2])
     x1 = np.maximum(box[1], boxes[:, 1])
     x2 = np.minimum(box[3], boxes[:, 3])
     intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0)
     union = box_area + boxes_area[:] - intersection[:]
     iou = intersection / union
 
     return iou
 
 
 
 def compute_iou_3D(box, boxes, box_volume, boxes_volume):
     """Calculates IoU of the given box with the array of the given boxes.
     box: 1D vector [y1, x1, y2, x2, z1, z2] (typically gt box)
     boxes: [boxes_count, (y1, x1, y2, x2, z1, z2)]
     box_area: float. the area of 'box'
     boxes_area: array of length boxes_count.
 
     Note: the areas are passed in rather than calculated here for
           efficency. Calculate once in the caller to avoid duplicate work.
     """
     # Calculate intersection areas
     y1 = np.maximum(box[0], boxes[:, 0])
     y2 = np.minimum(box[2], boxes[:, 2])
     x1 = np.maximum(box[1], boxes[:, 1])
     x2 = np.minimum(box[3], boxes[:, 3])
     z1 = np.maximum(box[4], boxes[:, 4])
     z2 = np.minimum(box[5], boxes[:, 5])
     intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0) * np.maximum(z2 - z1, 0)
     union = box_volume + boxes_volume[:] - intersection[:]
     iou = intersection / union
 
     return iou
 
 
 
 def compute_overlaps(boxes1, boxes2):
     """Computes IoU overlaps between two sets of boxes.
     boxes1, boxes2: [N, (y1, x1, y2, x2)]. / 3D: (z1, z2))
     For better performance, pass the largest set first and the smaller second.
     """
     # Areas of anchors and GT boxes
     if boxes1.shape[1] == 4:
         area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
         area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
         # Compute overlaps to generate matrix [boxes1 count, boxes2 count]
         # Each cell contains the IoU value.
         overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0]))
         for i in range(overlaps.shape[1]):
             box2 = boxes2[i] #this is the gt box
             overlaps[:, i] = compute_iou_2D(box2, boxes1, area2[i], area1)
         return overlaps
 
     else:
         # Areas of anchors and GT boxes
         volume1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) * (boxes1[:, 5] - boxes1[:, 4])
         volume2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) * (boxes2[:, 5] - boxes2[:, 4])
         # Compute overlaps to generate matrix [boxes1 count, boxes2 count]
         # Each cell contains the IoU value.
         overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0]))
         for i in range(overlaps.shape[1]):
             box2 = boxes2[i]  # this is the gt box
             overlaps[:, i] = compute_iou_3D(box2, boxes1, volume2[i], volume1)
         return overlaps
 
 
 
 def box_refinement(box, gt_box):
     """Compute refinement needed to transform box to gt_box.
     box and gt_box are [N, (y1, x1, y2, x2)] / 3D: (z1, z2))
     """
     height = box[:, 2] - box[:, 0]
     width = box[:, 3] - box[:, 1]
     center_y = box[:, 0] + 0.5 * height
     center_x = box[:, 1] + 0.5 * width
 
     gt_height = gt_box[:, 2] - gt_box[:, 0]
     gt_width = gt_box[:, 3] - gt_box[:, 1]
     gt_center_y = gt_box[:, 0] + 0.5 * gt_height
     gt_center_x = gt_box[:, 1] + 0.5 * gt_width
 
     dy = (gt_center_y - center_y) / height
     dx = (gt_center_x - center_x) / width
     dh = torch.log(gt_height / height)
     dw = torch.log(gt_width / width)
     result = torch.stack([dy, dx, dh, dw], dim=1)
 
     if box.shape[1] > 4:
         depth = box[:, 5] - box[:, 4]
         center_z = box[:, 4] + 0.5 * depth
         gt_depth = gt_box[:, 5] - gt_box[:, 4]
         gt_center_z = gt_box[:, 4] + 0.5 * gt_depth
         dz = (gt_center_z - center_z) / depth
         dd = torch.log(gt_depth / depth)
         result = torch.stack([dy, dx, dz, dh, dw, dd], dim=1)
 
     return result
 
 
 
 def unmold_mask_2D(mask, bbox, image_shape):
     """Converts a mask generated by the neural network into a format similar
     to it's original shape.
     mask: [height, width] of type float. A small, typically 28x28 mask.
     bbox: [y1, x1, y2, x2]. The box to fit the mask in.
 
     Returns a binary mask with the same size as the original image.
     """
     y1, x1, y2, x2 = bbox
     out_zoom = [y2 - y1, x2 - x1]
     zoom_factor = [i / j for i, j in zip(out_zoom, mask.shape)]
     mask = scipy.ndimage.zoom(mask, zoom_factor, order=1).astype(np.float32)
 
     # Put the mask in the right location.
     full_mask = np.zeros(image_shape[:2])
     full_mask[y1:y2, x1:x2] = mask
     return full_mask
 
 
 
 def unmold_mask_3D(mask, bbox, image_shape):
     """Converts a mask generated by the neural network into a format similar
     to it's original shape.
     mask: [height, width] of type float. A small, typically 28x28 mask.
     bbox: [y1, x1, y2, x2, z1, z2]. The box to fit the mask in.
 
     Returns a binary mask with the same size as the original image.
     """
     y1, x1, y2, x2, z1, z2 = bbox
     out_zoom = [y2 - y1, x2 - x1, z2 - z1]
     zoom_factor = [i/j for i,j in zip(out_zoom, mask.shape)]
     mask = scipy.ndimage.zoom(mask, zoom_factor, order=1).astype(np.float32)
 
     # Put the mask in the right location.
     full_mask = np.zeros(image_shape[:3])
     full_mask[y1:y2, x1:x2, z1:z2] = mask
     return full_mask
 
 
 ############################################################
 #  Anchors
 ############################################################
 
 def generate_anchors(scales, ratios, shape, feature_stride, anchor_stride):
     """
     scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128]
     ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2]
     shape: [height, width] spatial shape of the feature map over which
             to generate anchors.
     feature_stride: Stride of the feature map relative to the image in pixels.
     anchor_stride: Stride of anchors on the feature map. For example, if the
         value is 2 then generate anchors for every other feature map pixel.
     """
     # Get all combinations of scales and ratios
     scales, ratios = np.meshgrid(np.array(scales), np.array(ratios))
     scales = scales.flatten()
     ratios = ratios.flatten()
 
     # Enumerate heights and widths from scales and ratios
     heights = scales / np.sqrt(ratios)
     widths = scales * np.sqrt(ratios)
 
     # Enumerate shifts in feature space
     shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride
     shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride
     shifts_x, shifts_y = np.meshgrid(shifts_x, shifts_y)
 
     # Enumerate combinations of shifts, widths, and heights
     box_widths, box_centers_x = np.meshgrid(widths, shifts_x)
     box_heights, box_centers_y = np.meshgrid(heights, shifts_y)
 
     # Reshape to get a list of (y, x) and a list of (h, w)
     box_centers = np.stack(
         [box_centers_y, box_centers_x], axis=2).reshape([-1, 2])
     box_sizes = np.stack([box_heights, box_widths], axis=2).reshape([-1, 2])
 
     # Convert to corner coordinates (y1, x1, y2, x2)
     boxes = np.concatenate([box_centers - 0.5 * box_sizes,
                             box_centers + 0.5 * box_sizes], axis=1)
     return boxes
 
 
 
 def generate_anchors_3D(scales_xy, scales_z, ratios, shape, feature_stride_xy, feature_stride_z, anchor_stride):
     """
     scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128]
     ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2]
     shape: [height, width] spatial shape of the feature map over which
             to generate anchors.
     feature_stride: Stride of the feature map relative to the image in pixels.
     anchor_stride: Stride of anchors on the feature map. For example, if the
         value is 2 then generate anchors for every other feature map pixel.
     """
     # Get all combinations of scales and ratios
 
     scales_xy, ratios_meshed = np.meshgrid(np.array(scales_xy), np.array(ratios))
     scales_xy = scales_xy.flatten()
     ratios_meshed = ratios_meshed.flatten()
 
     # Enumerate heights and widths from scales and ratios
     heights = scales_xy / np.sqrt(ratios_meshed)
     widths = scales_xy * np.sqrt(ratios_meshed)
     depths = np.tile(np.array(scales_z), len(ratios_meshed)//np.array(scales_z)[..., None].shape[0])
 
     # Enumerate shifts in feature space
     shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride_xy #translate from fm positions to input coords.
     shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride_xy
     shifts_z = np.arange(0, shape[2], anchor_stride) * (feature_stride_z)
     shifts_x, shifts_y, shifts_z = np.meshgrid(shifts_x, shifts_y, shifts_z)
 
     # Enumerate combinations of shifts, widths, and heights
     box_widths, box_centers_x = np.meshgrid(widths, shifts_x)
     box_heights, box_centers_y = np.meshgrid(heights, shifts_y)
     box_depths, box_centers_z = np.meshgrid(depths, shifts_z)
 
     # Reshape to get a list of (y, x, z) and a list of (h, w, d)
     box_centers = np.stack(
         [box_centers_y, box_centers_x, box_centers_z], axis=2).reshape([-1, 3])
     box_sizes = np.stack([box_heights, box_widths, box_depths], axis=2).reshape([-1, 3])
 
     # Convert to corner coordinates (y1, x1, y2, x2, z1, z2)
     boxes = np.concatenate([box_centers - 0.5 * box_sizes,
                             box_centers + 0.5 * box_sizes], axis=1)
 
     boxes = np.transpose(np.array([boxes[:, 0], boxes[:, 1], boxes[:, 3], boxes[:, 4], boxes[:, 2], boxes[:, 5]]), axes=(1, 0))
     return boxes
 
 
 def generate_pyramid_anchors(logger, cf):
     """Generate anchors at different levels of a feature pyramid. Each scale
     is associated with a level of the pyramid, but each ratio is used in
     all levels of the pyramid.
 
     from configs:
     :param scales: cf.RPN_ANCHOR_SCALES , e.g. [4, 8, 16, 32]
     :param ratios: cf.RPN_ANCHOR_RATIOS , e.g. [0.5, 1, 2]
     :param feature_shapes: cf.BACKBONE_SHAPES , e.g.  [array of shapes per feature map] [80, 40, 20, 10, 5]
     :param feature_strides: cf.BACKBONE_STRIDES , e.g. [2, 4, 8, 16, 32, 64]
     :param anchors_stride: cf.RPN_ANCHOR_STRIDE , e.g. 1
     :return anchors: (N, (y1, x1, y2, x2, (z1), (z2)). All generated anchors in one array. Sorted
     with the same order of the given scales. So, anchors of scale[0] come first, then anchors of scale[1], and so on.
     """
     scales = cf.rpn_anchor_scales
     ratios = cf.rpn_anchor_ratios
     feature_shapes = cf.backbone_shapes
     anchor_stride = cf.rpn_anchor_stride
     pyramid_levels = cf.pyramid_levels
     feature_strides = cf.backbone_strides
 
     anchors = []
     logger.info("feature map shapes: {}".format(feature_shapes))
     logger.info("anchor scales: {}".format(scales))
 
     expected_anchors = [np.prod(feature_shapes[ii]) * len(ratios) * len(scales['xy'][ii]) for ii in pyramid_levels]
 
     for lix, level in enumerate(pyramid_levels):
         if len(feature_shapes[level]) == 2:
             anchors.append(generate_anchors(scales['xy'][level], ratios, feature_shapes[level],
                                             feature_strides['xy'][level], anchor_stride))
         else:
             anchors.append(generate_anchors_3D(scales['xy'][level], scales['z'][level], ratios, feature_shapes[level],
                                             feature_strides['xy'][level], feature_strides['z'][level], anchor_stride))
 
         logger.info("level {}: built anchors {} / expected anchors {} ||| total build {} / total expected {}".format(
             level, anchors[-1].shape, expected_anchors[lix], np.concatenate(anchors).shape, np.sum(expected_anchors)))
 
     out_anchors = np.concatenate(anchors, axis=0)
     return out_anchors
 
 
 
 def apply_box_deltas_2D(boxes, deltas):
     """Applies the given deltas to the given boxes.
     boxes: [N, 4] where each row is y1, x1, y2, x2
     deltas: [N, 4] where each row is [dy, dx, log(dh), log(dw)]
     """
     # Convert to y, x, h, w
     height = boxes[:, 2] - boxes[:, 0]
     width = boxes[:, 3] - boxes[:, 1]
     center_y = boxes[:, 0] + 0.5 * height
     center_x = boxes[:, 1] + 0.5 * width
     # Apply deltas
     center_y += deltas[:, 0] * height
     center_x += deltas[:, 1] * width
     height *= torch.exp(deltas[:, 2])
     width *= torch.exp(deltas[:, 3])
     # Convert back to y1, x1, y2, x2
     y1 = center_y - 0.5 * height
     x1 = center_x - 0.5 * width
     y2 = y1 + height
     x2 = x1 + width
     result = torch.stack([y1, x1, y2, x2], dim=1)
     return result
 
 
 
 def apply_box_deltas_3D(boxes, deltas):
     """Applies the given deltas to the given boxes.
     boxes: [N, 6] where each row is y1, x1, y2, x2, z1, z2
     deltas: [N, 6] where each row is [dy, dx, dz, log(dh), log(dw), log(dd)]
     """
     # Convert to y, x, h, w
     height = boxes[:, 2] - boxes[:, 0]
     width = boxes[:, 3] - boxes[:, 1]
     depth = boxes[:, 5] - boxes[:, 4]
     center_y = boxes[:, 0] + 0.5 * height
     center_x = boxes[:, 1] + 0.5 * width
     center_z = boxes[:, 4] + 0.5 * depth
     # Apply deltas
     center_y += deltas[:, 0] * height
     center_x += deltas[:, 1] * width
     center_z += deltas[:, 2] * depth
     height *= torch.exp(deltas[:, 3])
     width *= torch.exp(deltas[:, 4])
     depth *= torch.exp(deltas[:, 5])
     # Convert back to y1, x1, y2, x2
     y1 = center_y - 0.5 * height
     x1 = center_x - 0.5 * width
     z1 = center_z - 0.5 * depth
     y2 = y1 + height
     x2 = x1 + width
     z2 = z1 + depth
     result = torch.stack([y1, x1, y2, x2, z1, z2], dim=1)
     return result
 
 
 
 def clip_boxes_2D(boxes, window):
     """
     boxes: [N, 4] each col is y1, x1, y2, x2
     window: [4] in the form y1, x1, y2, x2
     """
     boxes = torch.stack( \
         [boxes[:, 0].clamp(float(window[0]), float(window[2])),
          boxes[:, 1].clamp(float(window[1]), float(window[3])),
          boxes[:, 2].clamp(float(window[0]), float(window[2])),
          boxes[:, 3].clamp(float(window[1]), float(window[3]))], 1)
     return boxes
 
 def clip_boxes_3D(boxes, window):
     """
     boxes: [N, 6] each col is y1, x1, y2, x2, z1, z2
     window: [6] in the form y1, x1, y2, x2, z1, z2
     """
     boxes = torch.stack( \
         [boxes[:, 0].clamp(float(window[0]), float(window[2])),
          boxes[:, 1].clamp(float(window[1]), float(window[3])),
          boxes[:, 2].clamp(float(window[0]), float(window[2])),
          boxes[:, 3].clamp(float(window[1]), float(window[3])),
          boxes[:, 4].clamp(float(window[4]), float(window[5])),
          boxes[:, 5].clamp(float(window[4]), float(window[5]))], 1)
     return boxes
 
 
 
 def clip_boxes_numpy(boxes, window):
     """
     boxes: [N, 4] each col is y1, x1, y2, x2 / [N, 6] in 3D.
     window: iamge shape (y, x, (z))
     """
     if boxes.shape[1] == 4:
         boxes = np.concatenate(
             (np.clip(boxes[:, 0], 0, window[0])[:, None],
             np.clip(boxes[:, 1], 0, window[0])[:, None],
             np.clip(boxes[:, 2], 0, window[1])[:, None],
             np.clip(boxes[:, 3], 0, window[1])[:, None]), 1
         )
 
     else:
         boxes = np.concatenate(
             (np.clip(boxes[:, 0], 0, window[0])[:, None],
              np.clip(boxes[:, 1], 0, window[0])[:, None],
              np.clip(boxes[:, 2], 0, window[1])[:, None],
              np.clip(boxes[:, 3], 0, window[1])[:, None],
              np.clip(boxes[:, 4], 0, window[2])[:, None],
              np.clip(boxes[:, 5], 0, window[2])[:, None]), 1
         )
 
     return boxes
 
 
 
 def bbox_overlaps_2D(boxes1, boxes2):
     """Computes IoU overlaps between two sets of boxes.
     boxes1, boxes2: [N, (y1, x1, y2, x2)].
     """
     # 1. Tile boxes2 and repeate boxes1. This allows us to compare
     # every boxes1 against every boxes2 without loops.
     # TF doesn't have an equivalent to np.repeate() so simulate it
     # using tf.tile() and tf.reshape.
     boxes1_repeat = boxes2.size()[0]
     boxes2_repeat = boxes1.size()[0]
     boxes1 = boxes1.repeat(1,boxes1_repeat).view(-1,4)
     boxes2 = boxes2.repeat(boxes2_repeat,1)
 
     # 2. Compute intersections
     b1_y1, b1_x1, b1_y2, b1_x2 = boxes1.chunk(4, dim=1)
     b2_y1, b2_x1, b2_y2, b2_x2 = boxes2.chunk(4, dim=1)
     y1 = torch.max(b1_y1, b2_y1)[:, 0]
     x1 = torch.max(b1_x1, b2_x1)[:, 0]
     y2 = torch.min(b1_y2, b2_y2)[:, 0]
     x2 = torch.min(b1_x2, b2_x2)[:, 0]
     zeros = Variable(torch.zeros(y1.size()[0]), requires_grad=False)
     if y1.is_cuda:
         zeros = zeros.cuda()
     intersection = torch.max(x2 - x1, zeros) * torch.max(y2 - y1, zeros)
 
     # 3. Compute unions
     b1_area = (b1_y2 - b1_y1) * (b1_x2 - b1_x1)
     b2_area = (b2_y2 - b2_y1) * (b2_x2 - b2_x1)
     union = b1_area[:,0] + b2_area[:,0] - intersection
 
     # 4. Compute IoU and reshape to [boxes1, boxes2]
     iou = intersection / union
     overlaps = iou.view(boxes2_repeat, boxes1_repeat)
     return overlaps
 
 
 
 def bbox_overlaps_3D(boxes1, boxes2):
     """Computes IoU overlaps between two sets of boxes.
     boxes1, boxes2: [N, (y1, x1, y2, x2, z1, z2)].
     """
     # 1. Tile boxes2 and repeate boxes1. This allows us to compare
     # every boxes1 against every boxes2 without loops.
     # TF doesn't have an equivalent to np.repeate() so simulate it
     # using tf.tile() and tf.reshape.
     boxes1_repeat = boxes2.size()[0]
     boxes2_repeat = boxes1.size()[0]
     boxes1 = boxes1.repeat(1,boxes1_repeat).view(-1,6)
     boxes2 = boxes2.repeat(boxes2_repeat,1)
 
     # 2. Compute intersections
     b1_y1, b1_x1, b1_y2, b1_x2, b1_z1, b1_z2 = boxes1.chunk(6, dim=1)
     b2_y1, b2_x1, b2_y2, b2_x2, b2_z1, b2_z2 = boxes2.chunk(6, dim=1)
     y1 = torch.max(b1_y1, b2_y1)[:, 0]
     x1 = torch.max(b1_x1, b2_x1)[:, 0]
     y2 = torch.min(b1_y2, b2_y2)[:, 0]
     x2 = torch.min(b1_x2, b2_x2)[:, 0]
     z1 = torch.max(b1_z1, b2_z1)[:, 0]
     z2 = torch.min(b1_z2, b2_z2)[:, 0]
     zeros = Variable(torch.zeros(y1.size()[0]), requires_grad=False)
     if y1.is_cuda:
         zeros = zeros.cuda()
     intersection = torch.max(x2 - x1, zeros) * torch.max(y2 - y1, zeros) * torch.max(z2 - z1, zeros)
 
     # 3. Compute unions
     b1_volume = (b1_y2 - b1_y1) * (b1_x2 - b1_x1)  * (b1_z2 - b1_z1)
     b2_volume = (b2_y2 - b2_y1) * (b2_x2 - b2_x1)  * (b2_z2 - b2_z1)
     union = b1_volume[:,0] + b2_volume[:,0] - intersection
 
     # 4. Compute IoU and reshape to [boxes1, boxes2]
     iou = intersection / union
     overlaps = iou.view(boxes2_repeat, boxes1_repeat)
     return overlaps
 
 
 
 def gt_anchor_matching(cf, anchors, gt_boxes, gt_class_ids=None):
     """Given the anchors and GT boxes, compute overlaps and identify positive
     anchors and deltas to refine them to match their corresponding GT boxes.
 
     anchors: [num_anchors, (y1, x1, y2, x2, (z1), (z2))]
     gt_boxes: [num_gt_boxes, (y1, x1, y2, x2, (z1), (z2))]
     gt_class_ids (optional): [num_gt_boxes] Integer class IDs for one stage detectors. in RPN case of Mask R-CNN,
     set all positive matches to 1 (foreground)
 
     Returns:
     anchor_class_matches: [N] (int32) matches between anchors and GT boxes.
                1 = positive anchor, -1 = negative anchor, 0 = neutral.
                In case of one stage detectors like RetinaNet/RetinaUNet this flag takes
                class_ids as positive anchor values, i.e. values >= 1!
     anchor_delta_targets: [N, (dy, dx, (dz), log(dh), log(dw), (log(dd)))] Anchor bbox deltas.
     """
 
     anchor_class_matches = np.zeros([anchors.shape[0]], dtype=np.int32)
     anchor_delta_targets = np.zeros((cf.rpn_train_anchors_per_image, 2*cf.dim))
     anchor_matching_iou = cf.anchor_matching_iou
 
     if gt_boxes is None:
         anchor_class_matches = np.full(anchor_class_matches.shape, fill_value=-1)
         return anchor_class_matches, anchor_delta_targets
 
     # for mrcnn: anchor matching is done for RPN loss, so positive labels are all 1 (foreground)
     if gt_class_ids is None:
         gt_class_ids = np.array([1] * len(gt_boxes))
 
     # Compute overlaps [num_anchors, num_gt_boxes]
     overlaps = compute_overlaps(anchors, gt_boxes)
 
     # Match anchors to GT Boxes
     # If an anchor overlaps a GT box with IoU >= anchor_matching_iou then it's positive.
     # If an anchor overlaps a GT box with IoU < 0.1 then it's negative.
     # Neutral anchors are those that don't match the conditions above,
     # and they don't influence the loss function.
     # However, don't keep any GT box unmatched (rare, but happens). Instead,
     # match it to the closest anchor (even if its max IoU is < 0.1).
 
     # 1. Set negative anchors first. They get overwritten below if a GT box is
     # matched to them. Skip boxes in crowd areas.
     anchor_iou_argmax = np.argmax(overlaps, axis=1)
     anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax]
     if anchors.shape[1] == 4:
         anchor_class_matches[(anchor_iou_max < 0.1)] = -1
     elif anchors.shape[1] == 6:
         anchor_class_matches[(anchor_iou_max < 0.01)] = -1
     else:
         raise ValueError('anchor shape wrong {}'.format(anchors.shape))
 
     # 2. Set an anchor for each GT box (regardless of IoU value).
     gt_iou_argmax = np.argmax(overlaps, axis=0)
     for ix, ii in enumerate(gt_iou_argmax):
         anchor_class_matches[ii] = gt_class_ids[ix]
 
     # 3. Set anchors with high overlap as positive.
     above_trhesh_ixs = np.argwhere(anchor_iou_max >= anchor_matching_iou)
     anchor_class_matches[above_trhesh_ixs] = gt_class_ids[anchor_iou_argmax[above_trhesh_ixs]]
 
     # Subsample to balance positive anchors.
     ids = np.where(anchor_class_matches > 0)[0]
     # extra == these positive anchors are too many --> reset them to negative ones.
     extra = len(ids) - (cf.rpn_train_anchors_per_image // 2)
     if extra > 0:
         # Reset the extra ones to neutral
         extra_ids = np.random.choice(ids, extra, replace=False)
         anchor_class_matches[extra_ids] = 0
 
     # Leave all negative proposals negative now and sample from them in online hard example mining.
     # For positive anchors, compute shift and scale needed to transform them to match the corresponding GT boxes.
-    #ids = np.where(anchor_class_matches > 0)[0]
+    ids = np.where(anchor_class_matches > 0)[0]
     ix = 0  # index into anchor_delta_targets
     for i, a in zip(ids, anchors[ids]):
         # closest gt box (it might have IoU < anchor_matching_iou)
         gt = gt_boxes[anchor_iou_argmax[i]]
 
         # convert coordinates to center plus width/height.
         gt_h = gt[2] - gt[0]
         gt_w = gt[3] - gt[1]
         gt_center_y = gt[0] + 0.5 * gt_h
         gt_center_x = gt[1] + 0.5 * gt_w
         # Anchor
         a_h = a[2] - a[0]
         a_w = a[3] - a[1]
         a_center_y = a[0] + 0.5 * a_h
         a_center_x = a[1] + 0.5 * a_w
 
         if cf.dim == 2:
             anchor_delta_targets[ix] = [
                 (gt_center_y - a_center_y) / a_h,
                 (gt_center_x - a_center_x) / a_w,
                 np.log(gt_h / a_h),
                 np.log(gt_w / a_w),
             ]
 
         else:
             gt_d = gt[5] - gt[4]
             gt_center_z = gt[4] + 0.5 * gt_d
             a_d = a[5] - a[4]
             a_center_z = a[4] + 0.5 * a_d
 
             anchor_delta_targets[ix] = [
                 (gt_center_y - a_center_y) / a_h,
                 (gt_center_x - a_center_x) / a_w,
                 (gt_center_z - a_center_z) / a_d,
                 np.log(gt_h / a_h),
                 np.log(gt_w / a_w),
                 np.log(gt_d / a_d)
             ]
 
         # normalize.
         anchor_delta_targets[ix] /= cf.rpn_bbox_std_dev
         ix += 1
 
     return anchor_class_matches, anchor_delta_targets
 
 
 
 def clip_to_window(window, boxes):
     """
         window: (y1, x1, y2, x2) / 3D: (z1, z2). The window in the image we want to clip to.
         boxes: [N, (y1, x1, y2, x2)]  / 3D: (z1, z2)
     """
     boxes[:, 0] = boxes[:, 0].clamp(float(window[0]), float(window[2]))
     boxes[:, 1] = boxes[:, 1].clamp(float(window[1]), float(window[3]))
     boxes[:, 2] = boxes[:, 2].clamp(float(window[0]), float(window[2]))
     boxes[:, 3] = boxes[:, 3].clamp(float(window[1]), float(window[3]))
 
     if boxes.shape[1] > 5:
         boxes[:, 4] = boxes[:, 4].clamp(float(window[4]), float(window[5]))
         boxes[:, 5] = boxes[:, 5].clamp(float(window[4]), float(window[5]))
 
     return boxes
 
 
 def nms_numpy(box_coords, scores, thresh):
     """ non-maximum suppression on 2D or 3D boxes in numpy.
     :param box_coords: [y1,x1,y2,x2 (,z1,z2)] with y1<=y2, x1<=x2, z1<=z2.
     :param scores: ranking scores (higher score == higher rank) of boxes.
     :param thresh: IoU threshold for clustering.
     :return:
     """
     y1 = box_coords[:, 0]
     x1 = box_coords[:, 1]
     y2 = box_coords[:, 2]
     x2 = box_coords[:, 3]
     assert np.all(y1 <= y2) and np.all(x1 <= x2), """"the definition of the coordinates is crucially important here: 
             coordinates of which maxima are taken need to be the lower coordinates"""
     areas = (x2 - x1) * (y2 - y1)
 
     is_3d = box_coords.shape[1] == 6
     if is_3d: # 3-dim case
         z1 = box_coords[:, 4]
         z2 = box_coords[:, 5]
         assert np.all(z1<=z2), """"the definition of the coordinates is crucially important here: 
            coordinates of which maxima are taken need to be the lower coordinates"""
         areas *= (z2 - z1)
 
     order = scores.argsort()[::-1]
 
     keep = []
     while order.size > 0:  # order is the sorted index.  maps order to index: order[1] = 24 means (rank1, ix 24)
         i = order[0] # highest scoring element
         yy1 = np.maximum(y1[i], y1[order])  # highest scoring element still in >order<, is compared to itself, that is okay.
         xx1 = np.maximum(x1[i], x1[order])
         yy2 = np.minimum(y2[i], y2[order])
         xx2 = np.minimum(x2[i], x2[order])
 
         h = np.maximum(0.0, yy2 - yy1)
         w = np.maximum(0.0, xx2 - xx1)
         inter = h * w
 
         if is_3d:
             zz1 = np.maximum(z1[i], z1[order])
             zz2 = np.minimum(z2[i], z2[order])
             d = np.maximum(0.0, zz2 - zz1)
             inter *= d
 
         iou = inter / (areas[i] + areas[order] - inter)
 
         non_matches = np.nonzero(iou <= thresh)[0]  # get all elements that were not matched and discard all others.
         order = order[non_matches]
         keep.append(i)
 
     return keep
 
 def roi_align_3d_numpy(input: np.ndarray, rois, output_size: tuple,
                        spatial_scale: float = 1., sampling_ratio: int = -1) -> np.ndarray:
     """ This fct mainly serves as a verification method for 3D CUDA implementation of RoIAlign, it's highly
         inefficient due to the nested loops.
     :param input:  (ndarray[N, C, H, W, D]): input feature map
     :param rois: list (N,K(n), 6), K(n) = nr of rois in batch-element n, single roi of format (y1,x1,y2,x2,z1,z2)
     :param output_size:
     :param spatial_scale:
     :param sampling_ratio:
     :return: (List[N, K(n), C, output_size[0], output_size[1], output_size[2]])
     """
 
     out_height, out_width, out_depth = output_size
 
     coord_grid = tuple([np.linspace(0, input.shape[dim] - 1, num=input.shape[dim]) for dim in range(2, 5)])
     pooled_rois = [[]] * len(rois)
     assert len(rois) == input.shape[0], "batch dim mismatch, rois: {}, input: {}".format(len(rois), input.shape[0])
     print("Numpy 3D RoIAlign progress:", end="\n")
     for b in range(input.shape[0]):
         for roi in tqdm.tqdm(rois[b]):
             y1, x1, y2, x2, z1, z2 = np.array(roi) * spatial_scale
             roi_height = max(float(y2 - y1), 1.)
             roi_width = max(float(x2 - x1), 1.)
             roi_depth = max(float(z2 - z1), 1.)
 
             if sampling_ratio <= 0:
                 sampling_ratio_h = int(np.ceil(roi_height / out_height))
                 sampling_ratio_w = int(np.ceil(roi_width / out_width))
                 sampling_ratio_d = int(np.ceil(roi_depth / out_depth))
             else:
                 sampling_ratio_h = sampling_ratio_w = sampling_ratio_d = sampling_ratio  # == n points per bin
 
             bin_height = roi_height / out_height
             bin_width = roi_width / out_width
             bin_depth = roi_depth / out_depth
 
             n_points = sampling_ratio_h * sampling_ratio_w * sampling_ratio_d
             pooled_roi = np.empty((input.shape[1], out_height, out_width, out_depth), dtype="float32")
             for chan in range(input.shape[1]):
                 lin_interpolator = scipy.interpolate.RegularGridInterpolator(coord_grid, input[b, chan],
                                                                              method="linear")
                 for bin_iy in range(out_height):
                     for bin_ix in range(out_width):
                         for bin_iz in range(out_depth):
 
                             bin_val = 0.
                             for i in range(sampling_ratio_h):
                                 for j in range(sampling_ratio_w):
                                     for k in range(sampling_ratio_d):
                                         loc_ijk = [
                                             y1 + bin_iy * bin_height + (i + 0.5) * (bin_height / sampling_ratio_h),
                                             x1 + bin_ix * bin_width + (j + 0.5) * (bin_width / sampling_ratio_w),
                                             z1 + bin_iz * bin_depth + (k + 0.5) * (bin_depth / sampling_ratio_d)]
                                         # print("loc_ijk", loc_ijk)
                                         if not (np.any([c < -1.0 for c in loc_ijk]) or loc_ijk[0] > input.shape[2] or
                                                 loc_ijk[1] > input.shape[3] or loc_ijk[2] > input.shape[4]):
                                             for catch_case in range(3):
                                                 # catch on-border cases
                                                 if int(loc_ijk[catch_case]) == input.shape[catch_case + 2] - 1:
                                                     loc_ijk[catch_case] = input.shape[catch_case + 2] - 1
                                             bin_val += lin_interpolator(loc_ijk)
                             pooled_roi[chan, bin_iy, bin_ix, bin_iz] = bin_val / n_points
 
             pooled_rois[b].append(pooled_roi)
 
     return np.array(pooled_rois)
 
 
 ############################################################
 #  Pytorch Utility Functions
 ############################################################
 
 
 def unique1d(tensor):
     if tensor.shape[0] == 0 or tensor.shape[0] == 1:
         return tensor
     tensor = tensor.sort()[0]
     unique_bool = tensor[1:] != tensor [:-1]
     first_element = torch.tensor([True], dtype=torch.bool, requires_grad=False)
     if tensor.is_cuda:
         first_element = first_element.cuda()
     unique_bool = torch.cat((first_element, unique_bool),dim=0)
     return tensor[unique_bool]
 
 
 
 def log2(x):
     """Implementatin of Log2. Pytorch doesn't have a native implemenation."""
     ln2 = Variable(torch.log(torch.FloatTensor([2.0])), requires_grad=False)
     if x.is_cuda:
         ln2 = ln2.cuda()
     return torch.log(x) / ln2
 
 
 
 def intersect1d(tensor1, tensor2):
     aux = torch.cat((tensor1, tensor2), dim=0)
     aux = aux.sort(descending=True)[0]
     return aux[:-1][(aux[1:] == aux[:-1]).data]
 
 
 
 def shem(roi_probs_neg, negative_count, ohem_poolsize):
     """
     stochastic hard example mining: from a list of indices (referring to non-matched predictions),
     determine a pool of highest scoring (worst false positives) of size negative_count*ohem_poolsize.
     Then, sample n (= negative_count) predictions of this pool as negative examples for loss.
     :param roi_probs_neg: tensor of shape (n_predictions, n_classes).
     :param negative_count: int.
     :param ohem_poolsize: int.
     :return: (negative_count).  indices refer to the positions in roi_probs_neg. If pool smaller than expected due to
     limited negative proposals availabel, this function will return sampled indices of number < negative_count without
     throwing an error.
     """
     # sort according to higehst foreground score.
     probs, order = roi_probs_neg[:, 1:].max(1)[0].sort(descending=True)
     select = torch.tensor((ohem_poolsize * int(negative_count), order.size()[0])).min().int()
     pool_indices = order[:select]
     rand_idx = torch.randperm(pool_indices.size()[0])
     return pool_indices[rand_idx[:negative_count].cuda()]
 
 
 
 def initialize_weights(net):
     """
    Initialize model weights. Current Default in Pytorch (version 0.4.1) is initialization from a uniform distriubtion.
    Will expectably be changed to kaiming_uniform in future versions.
    """
     init_type = net.cf.weight_init
 
     for m in [module for module in net.modules() if type(module) in [nn.Conv2d, nn.Conv3d,
                                                                      nn.ConvTranspose2d,
                                                                      nn.ConvTranspose3d,
                                                                      nn.Linear]]:
         if init_type == 'xavier_uniform':
             nn.init.xavier_uniform_(m.weight.data)
             if m.bias is not None:
                 m.bias.data.zero_()
 
         elif init_type == 'xavier_normal':
             nn.init.xavier_normal_(m.weight.data)
             if m.bias is not None:
                 m.bias.data.zero_()
 
         elif init_type == "kaiming_uniform":
             nn.init.kaiming_uniform_(m.weight.data, mode='fan_out', nonlinearity=net.cf.relu, a=0)
             if m.bias is not None:
                 fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                 bound = 1 / np.sqrt(fan_out)
                 nn.init.uniform_(m.bias, -bound, bound)
 
         elif init_type == "kaiming_normal":
             nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity=net.cf.relu, a=0)
             if m.bias is not None:
                 fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                 bound = 1 / np.sqrt(fan_out)
                 nn.init.normal_(m.bias, -bound, bound)
 
 
 
 class NDConvGenerator(object):
     """
     generic wrapper around conv-layers to avoid 2D vs. 3D distinguishing in code.
     """
     def __init__(self, dim):
         self.dim = dim
 
     def __call__(self, c_in, c_out, ks, pad=0, stride=1, norm=None, relu='relu'):
         """
         :param c_in: number of in_channels.
         :param c_out: number of out_channels.
         :param ks: kernel size.
         :param pad: pad size.
         :param stride: kernel stride.
         :param norm: string specifying type of feature map normalization. If None, no normalization is applied.
         :param relu: string specifying type of nonlinearity. If None, no nonlinearity is applied.
         :return: convolved feature_map.
         """
         if self.dim == 2:
             conv = nn.Conv2d(c_in, c_out, kernel_size=ks, padding=pad, stride=stride)
             if norm is not None:
                 if norm == 'instance_norm':
                     norm_layer = nn.InstanceNorm2d(c_out)
                 elif norm == 'batch_norm':
                     norm_layer = nn.BatchNorm2d(c_out)
                 else:
                     raise ValueError('norm type as specified in configs is not implemented...')
                 conv = nn.Sequential(conv, norm_layer)
 
         else:
             conv = nn.Conv3d(c_in, c_out, kernel_size=ks, padding=pad, stride=stride)
             if norm is not None:
                 if norm == 'instance_norm':
                     norm_layer = nn.InstanceNorm3d(c_out)
                 elif norm == 'batch_norm':
                     norm_layer = nn.BatchNorm3d(c_out)
                 else:
                     raise ValueError('norm type as specified in configs is not implemented... {}'.format(norm))
                 conv = nn.Sequential(conv, norm_layer)
 
         if relu is not None:
             if relu == 'relu':
                 relu_layer = nn.ReLU(inplace=True)
             elif relu == 'leaky_relu':
                 relu_layer = nn.LeakyReLU(inplace=True)
             else:
                 raise ValueError('relu type as specified in configs is not implemented...')
             conv = nn.Sequential(conv, relu_layer)
 
         return conv
 
 
 
 def get_one_hot_encoding(y, n_classes):
     """
     transform a numpy label array to a one-hot array of the same shape.
     :param y: array of shape (b, 1, y, x, (z)).
     :param n_classes: int, number of classes to unfold in one-hot encoding.
     :return y_ohe: array of shape (b, n_classes, y, x, (z))
     """
     dim = len(y.shape) - 2
     if dim == 2:
         y_ohe = np.zeros((y.shape[0], n_classes, y.shape[2], y.shape[3])).astype('int32')
     if dim ==3:
         y_ohe = np.zeros((y.shape[0], n_classes, y.shape[2], y.shape[3], y.shape[4])).astype('int32')
     for cl in range(n_classes):
         y_ohe[:, cl][y[:, 0] == cl] = 1
     return y_ohe
 
 
 
 def get_dice_per_batch_and_class(pred, y, n_classes):
     '''
     computes dice scores per batch instance and class.
     :param pred: prediction array of shape (b, 1, y, x, (z)) (e.g. softmax prediction with argmax over dim 1)
     :param y: ground truth array of shape (b, 1, y, x, (z)) (contains int [0, ..., n_classes]
     :param n_classes: int
     :return: dice scores of shape (b, c)
     '''
     pred = get_one_hot_encoding(pred, n_classes)
     y = get_one_hot_encoding(y, n_classes)
     axes = tuple(range(2, len(pred.shape)))
     intersect = np.sum(pred*y, axis=axes)
     denominator = np.sum(pred, axis=axes)+np.sum(y, axis=axes) + 1e-8
     dice = 2.0*intersect / denominator
     return dice
 
 
 
 def sum_tensor(input, axes, keepdim=False):
     axes = np.unique(axes)
     if keepdim:
         for ax in axes:
             input = input.sum(ax, keepdim=True)
     else:
         for ax in sorted(axes, reverse=True):
             input = input.sum(int(ax))
     return input
 
 
 
 def batch_dice(pred, y, false_positive_weight=1.0, smooth=1e-6):
     '''
     compute soft dice over batch. this is a differentiable score and can be used as a loss function.
     only dice scores of foreground classes are returned, since training typically
     does not benefit from explicit background optimization. Pixels of the entire batch are considered a pseudo-volume to compute dice scores of.
     This way, single patches with missing foreground classes can not produce faulty gradients.
     :param pred: (b, c, y, x, (z)), softmax probabilities (network output). (c==classes)
     :param y: (b, c, y, x, (z)), one-hot-encoded segmentation mask.
     :param false_positive_weight: float [0,1]. For weighting of imbalanced classes,
     reduces the penalty for false-positive pixels. Can be beneficial sometimes in data with heavy fg/bg imbalances.
     :return: soft dice score (float). This function discards the background score and returns the mean of foreground scores.
     '''
     if len(pred.size()) == 4:
         axes = (0, 2, 3)
         intersect = sum_tensor(pred * y, axes, keepdim=False)
         denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False)
         return torch.mean(( (2 * intersect + smooth) / (denom + smooth) )[1:]) # only fg dice here.
 
     elif len(pred.size()) == 5:
         axes = (0, 2, 3, 4)
         intersect = sum_tensor(pred * y, axes, keepdim=False)
         denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False)
         return torch.mean(( (2*intersect + smooth) / (denom + smooth) )[1:]) # only fg dice here.
 
     else:
         raise ValueError('wrong input dimension in dice loss')
 
 
 
 
 def batch_dice_mask(pred, y, mask, false_positive_weight=1.0, smooth=1e-6):
     '''
     compute soft dice over batch. this is a diffrentiable score and can be used as a loss function.
     only dice scores of foreground classes are returned, since training typically
     does not benefit from explicit background optimization. Pixels of the entire batch are considered a pseudo-volume to compute dice scores of.
     This way, single patches with missing foreground classes can not produce faulty gradients.
     :param pred: (b, c, y, x, (z)), softmax probabilities (network output).
     :param y: (b, c, y, x, (z)), one hote encoded segmentation mask.
     :param false_positive_weight: float [0,1]. For weighting of imbalanced classes,
     reduces the penalty for false-positive pixels. Can be beneficial sometimes in data with heavy fg/bg imbalances.
     :return: soft dice score (float). This function discards the background score and returns the mean of foreground scores.
     '''
 
     mask = mask.unsqueeze(1).repeat(1, 2, 1, 1)
 
     if len(pred.size()) == 4:
         axes = (0, 2, 3)
         intersect = sum_tensor(pred * y * mask, axes, keepdim=False)
         denom = sum_tensor(false_positive_weight*pred * mask + y * mask, axes, keepdim=False)
         return torch.mean(( (2*intersect + smooth) / (denom + smooth))[1:]) # only fg dice here.
 
     elif len(pred.size()) == 5:
         axes = (0, 2, 3, 4)
         intersect = sum_tensor(pred * y, axes, keepdim=False)
         denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False)
         return torch.mean(( (2*intersect + smooth) / (denom + smooth) )[1:]) # only fg dice here.
 
     else:
         raise ValueError('wrong input dimension in dice loss')
\ No newline at end of file