diff --git a/datasets/toy/configs.py b/datasets/toy/configs.py
index c13c954..abeafd0 100644
--- a/datasets/toy/configs.py
+++ b/datasets/toy/configs.py
@@ -1,491 +1,490 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import sys
 import os
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 import numpy as np
 from default_configs import DefaultConfigs
 from collections import namedtuple
 
 boxLabel = namedtuple('boxLabel', ["name", "color"])
 Label = namedtuple("Label", ['id', 'name', 'shape', 'radius', 'color', 'regression', 'ambiguities', 'gt_distortion'])
 binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'bin_vals'])
 
 class Configs(DefaultConfigs):
 
     def __init__(self, server_env=None):
         super(Configs, self).__init__(server_env)
 
         #########################
         #         Prepro        #
         #########################
 
-        self.pp_rootdir = os.path.join('/mnt/HDD2TB/Documents/data/toy', "cyl1ps_dev_exact")
+        self.pp_rootdir = os.path.join('/media/gregor/HDD2TB/data/toy', "cyl1ps_dev")
         self.pp_npz_dir = self.pp_rootdir+"_npz"
 
         self.pre_crop_size = [320,320,8] #y,x,z; determines pp data shape (2D easily implementable, but only 3D for now)
         self.min_2d_radius = 6 #in pixels
-        self.n_train_samples, self.n_test_samples = 80, 80
+        self.n_train_samples, self.n_test_samples = 320, 80
 
         # not actually real one-hot encoding (ohe) but contains more info: roi-overlap only within classes.
         self.pp_create_ohe_seg = False
         self.pp_empty_samples_ratio = 0.1
 
         self.pp_place_radii_mid_bin = True
         self.pp_only_distort_2d = True
         # outer-most intensity of blurred radii, relative to inner-object intensity. <1 for decreasing, > 1 for increasing.
         # e.g.: setting 0.1 means blurred edge has min intensity 10% as large as inner-object intensity.
         self.pp_blur_min_intensity = 0.2
 
-        self.max_instances_per_sample = 3 #how many max instances over all classes per sample (img if 2d, vol if 3d)
+        self.max_instances_per_sample = 1 #how many max instances over all classes per sample (img if 2d, vol if 3d)
         self.max_instances_per_class = self.max_instances_per_sample  # how many max instances per image per class
         self.noise_scale = 0.  # std-dev of gaussian noise
 
         self.ambigs_sampling = "gaussian" #"gaussian" or "uniform"
         """ radius_calib: gt distort for calibrating uncertainty. Range of gt distortion is inferable from
             image by distinguishing it from the rest of the object.
             blurring width around edge will be shifted so that symmetric rel to orig radius.
             blurring scale: if self.ambigs_sampling is uniform, distribution's non-zero range (b-a) will be sqrt(12)*scale
             since uniform dist has variance (b-a)²/12. b,a will be placed symmetrically around unperturbed radius.
             if sampling is gaussian, then scale parameter sets one std dev, i.e., blurring width will be orig_radius * std_dev * 2.
         """
         self.ambiguities = {
              #set which classes to apply which ambs to below in class labels
              #choose out of: 'outer_radius', 'inner_radius', 'radii_relations'.
              #kind              #probability   #scale (gaussian std, relative to unperturbed value)
             #"outer_radius":     (1.,            0.5),
             #"outer_radius_xy":  (1.,            0.5),
             #"inner_radius":     (0.5,            0.1),
             #"radii_relations":  (0.5,            0.1),
             "radius_calib":     (1.,            1./6)
         }
 
         # shape choices: 'cylinder', 'block'
         #                        id,    name,       shape,      radius,                 color,              regression,     ambiguities,    gt_distortion
         self.pp_classes = [Label(1,     'cylinder', 'cylinder', ((6,6,1),(40,40,8)),    (*self.blue, 1.),   "radius_2d",    (),             ()),
                            #Label(2,      'block',      'block',        ((6,6,1),(40,40,8)),  (*self.aubergine,1.),  "radii_2d", (), ('radius_calib',))
             ]
 
 
         #########################
         #         I/O           #
         #########################
 
-        self.data_sourcedir = '/mnt/HDD2TB/Documents/data/toy/cyl1ps_exact'
+        self.data_sourcedir = '/media/gregor/HDD2TB/data/toy/cyl1ps_dev'
 
         if server_env:
-            self.data_sourcedir = '/datasets/data_ramien/toy/cyl1ps_exact_npz'
+            self.data_sourcedir = '/datasets/data_ramien/toy/cyl1ps_dev_npz'
 
 
         self.test_data_sourcedir = os.path.join(self.data_sourcedir, 'test')
         self.data_sourcedir = os.path.join(self.data_sourcedir, "train")
 
         self.info_df_name = 'info_df.pickle'
 
-        # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'detection_fpn'].
-        self.model = 'retinau'
+        # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn', 'detection_fpn'].
+        self.model = 'retina_unet'
         self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net')
         self.model_path = os.path.join(self.source_dir, self.model_path)
 
 
         #########################
         #      Architecture     #
         #########################
 
         # one out of [2, 3]. dimension the model operates in.
-        self.dim = 2
+        self.dim = 3
 
         # 'class', 'regression', 'regression_bin', 'regression_ken_gal'
         # currently only tested mode is a single-task at a time (i.e., only one task in below list)
         # but, in principle, tasks could be combined (e.g., object classes and regression per class)
-        self.prediction_tasks = ['class',]
+        self.prediction_tasks = ['class', 'regression']
 
-        self.start_filts = 36 if self.dim == 2 else 18
+        self.start_filts = 48 if self.dim == 2 else 18
         self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2
         self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50'
-        self.norm = None #'instance_norm' # one of None, 'instance_norm', 'batch_norm'
+        self.norm = 'instance_norm' # one of None, 'instance_norm', 'batch_norm'
         self.relu = 'relu'
         # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform')
         self.weight_init = None
 
         self.regression_n_features = 1  # length of regressor target vector
 
 
         #########################
         #      Data Loader      #
         #########################
 
         self.num_epochs = 32
-        self.num_train_batches = 100 if self.dim == 2 else 80
-        self.batch_size = 12 if self.dim == 2 else 8
+        self.num_train_batches = 120 if self.dim == 2 else 80
+        self.batch_size = 16 if self.dim == 2 else 8
 
         self.n_cv_splits = 4
         # select modalities from preprocessed data
         self.channels = [0]
         self.n_channels = len(self.channels)
 
         # which channel (mod) to show as bg in plotting, will be extra added to batch if not in self.channels
         self.plot_bg_chan = 0
         self.crop_margin = [20, 20, 1]  # has to be smaller than respective patch_size//2
         self.patch_size_2D = self.pre_crop_size[:2]
         self.patch_size_3D = self.pre_crop_size[:2]+[8]
 
         # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation.
         self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D
 
         # ratio of free sampled batch elements before class balancing is triggered
         # (>0 to include "empty"/background patches.)
         self.batch_random_ratio = 0.2
         self.balance_target = "class_targets" if 'class' in self.prediction_tasks else "rg_bin_targets"
 
         self.observables_patient = []
         self.observables_rois = []
 
-        self.seed = 3 # for generating folds
+        self.seed = 3 #for generating folds
 
         #############################
         # Colors, Classes, Legends  #
         #############################
         self.plot_frequency = 1
 
         binary_bin_labels = [binLabel(1,  'r<=25',      (*self.green, 1.),      (1,25)),
                              binLabel(2,  'r>25',       (*self.red, 1.),        (25,))]
         quintuple_bin_labels = [binLabel(1,  'r2-10',   (*self.green, 1.),      (2,10)),
                                 binLabel(2,  'r10-20',  (*self.yellow, 1.),     (10,20)),
                                 binLabel(3,  'r20-30',  (*self.orange, 1.),     (20,30)),
                                 binLabel(4,  'r30-40',  (*self.bright_red, 1.), (30,40)),
                                 binLabel(5,  'r>40',    (*self.red, 1.), (40,))]
 
         # choose here if to do 2-way or 5-way regression-bin classification
         task_spec_bin_labels = quintuple_bin_labels
 
         self.class_labels = [
             # regression: regression-task label, either value or "(x,y,z)_radius" or "radii".
             # ambiguities: name of above defined ambig to apply to image data (not gt); need to be iterables!
             # gt_distortion: name of ambig to apply to gt only; needs to be iterable!
             #      #id  #name   #shape  #radius     #color              #regression #ambiguities    #gt_distortion
             Label(  0,  'bg',   None,   (0, 0, 0),  (*self.white, 0.),  (0, 0, 0),  (),             ())]
         if "class" in self.prediction_tasks:
             self.class_labels += self.pp_classes
         else:
             self.class_labels += [Label(1, 'object', 'object', ('various',), (*self.orange, 1.), ('radius_2d',), ("various",), ('various',))]
 
 
         if any(['regression' in task for task in self.prediction_tasks]):
             self.bin_labels = [binLabel(0,  'bg',       (*self.white, 1.),      (0,))]
             self.bin_labels += task_spec_bin_labels
             self.bin_id2label = {label.id: label for label in self.bin_labels}
             bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels]
             self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)}
             self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)]
             self.bin_dict = {label.id: label.name for label in self.bin_labels if label.id != 0}
 
         if self.class_specific_seg:
           self.seg_labels = self.class_labels
 
         self.box_type2label = {label.name: label for label in self.box_labels}
         self.class_id2label = {label.id: label for label in self.class_labels}
         self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0}
 
         self.seg_id2label = {label.id: label for label in self.seg_labels}
         self.cmap = {label.id: label.color for label in self.seg_labels}
 
         self.plot_prediction_histograms = True
         self.plot_stat_curves = False
         self.has_colorchannels = False
         self.plot_class_ids = True
 
         self.num_classes = len(self.class_dict)
         self.num_seg_classes = len(self.seg_labels)
 
         #########################
         #   Data Augmentation   #
         #########################
         self.do_aug = True
         self.da_kwargs = {
             'mirror': True,
             'mirror_axes': tuple(np.arange(0, self.dim, 1)),
             'do_elastic_deform': False,
             'alpha': (500., 1500.),
             'sigma': (40., 45.),
             'do_rotation': False,
             'angle_x': (0., 2 * np.pi),
             'angle_y': (0., 0),
             'angle_z': (0., 0),
             'do_scale': False,
             'scale': (0.8, 1.1),
             'random_crop': False,
             'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3),
             'border_mode_data': 'constant',
             'border_cval_data': 0,
             'order_data': 1
         }
 
         if self.dim == 3:
             self.da_kwargs['do_elastic_deform'] = False
             self.da_kwargs['angle_x'] = (0, 0.0)
             self.da_kwargs['angle_y'] = (0, 0.0)  # must be 0!!
             self.da_kwargs['angle_z'] = (0., 2 * np.pi)
 
         #########################
         #  Schedule / Selection #
         #########################
 
         # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training)
         # the former is morge accurate, while the latter is faster (depending on volume size)
         self.val_mode = 'val_sampling' # one of 'val_sampling' , 'val_patient'
         if self.val_mode == 'val_patient':
             self.max_val_patients = 220  # if 'all' iterates over entire val_set once.
         if self.val_mode == 'val_sampling':
             self.num_val_batches = 25 if self.dim==2 else 15
 
         self.save_n_models = 2
         self.min_save_thresh = 1 if self.dim == 2 else 1  # =wait time in epochs
         if "class" in self.prediction_tasks:
             self.model_selection_criteria = {name + "_ap": 1. for name in self.class_dict.values()}
         elif any("regression" in task for task in self.prediction_tasks):
             self.model_selection_criteria = {name + "_ap": 0.2 for name in self.class_dict.values()}
             self.model_selection_criteria.update({name + "_avp": 0.8 for name in self.class_dict.values()})
 
         self.lr_decay_factor = 0.5
         self.scheduling_patience = int(self.num_epochs / 5)
         self.weight_decay = 1e-5
         self.clip_norm = None  # number or None
 
         #########################
         #   Testing / Plotting  #
         #########################
 
         self.test_aug_axes = (0,1,(0,1)) # None or list: choices are 0,1,(0,1)
         self.held_out_test_set = True
         self.max_test_patients = "all"  # number or "all" for all
 
