diff --git a/datasets/lidc/configs.py b/datasets/lidc/configs.py
index e037756..413ce8f 100644
--- a/datasets/lidc/configs.py
+++ b/datasets/lidc/configs.py
@@ -1,445 +1,445 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import sys
 import os
 from collections import namedtuple
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 import numpy as np
 sys.path.append(os.path.dirname(os.path.realpath(__file__))+"/../..")
 from default_configs import DefaultConfigs
 
 # legends, nested classes are not handled well in multiprocessing! hence, Label class def in outer scope
 Label = namedtuple("Label", ['id', 'name', 'color', 'm_scores']) # m_scores = malignancy scores
 binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'm_scores', 'bin_vals'])
 
 
 class Configs(DefaultConfigs):
 
     def __init__(self, server_env=None):
         super(Configs, self).__init__(server_env)
 
         #########################
         #    Preprocessing      #
         #########################
 
         self.root_dir = '/home/gregor/networkdrives/E130-Personal/Goetz/Datenkollektive/Lungendaten/Nodules_LIDC_IDRI'
         self.raw_data_dir = '{}/new_nrrd'.format(self.root_dir)
         self.pp_dir = '/media/gregor/HDD2TB/data/lidc/pp_20200309_dev'
         # 'merged' for one gt per image, 'single_annotator' for four gts per image.
         self.gts_to_produce = ["single_annotator", "merged"]
 
         self.target_spacing = (0.7, 0.7, 1.25)
 
         #########################
         #         I/O          #
         #########################
 
         # path to preprocessed data.
         #self.pp_name = 'pp_20190318'
         self.pp_name = 'pp_20200309_dev'
 
         self.input_df_name = 'info_df.pickle'
         self.data_sourcedir = '/media/gregor/HDD2TB/data/lidc/{}/'.format(self.pp_name)
 
         # settings for deployment on cluster.
         if server_env:
             # path to preprocessed data.
             self.data_sourcedir = '/datasets/data_ramien/lidc/{}_npz/'.format(self.pp_name)
 
         # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_fpn'].
         self.model = 'mrcnn'
         self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net')
         self.model_path = os.path.join(self.source_dir, self.model_path)
 
 
         #########################
         #      Architecture     #
         #########################
 
         # dimension the model operates in. one out of [2, 3].
         self.dim = 2
 
         # 'class': standard object classification per roi, pairwise combinable with each of below tasks.
         # if 'class' is omitted from tasks, object classes will be fg/bg (1/0) from RPN.
         # 'regression': regress some vector per each roi
         # 'regression_ken_gal': use kendall-gal uncertainty sigma
         # 'regression_bin': classify each roi into a bin related to a regression scale
-        self.prediction_tasks = ['regression']
+        self.prediction_tasks = ['class']
 
         self.start_filts = 48 if self.dim == 2 else 18
         self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2
         self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50'
         self.norm = None # one of None, 'instance_norm', 'batch_norm'
 
         # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform')
         self.weight_init = None
 
         self.regression_n_features = 1
 
         #########################
         #      Data Loader      #
         #########################
 
         # distorted gt experiments: train on single-annotator gts in a random fashion to investigate network's
         # handling of noisy gts.
         # choose 'merged' for single, merged gt per image, or 'single_annotator' for four gts per image.
         # validation is always performed on same gt kind as training, testing always on merged gt.
-        self.training_gts = "sa"
+        self.training_gts = "merged"
 
         # select modalities from preprocessed data
         self.channels = [0]
         self.n_channels = len(self.channels)
 
         # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation.
         self.pre_crop_size_2D = [320, 320]
         self.patch_size_2D = [320, 320]
         self.pre_crop_size_3D = [160, 160, 96]
         self.patch_size_3D = [160, 160, 96]
 
         self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D
         self.pre_crop_size = self.pre_crop_size_2D if self.dim == 2 else self.pre_crop_size_3D
 
         # ratio of free sampled batch elements before class balancing is triggered
         # (>0 to include "empty"/background patches.)
         self.batch_random_ratio = 0.3
         self.balance_target =  "class_targets" if 'class' in self.prediction_tasks else 'rg_bin_targets'
 
         # set 2D network to match 3D gt boxes.
         self.merge_2D_to_3D_preds = self.dim==2
 
         self.observables_rois = []
 
         #self.rg_map = {1:1, 2:2, 3:3, 4:4, 5:5}
 
         #########################
         #   Colors and Legends  #
         #########################
         self.plot_frequency = 5
 
         binary_cl_labels = [Label(1, 'benign',  (*self.dark_green, 1.),  (1, 2)),
                             Label(2, 'malignant', (*self.red, 1.),  (3, 4, 5))]
         quintuple_cl_labels = [Label(1, 'MS1',  (*self.dark_green, 1.),      (1,)),
                                Label(2, 'MS2',  (*self.dark_yellow, 1.),     (2,)),
                                Label(3, 'MS3',  (*self.orange, 1.),     (3,)),
                                Label(4, 'MS4',  (*self.bright_red, 1.), (4,)),
                                Label(5, 'MS5',  (*self.red, 1.),        (5,))]
         # choose here if to do 2-way or 5-way regression-bin classification
         task_spec_cl_labels = quintuple_cl_labels
 
         self.class_labels = [
             #       #id #name     #color              #malignancy score
             Label(  0,  'bg',     (*self.gray, 0.),  (0,))]
         if "class" in self.prediction_tasks:
             self.class_labels += task_spec_cl_labels
 
         else:
             self.class_labels += [Label(1, 'lesion', (*self.orange, 1.), (1,2,3,4,5))]
 
         if any(['regression' in task for task in self.prediction_tasks]):
             self.bin_labels = [binLabel(0, 'MS0', (*self.gray, 1.), (0,), (0,))]
             self.bin_labels += [binLabel(cll.id, cll.name, cll.color, cll.m_scores,
                                          tuple([ms for ms in cll.m_scores])) for cll in task_spec_cl_labels]
             self.bin_id2label = {label.id: label for label in self.bin_labels}
             self.ms2bin_label = {ms: label for label in self.bin_labels for ms in label.m_scores}
             bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels]
             self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)}
             self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)]
 
         if self.class_specific_seg:
             self.seg_labels = self.class_labels
         else:
             self.seg_labels = [  # id      #name           #color
                 Label(0, 'bg', (*self.gray, 0.)),
                 Label(1, 'fg', (*self.orange, 1.))
             ]
 
         self.class_id2label = {label.id: label for label in self.class_labels}
         self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0}
         # class_dict is used in evaluator / ap, auc, etc. statistics, and class 0 (bg) only needs to be
         # evaluated in debugging
         self.class_cmap = {label.id: label.color for label in self.class_labels}
 
         self.seg_id2label = {label.id: label for label in self.seg_labels}
         self.cmap = {label.id: label.color for label in self.seg_labels}
 
         self.plot_prediction_histograms = True
         self.plot_stat_curves = False
         self.has_colorchannels = False
         self.plot_class_ids = True
 
         self.num_classes = len(self.class_dict)  # for instance classification (excl background)
         self.num_seg_classes = len(self.seg_labels)  # incl background
 
 
         #########################
         #   Data Augmentation   #
         #########################
 
         self.da_kwargs={
             'mirror': True,
             'mirror_axes': tuple(np.arange(0, self.dim, 1)),
             'do_elastic_deform': True,
             'alpha':(0., 1500.),
             'sigma':(30., 50.),
             'do_rotation':True,
             'angle_x': (0., 2 * np.pi),
             'angle_y': (0., 0),
             'angle_z': (0., 0),
             'do_scale': True,
             'scale':(0.8, 1.1),
             'random_crop':False,
             'rand_crop_dist':  (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3),
             'border_mode_data': 'constant',
             'border_cval_data': 0,
             'order_data': 1}
 
         if self.dim == 3:
             self.da_kwargs['do_elastic_deform'] = False
             self.da_kwargs['angle_x'] = (0, 0.0)
             self.da_kwargs['angle_y'] = (0, 0.0) #must be 0!!
             self.da_kwargs['angle_z'] = (0., 2 * np.pi)
 
         #################################
         #  Schedule / Selection / Optim #
         #################################
 
         self.num_epochs = 130 if self.dim == 2 else 150
         self.num_train_batches = 200 if self.dim == 2 else 200
         self.batch_size = 20 if self.dim == 2 else 8
 
         # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training)
         # the former is morge accurate, while the latter is faster (depending on volume size)
         self.val_mode = 'val_sampling' # only 'val_sampling', 'val_patient' not implemented
         if self.val_mode == 'val_patient':
             raise NotImplementedError
         if self.val_mode == 'val_sampling':
             self.num_val_batches = 70
 
         self.save_n_models = 4
         # set a minimum epoch number for saving in case of instabilities in the first phase of training.
         self.min_save_thresh = 0 if self.dim == 2 else 0
         # criteria to average over for saving epochs, 'criterion':weight.
         if "class" in self.prediction_tasks:
             # 'criterion': weight
             if len(self.class_labels)==3:
                 self.model_selection_criteria = {"benign_ap": 0.5, "malignant_ap": 0.5}
             elif len(self.class_labels)==6:
                 self.model_selection_criteria = {str(label.name)+"_ap": 1./5 for label in self.class_labels if label.id!=0}
         elif any("regression" in task for task in self.prediction_tasks):
             self.model_selection_criteria = {"lesion_ap": 0.2, "lesion_avp": 0.8}
 
         self.weight_decay = 0
         self.clip_norm = 200 if 'regression_ken_gal' in self.prediction_tasks else None  # number or None
 
         # int in [0, dataset_size]. select n patients from dataset for prototyping. If None, all data is used.
         self.select_prototype_subset = None #self.batch_size
 
         #########################
         #        Testing        #
         #########################
 
         # set the top-n-epochs to be saved for temporal averaging in testing.
         self.test_n_epochs = self.save_n_models
 
         self.test_aug_axes = (0,1,(0,1))  # None or list: choices are 0,1,(0,1) (0==spatial y, 1== spatial x).
         self.held_out_test_set = False
         self.max_test_patients = "all"  # "all" or number
 
         self.report_score_level = ['rois', 'patient']  # choose list from 'patient', 'rois'
         self.patient_class_of_interest = 2 if 'class' in self.prediction_tasks else 1
 
         self.metrics = ['ap', 'auc']
         if any(['regression' in task for task in self.prediction_tasks]):
             self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp',
                              'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp']
         if 'aleatoric' in self.model:
             self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted']
         self.evaluate_fold_means = True
 
         self.ap_match_ious = [0.1]  # list of ious to be evaluated for ap-scoring.
         self.min_det_thresh = 0.1  # minimum confidence value to select predictions for evaluation.
 
         # aggregation method for test and val_patient predictions.
         # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf,
         # nms = standard non-maximum suppression, or None = no clustering
         self.clustering = 'wbc'
         # iou thresh (exclusive!) for regarding two preds as concerning the same ROI
         self.clustering_iou = 0.1  # has to be larger than desired possible overlap iou of model predictions
 
         self.plot_prediction_histograms = True
         self.plot_stat_curves = False
         self.n_test_plots = 1
 
         #########################
         #   Assertions          #
         #########################
         if not 'class' in self.prediction_tasks:
             assert self.num_classes == 1
 
         #########################
         #   Add model specifics #
         #########################
 
         {'detection_fpn': self.add_det_fpn_configs,
          'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs,
          'retina_net': self.add_mrcnn_configs,
          'retina_unet': self.add_mrcnn_configs,
         }[self.model]()
 
     def rg_val_to_bin_id(self, rg_val):
         return float(np.digitize(np.mean(rg_val), self.bin_edges))
 
     def add_det_fpn_configs(self):
 
         self.learning_rate = [1e-4] * self.num_epochs
         self.dynamic_lr_scheduling = False
 
         # RoI score assigned to aggregation from pixel prediction (connected component). One of ['max', 'median'].
         self.score_det = 'max'
 
         # max number of roi candidates to identify per batch element and class.
         self.n_roi_candidates = 10 if self.dim == 2 else 30
 
         # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
         self.seg_loss_mode = 'wce'
 
         # if <1, false positive predictions in foreground are penalized less.
         self.fp_dice_weight = 1 if self.dim == 2 else 1
         if len(self.class_labels)==3:
             self.wce_weights = [1., 1., 1.] if self.seg_loss_mode=="dice_wce" else [0.1, 1., 1.]
         elif len(self.class_labels)==6:
             self.wce_weights = [1., 1., 1., 1., 1., 1.] if self.seg_loss_mode == "dice_wce" else [0.1, 1., 1., 1., 1., 1.]
         else:
             raise Exception("mismatch loss weights & nr of classes")
         self.detection_min_confidence = self.min_det_thresh
 
         self.head_classes = self.num_seg_classes
 
     def add_mrcnn_configs(self):
 
         # learning rate is a list with one entry per epoch.
         self.learning_rate = [1e-4] * self.num_epochs
         self.dynamic_lr_scheduling = False
         # disable the re-sampling of mask proposals to original size for speed-up.
         # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching),
         # mask-outputs are optional.
         self.return_masks_in_train = False
         self.return_masks_in_val = True
         self.return_masks_in_test = False
 
         # set number of proposal boxes to plot after each epoch.
         self.n_plot_rpn_props = 5 if self.dim == 2 else 30
 
         # number of classes for network heads: n_foreground_classes + 1 (background)
         self.head_classes = self.num_classes + 1
 
         self.frcnn_mode = False
 
         # feature map strides per pyramid level are inferred from architecture.
         self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]}
 
         # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale
         # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.)
         self.rpn_anchor_scales = {'xy': [[8], [16], [32], [64]], 'z': [[2], [4], [8], [16]]}
 
         # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3.
         self.pyramid_levels = [0, 1, 2, 3]
 
         # number of feature maps in rpn. typically lowered in 3D to save gpu-memory.
         self.n_rpn_features = 512 if self.dim == 2 else 128
 
         # anchor ratios and strides per position in feature maps.
         self.rpn_anchor_ratios = [0.5, 1, 2]
         self.rpn_anchor_stride = 1
 
         # Threshold for first stage (RPN) non-maximum suppression (NMS):  LOWER == HARDER SELECTION
         self.rpn_nms_threshold = 0.7 if self.dim == 2 else 0.7
 
         # loss sampling settings.
         self.rpn_train_anchors_per_image = 6  #per batch element
         self.train_rois_per_image = 6 #per batch element
         self.roi_positive_ratio = 0.5
         self.anchor_matching_iou = 0.7
 
         # factor of top-k candidates to draw from  per negative sample (stochastic-hard-example-mining).
         # poolsize to draw top-k candidates from will be shem_poolsize * n_negative_samples.
         self.shem_poolsize = 10
 
         self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3)
         self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5)
         self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10)
 
         self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
         self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
         self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]])
         self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1],
                                self.patch_size_3D[2], self.patch_size_3D[2]])
         if self.dim == 2:
             self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4]
             self.bbox_std_dev = self.bbox_std_dev[:4]
             self.window = self.window[:4]
             self.scale = self.scale[:4]
 
         # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element.
         self.pre_nms_limit = 3000 if self.dim == 2 else 6000
 
         # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True,
         # since proposals of the entire batch are forwarded through second stage in as one "batch".
         self.roi_chunk_size = 2500 if self.dim == 2 else 600
         self.post_nms_rois_training = 500 if self.dim == 2 else 75
         self.post_nms_rois_inference = 500
 
         # Final selection of detections (refine_detections)
         self.model_max_instances_per_batch_element = 10 if self.dim == 2 else 30  # per batch element and class.
         self.detection_nms_threshold = 1e-5  # needs to be > 0, otherwise all predictions are one cluster.
         self.model_min_confidence = 0.1
 
         if self.dim == 2:
             self.backbone_shapes = np.array(
                 [[int(np.ceil(self.patch_size[0] / stride)),
                   int(np.ceil(self.patch_size[1] / stride))]
                  for stride in self.backbone_strides['xy']])
         else:
             self.backbone_shapes = np.array(
                 [[int(np.ceil(self.patch_size[0] / stride)),
                   int(np.ceil(self.patch_size[1] / stride)),
                   int(np.ceil(self.patch_size[2] / stride_z))]
                  for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z']
                                              )])
 
         if self.model == 'retina_net' or self.model == 'retina_unet':
 
             self.focal_loss = True
 
             # implement extra anchor-scales according to retina-net publication.
             self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                             self.rpn_anchor_scales['xy']]
             self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                            self.rpn_anchor_scales['z']]
             self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3
 
             self.n_rpn_features = 256 if self.dim == 2 else 128
 
             # pre-selection of detections for NMS-speedup. per entire batch.
             self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size
 
             # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002
             self.anchor_matching_iou = 0.5
 
             if self.model == 'retina_unet':
                 self.operate_stride1 = True
 
