diff --git a/datasets/toy/configs.py b/datasets/toy/configs.py
index 8f81931..fe1686a 100644
--- a/datasets/toy/configs.py
+++ b/datasets/toy/configs.py
@@ -1,490 +1,490 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import sys
 import os
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 import numpy as np
 from default_configs import DefaultConfigs
 from collections import namedtuple
 
 boxLabel = namedtuple('boxLabel', ["name", "color"])
 Label = namedtuple("Label", ['id', 'name', 'shape', 'radius', 'color', 'regression', 'ambiguities', 'gt_distortion'])
 binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'bin_vals'])
 
 class Configs(DefaultConfigs):
 
     def __init__(self, server_env=None):
         super(Configs, self).__init__(server_env)
 
         #########################
         #         Prepro        #
         #########################
 
         self.pp_rootdir = os.path.join('/mnt/HDD2TB/Documents/data/toy', "cyl1ps_dev")
         self.pp_npz_dir = self.pp_rootdir+"_npz"
 
         self.pre_crop_size = [320,320,8] #y,x,z; determines pp data shape (2D easily implementable, but only 3D for now)
         self.min_2d_radius = 6 #in pixels
         self.n_train_samples, self.n_test_samples = 80, 80
 
         # not actually real one-hot encoding (ohe) but contains more info: roi-overlap only within classes.
         self.pp_create_ohe_seg = False
         self.pp_empty_samples_ratio = 0.1
 
         self.pp_place_radii_mid_bin = True
         self.pp_only_distort_2d = True
         # outer-most intensity of blurred radii, relative to inner-object intensity. <1 for decreasing, > 1 for increasing.
         # e.g.: setting 0.1 means blurred edge has min intensity 10% as large as inner-object intensity.
         self.pp_blur_min_intensity = 0.2
 
         self.max_instances_per_sample = 1 #how many max instances over all classes per sample (img if 2d, vol if 3d)
         self.max_instances_per_class = self.max_instances_per_sample  # how many max instances per image per class
         self.noise_scale = 0.  # std-dev of gaussian noise
 
         self.ambigs_sampling = "gaussian" #"gaussian" or "uniform"
         """ radius_calib: gt distort for calibrating uncertainty. Range of gt distortion is inferable from
             image by distinguishing it from the rest of the object.
             blurring width around edge will be shifted so that symmetric rel to orig radius.
             blurring scale: if self.ambigs_sampling is uniform, distribution's non-zero range (b-a) will be sqrt(12)*scale
             since uniform dist has variance (b-a)²/12. b,a will be placed symmetrically around unperturbed radius.
             if sampling is gaussian, then scale parameter sets one std dev, i.e., blurring width will be orig_radius * std_dev * 2.
         """
         self.ambiguities = {
              #set which classes to apply which ambs to below in class labels
              #choose out of: 'outer_radius', 'inner_radius', 'radii_relations'.
              #kind              #probability   #scale (gaussian std, relative to unperturbed value)
             #"outer_radius":     (1.,            0.5),
             #"outer_radius_xy":  (1.,            0.5),
             #"inner_radius":     (0.5,            0.1),
             #"radii_relations":  (0.5,            0.1),
             "radius_calib":     (1.,            1./6)
         }
 
         # shape choices: 'cylinder', 'block'
         #                        id,    name,       shape,      radius,                 color,              regression,     ambiguities,    gt_distortion
         self.pp_classes = [Label(1,     'cylinder', 'cylinder', ((6,6,1),(40,40,8)),    (*self.blue, 1.),   "radius_2d",    (),             ()),
                            #Label(2,      'block',      'block',        ((6,6,1),(40,40,8)),  (*self.aubergine,1.),  "radii_2d", (), ('radius_calib',))
             ]
 
 
         #########################
         #         I/O           #
         #########################
 
         self.data_sourcedir = '/mnt/HDD2TB/Documents/data/toy/cyl1ps_exact'
 
         if server_env:
             self.data_sourcedir = '/datasets/data_ramien/toy/cyl1ps_exact_npz'
 
 
         self.test_data_sourcedir = os.path.join(self.data_sourcedir, 'test')
         self.data_sourcedir = os.path.join(self.data_sourcedir, "train")
 
         self.info_df_name = 'info_df.pickle'
 
         # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn', 'detection_fpn'].
         self.model = 'retina_unet'
         self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net')
         self.model_path = os.path.join(self.source_dir, self.model_path)
 
 
         #########################
         #      Architecture     #
         #########################
 
         # one out of [2, 3]. dimension the model operates in.
         self.dim = 3
 
         # 'class', 'regression', 'regression_bin', 'regression_ken_gal'
         # currently only tested mode is a single-task at a time (i.e., only one task in below list)
         # but, in principle, tasks could be combined (e.g., object classes and regression per class)
         self.prediction_tasks = ['class',]
 
         self.start_filts = 48 if self.dim == 2 else 18
         self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2
         self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50'
         self.norm = 'instance_norm' # one of None, 'instance_norm', 'batch_norm'
         self.relu = 'relu'
         # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform')
         self.weight_init = None
 
         self.regression_n_features = 1  # length of regressor target vector
 
 
         #########################
         #      Data Loader      #
         #########################
 
         self.num_epochs = 32
         self.num_train_batches = 120 if self.dim == 2 else 80
         self.batch_size = 16 if self.dim == 2 else 8
 
         self.n_cv_splits = 4
         # select modalities from preprocessed data
         self.channels = [0]
         self.n_channels = len(self.channels)
 
         # which channel (mod) to show as bg in plotting, will be extra added to batch if not in self.channels
         self.plot_bg_chan = 0
         self.crop_margin = [20, 20, 1]  # has to be smaller than respective patch_size//2
         self.patch_size_2D = self.pre_crop_size[:2]
         self.patch_size_3D = self.pre_crop_size[:2]+[8]
 
         # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation.
         self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D
 
         # ratio of free sampled batch elements before class balancing is triggered
         # (>0 to include "empty"/background patches.)
         self.batch_random_ratio = 0.2
         self.balance_target = "class_targets" if 'class' in self.prediction_tasks else "rg_bin_targets"
 
         self.observables_patient = []
         self.observables_rois = []
 
         self.seed = 3 #for generating folds
 
         #############################
         # Colors, Classes, Legends  #
         #############################
         self.plot_frequency = 1
 
         binary_bin_labels = [binLabel(1,  'r<=25',      (*self.green, 1.),      (1,25)),
                              binLabel(2,  'r>25',       (*self.red, 1.),        (25,))]
         quintuple_bin_labels = [binLabel(1,  'r2-10',   (*self.green, 1.),      (2,10)),
                                 binLabel(2,  'r10-20',  (*self.yellow, 1.),     (10,20)),
                                 binLabel(3,  'r20-30',  (*self.orange, 1.),     (20,30)),
                                 binLabel(4,  'r30-40',  (*self.bright_red, 1.), (30,40)),
                                 binLabel(5,  'r>40',    (*self.red, 1.), (40,))]
 
         # choose here if to do 2-way or 5-way regression-bin classification
         task_spec_bin_labels = quintuple_bin_labels
 
         self.class_labels = [
             # regression: regression-task label, either value or "(x,y,z)_radius" or "radii".
             # ambiguities: name of above defined ambig to apply to image data (not gt); need to be iterables!
             # gt_distortion: name of ambig to apply to gt only; needs to be iterable!
             #      #id  #name   #shape  #radius     #color              #regression #ambiguities    #gt_distortion
             Label(  0,  'bg',   None,   (0, 0, 0),  (*self.white, 0.),  (0, 0, 0),  (),             ())]
         if "class" in self.prediction_tasks:
             self.class_labels += self.pp_classes
         else:
             self.class_labels += [Label(1, 'object', 'object', ('various',), (*self.orange, 1.), ('radius_2d',), ("various",), ('various',))]
 
 
         if any(['regression' in task for task in self.prediction_tasks]):
             self.bin_labels = [binLabel(0,  'bg',       (*self.white, 1.),      (0,))]
             self.bin_labels += task_spec_bin_labels
             self.bin_id2label = {label.id: label for label in self.bin_labels}
             bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels]
             self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)}
             self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)]
             self.bin_dict = {label.id: label.name for label in self.bin_labels if label.id != 0}
 
         if self.class_specific_seg:
           self.seg_labels = self.class_labels
 
         self.box_type2label = {label.name: label for label in self.box_labels}
         self.class_id2label = {label.id: label for label in self.class_labels}
         self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0}
 
         self.seg_id2label = {label.id: label for label in self.seg_labels}
         self.cmap = {label.id: label.color for label in self.seg_labels}
 
         self.plot_prediction_histograms = True
         self.plot_stat_curves = False
         self.has_colorchannels = False
         self.plot_class_ids = True
 
         self.num_classes = len(self.class_dict)
         self.num_seg_classes = len(self.seg_labels)
 
         #########################
         #   Data Augmentation   #
         #########################
         self.do_aug = True
         self.da_kwargs = {
             'mirror': True,
             'mirror_axes': tuple(np.arange(0, self.dim, 1)),
             'do_elastic_deform': False,
             'alpha': (500., 1500.),
             'sigma': (40., 45.),
             'do_rotation': False,
             'angle_x': (0., 2 * np.pi),
             'angle_y': (0., 0),
             'angle_z': (0., 0),
             'do_scale': False,
             'scale': (0.8, 1.1),
             'random_crop': False,
             'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3),
             'border_mode_data': 'constant',
             'border_cval_data': 0,
             'order_data': 1
         }
 
         if self.dim == 3:
             self.da_kwargs['do_elastic_deform'] = False
             self.da_kwargs['angle_x'] = (0, 0.0)
             self.da_kwargs['angle_y'] = (0, 0.0)  # must be 0!!
             self.da_kwargs['angle_z'] = (0., 2 * np.pi)
 
         #########################
         #  Schedule / Selection #
         #########################
 
         # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training)
         # the former is morge accurate, while the latter is faster (depending on volume size)
         self.val_mode = 'val_sampling' # one of 'val_sampling' , 'val_patient'
         if self.val_mode == 'val_patient':
             self.max_val_patients = 220  # if 'all' iterates over entire val_set once.
         if self.val_mode == 'val_sampling':
             self.num_val_batches = 25 if self.dim==2 else 15
 
         self.save_n_models = 2
         self.min_save_thresh = 1 if self.dim == 2 else 1  # =wait time in epochs
         if "class" in self.prediction_tasks:
             self.model_selection_criteria = {name + "_ap": 1. for name in self.class_dict.values()}
         elif any("regression" in task for task in self.prediction_tasks):
             self.model_selection_criteria = {name + "_ap": 0.2 for name in self.class_dict.values()}
             self.model_selection_criteria.update({name + "_avp": 0.8 for name in self.class_dict.values()})
 
         self.lr_decay_factor = 0.5
         self.scheduling_patience = int(self.num_epochs / 5)
         self.weight_decay = 1e-5
         self.clip_norm = None  # number or None
 
         #########################
         #   Testing / Plotting  #
         #########################
 
         self.test_aug_axes = (0,1,(0,1)) # None or list: choices are 0,1,(0,1)
         self.held_out_test_set = True
         self.max_test_patients = "all"  # number or "all" for all
 
         self.test_against_exact_gt = not 'exact' in self.data_sourcedir
         self.val_against_exact_gt = False # True is an unrealistic --> irrelevant scenario.
         self.report_score_level = ['rois']  # 'patient' or 'rois' (incl)
         self.patient_class_of_interest = 1
         self.patient_bin_of_interest = 2
 
         self.eval_bins_separately = False#"additionally" if not 'class' in self.prediction_tasks else False
         self.metrics = ['ap', 'auc', 'dice']
         if any(['regression' in task for task in self.prediction_tasks]):
             self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp',
                              'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp']
         if 'aleatoric' in self.model:
             self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted']
         self.evaluate_fold_means = True
 
         self.ap_match_ious = [0.5]  # threshold(s) for considering a prediction as true positive
         self.min_det_thresh = 0.3
 
         self.model_max_iou_resolution = 0.2
 
         # aggregation method for test and val_patient predictions.
         # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf,
         # nms = standard non-maximum suppression, or None = no clustering
         self.clustering = 'wbc'
         # iou thresh (exclusive!) for regarding two preds as concerning the same ROI
         self.clustering_iou = self.model_max_iou_resolution  # has to be larger than desired possible overlap iou of model predictions
 
         self.merge_2D_to_3D_preds = False
         self.merge_3D_iou = self.model_max_iou_resolution
         self.n_test_plots = 1  # per fold and rank
 
         self.test_n_epochs = self.save_n_models  # should be called n_test_ens, since is number of models to ensemble over during testing
         # is multiplied by (1 + nr of test augs)
 
         #########################
         #   Assertions          #
         #########################
         if not 'class' in self.prediction_tasks:
             assert self.num_classes == 1
 
         #########################
         #   Add model specifics #
         #########################
 
         {'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs,
          'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs,
          'detection_unet': self.add_det_unet_configs, 'detection_fpn': self.add_det_fpn_configs
          }[self.model]()
 
     def rg_val_to_bin_id(self, rg_val):
         #only meant for isotropic radii!!
         # only 2D radii (x and y dims) or 1D (x or y) are expected
         return np.round(np.digitize(rg_val, self.bin_edges).mean())
 
 
     def add_det_fpn_configs(self):
 