-        self.test_against_exact_gt = not 'exact' in self.data_sourcedir
+        self.test_against_exact_gt = True # only True implemented
         self.val_against_exact_gt = False # True is an unrealistic --> irrelevant scenario.
         self.report_score_level = ['rois']  # 'patient' or 'rois' (incl)
         self.patient_class_of_interest = 1
         self.patient_bin_of_interest = 2
 
         self.eval_bins_separately = False#"additionally" if not 'class' in self.prediction_tasks else False
         self.metrics = ['ap', 'auc', 'dice']
         if any(['regression' in task for task in self.prediction_tasks]):
             self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp',
                              'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp']
         if 'aleatoric' in self.model:
             self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted']
         self.evaluate_fold_means = True
 
         self.ap_match_ious = [0.5]  # threshold(s) for considering a prediction as true positive
         self.min_det_thresh = 0.3
 
         self.model_max_iou_resolution = 0.2
 
         # aggregation method for test and val_patient predictions.
         # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf,
         # nms = standard non-maximum suppression, or None = no clustering
         self.clustering = 'wbc'
         # iou thresh (exclusive!) for regarding two preds as concerning the same ROI
         self.clustering_iou = self.model_max_iou_resolution  # has to be larger than desired possible overlap iou of model predictions
 
         self.merge_2D_to_3D_preds = False
         self.merge_3D_iou = self.model_max_iou_resolution
         self.n_test_plots = 1  # per fold and rank
 
         self.test_n_epochs = self.save_n_models  # should be called n_test_ens, since is number of models to ensemble over during testing
         # is multiplied by (1 + nr of test augs)
 
         #########################
         #   Assertions          #
         #########################
         if not 'class' in self.prediction_tasks:
             assert self.num_classes == 1
 
         #########################
         #   Add model specifics #
         #########################
 
         {'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs,
          'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs,
          'detection_unet': self.add_det_unet_configs, 'detection_fpn': self.add_det_fpn_configs
          }[self.model]()
 
     def rg_val_to_bin_id(self, rg_val):
         #only meant for isotropic radii!!
         # only 2D radii (x and y dims) or 1D (x or y) are expected
         return np.round(np.digitize(rg_val, self.bin_edges).mean())
 
 
     def add_det_fpn_configs(self):
 
       self.learning_rate = [1 * 1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True
       self.scheduling_criterion = 'torch_loss'
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       self.n_roi_candidates = 4 if self.dim == 2 else 6
       # max number of roi candidates to identify per image (slice in 2D, volume in 3D)
 
       # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
       self.seg_loss_mode = 'wce'
       self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1]
 
       self.fp_dice_weight = 1 if self.dim == 2 else 1
       # if <1, false positive predictions in foreground are penalized less.
 
       self.detection_min_confidence = 0.05
       # how to determine score of roi: 'max' or 'median'
       self.score_det = 'max'
 
     def add_det_unet_configs(self):
 
       self.learning_rate = [1 * 1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True
       self.scheduling_criterion = "torch_loss"
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       # max number of roi candidates to identify per image (slice in 2D, volume in 3D)
       self.n_roi_candidates = 4 if self.dim == 2 else 6
 
       # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
       self.seg_loss_mode = 'wce'
       self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1]
       # if <1, false positive predictions in foreground are penalized less.
       self.fp_dice_weight = 1 if self.dim == 2 else 1
 
       self.detection_min_confidence = 0.05
       # how to determine score of roi: 'max' or 'median'
       self.score_det = 'max'
 
       self.init_filts = 32
       self.kernel_size = 3  # ks for horizontal, normal convs
       self.kernel_size_m = 2  # ks for max pool
       self.pad = "same"  # "same" or integer, padding of horizontal convs
 
     def add_mrcnn_configs(self):
-      self.frcnn_mode = False
 
       self.learning_rate = [1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True  # with scheduler set in exec
       self.scheduling_criterion = max(self.model_selection_criteria, key=self.model_selection_criteria.get)
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       # number of classes for network heads: n_foreground_classes + 1 (background)
       self.head_classes = self.num_classes + 1 if 'class' in self.prediction_tasks else 2
 
       # feed +/- n neighbouring slices into channel dimension. set to None for no context.
       self.n_3D_context = None
       if self.n_3D_context is not None and self.dim == 2:
         self.n_channels *= (self.n_3D_context * 2 + 1)
 
       self.detect_while_training = True
       # disable the re-sampling of mask proposals to original size for speed-up.
       # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching),
       # mask outputs are optional.
       self.return_masks_in_train = True
       self.return_masks_in_val = True
       self.return_masks_in_test = True
 
       # feature map strides per pyramid level are inferred from architecture. anchor scales are set accordingly.
       self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]}
       # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale
       # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.)
       self.rpn_anchor_scales = {'xy': [[4], [8], [16], [32]], 'z': [[1], [2], [4], [8]]}
       # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3.
       self.pyramid_levels = [0, 1, 2, 3]
       # number of feature maps in rpn. typically lowered in 3D to save gpu-memory.
-      self.n_rpn_features = 128 if self.dim == 2 else 64
+      self.n_rpn_features = 512 if self.dim == 2 else 64
 
       # anchor ratios and strides per position in feature maps.
       self.rpn_anchor_ratios = [0.5, 1., 2.]
       self.rpn_anchor_stride = 1
       # Threshold for first stage (RPN) non-maximum suppression (NMS):  LOWER == HARDER SELECTION
       self.rpn_nms_threshold = max(0.8, self.model_max_iou_resolution)
 
       # loss sampling settings.
       self.rpn_train_anchors_per_image = 4
       self.train_rois_per_image = 6 # per batch_instance
       self.roi_positive_ratio = 0.5
       self.anchor_matching_iou = 0.8
 
       # k negative example candidates are drawn from a pool of size k*shem_poolsize (stochastic hard-example mining),
       # where k<=#positive examples.
       self.shem_poolsize = 2
 
       self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3)
       self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5)
       self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10)
 
       self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
       self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
       self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]])
       self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1],
                              self.patch_size_3D[2], self.patch_size_3D[2]])  # y1,x1,y2,x2,z1,z2
 
       if self.dim == 2:
         self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4]
         self.bbox_std_dev = self.bbox_std_dev[:4]
         self.window = self.window[:4]
         self.scale = self.scale[:4]
 
       self.plot_y_max = 1.5
       self.n_plot_rpn_props = 5 if self.dim == 2 else 30  # per batch_instance (slice in 2D / patient in 3D)
 
       # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element.
       self.pre_nms_limit = 2000 if self.dim == 2 else 4000
 
       # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True,
       # since proposals of the entire batch are forwarded through second stage as one "batch".
       self.roi_chunk_size = 1300 if self.dim == 2 else 500
       self.post_nms_rois_training = 200 * (self.head_classes-1) if self.dim == 2 else 400
       self.post_nms_rois_inference = 200 * (self.head_classes-1)
 
       # Final selection of detections (refine_detections)
       self.model_max_instances_per_batch_element = 9 if self.dim == 2 else 18 # per batch element and class.
       self.detection_nms_threshold = self.model_max_iou_resolution  # needs to be > 0, otherwise all predictions are one cluster.
       self.model_min_confidence = 0.2  # iou for nms in box refining (directly after heads), should be >0 since ths>=x in mrcnn.py
 
       if self.dim == 2:
         self.backbone_shapes = np.array(
           [[int(np.ceil(self.patch_size[0] / stride)),
             int(np.ceil(self.patch_size[1] / stride))]
            for stride in self.backbone_strides['xy']])
       else:
         self.backbone_shapes = np.array(
           [[int(np.ceil(self.patch_size[0] / stride)),
             int(np.ceil(self.patch_size[1] / stride)),
             int(np.ceil(self.patch_size[2] / stride_z))]
            for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z']
                                        )])
 
       if self.model == 'retina_net' or self.model == 'retina_unet':
         # whether to use focal loss or SHEM for loss-sample selection
         self.focal_loss = False
         # implement extra anchor-scales according to https://arxiv.org/abs/1708.02002
         self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                         self.rpn_anchor_scales['xy']]
         self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                        self.rpn_anchor_scales['z']]
         self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3
 
         # pre-selection of detections for NMS-speedup. per entire batch.
         self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size
 
         # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002
         self.anchor_matching_iou = 0.7
 
         if self.model == 'retina_unet':
           self.operate_stride1 = True
diff --git a/datasets/toy/data_loader.py b/datasets/toy/data_loader.py
index 0590fae..78fb673 100644
--- a/datasets/toy/data_loader.py
+++ b/datasets/toy/data_loader.py
@@ -1,598 +1,582 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import sys
 sys.path.append('../') # works on cluster indep from where sbatch job is started
 import plotting as plg
 
 import numpy as np
 import os