diff --git a/datasets/lidc/data_loader.py b/datasets/lidc/data_loader.py
index fad15fc..fb97815 100644
--- a/datasets/lidc/data_loader.py
+++ b/datasets/lidc/data_loader.py
@@ -1,1024 +1,1024 @@
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 '''
 Data Loader for the LIDC data set. This dataloader expects preprocessed data in .npy or .npz files per patient and
 a pandas dataframe containing the meta info e.g. file paths, and some ground-truth info like labels, foreground slice ids.
 
 LIDC 4-fold annotations storage capacity problem: keep segmentation gts compressed (npz), unpack at each batch generation.
 
 '''
 
 import plotting as plg
 
 import os
 import pickle
 import time
 from multiprocessing import Pool
 
 import numpy as np
 import pandas as pd
 from collections import OrderedDict
 
 # batch generator tools from https://github.com/MIC-DKFZ/batchgenerators
 from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror
 from batchgenerators.transforms.abstract_transforms import Compose
 from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
 from batchgenerators.transforms.spatial_transforms import SpatialTransform
 from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform
 
 
 import utils.dataloader_utils as dutils
 from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates
-from utils.dataloader_utils import BatchGenerator as BatchGeneratorParent
+
 
 def save_obj(obj, name):
     """Pickle a python object."""
     with open(name + '.pkl', 'wb') as f:
         pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
 
 def vector(item):
     """ensure item is vector-like (list or array or tuple)
     :param item: anything
     """
     if not isinstance(item, (list, tuple, np.ndarray)):
         item = [item]
     return item
 
 
 class Dataset(dutils.Dataset):
     r"""Load a dict holding memmapped arrays and clinical parameters for each patient,
     evtly subset of those.
         If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to
         cf.data_dest.
     :param cf: config object.
     :param logger: logger.
     :param subset_ids: subset of patient/sample identifiers to load from whole set.
     :param data_sourcedir: directory in which to find data, defaults to cf.data_sourcedir if None.
     :return: dict with imgs, segs, pids, class_labels, observables
     """
 
     def __init__(self, cf, logger=None, subset_ids=None, data_sourcedir=None, mode='train'):
         super(Dataset,self).__init__(cf, data_sourcedir)
         if mode == 'train' and not cf.training_gts == "merged":
             self.gt_dir = "patient_gts_sa"
             self.gt_kind = cf.training_gts
         else:
             self.gt_dir = "patient_gts_merged"
             self.gt_kind = "merged"
         if logger is not None:
             logger.info("loading {} ground truths for {}".format(self.gt_kind, 'training and validation' if mode=='train'
         else 'testing'))
 
         p_df = pd.read_pickle(os.path.join(self.data_sourcedir, self.gt_dir, cf.input_df_name))
         #exclude_pids = ["0305a", "0447a"]  # due to non-bg segmentation but bg mal label in nodules 5728, 8840
         #p_df = p_df[~p_df.pid.isin(exclude_pids)]
 
         if subset_ids is not None:
             p_df = p_df[p_df.pid.isin(subset_ids)]
             if logger is not None:
                 logger.info('subset: selected {} instances from df'.format(len(p_df)))
         if cf.select_prototype_subset is not None:
             prototype_pids = p_df.pid.tolist()[:cf.select_prototype_subset]
             p_df = p_df[p_df.pid.isin(prototype_pids)]
             if logger is not None:
                 logger.warning('WARNING: using prototyping data subset of length {}!!!'.format(len(p_df)))
 
         pids = p_df.pid.tolist()
 
         # evtly copy data from data_sourcedir to data_dest
         if cf.server_env and not hasattr(cf, 'data_dir') and hasattr(cf, "data_dest"):
                 # copy and unpack images
                 file_subset = ["{}_img.npz".format(pid) for pid in pids if not
                 os.path.isfile(os.path.join(cf.data_dest,'{}_img.npy'.format(pid)))]
                 file_subset += [os.path.join(self.data_sourcedir, self.gt_dir, cf.input_df_name)]
                 self.copy_data(cf, file_subset=file_subset, keep_packed=False, del_after_unpack=True)
                 # copy and do not unpack segmentations
                 file_subset = [os.path.join(self.gt_dir, "{}_rois.np*".format(pid)) for pid in pids]
                 keep_packed = not cf.training_gts == "merged"
                 self.copy_data(cf, file_subset=file_subset, keep_packed=keep_packed, del_after_unpack=(not keep_packed))
         else:
             cf.data_dir = self.data_sourcedir
 
         ext = 'npy' if self.gt_kind == "merged" else 'npz'
         imgs = [os.path.join(self.data_dir, '{}_img.npy'.format(pid)) for pid in pids]
         segs = [os.path.join(self.data_dir, self.gt_dir, '{}_rois.{}'.format(pid, ext)) for pid in pids]
         orig_class_targets = p_df['class_target'].tolist()
 
         data = OrderedDict()
 
         if self.gt_kind == 'merged':
             for ix, pid in enumerate(pids):
                 data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid}
                 data[pid]['fg_slices'] = np.array(p_df['fg_slices'].tolist()[ix])
                 if 'class' in cf.prediction_tasks:
                     if len(cf.class_labels)==3:
                         # malignancy scores are binarized: (benign: 1-2 --> cl 1, malignant: 3-5 --> cl 2)
                         data[pid]['class_targets'] = np.array([2 if ii >= 3 else 1 for ii in orig_class_targets[ix]],
                                                               dtype='uint8')
                     elif len(cf.class_labels)==6:
                         # classify each malignancy score
                         data[pid]['class_targets'] = np.array([1 if ii==0.5 else np.round(ii) for ii in orig_class_targets[ix]], dtype='uint8')
                     else:
                         raise Exception("mismatch class labels and data-loading implementations.")
                 else:
                     data[pid]['class_targets'] = np.ones_like(np.array(orig_class_targets[ix]), dtype='uint8')
                 if any(['regression' in task for task in cf.prediction_tasks]):
                     data[pid]["regression_targets"] = np.array([vector(v) for v in orig_class_targets[ix]],
                                                                dtype='float16')
                     data[pid]["rg_bin_targets"] = np.array(
                         [cf.rg_val_to_bin_id(v) for v in data[pid]["regression_targets"]], dtype='uint8')
         else:
             for ix, pid in enumerate(pids):
                 data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid}
                 data[pid]['fg_slices'] = np.array(p_df['fg_slices'].values[ix])
                 if 'class' in cf.prediction_tasks:
                     # malignancy scores are binarized: (benign: 1-2 --> cl 1, malignant: 3-5 --> cl 2)
                     raise NotImplementedError
                     # todo need to consider bg
                     # data[pid]['class_targets'] = np.array(
                     #     [[2 if ii >= 3 else 1 for ii in four_fold_targs] for four_fold_targs in orig_class_targets[ix]])
                 else:
                     data[pid]['class_targets'] = np.array(
                         [[1 if ii > 0 else 0 for ii in four_fold_targs] for four_fold_targs in orig_class_targets[ix]], dtype='uint8')
                 if any(['regression' in task for task in cf.prediction_tasks]):
                     data[pid]["regression_targets"] = np.array(
                         [[vector(v) for v in four_fold_targs] for four_fold_targs in orig_class_targets[ix]], dtype='float16')
                     data[pid]["rg_bin_targets"] = np.array(
                         [[cf.rg_val_to_bin_id(v) for v in four_fold_targs] for four_fold_targs in data[pid]["regression_targets"]], dtype='uint8')
 
         cf.roi_items = cf.observables_rois[:]
         cf.roi_items += ['class_targets']
         if any(['regression' in task for task in cf.prediction_tasks]):
             cf.roi_items += ['regression_targets']
             cf.roi_items += ['rg_bin_targets']
 
         self.data = data
         self.set_ids = np.array(list(self.data.keys()))
         self.df = None
 
 # merged GTs
 class BatchGenerator_merged(dutils.BatchGenerator):
     """
     creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D)
     from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size.
     Actual patch_size is obtained after data augmentation.
     :param data: data dictionary as provided by 'load_dataset'.
     :param batch_size: number of patients to sample for the batch
     :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target
     """
     def __init__(self, cf, data, name="train"):
         super(BatchGenerator_merged, self).__init__(cf, data)
 
         self.crop_margin = np.array(self.cf.patch_size)/8. #min distance of ROI center to edge of cropped_patch.
         self.p_fg = 0.5
         self.empty_samples_max_ratio = 0.6
 
         self.random_count = int(cf.batch_random_ratio * cf.batch_size)
         self.class_targets = {k: v["class_targets"] for (k, v) in self._data.items()}
 
 
         self.balance_target_distribution(plot=name=="train")
 
     def generate_train_batch(self):
 
         # samples patients towards equilibrium of foreground classes on a roi-level after sampling a random ratio
         # fully random patients
         batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False))
         # target-balanced patients
         batch_patient_ids += list(np.random.choice(self.dataset_pids, size=self.batch_size-self.random_count,
                                                    replace=False, p=self.p_probs))
 
         batch_data, batch_segs, batch_pids, batch_patient_labels = [], [], [], []
         batch_roi_items = {name: [] for name in self.cf.roi_items}
         # record roi count of classes in batch
         batch_roi_counts = np.zeros((len(self.unique_ts),), dtype='uint32')
         batch_empty_counts = np.zeros((len(self.unique_ts),), dtype='uint32')
         # empty count for full bg samples (empty slices in 2D/patients in 3D) per class
 
 
         for sample in range(self.batch_size):
             patient = self._data[batch_patient_ids[sample]]
 
             data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis]
             seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0))
             batch_pids.append(patient['pid'])
             (c, y, x, z) = data.shape
 
             if self.cf.dim == 2:
 
                 elig_slices, choose_fg = [], False
                 if len(patient['fg_slices']) > 0:
                     if np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or \
                             np.random.rand(1)<=self.p_fg:
                         # fg is to be picked
                         for tix in np.argsort(batch_roi_counts):
                             # pick slices of patient that have roi of sought-for target
                             # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix
                             elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero(
                                 patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0])-1] ==
                                 self.unique_ts[tix]) > 0]
                             if len(elig_slices) > 0:
                                 choose_fg = True
                                 break
                     else:
                         # pick bg
                         elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices'])
 
                 if len(elig_slices)>0:
                     sl_pick_ix = np.random.choice(elig_slices, size=None)
                 else:
                     sl_pick_ix = np.random.choice(z, size=None)
 
                 data = data[..., sl_pick_ix]
                 seg = seg[..., sl_pick_ix]
 
             # pad data if smaller than pre_crop_size.
             if np.any([data.shape[dim + 1] < ps for dim, ps in enumerate(self.cf.pre_crop_size)]):
                 new_shape = [np.max([data.shape[dim + 1], ps]) for dim, ps in enumerate(self.cf.pre_crop_size)]
                 data = dutils.pad_nd_image(data, new_shape, mode='constant')
                 seg = dutils.pad_nd_image(seg, new_shape, mode='constant')
 
             # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg.
             crop_dims = [dim for dim, ps in enumerate(self.cf.pre_crop_size) if data.shape[dim + 1] > ps]
             if len(crop_dims) > 0:
                 if self.cf.dim == 3:
                     choose_fg = np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio)\
                                 or np.random.rand(1) <= self.p_fg
                 if choose_fg and np.any(seg):
                     available_roi_ids = np.unique(seg)[1:]
                     for tix in np.argsort(batch_roi_counts):
                         elig_roi_ids = available_roi_ids[patient[self.balance_target][available_roi_ids-1] == self.unique_ts[tix]]
                         if len(elig_roi_ids)>0:
                             seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None))
                             break
                     roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)]
                     assert seg[tuple(roi_anchor_pixel)] > 0
                     # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by
                     # distance to the desired ROI < patch_size /2.
                     # (here final patch size to account for center_crop after data augmentation).
                     sample_seg_center = {}
                     for ii in crop_dims:
                         low = np.max((self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] - (self.cf.patch_size[ii]//2 - self.crop_margin[ii])))
                         high = np.min((data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2,
                                        roi_anchor_pixel[ii] + (self.cf.patch_size[ii]//2 - self.crop_margin[ii])))
                         # happens if lesion on the edge of the image. dont care about roi anymore,
                         # just make sure pre-crop is inside image.
                         if low >= high:
                             low = data.shape[ii + 1] // 2 - (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2)
                             high = data.shape[ii + 1] // 2 + (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2)
                         sample_seg_center[ii] = np.random.randint(low=low, high=high)
 
                 else:
                     # not guaranteed to be empty. probability of emptiness depends on the data.
                     sample_seg_center = {ii: np.random.randint(low=self.cf.pre_crop_size[ii]//2,
                                                            high=data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2) for ii in crop_dims}
 
                 for ii in crop_dims:
                     min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2)
                     max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2)
                     data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1)
                     seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii)
 
             batch_data.append(data)
             batch_segs.append(seg[np.newaxis])
             for o in batch_roi_items: #after loop, holds every entry of every batchpatient per roi-item
                     batch_roi_items[o].append(patient[o])
 
             if self.cf.dim == 3:
                 for tix in range(len(self.unique_ts)):
                     non_zero = np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix])
                     batch_roi_counts[tix] += non_zero
                     batch_empty_counts[tix] += int(non_zero==0)
                     # todo remove assert when checked
                     if not np.any(seg):
                         assert non_zero==0
             elif self.cf.dim == 2:
                 for tix in range(len(self.unique_ts)):
                     non_zero = np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix])
                     batch_roi_counts[tix] += non_zero
                     batch_empty_counts[tix] += int(non_zero == 0)
                     # todo remove assert when checked
                     if not np.any(seg):
                         assert non_zero==0
 
 
         data = np.array(batch_data).astype(np.float16)
         seg = np.array(batch_segs).astype(np.uint8)
         batch = {'data': data, 'seg': seg, 'pid': batch_pids,
                 'roi_counts':batch_roi_counts, 'empty_counts': batch_empty_counts}
         for key,val in batch_roi_items.items(): #extend batch dic by roi-wise items (obs, class ids, regression vectors...)
             batch[key] = np.array(val)
 
         return batch
 
 class PatientBatchIterator_merged(dutils.PatientBatchIterator):
     """
     creates a test generator that iterates over entire given dataset returning 1 patient per batch.
     Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actualy evaluation (done in 3D),
     if willing to accept speed-loss during training.
     :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or
     batch_size = n_2D_patches in 2D .
     """
 
     def __init__(self, cf, data):  # threads in augmenter
         super(PatientBatchIterator_merged, self).__init__(cf, data)
         self.patient_ix = 0
         self.patch_size = cf.patch_size + [1] if cf.dim == 2 else cf.patch_size
 
     def generate_train_batch(self, pid=None):
 
         if pid is None:
             pid = self.dataset_pids[self.patient_ix]
         patient = self._data[pid]
 
         data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))
         seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0))
 
         # pad data if smaller than patch_size seen during training.
         if np.any([data.shape[dim] < ps for dim, ps in enumerate(self.patch_size)]):
             new_shape = [np.max([data.shape[dim], self.patch_size[dim]]) for dim, ps in enumerate(self.patch_size)]
             data = dutils.pad_nd_image(data, new_shape)  # use 'return_slicer' to crop image back to original shape.
             seg = dutils.pad_nd_image(seg, new_shape)
 
         # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor.
         if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds:
             out_data = data[np.newaxis, np.newaxis]
             out_seg = seg[np.newaxis, np.newaxis]
             batch_3D = {'data': out_data, 'seg': out_seg}
             for o in self.cf.roi_items:
                 batch_3D[o] = np.array([patient[o]])
             converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg)
             batch_3D = converter(**batch_3D)
             batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape})
             for o in self.cf.roi_items:
                 batch_3D["patient_" + o] = batch_3D[o]
 
         if self.cf.dim == 2:
             out_data = np.transpose(data, axes=(2, 0, 1))[:, np.newaxis]  # (z, c, x, y )
             out_seg = np.transpose(seg, axes=(2, 0, 1))[:, np.newaxis]
 
             batch_2D = {'data': out_data, 'seg': out_seg}
             for o in self.cf.roi_items:
                 batch_2D[o] = np.repeat(np.array([patient[o]]), out_data.shape[0], axis=0)
 
             converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg)
             batch_2D = converter(**batch_2D)
 
             if self.cf.merge_2D_to_3D_preds:
                 batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'],
                                  'original_img_shape': out_data.shape})
                 for o in self.cf.roi_items:
                     batch_2D["patient_" + o] = batch_3D[o]
             else:
                 batch_2D.update({'patient_bb_target': batch_2D['bb_target'],
                                  'original_img_shape': out_data.shape})
                 for o in self.cf.roi_items:
                     batch_2D["patient_" + o] = batch_2D[o]
 
         out_batch = batch_3D if self.cf.dim == 3 else batch_2D
         out_batch.update({'pid': np.array([patient['pid']] * len(out_data))})
 
         # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension.
         # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1.
         if np.any([data.shape[dim] > self.patch_size[dim] for dim in range(3)]):
             patient_batch = out_batch
             patch_crop_coords_list = dutils.get_patch_crop_coords(data, self.patch_size)
             new_img_batch, new_seg_batch = [], []
 
             for cix, c in enumerate(patch_crop_coords_list):
 
                 seg_patch = seg[c[0]:c[1], c[2]: c[3], c[4]:c[5]]
                 new_seg_batch.append(seg_patch)
 
                 tmp_c_5 = c[5]
 
                 new_img_batch.append(data[c[0]:c[1], c[2]:c[3], c[4]:tmp_c_5])
 
             data = np.array(new_img_batch)[:, np.newaxis]  # (n_patches, c, x, y, z)
             seg = np.array(new_seg_batch)[:, np.newaxis]  # (n_patches, 1, x, y, z)
             if self.cf.dim == 2:
                 # all patches have z dimension 1 (slices). discard dimension
                 data = data[..., 0]
                 seg = seg[..., 0]
 
             patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'),
                            'pid': np.array([patient['pid']] * data.shape[0])}
             for o in self.cf.roi_items:
                 patch_batch[o] = np.repeat(np.array([patient[o]]), len(patch_crop_coords_list), axis=0)
             # patient-wise (orig) batch info for putting the patches back together after prediction
             for o in self.cf.roi_items:
                 patch_batch["patient_" + o] = patient_batch['patient_' + o]
                 if self.cf.dim == 2:
                     # this could also be named "unpatched_2d_roi_items"
                     patch_batch["patient_" + o + "_2d"] = patient_batch[o]
             # adding patient-wise data and seg adds about 2 GB of additional RAM consumption to a batch 20x288x288
             # and enables calculating test-dice/viewing patient-wise results in test
             # remove, but also remove dice from metrics, when like to save memory
             patch_batch['patient_data'] = patient_batch['data']
             patch_batch['patient_seg'] = patient_batch['seg']
             patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
             patch_batch['patient_bb_target'] = patient_batch['patient_bb_target']
             if self.cf.dim == 2:
                 patch_batch['patient_bb_target_2d'] = patient_batch['bb_target']
             patch_batch['original_img_shape'] = patient_batch['original_img_shape']
 
             converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False,
                                                            self.cf.class_specific_seg)
             patch_batch = converter(**patch_batch)
             out_batch = patch_batch
 
         self.patient_ix += 1
         if self.patient_ix == len(self.dataset_pids):
             self.patient_ix = 0
 
         return out_batch
 
 # single-annotator GTs