-      self.learning_rate = [5 * 1e-4] * self.num_epochs
+      self.learning_rate = [1 * 1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True
       self.scheduling_criterion = 'torch_loss'
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       self.n_roi_candidates = 4 if self.dim == 2 else 6
       # max number of roi candidates to identify per image (slice in 2D, volume in 3D)
 
       # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
       self.seg_loss_mode = 'wce'
       self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1]
 
       self.fp_dice_weight = 1 if self.dim == 2 else 1
       # if <1, false positive predictions in foreground are penalized less.
 
       self.detection_min_confidence = 0.05
       # how to determine score of roi: 'max' or 'median'
       self.score_det = 'max'
 
     def add_det_unet_configs(self):
 
-      self.learning_rate = [5 * 1e-4] * self.num_epochs
+      self.learning_rate = [1 * 1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True
       self.scheduling_criterion = "torch_loss"
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       # max number of roi candidates to identify per image (slice in 2D, volume in 3D)
       self.n_roi_candidates = 4 if self.dim == 2 else 6
 
       # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
       self.seg_loss_mode = 'wce'
       self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1]
       # if <1, false positive predictions in foreground are penalized less.
       self.fp_dice_weight = 1 if self.dim == 2 else 1
 
       self.detection_min_confidence = 0.05
       # how to determine score of roi: 'max' or 'median'
       self.score_det = 'max'
 
       self.init_filts = 32
       self.kernel_size = 3  # ks for horizontal, normal convs
       self.kernel_size_m = 2  # ks for max pool
       self.pad = "same"  # "same" or integer, padding of horizontal convs
 
     def add_mrcnn_configs(self):
 
       self.learning_rate = [1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True  # with scheduler set in exec
       self.scheduling_criterion = max(self.model_selection_criteria, key=self.model_selection_criteria.get)
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       # number of classes for network heads: n_foreground_classes + 1 (background)
       self.head_classes = self.num_classes + 1 if 'class' in self.prediction_tasks else 2
 
       # feed +/- n neighbouring slices into channel dimension. set to None for no context.
       self.n_3D_context = None
       if self.n_3D_context is not None and self.dim == 2:
         self.n_channels *= (self.n_3D_context * 2 + 1)
 
       self.detect_while_training = True
       # disable the re-sampling of mask proposals to original size for speed-up.
       # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching),
       # mask outputs are optional.
       self.return_masks_in_train = True
       self.return_masks_in_val = True
       self.return_masks_in_test = True
 
       # feature map strides per pyramid level are inferred from architecture. anchor scales are set accordingly.
       self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]}
       # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale
       # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.)
       self.rpn_anchor_scales = {'xy': [[4], [8], [16], [32]], 'z': [[1], [2], [4], [8]]}
       # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3.
       self.pyramid_levels = [0, 1, 2, 3]
       # number of feature maps in rpn. typically lowered in 3D to save gpu-memory.
       self.n_rpn_features = 512 if self.dim == 2 else 64
 
       # anchor ratios and strides per position in feature maps.
       self.rpn_anchor_ratios = [0.5, 1., 2.]
       self.rpn_anchor_stride = 1
       # Threshold for first stage (RPN) non-maximum suppression (NMS):  LOWER == HARDER SELECTION
       self.rpn_nms_threshold = max(0.8, self.model_max_iou_resolution)
 
       # loss sampling settings.
       self.rpn_train_anchors_per_image = 4
       self.train_rois_per_image = 6 # per batch_instance
       self.roi_positive_ratio = 0.5
       self.anchor_matching_iou = 0.8
 
       # k negative example candidates are drawn from a pool of size k*shem_poolsize (stochastic hard-example mining),
       # where k<=#positive examples.
       self.shem_poolsize = 2
 
       self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3)
       self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5)
       self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10)
 
       self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
       self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
       self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]])
       self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1],
                              self.patch_size_3D[2], self.patch_size_3D[2]])  # y1,x1,y2,x2,z1,z2
 
       if self.dim == 2:
         self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4]
         self.bbox_std_dev = self.bbox_std_dev[:4]
         self.window = self.window[:4]
         self.scale = self.scale[:4]
 
       self.plot_y_max = 1.5
       self.n_plot_rpn_props = 5 if self.dim == 2 else 30  # per batch_instance (slice in 2D / patient in 3D)
 
       # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element.
       self.pre_nms_limit = 2000 if self.dim == 2 else 4000
 
       # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True,
       # since proposals of the entire batch are forwarded through second stage as one "batch".
       self.roi_chunk_size = 1300 if self.dim == 2 else 500
       self.post_nms_rois_training = 200 * (self.head_classes-1) if self.dim == 2 else 400
       self.post_nms_rois_inference = 200 * (self.head_classes-1)
 
       # Final selection of detections (refine_detections)
       self.model_max_instances_per_batch_element = 9 if self.dim == 2 else 18 # per batch element and class.
       self.detection_nms_threshold = self.model_max_iou_resolution  # needs to be > 0, otherwise all predictions are one cluster.
       self.model_min_confidence = 0.2  # iou for nms in box refining (directly after heads), should be >0 since ths>=x in mrcnn.py
 
       if self.dim == 2:
         self.backbone_shapes = np.array(
           [[int(np.ceil(self.patch_size[0] / stride)),
             int(np.ceil(self.patch_size[1] / stride))]
            for stride in self.backbone_strides['xy']])
       else:
         self.backbone_shapes = np.array(
           [[int(np.ceil(self.patch_size[0] / stride)),
             int(np.ceil(self.patch_size[1] / stride)),
             int(np.ceil(self.patch_size[2] / stride_z))]
            for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z']
                                        )])
 
       if self.model == 'retina_net' or self.model == 'retina_unet':
         # whether to use focal loss or SHEM for loss-sample selection
         self.focal_loss = False
         # implement extra anchor-scales according to https://arxiv.org/abs/1708.02002
         self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                         self.rpn_anchor_scales['xy']]
         self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                        self.rpn_anchor_scales['z']]
         self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3
 
         # pre-selection of detections for NMS-speedup. per entire batch.
         self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size
 
         # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002
         self.anchor_matching_iou = 0.7
 
         if self.model == 'retina_unet':
           self.operate_stride1 = True