+from multiprocessing import Lock
 from collections import OrderedDict
 import pandas as pd
 import pickle
 import time
 
 # batch generator tools from https://github.com/MIC-DKFZ/batchgenerators
 from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror
 from batchgenerators.transforms.abstract_transforms import Compose
 from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
 from batchgenerators.transforms.spatial_transforms import SpatialTransform
 from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform
 
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 import utils.dataloader_utils as dutils
 from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates
 
 
 def load_obj(file_path):
     with open(file_path, 'rb') as handle:
         return pickle.load(handle)
 
 class Dataset(dutils.Dataset):
     r""" Load a dict holding memmapped arrays and clinical parameters for each patient,
     evtly subset of those.
         If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to
         cf.data_dir.
     :param cf: config file
     :param folds: number of folds out of @params n_cv folds to include
     :param n_cv: number of total folds
     :return: dict with imgs, segs, pids, class_labels, observables
     """
 
     def __init__(self, cf, logger, subset_ids=None, data_sourcedir=None, mode='train'):
         super(Dataset,self).__init__(cf, data_sourcedir=data_sourcedir)
 
         load_exact_gts = (mode=='test' or cf.val_mode=="val_patient") and self.cf.test_against_exact_gt
 
         p_df = pd.read_pickle(os.path.join(self.data_dir, cf.info_df_name))
 
         if subset_ids is not None:
             p_df = p_df[p_df.pid.isin(subset_ids)]
             logger.info('subset: selected {} instances from df'.format(len(p_df)))
 
         pids = p_df.pid.tolist()
         #evtly copy data from data_sourcedir to data_dest
         if cf.server_env and not hasattr(cf, "data_dir"):
             file_subset = [os.path.join(self.data_dir, '{}.*'.format(pid)) for pid in pids]
             file_subset += [os.path.join(self.data_dir, '{}_seg.*'.format(pid)) for pid in pids]
             file_subset += [cf.info_df_name]
             if load_exact_gts:
                 file_subset += [os.path.join(self.data_dir, '{}_exact_seg.*'.format(pid)) for pid in pids]
             self.copy_data(cf, file_subset=file_subset)
 
         img_paths = [os.path.join(self.data_dir, '{}.npy'.format(pid)) for pid in pids]
         seg_paths = [os.path.join(self.data_dir, '{}_seg.npy'.format(pid)) for pid in pids]
         if load_exact_gts:
             exact_seg_paths = [os.path.join(self.data_dir, '{}_exact_seg.npy'.format(pid)) for pid in pids]
 
         class_targets = p_df['class_ids'].tolist()
         rg_targets = p_df['regression_vectors'].tolist()
         if load_exact_gts:
             exact_rg_targets = p_df['undistorted_rg_vectors'].tolist()
         fg_slices = p_df['fg_slices'].tolist()
 
         self.data = OrderedDict()
         for ix, pid in enumerate(pids):
             self.data[pid] = {'data': img_paths[ix], 'seg': seg_paths[ix], 'pid': pid,
                               'fg_slices': np.array(fg_slices[ix])}
             if load_exact_gts:
                 self.data[pid]['exact_seg'] = exact_seg_paths[ix]
             if 'class' in self.cf.prediction_tasks:
                 self.data[pid]['class_targets'] = np.array(class_targets[ix], dtype='uint8')
             else:
                 self.data[pid]['class_targets'] = np.ones_like(np.array(class_targets[ix]), dtype='uint8')
             if load_exact_gts:
                 self.data[pid]['exact_class_targets'] = self.data[pid]['class_targets']
             if any(['regression' in task for task in self.cf.prediction_tasks]):
                 self.data[pid]['regression_targets'] = np.array(rg_targets[ix], dtype='float16')
                 self.data[pid]["rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in rg_targets[ix]], dtype='uint8')
                 if load_exact_gts:
                     self.data[pid]['exact_regression_targets'] = np.array(exact_rg_targets[ix], dtype='float16')
                     self.data[pid]["exact_rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in exact_rg_targets[ix]],
                                                                 dtype='uint8')
 
 
         cf.roi_items = cf.observables_rois[:]
         cf.roi_items += ['class_targets']
         if any(['regression' in task for task in self.cf.prediction_tasks]):
             cf.roi_items += ['regression_targets']
             cf.roi_items += ['rg_bin_targets']
 
         self.set_ids = np.array(list(self.data.keys()))
         self.df = None
 
 class BatchGenerator(dutils.BatchGenerator):
     """
     creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D)
     from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size.
     Actual patch_size is obtained after data augmentation.
     :param data: data dictionary as provided by 'load_dataset'.
     :param batch_size: number of patients to sample for the batch
     :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target
     """
-    def __init__(self, cf, data, sample_pids_w_replace=True):
-        super(BatchGenerator, self).__init__(cf, data)
+    def __init__(self, cf, data, sample_pids_w_replace=True, max_batches=None, raise_stop_iteration=False, seed=0):
+        super(BatchGenerator, self).__init__(cf, data, sample_pids_w_replace=sample_pids_w_replace,
+                                             max_batches=max_batches, raise_stop_iteration=raise_stop_iteration,
+                                             seed=seed)
 
         self.chans = cf.channels if cf.channels is not None else np.index_exp[:]
         assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing"
 
-
-        self.sample_pids_w_replace = sample_pids_w_replace
-        self.eligible_pids = list(self._data.keys())
-
         self.crop_margin = np.array(self.cf.patch_size) / 8.  # min distance of ROI center to edge of cropped_patch.
         self.p_fg = 0.5
         self.empty_samples_max_ratio = 0.6
-        self.random_count = int(cf.batch_random_ratio * cf.batch_size)
 
         self.balance_target_distribution(plot=sample_pids_w_replace)
         self.stats = {"roi_counts": np.zeros((len(self.unique_ts),), dtype='uint32'), "empty_samples_count": 0}
 
-
     def generate_train_batch(self):
         # everything done in here is per batch
         # print statements in here get confusing due to multithreading
-        if self.sample_pids_w_replace:
-            # fully random patients
-            batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False))
-            # target-balanced patients
-            batch_patient_ids += list(np.random.choice(
-                self.dataset_pids, size=self.batch_size - self.random_count, replace=False, p=self.p_probs))
-        else:
-            batch_patient_ids = np.random.choice(self.eligible_pids, size=self.batch_size,
-                                                 replace=False)
-        if self.sample_pids_w_replace == False:
-            self.eligible_pids = [pid for pid in self.eligible_pids if pid not in batch_patient_ids]
-            if len(self.eligible_pids) < self.batch_size:
-                self.eligible_pids = self.dataset_pids
+
+        batch_pids = self.get_batch_pids()
 
         batch_data, batch_segs, batch_patient_targets = [], [], []
         batch_roi_items = {name: [] for name in self.cf.roi_items}
         # record roi count of classes in batch
         # empty count for full bg samples (empty slices in 2D/patients in 3D) in slot num_classes (last)
         batch_roi_counts, empty_samples_count = np.zeros((len(self.unique_ts),), dtype='uint32'), 0
 
         for b in range(self.batch_size):
-            patient = self._data[batch_patient_ids[b]]
+            patient = self._data[batch_pids[b]]
 
             data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis]
             seg =  np.load(patient['seg'], mmap_mode='r').astype('uint8')
 
             (c, y, x, z) = data.shape
             if self.cf.dim == 2:
                 elig_slices, choose_fg = [], False
                 if len(patient['fg_slices']) > 0:
                     if empty_samples_count / self.batch_size >= self.empty_samples_max_ratio or np.random.rand(
                             1) <= self.p_fg:
                         # fg is to be picked
                         for tix in np.argsort(batch_roi_counts):
                             # pick slices of patient that have roi of sought-for target
                             # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix
                             elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero(
                                 patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] ==
                                 self.unique_ts[tix]) > 0]
                             if len(elig_slices) > 0:
                                 choose_fg = True
                                 break
                     else:
                         # pick bg
                         elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices'])
                 if len(elig_slices) > 0:
                     sl_pick_ix = np.random.choice(elig_slices, size=None)
                 else:
                     sl_pick_ix = np.random.choice(z, size=None)
                 data = data[..., sl_pick_ix]
                 seg = seg[..., sl_pick_ix]
 
             spatial_shp = data[0].shape
             assert spatial_shp == seg.shape, "spatial shape incongruence betw. data and seg"
             if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]):
                 new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))]
                 data = dutils.pad_nd_image(data, (len(data), *new_shape))
                 seg = dutils.pad_nd_image(seg, new_shape)
 
             # eventual cropping to pre_crop_size: sample pixel from random ROI and shift center,
             # if possible, to that pixel, so that img still contains ROI after pre-cropping
             dim_cropflags = [spatial_shp[i] > self.cf.pre_crop_size[i] for i in range(len(spatial_shp))]
             if np.any(dim_cropflags):
                 # sample pixel from random ROI and shift center, if possible, to that pixel
                 if self.cf.dim==3:
                     choose_fg = (empty_samples_count/self.batch_size>=self.empty_samples_max_ratio) or np.random.rand(1) <= self.p_fg
                 if choose_fg and np.any(seg):
                     available_roi_ids = np.unique(seg)[1:]
                     for tix in np.argsort(batch_roi_counts):
                         elig_roi_ids = available_roi_ids[patient[self.balance_target][available_roi_ids-1] == self.unique_ts[tix]]
                         if len(elig_roi_ids)>0:
                             seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None))
                             break
                     roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)]
                     assert seg[tuple(roi_anchor_pixel)] > 0
 
                     # sample the patch center coords. constrained by edges of image - pre_crop_size /2 and
                     # distance to the selected ROI < patch_size /2
                     def get_cropped_centercoords(dim):
                         low = np.max((self.cf.pre_crop_size[dim] // 2,
                                       roi_anchor_pixel[dim] - (
                                                   self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim])))
                         high = np.min((spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2,
                                        roi_anchor_pixel[dim] + (
                                                    self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim])))
                         if low >= high:  # happens if lesion on the edge of the image.
                             low = self.cf.pre_crop_size[dim] // 2
                             high = spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2
 
                         assert low < high, 'low greater equal high, data dimension {} too small, shp {}, patient {}, low {}, high {}'.format(
                             dim,
                             spatial_shp, patient['pid'], low, high)
                         return np.random.randint(low=low, high=high)
                 else:
                     # sample crop center regardless of ROIs, not guaranteed to be empty
                     def get_cropped_centercoords(dim):
                         return np.random.randint(low=self.cf.pre_crop_size[dim] // 2,
                                                  high=spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2)
 
                 sample_seg_center = {}
                 for dim in np.where(dim_cropflags)[0]:
                     sample_seg_center[dim] = get_cropped_centercoords(dim)
                     min_ = int(sample_seg_center[dim] - self.cf.pre_crop_size[dim] // 2)
                     max_ = int(sample_seg_center[dim] + self.cf.pre_crop_size[dim] // 2)
                     data = np.take(data, indices=range(min_, max_), axis=dim + 1)  # +1 for channeldim
                     seg = np.take(seg, indices=range(min_, max_), axis=dim)
 
             batch_data.append(data)
             batch_segs.append(seg[np.newaxis])
 
             for o in batch_roi_items: #after loop, holds every entry of every batchpatient per observable
                     batch_roi_items[o].append(patient[o])
 
             if self.cf.dim == 3:
                 for tix in range(len(self.unique_ts)):
                     batch_roi_counts[tix] += np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix])
             elif self.cf.dim == 2:
                 for tix in range(len(self.unique_ts)):
                     batch_roi_counts[tix] += np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix])
             if not np.any(seg):
                 empty_samples_count += 1
 
         batch = {'data': np.array(batch_data), 'seg': np.array(batch_segs).astype('uint8'),
-                 'pid': batch_patient_ids,
+                 'pid': batch_pids,
                  'roi_counts': batch_roi_counts, 'empty_samples_count': empty_samples_count}
         for key,val in batch_roi_items.items(): #extend batch dic by entries of observables dic
             batch[key] = np.array(val)
 
         return batch
 
 class PatientBatchIterator(dutils.PatientBatchIterator):
     """
     creates a test generator that iterates over entire given dataset returning 1 patient per batch.
     Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actually evaluation (done in 3D),
     if willing to accept speed-loss during training.
     Specific properties of toy data set: toy data may be created with added ground-truth noise. thus, there are
     exact ground truths (GTs) and noisy ground truths available. the normal or noisy GTs are used in training by
     the BatchGenerator. The PatientIterator, however, may use the exact GTs if set in configs.
 
     :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or
     batch_size = n_2D_patches in 2D .
     """
 
     def __init__(self, cf, data, mode='test'):
         super(PatientBatchIterator, self).__init__(cf, data)
 
         self.patch_size = cf.patch_size_2D + [1] if cf.dim == 2 else cf.patch_size_3D
         self.chans = cf.channels if cf.channels is not None else np.index_exp[:]
         assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing"
 
         if (mode=="validation" and hasattr(self.cf, 'val_against_exact_gt') and self.cf.val_against_exact_gt) or \
                 (mode == 'test' and self.cf.test_against_exact_gt):
             self.gt_prefix = 'exact_'
             print("PatientIterator: Loading exact Ground Truths.")
         else:
             self.gt_prefix = ''
 
         self.patient_ix = 0  # running index over all patients in set
 
     def generate_train_batch(self, pid=None):
 
         if pid is None:
             pid = self.dataset_pids[self.patient_ix]
         patient = self._data[pid]
 
         # already swapped dimensions in pp from (c,)z,y,x to c,y,x,z or h,w,d to ease 2D/3D-case handling
         data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis]
         seg =  np.load(patient[self.gt_prefix+'seg']).astype('uint8')[np.newaxis]
 
         data_shp_raw = data.shape
         plot_bg = data[self.cf.plot_bg_chan] if self.cf.plot_bg_chan not in self.chans else None
         data = data[self.chans]
         discarded_chans = len(
             [c for c in np.setdiff1d(np.arange(data_shp_raw[0]), self.chans) if c < self.cf.plot_bg_chan])
         spatial_shp = data[0].shape  # spatial dims need to be in order x,y,z
         assert spatial_shp == seg[0].shape, "spatial shape incongruence betw. data and seg"
 
         if np.any([spatial_shp[i] < ps for i, ps in enumerate(self.patch_size)]):
             new_shape = [np.max([spatial_shp[i], self.patch_size[i]]) for i in range(len(self.patch_size))]
             data = dutils.pad_nd_image(data, new_shape)  # use 'return_slicer' to crop image back to original shape.
             seg = dutils.pad_nd_image(seg, new_shape)
             if plot_bg is not None:
                 plot_bg = dutils.pad_nd_image(plot_bg, new_shape)
 
         if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds:
             # adds the batch dim here bc won't go through MTaugmenter
             out_data = data[np.newaxis]
             out_seg = seg[np.newaxis]
             if plot_bg is not None:
                out_plot_bg = plot_bg[np.newaxis]
             # data and seg shape: (1,c,x,y,z), where c=1 for seg
 
             batch_3D = {'data': out_data, 'seg': out_seg}
             for o in self.cf.roi_items:
                 batch_3D[o] = np.array([patient[self.gt_prefix+o]])
             converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg)
             batch_3D = converter(**batch_3D)
             batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape})
             for o in self.cf.roi_items:
                 batch_3D["patient_" + o] = batch_3D[o]
 
         if self.cf.dim == 2:
             out_data = np.transpose(data, axes=(3, 0, 1, 2)).astype('float32')  # (c,y,x,z) to (b=z,c,x,y), use z=b as batchdim
             out_seg = np.transpose(seg, axes=(3, 0, 1, 2)).astype('uint8')  # (c,y,x,z) to (b=z,c,x,y)
 
             batch_2D = {'data': out_data, 'seg': out_seg}
             for o in self.cf.roi_items:
                 batch_2D[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(out_data), axis=0)
             converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg)
             batch_2D = converter(**batch_2D)
 
             if plot_bg is not None:
                 out_plot_bg = np.transpose(plot_bg, axes=(2, 0, 1)).astype('float32')
 
             if self.cf.merge_2D_to_3D_preds:
                 batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'],
                                  'original_img_shape': out_data.shape})
                 for o in self.cf.roi_items:
                     batch_2D["patient_" + o] = batch_3D[o]
             else:
                 batch_2D.update({'patient_bb_target': batch_2D['bb_target'],
                                  'original_img_shape': out_data.shape})
                 for o in self.cf.roi_items:
                     batch_2D["patient_" + o] = batch_2D[o]
 
         out_batch = batch_3D if self.cf.dim == 3 else batch_2D
         out_batch.update({'pid': np.array([patient['pid']] * len(out_data))})
 
         if self.cf.plot_bg_chan in self.chans and discarded_chans > 0:  # len(self.chans[:self.cf.plot_bg_chan])<data_shp_raw[0]:
             assert plot_bg is None
             plot_bg = int(self.cf.plot_bg_chan - discarded_chans)
             out_plot_bg = plot_bg
         if plot_bg is not None:
             out_batch['plot_bg'] = out_plot_bg
 
         # eventual tiling into patches
         spatial_shp = out_batch["data"].shape[2:]
         if np.any([spatial_shp[ix] > self.patch_size[ix] for ix in range(len(spatial_shp))]):
             patient_batch = out_batch
             print("patientiterator produced patched batch!")
             patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size)
             new_img_batch, new_seg_batch = [], []
 
             for c in patch_crop_coords_list:
                 new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:c[5]])
                 seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]]
                 new_seg_batch.append(seg_patch)
             shps = []
             for arr in new_img_batch:
                 shps.append(arr.shape)
 
             data = np.array(new_img_batch)  # (patches, c, x, y, z)
             seg = np.array(new_seg_batch)
             if self.cf.dim == 2:
                 # all patches have z dimension 1 (slices). discard dimension
                 data = data[..., 0]
                 seg = seg[..., 0]
             patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'),
                            'pid': np.array([patient['pid']] * data.shape[0])}
             for o in self.cf.roi_items:
                 patch_batch[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(patch_crop_coords_list), axis=0)
             #patient-wise (orig) batch info for putting the patches back together after prediction
             for o in self.cf.roi_items:
                 patch_batch["patient_"+o] = patient_batch["patient_"+o]
                 if self.cf.dim == 2:
                     # this could also be named "unpatched_2d_roi_items"
                     patch_batch["patient_" + o + "_2d"] = patient_batch[o]
             patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
             patch_batch['patient_bb_target'] = patient_batch['patient_bb_target']
             if self.cf.dim == 2:
                 patch_batch['patient_bb_target_2d'] = patient_batch['bb_target']
             patch_batch['patient_data'] = patient_batch['data']
             patch_batch['patient_seg'] = patient_batch['seg']
             patch_batch['original_img_shape'] = patient_batch['original_img_shape']
             if plot_bg is not None:
                 patch_batch['patient_plot_bg'] = patient_batch['plot_bg']
 
             converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, get_rois_from_seg=False,
                                                            class_specific_seg=self.cf.class_specific_seg)
 
             patch_batch = converter(**patch_batch)
             out_batch = patch_batch
 
         self.patient_ix += 1
         if self.patient_ix == len(self.dataset_pids):
             self.patient_ix = 0
 
         return out_batch
 
 
 def create_data_gen_pipeline(cf, patient_data, do_aug=True, sample_pids_w_replace=True):
     """
     create mutli-threaded train/val/test batch generation and augmentation pipeline.
     :param patient_data: dictionary containing one dictionary per patient in the train/test subset.
     :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing)
     :return: multithreaded_generator
     """
 
     # create instance of batch generator as first element in pipeline.
     data_gen = BatchGenerator(cf, patient_data, sample_pids_w_replace=sample_pids_w_replace)
 
     my_transforms = []
     if do_aug:
         if cf.da_kwargs["mirror"]:
             mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes'])
             my_transforms.append(mirror_transform)
 
         spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
                                              patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
                                              do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
                                              alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
                                              do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
                                              angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
                                              do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
                                              random_crop=cf.da_kwargs['random_crop'])
 
         my_transforms.append(spatial_transform)
     else:
         my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))
 
     my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg))
     all_transforms = Compose(my_transforms)
     # multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms)
     multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
     return multithreaded_generator
 
 def get_train_generators(cf, logger, data_statistics=False):
     """
     wrapper function for creating the training batch generator pipeline. returns the train/val generators.
     selects patients according to cv folds (generated by first run/fold of experiment):
     splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set)
     If cf.hold_out_test_set is True, adds the test split to the training data.
     """
     dataset = Dataset(cf, logger)
     dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits)
     dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle'))
     set_splits = dataset.fg.splits
 
     test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1)
     train_ids = np.concatenate(set_splits, axis=0)
 
     if cf.held_out_test_set:
         train_ids = np.concatenate((train_ids, test_ids), axis=0)
         test_ids = []
 
     train_data = {k: v for (k, v) in dataset.data.items() if str(k) in train_ids}
     val_data = {k: v for (k, v) in dataset.data.items() if str(k) in val_ids}
 
     logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids),
                                                                                     len(test_ids)))
     if data_statistics:
         dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids}, plot_dir=
         os.path.join(cf.plot_dir,"dataset"))
 
     batch_gen = {}
     batch_gen['train'] = create_data_gen_pipeline(cf, train_data, do_aug=cf.do_aug, sample_pids_w_replace=True)
     batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, do_aug=False, sample_pids_w_replace=False)
 
     if cf.val_mode == 'val_patient':
         batch_gen['val_patient'] = PatientBatchIterator(cf, val_data, mode='validation')
         batch_gen['n_val'] = len(val_ids) if cf.max_val_patients=="all" else min(len(val_ids), cf.max_val_patients)
     elif cf.val_mode == 'val_sampling':
         batch_gen['n_val'] = cf.num_val_batches if cf.num_val_batches != "all" else len(val_data)
 
     return batch_gen
 
 def get_test_generator(cf, logger):
     """
     if get_test_generators is possibly called multiple times in server env, every time of
     Dataset initiation rsync will check for copying the data; this should be okay
     since rsync will not copy if files already exist in destination.
     """
 
     if cf.held_out_test_set:
         sourcedir = cf.test_data_sourcedir
         test_ids = None
     else:
         sourcedir = None
         with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle:
             set_splits = pickle.load(handle)
         test_ids = set_splits[cf.fold]
 
     test_set = Dataset(cf, logger, subset_ids=test_ids, data_sourcedir=sourcedir, mode='test')
     logger.info("data set loaded with: {} test patients".format(len(test_set.set_ids)))
     batch_gen = {}
     batch_gen['test'] = PatientBatchIterator(cf, test_set.data)
     batch_gen['n_test'] = len(test_set.set_ids) if cf.max_test_patients=="all" else \
         min(cf.max_test_patients, len(test_set.set_ids))
 
     return batch_gen
 
 
 if __name__=="__main__":
 
     import utils.exp_utils as utils
     from configs import Configs
 
     cf = Configs()
 
     total_stime = time.time()
     times = {}
 
     # cf.server_env = True
     # cf.data_dir = "experiments/dev_data"
 
     cf.exp_dir = "experiments/dev/"
     cf.plot_dir = cf.exp_dir + "plots"
     os.makedirs(cf.exp_dir, exist_ok=True)
     cf.fold = 0
     logger = utils.get_logger(cf.exp_dir)
     gens = get_train_generators(cf, logger)
     train_loader = gens['train']