-class BatchGenerator_sa(BatchGeneratorParent):
+class BatchGenerator_sa(dutils.BatchGenerator):
     """
     creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D)
     from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size.
     Actual patch_size is obtained after data augmentation.
     :param data: data dictionary as provided by 'load_dataset'.
     :param batch_size: number of patients to sample for the batch
     :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target
     """
 
     # noinspection PyMethodOverriding
     def balance_target_distribution(self, rater, plot=False):
         """
         :param rater: for which rater slot to generate the distribution
         :param self.targets:  dic holding {patient_specifier : patient-wise-unique ROI targets}
         :param plot: whether to plot the generated patient distributions
         :return: probability distribution over all pids. draw without replace from this.
         """
         unique_ts = np.unique([v[rater] for pat in self.targets.values() for v in pat])
         sample_stats = pd.DataFrame(columns=[str(ix) + suffix for ix in unique_ts for suffix in ["", "_bg"]],
                                          index=list(self.targets.keys()))
         for pid in sample_stats.index:
             for targ in unique_ts:
                 fg_count = 0 if len(self.targets[pid]) == 0 else np.count_nonzero(self.targets[pid][:, rater] == targ)
                 sample_stats.loc[pid, str(targ)] = int(fg_count > 0)
                 sample_stats.loc[pid, str(targ) + "_bg"] = int(fg_count == 0)
 
         target_stats = sample_stats.agg(
             ("sum", lambda col: col.sum() / len(self._data)), axis=0, sort=False).rename({"<lambda>": "relative"})
 
         anchor = 1. - target_stats.loc["relative"].iloc[0]
         fg_bg_weights = anchor / target_stats.loc["relative"]
         cum_weights = anchor * len(fg_bg_weights)
         fg_bg_weights /= cum_weights
 
         p_probs = sample_stats.apply(self.sample_targets_to_weights, args=(fg_bg_weights,), axis=1).sum(axis=1)
         p_probs = p_probs / p_probs.sum()
         if plot:
             print("Rater: {}. Applying class-weights:\n {}".format(rater, fg_bg_weights))
         if len(sample_stats.columns) == 2:
             # assert that probs are calc'd correctly:
             # (p_probs * sample_stats["1"]).sum() == (p_probs * sample_stats["1_bg"]).sum()
             # only works if one label per patient (multi-label expectations depend on multi-label occurences).
             for rater in range(self.rater_bsize):
                 expectations = []
                 for targ in sample_stats.columns:
                     expectations.append((p_probs[rater] * sample_stats[targ]).sum())
                 assert np.allclose(expectations, expectations[0], atol=1e-4), "expectation values for fgs/bgs: {}".format(
                     expectations)
 
         if plot:
             plg.plot_batchgen_distribution(self.cf, self.dataset_pids, p_probs, self.balance_target,
                                            out_file=os.path.join(self.plot_dir,
                                                                  "train_gen_distr_"+str(self.cf.fold)+"_rater"+str(rater)+".png"))
         return p_probs, unique_ts, sample_stats
 
 
 
     def __init__(self, cf, data, name="train"):
         super(BatchGenerator_sa, self).__init__(cf, data)
         self.name = name
         self.crop_margin = np.array(self.cf.patch_size) / 8.  # min distance of ROI center to edge of cropped_patch.
         self.p_fg = 0.5
         self.empty_samples_max_ratio = 0.6
 
         self.random_count = int(cf.batch_random_ratio * cf.batch_size)
 
         self.rater_bsize = 4
         unique_ts_total = set()
         self.p_probs = []
         self.sample_stats = []
 
         # todo resolve pickling error
         # p = Pool(processes=min(self.rater_bsize, cf.n_workers))
         # mp_res = p.starmap(self.balance_target_distribution, [(r, name=="train") for r in range(self.rater_bsize)])
         # p.close()
         # p.join()
         # for r, res in enumerate(mp_res):
         #     p_probs, unique_ts, sample_stats = res
         #     self.p_probs.append(p_probs)
         #     self.sample_stats.append(sample_stats)
         #     unique_ts_total.update(unique_ts)
 
         for r in range(self.rater_bsize):
             # todo multiprocess. takes forever
             p_probs, unique_ts, sample_stats = self.balance_target_distribution(r, plot=name == "train")
             self.p_probs.append(p_probs)
             self.sample_stats.append(sample_stats)
             unique_ts_total.update(unique_ts)
 
         self.unique_ts = sorted(list(unique_ts_total))
         self.stats = {"roi_counts": np.zeros(len(self.unique_ts,), dtype='uint32'),
                       "empty_counts": np.zeros(len(self.unique_ts,), dtype='uint32')}
 
     def generate_train_batch(self):
 
         rater = np.random.randint(self.rater_bsize)
 
         # samples patients towards equilibrium of foreground classes on a roi-level (after randomly sampling the ratio batch_random_ratio).
         # random patients
         batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False))
         # target-balanced patients
         batch_patient_ids += list(np.random.choice(self.dataset_pids, size=self.batch_size-self.random_count, replace=False,
                                              p=self.p_probs[rater]))
 
         batch_data, batch_segs, batch_pids, batch_patient_labels = [], [], [], []
         batch_roi_items = {name: [] for name in self.cf.roi_items}
         # record roi count of classes in batch
         batch_roi_counts = np.zeros((len(self.unique_ts),), dtype='uint32')
         batch_empty_counts = np.zeros((len(self.unique_ts),), dtype='uint32')
         # empty count for full bg samples (empty slices in 2D/patients in 3D)
 
 
         for sample in range(self.batch_size):
 
             patient = self._data[batch_patient_ids[sample]]
 
             patient_balance_ts = np.array([roi[rater] for roi in patient[self.balance_target]])
             data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis]
             seg = np.load(patient['seg'], mmap_mode='r')
             seg = np.transpose(seg[list(seg.keys())[0]][rater], axes=(1, 2, 0))
             batch_pids.append(patient['pid'])
             (c, y, x, z) = data.shape
 
             if self.cf.dim == 2:
 
                 elig_slices, choose_fg = [], False
                 if len(patient['fg_slices']) > 0:
                     if np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or \
                             np.random.rand(1) <= self.p_fg:
                         # fg is to be picked
                         for tix in np.argsort(batch_roi_counts):
                             # pick slices of patient that have roi of sought-for target
                             # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix
                             elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero(
                                 patient_balance_ts[np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] ==
                                 self.unique_ts[tix]) > 0]
                             if len(elig_slices) > 0:
                                 choose_fg = True
                                 break
                     else:
                         # pick bg
                         elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices'][rater])
 
                 if len(elig_slices) > 0:
                     sl_pick_ix = np.random.choice(elig_slices, size=None)
                 else:
                     sl_pick_ix = np.random.choice(z, size=None)
 
                 data = data[..., sl_pick_ix]
                 seg = seg[..., sl_pick_ix]
 
             # pad data if smaller than pre_crop_size.
             if np.any([data.shape[dim + 1] < ps for dim, ps in enumerate(self.cf.pre_crop_size)]):
                 new_shape = [np.max([data.shape[dim + 1], ps]) for dim, ps in enumerate(self.cf.pre_crop_size)]
                 data = dutils.pad_nd_image(data, new_shape, mode='constant')
                 seg = dutils.pad_nd_image(seg, new_shape, mode='constant')
 
             # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg.
             crop_dims = [dim for dim, ps in enumerate(self.cf.pre_crop_size) if data.shape[dim + 1] > ps]
             if len(crop_dims) > 0:
                 if self.cf.dim == 3:
                     choose_fg = np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or \
                                 np.random.rand(1) <= self.p_fg
                 if choose_fg and np.any(seg):
                     available_roi_ids = np.unique(seg[seg>0])
                     assert np.all(patient_balance_ts[available_roi_ids-1]>0), "trying to choose roi with rating 0"
                     for tix in np.argsort(batch_roi_counts):
                         elig_roi_ids = available_roi_ids[ patient_balance_ts[available_roi_ids-1] == self.unique_ts[tix] ]
                         if len(elig_roi_ids)>0:
                             seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None))
                             roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)]
                             break
 
                     assert seg[tuple(roi_anchor_pixel)] > 0, "roi_anchor_pixel not inside roi: {}, pb_ts {}, elig ids {}".format(tuple(roi_anchor_pixel), patient_balance_ts, elig_roi_ids)
                     # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by
                     # distance to the desired ROI < patch_size /2.
                     # (here final patch size to account for center_crop after data augmentation).
                     sample_seg_center = {}
                     for ii in crop_dims:
                         low = np.max((self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] - (self.cf.patch_size[ii]//2 - self.crop_margin[ii])))
                         high = np.min((data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2,
                                        roi_anchor_pixel[ii] + (self.cf.patch_size[ii]//2 - self.crop_margin[ii])))
                         # happens if lesion on the edge of the image. dont care about roi anymore,
                         # just make sure pre-crop is inside image.
                         if low >= high:
                             low = data.shape[ii + 1] // 2 - (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2)
                             high = data.shape[ii + 1] // 2 + (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2)
                         sample_seg_center[ii] = np.random.randint(low=low, high=high)
                 else:
                     # not guaranteed to be empty. probability of emptiness depends on the data.
                     sample_seg_center = {ii: np.random.randint(low=self.cf.pre_crop_size[ii]//2,
                                                            high=data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2) for ii in crop_dims}
 
                 for ii in crop_dims:
                     min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2)
                     max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2)
                     data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1)
                     seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii)
 
             batch_data.append(data)
             batch_segs.append(seg[np.newaxis])
             for o in batch_roi_items: #after loop, holds every entry of every batchpatient per roi-item
                 batch_roi_items[o].append([roi[rater] for roi in patient[o]])
 
             if self.cf.dim == 3:
                 for tix in range(len(self.unique_ts)):
                     non_zero = np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix])
                     batch_roi_counts[tix] += non_zero
                     batch_empty_counts[tix] += int(non_zero==0)
                     # todo remove assert when checked
                     if not np.any(seg):
                         assert non_zero==0
             elif self.cf.dim == 2:
                 for tix in range(len(self.unique_ts)):
                     non_zero = np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix])
                     batch_roi_counts[tix] += non_zero
                     batch_empty_counts[tix] += int(non_zero == 0)
                     # todo remove assert when checked
                     if not np.any(seg):
                         assert non_zero==0
 
 
         data = np.array(batch_data).astype('float16')
         seg = np.array(batch_segs).astype('uint8')
         batch = {'data': data, 'seg': seg, 'pid': batch_pids, 'rater_id': rater,
                 'roi_counts': batch_roi_counts, 'empty_counts': batch_empty_counts}
         for key,val in batch_roi_items.items(): #extend batch dic by roi-wise items (obs, class ids, regression vectors...)
             batch[key] = np.array(val)
 
         return batch
 
 class PatientBatchIterator_sa(dutils.PatientBatchIterator):
     """
     creates a test generator that iterates over entire given dataset returning 1 patient per batch.
     Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actual evaluation (done in 3D),
     if willing to accept speed loss during training.
     :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or
     batch_size = n_2D_patches in 2D .
 
     This is the data & gt loader for the 4-fold single-annotator GTs: each data input has separate annotations of 4 annotators.
     the way the pipeline is currently setup, the single-annotator GTs are only used if training with validation mode
     val_patient; during testing the Iterator with the merged GTs is used.
     # todo mode val_patient not implemented yet (since very slow). would need to sample from all available rater GTs.
     """
     def __init__(self, cf, data): #threads in augmenter
         super(PatientBatchIterator_sa, self).__init__(cf, data)
         self.cf = cf
         self.patient_ix = 0
         self.dataset_pids = list(self._data.keys())
         self.patch_size =  cf.patch_size+[1] if cf.dim==2 else cf.patch_size
 
         self.rater_bsize = 4
 
 
     def generate_train_batch(self, pid=None):
 
         if pid is None:
             pid = self.dataset_pids[self.patient_ix]
         patient = self._data[pid]
 
         data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))
         # all gts are 4-fold and npz!
         seg = np.load(patient['seg'], mmap_mode='r')
         seg = np.transpose(seg[list(seg.keys())[0]], axes=(0, 2, 3, 1))
 
         # pad data if smaller than patch_size seen during training.
         if np.any([data.shape[dim] < ps for dim, ps in enumerate(self.patch_size)]):
             new_shape = [np.max([data.shape[dim], self.patch_size[dim]]) for dim, ps in enumerate(self.patch_size)]
             data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape.
             seg = dutils.pad_nd_image(seg, new_shape)
 
         # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor.
         if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds:
             out_data = data[np.newaxis, np.newaxis]
             out_seg = seg[:, np.newaxis]
             batch_3D = {'data': out_data, 'seg': out_seg}
 
             for item in self.cf.roi_items:
                 batch_3D[item] = []
             for r in range(self.rater_bsize):
                 for item in self.cf.roi_items:
                     batch_3D[item].append(np.array([roi[r] for roi in patient[item]]))
 
             converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg)
             batch_3D = converter(**batch_3D)
             batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape})
             for o in self.cf.roi_items:
                 batch_3D["patient_" + o] = batch_3D[o]
 
         if self.cf.dim == 2:
             out_data = np.transpose(data, axes=(2, 0, 1))[:, np.newaxis]  # (z, c, y, x )
             out_seg = np.transpose(seg, axes=(0, 3, 1, 2))[:, :, np.newaxis] # (n_raters, z, 1, y,x)
 
             batch_2D = {'data': out_data}
 
             for item in ["seg", "bb_target"]+self.cf.roi_items:
                 batch_2D[item] = []
 
             converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg)
             for r in range(self.rater_bsize):
                 tmp_batch = {"seg": out_seg[r]}
                 for item in self.cf.roi_items:
                     tmp_batch[item] = np.repeat(np.array([[roi[r] for roi in patient[item]]]), out_data.shape[0], axis=0)
                 tmp_batch = converter(**tmp_batch)
                 for item in ["seg", "bb_target"]+self.cf.roi_items:
                     batch_2D[item].append(tmp_batch[item])
             # for item in ["seg", "bb_target"]+self.cf.roi_items:
             #     batch_2D[item] = np.array(batch_2D[item])
 
             if self.cf.merge_2D_to_3D_preds:
                 batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'],
                                  'original_img_shape': out_data.shape})
                 for o in self.cf.roi_items:
                     batch_2D["patient_" + o] = batch_3D[o]
             else:
                 batch_2D.update({'patient_bb_target': batch_2D['bb_target'],
                                  'original_img_shape': out_data.shape})
                 for o in self.cf.roi_items:
                     batch_2D["patient_" + o] = batch_2D[o]
 
         out_batch = batch_3D if self.cf.dim == 3 else batch_2D
         out_batch.update({'pid': np.array([patient['pid']] * out_data.shape[0])})
 
         # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension.
         # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1.
         if np.any([data.shape[dim] > self.patch_size[dim] for dim in range(3)]):
             patient_batch = out_batch
             patch_crop_coords_list = dutils.get_patch_crop_coords(data, self.patch_size)
             new_img_batch  = []
             new_seg_batch = []
 
             for cix, c in enumerate(patch_crop_coords_list):
                 seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]]
                 new_seg_batch.append(seg_patch)
                 tmp_c_5 = c[5]
 
                 new_img_batch.append(data[c[0]:c[1], c[2]:c[3], c[4]:tmp_c_5])
 
             data = np.array(new_img_batch)[:, np.newaxis] # (n_patches, c, x, y, z)
             seg = np.transpose(np.array(new_seg_batch), axes=(1,0,2,3,4))[:,:,np.newaxis] # (n_raters, n_patches, x, y, z)
 
             if self.cf.dim == 2:
                 # all patches have z dimension 1 (slices). discard dimension
                 data = data[..., 0]
                 seg = seg[..., 0]
 
             patch_batch = {'data': data.astype('float32'),
                            'pid': np.array([patient['pid']] * data.shape[0])}
             # for o in self.cf.roi_items:
             #     patch_batch[o] = np.repeat(np.array([patient[o]]), len(patch_crop_coords_list), axis=0)
 
             converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False,
                                                            self.cf.class_specific_seg)
 
             for item in ["seg", "bb_target"]+self.cf.roi_items:
                 patch_batch[item] = []
             # coord_list = [np.min(seg_ixs[:, 1]) - 1, np.min(seg_ixs[:, 2]) - 1, np.max(seg_ixs[:, 1]) + 1,
             # IndexError: index 2 is out of bounds for axis 1 with size 2
             for r in range(self.rater_bsize):
                 tmp_batch = {"seg": seg[r]}
                 for item in self.cf.roi_items:
                     tmp_batch[item] = np.repeat(np.array([[roi[r] for roi in patient[item]]]), len(patch_crop_coords_list), axis=0)
                 tmp_batch = converter(**tmp_batch)
                 for item in ["seg", "bb_target"]+self.cf.roi_items:
                     patch_batch[item].append(tmp_batch[item])
 
             # patient-wise (orig) batch info for putting the patches back together after prediction
             for o in self.cf.roi_items:
                 patch_batch["patient_" + o] = patient_batch['patient_'+o]
                 if self.cf.dim==2:
                     # this could also be named "unpatched_2d_roi_items"
                     patch_batch["patient_"+o+"_2d"] = patient_batch[o]
             # adding patient-wise data and seg adds about 2 GB of additional RAM consumption to a batch 20x288x288
             # and enables calculating test-dice/viewing patient-wise results in test
             # remove, but also remove dice from metrics, if you like to save memory
             patch_batch['patient_data'] =  patient_batch['data']
             patch_batch['patient_seg'] = patient_batch['seg']
             patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
             patch_batch['patient_bb_target'] = patient_batch['patient_bb_target']
             if self.cf.dim==2:
                 patch_batch['patient_bb_target_2d'] = patient_batch['bb_target']
             patch_batch['original_img_shape'] = patient_batch['original_img_shape']
 
             out_batch = patch_batch
 
         self.patient_ix += 1
         if self.patient_ix == len(self.dataset_pids):
             self.patient_ix = 0
 
         return out_batch
 
 
 def create_data_gen_pipeline(cf, patient_data, is_training=True):
     """ create multi-threaded train/val/test batch generation and augmentation pipeline.
     :param cf: configs object.
     :param patient_data: dictionary containing one dictionary per patient in the train/test subset.
     :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing)
     :return: multithreaded_generator
     """
     BG_name = "train" if is_training else "val"
     data_gen = BatchGenerator_merged(cf, patient_data, name=BG_name) if cf.training_gts=='merged' else \
         BatchGenerator_sa(cf, patient_data, name=BG_name)
 
     # add transformations to pipeline.
     my_transforms = []
     if is_training:
         if cf.da_kwargs["mirror"]:
             mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes'])
             my_transforms.append(mirror_transform)
         spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
                                              patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
                                              do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
                                              alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
                                              do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
                                              angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
                                              do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
                                              random_crop=cf.da_kwargs['random_crop'])
 
         my_transforms.append(spatial_transform)
     else:
         my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))
 
     if cf.create_bounding_box_targets:
         my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg))
     all_transforms = Compose(my_transforms)
 
     multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
     return multithreaded_generator
 
 def get_train_generators(cf, logger,  data_statistics=True):
     """
     wrapper function for creating the training batch generator pipeline. returns the train/val generators.
     selects patients according to cv folds (generated by first run/fold of experiment):
     splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set)
     If cf.held_out_test_set is True, adds the test split to the training data.
     """
     dataset = Dataset(cf, logger)
 
     dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits)
     dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle'))
     set_splits = dataset.fg.splits
 
     test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1)
     train_ids = np.concatenate(set_splits, axis=0)
 
     if cf.held_out_test_set:
         train_ids = np.concatenate((train_ids, test_ids), axis=0)
         test_ids = []
 
     train_data = {k: v for (k, v) in dataset.data.items() if k in train_ids}
     val_data = {k: v for (k, v) in dataset.data.items() if k in val_ids}
 
     logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids),
                                                                                     len(test_ids)))
     if data_statistics:
         dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids},
                                 plot_dir=os.path.join(cf.plot_dir,"dataset"))
 
     batch_gen = {}
     batch_gen['train'] = create_data_gen_pipeline(cf, train_data, is_training=True)
     batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, is_training=False)
     if cf.val_mode == 'val_patient':
         assert cf.training_gts == 'merged', 'val_patient not yet implemented for sa gts'
         batch_gen['val_patient'] = PatientBatchIterator_merged(cf, val_data) if cf.training_gts=='merged' \
             else PatientBatchIterator_sa(cf, val_data)
         batch_gen['n_val'] = len(val_data) if cf.max_val_patients=="all" else min(len(val_data), cf.max_val_patients)
     else:
         batch_gen['n_val'] = cf.num_val_batches
 
     return batch_gen
 
 def get_test_generator(cf, logger):
     """
     wrapper function for creating the test batch generator pipeline.
     selects patients according to cv folds (generated by first run/fold of experiment)
     If cf.held_out_test_set is True, gets the data from an external folder instead.
     """
     if cf.held_out_test_set:
         sourcedir = cf.test_data_sourcedir
         test_ids = None
     else:
         sourcedir = None
         with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle:
             set_splits = pickle.load(handle)
         test_ids = set_splits[cf.fold]
 
     test_data = Dataset(cf, logger, subset_ids=test_ids, data_sourcedir=sourcedir, mode="test").data
     logger.info("data set loaded with: {} test patients".format(len(test_ids)))
     batch_gen = {}
     batch_gen['test'] = PatientBatchIterator_merged(cf, test_data)
     batch_gen['n_test'] = len(test_ids) if cf.max_test_patients == "all" else min(cf.max_test_patients, len(test_ids))
     return batch_gen
 
 
 if __name__ == "__main__":
     import sys
     sys.path.append('../')
     import plotting as plg
     import utils.exp_utils as utils
     from configs import Configs
 
     cf = Configs()
     cf.batch_size = 3
     #dataset_path = os.path.dirname(os.path.realpath(__file__))
     #exp_path = os.path.join(dataset_path, "experiments/dev")
     #cf = utils.prep_exp(dataset_path, exp_path, server_env=False, use_stored_settings=False, is_training=True)
     cf.created_fold_id_pickle = False
     total_stime = time.time()
     times = {}
 
     # cf.server_env = True
     # cf.data_dir = "experiments/dev_data"
 
     # dataset = Dataset(cf)
     # patient = dataset['Master_00018']
     cf.exp_dir = "experiments/dev/"
     cf.plot_dir = cf.exp_dir + "plots"
     os.makedirs(cf.exp_dir, exist_ok=True)
     cf.fold = 0
     logger = utils.get_logger(cf.exp_dir)
     gens = get_train_generators(cf, logger)
     train_loader = gens['train']
 
 
 
     for i in range(1):
         stime = time.time()
         #ex_batch = next(train_loader)
         print("train batch", i)
         times["train_batch"] = time.time() - stime
         #plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch.png", show_gt_labels=True)
     #
     # # with open(os.path.join(cf.exp_dir, "fold_"+str(cf.fold), "BatchGenerator_stats.txt"), mode="w") as file:
     # #    train_loader.generator.print_stats(logger, file)
     #
     val_loader = gens['val_sampling']
     stime = time.time()
     ex_batch = next(val_loader)
     times["val_batch"] = time.time() - stime
     stime = time.time()
     #plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch.png", show_gt_labels=True, plot_mods=False,
     #               show_info=False)
     times["val_plot"] = time.time() - stime
     #
     test_loader = get_test_generator(cf, logger)["test"]
     stime = time.time()
     ex_batch = test_loader.generate_train_batch()
     times["test_batch"] = time.time() - stime
     stime = time.time()
     plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png", get_time=False)#, sample_picks=[0,1,2,3])
     times["test_patchbatch_plot"] = time.time() - stime
 
     # ex_batch['data'] = ex_batch['patient_data']
     # ex_batch['seg'] = ex_batch['patient_seg']
     # ex_batch['bb_target'] = ex_batch['patient_bb_target']
     # for item in cf.roi_items:
     #     ex_batch[]
     # stime = time.time()
     # #ex_batch = next(test_loader)
     # ex_batch = next(test_loader)
     # plg.view_batch(cf, ex_batch, show_gt_labels=False, show_gt_boxes=True, patient_items=True,# vol_slice_picks=[146,148, 218,220],
     #                 out_file="experiments/dev/dev_expatientbatch.png")  # , sample_picks=[0,1,2,3])
     # times["test_patient_batch_plot"] = time.time() - stime
 
 
 
     print("Times recorded throughout:")
     for (k, v) in times.items():
         print(k, "{:.2f}".format(v))
 
     mins, secs = divmod((time.time() - total_stime), 60)
     h, mins = divmod(mins, 60)
     t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
     print("{} total runtime: {}".format(os.path.split(__file__)[1], t))