diff --git a/models/detection_fpn.py b/models/detection_fpn.py
index b59e51d..4cf58ef 100644
--- a/models/detection_fpn.py
+++ b/models/detection_fpn.py
@@ -1,176 +1,176 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 """
 Unet-like Backbone architecture, with non-parametric heuristics for box detection on semantic segmentation outputs.
 """
 import numpy as np
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from torch.autograd import Variable
 from scipy.ndimage.measurements import label as lb
 
 import utils.exp_utils as utils
 import utils.model_utils as mutils
 
 
 class net(nn.Module):
 
     def __init__(self, cf, logger):
 
         super(net, self).__init__()
         self.cf = cf
         self.logger = logger
         backbone = utils.import_module('bbone', cf.backbone_path)
         self.logger.info("loaded backbone from {}".format(self.cf.backbone_path))
         conv_gen = backbone.ConvGenerator(cf.dim)
 
         # set operate_stride1=True to generate a unet-like FPN.)
         self.fpn = backbone.FPN(cf, conv=conv_gen, relu_enc=cf.relu, operate_stride1=True)
-        self.conv_final = conv_gen(cf.end_filts, cf.num_seg_classes, ks=1, pad=0, norm=cf.norm, relu=None)
+        self.conv_final = conv_gen(cf.end_filts, cf.num_seg_classes, ks=1, pad=0, norm=None, relu=None)
 
         #initialize parameters
         if self.cf.weight_init=="custom":
             logger.info("Tried to use custom weight init which is not defined. Using pytorch default.")
         elif self.cf.weight_init:
             mutils.initialize_weights(self)
         else:
             logger.info("using default pytorch weight init")
 
 
     def forward(self, x):
         """
         forward pass of network.
         :param x: input image. shape (b, c, y, x, (z))
         :return: seg_logits: shape (b, n_classes, y, x, (z))
         :return: out_box_coords: list over n_classes. elements are arrays(b, n_rois, (y1, x1, y2, x2, (z1), (z2)))
         :return: out_max_scores: list over n_classes. elements are arrays(b, n_rois)
         """
 
         out_features = self.fpn(x)[0] #take only pyramid output of stride 1
 
         seg_logits = self.conv_final(out_features)
         out_box_coords, out_max_scores = [], []
         smax = F.softmax(seg_logits.detach(), dim=1).cpu().data.numpy()
 
         for cl in range(1, len(self.cf.class_dict.keys()) + 1):
             hard_mask = np.copy(smax).argmax(1)
             hard_mask[hard_mask != cl] = 0
             hard_mask[hard_mask == cl] = 1
             # perform connected component analysis on argmaxed predictions,
             # draw boxes around components and return coordinates.
             box_coords, rois = mutils.get_coords(hard_mask, self.cf.n_roi_candidates, self.cf.dim)
 
             # for each object, choose the highest softmax score (in the respective class)
             # of all pixels in the component as object score.
             max_scores = [[] for _ in range(x.shape[0])]
             for bix, broi in enumerate(rois):
                 for nix, nroi in enumerate(broi):
                     score_det = np.max if self.cf.score_det=="max" else np.median #score determination
                     max_scores[bix].append(score_det(smax[bix, cl][nroi > 0]))
             out_box_coords.append(box_coords)
             out_max_scores.append(max_scores)
         return seg_logits, out_box_coords, out_max_scores
 
     def train_forward(self, batch, **kwargs):
         """
         train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
         for processing, computes losses, and stores outputs in a dictionary.
         :param batch: dictionary containing 'data', 'seg', etc.
         :param kwargs:
         :return: results_dict: dictionary with keys:
                 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                         [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]
                 'torch_loss': 1D torch tensor for backprop.
                 'class_loss': classification loss for monitoring. here: dummy array, since no classification conducted.
         """
 
         img = torch.from_numpy(batch['data']).cuda().float()
         seg = torch.from_numpy(batch['seg']).cuda().long()
         seg_ohe = torch.from_numpy(mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).cuda()
         results_dict = {}
         seg_logits, box_coords, max_scores = self.forward(img)
 
         # no extra class loss applied in this model. pass dummy tensor for monitoring.
         results_dict['class_loss'] = np.nan
 
         results_dict['boxes'] = [[] for _ in range(img.shape[0])]
         for cix in range(len(self.cf.class_dict.keys())):
             for bix in range(img.shape[0]):
                 for rix in range(len(max_scores[cix][bix])):
                     if max_scores[cix][bix][rix] > self.cf.detection_min_confidence:
                         results_dict['boxes'][bix].append({'box_coords': np.copy(box_coords[cix][bix][rix]),
                                     'box_score': max_scores[cix][bix][rix],
                                     'box_pred_class_id': cix + 1, # add 0 for background.
                                     'box_type': 'det'})
 
         for bix in range(img.shape[0]):
             for tix in range(len(batch['bb_target'][bix])):
                 gt_box = {'box_coords': batch['bb_target'][bix][tix], 'box_type': 'gt'}
                 for name in self.cf.roi_items:
                     gt_box.update({name: batch[name][bix][tix]})
 
                 results_dict['boxes'][bix].append(gt_box)
 
         # compute segmentation loss as either weighted cross entropy, dice loss, or the sum of both.
         loss = torch.tensor([0.], dtype=torch.float, requires_grad=False).cuda()
         seg_pred = F.softmax(seg_logits, dim=1)
         if self.cf.seg_loss_mode == 'dice' or self.cf.seg_loss_mode == 'dice_wce':
             loss += 1 - mutils.batch_dice(seg_pred, seg_ohe.float(), false_positive_weight=float(self.cf.fp_dice_weight))
 
         if self.cf.seg_loss_mode == 'wce' or self.cf.seg_loss_mode == 'dice_wce':
             loss += F.cross_entropy(seg_logits, seg[:, 0], weight=torch.FloatTensor(self.cf.wce_weights).cuda())
 
         results_dict['torch_loss'] = loss
         seg_pred = seg_pred.argmax(dim=1).unsqueeze(dim=1).cpu().data.numpy()
         results_dict['seg_preds'] = seg_pred
         if 'dice' in self.cf.metrics:
             results_dict['batch_dices'] = mutils.dice_per_batch_and_class(seg_pred, batch["seg"],
                                                                            self.cf.num_seg_classes, convert_to_ohe=True)
         #self.logger.info("loss: {0:.2f}".format(loss.item()))
         return results_dict
 
 
     def test_forward(self, batch, **kwargs):
         """
         test method. wrapper around forward pass of network without usage of any ground truth information.
         prepares input data for processing and stores outputs in a dictionary.
         :param batch: dictionary containing 'data'
         :param kwargs:
         :return: results_dict: dictionary with keys:
                'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                        [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]
         """
         img = torch.FloatTensor(batch['data']).cuda()
         seg_logits, box_coords, max_scores = self.forward(img)
 
         results_dict = {}
         results_dict['boxes'] = [[] for _ in range(img.shape[0])]
         for cix in range(len(box_coords)):
             for bix in range(img.shape[0]):
                 for rix in range(len(max_scores[cix][bix])):
                     if max_scores[cix][bix][rix] > self.cf.detection_min_confidence:
                         results_dict['boxes'][bix].append({'box_coords': np.copy(box_coords[cix][bix][rix]),
                                     'box_score': max_scores[cix][bix][rix],
                                     'box_pred_class_id': cix + 1,
                                     'box_type': 'det'})
         results_dict['seg_preds'] = F.softmax(seg_logits, dim=1).cpu().data.numpy()
 
         return results_dict
 