-    for i in range(1):
+    for i in range(0):
         stime = time.time()
         print("producing training batch nr ", i)
         ex_batch = next(train_loader)
         times["train_batch"] = time.time() - stime
         #experiments/dev/dev_exbatch_{}.png".format(i)
-        print(ex_batch.keys())
         plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=False)
 
 
     val_loader = gens['val_sampling']
     stime = time.time()
     for i in range(1):
         ex_batch = next(val_loader)
         times["val_batch"] = time.time() - stime
         stime = time.time()
         #"experiments/dev/dev_exvalbatch_{}.png"
         plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=True)
         times["val_plot"] = time.time() - stime
-    import IPython; IPython.embed()
     #
     test_loader = get_test_generator(cf, logger)["test"]
     stime = time.time()
     ex_batch = test_loader.generate_train_batch(pid=None)
     times["test_batch"] = time.time() - stime
     stime = time.time()
     plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png", vmin=0)
     times["test_patchbatch_plot"] = time.time() - stime
 
 
 
     print("Times recorded throughout:")
     for (k, v) in times.items():
         print(k, "{:.2f}".format(v))
 
     mins, secs = divmod((time.time() - total_stime), 60)
     h, mins = divmod(mins, 60)
     t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
     print("{} total runtime: {}".format(os.path.split(__file__)[1], t))
\ No newline at end of file
diff --git a/datasets/toy/generate_toys.py b/datasets/toy/generate_toys.py
index d860b01..0a3faeb 100644
--- a/datasets/toy/generate_toys.py
+++ b/datasets/toy/generate_toys.py
@@ -1,388 +1,399 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 """ Generate a data set of toy examples. Examples can be cylinders, spheres, blocks, diamonds.
     Distortions may be applied, e.g., noise to the radius ground truths.
     Settings are configured in configs file.
 """
 
 import plotting as plg
 import os
 import shutil
 import warnings
 import time
 from multiprocessing import Pool
 
 import numpy as np
 import pandas as pd
 
 import data_manager as dmanager
 
 
 for msg in ["RuntimeWarning: divide by zero encountered in true_divide.*",]:
     warnings.filterwarnings("ignore", msg)
 
 
 class ToyGenerator(object):
     """ Generator of toy data set.
         A train and a test split with certain nr of samples are created and saved to disk. Samples can contain varying
         number of objects. Objects have shapes cylinder or block (diamond, ellipsoid, torus not fully implemented).
 
         self.mp_args holds image split and id, objects are then randomly drawn into each image. Multi-processing is
         enabled for parallel creation of images, final .npy-files can then be converted to .npz.
     """
     def __init__(self, cf):
         """
         :param cf: configs file holding object specifications and output directories.
         """
 
         self.cf = cf
 
         self.n_train, self.n_test = cf.n_train_samples, cf.n_test_samples
         self.sample_size = cf.pre_crop_size
         self.dim = len(self.sample_size)
         self.class_radii = np.array([label.radius for label in self.cf.pp_classes if label.id!=0])
         self.class_id2label = {label.id: label for label in self.cf.pp_classes}
 
         self.mp_args = []
         # count sample ids consecutively over train, test splits within on dataset (one shape kind)
         self.last_s_id = 0
         for split in ["train", "test"]:
             self.set_splits_info(split)
 
     def set_splits_info(self, split):
         """ Set info for data set splits, i.e., directory and nr of samples.
         :param split: name of split, in {"train", "test"}.
         """
         out_dir = os.path.join(self.cf.pp_rootdir, split)
         os.makedirs(out_dir, exist_ok=True)
 
         n_samples = self.n_train if "train" in split else self.n_test
+        req_exact_gt = "test" in split
 