diff --git a/datasets/toy/data_loader.py b/datasets/toy/data_loader.py
index e77b3db..f4a444c 100644
--- a/datasets/toy/data_loader.py
+++ b/datasets/toy/data_loader.py
@@ -1,594 +1,595 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import sys
 sys.path.append('../') # works on cluster indep from where sbatch job is started
 import plotting as plg
+from multiprocessing import Pool
 
 import numpy as np
 import os
 from multiprocessing import Lock
 from collections import OrderedDict
 import pandas as pd
 import pickle
 import time
 
 # batch generator tools from https://github.com/MIC-DKFZ/batchgenerators
 from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror
 from batchgenerators.transforms.abstract_transforms import Compose
 from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
 from batchgenerators.transforms.spatial_transforms import SpatialTransform
 from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform
 
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 import utils.dataloader_utils as dutils
 from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates
 
 
 def load_obj(file_path):
     with open(file_path, 'rb') as handle:
         return pickle.load(handle)
 
 class Dataset(dutils.Dataset):
     r""" Load a dict holding memmapped arrays and clinical parameters for each patient,
     evtly subset of those.
         If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to
         cf.data_dir.
     :param cf: config file
     :param folds: number of folds out of @params n_cv folds to include
     :param n_cv: number of total folds
     :return: dict with imgs, segs, pids, class_labels, observables
     """
 
     def __init__(self, cf, logger, subset_ids=None, data_sourcedir=None, mode='train'):
         super(Dataset,self).__init__(cf, data_sourcedir=data_sourcedir)
 
         load_exact_gts = (mode=='test' or cf.val_mode=="val_patient") and self.cf.test_against_exact_gt
 
         p_df = pd.read_pickle(os.path.join(self.data_dir, cf.info_df_name))
 
         if subset_ids is not None:
             p_df = p_df[p_df.pid.isin(subset_ids)]
             logger.info('subset: selected {} instances from df'.format(len(p_df)))
 
         pids = p_df.pid.tolist()
         #evtly copy data from data_sourcedir to data_dest
         if cf.server_env and not hasattr(cf, "data_dir"):
             file_subset = [os.path.join(self.data_dir, '{}.*'.format(pid)) for pid in pids]
             file_subset += [os.path.join(self.data_dir, '{}_seg.*'.format(pid)) for pid in pids]
             file_subset += [cf.info_df_name]
             if load_exact_gts:
                 file_subset += [os.path.join(self.data_dir, '{}_exact_seg.*'.format(pid)) for pid in pids]
             self.copy_data(cf, file_subset=file_subset)
 
         img_paths = [os.path.join(self.data_dir, '{}.npy'.format(pid)) for pid in pids]
         seg_paths = [os.path.join(self.data_dir, '{}_seg.npy'.format(pid)) for pid in pids]
         if load_exact_gts:
             exact_seg_paths = [os.path.join(self.data_dir, '{}_exact_seg.npy'.format(pid)) for pid in pids]
 
         class_targets = p_df['class_ids'].tolist()
         rg_targets = p_df['regression_vectors'].tolist()
         if load_exact_gts:
             exact_rg_targets = p_df['undistorted_rg_vectors'].tolist()
         fg_slices = p_df['fg_slices'].tolist()
 
         self.data = OrderedDict()
         for ix, pid in enumerate(pids):
             self.data[pid] = {'data': img_paths[ix], 'seg': seg_paths[ix], 'pid': pid,
                               'fg_slices': np.array(fg_slices[ix])}
             if load_exact_gts:
                 self.data[pid]['exact_seg'] = exact_seg_paths[ix]
             if 'class' in self.cf.prediction_tasks:
                 self.data[pid]['class_targets'] = np.array(class_targets[ix], dtype='uint8')
             else:
                 self.data[pid]['class_targets'] = np.ones_like(np.array(class_targets[ix]), dtype='uint8')
             if load_exact_gts:
                 self.data[pid]['exact_class_targets'] = self.data[pid]['class_targets']
             if any(['regression' in task for task in self.cf.prediction_tasks]):
                 self.data[pid]['regression_targets'] = np.array(rg_targets[ix], dtype='float16')
                 self.data[pid]["rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in rg_targets[ix]], dtype='uint8')
                 if load_exact_gts:
                     self.data[pid]['exact_regression_targets'] = np.array(exact_rg_targets[ix], dtype='float16')
                     self.data[pid]["exact_rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in exact_rg_targets[ix]],
                                                                 dtype='uint8')
 
 
         cf.roi_items = cf.observables_rois[:]
         cf.roi_items += ['class_targets']
         if any(['regression' in task for task in self.cf.prediction_tasks]):
             cf.roi_items += ['regression_targets']
             cf.roi_items += ['rg_bin_targets']
 
         self.set_ids = np.array(list(self.data.keys()))
         self.df = None
 
 class BatchGenerator(dutils.BatchGenerator):
     """
     creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D)
     from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size.
     Actual patch_size is obtained after data augmentation.
     :param data: data dictionary as provided by 'load_dataset'.
     :param batch_size: number of patients to sample for the batch
     :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target
     """
     def __init__(self, cf, data, sample_pids_w_replace=True, max_batches=None, raise_stop_iteration=False, seed=0):
         super(BatchGenerator, self).__init__(cf, data, sample_pids_w_replace=sample_pids_w_replace,
                                              max_batches=max_batches, raise_stop_iteration=raise_stop_iteration,
                                              seed=seed)
 
         self.chans = cf.channels if cf.channels is not None else np.index_exp[:]
         assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing"
 
         self.crop_margin = np.array(self.cf.patch_size) / 8.  # min distance of ROI center to edge of cropped_patch.
         self.p_fg = 0.5
         self.empty_samples_max_ratio = 0.6
 
         self.balance_target_distribution(plot=sample_pids_w_replace)
 
     def generate_train_batch(self):
         # everything done in here is per batch
         # print statements in here get confusing due to multithreading
 
         batch_pids = self.get_batch_pids()
 
         batch_data, batch_segs, batch_patient_targets = [], [], []
         batch_roi_items = {name: [] for name in self.cf.roi_items}
         # record roi count and empty count of classes in batch
         # empty count for no presence of resp. class in whole sample (empty slices in 2D/patients in 3D)
         batch_roi_counts = np.zeros((len(self.unique_ts),), dtype='uint32')
         batch_empty_counts = np.zeros((len(self.unique_ts),), dtype='uint32')
 
         for b in range(len(batch_pids)):
             patient = self._data[batch_pids[b]]
 
             data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis]
             seg =  np.load(patient['seg'], mmap_mode='r').astype('uint8')
 
             (c, y, x, z) = data.shape
             if self.cf.dim == 2:
                 elig_slices, choose_fg = [], False
                 if len(patient['fg_slices']) > 0:
                     if np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or np.random.rand(
                             1) <= self.p_fg:
                         # fg is to be picked
                         for tix in np.argsort(batch_roi_counts):
                             # pick slices of patient that have roi of sought-for target
                             # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix
                             elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero(
                                 patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] ==
                                 self.unique_ts[tix]) > 0]
                             if len(elig_slices) > 0:
                                 choose_fg = True
                                 break
                     else:
                         # pick bg
                         elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices'])
                 if len(elig_slices) > 0:
                     sl_pick_ix = np.random.choice(elig_slices, size=None)
                 else:
                     sl_pick_ix = np.random.choice(z, size=None)
                 data = data[..., sl_pick_ix]
                 seg = seg[..., sl_pick_ix]
 
             spatial_shp = data[0].shape
             assert spatial_shp == seg.shape, "spatial shape incongruence betw. data and seg"
             if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]):
                 new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))]
                 data = dutils.pad_nd_image(data, (len(data), *new_shape))
                 seg = dutils.pad_nd_image(seg, new_shape)
 
             # eventual cropping to pre_crop_size: sample pixel from random ROI and shift center,
             # if possible, to that pixel, so that img still contains ROI after pre-cropping
             dim_cropflags = [spatial_shp[i] > self.cf.pre_crop_size[i] for i in range(len(spatial_shp))]
             if np.any(dim_cropflags):
                 # sample pixel from random ROI and shift center, if possible, to that pixel
                 if self.cf.dim==3:
                     choose_fg = np.any(batch_empty_counts/self.batch_size>=self.empty_samples_max_ratio) or \
                                 np.random.rand(1) <= self.p_fg
                 if choose_fg and np.any(seg):
                     available_roi_ids = np.unique(seg)[1:]
                     for tix in np.argsort(batch_roi_counts):
                         elig_roi_ids = available_roi_ids[patient[self.balance_target][available_roi_ids-1] == self.unique_ts[tix]]
                         if len(elig_roi_ids)>0:
                             seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None))
                             break
                     roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)]
                     assert seg[tuple(roi_anchor_pixel)] > 0
 
                     # sample the patch center coords. constrained by edges of image - pre_crop_size /2 and
                     # distance to the selected ROI < patch_size /2
                     def get_cropped_centercoords(dim):
                         low = np.max((self.cf.pre_crop_size[dim] // 2,
                                       roi_anchor_pixel[dim] - (
                                                   self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim])))
                         high = np.min((spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2,
                                        roi_anchor_pixel[dim] + (
                                                    self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim])))
                         if low >= high:  # happens if lesion on the edge of the image.
                             low = self.cf.pre_crop_size[dim] // 2
                             high = spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2
 
                         assert low < high, 'low greater equal high, data dimension {} too small, shp {}, patient {}, low {}, high {}'.format(
                             dim,
                             spatial_shp, patient['pid'], low, high)
                         return np.random.randint(low=low, high=high)
                 else:
                     # sample crop center regardless of ROIs, not guaranteed to be empty
                     def get_cropped_centercoords(dim):
                         return np.random.randint(low=self.cf.pre_crop_size[dim] // 2,
                                                  high=spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2)
 
                 sample_seg_center = {}
                 for dim in np.where(dim_cropflags)[0]:
                     sample_seg_center[dim] = get_cropped_centercoords(dim)
                     min_ = int(sample_seg_center[dim] - self.cf.pre_crop_size[dim] // 2)
                     max_ = int(sample_seg_center[dim] + self.cf.pre_crop_size[dim] // 2)
                     data = np.take(data, indices=range(min_, max_), axis=dim + 1)  # +1 for channeldim
                     seg = np.take(seg, indices=range(min_, max_), axis=dim)
 
             batch_data.append(data)
             batch_segs.append(seg[np.newaxis])
 
             for o in batch_roi_items: #after loop, holds every entry of every batchpatient per observable
                     batch_roi_items[o].append(patient[o])
 
             if self.cf.dim == 3:
                 for tix in range(len(self.unique_ts)):
                     non_zero = np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix])
                     batch_roi_counts[tix] += non_zero
                     batch_empty_counts[tix] += int(non_zero==0)
                     # todo remove assert when checked
                     if not np.any(seg):
                         assert non_zero==0
             elif self.cf.dim == 2:
                 for tix in range(len(self.unique_ts)):
                     non_zero = np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix])
                     batch_roi_counts[tix] += non_zero
                     batch_empty_counts[tix] += int(non_zero == 0)
                     # todo remove assert when checked
                     if not np.any(seg):
                         assert non_zero==0
 
         batch = {'data': np.array(batch_data), 'seg': np.array(batch_segs).astype('uint8'),
                  'pid': batch_pids,
                  'roi_counts': batch_roi_counts, 'empty_counts': batch_empty_counts}
         for key,val in batch_roi_items.items(): #extend batch dic by entries of observables dic
             batch[key] = np.array(val)
 
         return batch
 
 class PatientBatchIterator(dutils.PatientBatchIterator):
     """
     creates a test generator that iterates over entire given dataset returning 1 patient per batch.
     Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actually evaluation (done in 3D),
     if willing to accept speed-loss during training.
     Specific properties of toy data set: toy data may be created with added ground-truth noise. thus, there are
     exact ground truths (GTs) and noisy ground truths available. the normal or noisy GTs are used in training by
     the BatchGenerator. The PatientIterator, however, may use the exact GTs if set in configs.
 
     :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or
     batch_size = n_2D_patches in 2D .
     """
 
     def __init__(self, cf, data, mode='test'):
         super(PatientBatchIterator, self).__init__(cf, data)
 
         self.patch_size = cf.patch_size_2D + [1] if cf.dim == 2 else cf.patch_size_3D
         self.chans = cf.channels if cf.channels is not None else np.index_exp[:]
         assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing"
 
         if (mode=="validation" and hasattr(self.cf, 'val_against_exact_gt') and self.cf.val_against_exact_gt) or \
                 (mode == 'test' and self.cf.test_against_exact_gt):
             self.gt_prefix = 'exact_'
             print("PatientIterator: Loading exact Ground Truths.")
         else:
             self.gt_prefix = ''
 
         self.patient_ix = 0  # running index over all patients in set
 
     def generate_train_batch(self, pid=None):
 
         if pid is None:
             pid = self.dataset_pids[self.patient_ix]
         patient = self._data[pid]
 
         # already swapped dimensions in pp from (c,)z,y,x to c,y,x,z or h,w,d to ease 2D/3D-case handling
         data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis]
         seg =  np.load(patient[self.gt_prefix+'seg']).astype('uint8')[np.newaxis]
 
         data_shp_raw = data.shape
         plot_bg = data[self.cf.plot_bg_chan] if self.cf.plot_bg_chan not in self.chans else None
         data = data[self.chans]
         discarded_chans = len(
             [c for c in np.setdiff1d(np.arange(data_shp_raw[0]), self.chans) if c < self.cf.plot_bg_chan])
         spatial_shp = data[0].shape  # spatial dims need to be in order x,y,z
         assert spatial_shp == seg[0].shape, "spatial shape incongruence betw. data and seg"
 
         if np.any([spatial_shp[i] < ps for i, ps in enumerate(self.patch_size)]):
             new_shape = [np.max([spatial_shp[i], self.patch_size[i]]) for i in range(len(self.patch_size))]
             data = dutils.pad_nd_image(data, new_shape)  # use 'return_slicer' to crop image back to original shape.
             seg = dutils.pad_nd_image(seg, new_shape)
             if plot_bg is not None:
                 plot_bg = dutils.pad_nd_image(plot_bg, new_shape)
 
         if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds:
             # adds the batch dim here bc won't go through MTaugmenter
             out_data = data[np.newaxis]
             out_seg = seg[np.newaxis]
             if plot_bg is not None:
                out_plot_bg = plot_bg[np.newaxis]
             # data and seg shape: (1,c,x,y,z), where c=1 for seg
 
             batch_3D = {'data': out_data, 'seg': out_seg}
             for o in self.cf.roi_items:
                 batch_3D[o] = np.array([patient[self.gt_prefix+o]])
             converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg)
             batch_3D = converter(**batch_3D)
             batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape})
             for o in self.cf.roi_items:
                 batch_3D["patient_" + o] = batch_3D[o]
 
         if self.cf.dim == 2:
             out_data = np.transpose(data, axes=(3, 0, 1, 2)).astype('float32')  # (c,y,x,z) to (b=z,c,x,y), use z=b as batchdim
             out_seg = np.transpose(seg, axes=(3, 0, 1, 2)).astype('uint8')  # (c,y,x,z) to (b=z,c,x,y)
 
             batch_2D = {'data': out_data, 'seg': out_seg}
             for o in self.cf.roi_items:
                 batch_2D[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(out_data), axis=0)
             converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg)
             batch_2D = converter(**batch_2D)
 
             if plot_bg is not None:
                 out_plot_bg = np.transpose(plot_bg, axes=(2, 0, 1)).astype('float32')
 
             if self.cf.merge_2D_to_3D_preds:
                 batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'],
                                  'original_img_shape': out_data.shape})
                 for o in self.cf.roi_items:
                     batch_2D["patient_" + o] = batch_3D[o]
             else:
                 batch_2D.update({'patient_bb_target': batch_2D['bb_target'],
                                  'original_img_shape': out_data.shape})
                 for o in self.cf.roi_items:
                     batch_2D["patient_" + o] = batch_2D[o]
 
         out_batch = batch_3D if self.cf.dim == 3 else batch_2D
         out_batch.update({'pid': np.array([patient['pid']] * len(out_data))})
 
         if self.cf.plot_bg_chan in self.chans and discarded_chans > 0:  # len(self.chans[:self.cf.plot_bg_chan])<data_shp_raw[0]:
             assert plot_bg is None
             plot_bg = int(self.cf.plot_bg_chan - discarded_chans)
             out_plot_bg = plot_bg
         if plot_bg is not None:
             out_batch['plot_bg'] = out_plot_bg
 
         # eventual tiling into patches
         spatial_shp = out_batch["data"].shape[2:]
         if np.any([spatial_shp[ix] > self.patch_size[ix] for ix in range(len(spatial_shp))]):
             patient_batch = out_batch
             print("patientiterator produced patched batch!")
             patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size)
             new_img_batch, new_seg_batch = [], []
 
             for c in patch_crop_coords_list:
                 new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:c[5]])
                 seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]]
                 new_seg_batch.append(seg_patch)
             shps = []
             for arr in new_img_batch:
                 shps.append(arr.shape)
 
             data = np.array(new_img_batch)  # (patches, c, x, y, z)
             seg = np.array(new_seg_batch)
             if self.cf.dim == 2:
                 # all patches have z dimension 1 (slices). discard dimension
                 data = data[..., 0]
                 seg = seg[..., 0]
             patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'),
                            'pid': np.array([patient['pid']] * data.shape[0])}
             for o in self.cf.roi_items:
                 patch_batch[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(patch_crop_coords_list), axis=0)
             #patient-wise (orig) batch info for putting the patches back together after prediction
             for o in self.cf.roi_items:
                 patch_batch["patient_"+o] = patient_batch["patient_"+o]
                 if self.cf.dim == 2:
                     # this could also be named "unpatched_2d_roi_items"
                     patch_batch["patient_" + o + "_2d"] = patient_batch[o]
             patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
             patch_batch['patient_bb_target'] = patient_batch['patient_bb_target']
             if self.cf.dim == 2:
                 patch_batch['patient_bb_target_2d'] = patient_batch['bb_target']
             patch_batch['patient_data'] = patient_batch['data']
             patch_batch['patient_seg'] = patient_batch['seg']
             patch_batch['original_img_shape'] = patient_batch['original_img_shape']
             if plot_bg is not None:
                 patch_batch['patient_plot_bg'] = patient_batch['plot_bg']
 
             converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, get_rois_from_seg=False,
                                                            class_specific_seg=self.cf.class_specific_seg)
 
             patch_batch = converter(**patch_batch)
             out_batch = patch_batch
 
         self.patient_ix += 1
         if self.patient_ix == len(self.dataset_pids):
             self.patient_ix = 0
 
         return out_batch
 
 
 def create_data_gen_pipeline(cf, patient_data, do_aug=True, **kwargs):
     """
     create mutli-threaded train/val/test batch generation and augmentation pipeline.
     :param patient_data: dictionary containing one dictionary per patient in the train/test subset.
     :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing)
     :return: multithreaded_generator
     """
 
     # create instance of batch generator as first element in pipeline.
     data_gen = BatchGenerator(cf, patient_data, **kwargs)
 
     my_transforms = []
     if do_aug:
         if cf.da_kwargs["mirror"]:
             mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes'])
             my_transforms.append(mirror_transform)
 
         spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
                                              patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
                                              do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
                                              alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
                                              do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
                                              angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
                                              do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
                                              random_crop=cf.da_kwargs['random_crop'])
 
         my_transforms.append(spatial_transform)
     else:
         my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))
 
     my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg))
     all_transforms = Compose(my_transforms)
     # multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms)
     multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
     return multithreaded_generator
 
 def get_train_generators(cf, logger, data_statistics=False):
     """
     wrapper function for creating the training batch generator pipeline. returns the train/val generators.
     selects patients according to cv folds (generated by first run/fold of experiment):
     splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set)
     If cf.hold_out_test_set is True, adds the test split to the training data.
     """
     dataset = Dataset(cf, logger)
     dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits)
     dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle'))
     set_splits = dataset.fg.splits
 
     test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1)
     train_ids = np.concatenate(set_splits, axis=0)
 
     if cf.held_out_test_set:
         train_ids = np.concatenate((train_ids, test_ids), axis=0)
         test_ids = []
 
     train_data = {k: v for (k, v) in dataset.data.items() if str(k) in train_ids}
     val_data = {k: v for (k, v) in dataset.data.items() if str(k) in val_ids}
 
     logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids),
                                                                                     len(test_ids)))
     if data_statistics:
         dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids}, plot_dir=
         os.path.join(cf.plot_dir,"dataset"))
 
 
 
     batch_gen = {}
     batch_gen['train'] = create_data_gen_pipeline(cf, train_data, do_aug=cf.do_aug, sample_pids_w_replace=True)
     if cf.val_mode == 'val_patient':
         batch_gen['val_patient'] = PatientBatchIterator(cf, val_data, mode='validation')
         batch_gen['n_val'] = len(val_ids) if cf.max_val_patients=="all" else min(len(val_ids), cf.max_val_patients)
     elif cf.val_mode == 'val_sampling':
         batch_gen['n_val'] = int(np.ceil(len(val_data)/cf.batch_size)) if cf.num_val_batches == "all" else cf.num_val_batches
         # in current setup, val loader is used like generator. with max_batches being applied in train routine.
         batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, do_aug=False, sample_pids_w_replace=False,
                                                              max_batches=None, raise_stop_iteration=False)
 
     return batch_gen
 
 def get_test_generator(cf, logger):
     """
     if get_test_generators is possibly called multiple times in server env, every time of
     Dataset initiation rsync will check for copying the data; this should be okay
     since rsync will not copy if files already exist in destination.
     """
 
     if cf.held_out_test_set:
         sourcedir = cf.test_data_sourcedir
         test_ids = None
     else:
         sourcedir = None
         with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle:
             set_splits = pickle.load(handle)
         test_ids = set_splits[cf.fold]
 
     test_set = Dataset(cf, logger, subset_ids=test_ids, data_sourcedir=sourcedir, mode='test')
     logger.info("data set loaded with: {} test patients".format(len(test_set.set_ids)))
     batch_gen = {}
     batch_gen['test'] = PatientBatchIterator(cf, test_set.data)
     batch_gen['n_test'] = len(test_set.set_ids) if cf.max_test_patients=="all" else \
         min(cf.max_test_patients, len(test_set.set_ids))
 
     return batch_gen
 
 
 if __name__=="__main__":
 
     import utils.exp_utils as utils
     from datasets.toy.configs import Configs
 
     cf = Configs()
 
     total_stime = time.time()
     times = {}
 
     # cf.server_env = True
     # cf.data_dir = "experiments/dev_data"
 
     cf.exp_dir = "experiments/dev/"
     cf.plot_dir = cf.exp_dir + "plots"
     os.makedirs(cf.exp_dir, exist_ok=True)
     cf.fold = 0
     logger = utils.get_logger(cf.exp_dir)
     gens = get_train_generators(cf, logger)
     train_loader = gens['train']
     for i in range(0):
         stime = time.time()
         print("producing training batch nr ", i)
         ex_batch = next(train_loader)
         times["train_batch"] = time.time() - stime
         #experiments/dev/dev_exbatch_{}.png".format(i)
         plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=False)
 
 
     val_loader = gens['val_sampling']
     stime = time.time()
     for i in range(1):
         ex_batch = next(val_loader)
         times["val_batch"] = time.time() - stime
         stime = time.time()
         #"experiments/dev/dev_exvalbatch_{}.png"
         plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=True)
         times["val_plot"] = time.time() - stime
     #
     test_loader = get_test_generator(cf, logger)["test"]
     stime = time.time()
     ex_batch = test_loader.generate_train_batch(pid=None)
     times["test_batch"] = time.time() - stime
     stime = time.time()
     plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png", vmin=0)
     times["test_patchbatch_plot"] = time.time() - stime
 
 
 
     print("Times recorded throughout:")
     for (k, v) in times.items():
         print(k, "{:.2f}".format(v))
 
     mins, secs = divmod((time.time() - total_stime), 60)
     h, mins = divmod(mins, 60)
     t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
     print("{} total runtime: {}".format(os.path.split(__file__)[1], t))
\ No newline at end of file