diff --git a/models/detection_unet.py b/models/detection_unet.py
index 20394ba..dd7e293 100644
--- a/models/detection_unet.py
+++ b/models/detection_unet.py
@@ -1,545 +1,545 @@
 import warnings
 import os
 import shutil
 import time
 
 import math
 import numpy as np
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
 
 
 import utils.exp_utils as utils
 import utils.model_utils as mutils
 
 '''
 Use nn.DataParallel to use more than one GPU
 '''
 
 def center_crop_2D_image_batched(img, crop_size):
     # from batch generator tools from https://github.com/MIC-DKFZ/batchgenerators
     # dim 0 is batch, dim 1 is channel, dim 2 and 3 are x y
     center = np.array(img.shape[2:]) / 2.
     if not hasattr(crop_size, "__iter__"):
         center_crop = [int(crop_size)] * (len(img.shape) - 2)
     else:
         center_crop = np.array(crop_size)
         assert len(center_crop) == (len(
             img.shape) - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (2d)"
     return img[:, :, int(center[0] - center_crop[0] / 2.):int(center[0] + center_crop[0] / 2.),
            int(center[1] - center_crop[1] / 2.):int(center[1] + center_crop[1] / 2.)]
 
 def center_crop_3D_image_batched(img, crop_size):
     # dim 0 is batch, dim 1 is channel, dim 2, 3 and 4 are x y z
     center = np.array(img.shape[2:]) / 2.
     if not hasattr(crop_size, "__iter__"):
         center_crop = np.array([int(crop_size)] * (len(img.shape) - 2))
     else:
         center_crop = np.array(crop_size)
         assert len(center_crop) == (len(
             img.shape) - 2), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)"
     return img[:, :, int(center[0] - center_crop[0] / 2.):int(center[0] + center_crop[0] / 2.),
            int(center[1] - center_crop[1] / 2.):int(center[1] + center_crop[1] / 2.),
            int(center[2] - center_crop[2] / 2.):int(center[2] + center_crop[2] / 2.)]
 
 
 def centercrop_vol(tensor, size):
     """:param tensor: tensor whose last two dimensions should be centercropped to size
     :param size: 2- or 3-int tuple of target (height, width(,depth))
     """
     dim = len(size)
     if dim==2:
         center_crop_2D_image_batched(tensor, size)
     elif dim==3:
         center_crop_2D_image_batched(tensor, size)
     else:
         raise Exception("invalid size argument {} encountered in centercrop".format(size))
 
     """this below worked so fine, when optional z-dim was first spatial dim instead of last
     h_, w_ = size[0], size[1] #target size
     (h,w) = tensor.size()[-2:] #orig size
     dh, dw = h-h_, w-w_ #deltas
     if dim == 3:
         d_ = size[2]
         d  = tensor.size()[-3]
         dd = d-d_
         
     if h_<h:
         tensor = tensor[...,dh//2:-int(math.ceil(dh/2.)),:] #crop height
     elif h_>=h:
         print("no h crop")
         warn.warn("no height crop applied since target dims larger equal orig dims")
     if w_<w:
         tensor = tensor[...,dw//2:-int(math.ceil(dw/2.))]
     elif w_>=w:
         warn.warn("no width crop applied since target dims larger equal orig dims")
     if dim == 3:
         if d_ < d:
             tensor = tensor[..., dd // 2:-int(math.ceil(dd / 2.)),:,:]
         elif d_ >= d:
             warn.warn("no depth crop applied since target dims larger equal orig dims")
     """
 
     return tensor
     
 def dimcalc_conv2D(dims,F=3,s=1,pad="same"):
     r"""
     :param dims: orig width, height as (2,)-np.array
     :param F: quadratic kernel size
     :param s: stride
     :param pad: pad
     """
     if pad=="same":
         pad = (F-1)//2
     h, w = dims[0], dims[1] 
     return np.floor([(h + 2*pad-F)/s+1, (w+ 2*pad-F)/s+1])
 
 def dimcalc_transconv2D(dims,F=2,s=2):
     r"""
     :param dims: orig width, height as (2,)-np.array
     :param F: quadratic kernel size
     :param s: stride
     """    
 
     h, w = dims[0], dims[1]
     return np.array([(h-1)*s+F, (w-1)*s+F])
 
 def dimcalc_Unet_std(init_dims, F=3, F_pool=2, F_up=2, s=1, s_pool=2, s_up=2, pad=0):
     r"""Calculate theoretic dimensions of feature maps throughout layers of this U-net.
     """
     dims = np.array(init_dims)
     print("init dims: ", dims)
     
     def down(dims):
         for i in range(2):
             dims = dimcalc_conv2D(dims, F=F, s=s, pad=pad)       
         dims = dimcalc_conv2D(dims, F=F_pool, s=s_pool)     
         return dims.astype(int)    
     def up(dims):
         for i in range(2):
             dims = dimcalc_conv2D(dims, F=F, s=s, pad=pad)
         dims = dimcalc_transconv2D(dims, F=F_up,s=s_up)
         return dims.astype(int)
     
     stage = 1
     for i in range(4):
         dims = down(dims)
         print("stage ", stage, ": ", dims)
         stage+=1
     for i in range(4):
         dims = up(dims)
         print("stage ", stage, ": ", dims)
         stage+=1
     for i in range(2):
         dims = dimcalc_conv2D(dims,F=F,s=s, pad=pad).astype(int)
     print("final output size: ", dims)
     return dims
 
 def dimcalc_Unet(init_dims, F=3, F_pool=2, F_up=2, s=1, s_pool=2, s_up=2, pad=0):
     r"""Calculate theoretic dimensions of feature maps throughout layers of this U-net.
     """
     dims = np.array(init_dims)
     print("init dims: ", dims)
     
     def down(dims):
         for i in range(3):
             dims = dimcalc_conv2D(dims, F=F, s=s, pad=pad)       
         dims = dimcalc_conv2D(dims, F=F_pool, s=s_pool)     
         return dims.astype(int)    
     def up(dims):
         dims = dimcalc_transconv2D(dims, F=F_up,s=s_up)
         for i in range(3):
             dims = dimcalc_conv2D(dims, F=F, s=s, pad=pad)
         return dims.astype(int)
     
     stage = 1
     for i in range(6):
         dims = down(dims)
         print("stage ", stage, ": ", dims)
         stage+=1
     for i in range(3):
         dims = dimcalc_conv2D(dims, F=F, s=s, pad=pad)
     for i in range(6):
         dims = up(dims)
         print("stage ", stage, ": ", dims)
         stage+=1
     dims = dims.astype(int)
     print("final output size: ", dims)
     return dims
 
 
 
 class horiz_conv(nn.Module):
     def __init__(self, in_chans, out_chans, kernel_size, c_gen, norm, pad=0, relu="relu", bottleneck=True):
         super(horiz_conv, self).__init__()
         #TODO maybe make res-block?
         if bottleneck:
             bottleneck = int(np.round((in_chans+out_chans)*3/8))
             #print("bottleneck:", bottleneck)
         else:
             bottleneck = out_chans
         self.conv = nn.Sequential(
             c_gen(in_chans, bottleneck, kernel_size, pad=pad, norm=norm, relu=relu), #TODO maybe use norm only on last conv?
             c_gen(bottleneck, out_chans, kernel_size, pad=pad, norm=norm, relu=relu), #TODO maybe make bottleneck?
             #c_gen(out_chans, out_chans, kernel_size, pad=pad, norm=norm, relu=relu),
             )
     def forward(self, x):
         x = self.conv(x)
         return x
 
 class up(nn.Module):
     def __init__(self, in_chans, out_chans, kernel_size, interpol, c_gen, norm, pad=0, relu="relu", stride_ip=2):
         super(up, self).__init__()
         self.dim = c_gen.dim
         self.upsample = interpol(stride_ip, "bilinear") if self.dim==2 else interpol(stride_ip, "trilinear") #TODO check if fits with spatial dims order in data
         self.reduce_chans = c_gen(in_chans, out_chans, ks=1, norm=norm, relu=None)
         self.horiz = horiz_conv(out_chans*2, out_chans, kernel_size, c_gen, norm=norm, pad=pad, relu=relu)
 
     def forward(self, x, skip_inp):
         #TODO maybe add highway weights in skips?
         x = self.upsample(x)
         x = self.reduce_chans(x)
         #print("shape x, skip", x.shape, skip_inp.shape)
         targ_size = x.size()[-self.dim:] #ft map x,y,z (spatial)
         skip_inp = centercrop_vol(skip_inp, targ_size)
         assert targ_size == skip_inp.size()[-self.dim:], "corresp. skip and forward dimensions don't match"
         x = torch.cat((x,skip_inp),dim=1)
         x = self.horiz(x)
         return x
 
    
 class net(nn.Module):
     r"""U-Net with few more steps than standard.
     
     Dimensions: 
         feature maps have dims ...xhxwxd, d=feature map depth, h, w = orig 
         img height, width. h,w each are downsized by unpadded forward-convs and pooling,
         upsized by upsampling or upconvolution.
         If :math:`F\times F` is the single kernel_size and stride is :math:`s\geq 1`, 
         :math:`k` is the number of kernels in the conv, i.e. the resulting feature map depth,
         (all may differ between operations), then
     
     :Forward Conv: input  :math:`h \times w \times d` is converted to
     .. math:: \left[ (h-F)//s+1 \right] \times \left[ (w-F)//s+1 \right] \times k
     
     :Pooling: input  :math:`h \times w \times d` is converted to
     .. math:: \left[ (h-F)//s+1 \right] \times \left[ (w-F)//s+1 \right] \times d,
     pooling filters have no depths => orig depths preserved.
 
     :Up-Conv.: input  :math:`h \times w \times d` is converted to
     .. math:: \left[ (h-1)s + F \right] \times \left[ (w-1)s + F \right] \times k
     """
 
 
     def down(self, in_chans, out_chans, kernel_size, kernel_size_m, pad=0, relu="relu",maintain_z=False):
         """generate encoder block
         :param in_chans:
         :param out_chans:
         :param kernel_size:
         :param pad:
         :return:
         """
         if maintain_z and self.dim==3:
             stride_pool = (2,2,1)
             if not hasattr(kernel_size_m, "__iter__"):
                 kernel_size_m = [kernel_size_m]*self.dim
             kernel_size_m = (*kernel_size_m[:-1], 1)
         else:
             stride_pool = 2
         module = nn.Sequential(
             nn.MaxPool2d(kernel_size_m, stride=stride_pool) if self.dim == 2 else nn.MaxPool3d(
                 kernel_size_m, stride=stride_pool),
             #--> needs stride 2 in z in upsampling as well!
             horiz_conv(in_chans, out_chans, kernel_size, self.c_gen, self.norm, pad, relu=relu)
         )
         return module
 
     def up(self, in_chans, out_chans, kernel_size, pad=0, relu="relu", maintain_z=False):
         """generate decoder block
         :param in_chans:
         :param out_chans:
         :param kernel_size:
         :param pad:
         :param relu:
         :return:
         """
         if maintain_z and self.dim==3:
             stride_ip = (2,2,1)
         else:
             stride_ip = 2
 
         module = up(in_chans, out_chans, kernel_size, self.Interpolator, self.c_gen, norm=self.norm, pad=pad,
                     relu=relu, stride_ip=stride_ip)
 
         return module
 
 
     def __init__(self, cf, logger):
         super(net, self).__init__()
 
         self.cf = cf
         self.dim = cf.dim
         self.norm = cf.norm
         self.logger = logger
         backbone = utils.import_module('bbone', cf.backbone_path)
         self.c_gen = backbone.ConvGenerator(cf.dim)
         self.Interpolator = backbone.Interpolate
 
         #down = DownBlockGen(cf.dim)
         #up = UpBlockGen(cf.dim, backbone.Interpolate)
         down = self.down
         up = self.up
 
         pad = cf.pad
         if pad=="same":
             pad = (cf.kernel_size-1)//2
 
         
         self.dims = "not yet recorded"
         self.is_cuda = False
               
         self.init = horiz_conv(len(cf.channels), cf.init_filts, cf.kernel_size, self.c_gen, self.norm, pad=pad,
                                relu=cf.relu)
         
         self.down1 = down(cf.init_filts,    cf.init_filts*2,  cf.kernel_size, cf.kernel_size_m, pad=pad, relu=cf.relu)
         self.down2 = down(cf.init_filts*2,  cf.init_filts*4,  cf.kernel_size, cf.kernel_size_m, pad=pad, relu=cf.relu)
         self.down3 = down(cf.init_filts*4,  cf.init_filts*6,  cf.kernel_size, cf.kernel_size_m, pad=pad, relu=cf.relu)
         self.down4 = down(cf.init_filts*6,  cf.init_filts*8,  cf.kernel_size, cf.kernel_size_m, pad=pad, relu=cf.relu,
                           maintain_z=True)
         self.down5 = down(cf.init_filts*8,  cf.init_filts*12, cf.kernel_size, cf.kernel_size_m, pad=pad, relu=cf.relu,
                           maintain_z=True)
         #self.down6 = down(cf.init_filts*10, cf.init_filts*14, cf.kernel_size, cf.kernel_size_m, pad=pad, relu=cf.relu)
         
         #self.up1 = up(cf.init_filts*14, cf.init_filts*10, cf.kernel_size, pad=pad, relu=cf.relu)
         self.up2 = up(cf.init_filts*12, cf.init_filts*8,  cf.kernel_size, pad=pad, relu=cf.relu, maintain_z=True)
         self.up3 = up(cf.init_filts*8,  cf.init_filts*6,  cf.kernel_size, pad=pad, relu=cf.relu, maintain_z=True)
         self.up4 = up(cf.init_filts*6,  cf.init_filts*4,  cf.kernel_size, pad=pad, relu=cf.relu)
         self.up5 = up(cf.init_filts*4,  cf.init_filts*2,  cf.kernel_size, pad=pad, relu=cf.relu)
         self.up6 = up(cf.init_filts*2,  cf.init_filts,    cf.kernel_size, pad=pad, relu=cf.relu)
         