-        self.mp_args+= [[out_dir, self.last_s_id+running_id] for running_id in range(n_samples)]
+        self.mp_args += [[out_dir, self.last_s_id+running_id, req_exact_gt] for running_id in range(n_samples)]
         self.last_s_id+= n_samples
 
     def generate_sample_radii(self, class_ids, shapes):
 
         # the radii set in labels are ranges to sample from in the form [(min_x,min_y,min_z), (max_x,max_y,max_z)]
         all_radii = []
         for ix, cl_radii in enumerate([self.class_radii[cl_id - 1].transpose() for cl_id in class_ids]):
             if "cylinder" in shapes[ix] or "block" in shapes[ix]:
                 # maintain 2D aspect ratio
                 sample_radii = [np.random.uniform(*cl_radii[0])] * 2
                 assert len(sample_radii) == 2, "upper sr {}, cl_radii {}".format(sample_radii, cl_radii)
                 if self.cf.pp_place_radii_mid_bin:
                     bef_conv_r = np.copy(sample_radii)
                     bin_id =  self.cf.rg_val_to_bin_id(bef_conv_r)
                     assert np.isscalar(bin_id)
                     sample_radii = self.cf.bin_id2rg_val[bin_id]*2
                     assert len(sample_radii) == 2, "mid before sr {}, sr {}, rgv2bid {}, cl_radii {},  bid2rgval {}".format(bef_conv_r, sample_radii, bin_id, cl_radii,
                                                                                                              self.cf.bin_id2rg_val[bin_id])
             else:
                 raise NotImplementedError("requested object shape {}".format(shapes[ix]))
             if self.dim == 3:
                 assert len(sample_radii) == 2, "lower sr {}, cl_radii {}".format(sample_radii, cl_radii)
                 #sample_radii += [np.random.uniform(*cl_radii[2])]
                 sample_radii = np.concatenate((sample_radii, np.random.uniform(*cl_radii[2], size=1)))
             all_radii.append(sample_radii)
 
         return all_radii
 
     def apply_gt_distort(self, class_id, radii, radii_divs, outer_min_radii=None, outer_max_radii=None):
         """ Apply a distortion to the ground truth (gt). This is motivated by investigating the effects of noisy labels.
             GTs that can be distorted are the object radii and ensuing GT quantities like segmentation and regression
             targets.
         :param class_id: class id of object.
         :param radii: radii of object. This is in the abstract sense, s.t. for a block-shaped object radii give the side
             lengths.
         :param radii_divs: radii divisors, i.e., fractions to take from radii to get inner radii of hole-shaped objects,
             like a torus.
         :param outer_min_radii: min radii assignable when distorting gt.
         :param outer_max_radii: max radii assignable when distorting gt.
         :return:
         """
         applied_gt_distort = False
         for ambig in self.class_id2label[class_id].gt_distortion:
             if self.cf.ambiguities[ambig][0] > np.random.rand():
                 if ambig == "outer_radius":
                     radii = radii * abs(np.random.normal(1., self.cf.ambiguities["outer_radius"][1]))
                     applied_gt_distort = True
                 if ambig == "radii_relations":
                     radii = radii * abs(np.random.normal(1.,self.cf.ambiguities["radii_relations"][1],size=len(radii)))
                     applied_gt_distort = True
                 if ambig == "inner_radius":
                     radii_divs = radii_divs * abs(np.random.normal(1., self.cf.ambiguities["inner_radius"][1]))
                     applied_gt_distort = True
                 if ambig == "radius_calib":
                     if self.cf.ambigs_sampling=="uniform":
                         radii = abs(np.random.uniform(outer_min_radii, outer_max_radii))
                     elif self.cf.ambigs_sampling=="gaussian":
                         distort = abs(np.random.normal(1, scale=self.cf.ambiguities["radius_calib"][1], size=None))
                         assert len(radii) == self.dim, "radii {}".format(radii)
                         radii *= [distort, distort, 1.] if self.cf.pp_only_distort_2d else distort
                     applied_gt_distort = True
         return radii, radii_divs, applied_gt_distort
 
     def draw_object(self, img, seg, undistorted_seg, ics, regress_targets, undistorted_rg_targets, applied_gt_distort,
                                  roi_ix, class_id, shape, radii, center):
         """ Draw a single object into the given image and add it to the corresponding ground truths.
         :param img: image (volume) to hold the object.
         :param seg: pixel-wise labelling of the image, possibly distorted if gt distortions are applied.
         :param undistorted_seg: certainly undistorted, i.e., exact segmentation of object.
         :param ics: indices which mark the positions within the image.
         :param regress_targets: regression targets (e.g., 2D radii of object), evtly distorted.
         :param undistorted_rg_targets: undistorted regression targets.
         :param applied_gt_distort: boolean, whether or not gt distortion was applied.
         :param roi_ix: running index of object in whole image.
         :param class_id: class id of object.
         :param shape: shape of object (e.g., whether to draw a cylinder, or block, or ...).
         :param radii: radii of object (in an abstract sense, i.e., radii are side lengths in case of block shape).
         :param center: center of object in image coordinates.
         :return: img, seg, undistorted_seg, regress_targets, undistorted_rg_targets, applied_gt_distort, which are now
             extended are amended to reflect the new object.
         """
 
         radii_blur = hasattr(self.cf, "ambiguities") and hasattr(self.class_id2label[class_id],
                                                                  "gt_distortion") and 'radius_calib' in \
                      self.class_id2label[class_id].gt_distortion
 
         if radii_blur:
             blur_width = self.cf.ambiguities['radius_calib'][1]
             if self.cf.ambigs_sampling == "uniform":
                 blur_width *= np.sqrt(12)
             if self.cf.pp_only_distort_2d:
                 outer_max_radii = np.concatenate((radii[:2] + blur_width * radii[:2], [radii[2]]))
                 outer_min_radii = np.concatenate((radii[:2] - blur_width * radii[:2], [radii[2]]))
                 #print("belt width ", outer_max_radii - outer_min_radii)
             else:
                 outer_max_radii = radii + blur_width * radii
                 outer_min_radii = radii - blur_width * radii
         else:
             outer_max_radii, outer_min_radii = radii, radii
 
         if "ellipsoid" in shape or "torus" in shape:
             # sphere equation: (x-h)**2 + (y-k)**2 - (z-l)**2 = r**2
             # ellipsoid equation: ((x-h)/a)**2+((y-k)/b)**2+((z-l)/c)**2 <= 1; a, b, c the "radii"/ half-length of principal axes
             obj = ((ics - center) / radii) ** 2
         elif "diamond" in shape:
             # diamond equation: (|x-h|)/a+(|y-k|)/b+(|z-l|)/c <= 1
             obj = abs(ics - center) / radii
         elif "cylinder" in shape:
             # cylinder equation:((x-h)/a)**2 + ((y-k)/b)**2 <= 1 while |z-l| <= c
             obj = ((ics - center).astype("float64") / radii) ** 2
             # set z values s.t. z slices outside range are sorted out
             obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= radii[2], 0., 1.1)
             if radii_blur:
                 inner_obj = ((ics - center).astype("float64") / outer_min_radii) ** 2
                 inner_obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= outer_min_radii[2], 0., 1.1)
                 outer_obj = ((ics - center).astype("float64") / outer_max_radii) ** 2
                 outer_obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= outer_max_radii[2], 0., 1.1)
                 # radial dists: sqrt( (x-h)**2 + (y-k)**2 + (z-l)**2 )
                 obj_radial_dists = np.sqrt(np.sum((ics - center).astype("float64")**2, axis=1))
         elif "block" in shape:
             # block equation: (|x-h|)/a+(|y-k|)/b <= 1 while  |z-l| <= c
             obj = abs(ics - center) / radii
             obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= radii[2], 0., 1.1)
             if radii_blur:
                 inner_obj = abs(ics - center) / outer_min_radii
                 inner_obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= outer_min_radii[2], 0., 1.1)
                 outer_obj = abs(ics - center) / outer_max_radii
                 outer_obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= outer_max_radii[2], 0., 1.1)
                 obj_radial_dists = np.sum(abs(ics - center), axis=1).astype("float64")
         else:
             raise Exception("Invalid object shape '{}'".format(shape))
 
         # create the "original" GT, i.e., the actually true object and draw it into undistorted seg.
         obj = (np.sum(obj, axis=1) <= 1)
         obj = obj.reshape(seg[0].shape)
         slices_to_discard = np.where(np.count_nonzero(np.count_nonzero(obj, axis=0), axis=0) <= self.cf.min_2d_radius)[0]
         obj[..., slices_to_discard] = 0
         undistorted_radii = np.copy(radii)
         undistorted_seg[class_id][obj] = roi_ix + 1
         obj = obj.astype('float64')
 
         if radii_blur:
             inner_obj = np.sum(inner_obj, axis=1) <= 1
             outer_obj = (np.sum(outer_obj, axis=1) <= 1) & ~inner_obj
             obj_radial_dists[outer_obj] = obj_radial_dists[outer_obj] / max(obj_radial_dists[outer_obj])
             intensity_slope = self.cf.pp_blur_min_intensity - 1.
             # intensity(r) = (i(r_max)-i(0))/r_max * r + i(0), where i(0)==1.
             obj_radial_dists[outer_obj] = obj_radial_dists[outer_obj] * intensity_slope + 1.
             inner_obj = inner_obj.astype('float64')
             #outer_obj, obj_radial_dists = outer_obj.reshape(seg[0].shape), obj_radial_dists.reshape(seg[0].shape)
             inner_obj += np.where(outer_obj, obj_radial_dists, 0.)
             obj = inner_obj.reshape(seg[0].shape)
         if not np.any(obj):
             print("An object was completely discarded due to min 2d radius requirement, discarded slices: {}.".format(
                 slices_to_discard))
         # draw the evtly blurred obj into image.
         img += obj * (class_id + 1.)
 
         if hasattr(self.cf, "ambiguities") and hasattr(self.class_id2label[class_id], "gt_distortion"):
             radii_divs = [None]  # dummy since not implemented yet
             radii, radii_divs, applied_gt_distort = self.apply_gt_distort(class_id, radii, radii_divs,
                                                                           outer_min_radii, outer_max_radii)
             if applied_gt_distort:
                 if "ellipsoid" in shape or "torus" in shape:
                     obj = ((ics - center) / radii) ** 2
                 elif 'diamond' in shape:
                     obj = abs(ics - center) / radii
                 elif "cylinder" in shape:
                     obj = ((ics - center) / radii) ** 2
                     obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= radii[2], 0., 1.1)
                 elif "block" in shape:
                     obj = abs(ics - center) / radii
                     obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= radii[2], 0., 1.1)
                 obj = (np.sum(obj, axis=1) <= 1).reshape(seg[0].shape)
                 obj[..., slices_to_discard] = False
 
         if self.class_id2label[class_id].regression == "radii":
             regress_targets.append(radii)
             undistorted_rg_targets.append(undistorted_radii)
         elif self.class_id2label[class_id].regression == "radii_2d":
             regress_targets.append(radii[:2])
             undistorted_rg_targets.append(undistorted_radii[:2])
         elif self.class_id2label[class_id].regression == "radius_2d":
             regress_targets.append(radii[:1])
             undistorted_rg_targets.append(undistorted_radii[:1])
         else:
             regress_targets.append(self.class_id2label[class_id].regression)
             undistorted_rg_targets.append(self.class_id2label[class_id].regression)
 
         seg[class_id][obj.astype('bool')] = roi_ix + 1
 
         return  img, seg, undistorted_seg, regress_targets, undistorted_rg_targets, applied_gt_distort
 
     def create_sample(self, args):
         """ Create a single sample and save to file. One sample is one image (volume) containing none, one, or multiple
             objects.
         :param args: out_dir: directory where to save sample, s_id: id of the sample.
         :return: specs that identify this single created image
         """
-        out_dir, s_id = args
+        out_dir, s_id, req_exact_gt = args
 
         print('processing {} {}'.format(out_dir, s_id))
         img = np.random.normal(loc=0.0, scale=self.cf.noise_scale, size=self.sample_size)
         img[img<0.] = 0.
         # one-hot-encoded seg
         seg = np.zeros((self.cf.num_classes+1, *self.sample_size)).astype('uint8')
         undistorted_seg = np.copy(seg)
         applied_gt_distort = False
 
         if hasattr(self.cf, "pp_empty_samples_ratio") and self.cf.pp_empty_samples_ratio >= np.random.rand():
             # generate fully empty sample
             class_ids, regress_targets, undistorted_rg_targets = [], [], []
         else:
             class_choices = np.repeat(np.arange(1, self.cf.num_classes+1), self.cf.max_instances_per_class)
             n_insts = np.random.randint(1, self.cf.max_instances_per_sample + 1)
             class_ids = np.random.choice(class_choices, size=n_insts, replace=False)
             shapes = np.array([self.class_id2label[cl_id].shape for cl_id in class_ids])
             all_radii = self.generate_sample_radii(class_ids, shapes)
 
             # reorder s.t. larger objects are drawn first (in order to not fully cover smaller objects)
             order = np.argsort(-1*np.prod(all_radii,axis=1))
             class_ids = class_ids[order]; all_radii = np.array(all_radii)[order]; shapes = shapes[order]
 
             regress_targets, undistorted_rg_targets = [], []
             # indices ics equal positions within img/volume
             ics = np.argwhere(np.ones(seg[0].shape))
             for roi_ix, class_id in enumerate(class_ids):
                 radii = all_radii[roi_ix]
                 # enforce distance between object center and image edge relative to radii.
                 margin_r_divisor = (2, 2, 4)
                 center = [np.random.randint(radii[dim] / margin_r_divisor[dim], img.shape[dim] -
                                             radii[dim] / margin_r_divisor[dim]) for dim in range(len(img.shape))]
 
                 img, seg, undistorted_seg, regress_targets, undistorted_rg_targets, applied_gt_distort = \
                     self.draw_object(img, seg, undistorted_seg, ics, regress_targets, undistorted_rg_targets, applied_gt_distort,
                                  roi_ix, class_id, shapes[roi_ix], radii, center)
 
         fg_slices = np.where(np.sum(np.sum(np.sum(seg,axis=0), axis=0), axis=0))[0]
         if self.cf.pp_create_ohe_seg:
             img = img[np.newaxis]
         else:
             # choosing rois to keep by smaller radius==higher prio needs to be ensured during roi generation,
             # smaller objects need to be drawn later (==higher roi id)
             seg = seg.max(axis=0)
             seg_ids = np.unique(seg)
             if len(seg_ids) != len(class_ids) + 1:
                 # in this case an object was completely covered by a succeeding object
                 print("skipping corrupt sample")
                 print("seg ids {}, class_ids {}".format(seg_ids, class_ids))
                 return None
             if not applied_gt_distort:
                 assert np.all(np.flatnonzero(img>0) == np.flatnonzero(seg>0))
                 assert np.all(np.array(regress_targets).flatten()==np.array(undistorted_rg_targets).flatten())
 
+        # save the img
         out_path = os.path.join(out_dir, '{}.npy'.format(s_id))
-        np.save(out_path, img.astype('float16')); np.save(os.path.join(out_dir, '{}_seg.npy'.format(s_id)), seg)
-        if hasattr(self.cf, 'ambiguities') and \
-            np.any([hasattr(label, "gt_distortion") and len(label.gt_distortion)>0 for label in self.class_id2label.values()]):
-            undist_out_path = os.path.join(out_dir, '{}_exact_seg.npy'.format(s_id))
+        np.save(out_path, img.astype('float16'))
+
+        # exact GT
+        if req_exact_gt:
             if not self.cf.pp_create_ohe_seg:
                 undistorted_seg = undistorted_seg.max(axis=0)
-            np.save(undist_out_path, undistorted_seg)
+            np.save(os.path.join(out_dir, '{}_exact_seg.npy'.format(s_id)), undistorted_seg)
+        else:
+            # if hasattr(self.cf, 'ambiguities') and \
+            #     np.any([hasattr(label, "gt_distortion") and len(label.gt_distortion)>0 for label in self.class_id2label.values()]):
+            # save (evtly) distorted GT
+            np.save(os.path.join(out_dir, '{}_seg.npy'.format(s_id)), seg)
+
 
         return [out_dir, out_path, class_ids, regress_targets, fg_slices, undistorted_rg_targets, str(s_id)]
 
     def create_sets(self, processes=os.cpu_count()):
         """ Create whole training and test set, save to files under given directory cf.out_dir.
         :param processes: nr of parallel processes.
         """
-        print('starting creation of {} images'.format(len(self.mp_args)))
+
+
+        print('starting creation of {} images.'.format(len(self.mp_args)))
         shutil.copyfile("configs.py", os.path.join(self.cf.pp_rootdir, 'applied_configs.py'))
         pool = Pool(processes=processes)
-        imgs_info = pool.map(self.create_sample, self.mp_args, chunksize=1)
+        imgs_info = pool.map(self.create_sample, self.mp_args)
+        imgs_info = [img for img in imgs_info if img is not None]
         pool.close()
         pool.join()
-        imgs_info = [img for img in imgs_info if img is not None]
         print("created a total of {} samples.".format(len(imgs_info)))
+
         self.df = pd.DataFrame.from_records(imgs_info, columns=['out_dir', 'path', 'class_ids', 'regression_vectors',
                                                                 'fg_slices', 'undistorted_rg_vectors', 'pid'])
 
         for out_dir, group_df in self.df.groupby("out_dir"):
             group_df.to_pickle(os.path.join(out_dir, 'info_df.pickle'))
 
 
     def convert_copy_npz(self):
         """ Convert a copy of generated .npy-files to npz and save in .npz-directory given in configs.
         """
         if hasattr(self.cf, "pp_npz_dir") and self.cf.pp_npz_dir:
             for out_dir, group_df in self.df.groupby("out_dir"):
                 rel_dir = os.path.relpath(out_dir, self.cf.pp_rootdir).split(os.sep)
                 npz_out_dir = os.path.join(self.cf.pp_npz_dir, str(os.sep).join(rel_dir))
                 print("npz out dir: ", npz_out_dir)
                 os.makedirs(npz_out_dir, exist_ok=True)
                 group_df.to_pickle(os.path.join(npz_out_dir, 'info_df.pickle'))
                 dmanager.pack_dataset(out_dir, npz_out_dir, recursive=True, verbose=False)
         else:
             print("Did not convert .npy-files to .npz because npz directory not set in configs.")
 
 
 if __name__ == '__main__':
     import configs as cf
     cf = cf.Configs()
     total_stime = time.time()
 
     toy_gen = ToyGenerator(cf)
     toy_gen.create_sets()
-    #toy_gen.convert_copy_npz()
+    toy_gen.convert_copy_npz()
 
 
     mins, secs = divmod((time.time() - total_stime), 60)
     h, mins = divmod(mins, 60)
     t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
     print("{} total runtime: {}".format(os.path.split(__file__)[1], t))
diff --git a/utils/dataloader_utils.py b/utils/dataloader_utils.py
index 362711c..4806fc2 100644
--- a/utils/dataloader_utils.py
+++ b/utils/dataloader_utils.py
@@ -1,650 +1,699 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 import plotting as plg
 
 import os