-        self.seg = self.c_gen(cf.init_filts, cf.num_seg_classes, 1, norm=None, relu=None) #TODO maybe apply norm too?
+        self.seg = self.c_gen(cf.init_filts, cf.num_seg_classes, 1, norm=None, relu=None)
 
 
         # initialize parameters
         if self.cf.weight_init == "custom":
             logger.info("Tried to use custom weight init which is not defined. Using pytorch default.")
         elif self.cf.weight_init:
             mutils.initialize_weights(self)
         else:
             logger.info("using default pytorch weight init")
         
     
     def forward(self, x):
         r'''Forward application of network-function.
         
         :param x: input to the network, expected as torch.tensor of dims
         .. math:: batch\_size \times channels \times height \times width
         requires_grad should be True for training
         '''
         #self.dims = np.array([x.size()[-self.dim-1:]])
         
         x1 = self.init(x)
         #self.dims = np.vstack((self.dims, x1.size()[-self.dim-1:]))
         
         #---downwards---
         x2 = self.down1(x1)
         #self.dims = np.vstack((self.dims, x2.size()[-self.dim-1:]))
         x3 = self.down2(x2)
         #self.dims = np.vstack((self.dims, x3.size()[-self.dim-1:]))
         x4 = self.down3(x3)
         #self.dims = np.vstack((self.dims, x4.size()[-self.dim-1:]))
         x5 = self.down4(x4)
         #self.dims = np.vstack((self.dims, x5.size()[-self.dim-1:]))
         #x6 = self.down5(x5)
         #self.dims = np.vstack((self.dims, x6.size()[-self.dim-1:]))
         
         #---bottom---
         x = self.down5(x5)
         #self.dims = np.vstack((self.dims, x.size()[-self.dim-1:]))
         
         #---upwards---
         #x = self.up1(x, x6)
         #self.dims = np.vstack((self.dims, x.size()[-self.dim-1:]))
         x = self.up2(x, x5)
         #self.dims = np.vstack((self.dims, x.size()[-self.dim-1:]))
         x = self.up3(x, x4)
         #self.dims = np.vstack((self.dims, x.size()[-self.dim-1:]))
         x = self.up4(x, x3)
         #self.dims = np.vstack((self.dims, x.size()[-self.dim-1:]))
         x = self.up5(x, x2)
         #self.dims = np.vstack((self.dims, x.size()[-self.dim-1:]))
 
         x = self.up6(x, x1)
         #self.dims = np.vstack((self.dims, x.size()[-self.dim-1:]))
 
         # ---final---
         x = self.seg(x)
         #self.dims = np.vstack((self.dims, x.size()[-self.dim-1:]))
 
         seg_logits = x
         out_box_coords, out_scores = [], []
         seg_probs = F.softmax(seg_logits.detach(), dim=1).cpu().data.numpy()
         #seg_probs = F.softmax(seg_logits, dim=1)
 
         assert seg_logits.shape[1]==self.cf.num_seg_classes
         for cl in range(1, seg_logits.shape[1]):
             hard_mask = np.copy(seg_probs).argmax(1)
             #hard_mask = seg_probs.clone().argmax(1)
             hard_mask[hard_mask != cl] = 0
             hard_mask[hard_mask == cl] = 1
             # perform connected component analysis on argmaxed predictions,
             # draw boxes around components and return coordinates.
             box_coords, rois = mutils.get_coords(hard_mask, self.cf.n_roi_candidates, self.cf.dim)
 
             # for each object, choose the highest softmax score (in the respective class)
             # of all pixels in the component as object score.
             scores = [[] for b_inst in range(x.shape[0])]  # np.zeros((out_features.shape[0], self.cf.n_roi_candidates))
             for b_inst, brois in enumerate(rois):
                 for nix, nroi in enumerate(brois):
                     score_det = np.max if self.cf.score_det == "max" else np.median  # score determination
                     scores[b_inst].append(score_det(seg_probs[b_inst, cl][nroi > 0]))
             out_box_coords.append(box_coords)
             out_scores.append(scores)
 
         return seg_logits, out_box_coords, out_scores
 
     # noinspection PyCallingNonCallable
     def train_forward(self, batch, **kwargs):
         """
         train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
         for processing, computes losses, and stores outputs in a dictionary.
         :param batch: dictionary containing 'data', 'seg', etc.
         :param kwargs:
         :return: results_dict: dictionary with keys:
                 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                         [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]
                 'torch_loss': 1D torch tensor for backprop.
                 'class_loss': classification loss for monitoring. here: dummy array, since no classification conducted.
         """
 
         img = torch.from_numpy(batch["data"]).float().cuda()
         seg = torch.from_numpy(batch["seg"]).long().cuda()
         seg_ohe = torch.from_numpy(mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).float().cuda()
 
         results_dict = {}
         seg_logits, box_coords, scores = self.forward(img)
 
         # no extra class loss applied in this model. pass dummy tensor for monitoring.
         results_dict['class_loss'] = np.nan
 
         results_dict['boxes'] = [[] for _ in range(img.shape[0])]
         for cix in range(len(self.cf.class_dict.keys())):
             for bix in range(img.shape[0]):
                 for rix in range(len(scores[cix][bix])):
                     if scores[cix][bix][rix] > self.cf.detection_min_confidence:
                         results_dict['boxes'][bix].append({'box_coords': np.copy(box_coords[cix][bix][rix]),
                                                            'box_score': scores[cix][bix][rix],
                                                            'box_pred_class_id': cix + 1,  # add 0 for background.
                                                            'box_type': 'det',
                                                            })
 
         for bix in range(img.shape[0]): #bix = batch-element index
             for tix in range(len(batch['bb_target'][bix])): #target index
                 gt_box = {'box_coords': batch['bb_target'][bix][tix], 'box_type': 'gt'}
                 for name in self.cf.roi_items:
                     gt_box.update({name: batch[name][bix][tix]})
                 results_dict['boxes'][bix].append(gt_box)
 
         # compute segmentation loss as either weighted cross entropy, dice loss, or the sum of both.
         seg_pred = F.softmax(seg_logits, 1)
         loss = torch.tensor([0.], dtype=torch.float, requires_grad=False).cuda()
         if self.cf.seg_loss_mode == 'dice' or self.cf.seg_loss_mode == 'dice_wce':
             loss += 1 - mutils.batch_dice(seg_pred, seg_ohe.float(),
                                          false_positive_weight=float(self.cf.fp_dice_weight))
 
         if self.cf.seg_loss_mode == 'wce' or self.cf.seg_loss_mode == 'dice_wce':
             loss += F.cross_entropy(seg_logits, seg[:, 0], weight=torch.FloatTensor(self.cf.wce_weights).cuda(),
                                     reduction='mean')
 
         results_dict['torch_loss'] = loss
         seg_pred = seg_pred.argmax(dim=1).unsqueeze(dim=1).cpu().data.numpy()
         results_dict['seg_preds'] = seg_pred
         if 'dice' in self.cf.metrics:
             results_dict['batch_dices'] = mutils.dice_per_batch_and_class(seg_pred, batch["seg"],
                                                                            self.cf.num_seg_classes, convert_to_ohe=True)
             #print("batch dice scores ", results_dict['batch_dices'] )
         # self.logger.info("loss: {0:.2f}".format(loss.item()))
         return results_dict
 
     def test_forward(self, batch, **kwargs):
         """
         test method. wrapper around forward pass of network without usage of any ground truth information.
         prepares input data for processing and stores outputs in a dictionary.
         :param batch: dictionary containing 'data'
         :param kwargs:
         :return: results_dict: dictionary with keys:
                'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                        [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]
         """
         img = torch.FloatTensor(batch['data']).cuda()
         seg_logits, box_coords, scores = self.forward(img)
 
         results_dict = {}
         results_dict['boxes'] = [[] for b_inst in range(img.shape[0])]
         for cix in range(len(box_coords)): #class index
             for bix in range(img.shape[0]): #batch instance
                 for rix in range(len(scores[cix][bix])): #range(self.cf.n_roi_candidates):
                     if scores[cix][bix][rix] > self.cf.detection_min_confidence:
                         results_dict['boxes'][bix].append({'box_coords': np.copy(box_coords[cix][bix][rix]),
                                     'box_score': scores[cix][bix][rix],
                                     'box_pred_class_id': cix + 1,
                                     'box_type': 'det'})
         # carry probs instead of preds to use for multi-model voting in predictor
         results_dict['seg_preds'] = F.softmax(seg_logits, dim=1).cpu().data.numpy()
 
 
         return results_dict
 
 
     def actual_dims(self, print_=True):
         r"""Return dimensions of actually calculated layers at beginning of each block.
         """
         if print_:
             print("dimensions as recorded in forward pass: ")
             for stage in range(len(self.dims)):
                 print("Stage ", stage, ": ", self.dims[stage])
         return self.dims
         
     def cuda(self, device=None):
         r"""Moves all model parameters and buffers to the GPU.
 
         This also makes associated parameters and buffers different objects. So
         it should be called before constructing optimizer if the module will
         live on GPU while being optimized.
 
         Arguments:
             device (int, optional): if specified, all parameters will be
                 copied to that device
 
         Returns:
             Module: self
         """
         try:
             self.loss_f = self.loss_f.cuda()
         except:
             pass
         self.is_cuda = True
         return self._apply(lambda t: t.cuda(device))
     
     def cpu(self):
         r"""Moves all model parameters and buffers to the CPU.
 
         Returns:
             Module: self
         """
         self.is_cuda = False
         return self._apply(lambda t: t.cpu()) 
 
 
 
 
         
\ No newline at end of file