-from multiprocessing import Pool
+from multiprocessing import Pool, Lock
 import pickle
 import warnings
 
 import numpy as np
 import pandas as pd
 from batchgenerators.transforms.abstract_transforms import AbstractTransform
 from scipy.ndimage.measurements import label as lb
 from torch.utils.data import Dataset as torchDataset
 from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
 
 import utils.exp_utils as utils
 import data_manager as dmanager
 
 
 for msg in ["This figure includes Axes that are not compatible with tight_layout",
             "Data has no positive values, and therefore cannot be log-scaled."]:
     warnings.filterwarnings("ignore", msg)
 
 
 class AttributeDict(dict):
     __getattr__ = dict.__getitem__
     __setattr__ = dict.__setitem__
 
 ##################################
 #  data loading, organisation  #
 ##################################
 
 
 class fold_generator:
     """
     generates splits of indices for a given length of a dataset to perform n-fold cross-validation.
     splits each fold into 3 subsets for training, validation and testing.
     This form of cross validation uses an inner loop test set, which is useful if test scores shall be reported on a
     statistically reliable amount of patients, despite limited size of a dataset.
     If hold out test set is provided and hence no inner loop test set needed, just add test_idxs to the training data in the dataloader.
     This creates straight-forward train-val splits.
     :returns names list: list of len n_splits. each element is a list of len 3 for train_ix, val_ix, test_ix.
     """
     def __init__(self, seed, n_splits, len_data):
         """
         :param seed: Random seed for splits.
         :param n_splits: number of splits, e.g. 5 splits for 5-fold cross-validation
         :param len_data: number of elements in the dataset.
         """
         self.tr_ix = []
         self.val_ix = []
         self.te_ix = []
         self.slicer = None
         self.missing = 0
         self.fold = 0
         self.len_data = len_data
         self.n_splits = n_splits
         self.myseed = seed
         self.boost_val = 0
 
     def init_indices(self):
 
         t = list(np.arange(self.l))
         # round up to next splittable data amount.
         split_length = int(np.ceil(len(t) / float(self.n_splits)))
         self.slicer = split_length
         self.mod = len(t) % self.n_splits
         if self.mod > 0:
             # missing is the number of folds, in which the new splits are reduced to account for missing data.
             self.missing = self.n_splits - self.mod
 
         self.te_ix = t[:self.slicer]
         self.tr_ix = t[self.slicer:]
         self.val_ix = self.tr_ix[:self.slicer]
         self.tr_ix = self.tr_ix[self.slicer:]
 
     def new_fold(self):
 
         slicer = self.slicer
         if self.fold < self.missing :
             slicer = self.slicer - 1
 
         temp = self.te_ix
 
         # catch exception mod == 1: test set collects 1+ data since walk through both roudned up splits.
         # account for by reducing last fold split by 1.
         if self.fold == self.n_splits-2 and self.mod ==1:
             temp += self.val_ix[-1:]
             self.val_ix = self.val_ix[:-1]
 
         self.te_ix = self.val_ix
         self.val_ix = self.tr_ix[:slicer]
         self.tr_ix = self.tr_ix[slicer:] + temp
 
 
     def get_fold_names(self):
         names_list = []
         rgen = np.random.RandomState(self.myseed)
         cv_names = np.arange(self.len_data)
 
         rgen.shuffle(cv_names)
         self.l = len(cv_names)
         self.init_indices()
 
         for split in range(self.n_splits):
             train_names, val_names, test_names = cv_names[self.tr_ix], cv_names[self.val_ix], cv_names[self.te_ix]
             names_list.append([train_names, val_names, test_names, self.fold])
             self.new_fold()
             self.fold += 1
 
         return names_list
 
 
 
 class FoldGenerator():
     r"""takes a set of elements (identifiers) and randomly splits them into the specified amt of subsets.
     """
 
     def __init__(self, identifiers, seed, n_splits=5):
         self.ids = np.array(identifiers)
         self.n_splits = n_splits
         self.seed = seed
 
     def generate_splits(self, n_splits=None):
         if n_splits is None:
             n_splits = self.n_splits
 
         rgen = np.random.RandomState(self.seed)
         rgen.shuffle(self.ids)
         self.splits = list(np.array_split(self.ids, n_splits, axis=0))  # already returns list, but to be sure
         return self.splits
 
 
 class Dataset(torchDataset):
     r"""Parent Class for actual Dataset classes to inherit from!
     """
     def __init__(self, cf, data_sourcedir=None):
         super(Dataset, self).__init__()
         self.cf = cf
 
         self.data_sourcedir = cf.data_sourcedir if data_sourcedir is None else data_sourcedir
         self.data_dir = cf.data_dir if hasattr(cf, 'data_dir') else self.data_sourcedir
 
         self.data_dest = cf.data_dest if hasattr(cf, "data_dest") else self.data_sourcedir
 
         self.data = {}
         self.set_ids = []
 
     def copy_data(self, cf, file_subset, keep_packed=False, del_after_unpack=False):
         if os.path.normpath(self.data_sourcedir) != os.path.normpath(self.data_dest):
             self.data_sourcedir = os.path.join(self.data_sourcedir, '')
             args = AttributeDict({
                     "source" :  self.data_sourcedir,
                     "destination" : self.data_dest,
                     "recursive" : True,
                     "cp_only_npz" : False,
                     "keep_packed" : keep_packed,
                     "del_after_unpack" : del_after_unpack,
                     "threads" : 16 if self.cf.server_env else os.cpu_count()
                     })
             dmanager.copy(args, file_subset=file_subset)
             self.data_dir = self.data_dest
 
 
 
     def __len__(self):
         return len(self.data)
     def __getitem__(self, id):
         """Return a sample of the dataset, i.e.,the dict of the id
         """
         return self.data[id]
     def __iter__(self):
         return self.data.__iter__()
 
     def init_FoldGenerator(self, seed, n_splits):
         self.fg = FoldGenerator(self.set_ids, seed=seed, n_splits=n_splits)
 
     def generate_splits(self, check_file):
         if not os.path.exists(check_file):
             self.fg.generate_splits()
             with open(check_file, 'wb') as handle:
                 pickle.dump(self.fg.splits, handle)
         else:
             with open(check_file, 'rb') as handle:
                 self.fg.splits = pickle.load(handle)
 
     def calc_statistics(self, subsets=None, plot_dir=None, overall_stats=True):
 
         if self.df is None:
             self.df = pd.DataFrame()
             balance_t = self.cf.balance_target if hasattr(self.cf, "balance_target") else "class_targets"
             self.df._metadata.append(balance_t)
             if balance_t=="class_targets":
                 mapper = lambda cl_id: self.cf.class_id2label[cl_id]
                 labels = self.cf.class_id2label.values()
             elif balance_t=="rg_bin_targets":
                 mapper = lambda rg_bin: self.cf.bin_id2label[rg_bin]
                 labels = self.cf.bin_id2label.values()
             # elif balance_t=="regression_targets":
             #     # todo this wont work
             #     mapper = lambda rg_val: AttributeDict({"name":rg_val}) #self.cf.bin_id2label[self.cf.rg_val_to_bin_id(rg_val)]
             #     labels = self.cf.bin_id2label.values()
             elif balance_t=="lesion_gleasons":
                 mapper = lambda gs: self.cf.gs2label[gs]
                 labels = self.cf.gs2label.values()
             else:
                 mapper = lambda x: AttributeDict({"name":x})
                 labels = None
             for pid, subj_data in self.data.items():
                 unique_ts, counts = np.unique(subj_data[balance_t], return_counts=True)
                 self.df = self.df.append(pd.DataFrame({"pid": [pid],
                                                        **{mapper(unique_ts[i]).name: [counts[i]] for i in
                                                           range(len(unique_ts))}}), ignore_index=True, sort=True)
             self.df = self.df.fillna(0)
 
         if overall_stats:
             df = self.df.drop("pid", axis=1)
             df = df.reindex(sorted(df.columns), axis=1).astype('uint32')
             print("Overall dataset roi counts per target kind:"); print(df.sum())
         if subsets is not None:
             self.df["subset"] = np.nan
             self.df["display_order"] = np.nan
             for ix, (subset, pids) in enumerate(subsets.items()):
                 self.df.loc[self.df.pid.isin(pids), "subset"] = subset
                 self.df.loc[self.df.pid.isin(pids), "display_order"] = ix
             df = self.df.groupby("subset").agg("sum").drop("pid", axis=1, errors='ignore').astype('int64')
             df = df.sort_values(by=['display_order']).drop('display_order', axis=1)
             df = df.reindex(sorted(df.columns), axis=1)
 
             print("Fold {} dataset roi counts per target kind:".format(self.cf.fold)); print(df)
         if plot_dir is not None:
             os.makedirs(plot_dir, exist_ok=True)
             if subsets is not None:
                 plg.plot_fold_stats(self.cf, df, labels, os.path.join(plot_dir, "data_stats_fold_" + str(self.cf.fold))+".pdf")
             if overall_stats:
                 plg.plot_data_stats(self.cf, df, labels, os.path.join(plot_dir, 'data_stats_overall.pdf'))
 
         return df, labels
 
 
 def get_class_balanced_patients(all_pids, class_targets, batch_size, num_classes, random_ratio=0):
     '''
     samples towards equilibrium of classes (on basis of total RoI counts). for highly imbalanced dataset, this might be a too strong requirement.
     :param class_targets: dic holding {patient_specifier : ROI class targets}, list position of ROI target corresponds to respective seg label - 1
     :param batch_size:
     :param num_classes:
     :return:
     '''
     # assert len(all_pids)>=batch_size, "not enough eligible pids {} to form a single batch of size {}".format(len(all_pids), batch_size)
     class_counts = {k: 0 for k in range(1,num_classes+1)}
     not_picked = np.array(all_pids)
     batch_patients = np.empty((batch_size,), dtype=not_picked.dtype)
     rarest_class = np.random.randint(1,num_classes+1)
 
     for ix in range(batch_size):
         if len(not_picked) == 0:
             warnings.warn("Dataset too small to generate batch with unique samples; => recycling.")
             not_picked = np.array(all_pids)
 
         np.random.shuffle(not_picked) #this could actually go outside(above) the loop.
         pick = not_picked[0]
         for cand in not_picked:
             if np.count_nonzero(class_targets[cand] == rarest_class) > 0:
                 pick = cand
                 cand_rarest_class = np.argmin([np.count_nonzero(class_targets[cand] == cl) for cl in
                                                range(1,num_classes+1)])+1
                 # if current batch already bigger than the batch random ratio, then
                 # check that weakest class in this patient is not the weakest in current batch (since needs to be boosted)
                 # also that at least one roi of this patient belongs to weakest class. If True, keep patient, else keep looking.
                 if (cand_rarest_class != rarest_class and np.count_nonzero(class_targets[cand] == rarest_class) > 0) \
                         or ix < int(batch_size * random_ratio):
                     break
 
         for c in range(1,num_classes+1):
             class_counts[c] += np.count_nonzero(class_targets[pick] == c)
         if not ix < int(batch_size * random_ratio) and class_counts[rarest_class] == 0:  # means searched thru whole set without finding rarest class
             print("Class {} not represented in current dataset.".format(rarest_class))
         rarest_class = np.argmin(([class_counts[c] for c in range(1,num_classes+1)]))+1
         batch_patients[ix] = pick
         not_picked = not_picked[not_picked != pick]  # removes pick
 
     return batch_patients
 
 
 class BatchGenerator(SlimDataLoaderBase):
     """
     create the training/validation batch generator. Randomly sample batch_size patients
     from the data set, (draw a random slice if 2D), pad-crop them to equal sizes and merge to an array.
     :param data: data dictionary as provided by 'load_dataset'
     :param img_modalities: list of strings ['adc', 'b1500'] from config
     :param batch_size: number of patients to sample for the batch
     :param pre_crop_size: equal size for merging the patients to a single array (before the final random-crop in data aug.)
     :return dictionary containing the batch data / seg / pids as lists; the augmenter will later concatenate them into an array.
     """
 
-    def __init__(self, cf, data, n_batches=None):
-        super(BatchGenerator, self).__init__(data, cf.batch_size, n_batches)
+    def __init__(self, cf, data, sample_pids_w_replace, max_batches=None, raise_stop_iteration=False, n_threads=None, seed=0):
+        if n_threads is None:
+            n_threads = cf.n_workers
+        super(BatchGenerator, self).__init__(data, cf.batch_size, number_of_threads_in_multithreaded=n_threads)
         self.cf = cf
+        self.random_count = int(cf.batch_random_ratio * cf.batch_size)
         self.plot_dir = os.path.join(self.cf.plot_dir, 'train_generator')
+        self.max_batches = max_batches
+        self.raise_stop = raise_stop_iteration
+        self.thread_id = 0
+        self.batches_produced = 0
 
         self.dataset_length = len(self._data)
         self.dataset_pids = list(self._data.keys())
-        self.eligible_pids = self.dataset_pids
+        self.rgen = np.random.RandomState(seed=seed)
+        self.eligible_pids = self.rgen.permutation(self.dataset_pids.copy())
+        self.eligible_pids = np.array_split(self.eligible_pids, self.number_of_threads_in_multithreaded)
+        self.eligible_pids = sorted(self.eligible_pids, key=len, reverse=True)
+        self.sample_pids_w_replace = sample_pids_w_replace
+        if not self.sample_pids_w_replace:
+            assert len(self.dataset_pids) / self.number_of_threads_in_multithreaded >= self.batch_size, \
+                "at least one batch needed per thread. dataset size: {}, n_threads: {}, batch_size: {}.".format(
+                    len(self.dataset_pids), self.number_of_threads_in_multithreaded, self.batch_size)
+            self.lock = Lock()
 
         self.stats = {"roi_counts": np.zeros((self.cf.num_classes,), dtype='uint32'), "empty_samples_count": 0}
 
         if hasattr(cf, "balance_target"):
             # WARNING: "balance targets are only implemented for 1-d targets (or 1-component vectors)"
             self.balance_target = cf.balance_target
         else:
             self.balance_target = "class_targets"
         self.targets = {k:v[self.balance_target] for (k,v) in self._data.items()}
 
+    def set_thread_id(self, thread_id):
+        self.thread_ids = self.eligible_pids[thread_id]
+        self.thread_id  = thread_id
+
+    def reset(self):
+        self.batches_produced = 0
+        self.thread_ids = self.rgen.permutation(self.eligible_pids[self.thread_id])
+
     def balance_target_distribution(self, plot=False):
         """
         :param all_pids:
         :param self.targets:  dic holding {patient_specifier : patient-wise-unique ROI targets}
         :return: probability distribution over all pids. draw without replace from this.
         """
         # get unique foreground targets per patient, assign -1 to an "empty" patient (has no foreground)
         patient_ts = [np.unique(lst) if len([t for t in lst if np.any(t>0)])>0 else [-1] for lst in self.targets.values()]
         #bg_mask = np.array([np.all(lst == [-1]) for lst in patient_ts])
         unique_ts, t_counts = np.unique([t for lst in patient_ts for t in lst if t!=-1], return_counts=True)
         t_probs = t_counts.sum() / t_counts
         t_probs /= t_probs.sum()
         t_probs = {t : t_probs[ix] for ix, t in enumerate(unique_ts)}
         t_probs[-1] = 0.
         # fail if balance target is not a number (i.e., a vector)
         self.p_probs = np.array([ max([t_probs[t] for t in lst]) for lst in patient_ts ])
         #normalize
         self.p_probs /= self.p_probs.sum()
         # rescale probs of empty samples
         # if not 0 == self.p_probs[bg_mask].shape[0]:
         #     #rescale_f = (1 - self.cf.empty_samples_ratio) / self.p_probs[~bg_mask].sum()
         #     rescale_f = 1 / self.p_probs[~bg_mask].sum()
         #     self.p_probs *= rescale_f
         #     self.p_probs[bg_mask] = 0. #self.cf.empty_samples_ratio/self.p_probs[bg_mask].shape[0]
 
         self.unique_ts = unique_ts
 
         if plot:
             os.makedirs(self.plot_dir, exist_ok=True)
             plg.plot_batchgen_distribution(self.cf, self.dataset_pids, self.p_probs, self.balance_target,
                                            out_file=os.path.join(self.plot_dir,
                                                                  "train_gen_distr_"+str(self.cf.fold)+".png"))
         return self.p_probs
 
+    def get_batch_pids(self):
+        if self.max_batches is not None and self.batches_produced * self.number_of_threads_in_multithreaded \
+                + self.thread_id >= self.max_batches:
+            self.reset()
+            raise StopIteration
+
+        if self.sample_pids_w_replace:
+            # fully random patients
+            batch_pids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False))
+            # target-balanced patients
+            batch_pids += list(np.random.choice(
+                self.dataset_pids, size=self.batch_size - self.random_count, replace=False, p=self.p_probs))
+        else:
+            with self.lock:
+                if len(self.thread_ids) == 0:
+                    if self.raise_stop:
+                        self.reset()
+                        raise StopIteration
+                    else:
+                        self.thread_ids = self.rgen.permutation(self.eligible_pids[self.thread_id])
+                batch_pids = self.thread_ids[:self.batch_size]
+                # batch_pids = np.random.choice(self.thread_ids, size=self.batch_size, replace=False)
+                self.thread_ids = [pid for pid in self.thread_ids if pid not in batch_pids]
+        self.batches_produced += 1
+
+        return batch_pids
 
     def generate_train_batch(self):
         # to be overriden by child
         # everything done in here is per batch
         # print statements in here get confusing due to multithreading
-
-        return
+        raise NotImplementedError
 
     def print_stats(self, logger=None, file=None, plot_file=None, plot=True):
         print_f = utils.CombinedPrinter(logger, file)
 
         print_f('\n***Final Training Stats***')
         total_count = np.sum(self.stats['roi_counts'])
         for tix, count in enumerate(self.stats['roi_counts']):
             #name = self.cf.class_dict[tix] if self.balance_target=="class_targets" else str(self.unique_ts[tix])
             name=str(self.unique_ts[tix])
             print_f('{}: {} rois seen ({:.1f}%).'.format(name, count, count / total_count * 100))
         total_samples = self.cf.num_epochs*self.cf.num_train_batches*self.cf.batch_size
         print_f('empty samples seen: {} ({:.1f}%).\n'.format(self.stats['empty_samples_count'],
                                                          self.stats['empty_samples_count']/total_samples*100))
         if plot:
             if plot_file is None:
                 plot_file = os.path.join(self.plot_dir, "train_gen_stats_{}.png".format(self.cf.fold))
                 os.makedirs(self.plot_dir, exist_ok=True)
             plg.plot_batchgen_stats(self.cf, self.stats, self.balance_target, self.unique_ts, plot_file)
 
 class PatientBatchIterator(SlimDataLoaderBase):
     """
     creates a val/test generator. Step through the dataset and return dictionaries per patient.
     2D is a special case of 3D patching with patch_size[2] == 1 (slices)
     Creates whole Patient batch and targets, and - if necessary - patchwise batch and targets.
     Appends patient targets anyway for evaluation.
     For Patching, shifts all patches into batch dimension. batch_tiling_forward will take care of exceeding batch dimensions.
 
     This iterator/these batches are not intended to go through MTaugmenter afterwards
     """
 
     def __init__(self, cf, data):
         super(PatientBatchIterator, self).__init__(data, 0)
         self.cf = cf
 
         self.dataset_length = len(self._data)
         self.dataset_pids = list(self._data.keys())
 
     def generate_train_batch(self, pid=None):
         # to be overriden by child
 
         return
 
 ###################################
 #  transforms, image manipulation #
 ###################################
 
 def get_patch_crop_coords(img, patch_size, min_overlap=30):
     """
     _:param img (y, x, (z))
     _:param patch_size: list of len 2 (2D) or 3 (3D).
     _:param min_overlap: minimum required overlap of patches.
     If too small, some areas are poorly represented only at edges of single patches.
     _:return ndarray: shape (n_patches, 2*dim). crop coordinates for each patch.
     """
     crop_coords = []
     for dim in range(len(img.shape)):
         n_patches = int(np.ceil(img.shape[dim] / patch_size[dim]))
 
         # no crops required in this dimension, add image shape as coordinates.
         if n_patches == 1:
             crop_coords.append([(0, img.shape[dim])])
             continue
 
         # fix the two outside patches to coords patchsize/2 and interpolate.
         center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1)
 
         if (patch_size[dim] - center_dists) < min_overlap:
             n_patches += 1
             center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1)
 
         patch_centers = np.round([(patch_size[dim] / 2 + (center_dists * ii)) for ii in range(n_patches)])
         dim_crop_coords = [(center - patch_size[dim] / 2, center + patch_size[dim] / 2) for center in patch_centers]
         crop_coords.append(dim_crop_coords)
 
     coords_mesh_grid = []
     for ymin, ymax in crop_coords[0]:
         for xmin, xmax in crop_coords[1]:
             if len(crop_coords) == 3 and patch_size[2] > 1:
                 for zmin, zmax in crop_coords[2]:
                     coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmax])
             elif len(crop_coords) == 3 and patch_size[2] == 1:
                 for zmin in range(img.shape[2]):
                     coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmin + 1])
             else:
                 coords_mesh_grid.append([ymin, ymax, xmin, xmax])
     return np.array(coords_mesh_grid).astype(int)
 
 def pad_nd_image(image, new_shape=None, mode="edge", kwargs=None, return_slicer=False, shape_must_be_divisible_by=None):
     """
     one padder to pad them all. Documentation? Well okay. A little bit. by Fabian Isensee
 
     :param image: nd image. can be anything
     :param new_shape: what shape do you want? new_shape does not have to have the same dimensionality as image. If
     len(new_shape) < len(image.shape) then the last axes of image will be padded. If new_shape < image.shape in any of
     the axes then we will not pad that axis, but also not crop! (interpret new_shape as new_min_shape)
     Example:
     image.shape = (10, 1, 512, 512); new_shape = (768, 768) -> result: (10, 1, 768, 768). Cool, huh?
     image.shape = (10, 1, 512, 512); new_shape = (364, 768) -> result: (10, 1, 512, 768).
 
     :param mode: see np.pad for documentation
     :param return_slicer: if True then this function will also return what coords you will need to use when cropping back
     to original shape
     :param shape_must_be_divisible_by: for network prediction. After applying new_shape, make sure the new shape is
     divisibly by that number (can also be a list with an entry for each axis). Whatever is missing to match that will
     be padded (so the result may be larger than new_shape if shape_must_be_divisible_by is not None)
     :param kwargs: see np.pad for documentation
     """
     if kwargs is None:
         kwargs = {}
 
     if new_shape is not None:
         old_shape = np.array(image.shape[-len(new_shape):])
     else:
         assert shape_must_be_divisible_by is not None
         assert isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray))
         new_shape = image.shape[-len(shape_must_be_divisible_by):]
         old_shape = new_shape
 
     num_axes_nopad = len(image.shape) - len(new_shape)
 
     new_shape = [max(new_shape[i], old_shape[i]) for i in range(len(new_shape))]
 
     if not isinstance(new_shape, np.ndarray):
         new_shape = np.array(new_shape)
 
     if shape_must_be_divisible_by is not None:
         if not isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)):
             shape_must_be_divisible_by = [shape_must_be_divisible_by] * len(new_shape)
         else:
             assert len(shape_must_be_divisible_by) == len(new_shape)
 
         for i in range(len(new_shape)):
             if new_shape[i] % shape_must_be_divisible_by[i] == 0:
                 new_shape[i] -= shape_must_be_divisible_by[i]
 
         new_shape = np.array([new_shape[i] + shape_must_be_divisible_by[i] - new_shape[i] % shape_must_be_divisible_by[i] for i in range(len(new_shape))])
 
     difference = new_shape - old_shape
     pad_below = difference // 2
     pad_above = difference // 2 + difference % 2
     pad_list = [[0, 0]]*num_axes_nopad + list([list(i) for i in zip(pad_below, pad_above)])
     res = np.pad(image, pad_list, mode, **kwargs)
     if not return_slicer:
         return res
     else:
         pad_list = np.array(pad_list)
         pad_list[:, 1] = np.array(res.shape) - pad_list[:, 1]
         slicer = list(slice(*i) for i in pad_list)
         return res, slicer
 
 def convert_seg_to_bounding_box_coordinates(data_dict, dim, roi_item_keys, get_rois_from_seg=False,
                                                 class_specific_seg=False):
     '''adapted from batchgenerators
 
     :param data_dict: seg: segmentation with labels indicating roi_count (get_rois_from_seg=False) or classes (get_rois_from_seg=True),
         class_targets: list where list index corresponds to roi id (roi_count)
     :param dim:
     :param roi_item_keys: keys of the roi-wise items in data_dict to process
     :param n_rg_feats: nr of regression vector features
     :param get_rois_from_seg:
     :return: coords (y1,x1,y2,x2 (,z1,z2)) where the segmentation GT is framed by +1 voxel, i.e., for an object with
         z-extensions z1=0 through z2=5, bbox target coords will be z1=-1, z2=6. (analogically for x,y).
         data_dict['roi_masks']: (b, n(b), c, h(n), w(n) (z(n))) list like roi_labels but with arrays (masks) inplace of
         integers. c==1 if segmentation not one-hot encoded.
     '''
 
     bb_target = []
     roi_masks = []
     roi_items = {name:[] for name in roi_item_keys}
     out_seg = np.copy(data_dict['seg'])
     for b in range(data_dict['seg'].shape[0]):
 
         p_coords_list = [] #p for patient?
         p_roi_masks_list = []
         p_roi_items_lists = {name:[] for name in roi_item_keys}
 
         if np.sum(data_dict['seg'][b] != 0) > 0:
             if get_rois_from_seg:
                 clusters, n_cands = lb(data_dict['seg'][b])
                 data_dict['class_targets'][b] = [data_dict['class_targets'][b]] * n_cands
             else:
                 n_cands = int(np.max(data_dict['seg'][b]))
 
             rois = np.array(
                 [(data_dict['seg'][b] == ii) * 1 for ii in range(1, n_cands + 1)], dtype='uint8')  # separate clusters
 
             for rix, r in enumerate(rois):
                 if np.sum(r != 0) > 0:  # check if the roi survived slicing (3D->2D) and data augmentation (cropping etc.)
                     seg_ixs = np.argwhere(r != 0)
                     coord_list = [np.min(seg_ixs[:, 1]) - 1, np.min(seg_ixs[:, 2]) - 1, np.max(seg_ixs[:, 1]) + 1,
                                   np.max(seg_ixs[:, 2]) + 1]
                     if dim == 3:
                         coord_list.extend([np.min(seg_ixs[:, 3]) - 1, np.max(seg_ixs[:, 3]) + 1])
 
                     p_coords_list.append(coord_list)
                     p_roi_masks_list.append(r)
                     # add background class = 0. rix is a patient wide index of lesions. since 'class_targets' is
                     # also patient wide, this assignment is not dependent on patch occurrences.
                     for name in roi_item_keys:
                         p_roi_items_lists[name].append(data_dict[name][b][rix])
 
                     assert data_dict["class_targets"][b][rix]>=1, "convertsegtobbox produced bg roi w cl targ {} and unique roi seg {}".format(data_dict["class_targets"][b][rix], np.unique(r))
 
 
                 if class_specific_seg:
                     out_seg[b][data_dict['seg'][b] == rix + 1] = data_dict['class_targets'][b][rix]
 
             if not class_specific_seg:
                 out_seg[b][data_dict['seg'][b] > 0] = 1
 
             bb_target.append(np.array(p_coords_list))
             roi_masks.append(np.array(p_roi_masks_list))
             for name in roi_item_keys:
                 roi_items[name].append(np.array(p_roi_items_lists[name]))
 
 
         else:
             bb_target.append([])
             roi_masks.append(np.zeros_like(data_dict['seg'][b], dtype='uint8')[None])
             for name in roi_item_keys:
                 roi_items[name].append(np.array([]))
 
     if get_rois_from_seg:
         data_dict.pop('class_targets', None)
 
     data_dict['bb_target'] = np.array(bb_target)
     data_dict['roi_masks'] = np.array(roi_masks)
     data_dict['seg'] = out_seg
     for name in roi_item_keys:
         data_dict[name] = np.array(roi_items[name])
 
 
     return data_dict
 
 class ConvertSegToBoundingBoxCoordinates(AbstractTransform):
     """ Converts segmentation masks into bounding box coordinates.
     """
 
     def __init__(self, dim, roi_item_keys, get_rois_from_seg=False, class_specific_seg=False):
         self.dim = dim
         self.roi_item_keys = roi_item_keys
         self.get_rois_from_seg = get_rois_from_seg
         self.class_specific_seg = class_specific_seg
 
     def __call__(self, **data_dict):
         return convert_seg_to_bounding_box_coordinates(data_dict, self.dim, self.roi_item_keys, self.get_rois_from_seg,
                                                        self.class_specific_seg)
 
 
 
 
 
 #############################
 #  data packing / unpacking # not used, data_manager.py used instead
 #############################
 
 def get_case_identifiers(folder):
     case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith("npz")]
     return case_identifiers
 
 
 def convert_to_npy(npz_file):
     if not os.path.isfile(npz_file[:-3] + "npy"):
         a = np.load(npz_file)['data']
         np.save(npz_file[:-3] + "npy", a)
 
 
 def unpack_dataset(folder, threads=8):
     case_identifiers = get_case_identifiers(folder)
     p = Pool(threads)
     npz_files = [os.path.join(folder, i + ".npz") for i in case_identifiers]
     p.map(convert_to_npy, npz_files)
     p.close()
     p.join()
 
 
 def delete_npy(folder):
     case_identifiers = get_case_identifiers(folder)
     npy_files = [os.path.join(folder, i + ".npy") for i in case_identifiers]
     npy_files = [i for i in npy_files if os.path.isfile(i)]
     for n in npy_files:
         os.remove(n)
\ No newline at end of file