diff --git a/CodeDoc.odt b/CodeDoc.odt
deleted file mode 100644
index 880f0c7..0000000
Binary files a/CodeDoc.odt and /dev/null differ
diff --git a/code_optim/code_optim.py b/code_optim/code_optim.py
deleted file mode 100644
index 2702b3c..0000000
--- a/code_optim/code_optim.py
+++ /dev/null
@@ -1,328 +0,0 @@
-"""
-Created at 04/02/19 13:50
-@author: gregor 
-"""
-import plotting as plg
-
-import sys
-import os
-import pickle
-import json, socket, subprocess, time, threading
-
-import numpy as np
-import pandas as pd
-import torch
-from collections import OrderedDict
-from matplotlib.lines import  Line2D
-
-import utils.exp_utils as utils
-import utils.model_utils as mutils
-from predictor import Predictor
-from evaluator import Evaluator
-
-
-"""
-Need to start this script as sudo for background logging thread to work (needs to set niceness<0)
-"""
-
-
-def measure_train_batch_loading(logger, batch_gen, iters=1, warm_up=20, is_val=False, out_dir=None):
-    torch.cuda.empty_cache()
-    timer_key = "val_fw" if is_val else "train_fw"
-    for i in range(warm_up):
-        batch = next(batch_gen)
-        print("\rloaded warm-up batch {}/{}".format(i+1, warm_up), end="", flush=True)
-    sysmetrics_start_ix = len(logger.sysmetrics.index)
-    for i in range(iters):
-        logger.time(timer_key)
-        batch = next(batch_gen)
-        print("\r{} batch {} loading took {:.3f}s.".format("val" if is_val else "train", i+1,
-                                                           logger.time(timer_key)), end="", flush=True)
-    print("Total avg fw {:.2f}s".format(logger.get_time(timer_key)/iters))
-    if out_dir is not None:
-        assert len(logger.sysmetrics[sysmetrics_start_ix:-1]) > 0, "train loading: empty df"
-        logger.sysmetrics[sysmetrics_start_ix:-1].to_pickle(os.path.join(
-            out_dir,"{}_loading.pickle".format("val" if is_val else "train")))
-    return logger.sysmetrics[sysmetrics_start_ix:-1]
-
-
-def measure_RPN(logger, net, batch, iters=1, warm_up=20, out_dir=None):
-    torch.cuda.empty_cache()
-    data = torch.from_numpy(batch["data"]).float().cuda()
-    fpn_outs = net.fpn(data)
-    rpn_feature_maps = [fpn_outs[i] for i in net.cf.pyramid_levels]
-
-    for i in range(warm_up):
-        layer_outputs = [net.rpn(p_feats) for p_feats in rpn_feature_maps]
-        print("\rfinished warm-up batch {}/{}".format(i+1, warm_up), end="", flush=True)
-    sysmetrics_start_ix = len(logger.sysmetrics.index)
-    for i in range(iters):
-        logger.time("RPN_fw")
-        layer_outputs = [net.rpn(p_feats) for p_feats in rpn_feature_maps]
-        print("\r{} batch took {:.3f}s.".format("RPN", logger.time("RPN_fw")), end="", flush=True)
-    print("Total avg fw {:.2f}s".format(logger.get_time("RPN_fw")/iters))
-
-    if out_dir is not None:
-        assert len(logger.sysmetrics[sysmetrics_start_ix:-1])>0, "six {}, sysm ix {}".format(sysmetrics_start_ix, logger.sysmetrics.index)
-        logger.sysmetrics[sysmetrics_start_ix:-1].to_pickle(os.path.join(out_dir,"RPN_msrmts.pickle"))
-    return logger.sysmetrics[sysmetrics_start_ix:-1]
-
-def measure_FPN(logger, net, batch, iters=1, warm_up=20, out_dir=None):
-    torch.cuda.empty_cache()
-    data = torch.from_numpy(batch["data"]).float().cuda()
-    for i in range(warm_up):
-        outputs = net.fpn(data)
-        print("\rfinished warm-up batch {}/{}".format(i+1, warm_up), end="", flush=True)
-    sysmetrics_start_ix = len(logger.sysmetrics.index)
-    for i in range(iters):
-        logger.time("FPN_fw")
-        outputs = net.fpn(data)
-        #print("in mean thread", logger.sysmetrics.index)
-        print("\r{} batch took {:.3f}s.".format("FPN", logger.time("FPN_fw")), end="", flush=True)
-    print("Total avg fw {:.2f}s".format(logger.get_time("FPN_fw")/iters))
-
-    if out_dir is not None:
-        assert len(logger.sysmetrics[sysmetrics_start_ix:-1])>0, "six {}, sysm ix {}".format(sysmetrics_start_ix, logger.sysmetrics.index)
-        logger.sysmetrics[sysmetrics_start_ix:-1].to_pickle(os.path.join(out_dir,"FPN_msrmts.pickle"))
-    return logger.sysmetrics[sysmetrics_start_ix:-1]
-
-def measure_forward(logger, net, batch, iters=1, warm_up=20, out_dir=None):
-    torch.cuda.empty_cache()
-    data = torch.from_numpy(batch["data"]).float().cuda()
-    for i in range(warm_up):
-        outputs = net.forward(data)
-        print("\rfinished warm-up batch {}/{}".format(i+1, warm_up), end="", flush=True)
-    sysmetrics_start_ix = len(logger.sysmetrics.index)
-    for i in range(iters):
-        logger.time("net_fw")
-        outputs = net.forward(data)
-        print("\r{} batch took {:.3f}s.".format("forward", logger.time("net_fw")), end="", flush=True)
-    print("Total avg fw {:.2f}s".format(logger.get_time("net_fw")/iters))
-    if out_dir is not None:
-        assert len(logger.sysmetrics[sysmetrics_start_ix:-1]) > 0, "fw: empty df"
-        logger.sysmetrics[sysmetrics_start_ix:-1].to_pickle(os.path.join(out_dir,"fw_msrmts.pickle"))
-    return logger.sysmetrics[sysmetrics_start_ix:-1].copy()
-
-def measure_train_forward(logger, net, batch, iters=1, warm_up=20, is_val=False, out_dir=None):
-    torch.cuda.empty_cache()
-    timer_key = "val_fw" if is_val else "train_fw"
-    optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay)
-    for i in range(warm_up):
-        results_dict = net.train_forward(batch)
-        print("\rfinished warm-up batch {}/{}".format(i+1, warm_up), end="", flush=True)
-    sysmetrics_start_ix = len(logger.sysmetrics.index)
-    for i in range(iters):
-        logger.time(timer_key)
-        if not is_val:
-            optimizer.zero_grad()
-        results_dict = net.train_forward(batch, is_validation=is_val)
-        #results_dict["torch_loss"] *= torch.rand(1).cuda()
-        if not is_val:
-            results_dict["torch_loss"].backward()
-            optimizer.step()
-        print("\r{} batch took {:.3f}s.".format("val" if is_val else "train", logger.time(timer_key)), end="", flush=True)
-    print("Total avg fw {:.2f}s".format(logger.get_time(timer_key)/iters))
-    if out_dir is not None:
-        assert len(logger.sysmetrics[sysmetrics_start_ix:-1]) > 0, "train_fw: empty df"
-        logger.sysmetrics[sysmetrics_start_ix:-1].to_pickle(os.path.join(
-            out_dir,"{}_msrmts.pickle".format("val_fw" if is_val else "train_fwbw")))
-    return logger.sysmetrics[sysmetrics_start_ix:-1].copy()
-
-def measure_train_fw_incl_batch_gen(logger, net, batch_gen, iters=1, warm_up=20, is_val=False, out_dir=None):
-    torch.cuda.empty_cache()
-    timer_key = "val_fw" if is_val else "train_fw"
-    for i in range(warm_up):
-        batch = next(batch_gen)
-        results_dict = net.train_forward(batch)
-        print("\rfinished warm-up batch {}/{}".format(i+1, warm_up), end="", flush=True)
-    sysmetrics_start_ix = len(logger.sysmetrics.index)
-    for i in range(iters):
-        logger.time(timer_key)
-        batch = next(batch_gen)
-        results_dict = net.train_forward(batch, is_validation=is_val)
-        if not is_val:
-            results_dict["torch_loss"].backward()
-        print("\r{} batch took {:.3f}s.".format("val" if is_val else "train", logger.time(timer_key)), end="", flush=True)
-    print("Total avg fw {:.2f}s".format(logger.get_time(timer_key)/iters))
-    if out_dir is not None:
-        assert len(logger.sysmetrics[sysmetrics_start_ix:-1]) > 0, "train_fw incl batch: empty df"
-        logger.sysmetrics[sysmetrics_start_ix:-1].to_pickle(os.path.join(
-            out_dir,"{}_incl_batch_msrmts.pickle".format("val_fw" if is_val else "train_fwbw")))
-    return logger.sysmetrics[sysmetrics_start_ix:-1]
-
-
-
-def measure_train_backward(cf, logger, net, batch, iters=1, warm_up=20, out_dir=None):
-    torch.cuda.empty_cache()
-    optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay)
-    results_dict = net.train_forward(batch, is_validation=False)
-    loss = results_dict["torch_loss"]
-    for i in range(warm_up):
-        loss.backward(retain_graph=True)
-        print("\rfinished warm-up batch {}/{}".format(i + 1, warm_up), end="", flush=True)
-    sysmetrics_start_ix = len(logger.sysmetrics.index)
-    for i in range(iters):
-        logger.time("train_bw")
-        optimizer.zero_grad()
-        loss.backward(retain_graph=True)
-        optimizer.step()
-        print("\r{} bw batch {} took {:.3f}s.".format("train", i+1, logger.time("train_bw")), end="", flush=True)
-    print("Total avg bw {:.2f}s".format(logger.get_time("train_bw") / iters))
-    if out_dir is not None:
-        assert len(logger.sysmetrics[sysmetrics_start_ix:-1]) > 0, "train_bw: empty df"
-        logger.sysmetrics[sysmetrics_start_ix:-1].to_pickle(os.path.join(out_dir,"train_bw.pickle"))
-    return logger.sysmetrics[sysmetrics_start_ix:-1]
-
-
-
-def measure_test_forward(logger, net, batch, iters=1, return_masks=False):
-    torch.cuda.empty_cache()
-    for i in range(iters):
-        logger.time("test_fw")
-        results_dict = net.test_forward(batch, return_masks=return_masks)
-        print("\rtest batch took {:.3f}s.".format(logger.time("test_fw")), end="", flush=True)
-    print("Total avg test fw {:.2f}s".format(logger.get_time('test_fw')/iters))
-
-
-def perform_measurements(args, iters=20):
-
-    cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, is_training=True, use_stored_settings=False)
-
-    cf.exp_dir = args.exp_dir
-
-    # pid = 1624
-    # cf.fold = find_pid_in_splits(pid)
-    cf.fold = 0
-    cf.merge_2D_to_3D_preds = False
-    cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold))
-
-    logger = utils.get_logger(cf.exp_dir, sysmetrics_interval=0.5)
-    model = utils.import_module('model', cf.model_path)
-    net = model.net(cf, logger).cuda()
-    test_predictor = Predictor(cf, None, logger, mode='test')
-    #cf.p_batchbalance = 0
-    #cf.do_aug = False
-    batch_gens = data_loader.get_train_generators(cf, logger)
-    train_gen, val_gen = batch_gens['train'], batch_gens['val_sampling']
-    test_gen = data_loader.get_test_generator(cf, logger)['test']
-    weight_paths = [os.path.join(cf.fold_dir, '{}_best_params.pth'.format(rank)) for rank in
-                    test_predictor.epoch_ranking]
-
-    try:
-        pids = test_gen.dataset_pids
-    except:
-        pids = test_gen.generator.dataset_pids
-    print("pids in test set: ", pids)
-    pid = pids[0]
-    assert pid in pids
-    pid = "285"
-
-    model_name = cf.model
-
-    results_dir = "/home/gregor/Documents/medicaldetectiontoolkit/code_optim/"+model_name
-    os.makedirs(results_dir, exist_ok=True)
-    print("Model: {}.".format(model_name))
-    #gpu_logger = utils.Nvidia_GPU_Logger()
-    #gpu_logger.start(interval=0.1)
-    #measure_train_batch_loading(logger, train_gen, iters=iters, out_dir=results_dir)
-    #measure_train_batch_loading(logger, val_gen, iters=iters, is_val=True, out_dir=results_dir)
-    #measure_RPN(logger, net, next(train_gen), iters=iters,  out_dir=results_dir)
-    #measure_FPN(logger, net, next(train_gen), iters=iters, out_dir=results_dir)
-    #measure_forward(logger, net, next(train_gen), iters=iters, out_dir=results_dir)
-    measure_train_forward(logger, net, next(train_gen), iters=iters, out_dir=results_dir) #[['global_step', 'gpu_utilization (%)']]
-    #measure_train_forward(logger, net, next(val_gen), iters=iters, is_val=True, out_dir=results_dir)
-    #measure_train_fw_incl_batch_gen(logger, net, train_gen, iters=iters, out_dir=results_dir)
-    #measure_train_fw_incl_batch_gen(logger, net, val_gen, iters=iters, is_val=True, out_dir=results_dir)
-    #measure_train_backward(cf, logger, net, next(train_gen), iters=iters, out_dir=results_dir)
-    #measure_test_forward(logger, net, next(test_gen), iters=iters, return_masks=cf.return_masks_in_test)
-
-    return results_dir, iters
-
-def plot_folder(cf, ax, results_dir, iters, markers='o', offset=(+0.01, -4)):
-    point_renaming = {"FPN_msrmts": ["FPN.forward", (offset[0], -4)], "fw_msrmts": "net.forward",
-                      "train_bw": "backward+optimizer",
-                      "train_fw_msrmts": "net.train_forward",
-                      "train_fw_incl_batch": "train_fw+batch", "RPN_msrmts": "RPN.forward",
-                      "train_fwbw_msrmts": ["train_fw+bw", (offset[0], +2)],
-                      "val_fw_msrmts": ["val_fw", (offset[0], -4)],
-                      "train_fwbw_incl_batch_msrmts": ["train_fw+bw+batchload", (offset[0], +2)],
-                      "train_fwbw_incl_batch_aug_msrmts": ["train_fw+bw+batchload+aug", (-0.2, +2)],
-                      "val_fw_incl_batch_msrmts": ["val_fw+batchload", (offset[0], -4)],
-                      "val_loading": ["val_load", (-0.06, -4)],
-                      "train_loading_wo_bal_fg_aug": ["train_load_w/o_bal,fg,aug", (offset[0], 2)],
-                      "train_loading_wo_balancing": ["train_load_w/o_balancing", (-0.05, 2)],
-                      "train_loading_wo_aug": ["train_load_w/o_aug", (offset[0], 2)],
-                      "train_loading_wo_bal_fg": ["train_load_w/o_bal,fg", (offset[0], -4)],
-                      "train_loading": ["train_load", (+0.01, -1.3)]
-                      }
-    dfs = OrderedDict()
-    for file in os.listdir(results_dir):
-        if os.path.splitext(file)[-1]==".pickle":
-           dfs[file.split(os.sep)[-1].split(".")[0]] = pd.read_pickle(os.path.join(results_dir,file))
-
-
-    for i, (name, df) in enumerate(dfs.items()):
-        time = (df["rel_time"].iloc[-1] - df["rel_time"].iloc[0])/iters
-        gpu_u = df["gpu_utilization (%)"].values.astype(int).mean()
-
-        color = cf.color_palette[i%len(cf.color_palette)]
-        ax.scatter(time, gpu_u, color=color, marker=markers)
-        if name in point_renaming.keys():
-            name = point_renaming[name]
-            if isinstance(name, list):
-                offset = name[1]
-                name = name[0]
-        ax.text(time+offset[0], gpu_u+offset[1], name, color=color)
-
-def analyze_measurements(cf, results_dir, iters, title=""):
-    fig, ax = plg.plt.subplots(1, 1)
-
-    settings = [(results_dir, iters, 'o'), (os.path.join(results_dir, "200iters_pre_optim"), 200, 'v', (-0.08, 2)),
-                (os.path.join(results_dir, "200iters_after_optim"), 200, 'o')]
-    for args in settings:
-        plot_folder(cf, ax, *args)
-    labels = ["after optim", "pre optim"]
-    handles = [Line2D([0], [0], marker=settings[i][2], label=labels[i], color="w", markerfacecolor=cf.black, markersize=10)
-               for i in range(len(settings[:2]))]
-    plg.plt.legend(handles=handles, loc="best")
-    ax.set_xlim(0,ax.get_xlim()[1]*1.05)
-    ax.set_ylim(0, 100)
-    ax.set_ylabel("Mean GPU Utilization (%)")
-    ax.set_xlabel("Runtime (s)")
-    plg.plt.title(title+"GPU utilization vs Method Runtime\nMean Over {} Iterations".format(iters))
-
-    major_ticks = np.arange(0, 101, 10)
-    minor_ticks = np.arange(0, 101, 5)
-    ax.set_yticks(major_ticks)
-    ax.set_yticks(minor_ticks, minor=True)
-    ax.grid(which='minor', alpha=0.2)
-    ax.grid(which='major', alpha=0.5)
-
-
-    plg.plt.savefig(os.path.join(results_dir, "measurements.png"))
-
-
-
-    return
-
-
-if __name__=="__main__":
-    class Args():
-        def __init__(self):
-            self.dataset_name = "datasets/prostate"
-            self.exp_dir = "datasets/prostate/experiments/dev"
-            self.server_env = False
-
-
-    args = Args()
-
-    sys.path.append(args.dataset_name)
-    import data_loader
-    from configs import Configs
-    cf = configs(args.server_env)
-    iters = 200
-    results_dir, iters = perform_measurements(args, iters=iters)
-    results_dir = "/home/gregor/Documents/medicaldetectiontoolkit/code_optim/" + cf.model
-    analyze_measurements(cf, results_dir, iters=iters, title=cf.model+": ")
-
-
diff --git a/datasets/cityscapes/configs.py b/datasets/cityscapes/configs.py
deleted file mode 100644
index ed2cdab..0000000
--- a/datasets/cityscapes/configs.py
+++ /dev/null
@@ -1,434 +0,0 @@
-__author__ = ''
-#credit Paul F. Jaeger
-
-#########################
-#     Example Config    #
-#########################
-
-import os
-import sys
-
-import numpy as np
-from collections import namedtuple
-
-sys.path.append('../')
-from default_configs import DefaultConfigs
-
-class Configs(DefaultConfigs):
-
-    def __init__(self, server_env=None):
-        super(Configs, self).__init__(server_env)
-
-        self.dim = 2
-
-        #########################
-        #         I/O           #
-        #########################
-
-        self.data_sourcedir = "/mnt/HDD2TB/Documents/data/cityscapes/cs_20190715/"
-        if server_env:
-            #self.source_dir = '/home/ramien/medicaldetectiontoolkit/'
-            self.data_sourcedir = '/datasets/data_ramien/cityscapes/cs_20190715_npz/'
-            #self.data_sourcedir = "/mnt/HDD2TB/Documents/data/cityscapes/cs_6c_inst_only/"
-
-        self.datapath = "leftImg8bit/"
-        self.targetspath = "gtFine/"
-        
-        self.cities = {'train':['dusseldorf', 'aachen', 'bochum', 'cologne', 'erfurt',
-                                'hamburg', 'hanover', 'jena', 'krefeld', 'monchengladbach', 
-                                'strasbourg', 'stuttgart', 'tubingen', 'ulm', 'weimar',
-                                'zurich'], 
-                        'val':['frankfurt', 'munster'], 
-                        'test':['bremen', 'darmstadt', 'lindau'] }
-        self.set_splits = ["train", "val", "test"] # for training and val, mixed up
-        # test cities are not held out
-
-        self.info_dict_name = 'city_info.pkl'
-        self.info_dict_path = os.path.join(self.data_sourcedir, self.info_dict_name)
-        self.config_path = os.path.realpath(__file__)
-        self.backbone_path = 'models/backbone.py'
-
-        # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'detection_fpn'].
-        self.model = 'retina_unet'
-        self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net')
-        self.model_path = os.path.join(self.source_dir, self.model_path)
-
-        self.select_prototype_subset = None
-            
-        #########################
-        #      Preprocessing    #
-        #########################
-        self.prepro = {
-            
-		 'data_dir': '/mnt/HDD2TB/Documents/data/cityscapes_raw/', #raw files (input), needs to end with "/"
-         'targettype': "gtFine_instanceIds",
-         'set_splits': ["train", "val", "test"],
-         
-         'img_target_size': np.array([256, 512])*4, #y,x
-         
-         'output_directory': self.data_sourcedir,
-		 
-         'center_of_mass_crop': True, #not implemented
-         #'pre_crop_size': , #z,y,x
-		 'normalization': {'percentiles':[1., 99.]},#not implemented
-	     'interpolation': 'nearest', #not implemented
-         
-         'info_dict_path': self.info_dict_path,
-         
-         'npz_dir' : self.data_sourcedir[:-1]+"_npz" #if not None: convert to npz, copy data here
-         }
-
-        #########################
-        #      Architecture     #
-        #########################
-        # 'class', 'regression', 'regression_ken_gal'
-        # 'class': standard object classification per roi, pairwise combinable with each of below tasks.
-        # 'class' is only option implemented for CityScapes data set.
-        self.prediction_tasks = ['class',]
-        self.start_filts = 52
-        self.end_filts = self.start_filts * 4
-        self.res_architecture = 'resnet101'  # 'resnet101' , 'resnet50'
-        self.weight_init = None  # 'kaiming', 'xavier' or None for pytorch default
-        self.norm = 'instance_norm'  # 'batch_norm' # one of 'None', 'instance_norm', 'batch_norm'
-        self.relu = 'relu'
-
-        #########################
-        #      Data Loader      #
-        #########################
-
-        self.seed = 17
-        self.n_workers = 16 if server_env else os.cpu_count()
-        
-        self.batch_size = 8
-        self.n_cv_splits = 10 #at least 2 (train, val)
-        
-        self.num_classes = None #set below #for instance classification (excl background)
-        self.num_seg_classes = None #set below #incl background
-        
-        self.create_bounding_box_targets = True
-        self.class_specific_seg = True
-        
-        self.channels = [0,1,2] 
-        self.pre_crop_size = self.prepro['img_target_size'] # y,x
-        self.crop_margin   = [10,10] #has to be smaller than respective patch_size//2
-        self.patch_size_2D = [256, 512] #self.pre_crop_size #would be better to save as tuple since should not be altered
-        self.patch_size_3D = self.patch_size_2D + [1]
-        self.patch_size = self.patch_size_2D
-
-        self.balance_target = "class_targets"
-        # ratio of fully random patients drawn during batch generation
-        # resulting batch random count is rounded down to closest integer
-        self.batch_random_ratio = 0.2
-
-        self.observables_patient = []
-        self.observables_rois = []
-        
-        #########################
-        #   Data Augmentation   #
-        #########################
-        #the angle rotations are implemented incorrectly in batchgenerators! in 2D,
-        #the x-axis angle controls the z-axis angle.
-        self.do_aug = True
-        self.da_kwargs = {
-            'mirror': True,
-            'mirror_axes': (1,), #image axes, (batch and channel are ignored, i.e., actual tensor dims are +2)
-        	'random_crop': True,
-        	'rand_crop_dist': (self.patch_size[0] / 2., self.patch_size[1] / 2.),
-        	'do_elastic_deform': True,
-        	'alpha': (0., 1000.),
-        	'sigma': (28., 30.),
-        	'do_rotation': True,
-        	'angle_x': (-np.pi / 8., np.pi / 8.),
-        	'angle_y': (0.,0.),
-        	'angle_z': (0.,0.),
-        	'do_scale': True,
-        	'scale': (0.6, 1.4),
-        	'border_mode_data': 'constant',
-            'gamma_range': (0.6, 1.4)
-        }        
-        
-        #################################
-        #  Schedule / Selection / Optim #
-        #################################
-        #mrcnn paper: ~2.56m samples seen during coco-dataset training
-        self.num_epochs = 400
-        self.num_train_batches = 600
-        
-        self.do_validation = True
-        # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training)
-        # the former is morge accurate, while the latter is faster (depending on volume size)
-        self.val_mode = 'val_sampling' # one of 'val_sampling', 'val_patient'
-        # if 'all' iterates over entire val_set once.
-        self.num_val_batches = "all" # for val_sampling
-        
-        self.save_n_models = 3
-        self.min_save_thresh = 1 # in epochs
-        self.model_selection_criteria = {"human_ap": 1., "vehicle_ap": 0.9}
-        self.warm_up = 0
-
-        self.learning_rate = [5*1e-4] * self.num_epochs
-        self.dynamic_lr_scheduling = True #with scheduler set in exec
-        self.lr_decay_factor = 0.5
-        self.scheduling_patience = int(self.num_epochs//10)
-        self.weight_decay = 1e-6
-        self.clip_norm = None  # number or None
-
-        #########################
-        #   Colors and Legends  #
-        #########################
-        self.plot_frequency = 5
-
-        #colors
-        self.color_palette = [self.red, self.blue, self.green, self.orange, self.aubergine,
-                              self.yellow, self.gray, self.cyan, self.black]
-        
-        #legends
-        Label = namedtuple( 'Label' , [
-            'name'        , # The identifier of this label, e.g. 'car', 'person', ... .
-                            # We use them to uniquely name a class
-            'ppId'          , # An integer ID that is associated with this label.
-                            # The IDs are used to represent the label in ground truth images
-                            # An ID of -1 means that this label does not have an ID and thus
-                            # is ignored when creating ground truth images (e.g. license plate).
-                            # Do not modify these IDs, since exactly these IDs are expected by the
-                            # evaluation server.
-            'id'     , # Feel free to modify these IDs as suitable for your method.
-                            # Max value is 255!
-            'category'    , # The name of the category that this label belongs to
-            'categoryId'  , # The ID of this category. Used to create ground truth images
-                            # on category level.
-            'hasInstances', # Whether this label distinguishes between single instances or not
-            'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
-                            # during evaluations or not
-            'color'       , # The color of this label
-            ] )
-        segLabel = namedtuple( "segLabel", ["name", "id", "color"])
-        boxLabel = namedtuple( 'boxLabel', [ "name", "color"]) 
-        
-        self.labels = [
-            #       name                   ppId         id   category            catId     hasInstances   ignoreInEval   color
-            Label(  'ignore'               ,  0 ,        0 , 'void'            , 0       , False        , True         , (  0.,  0.,  0., 1.) ),
-            Label(  'ego vehicle'          ,  1 ,        0 , 'void'            , 0       , False        , True         , (  0.,  0.,  0., 1.) ),
-            Label(  'rectification border' ,  2 ,        0 , 'void'            , 0       , False        , True         , (  0.,  0.,  0., 1.) ),
-            Label(  'out of roi'           ,  3 ,        0 , 'void'            , 0       , False        , True         , (  0.,  0.,  0., 1.) ),
-            Label(  'static'               ,  4 ,        0 , 'void'            , 0       , False        , True         , (  0.,  0.,  0., 1.) ),
-            Label(  'dynamic'              ,  5 ,        0 , 'void'            , 0       , False        , True         , (0.44, 0.29,  0., 1.) ),
-            Label(  'ground'               ,  6 ,        0 , 'void'            , 0       , False        , True         , ( 0.32,  0., 0.32, 1.) ),
-            Label(  'road'                 ,  7 ,        0 , 'flat'            , 1       , False        , False        , (0.5, 0.25, 0.5, 1.) ),
-            Label(  'sidewalk'             ,  8 ,        0 , 'flat'            , 1       , False        , False        , (0.96, 0.14, 0.5, 1.) ),
-            Label(  'parking'              ,  9 ,        0 , 'flat'            , 1       , False        , True         , (0.98, 0.67, 0.63, 1.) ),
-            Label(  'rail track'           , 10 ,        0 , 'flat'            , 1       , False        , True         , ( 0.9,  0.59, 0.55, 1.) ),
-            Label(  'building'             , 11 ,        0 , 'construction'    , 2       , False        , False        , ( 0.27, 0.27, 0.27, 1.) ),
-            Label(  'wall'                 , 12 ,        0 , 'construction'    , 2       , False        , False        , (0.4,0.4,0.61, 1.) ),
-            Label(  'fence'                , 13 ,        0 , 'construction'    , 2       , False        , False        , (0.75,0.6,0.6, 1.) ),
-            Label(  'guard rail'           , 14 ,        0 , 'construction'    , 2       , False        , True         , (0.71,0.65,0.71, 1.) ),
-            Label(  'bridge'               , 15 ,        0 , 'construction'    , 2       , False        , True         , (0.59,0.39,0.39, 1.) ),
-            Label(  'tunnel'               , 16 ,        0 , 'construction'    , 2       , False        , True         , (0.59,0.47, 0.35, 1.) ),
-            Label(  'pole'                 , 17 ,        0 , 'object'          , 3       , False        , False        , (0.6,0.6,0.6, 1.) ),
-            Label(  'polegroup'            , 18 ,        0 , 'object'          , 3       , False        , True         , (0.6,0.6,0.6, 1.) ),
-            Label(  'traffic light'        , 19 ,        0 , 'object'          , 3       , False        , False        , (0.98,0.67, 0.12, 1.) ),
-            Label(  'traffic sign'         , 20 ,        0 , 'object'          , 3       , False        , False        , (0.86,0.86, 0., 1.) ),
-            Label(  'vegetation'           , 21 ,        0 , 'nature'          , 4       , False        , False        , (0.42,0.56, 0.14, 1.) ),
-            Label(  'terrain'              , 22 ,        0 , 'nature'          , 4       , False        , False        , (0.6, 0.98,0.6, 1.) ),
-            Label(  'sky'                  , 23 ,        0 , 'sky'             , 5       , False        , False        , (0.27,0.51,0.71, 1.) ),
-            Label(  'person'               , 24 ,        1 , 'human'           , 6       , True         , False        , (0.86, 0.08, 0.24, 1.) ),
-            Label(  'rider'                , 25 ,        1 , 'human'           , 6       , True         , False        , (1.,  0.,  0., 1.) ),
-            Label(  'car'                  , 26 ,        2 , 'vehicle'         , 7       , True         , False        , (  0., 0.,0.56, 1.) ),
-            Label(  'truck'                , 27 ,        2 , 'vehicle'         , 7       , True         , False        , (  0.,  0., 0.27, 1.) ),
-            Label(  'bus'                  , 28 ,        2 , 'vehicle'         , 7       , True         , False        , (  0., 0.24,0.39, 1.) ),
-            Label(  'caravan'              , 29 ,        2 , 'vehicle'         , 7       , True         , True         , (  0.,  0., 0.35, 1.) ),
-            Label(  'trailer'              , 30 ,        2 , 'vehicle'         , 7       , True         , True         , (  0.,  0.,0.43, 1.) ),
-            Label(  'train'                , 31 ,        2 , 'vehicle'         , 7       , True         , False        , (  0., 0.31,0.39, 1.) ),
-            Label(  'motorcycle'           , 32 ,        2 , 'vehicle'         , 7       , True         , False        , (  0.,  0., 0.9, 1.) ),
-            Label(  'bicycle'              , 33 ,        2 , 'vehicle'         , 7       , True         , False        , (0.47, 0.04, 0.13, 1.) ),
-            Label(  'license plate'        , -1 ,        0 , 'vehicle'         , 7       , False        , True         , (  0.,  0., 0.56, 1.) ),
-            Label(  'background'           , -1 ,        0 , 'void'            , 0       , False        , True         , (  0.,  0., 0.0, 0.) ),
-            Label(  'vehicle'              , 33 ,        2 , 'vehicle'         , 7       , True         , False        , (*self.aubergine, 1.)  ),
-            Label(  'human'                , 25 ,        1 , 'human'           , 6       , True         , False        , (*self.blue, 1.) )
-        ]
-        # evtl problem: class-ids (trainIds) don't start with 0 for the first class, 0 is bg.
-        #WONT WORK: class ids need to start at 0 (excluding bg!) and be consecutively numbered 
-
-        self.ppId2id = { label.ppId : label.id for label in self.labels}
-        self.class_id2label = { label.id : label for label in self.labels}
-        self.class_cmap = {label.id : label.color for label in self.labels}
-        self.class_dict = {label.id : label.name for label in self.labels if label.id!=0}
-        #c_dict: only for evaluation, remove bg class.
-        
-        self.box_type2label = {label.name : label for label in self.box_labels}
-        self.box_color_palette = {label.name:label.color for label in self.box_labels}
-
-        if self.class_specific_seg:
-            self.seg_labels = [label for label in self.class_id2label.values()]
-        else:
-            self.seg_labels = [
-                    #           name    id  color
-                    segLabel(  "bg" ,   0,  (1.,1.,1.,0.) ),
-                    segLabel(  "fg" ,   1,  (*self.orange, .8))
-                    ]
-
-        self.seg_id2label = {label.id : label for label in self.seg_labels}
-        self.cmap = {label.id : label.color for label in self.seg_labels}
-        
-        self.plot_prediction_histograms = True
-        self.plot_stat_curves = False
-        self.has_colorchannels = True
-        self.plot_class_ids = True
-        
-        self.num_classes = len(self.class_dict)
-        self.num_seg_classes = len(self.seg_labels)
-
-        #########################
-        #   Testing             #
-        #########################
-
-        self.test_aug_axes = None #None or list: choices are 2,3,(2,3)
-        self.held_out_test_set = False
-        self.max_test_patients = 'all' # 'all' for all
-        self.report_score_level = ['rois',]  # choose list from 'patient', 'rois'
-        self.patient_class_of_interest = 1
-        
-        self.metrics = ['ap', 'dice']
-        self.ap_match_ious = [0.1]  # threshold(s) for considering a prediction as true positive
-        # aggregation method for test and val_patient predictions.
-        # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf,
-        # nms = standard non-maximum suppression, or None = no clustering
-        self.clustering = 'wbc'
-        # iou thresh (exclusive!) for regarding two preds as concerning the same ROI
-        self.clustering_iou = 0.1  # has to be larger than desired possible overlap iou of model predictions
-
-        self.min_det_thresh = 0.06
-        self.merge_2D_to_3D_preds = False
-        
-        self.n_test_plots = 1 #per fold and rankself.ap_match_ious = [0.1] #threshold(s) for considering a prediction as true positive
-        self.test_n_epochs = self.save_n_models
-
-
-        #########################
-        # shared model settings #
-        #########################
-
-        # max number of roi candidates to identify per image and class (slice in 2D, volume in 3D)
-        self.n_roi_candidates = 100
-        
-        #########################
-        #   Add model specifics #
-        #########################
-        
-        {'mrcnn': self.add_mrcnn_configs, 'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs
-         }[self.model]()
-
-    def add_mrcnn_configs(self):
-
-        self.scheduling_criterion = max(self.model_selection_criteria, key=self.model_selection_criteria.get)
-        self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
-
-        # number of classes for network heads: n_foreground_classes + 1 (background)
-        self.head_classes = self.num_classes + 1
-
-        # seg_classes here refers to the first stage classifier (RPN) reallY?
-
-        # feed +/- n neighbouring slices into channel dimension. set to None for no context.
-        self.n_3D_context = None
-
-
-        self.frcnn_mode = False
-
-        self.detect_while_training = True
-        # disable the re-sampling of mask proposals to original size for speed-up.
-        # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching),
-        # mask outputs are optional.
-        self.return_masks_in_train = True
-        self.return_masks_in_val = True
-        self.return_masks_in_test = True
-
-        # feature map strides per pyramid level are inferred from architecture. anchor scales are set accordingly.
-        self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]}
-        # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale
-        # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.)
-        self.rpn_anchor_scales = {'xy': [[4], [8], [16], [32]], 'z': [[1], [2], [4], [8]]}
-        # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3.
-        self.pyramid_levels = [0, 1, 2, 3]
-        # number of feature maps in rpn. typically lowered in 3D to save gpu-memory.
-        self.n_rpn_features = 512 if self.dim == 2 else 64
-
-        # anchor ratios and strides per position in feature maps.
-        self.rpn_anchor_ratios = [0.5, 1., 2.]
-        self.rpn_anchor_stride = 1
-        # Threshold for first stage (RPN) non-maximum suppression (NMS):  LOWER == HARDER SELECTION
-        self.rpn_nms_threshold = 0.7
-
-        # loss sampling settings.
-        self.rpn_train_anchors_per_image = 8
-        self.train_rois_per_image = 10  # per batch_instance
-        self.roi_positive_ratio = 0.5
-        self.anchor_matching_iou = 0.8
-
-        # k negative example candidates are drawn from a pool of size k*shem_poolsize (stochastic hard-example mining),
-        # where k<=#positive examples.
-        self.shem_poolsize = 3
-
-        self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3)
-        self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5)
-        self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10)
-
-        self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
-        self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
-        self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]])
-        self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1],
-                               self.patch_size_3D[2], self.patch_size_3D[2]])  # y1,x1,y2,x2,z1,z2
-
-        if self.dim == 2:
-            self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4]
-            self.bbox_std_dev = self.bbox_std_dev[:4]
-            self.window = self.window[:4]
-            self.scale = self.scale[:4]
-
-        self.plot_y_max = 1.5
-        self.n_plot_rpn_props = 5 # per batch_instance (slice in 2D / patient in 3D)
-
-        # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element.
-        self.pre_nms_limit = 3000
-
-        # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True,
-        # since proposals of the entire batch are forwarded through second stage as one "batch".
-        self.roi_batch_size = 2500
-        self.post_nms_rois_training = 500
-        self.post_nms_rois_inference = 500
-
-        # Final selection of detections (refine_detections)
-        self.model_max_instances_per_batch_element = 50 # per batch element and class.
-        self.detection_nms_threshold = 1e-5  # needs to be > 0, otherwise all predictions are one cluster.
-        self.model_min_confidence = 0.05  # iou for nms in box refining (directly after heads), should be >0 since ths>=x in mrcnn.py
-
-        if self.dim == 2:
-            self.backbone_shapes = np.array(
-                [[int(np.ceil(self.patch_size[0] / stride)),
-                  int(np.ceil(self.patch_size[1] / stride))]
-                 for stride in self.backbone_strides['xy']])
-        else:
-            self.backbone_shapes = np.array(
-                [[int(np.ceil(self.patch_size[0] / stride)),
-                  int(np.ceil(self.patch_size[1] / stride)),
-                  int(np.ceil(self.patch_size[2] / stride_z))]
-                 for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z']
-                                             )])
-
-        if self.model == 'retina_net' or self.model == 'retina_unet':
-            # implement extra anchor-scales according to https://arxiv.org/abs/1708.02002
-            self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
-                                            self.rpn_anchor_scales['xy']]
-            self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
-                                           self.rpn_anchor_scales['z']]
-            self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3
-
-            self.n_rpn_features = 256 if self.dim == 2 else 64
-
-            # pre-selection of detections for NMS-speedup. per entire batch.
-            self.pre_nms_limit = 10000 if self.dim == 2 else 30000
-
-            # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002
-            self.anchor_matching_iou = 0.5
-
-            if self.model == 'retina_unet':
-                self.operate_stride1 = True
\ No newline at end of file
diff --git a/datasets/cityscapes/data_loader.py b/datasets/cityscapes/data_loader.py
deleted file mode 100644
index 01a1a45..0000000
--- a/datasets/cityscapes/data_loader.py
+++ /dev/null
@@ -1,452 +0,0 @@
-import sys
-sys.path.append('../') #works on cluster indep from where sbatch job is started
-import plotting as plg
-
-import warnings
-import os
-import time
-import pickle
-
-
-import numpy as np
-import pandas as pd
-from PIL import Image as pil
-
-import torch
-import torch.utils.data
-
-# batch generator tools from https://github.com/MIC-DKFZ/batchgenerators
-from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror
-from batchgenerators.transforms.abstract_transforms import Compose
-from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
-from batchgenerators.transforms.spatial_transforms import SpatialTransform
-from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform
-from batchgenerators.transforms.color_transforms import GammaTransform
-#from batchgenerators.transforms.utility_transforms import ConvertSegToBoundingBoxCoordinates
-
-
-sys.path.append(os.path.dirname(os.path.realpath(__file__)))
-
-import utils.exp_utils as utils
-import utils.dataloader_utils as dutils
-from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates
-
-from configs import Configs
-cf= configs()
-
-
-warnings.filterwarnings("ignore", message="This figure includes Axes.*")
-
-
-def load_obj(file_path):
-    with open(file_path, 'rb') as handle:
-        return pickle.load(handle)
-
-def save_to_npy(arr_out, array):
-    np.save(arr_out+".npy", array) 
-    print("Saved binary .npy-file to {}".format(arr_out))
-    return arr_out+".npy"
-
-def shape_small_first(shape):
-    if len(shape)<=2: #no changing dimensions if channel-dim is missing
-        return shape
-    smallest_dim = np.argmin(shape)
-    if smallest_dim!=0: #assume that smallest dim is color channel
-        new_shape = np.array(shape) #to support mask indexing
-        new_shape = (new_shape[smallest_dim],
-                    *new_shape[(np.arange(len(shape),dtype=int)!=smallest_dim)])
-        return new_shape
-    else:
-        return shape
-       
-class Dataset(dutils.Dataset):
-    def __init__(self,  cf, logger=None, subset_ids=None, data_sourcedir=None):
-        super(Dataset, self).__init__(cf, data_sourcedir=data_sourcedir)
-
-        info_dict = load_obj(cf.info_dict_path)
-
-        if subset_ids is not None:
-            img_ids = subset_ids
-            if logger is None:
-                print('subset: selected {} instances from df'.format(len(pids)))
-            else:
-                logger.info('subset: selected {} instances from df'.format(len(pids)))
-        else:
-            img_ids = list(info_dict.keys())
-
-        #evtly copy data from data_rootdir to data_dir
-        if cf.server_env and not hasattr(cf, "data_dir"):
-            file_subset = [info_dict[img_id]['img'][:-3]+"*" for img_id in img_ids]
-            file_subset+= [info_dict[img_id]['seg'][:-3]+"*" for img_id in img_ids]
-            file_subset+= [cf.info_dict_path]
-            self.copy_data(cf, file_subset=file_subset)
-            cf.data_dir = self.data_dir
-
-        img_paths = [os.path.join(self.data_dir, info_dict[img_id]['img']) for img_id in img_ids]
-        seg_paths = [os.path.join(self.data_dir, info_dict[img_id]['seg']) for img_id in img_ids]
-
-        # load all subject files
-        self.data = {}
-        for i, img_id in enumerate(img_ids):
-            subj_data = {'img_id':img_id}
-            subj_data['img'] = img_paths[i]
-            subj_data['seg'] = seg_paths[i]
-            if 'class' in self.cf.prediction_tasks:
-                subj_data['class_targets'] = np.array(info_dict[img_id]['roi_classes'])
-            else:
-                subj_data['class_targets'] = np.ones_like(np.array(info_dict[img_id]['roi_classes']))
-    
-            self.data[img_id] = subj_data
-
-        cf.roi_items = cf.observables_rois[:]
-        cf.roi_items += ['class_targets']
-        if 'regression' in cf.prediction_tasks:
-            cf.roi_items += ['regression_targets']
-
-        self.set_ids = list(self.data.keys())
-        
-        self.df = None
-
-class BatchGenerator(dutils.BatchGenerator):
-    """
-    create the training/validation batch generator. Randomly sample batch_size patients
-    from the data set, (draw a random slice if 2D), pad-crop them to equal sizes and merge to an array.
-    :param data: data dictionary as provided by 'load_dataset'
-    :param img_modalities: list of strings ['adc', 'b1500'] from config
-    :param batch_size: number of patients to sample for the batch
-    :param pre_crop_size: equal size for merging the patients to a single array (before the final random-crop in data aug.)
-    :return dictionary containing the batch data / seg / pids as lists; the augmenter will later concatenate them into an array.
-    """
-    def __init__(self, cf, data, n_batches=None, sample_pids_w_replace=True):
-        super(BatchGenerator, self).__init__(cf, data, n_batches)
-        self.dataset_length = len(self._data)
-        self.cf = cf
-
-        self.sample_pids_w_replace = sample_pids_w_replace
-        self.eligible_pids = list(self._data.keys())
-
-        self.chans = cf.channels if cf.channels is not None else np.index_exp[:]
-        assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing"
-
-        self.p_fg = 0.5
-        self.empty_samples_max_ratio = 0.33
-        self.random_count = int(cf.batch_random_ratio * cf.batch_size)
-
-        self.balance_target_distribution(plot=sample_pids_w_replace)
-        self.stats = {"roi_counts" : np.zeros((len(self.unique_ts),), dtype='uint32'), "empty_samples_count" : 0}
-
-    def generate_train_batch(self):
-        #everything done in here is per batch
-        #print statements in here get confusing due to multithreading
-        if self.sample_pids_w_replace:
-            # fully random patients
-            batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False))
-            # target-balanced patients
-            batch_patient_ids += list(np.random.choice(
-                self.dataset_pids, size=self.batch_size - self.random_count, replace=False, p=self.p_probs))
-        else:
-            batch_patient_ids = np.random.choice(self.eligible_pids, size=self.batch_size, replace=False)
-        if self.sample_pids_w_replace == False:
-            self.eligible_pids = [pid for pid in self.eligible_pids if pid not in batch_patient_ids]
-            if len(self.eligible_pids) < self.batch_size:
-                self.eligible_pids = self.dataset_pids
-        
-        batch_data, batch_segs, batch_class_targets = [], [], []
-        # record roi count of classes in batch
-        batch_roi_counts, empty_samples_count = np.zeros((self.cf.num_classes,), dtype='uint32'), 0
-
-        for sample in range(self.batch_size):
-
-            patient = self._data[batch_patient_ids[sample]]
-            
-            data = np.load(patient["img"], mmap_mode="r")
-            seg = np.load(patient['seg'], mmap_mode="r")
-            
-            (c,y,x) = data.shape
-            spatial_shp = data[0].shape
-            assert spatial_shp==seg.shape, "spatial shape incongruence betw. data {} and seg {}".format(spatial_shp, seg.shape)
-
-            if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]):
-                new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))]
-                data = dutils.pad_nd_image(data, (len(data), *new_shape))
-                seg = dutils.pad_nd_image(seg, new_shape)
-            
-            #eventual cropping to pre_crop_size: with prob self.p_fg sample pixel from random ROI and shift center,
-            #if possible, to that pixel, so that img still contains ROI after pre-cropping
-            dim_cropflags = [spatial_shp[i] > self.cf.pre_crop_size[i] for i in range(len(spatial_shp))]
-            if np.any(dim_cropflags):
-                #sample crop center regardless of ROIs, not guaranteed to be empty
-                def get_cropped_centercoords(dim):                        
-                    return np.random.randint(low=self.cf.pre_crop_size[dim]//2,
-                                             high=spatial_shp[dim] - self.cf.pre_crop_size[dim]//2)
-                    
-                sample_seg_center = {}
-                for dim in np.where(dim_cropflags)[0]:
-                    sample_seg_center[dim] = get_cropped_centercoords(dim)
-                    min_ = int(sample_seg_center[dim] - self.cf.pre_crop_size[dim]//2)
-                    max_ = int(sample_seg_center[dim] + self.cf.pre_crop_size[dim]//2)
-                    data = np.take(data, indices=range(min_, max_), axis=dim+1) #+1 for channeldim
-                    seg = np.take(seg, indices=range(min_, max_), axis=dim)
-                    
-            batch_data.append(data)
-            batch_segs.append(seg[np.newaxis])
-                
-            batch_class_targets.append(patient['class_targets'])
-
-            for cl in range(self.cf.num_classes):
-                batch_roi_counts[cl] += np.count_nonzero(patient['class_targets'][np.unique(seg[seg>0]) - 1] == cl)
-            if not np.any(seg):
-                empty_samples_count += 1
-        
-        batch = {'data': np.array(batch_data).astype('float32'), 'seg': np.array(batch_segs).astype('uint8'),
-                 'pid': batch_patient_ids, 'class_targets': np.array(batch_class_targets),
-                 'roi_counts': batch_roi_counts, 'empty_samples_count': empty_samples_count}
-        return batch
-
-class PatientBatchIterator(dutils.PatientBatchIterator):
-    """
-    creates a val/test generator. Step through the dataset and return dictionaries per patient.
-    For Patching, shifts all patches into batch dimension. batch_tiling_forward will take care of exceeding batch dimensions.
-
-    This iterator/these batches are not intended to go through MTaugmenter afterwards
-    """
-
-    def __init__(self, cf, data):
-        super(PatientBatchIterator, self).__init__(cf, data)
-
-        self.patch_size = cf.patch_size
-
-        self.patient_ix = 0  # running index over all patients in set
-
-    def generate_train_batch(self, pid=None):
-
-        if self.patient_ix == len(self.dataset_pids):
-            self.patient_ix = 0
-        if pid is None:
-            pid = self.dataset_pids[self.patient_ix]  # + self.thread_id
-        patient = self._data[pid]
-        batch_class_targets = np.array([patient['class_targets']])
-
-        data = np.load(patient["img"], mmap_mode="r")[np.newaxis]
-        seg = np.load(patient['seg'], mmap_mode="r")[np.newaxis, np.newaxis]
-        (b, c, y, x) = data.shape
-        spatial_shp = data.shape[2:]
-        assert spatial_shp == seg.shape[2:], "spatial shape incongruence betw. data {} and seg {}".format(spatial_shp,
-                                                                                                      seg.shape)
-        if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]):
-            new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))]
-            data = dutils.pad_nd_image(data, (len(data), *new_shape))
-            seg = dutils.pad_nd_image(seg, new_shape)
-
-        batch = {'data': data, 'seg': seg, 'class_targets': batch_class_targets}
-        converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, self.cf.class_specific_seg)
-        batch = converter(**batch)
-        batch.update({'patient_bb_target': batch['bb_target'],
-                      'patient_class_targets': batch['class_targets'],
-                      'original_img_shape': data.shape,
-                      'pid': np.array([pid] * len(data))})
-
-        # eventual tiling into patches
-        spatial_shp = batch["data"].shape[2:]
-        if np.any([spatial_shp[ix] > self.patch_size[ix] for ix in range(len(spatial_shp))]):
-            patient_batch = batch
-            print("patientiterator produced patched batch!")
-            patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size)
-            new_img_batch, new_seg_batch = [], []
-
-            for c in patch_crop_coords_list:
-                new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3]])
-                seg_patch = seg[:, c[0]:c[1], c[2]: c[3]]
-                new_seg_batch.append(seg_patch)
-
-            shps = []
-            for arr in new_img_batch:
-                shps.append(arr.shape)
-
-            data = np.array(new_img_batch)  # (patches, c, x, y, z)
-            seg = np.array(new_seg_batch)
-            batch_class_targets = np.repeat(batch_class_targets, len(patch_crop_coords_list), axis=0)
-
-            patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'),
-                           'class_targets': batch_class_targets,
-                           'pid': np.array([pid] * data.shape[0])}
-            patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
-            patch_batch['patient_bb_target'] = patient_batch['patient_bb_target']
-            patch_batch['patient_class_targets'] = patient_batch['patient_class_targets']
-            patch_batch['patient_data'] = patient_batch['data']
-            patch_batch['patient_seg'] = patient_batch['seg']
-            patch_batch['original_img_shape'] = patient_batch['original_img_shape']
-
-            converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, self.cf.class_specific_seg)
-            patch_batch = converter(**patch_batch)
-            batch = patch_batch
-
-        self.patient_ix += 1
-        if self.patient_ix == len(self.dataset_pids):
-            self.patient_ix = 0
-
-        return batch
-
-def create_data_gen_pipeline(cf, patient_data, do_aug=True, sample_pids_w_replace=True):
-    """
-    create mutli-threaded train/val/test batch generation and augmentation pipeline.
-    :param patient_data: dictionary containing one dictionary per patient in the train/test subset
-    :param test_pids: (optional) list of test patient ids, calls the test generator.
-    :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing)
-    :return: multithreaded_generator
-    """
-    data_gen = BatchGenerator(cf, patient_data, sample_pids_w_replace=sample_pids_w_replace)
-
-    my_transforms = []
-    if do_aug:
-        if cf.da_kwargs["mirror"]:
-            mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes'])
-            my_transforms.append(mirror_transform)
-        spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
-                                         patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'][:2],
-                                         do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
-                                         alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
-                                         do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
-                                         angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
-                                         do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
-                                         random_crop=cf.da_kwargs['random_crop'],
-                                         border_mode_data=cf.da_kwargs['border_mode_data'])
-        my_transforms.append(spatial_transform)
-        gamma_transform = GammaTransform(gamma_range=cf.da_kwargs["gamma_range"], invert_image=False,
-                                         per_channel=False, retain_stats=False)
-        my_transforms.append(gamma_transform)
-    
-    else:
-        my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))
-
-    if cf.create_bounding_box_targets:
-        my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg))
-        #batch receives entry 'bb_target' w bbox coordinates as [y1,x1,y2,x2,z1,z2].
-    #my_transforms.append(ConvertSegToOnehotTransform(classes=range(cf.num_seg_classes)))
-    all_transforms = Compose(my_transforms)
-    #MTAugmenter creates iterator from data iterator data_gen after applying the composed transform all_transforms
-    multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers,
-                                                     seeds=np.random.randint(0,cf.n_workers*2,size=cf.n_workers))
-    return multithreaded_generator
-
-
-def get_train_generators(cf, logger, data_statistics=True):
-    """
-    wrapper function for creating the training batch generator pipeline. returns the train/val generators
-    need to select cv folds on patient level, but be able to include both breasts of each patient.
-    """
-    dataset = Dataset(cf)
-    dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits)
-    dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle'))
-    set_splits = dataset.fg.splits
-
-    test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1)
-    train_ids = np.concatenate(set_splits, axis=0)
-
-    if cf.held_out_test_set:
-        train_ids = np.concatenate((train_ids, test_ids), axis=0)
-        test_ids = []
-
-    train_data = {k: v for (k, v) in dataset.data.items() if k in train_ids}
-    val_data = {k: v for (k, v) in dataset.data.items() if k in val_ids}
-
-    logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids),
-                                                                                    len(test_ids)))
-    if data_statistics:
-        dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids},
-                                plot_dir=os.path.join(cf.plot_dir, "data_stats_fold_"+str(cf.fold)))
-
-    batch_gen = {}
-    batch_gen['train'] = create_data_gen_pipeline(cf, train_data, do_aug=True)
-    batch_gen[cf.val_mode] = create_data_gen_pipeline(cf, val_data, do_aug=False, sample_pids_w_replace=False)
-    batch_gen['n_val'] = cf.num_val_batches if cf.num_val_batches!="all" else len(val_data)
-        
-    return batch_gen
-
-def get_test_generator(cf, logger):
-    """
-    if get_test_generators is called multiple times in server env, every time of 
-    Dataset initiation rsync will check for copying the data; this should be okay
-    since rsync will not copy if files already exist in destination.
-    """
-
-    if cf.held_out_test_set:
-        sourcedir = cf.test_data_sourcedir
-        test_ids = None
-    else:
-        sourcedir = None
-        with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle:
-            set_splits = pickle.load(handle)
-        test_ids = set_splits[cf.fold]
-
-
-    test_set = Dataset(cf, test_ids, data_sourcedir=sourcedir)
-    logger.info("data set loaded with: {} test patients".format(len(test_set.set_ids)))
-    batch_gen = {}
-    batch_gen['test'] = PatientBatchIterator(cf, test_set.data)
-    batch_gen['n_test'] = len(test_set.set_ids) if cf.max_test_patients=="all" else min(cf.max_test_patients, len(test_set.set_ids))
-    
-    return batch_gen   
-
-def main():
-    total_stime = time.time()
-    times = {}
-    
-    CUDA = torch.cuda.is_available()
-    print("CUDA available: ", CUDA)
-
-
-    #cf.server_env = True
-    #cf.data_dir = "experiments/dev_data"
-
-    cf.exp_dir = "experiments/dev/"
-    cf.plot_dir = cf.exp_dir+"plots"
-    os.makedirs(cf.exp_dir, exist_ok=True)
-    cf.fold = 0
-    logger = utils.get_logger(cf.exp_dir)
-
-    gens = get_train_generators(cf, logger)
-    train_loader = gens['train']
-    
-    #for i in range(train_loader.dataset_length):
-    #    print("batch", i)
-    stime = time.time()
-    ex_batch = next(train_loader)
-   # plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_extrainbatch.png", has_colorchannels=True, isRGB=True)
-    times["train_batch"] = time.time()-stime
-
-    
-    val_loader = gens['val_sampling']
-    stime = time.time()
-    ex_batch = next(val_loader)
-    times["val_batch"] = time.time()-stime
-    stime = time.time()
-    plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch.png", has_colorchannels=True, isRGB=True, show_gt_boxes=False)
-    times["val_plot"] = time.time()-stime
-    
-    test_loader = get_test_generator(cf, logger)["test"]
-    stime = time.time()
-    ex_batch = next(test_loader)
-    times["test_batch"] = time.time()-stime
-    #plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_expatientbatch.png", has_colorchannels=True, isRGB=True)
-    
-    print(ex_batch["data"].shape)
-
-
-    print("Times recorded throughout:")
-    for (k,v) in times.items():
-        print(k, "{:.2f}".format(v))
-    
-    mins, secs = divmod((time.time() - total_stime), 60)
-    h, mins = divmod(mins, 60)
-    t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) 
-    print("{} total runtime: {}".format(os.path.split(__file__)[1], t))
-    
-
-        
-if __name__=="__main__":
-    start_time = time.time()
-    
-    main()
-    
-    print("Program runtime in s: ", '{:.2f}'.format(time.time()-start_time))
\ No newline at end of file
diff --git a/datasets/cityscapes/preprocessing.py b/datasets/cityscapes/preprocessing.py
deleted file mode 100644
index 56c8c20..0000000
--- a/datasets/cityscapes/preprocessing.py
+++ /dev/null
@@ -1,267 +0,0 @@
-import sys
-import os
-from multiprocessing import Pool
-import time
-import pickle
-
-import numpy as np
-
-from PIL import Image as pil
-from matplotlib import pyplot as plt
-
-sys.path.append("../")
-import data_manager as dmanager
-
-from configs import Configs
-cf = configs()
-
-
-"""
-"""
-
-def load_obj(file_path):
-    with open(file_path, 'rb') as handle:
-        return pickle.load(handle)
-
-def save_obj(obj, path):
-    """Pickle a python object."""
-    with open(path, 'wb') as f:
-        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
-
-def merge_labelids(target, cf=cf):
-    """relabel preprocessing id to training id according to config.labels
-    :param target: np.array hxw holding the annotation (labelids at pixel positions)
-    :cf: The configurations file
-    """
-    for i in range(target.shape[0]):  #Iterate over height.
-        for j in range(target.shape[1]): #Iterate over width
-            target[i][j] = cf.ppId2id[int(target[i][j])]
-            
-    return target
-
-def generate_detection_labels(target, cf=cf):
-    """labels suitable to be used with batchgenerators.ConvertSegToBoundingBoxCoordinates.
-    Flaw: cannot handle more than 2 segmentation classes (fg/bg).
-    --> seg-info is lost, but not interested in seg rn anyway.
-    :param target: expected as instanceIds img
-         The pixel values encode both, class and the individual instance.
-         The integer part of a division by 1000 of each ID provides the class ID,
-         as described in labels.py. The remainder is the instance ID. If a certain
-         annotation describes multiple instances, then the pixels have the regular
-         ID of that class.
-    """
-
-    unique_IDs = np.unique(target)
-    roi_classes = []
-    
-    objs_in_img = 0
-    for i, instanceID in enumerate(unique_IDs):
-        if instanceID > max(list(cf.ppId2id.keys())):
-            instance_classID = instanceID // 1000
-        else:
-            # this is the group case (only class id assigned, no instance id)
-            instance_classID = instanceID 
-            if cf.ppId2id[instance_classID]!=0:
-                #discard this whole sample since it has group instead of 
-                #single instance annotations for a non-bg class
-                return None, None
-        
-        if cf.ppId2id[instance_classID]!=0:
-            #only pick reasonable objects, exclude road, sky, etc.
-            roi_classes.append(cf.ppId2id[instance_classID])
-            objs_in_img+=1 #since 0 is bg
-            target[target==instanceID] = objs_in_img
-        else:
-            target[target==instanceID] = 0
-
-    return target, roi_classes
-
-class Preprocessor():
-    
-    def __init__(self, cf, cities):
-        
-        self._cf = cf.prepro
-        
-        self.rootpath = cf.prepro['data_dir']
-        self.set_splits = self._cf["set_splits"]
-        self.cities = cities
-        self.datapath = cf.datapath
-        self.targetspath = cf.targetspath
-        self.targettype = cf.prepro["targettype"]
-        
-        self.img_t_size = cf.prepro["img_target_size"]
-        self.target_t_size = self.img_t_size
-        
-        self.rootpath_out = cf.prepro["output_directory"]
-        
-        self.info_dict = {}
-        """info_dict: will hold {img_identifier: img_dict} with
-            img_dict = {id: img_identifier, img:img_path, seg:seg_path,
-            roi_classes:roiclasses}
-        """
-     
-    def load_from_path_to_path(self, set_split, max_num=None):
-        """composes data and corresponding labels paths (to .png-files).
-        
-        assumes data tree structure:   datapath-|-->city1-->img1.png,img2.png,...
-                                                |-->city2-->img1.png, ...
-        """
-        data = []
-        labels = []
-        num=0
-        for city in self.cities[set_split]:
-            path = os.path.join(self.rootpath, self.datapath, set_split, city)
-            lpath = os.path.join(self.rootpath,self.targetspath,set_split, city)
-
-            files_in_dir = os.listdir(path)        
-            for file in files_in_dir:
-                split = os.path.splitext(file)
-                if split[1].lower() == ".png":
-                    num+=1
-                    filetag = file[:-(len(self.datapath)+3)]
-                    data.append(os.path.join(path,file))
-                    labels.append(os.path.join(lpath,filetag+self.targettype+".png"))
-                    
-                    if num==max_num:
-                        break
-            if num==max_num:
-                break
-      
-        return data, labels 
-        
-    def prep_img(self, args):
-        """suited for multithreading.
-        :param args: (img_path, targ_path)
-        """           
-
-        img_path, trg_path = args[0], args[1]
-        
-        img_rel_path = img_path[len(self.rootpath):]
-        trg_rel_path = trg_path[len(self.rootpath):]
-        
-        _path, img_name = os.path.split(img_path)        
-        img_identifier = "".join(img_name.split("_")[:3])
-        img_info_dict = {} #entry of img_identifier in full info_dict
-        
-        img, target = pil.open(img_path), pil.open(trg_path)
-        img, target = img.resize(self.img_t_size[::-1]), target.resize(self.target_t_size[::-1])
-        img, target = np.array(img), np.array(target) #shapes y,x(,c)
-        img         = np.transpose(img, axes=(2,0,1)) #shapes (c,)y,x
-        
-        target, roi_classes = generate_detection_labels(target)
-        if target is None:
-            return (img_identifier, target)
-        img_info_dict["roi_classes"] = roi_classes
-
-        path = os.path.join(self.rootpath_out,*img_rel_path.split(os.path.sep)[:-1])
-        os.makedirs(path, exist_ok=True)
-
-        img_path = os.path.join(self.rootpath_out, img_rel_path[:-3]+"npy")
-
-        #img.save(img_path)
-        img_info_dict["img"] = img_rel_path[:-3]+"npy"
-        np.save(img_path, img)
-        
-        path = os.path.join(self.rootpath_out,*trg_rel_path.split(os.path.sep)[:-1])
-        os.makedirs(path, exist_ok=True)
-        t_path = os.path.join(self.rootpath_out, trg_rel_path)[:-3]+"npy"
-        #target.save(t_path)
-        img_info_dict["seg"] = trg_rel_path[:-3]+"npy"
-        np.save(t_path, target)
-            
-        print("\rSaved npy images and targets of shapes {}, {} to files\n {},\n {}". \
-                format(img.shape, target.shape, img_path, t_path), flush=True, end="")
-        
-        return (img_identifier, img_info_dict)
-    
-    def prep_imgs(self, max_num=None, processes=4):
-        self.info_dict = {}
-        self.discarded = []
-        os.makedirs(self.rootpath_out, exist_ok=True)
-        for set_split in self.set_splits:
-            data, targets = self.load_from_path_to_path(set_split, max_num=max_num)
-            
-            print(next(zip(data, targets)))
-            p = Pool(processes)
-            
-            img_info_dicts = p.map(self.prep_img, zip(data, targets))
-
-            p.close()
-            p.join()
-
-            self.info_dict.update({id_:dict_ for (id_,dict_) in img_info_dicts if dict_ is not None})
-            self.discarded += [id_ for (id_, dict_) in img_info_dicts if dict_ is None]
-            #list of samples discarded due to group instead of single instance annotation
-        
-    def finish(self):
-        total_items = len(self.info_dict)+len(self.discarded)
-        
-        print("\n\nSamples discarded: {}/{}={:.1f}%, identifiers:".format(len(self.discarded),
-              total_items, len(self.discarded)/total_items*100))
-        for id_ in self.discarded:
-            print(id_)
-            
-        save_obj(self.info_dict, self._cf["info_dict_path"])
-
-
-    def convert_copy_npz(self):
-        if not self._cf["npz_dir"]:
-            return
-        print("converting & copying to npz dir", self._cf['npz_dir'])
-        os.makedirs(self._cf['npz_dir'], exist_ok=True)
-        save_obj(self.info_dict, os.path.join(self._cf['npz_dir'], 
-                                               self._cf['info_dict_path'].split("/")[-1]))
-        
-        dmanager.pack_dataset(self._cf["output_directory"], self._cf["npz_dir"], recursive=True, verbose=False)
-
-
-    def verification(self, max_num=None):
-        print("\n\n\nVerification\n")
-        for i, k in enumerate(self.info_dict):
-            if max_num is not None and i==max_num:
-                break
-            
-            subject = self.info_dict[k]
-            
-            seg = np.load(os.path.join(self.rootpath_out, subject["seg"]))
-            
-            #print("seg values", np.unique(seg))
-            print("nr of objects", len(subject["roi_classes"]))
-            print("nr of objects should equal highest seg value, fulfilled?",
-                  np.max(seg)==len(subject["roi_classes"]))
-            #print("roi_classes", subject["roi_classes"])
-            
-            img = np.transpose(np.load(os.path.join(self.rootpath_out, subject["img"])), axes=(1,2,0))
-            print("img shp", img.shape)
-            plt.imshow(img)         
-                            
-        
-def main():
-    #cf.set_splits = ["train"]
-    #cities = {'train':['dusseldorf'], 'val':['frankfurt']} #cf.cities
-    cities= cf.cities
-    
-    pp = Preprocessor(cf, cities)
-    pp.prep_imgs(max_num=None, processes=8)
-    pp.finish()
-
-    #pp.convert_copy_npz()
-
-    pp.verification(1)
-    
-    
-    
-    
-    
-    
-    return
-
-if __name__=="__main__":
-    stime = time.time()
-    
-    main()
-    
-    mins, secs = divmod((time.time() - stime), 60)
-    h, mins = divmod(mins, 60)
-    t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) 
-    print("Prepro program runtime: {}".format(t))
diff --git a/datasets/legacy/convert_folds_ids.py b/datasets/legacy/convert_folds_ids.py
deleted file mode 100644
index ba16b34..0000000
--- a/datasets/legacy/convert_folds_ids.py
+++ /dev/null
@@ -1,148 +0,0 @@
-"""
-Created at 28.05.19 16:46
-@author: gregor 
-"""
-
-import os
-import sys
-import subprocess
-
-import pickle
-import numpy as np
-import pandas as pd
-from collections import OrderedDict
-
-import utils.exp_utils as utils
-
-def get_cf(dataset_name, exp_dir=""):
-
-    cf_path = os.path.join('datasets', dataset_name, exp_dir, "configs.py")
-    cf_file = utils.import_module('configs', cf_path)
-
-    return cf_file.Configs()
-
-def vector(item):
-    """ensure item is vector-like (list or array or tuple)
-    :param item: anything
-    """
-    if not isinstance(item, (list, tuple, np.ndarray)):
-        item = [item]
-    return item
-
-def load_dataset(cf, subset_ixs=None):
-    """
-    loads the dataset. if deployed in cloud also copies and unpacks the data to the working directory.
-    :param subset_ixs: subset indices to be loaded from the dataset. used e.g. for testing to only load the test folds.
-    :return: data: dictionary with one entry per patient (in this case per patient-breast, since they are treated as
-    individual images for training) each entry is a dictionary containing respective meta-info as well as paths to the preprocessed
-    numpy arrays to be loaded during batch-generation
-    """
-
-    p_df = pd.read_pickle(os.path.join(cf.pp_data_path, cf.input_df_name))
-
-    exclude_pids = ["0305a", "0447a"] # due to non-bg segmentation but bg mal label in nodules 5728, 8840
-    p_df = p_df[~p_df.pid.isin(exclude_pids)]
-
-    if cf.select_prototype_subset is not None:
-        prototype_pids = p_df.pid.tolist()[:cf.select_prototype_subset]
-        p_df = p_df[p_df.pid.isin(prototype_pids)]
-        logger.warning('WARNING: using prototyping data subset!!!')
-    if subset_ixs is not None:
-        subset_pids = [np.unique(p_df.pid.tolist())[ix] for ix in subset_ixs]
-        p_df = p_df[p_df.pid.isin(subset_pids)]
-
-        print('subset: selected {} instances from df'.format(len(p_df)))
-
-    pids = p_df.pid.tolist()
-    cf.data_dir = cf.pp_data_path
-
-
-    imgs = [os.path.join(cf.data_dir, '{}_img.npy'.format(pid)) for pid in pids]
-    segs = [os.path.join(cf.data_dir,'{}_rois.npz'.format(pid)) for pid in pids]
-    orig_class_targets = p_df['class_target'].tolist()
-
-    data = OrderedDict()
-    for ix, pid in enumerate(pids):
-        data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid}
-        data[pid]['fg_slices'] = np.array(p_df['fg_slices'].tolist()[ix])
-        if 'class' in cf.prediction_tasks:
-            # malignancy scores are binarized: (benign: 1-2 --> cl 1, malignant: 3-5 --> cl 2)
-            raise NotImplementedError
-            # todo need to consider bg
-            data[pid]['class_targets'] = np.array([ [2 if ii >= 3 else 1 for ii in four_fold_targs] for four_fold_targs in orig_class_targets[ix]])
-        else:
-            data[pid]['class_targets'] = np.array([ [1 if ii>0 else 0 for ii in four_fold_targs] for four_fold_targs in orig_class_targets[ix]], dtype='uint8')
-        if any(['regression' in task for task in cf.prediction_tasks]):
-            data[pid]["regression_targets"] = np.array([ [vector(v) for v in four_fold_targs] for four_fold_targs in orig_class_targets[ix] ], dtype='float16')
-            data[pid]["rg_bin_targets"] = np.array([ [cf.rg_val_to_bin_id(v) for v in four_fold_targs] for four_fold_targs in data[pid]["regression_targets"]], dtype='uint8')
-
-    cf.roi_items = cf.observables_rois[:]
-    cf.roi_items += ['class_targets']
-    if any(['regression' in task for task in cf.prediction_tasks]):
-        cf.roi_items += ['regression_targets']
-        cf.roi_items += ['rg_bin_targets']
-
-    return data
-
-
-def get_patient_identifiers(cf, fold_lists):
-
-
-    all_data = load_dataset(cf)
-    all_pids_list = np.unique([v['pid'] for (k, v) in all_data.items()])
-
-
-    verifier = [] #list of folds
-    for fold in range(cf.n_cv_splits):
-        train_ix, val_ix, test_ix, fold_nr = fold_lists[fold]
-        assert fold==fold_nr
-        test_ids = [all_pids_list[ix] for ix in test_ix]
-        for ix, arr in enumerate(verifier):
-            inter = np.intersect1d(test_ids, arr)
-            #print("intersect of fold {} with fold {}: {}".format(fold, ix, inter))
-            assert len(inter)==0
-        verifier.append(test_ids)
-
-
-    return verifier
-
-def convert_folds_ids(exp_dir):
-    import inference_analysis
-    cf = get_cf('lidc', exp_dir=exp_dir)
-    cf.exp_dir = exp_dir
-    with open(os.path.join(exp_dir, 'fold_ids.pickle'), 'rb') as f:
-        fids = pickle.load(f)
-
-    pid_fold_splits = get_patient_identifiers(cf, fids)
-
-    with open(os.path.join(exp_dir, 'fold_real_ids.pickle'), 'wb') as handle:
-        pickle.dump(pid_fold_splits, handle)
-
-
-    #inference_analysis.find_pid_in_splits('0811a', exp_dir=exp_dir)
-    return
-
-
-def copy_to_new_exp_dir(old_dir, new_dir):
-
-
-    cp_ids = r"rsync {} {}".format(os.path.join(old_dir, 'fold_real_ids.pickle'), new_dir)
-    rn_ids = "mv {} {}".format(os.path.join(new_dir, 'fold_real_ids.pickle'), os.path.join(new_dir, 'fold_ids.pickle'))
-    cp_params = r"""rsync -a --include='*/' --include='*best_params.pth' --exclude='*' --prune-empty-dirs  
-    {}  {}""".format(old_dir, new_dir)
-    cp_ranking = r"""rsync -a --include='*/' --include='epoch_ranking.npy' --exclude='*' --prune-empty-dirs  
-        {}  {}""".format(old_dir, new_dir)
-    cp_results = r"""rsync -a --include='*/' --include='pred_results.pkl' --exclude='*' --prune-empty-dirs  
-        {}  {}""".format(old_dir, new_dir)
-
-    for cmd in  [cp_ids, rn_ids, cp_params, cp_ranking, cp_results]:
-        subprocess.call(cmd, shell=True)
-    print("Setup {} for inference with ids, params from {}".format(new_dir, old_dir))
-
-
-
-if __name__=="__main__":
-    exp_dir = '/home/gregor/networkdrives/E132-Cluster-Projects/lidc_sa/experiments/ms12345_mrcnn3d_rgbin_bs8'
-    new_exp_dir = '/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/ms12345_mrcnn3d_rgbin_copiedparams'
-    #convert_folds_ids(exp_dir)
-    copy_to_new_exp_dir(exp_dir, new_exp_dir)
\ No newline at end of file
diff --git a/datasets/prostate/check_GSBx_Re.py b/datasets/prostate/check_GSBx_Re.py
deleted file mode 100755
index 8f64ca1..0000000
--- a/datasets/prostate/check_GSBx_Re.py
+++ /dev/null
@@ -1,120 +0,0 @@
-"""
-Created at 20/11/18 16:18
-@author: gregor 
-"""
-import os
-import numpy as np
-import pandas as pd
-
-
-class CombinedPrinter(object):
-    """combined print function.
-    prints to logger and/or file if given, to normal print if non given.
-
-    """
-    def __init__(self, logger=None, file=None):
-
-        if logger is None and file is None:
-            self.out = [print]
-        elif logger is None:
-            self.out = [print, file.write]
-        elif file is None:
-            self.out = [print, logger.info]
-        else:
-            self.out = [print, logger.info, file.write]
-
-    def __call__(self, string):
-        for fct in self.out:
-            fct(string)
-
-def spec_to_id(spec):
-    """Get subject id from string"""
-    return int(spec[-5:])
-
-
-def pat_roi_GS_histo_check(root_dir):
-    """ Check, in histo files, whether patient-wide Gleason Score equals maximum GS found in single lesions of patient.
-    """
-
-    histo_les_path = os.path.join(root_dir, "MasterHistoAll.csv")
-    histo_pat_path = os.path.join(root_dir, "MasterPatientbasedAll_clean.csv")
-
-    with open(histo_les_path,mode="r") as les_file:
-        les_df = pd.read_csv(les_file, delimiter=",")
-    with open(histo_pat_path, mode="r") as pat_file:
-        pat_df = pd.read_csv(pat_file, delimiter=",")
-
-    merged_df = les_df.groupby('Master_ID').agg({'Gleason': 'max', 'segmentationsNameADC': 'last'})
-
-    for pid in merged_df.index:
-        merged_df.set_value(pid, "GSBx", pat_df[pat_df.Master_ID_Short==pid].GSBx.unique().astype('uint32'))
-
-    #print(merged_df)
-    print("All patient-wise GS are maximum of lesion-wise GS?", np.all(merged_df.Gleason == merged_df.GSBx), end="\n\n")
-    assert np.all(merged_df.Gleason == merged_df.GSBx)
-
-
-def lesion_redone_check(root_dir, out_path=None):
-    """check how many les annotations without post_fix _Re exist and if exists what their GS is
-    """
-
-    histo_les_path = os.path.join(root_dir, "Dokumente/MasterHistoAll.csv")
-    with open(histo_les_path,mode="r") as les_file:
-        les_df = pd.read_csv(les_file, delimiter=",")
-    if out_path is not None:
-        out_file = open(out_path, "w")
-    else:
-        out_file = None
-    print_f = CombinedPrinter(file=out_file)
-
-    data_dir = os.path.join(root_dir, "Daten")
-
-    matches = {}
-    for patient in [dir for dir in os.listdir(data_dir) if dir.startswith("Master_") \
-                    and os.path.isdir(os.path.join(data_dir, dir))]:
-        matches[patient] = {}
-        pat_dir = os.path.join(data_dir,patient)
-        lesions = [os.path.splitext(file)[0] for file in os.listdir(pat_dir) if os.path.isfile(os.path.join(pat_dir,file)) and file.startswith("seg") and "LES" in file]
-        lesions_wo = [os.path.splitext(file)[0] for file in lesions if not "_Re" in file]
-        lesions_with = [file for file in lesions if "_Re" in file and not "registered" in file]
-
-        matches[patient] = {les_wo : [] for les_wo in lesions_wo}
-
-        for les_wo in matches[patient].keys():
-            matches[patient][les_wo] += [les_with for les_with in lesions_with if les_with.startswith(les_wo)]
-
-    missing_les_count = 0
-    for patient, lesions in sorted(list(matches.items())):
-        pat_df = les_df[les_df.Master_ID==spec_to_id(patient)]
-        for les, les_matches in sorted(list(lesions.items())):
-            if len(les_matches)==0:
-                if "t2" in les.lower():
-                    les_GS = pat_df[pat_df.segmentationsNameT2==les]["Gleason"]
-                elif "adc" in les.lower():
-                    les_GS = pat_df[pat_df.segmentationsNameADC==les]["Gleason"]
-                if len(les_GS)==0:
-                    les_GS = r"[no histo finding!]"
-                print_f("Patient {}, lesion {} with GS {} has no matches!\n".format(patient, les, les_GS))
-                missing_les_count +=1
-            else:
-                del matches[patient][les]
-            #elif len(les_matches) > 1:
-            #    print("Patient {}, Lesion {} has {} matches: {}".format(patient, les, len(les_matches), les_matches))
-        if len(matches[patient])==0:
-            del matches[patient]
-
-    print_f("Total missing lesion matches: {} within {} patients".format(missing_les_count, len(matches)))
-
-    out_file.close()
-
-
-if __name__=="__main__":
-
-    #root_dir = "/mnt/HDD2TB/Documents/data/prostate/data_di_ana_081118_ps384_gs71/histos/"
-    root_dir = "/mnt/E132-Projekte/Move_to_E132-Rohdaten/Prisma_Master/Dokumente"
-    pat_roi_GS_histo_check(root_dir)
-
-    root_dir = "/mnt/E132-Projekte/Move_to_E132-Rohdaten/Prisma_Master"
-    out_path = os.path.join(root_dir,"lesion_redone_check.txt")
-    lesion_redone_check(root_dir, out_path=out_path)
-
diff --git a/datasets/prostate/configs.py b/datasets/prostate/configs.py
deleted file mode 100644
index 2de02f3..0000000
--- a/datasets/prostate/configs.py
+++ /dev/null
@@ -1,588 +0,0 @@
-__author__ = ''
-#credit Paul F. Jaeger
-
-#########################
-#     Example Config    #
-#########################
-
-import os
-import sys
-import pickle
-
-import numpy as np
-import torch
-
-from collections import namedtuple
-
-from default_configs import DefaultConfigs
-
-def load_obj(file_path):
-    with open(file_path, 'rb') as handle:
-        return pickle.load(handle)
-
-# legends, nested classes are not handled well in multiprocessing! hence, Label class def in outer scope
-Label = namedtuple("Label", ['id', 'name', 'color', 'gleasons'])
-binLabel = namedtuple("Label", ['id', 'name', 'color', 'gleasons', 'bin_vals'])
-
-
-class Configs(DefaultConfigs): #todo change to Configs
-
-    def __init__(self, server_env=None):
-        #########################
-        #         General       #
-        #########################
-        super(Configs, self).__init__(server_env)
-
-        #########################
-        #         I/O           #
-        #########################
-
-        self.data_sourcedir = "/mnt/HDD2TB/Documents/data/prostate/data_di_250519_ps384_gs6071/"
-        #self.data_sourcedir = "/mnt/HDD2TB/Documents/data/prostate/data_t2_250519_ps384_gs6071/"
-        #self.data_sourcedir = "/mnt/HDD2TB/Documents/data/prostate/data_analysis/"
-
-        if server_env:
-            self.data_sourcedir = "/datasets/data_ramien/prostate/data_di_250519_ps384_gs6071_npz/"
-            #self.data_sourcedir = '/datasets/data_ramien/prostate/data_t2_250519_ps384_gs6071_npz/'
-            #self.data_sourcedir = "/mnt/HDD2TB/Documents/data/prostate/data_di_ana_151118_ps384_gs60/"
-
-        self.histo_dir = os.path.join(self.data_sourcedir,"histos/")
-        self.info_dict_name = 'master_info.pkl'
-        self.info_dict_path = os.path.join(self.data_sourcedir, self.info_dict_name)
-
-        self.config_path = os.path.realpath(__file__)
-
-        # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_fpn'].
-        self.model = 'detection_fpn'
-        self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net')
-        self.model_path = os.path.join(self.source_dir,self.model_path)
-                       
-        self.select_prototype_subset = None
-
-        #########################
-        #      Preprocessing    #
-        #########################
-        self.missing_pz_subjects = [#189, 196, 198, 205, 211, 214, 215, 217, 218, 219, 220,
-                                     #223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233,
-                                     #234, 235, 236, 237, 238, 239, 240, 241, 242, 244, 258,
-                                     #261, 262, 264, 267, 268, 269, 270, 271, 273, 275, 276,
-                                     #277, 278, 283
-                                    ]
-        self.no_bval_radval_subjects = [57] #this guy has master id 222
-        
-        self.prepro = {
-            'data_dir': '/home/gregor/networkdrives/E132-Projekte/Move_to_E132-Rohdaten/Prisma_Master/Daten/',
-            'dir_spec': 'Master',
-            #'images': {'t2': 'T2TRA', 'adc': 'ADC1500', 'b50': 'BVAL50', 'b500': 'BVAL500',
-            #     'b1000': 'BVAL1000', 'b1500': 'BVAL1500'},
-            #'images': {'adc': 'ADC1500', 'b50': 'BVAL50', 'b500': 'BVAL500', 'b1000': 'BVAL1000', 'b1500': 'BVAL1500'},
-            'images': {'t2': 'T2TRA'},
-            'anatomical_masks': ['seg_T2_PRO'], # try: 'seg_T2_PRO','seg_T2_PZ', 'seg_ADC_PRO', 'seg_ADC_PZ',
-            'merge_mode' : 'union', #if registered data w/ two gts: take 'union' or 'adc' or 't2' of gt
-            'rename_tags': {'seg_ADC_PRO':"pro", 'seg_T2_PRO':"pro", 'seg_ADC_PZ':"pz", 'seg_T2_PZ':"pz"},
-            'lesion_postfix': '_Re', #lesion files are tagged seg_MOD_LESx
-            'img_postfix': "_resampled2", #"_resampled2_registered",
-            'overall_postfix': ".nrrd", #including filetype ending!
-
-            'histo_dir': '/home/gregor/networkdrives/E132-Projekte/Move_to_E132-Rohdaten/Prisma_Master/Dokumente/',
-            'histo_dir_out': self.histo_dir,
-            'histo_lesion_based': 'MasterHistoAll.csv',
-            'histo_patient_based': 'MasterPatientbasedAll_clean.csv',
-            'histo_id_column_name': 'Master_ID',
-            'histo_pb_id_column_name': 'Master_ID_Short', #for patient histo
-
-            'excluded_prisma_subjects': [],
-            'excluded_radval_subjects': self.no_bval_radval_subjects,
-            'excluded_master_subjects': self.missing_pz_subjects,
-
-            'seg_labels': {'tz': 0, 'pz': 0, 'lesions':'roi'},
-            #set as hard label or 'roi' to have seg labels represent obj instance count
-            #if not given 'lesions' are numbered highest seg label +lesion-nr-in-histofile
-            'class_labels': {'lesions':'gleason'}, #0 is not bg, but first fg class!
-            #i.e., prepro labels are shifted by -1 towards later training labels in gt, legends, dicts, etc.
-            #evtly set lesions to 'gleason' and check gleason remap in prepro
-            #'gleason_thresh': 71,
-            'gleason_mapping': {0: -1, 60:0, 71:1, 72:1, 80:1, 90:1, 91:1, 92:1},
-            'gleason_map': self.gleason_map, #see below
-            'color_palette': [self.green, self.red],
-
-            'output_directory': self.data_sourcedir,
-
-            'modalities2concat' : "all", #['t2', 'adc','b50','b500','b1000','b1500'], #will be concatenated on colorchannel
-            'center_of_mass_crop': True,
-            'mod_scaling' : (1,1,1), #z,y,x
-            'pre_crop_size': [20, 384, 384], #z,y,x, z-cropping and non-square not implemented atm!!
-            'swap_yx_to_xy': False, #change final spatial shape from z,y,x to z,x,y
-            'normalization': {'percentiles':[1., 99.]},
-            'interpolation': 'nearest',
-
-            'observables_patient': ['Original_ID', 'GSBx', 'PIRADS2', 'PSA'],
-            'observables_rois': ['lesion_gleasons'],
-
-            'info_dict_path': self.info_dict_path,
-
-            'npz_dir' : self.data_sourcedir[:-1]+"_npz" #if not None: convert to npz, copy data here
-         }
-        if self.prepro["modalities2concat"] == "all":
-            self.prepro["modalities2concat"] = list(self.prepro["images"].keys())
-
-        #########################
-        #      Architecture     #
-        #########################
-
-        # dimension the model operates in. one out of [2, 3].
-        self.dim = 2
-
-        # 'class': standard object classification per roi, pairwise combinable with each of below tasks.
-        # if 'class' is omitted from tasks, object classes will be fg/bg (1/0) from RPN.
-        # 'regression': regress some vector per each roi
-        # 'regression_ken_gal': use kendall-gal uncertainty sigma
-        # 'regression_bin': classify each roi into a bin related to a regression scale
-        self.prediction_tasks = ['class',]
-
-        self.start_filts = 48 if self.dim == 2 else 18
-        self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2
-        self.res_architecture = 'resnet50' # 'resnet101' or 'resnet50'
-        self.weight_init = None #'kaiming_normal' #, 'xavier' or None-->pytorch standard,
-        self.norm = None #'instance_norm' # one of 'None', 'instance_norm', 'batch_norm'
-        self.relu = 'relu' # 'relu' or 'leaky_relu'
-
-        self.regression_n_features = 1 #length of regressor target vector (always 1D)
-
-        #########################
-        #      Data Loader      #
-        #########################
-
-        self.seed = 17
-        self.n_workers = 16 if server_env else os.cpu_count()
-
-        self.batch_size = 10 if self.dim == 2 else 6
-
-        self.channels = [1, 2, 3, 4]  # modalities2load, see prepo
-        self.n_channels = len(self.channels)  # for compatibility, but actually redundant
-        # which channel (mod) to show as bg in plotting, will be extra added to batch if not in self.channels
-        self.plot_bg_chan = 0
-        self.pre_crop_size = list(np.array(self.prepro['pre_crop_size'])[[1, 2, 0]])  # now y,x,z
-        self.crop_margin = [20, 20, 1]  # has to be smaller than respective patch_size//2
-        self.patch_size_2D = self.pre_crop_size[:2] #[288, 288]
-        self.patch_size_3D = self.pre_crop_size[:2] + [8]  # only numbers divisible by 2 multiple times
-        # (at least 5 times for x,y, at least 3 for z)!
-        # otherwise likely to produce error in crop fct or net
-        self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D
-
-        self.balance_target = "class_targets" if 'class' in self.prediction_tasks else 'rg_bin_targets'
-        # ratio of fully random patients drawn during batch generation
-        # resulting batch random count is rounded down to closest integer
-        self.batch_random_ratio = 0.2 if self.dim==2 else 0.4
-
-        self.observables_patient = ['Original_ID', 'GSBx', 'PIRADS2']
-        self.observables_rois = ['lesion_gleasons']
-
-        self.regression_target = "lesion_gleasons"  # name of the info_dict entry holding regression targets
-        # linear mapping
-        self.rg_map = {0: 0, 60: 1, 71: 2, 72: 3, 80: 4, 90: 5, 91: 6, 92: 7, None: 0}
-        # non-linear mapping
-        #self.rg_map = {0: 0, 60: 1, 71: 6, 72: 7.5, 80: 9, 90: 10, 91: 10, 92: 10, None: 0}
-
-        #########################
-        #   Colors and Legends  #
-        #########################
-        self.plot_frequency = 5
-
-        # colors
-        self.gravity_col_palette = [self.green, self.yellow, self.orange, self.bright_red, self.red, self.dark_red]
-
-        self.gs_labels = [
-            Label(0,    'bg',   self.gray,     (0,)),
-            Label(60,   'GS60', self.dark_green,     (60,)),
-            Label(71,   'GS71', self.dark_yellow,    (71,)),
-            Label(72,   'GS72', self.orange,    (72,)),
-            Label(80,   'GS80', self.brighter_red,(80,)),
-            Label(90,   'GS90', self.bright_red,       (90,)),
-            Label(91,   'GS91', self.red,       (91,)),
-            Label(92,   'GS92', self.dark_red,  (92,))
-        ]
-        self.gs2label = {label.id: label for label in self.gs_labels}
-
-
-        binary_cl_labels = [Label(1, 'benign',      (*self.green, 1.),  (60,)),
-                            Label(2, 'malignant',   (*self.red, 1.),    (71,72,80,90,91,92)),
-                            #Label(3, 'pz',          (*self.blue, 1.),   (None,)),
-                            #Label(4, 'tz',          (*self.aubergine, 1.), (None,))
-                        ]
-
-        self.class_labels = [
-                    #id #name           #color              #gleason score
-            Label(  0,  'bg',           (*self.gray, 0.),  (0,))]
-        if "class" in self.prediction_tasks:
-                self.class_labels += binary_cl_labels
-                # self.class_labels += [Label(cl, cl_dic["name"], cl_dic["color"], tuple(cl_dic["gleasons"]))
-                #                      for cl, cl_dic in
-                #                      load_obj(os.path.join(self.data_sourcedir, "pp_class_labels.pkl")).items()]
-        else:
-            self.class_labels += [Label(  1,  'lesion',    (*self.red, 1.),    (60,71,72,80,90,91,92))]
-
-        if any(['regression' in task for task in self.prediction_tasks]):
-            self.bin_labels = [binLabel(0, 'bg', (*self.gray, 0.), (0,), (0,))]
-            self.bin_labels += [binLabel(cl, cl_dic["name"], cl_dic["color"], tuple(cl_dic["gleasons"]),
-                                         tuple([self.rg_map[gs] for gs in cl_dic["gleasons"]])) for cl, cl_dic in
-                                sorted(load_obj(os.path.join(self.data_sourcedir, "pp_class_labels.pkl")).items())]
-            self.bin_id2label = {label.id: label for label in self.bin_labels}
-            self.gs2bin_label = {gs: label for label in self.bin_labels for gs in label.gleasons}
-            bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels]
-            self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)}
-            self.bin_edges = [(bins[i][1] + bins[i+1][0]) / 2 for i in range(len(bins)-1)]
-            self.bin_dict = {label.id: label.name for label in self.bin_labels if label.id != 0}
-
-
-        if self.class_specific_seg:
-            self.seg_labels = self.class_labels
-        else:
-            self.seg_labels = [  # id      #name           #color
-                Label(0, 'bg', (*self.white, 0.)),
-                Label(1, 'fg', (*self.orange, 1.))
-            ]
-
-        self.class_id2label = {label.id: label for label in self.class_labels}
-        self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0}
-        # class_dict is used in evaluator / ap, auc, etc. statistics, and class 0 (bg) only needs to be
-        # evaluated in debugging
-        self.class_cmap = {label.id: label.color for label in self.class_labels}
-
-        self.seg_id2label = {label.id: label for label in self.seg_labels}
-        self.cmap = {label.id: label.color for label in self.seg_labels}
-
-        self.plot_prediction_histograms = True
-        self.plot_stat_curves = False
-        self.plot_class_ids = True
-
-        self.num_classes = len(self.class_dict)  # for instance classification (excl background)
-        self.num_seg_classes = len(self.seg_labels)  # incl background
-
-        #########################
-        #   Data Augmentation   #
-        #########################
-        #the angle rotations are implemented incorrectly in batchgenerators! in 2D,
-        #the x-axis angle controls the z-axis angle.
-        if self.dim == 2:
-            angle_x = (-np.pi / 3., np.pi / 3.)
-            angle_z = (0.,0.)
-            rcd = (self.patch_size[0] / 2., self.patch_size[1] / 2.)
-        else:
-            angle_x = (0.,0.)
-            angle_z = (-np.pi / 2., np.pi / 2.)
-            rcd = (self.patch_size[0] / 2., self.patch_size[1] / 2.,
-                   self.patch_size[2] / 2.)
-        
-        self.do_aug = True
-        # DA settings for DWI
-        self.da_kwargs = {
-            'mirror': True,
-            'mirror_axes': tuple(np.arange(0, self.dim, 1)),
-            'random_crop': True,
-            'rand_crop_dist': rcd,
-            'do_elastic_deform': self.dim==2,
-            'alpha': (0., 1500.),
-            'sigma': (25., 50.),
-            'do_rotation': True,
-            'angle_x': angle_x,
-            'angle_y': (0., 0.),
-            'angle_z': angle_z,
-            'do_scale': True,
-            'scale': (0.7, 1.3),
-            'border_mode_data': 'constant',
-            'gamma_transform': True,
-            'gamma_range': (0.5, 2.)
-        }
-        # for T2
-        # self.da_kwargs = {
-        #     'mirror': True,
-        #     'mirror_axes': tuple(np.arange(0, self.dim, 1)),
-        #     'random_crop': False,
-        #     'rand_crop_dist': rcd,
-        #     'do_elastic_deform': False,
-        #     'alpha': (0., 1500.),
-        #     'sigma': (25., 50.),
-        #     'do_rotation': True,
-        #     'angle_x': angle_x,
-        #     'angle_y': (0., 0.),
-        #     'angle_z': angle_z,
-        #     'do_scale': False,
-        #     'scale': (0.7, 1.3),
-        #     'border_mode_data': 'constant',
-        #     'gamma_transform': False,
-        #     'gamma_range': (0.5, 2.)
-        # }
-
-
-        #################################
-        #  Schedule / Selection / Optim #
-        #################################
-
-        # good guess: train for n_samples = 1.1m = epochs*n_train_bs*b_size
-        self.num_epochs = 270
-        self.num_train_batches = 120 if self.dim == 2 else 140
-        
-        self.val_mode = 'val_patient' # one of 'val_sampling', 'val_patient'
-        # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training)
-        # the former is more accurate, while the latter is faster (depending on volume size)
-        self.num_val_batches = 200 if self.dim==2 else 40 # for val_sampling, number or "all"
-        self.max_val_patients = "all"  #for val_patient, "all" takes whole split
-
-        self.save_n_models = 6
-        self.min_save_thresh = 3 if self.dim == 2 else 4 #=wait time in epochs
-        if "class" in self.prediction_tasks:
-            # 'criterion': weight
-            self.model_selection_criteria = {"benign_ap": 0.2, "malignant_ap": 0.8}
-        elif any("regression" in task for task in self.prediction_tasks):
-            self.model_selection_criteria = {"lesion_ap": 0.2, "lesion_avp": 0.8}
-        #self.model_selection_criteria = {"GS71-92_ap": 0.9, "GS60_ap": 0.1}  # 'criterion':weight
-        #self.model_selection_criteria = {"lesion_ap": 0.2, "lesion_avp": 0.8}
-        #self.model_selection_criteria = {label.name+"_ap": 1. for label in self.class_labels if label.id!=0}
-
-        self.scan_det_thresh = False
-        self.warm_up = 0
-
-        self.optimizer = "ADAM"
-        self.weight_decay = 1e-5
-        self.clip_norm = None #number or None
-
-        self.learning_rate = [1e-4] * self.num_epochs
-        self.dynamic_lr_scheduling = True
-        self.lr_decay_factor = 0.5
-        self.scheduling_patience = int(self.num_epochs / 6)
-
-        #########################
-        #   Testing             #
-        #########################
-
-        self.test_aug_axes = (0,1,(0,1))  # None or list: choices are 0,1,(0,1) (0==spatial y, 1== spatial x).
-        self.held_out_test_set = False
-        self.max_test_patients = "all"  # "all" or number
-        self.report_score_level = ['rois', 'patient']  # 'patient' or 'rois' (incl)
-        self.patient_class_of_interest = 2 if 'class' in self.prediction_tasks else 1
-
-
-        self.eval_bins_separately = "additionally" if not 'class' in self.prediction_tasks else False
-        self.patient_bin_of_interest = 2
-        self.metrics = ['ap', 'auc', 'dice']
-        if any(['regression' in task for task in self.prediction_tasks]):
-            self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp',
-                             'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp']
-        if 'aleatoric' in self.model:
-            self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted']
-        self.evaluate_fold_means = True
-
-        self.min_det_thresh = 0.02
-
-        self.ap_match_ious = [0.1]  # threshold(s) for considering a prediction as true positive
-        # aggregation method for test and val_patient predictions.
-        # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf,
-        # nms = standard non-maximum suppression, or None = no clustering
-        self.clustering = 'wbc'
-        # iou thresh (exclusive!) for regarding two preds as concerning the same ROI
-        self.clustering_iou = 0.1  # has to be larger than desired possible overlap iou of model predictions
-        # 2D-3D merging is applied independently from clustering setting.
-        self.merge_2D_to_3D_preds = True if self.dim == 2 else False
-        self.merge_3D_iou = 0.1
-        self.n_test_plots = 1  # per fold and rank
-        self.test_n_epochs = self.save_n_models  # should be called n_test_ens, since is number of models to ensemble over during testing
-        # is multiplied by n_test_augs if test_aug
-
-        #########################
-        # shared model settings #
-        #########################
-
-        # max number of roi candidates to identify per image and class (slice in 2D, volume in 3D)
-        self.n_roi_candidates = 10 if self.dim == 2 else 15
-
-        #########################
-        #      assertions       #
-        #########################
-        if not 'class' in self.prediction_tasks:
-            assert self.num_classes == 1
-        for mod in self.prepro['modalities2concat']:
-            assert mod in self.prepro['images'].keys(), "need to adapt mods2concat to chosen images"
-
-        #########################
-        #   Add model specifics #
-        #########################
-        
-        {'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs,
-         'mrcnn_gan': self.add_mrcnn_configs,
-         'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs,
-         'detection_unet': self.add_det_unet_configs, 'detection_fpn': self.add_det_fpn_configs
-         }[self.model]()
-
-    def gleason_map(self, GS):
-        """gleason to class id
-        :param GS: gleason score as in histo file
-        """
-        if "gleason_thresh" in self.prepro.keys():
-            assert "gleason_mapping" not in self.prepro.keys(), "cant define both, thresh and map, for GS to classes"
-            # -1 == bg, 0 == benign, 1 == malignant
-            # before shifting, i.e., 0!=bg, but 0==first class
-            remapping = 0 if GS >= self.prepro["gleason_thresh"] else -1
-            return remapping
-        elif "gleason_mapping" in self.prepro.keys():
-            return self.prepro["gleason_mapping"][GS]
-        else:
-            raise Exception("Need to define some remapping, at least GS 0 -> background (class -1)")
-
-    def rg_val_to_bin_id(self, rg_val):
-        return float(np.digitize(rg_val, self.bin_edges))
-
-    def add_det_fpn_configs(self):
-        self.scheduling_criterion = 'torch_loss'
-        self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
-
-        # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
-        self.seg_loss_mode = 'wce'
-        self.wce_weights = [1]*self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1, 1]
-        # if <1, false positive predictions in foreground are penalized less.
-        self.fp_dice_weight = 1 if self.dim == 2 else 1
-
-
-        self.detection_min_confidence = 0.05
-        #how to determine score of roi: 'max' or 'median'
-        self.score_det = 'max'
-
-        self.cuda_benchmark = self.dim==3
-
-    def add_det_unet_configs(self):
-        self.scheduling_criterion = "torch_loss"
-        self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
-
-        # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
-        self.seg_loss_mode = 'wce'
-        self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1, 1]
-        # if <1, false positive predictions in foreground are penalized less.
-        self.fp_dice_weight = 1 if self.dim == 2 else 1
-
-        self.detection_min_confidence = 0.05
-        #how to determine score of roi: 'max' or 'median'
-        self.score_det = 'max'
-
-        self.init_filts = 32
-        self.kernel_size = 3 #ks for horizontal, normal convs
-        self.kernel_size_m = 2 #ks for max pool
-        self.pad = "same" # "same" or integer, padding of horizontal convs
-
-        self.cuda_benchmark = True
-
-    def add_mrcnn_configs(self):
-
-        self.scheduling_criterion = max(self.model_selection_criteria, key=self.model_selection_criteria.get)
-        self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
-
-        # number of classes for network heads: n_foreground_classes + 1 (background)
-        self.head_classes = self.num_classes + 1
-        #
-        # feed +/- n neighbouring slices into channel dimension. set to None for no context.
-        self.n_3D_context = None
-        if self.n_3D_context is not None and self.dim == 2:
-            self.n_channels *= (self.n_3D_context * 2 + 1)
-
-        self.frcnn_mode = False
-        # disable the re-sampling of mask proposals to original size for speed-up.
-        # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching),
-        # mask outputs are optional.
-        self.return_masks_in_train = True
-        self.return_masks_in_val = True
-        self.return_masks_in_test = True
-
-        # feature map strides per pyramid level are inferred from architecture. anchor scales are set accordingly.
-        self.backbone_strides =  {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]}
-        # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale
-        # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.)
-        self.rpn_anchor_scales = {'xy': [[4], [8], [16], [32]], 'z': [[1], [2], [4], [8]]}
-        # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3.
-        self.pyramid_levels = [0, 1, 2, 3]
-        # number of feature maps in rpn. typically lowered in 3D to save gpu-memory.
-        self.n_rpn_features = 512 if self.dim == 2 else 128
-
-        # anchor ratios and strides per position in feature maps.
-        self.rpn_anchor_ratios = [0.5,1.,2.]
-        self.rpn_anchor_stride = 1
-        # Threshold for first stage (RPN) non-maximum suppression (NMS):  LOWER == HARDER SELECTION
-        self.rpn_nms_threshold = 0.7 if self.dim == 2 else 0.7
-
-        # loss sampling settings.
-        self.rpn_train_anchors_per_image = 6
-        self.train_rois_per_image = 6 #per batch_instance
-        self.roi_positive_ratio = 0.5
-        self.anchor_matching_iou = 0.7
-
-        # k negative example candidates are drawn from a pool of size k*shem_poolsize (stochastic hard-example mining),
-        # where k<=#positive examples.
-        self.shem_poolsize = 3
-
-        self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3)
-        self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5)
-        self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10)
-
-        self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
-        self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
-        self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]])
-        self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1],
-                               self.patch_size_3D[2], self.patch_size_3D[2]]) #y1,x1,y2,x2,z1,z2
-
-        if self.dim == 2:
-            self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4]
-            self.bbox_std_dev = self.bbox_std_dev[:4]
-            self.window = self.window[:4]
-            self.scale = self.scale[:4]
-
-        self.plot_y_max = 1.5
-        self.n_plot_rpn_props = 5 if self.dim == 2 else 30 #per batch_instance (slice in 2D / patient in 3D)
-
-        # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element.
-        self.pre_nms_limit = 3000 if self.dim == 2 else 6000
-
-        # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True,
-        # since proposals of the entire batch are forwarded through second stage in as one "batch".
-        self.roi_chunk_size = 2000 if self.dim == 2 else 400
-        self.post_nms_rois_training = 250 * (self.head_classes-1) if self.dim == 2 else 500
-        self.post_nms_rois_inference = 250 * (self.head_classes-1)
-
-        # Final selection of detections (refine_detections)
-        self.model_max_instances_per_batch_element = self.n_roi_candidates  # per batch element and class.
-        # iou for nms in box refining (directly after heads), should be >0 since ths>=x in mrcnn.py, otherwise all predictions are one cluster.
-        self.detection_nms_threshold = 1e-5
-        # detection score threshold in refine_detections()
-        self.model_min_confidence = 0.05 #self.min_det_thresh/2
-
-        if self.dim == 2:
-            self.backbone_shapes = np.array(
-                [[int(np.ceil(self.patch_size[0] / stride)),
-                  int(np.ceil(self.patch_size[1] / stride))]
-                 for stride in self.backbone_strides['xy']])
-        else:
-            self.backbone_shapes = np.array(
-                [[int(np.ceil(self.patch_size[0] / stride)),
-                  int(np.ceil(self.patch_size[1] / stride)),
-                  int(np.ceil(self.patch_size[2] / stride_z))]
-                 for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z']
-            )])
-
-        self.operate_stride1 = False
-
-        if self.model == 'retina_net' or self.model == 'retina_unet':
-            self.cuda_benchmark = self.dim == 3
-            #implement extra anchor-scales according to https://arxiv.org/abs/1708.02002
-            self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
-                                            self.rpn_anchor_scales['xy']]
-            self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
-                                            self.rpn_anchor_scales['z']]
-            self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3
-
-            self.n_rpn_features = 256 if self.dim == 2 else 64
-
-            # pre-selection of detections for NMS-speedup. per entire batch.
-            self.pre_nms_limit = (1000 if self.dim == 2 else 6250) * self.batch_size
-
-            # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002
-            self.anchor_matching_iou = 0.5
-
-            if self.model == 'retina_unet':
-                self.operate_stride1 = True
\ No newline at end of file
diff --git a/datasets/prostate/data_loader.py b/datasets/prostate/data_loader.py
deleted file mode 100644
index 69c53e6..0000000
--- a/datasets/prostate/data_loader.py
+++ /dev/null
@@ -1,716 +0,0 @@
-__author__ = ''
-#credit derives from Paul Jaeger, Simon Kohl
-
-import os
-import time
-import warnings
-
-from collections import OrderedDict
-import pickle
-
-import numpy as np
-import pandas as pd
-
-# batch generator tools from https://github.com/MIC-DKFZ/batchgenerators
-from batchgenerators.augmentations.utils import resize_image_by_padding, center_crop_2D_image, center_crop_3D_image
-from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
-from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror
-from batchgenerators.transforms.abstract_transforms import Compose
-from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
-from batchgenerators.dataloading import SingleThreadedAugmenter
-from batchgenerators.transforms.spatial_transforms import SpatialTransform
-from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform
-#from batchgenerators.transforms.utility_transforms import ConvertSegToBoundingBoxCoordinates
-from batchgenerators.transforms import AbstractTransform
-from batchgenerators.transforms.color_transforms import GammaTransform
-
-#sys.path.append(os.path.dirname(os.path.realpath(__file__)))
-
-#import utils.exp_utils as utils
-import utils.dataloader_utils as dutils
-from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates
-import data_manager as dmanager
-
-
-def load_obj(file_path):
-    with open(file_path, 'rb') as handle:
-        return pickle.load(handle)
-
-def id_to_spec(id, base_spec):
-    """Construct subject specifier from base string and an integer subject number."""
-    num_zeros = 5 - len(str(id))
-    assert num_zeros>=0, "id_to_spec: patient id too long to fit into 5 figures"
-    return base_spec + '_' + ('').join(['0'] * num_zeros) + str(id)
-
-def convert_3d_to_2d_generator(data_dict, shape="bcxyz"):
-    """Fold/Shape z-dimension into color-channel.
-    :param shape: bcxyz or bczyx
-    :return: shape b(c*z)xy or b(c*z)yx
-    """
-    if shape=="bcxyz":
-        data_dict['data'] = np.transpose(data_dict['data'], axes=(0,1,4,3,2))
-        data_dict['seg'] = np.transpose(data_dict['seg'], axes=(0,1,4,3,2))
-    elif shape=="bczyx":
-        pass
-    else:
-        raise Exception("unknown datashape {} in 3d_to_2d transform converter".format(shape))
- 
-    shp = data_dict['data'].shape
-    data_dict['orig_shape_data'] = shp
-    seg_shp = data_dict['seg'].shape
-    data_dict['orig_shape_seg'] = seg_shp
-    
-    data_dict['data'] = data_dict['data'].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4]))
-    data_dict['seg'] = data_dict['seg'].reshape((seg_shp[0], seg_shp[1] * seg_shp[2], seg_shp[3], seg_shp[4]))
-
-    return data_dict
-
-def convert_2d_to_3d_generator(data_dict, shape="bcxyz"):
-    """Unfold z-dimension from color-channel.
-    data needs to be in shape bcxy or bcyx, x,y dims won't be swapped relative to each other.
-    :param shape: target shape, bcxyz or bczyx
-    """
-    shp = data_dict['orig_shape_data']
-    cur_shape = data_dict['data'].shape
-    seg_shp = data_dict['orig_shape_seg']
-    cur_shape_seg = data_dict['seg'].shape
-    
-    data_dict['data'] = data_dict['data'].reshape((shp[0], shp[1], shp[2], cur_shape[-2], cur_shape[-1]))
-    data_dict['seg'] = data_dict['seg'].reshape((seg_shp[0], seg_shp[1], seg_shp[2], cur_shape_seg[-2], cur_shape_seg[-1]))
-    
-    if shape=="bcxyz":
-        data_dict['data'] = np.transpose(data_dict['data'], axes=(0,1,4,3,2))
-        data_dict['seg'] = np.transpose(data_dict['seg'], axes=(0,1,4,3,2)) 
-    return data_dict
-
-class Convert3DTo2DTransform(AbstractTransform):
-    def __init__(self):
-        pass
-
-    def __call__(self, **data_dict):
-        return convert_3d_to_2d_generator(data_dict)
-
-class Convert2DTo3DTransform(AbstractTransform):
-    def __init__(self):
-        pass
-
-    def __call__(self, **data_dict):
-        return convert_2d_to_3d_generator(data_dict)
-
-def vector(item):
-    """ensure item is vector-like (list or array or tuple)
-    :param item: anything
-    """
-    if not isinstance(item, (list, tuple, np.ndarray)):
-        item = [item]
-    return item
-
-class Dataset(dutils.Dataset):
-    r"""Load a dict holding memmapped arrays and clinical parameters for each patient,
-    evtly subset of those.
-        If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to
-        cf.data_dest.
-    :param cf: config file
-    :param data_dir: directory in which to find data, defaults to cf.data_dir if None.
-    :return: dict with imgs, segs, pids, class_labels, observables
-    """
-
-    def __init__(self, cf, logger=None, subset_ids=None, data_sourcedir=None):
-        super(Dataset,self).__init__(cf, data_sourcedir=data_sourcedir)
-
-        info_dict = load_obj(cf.info_dict_path)
-
-        if subset_ids is not None:
-            pids = subset_ids
-            if logger is None:
-                print('subset: selected {} instances from df'.format(len(pids)))
-            else:
-                logger.info('subset: selected {} instances from df'.format(len(pids)))
-        else:
-            pids = list(info_dict.keys())
-
-        #evtly copy data from data_rootdir to data_dir
-        if cf.server_env and not hasattr(cf, "data_dir"):
-            file_subset = [info_dict[pid]['img'][:-3]+"*" for pid in pids]
-            file_subset+= [info_dict[pid]['seg'][:-3]+"*" for pid in pids]
-            file_subset += [cf.info_dict_path]
-            self.copy_data(cf, file_subset=file_subset)
-            cf.data_dir = self.data_dir
-
-        img_paths = [os.path.join(self.data_dir, info_dict[pid]['img']) for pid in pids]
-        seg_paths = [os.path.join(self.data_dir, info_dict[pid]['seg']) for pid in pids]
-
-        # load all subject files
-        self.data = OrderedDict()
-        for i, pid in enumerate(pids):
-            subj_spec = id_to_spec(pid, cf.prepro['dir_spec'])
-            subj_data = {'pid':pid, "spec":subj_spec}
-            subj_data['img'] = img_paths[i]
-            subj_data['seg'] = seg_paths[i]
-            #read, add per-roi labels
-            for obs in cf.observables_patient+cf.observables_rois:
-                subj_data[obs] = np.array(info_dict[pid][obs])
-            if 'class' in self.cf.prediction_tasks:
-                subj_data['class_targets'] = np.array(info_dict[pid]['roi_classes'], dtype='uint8') + 1
-            else:
-                subj_data['class_targets'] = np.ones_like(np.array(info_dict[pid]['roi_classes']), dtype='uint8')
-            if any(['regression' in task for task in self.cf.prediction_tasks]):
-                if hasattr(cf, "rg_map"):
-                    subj_data["regression_targets"] = np.array([vector(cf.rg_map[v]) for v in info_dict[pid][cf.regression_target]], dtype='float16')
-                else:
-                    subj_data["regression_targets"] = np.array([vector(v) for v in info_dict[pid][cf.regression_target]], dtype='float16')
-                subj_data["rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in subj_data["regression_targets"]], dtype='uint8')
-            subj_data['fg_slices'] = info_dict[pid]['fg_slices']
-
-            self.data[pid] = subj_data
-
-        cf.roi_items = cf.observables_rois[:]
-        cf.roi_items += ['class_targets']
-        if any(['regression' in task for task in self.cf.prediction_tasks]):
-            cf.roi_items += ['regression_targets']
-            cf.roi_items += ['rg_bin_targets']
-        #cf.patient_items = cf.observables_patient[:]
-        #patient-wise items not used currently
-        self.set_ids = np.array(list(self.data.keys()))
-
-        self.df = None
-
-class BatchGenerator(dutils.BatchGenerator):
-    """
-    create the training/validation batch generator. Randomly sample batch_size patients
-    from the data set, (draw a random slice if 2D), pad-crop them to equal sizes and merge to an array.
-    :param data: data dictionary as provided by 'load_dataset'
-    :param img_modalities: list of strings ['adc', 'b1500'] from config
-    :param batch_size: number of patients to sample for the batch
-    :param pre_crop_size: equal size for merging the patients to a single array (before the final random-crop in data aug.)
-    :param sample_pids_w_replace: whether to randomly draw pids from dataset for batch generation. if False, step through whole dataset
-        before repition.
-    :return dictionary containing the batch data / seg / pids as lists; the augmenter will later concatenate them into an array.
-    """
-    def __init__(self, cf, data, n_batches=None, sample_pids_w_replace=True):
-        super(BatchGenerator, self).__init__(cf, data,  n_batches)
-        self.dataset_length = len(self._data)
-        self.cf = cf
-
-        self.sample_pids_w_replace = sample_pids_w_replace
-        self.eligible_pids = list(self._data.keys())
-
-        self.chans = cf.channels if cf.channels is not None else np.index_exp[:]
-        assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing"
-
-        self.p_fg = 0.5
-        self.empty_samples_max_ratio = 0.6
-        self.random_count = int(cf.batch_random_ratio * cf.batch_size)
-
-        self.balance_target_distribution(plot=sample_pids_w_replace)
-        self.stats = {"roi_counts" : np.zeros((len(self.unique_ts),), dtype='uint32'), "empty_samples_count" : 0}
-        
-    def generate_train_batch(self):
-        #everything done in here is per batch
-        #print statements in here get confusing due to multithreading
-        if self.sample_pids_w_replace:
-            # fully random patients
-            batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False))
-            # target-balanced patients
-            batch_patient_ids += list(np.random.choice(
-                self.dataset_pids, size=self.batch_size - self.random_count, replace=False, p=self.p_probs))
-        else:
-            batch_patient_ids = np.random.choice(self.eligible_pids, size=self.batch_size,
-                                                 replace=False)
-        if self.sample_pids_w_replace == False:
-            self.eligible_pids = [pid for pid in self.eligible_pids if pid not in batch_patient_ids]
-            if len(self.eligible_pids) < self.batch_size:
-                self.eligible_pids = self.dataset_pids
-        
-        batch_data, batch_segs, batch_patient_specs = [], [], []
-        batch_roi_items = {name: [] for name in self.cf.roi_items}
-        #record roi count of classes in batch
-        batch_roi_counts, empty_samples_count = np.zeros((len(self.unique_ts),), dtype='uint32'), 0
-        #empty count for full bg samples (empty slices in 2D/patients in 3D)
-
-        for sample in range(self.batch_size):
-
-            patient = self._data[batch_patient_ids[sample]]
-            
-            #swap dimensions from (c,)z,y,x to (c,)y,x,z or h,w,d to ease 2D/3D-case handling
-            data = np.transpose(np.load(patient['img'], mmap_mode='r'), axes=(0, 2, 3, 1))[self.chans]
-            seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0))
-            (c,y,x,z) = data.shape 
-
-            #original data is 3D MRIs, so need to pick (e.g. randomly) single slice to make it 2D,
-            #consider batch roi-class balance
-            if self.cf.dim == 2:
-                elig_slices, choose_fg = [], False
-                if self.sample_pids_w_replace and len(patient['fg_slices']) > 0:
-                    if empty_samples_count / self.batch_size >= self.empty_samples_max_ratio or np.random.rand(
-                            1) <= self.p_fg:
-                        # fg is to be picked
-                        for tix in np.argsort(batch_roi_counts):
-                            # pick slices of patient that have roi of sought-for target
-                            # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix
-                            elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero(
-                                patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] ==
-                                self.unique_ts[tix]) > 0]
-                            if len(elig_slices) > 0:
-                                choose_fg = True
-                                break
-                    else:
-                        # pick bg
-                        elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices'])
-                if len(elig_slices) == 0:
-                    elig_slices = z
-                sl_pick_ix = np.random.choice(elig_slices, size=None)
-                data = data[..., sl_pick_ix]
-                seg = seg[..., sl_pick_ix]
-
-            spatial_shp = data[0].shape
-            assert spatial_shp==seg.shape, "spatial shape incongruence betw. data and seg"
-
-            if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]):
-                new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))]
-                data = dutils.pad_nd_image(data, (len(data), *new_shape))
-                seg = dutils.pad_nd_image(seg, new_shape)
-            
-            #eventual cropping to pre_crop_size: with prob self.p_fg sample pixel from random ROI and shift center,
-            #if possible, to that pixel, so that img still contains ROI after pre-cropping
-            dim_cropflags = [spatial_shp[i] > self.cf.pre_crop_size[i] for i in range(len(spatial_shp))]
-            if np.any(dim_cropflags):
-                print("dim crop applied")
-                # sample pixel from random ROI and shift center, if possible, to that pixel
-                if self.cf.dim==3:
-                    choose_fg = (empty_samples_count/self.batch_size>=self.empty_samples_max_ratio) or np.random.rand(1) <= self.p_fg
-                if self.sample_pids_w_replace and choose_fg and np.any(seg):
-                    available_roi_ids = np.unique(seg)[1:]
-                    for tix in np.argsort(batch_roi_counts):
-                        elig_roi_ids = available_roi_ids[
-                            patient[self.balance_target][available_roi_ids - 1] == self.unique_ts[tix]]
-                        if len(elig_roi_ids) > 0:
-                            seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None))
-                            break
-                    roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)]
-                    assert seg[tuple(roi_anchor_pixel)] > 0
-
-                    # sample the patch center coords. constrained by edges of image - pre_crop_size /2 and 
-                    # distance to the selected ROI < patch_size /2
-                    def get_cropped_centercoords(dim):     
-                        low = np.max((self.cf.pre_crop_size[dim]//2,
-                                      roi_anchor_pixel[dim] - (self.cf.patch_size[dim]//2 - self.cf.crop_margin[dim])))
-                        high = np.min((spatial_shp[dim] - self.cf.pre_crop_size[dim]//2,
-                                       roi_anchor_pixel[dim] + (self.cf.patch_size[dim]//2 - self.cf.crop_margin[dim])))
-                        if low >= high: #happens if lesion on the edge of the image.
-                            #print('correcting low/high:', low, high, spatial_shp, roi_anchor_pixel, dim)
-                            low = self.cf.pre_crop_size[dim] // 2
-                            high = spatial_shp[dim] - self.cf.pre_crop_size[dim]//2
-                        
-                        assert low<high, 'low greater equal high, data dimension {} too small, shp {}, patient {}, low {}, high {}'.format(dim, 
-                                                                         spatial_shp, patient['pid'], low, high)
-                        return np.random.randint(low=low, high=high)
-                else:
-                    #sample crop center regardless of ROIs, not guaranteed to be empty
-                    def get_cropped_centercoords(dim):                        
-                        return np.random.randint(low=self.cf.pre_crop_size[dim]//2,
-                                                 high=spatial_shp[dim] - self.cf.pre_crop_size[dim]//2)
-                    
-                sample_seg_center = {}
-                for dim in np.where(dim_cropflags)[0]:
-                    sample_seg_center[dim] = get_cropped_centercoords(dim)
-                    min_ = int(sample_seg_center[dim] - self.cf.pre_crop_size[dim]//2)
-                    max_ = int(sample_seg_center[dim] + self.cf.pre_crop_size[dim]//2)
-                    data = np.take(data, indices=range(min_, max_), axis=dim+1) #+1 for channeldim
-                    seg = np.take(seg, indices=range(min_, max_), axis=dim)
-                    
-            batch_data.append(data)
-            batch_segs.append(seg[np.newaxis])
-                
-            for o in batch_roi_items: #after loop, holds every entry of every batchpatient per roi-item
-                    batch_roi_items[o].append(patient[o])
-            batch_patient_specs.append(patient['spec'])
-
-            if self.cf.dim == 3:
-                for tix in range(len(self.unique_ts)):
-                    batch_roi_counts[tix] += np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix])
-            elif self.cf.dim == 2:
-                for tix in range(len(self.unique_ts)):
-                    batch_roi_counts[tix] += np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix])
-            if not np.any(seg):
-                empty_samples_count += 1
-
-        #self.stats['roi_counts'] += batch_roi_counts #DOESNT WORK WITH MULTITHREADING! do outside
-        #self.stats['empty_samples_count'] += empty_samples_count
-
-        batch = {'data': np.array(batch_data), 'seg': np.array(batch_segs).astype('uint8'),
-                 'pid': batch_patient_ids, 'spec': batch_patient_specs,
-                 'roi_counts':batch_roi_counts, 'empty_samples_count': empty_samples_count}
-        for key,val in batch_roi_items.items(): #extend batch dic by roi-wise items (obs, class ids, regression vectors...)
-            batch[key] = np.array(val)
-
-        return batch
-
-class PatientBatchIterator(dutils.PatientBatchIterator):
-    """
-    creates a val/test generator. Step through the dataset and return dictionaries per patient.
-    2D is a special case of 3D patching with patch_size[2] == 1 (slices)
-    Creates whole Patient batch and targets, and - if necessary - patchwise batch and targets.
-    Appends patient targets anyway for evaluation.
-    For Patching, shifts all patches into batch dimension. batch_tiling_forward will take care of exceeding batch dimensions.
-    
-    This iterator/these batches are not intended to go through MTaugmenter afterwards
-    """
-
-    def __init__(self, cf, data):
-        super(PatientBatchIterator, self).__init__(cf, data)
-
-        self.patient_ix = 0 #running index over all patients in set
-        
-        self.patch_size = cf.patch_size+[1] if cf.dim==2 else cf.patch_size
-        self.chans = cf.channels if cf.channels is not None else np.index_exp[:]
-        assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing"
-
-    def generate_train_batch(self, pid=None):
-        
-        if self.patient_ix == len(self.dataset_pids):
-            self.patient_ix = 0          
-        if pid is None:
-            pid = self.dataset_pids[self.patient_ix] # + self.thread_id
-        patient = self._data[pid]
-
-        #swap dimensions from (c,)z,y,x to c,y,x,z or h,w,d to ease 2D/3D-case handling
-        data  = np.transpose(np.load(patient['img'], mmap_mode='r'), axes=(0, 2, 3, 1))
-        seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis]
-        data_shp_raw = data.shape
-        plot_bg = data[self.cf.plot_bg_chan] if self.cf.plot_bg_chan not in self.chans else None
-        data = data[self.chans]
-        discarded_chans = len(
-            [c for c in np.setdiff1d(np.arange(data_shp_raw[0]), self.chans) if c < self.cf.plot_bg_chan])
-        spatial_shp = data[0].shape # spatial dims need to be in order x,y,z
-        assert spatial_shp==seg[0].shape, "spatial shape incongruence betw. data and seg"
-                
-        if np.any([spatial_shp[i] < ps for i, ps in enumerate(self.patch_size)]):
-            new_shape = [np.max([spatial_shp[i], self.patch_size[i]]) for i in range(len(self.patch_size))]
-            data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape.
-            seg = dutils.pad_nd_image(seg, new_shape)
-            if plot_bg is not None:
-                plot_bg = dutils.pad_nd_image(plot_bg, new_shape)
-
-        if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds:
-            #adds the batch dim here bc won't go through MTaugmenter 
-            out_data = data[np.newaxis]
-            out_seg = seg[np.newaxis]
-            if plot_bg is not None:
-                out_plot_bg = plot_bg[np.newaxis]
-            #data and seg shape: (1,c,x,y,z), where c=1 for seg
-            batch_3D = {'data': out_data, 'seg': out_seg}
-            for o in self.cf.roi_items:
-                batch_3D[o] = np.array([patient[o]])
-            converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg)
-            batch_3D = converter(**batch_3D)
-            batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape})
-            for o in self.cf.roi_items:
-                batch_3D["patient_" + o] = batch_3D[o]
-
-        if self.cf.dim == 2:
-            out_data = np.transpose(data, axes=(3,0,1,2)) #(c,y,x,z) to (b=z,c,x,y), use z=b as batchdim
-            out_seg = np.transpose(seg, axes=(3,0,1,2)).astype('uint8')   #(c,y,x,z) to (b=z,c,x,y)
-
-            batch_2D = {'data': out_data, 'seg': out_seg}
-            for o in self.cf.roi_items:
-                batch_2D[o] = np.repeat(np.array([patient[o]]), len(out_data), axis=0)
-
-            converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg)
-            batch_2D = converter(**batch_2D)
-
-            if plot_bg is not None:
-                out_plot_bg = np.transpose(plot_bg, axes=(2,0,1)).astype('float32')
-
-            if self.cf.merge_2D_to_3D_preds:
-                batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'],
-                                      'original_img_shape': out_data.shape})
-                for o in self.cf.roi_items:
-                    batch_2D["patient_" + o] = batch_3D['patient_'+o]
-            else:
-                batch_2D.update({'patient_bb_target': batch_2D['bb_target'],
-                                 'original_img_shape': out_data.shape})
-                for o in self.cf.roi_items:
-                    batch_2D["patient_" + o] = batch_2D[o]
-
-        out_batch = batch_3D if self.cf.dim == 3 else batch_2D
-        out_batch.update({'pid': np.array([patient['pid']] * len(out_data)),
-                         'spec':np.array([patient['spec']] * len(out_data))})
-
-        if self.cf.plot_bg_chan in self.chans and discarded_chans>0:
-            assert plot_bg is None
-            plot_bg = int(self.cf.plot_bg_chan - discarded_chans)
-            out_plot_bg = plot_bg
-        if plot_bg is not None:
-            out_batch['plot_bg'] = out_plot_bg
-
-        #eventual tiling into patches
-        spatial_shp = out_batch["data"].shape[2:]
-        if np.any([spatial_shp[ix] > self.patch_size[ix] for ix in range(len(spatial_shp))]):
-            patient_batch = out_batch
-            #print("patientiterator produced patched batch!")
-            patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size)
-            new_img_batch, new_seg_batch = [], []
-
-            for c in patch_crop_coords_list:
-                new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:c[5]])
-                seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]]
-                new_seg_batch.append(seg_patch)
-            shps = []
-            for arr in new_img_batch:
-                shps.append(arr.shape)
-            
-            data = np.array(new_img_batch) # (patches, c, x, y, z)
-            seg = np.array(new_seg_batch)
-            if self.cf.dim == 2:
-                # all patches have z dimension 1 (slices). discard dimension
-                data = data[..., 0]
-                seg = seg[..., 0]
-            patch_batch = {'data': data, 'seg': seg.astype('uint8'),
-                                'pid': np.array([patient['pid']] * data.shape[0]),
-                                'spec':np.array([patient['spec']] * data.shape[0])}
-            for o in self.cf.roi_items:
-                patch_batch[o] = np.repeat(np.array([patient[o]]), len(patch_crop_coords_list), axis=0)
-            # patient-wise (orig) batch info for putting the patches back together after prediction
-            for o in self.cf.roi_items:
-                patch_batch["patient_"+o] = patient_batch['patient_'+o]
-            patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
-            patch_batch['patient_bb_target'] = patient_batch['patient_bb_target']
-            #patch_batch['patient_roi_labels'] = patient_batch['patient_roi_labels']
-            patch_batch['patient_data'] = patient_batch['data']
-            patch_batch['patient_seg'] = patient_batch['seg']
-            patch_batch['original_img_shape'] = patient_batch['original_img_shape']
-            if plot_bg is not None:
-                patch_batch['patient_plot_bg'] = patient_batch['plot_bg']
-
-            converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, self.cf.class_specific_seg)
-            
-            patch_batch = converter(**patch_batch)
-            out_batch = patch_batch
-        
-        self.patient_ix += 1
-        # todo raise stopiteration when in test mode
-        if self.patient_ix == len(self.dataset_pids):
-            self.patient_ix = 0
-
-        return out_batch
-
-
-def create_data_gen_pipeline(cf, patient_data, do_aug=True, sample_pids_w_replace=True):
-    """
-    create mutli-threaded train/val/test batch generation and augmentation pipeline.
-    :param patient_data: dictionary containing one dictionary per patient in the train/test subset
-    :param test_pids: (optional) list of test patient ids, calls the test generator.
-    :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing)
-    :return: multithreaded_generator
-    """
-    data_gen = BatchGenerator(cf, patient_data, sample_pids_w_replace=sample_pids_w_replace)
-    
-    my_transforms = []
-    if do_aug:
-        if cf.da_kwargs["mirror"]:
-            mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes'])
-            my_transforms.append(mirror_transform)
-        if cf.da_kwargs["gamma_transform"]:
-            gamma_transform = GammaTransform(gamma_range=cf.da_kwargs["gamma_range"], invert_image=False,
-                                             per_channel=False, retain_stats=True)
-            my_transforms.append(gamma_transform)
-        if cf.dim == 3:
-            # augmentations with desired effect on z-dimension
-            spatial_transform = SpatialTransform(patch_size=cf.patch_size,
-                                             patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
-                                             do_elastic_deform=False,
-                                             do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
-                                             angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
-                                             do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
-                                             random_crop=cf.da_kwargs['random_crop'],
-                                             border_mode_data=cf.da_kwargs['border_mode_data'])
-            my_transforms.append(spatial_transform)
-            # augmentations that are only meant to affect x-y
-            my_transforms.append(Convert3DTo2DTransform())
-            spatial_transform = SpatialTransform(patch_size=cf.patch_size[:2],
-                                             patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'][:2],
-                                             do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
-                                             alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
-                                             do_rotation=False,
-                                             do_scale=False,
-                                             random_crop=False,
-                                             border_mode_data=cf.da_kwargs['border_mode_data'])
-            my_transforms.append(spatial_transform)
-            my_transforms.append(Convert2DTo3DTransform())
-
-        else:
-            spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
-                                             patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'][:2],
-                                             do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
-                                             alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
-                                             do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
-                                             angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
-                                             do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
-                                             random_crop=cf.da_kwargs['random_crop'],
-                                             border_mode_data=cf.da_kwargs['border_mode_data'])
-            my_transforms.append(spatial_transform)
-    else:
-        my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))
-
-    if cf.create_bounding_box_targets:
-        my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg))
-        #batch receives entry 'bb_target' w bbox coordinates as [y1,x1,y2,x2,z1,z2].
-    #my_transforms.append(ConvertSegToOnehotTransform(classes=range(cf.num_seg_classes)))
-    all_transforms = Compose(my_transforms)
-    #MTAugmenter creates iterator from data iterator data_gen after applying the composed transform all_transforms
-    multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers,
-                                                     seeds=list(np.random.randint(0,cf.n_workers*2,size=cf.n_workers)))
-    return multithreaded_generator
-
-def get_train_generators(cf, logger, data_statistics=True):
-    """
-    wrapper function for creating the training batch generator pipeline. returns the train/val generators
-    need to select cv folds on patient level, but be able to include both breasts of each patient.
-    """
-    dataset = Dataset(cf, logger)
-
-    dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits)
-    dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle'))
-    set_splits = dataset.fg.splits
-
-    test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold-1)
-    train_ids = np.concatenate(set_splits, axis=0)
-
-    if cf.held_out_test_set:
-        train_ids = np.concatenate((train_ids, test_ids), axis=0)
-        test_ids = []
-
-    train_data = {k: v for (k, v) in dataset.data.items() if k in train_ids}
-    val_data = {k: v for (k, v) in dataset.data.items() if k in val_ids}
-    
-    logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids), len(test_ids)))
-    if data_statistics:
-        dataset.calc_statistics(subsets={"train":train_ids, "val":val_ids, "test":test_ids},
-                                plot_dir=os.path.join(cf.plot_dir,"dataset"))
-        
-    batch_gen = {}
-    batch_gen['train'] = create_data_gen_pipeline(cf, train_data, do_aug=cf.do_aug)
-    batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, do_aug=False, sample_pids_w_replace=False)
-
-    if cf.val_mode == 'val_patient':
-        batch_gen['val_patient'] = PatientBatchIterator(cf, val_data)
-        batch_gen['n_val'] = len(val_ids) if cf.max_val_patients=="all" else cf.max_val_patients
-    elif cf.val_mode == 'val_sampling':
-        batch_gen['n_val'] = cf.num_val_batches if cf.num_val_batches!="all" else len(val_ids)
-    
-    return batch_gen
-
-def get_test_generator(cf, logger):
-    """
-    if get_test_generators is called multiple times in server env, every time of 
-    Dataset initiation rsync will check for copying the data; this should be okay
-    since rsync will not copy if files already exist in destination.
-    """
-
-    if cf.held_out_test_set:
-        sourcedir = cf.test_data_sourcedir
-        test_ids = None
-    else:
-        sourcedir = None
-        with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle:
-            set_splits = pickle.load(handle)
-        test_ids = set_splits[cf.fold]
-
-    test_set = Dataset(cf, logger, test_ids, data_sourcedir=sourcedir)
-    logger.info("data set loaded with: {} test patients".format(len(test_set.set_ids)))
-    batch_gen = {}
-    batch_gen['test'] = PatientBatchIterator(cf, test_set.data)
-    batch_gen['n_test'] = len(test_set.set_ids) if cf.max_test_patients=="all" else min(cf.max_test_patients, len(test_set.set_ids))
-    
-    return batch_gen
-
-
-if __name__=="__main__":
-    import sys
-    sys.path.append('../')  # works on cluster indep from where sbatch job is started
-    import plotting as plg
-    import utils.exp_utils as utils
-    from configs import Configs
-    cf = configs()
-    
-    total_stime = time.time()
-    times = {}
-
-    #cf.server_env = True
-    #cf.data_dir = "experiments/dev_data"
-    
-    #dataset = Dataset(cf)
-    #patient = dataset['Master_00018'] 
-    cf.exp_dir = "experiments/dev/"
-    cf.plot_dir = cf.exp_dir+"plots"
-    os.makedirs(cf.exp_dir, exist_ok=True)
-    cf.fold = 0
-    logger = utils.get_logger(cf.exp_dir)
-    gens = get_train_generators(cf, logger)
-    train_loader = gens['train']
-    
-    #for i in range(train_loader.dataset_length):
-    #    print("batch", i)
-    stime = time.time()
-    ex_batch = next(train_loader)
-    #ex_batch = next(train_loader)
-    times["train_batch"] = time.time()-stime
-    plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch.png", show_gt_labels=True)
-
-    #with open(os.path.join(cf.exp_dir, "fold_"+str(cf.fold), "BatchGenerator_stats.txt"), mode="w") as file:
-    #    train_loader.generator.print_stats(logger, file)
-
-
-    val_loader = gens['val_sampling']
-    stime = time.time()
-    ex_batch = next(val_loader)
-    times["val_batch"] = time.time()-stime
-    stime = time.time()
-    plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch.png", show_gt_labels=True, plot_mods=False, show_info=False)
-    times["val_plot"] = time.time()-stime
-    
-    test_loader = get_test_generator(cf, logger)["test"]
-    stime = time.time()
-    ex_batch = test_loader.generate_train_batch()
-    print(ex_batch["data"].shape)
-    times["test_batch"] = time.time()-stime
-    stime = time.time()
-    plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/ex_patchbatch.png", show_gt_boxes=False, show_info=False, dpi=400, sample_picks=[2,5], plot_mods=False)
-    times["test_patchbatch_plot"] = time.time()-stime
-
-    #stime = time.time()
-    #ex_batch['data'] = ex_batch['patient_data']
-    #ex_batch['seg'] = ex_batch['patient_seg']
-    #if 'patient_plot_bg' in ex_batch.keys():
-    #    ex_batch['plot_bg'] = ex_batch['patient_plot_bg']
-    #plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png")
-    #times["test_patientbatch_plot"] = time.time() - stime
-    
-    
-    #print("patch batch keys", ex_batch.keys())
-    #print("patch batch les gle", ex_batch["lesion_gleasons"].shape)
-    #print("patch batch gsbx", ex_batch["GSBx"].shape)
-    #print("patch batch class_targ", ex_batch["class_targets"].shape)
-    #print("patient b roi labels", ex_batch["patient_roi_labels"].shape)
-    #print("patient les gleas", ex_batch["patient_lesion_gleasons"].shape)
-    #print("patch&patient batch pid", ex_batch["pid"], len(ex_batch["pid"]))
-    #print("unique patient_seg", np.unique(ex_batch["patient_seg"]))
-    #print("pb patient roi labels", len(ex_batch["patient_roi_labels"]), ex_batch["patient_roi_labels"])
-    #print("pid", ex_batch["pid"])
-    
-    #patient_batch = {k[len("patient_"):]:v for (k,v) in ex_batch.items() if k.lower().startswith("patient")}
-    #patient_batch["pid"] = ex_batch["pid"]
-    #stime = time.time()
-    #plg.view_batch(cf, patient_batch, out_file="experiments/dev_expatientbatch")
-    #times["test_plot"] = time.time()-stime
-    
-    
-    print("Times recorded throughout:")
-    for (k,v) in times.items():
-        print(k, "{:.2f}".format(v))
-    
-    mins, secs = divmod((time.time() - total_stime), 60)
-    h, mins = divmod(mins, 60)
-    t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) 
-    print("{} total runtime: {}".format(os.path.split(__file__)[1], t))
\ No newline at end of file
diff --git a/datasets/prostate/data_preprocessing.py b/datasets/prostate/data_preprocessing.py
deleted file mode 100644
index ca97532..0000000
--- a/datasets/prostate/data_preprocessing.py
+++ /dev/null
@@ -1,809 +0,0 @@
-__author__ = "Simon Kohl, Gregor Ramien"
-
-
-# subject-wise extractor that does not depend on Prisma/Radval and that checks for geometry miss-alignments
-# (corrects them if applicable), images and masks should be stored separately, each in its own memmap
-# at run-time, the data-loaders will assemble dicts using the histo csvs
-import os
-import sys
-from multiprocessing import Pool
-import warnings
-import time
-import shutil
-
-import pandas as pd
-import numpy as np
-import pickle
-
-import SimpleITK as sitk
-from scipy.ndimage.measurements import center_of_mass
-
-sys.path.append("../")
-import plotting as plg
-import data_manager as dmanager
-
-def save_obj(obj, name):
-    """Pickle a python object."""
-    with open(name + '.pkl', 'wb') as f:
-        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
-
-def load_array(path):
-    """Load an image as a numpy array."""
-    img = sitk.ReadImage(path)
-    return sitk.GetArrayFromImage(img)
-
-def id_to_spec(id, base_spec):
-    """Construct subject specifier from base string and an integer subject number."""
-    num_zeros = 5 - len(str(id))
-    assert num_zeros>=0, "id_to_spec: patient id too long to fit into 5 figures"
-    return base_spec + '_' + ('').join(['0'] * num_zeros) + str(id)
-
-def spec_to_id(spec):
-    """Get subject id from string"""
-    return int(spec[-5:])
-
-def has_equal_geometry(img1, img2, precision=0.001):
-    """Check whether geometries of 2 images match within a given precision."""
-    equal = True
-
-    # assert equal image extentions
-    delta = [abs((img1.GetSize()[i] - img2.GetSize()[i])) < precision for i in range(3)]
-    if not np.all(delta):
-        equal = False
-
-    # assert equal origins
-    delta = [abs((img1.GetOrigin()[i] - img2.GetOrigin()[i])) < precision for i in range(3)]
-    if not np.all(delta):
-        equal = False
-
-    # assert equal spacings
-    delta = [abs((img1.GetSpacing()[i] - img2.GetSpacing()[i])) < precision for i in range(3)]
-    if not np.all(delta):
-        equal = False
-
-    return equal
-
-def resample_to_reference(ref_img, img, interpolation):
-    """
-    Resample an sitk image to a reference image, the size, spacing,
-    origin and direction of the reference image will be used
-    :param ref_img:
-    :param img:
-    :param interpolation:
-    :return: interpolated SITK image
-    """
-    if interpolation == 'nearest':
-        interpolator = sitk.sitkNearestNeighbor #these are just integers
-    elif interpolation == 'linear':
-        interpolator = sitk.sitkLinear
-    elif interpolation == 'bspline':
-        # basis spline of order 3
-        interpolator = sitk.sitkBSpline
-    else:
-        raise NotImplementedError('Interpolation of type {} not implemented!'.format(interpolation))
-
-    img = sitk.Cast(img, sitk.sitkFloat64)
-
-    rif = sitk.ResampleImageFilter()
-    # set the output size, origin, spacing and direction to that of the provided image
-    rif.SetReferenceImage(ref_img) 
-    rif.SetInterpolator(interpolator)
-
-    return rif.Execute(img)
-
-def rescale(img, scaling, interpolation=sitk.sitkBSpline, out_fpath=None):
-    """
-    :param scaling: tuple (z_scale, y_scale, x_scale) of scaling factors
-    :param out_fpath: filepath (incl filename), if set will write .nrrd (uncompressed)
-        to that location
-    
-    sitk/nrrd images spacing: imgs are treated as physical objects. When resampling,
-    a given image is re-evaluated (resampled) at given gridpoints, the physical 
-    properties of the image don't change. Hence, if the resampling-grid has a smaller
-    spacing than the original image(grid), the image is sampled more often than before.
-    Since every sampling produces one pixel, the resampled image will have more pixels
-    (when sampled at undefined points of the image grid, the sample values will be
-    interpolated). I.e., for an upsampling of an image, we need to set a smaller
-    spacing for the resampling grid and a larger (pixel)size for the resampled image.
-    """
-    (z,y,x) = scaling
-    
-    old_size = np.array(img.GetSize())
-    old_spacing = np.array(img.GetSpacing())
-    
-
-    new_size = (int(old_size[0]*x), int(old_size[1]*y), int(old_size[2]*z))
-    new_spacing = old_spacing * (old_size/ new_size)
-    
-    rif = sitk.ResampleImageFilter()
-    
-    rif.SetReferenceImage(img)
-    rif.SetInterpolator(interpolation)
-    rif.SetOutputSpacing(new_spacing)
-    rif.SetSize(new_size)
-    
-    new_img = rif.Execute(img)
-    
-    if not out_fpath is None:
-        writer = sitk.ImageFileWriter()
-        writer.SetFileName(out_fpath)
-        writer.SetUseCompression(True)
-        writer.Execute(new_img)
-    
-    return new_img
-
-def get_valid_z_range(arr):
-    """
-    check which z-slices of an image array aren't constant
-    :param arr:
-    :return: min and max valid slice found; under the assumption that invalid 
-        slices occur never inbetween valid slices
-    """
-
-    valid_z_slices = []
-    for z in range(arr.shape[0]):
-        if np.var(arr[z]) != 0:
-            valid_z_slices.append(z)
-    return valid_z_slices[0], valid_z_slices[-1] 
-
-def convert_to_arrays(data):
-    """convert to numpy arrays.
-        sitk.Images have shape (x,y,z), but GetArrayFromImage returns shape (z,y,x)
-    """
-    for mod in data['img'].keys():
-        data['img'][mod] = sitk.GetArrayFromImage(data['img'][mod]).astype(np.float32)
-
-    for mask in data['anatomical_masks'].keys():
-        data['anatomical_masks'][mask] = sitk.GetArrayFromImage(data['anatomical_masks'][mask]).astype(np.uint8)
-
-    for mask in data['lesions'].keys():
-        data['lesions'][mask] = sitk.GetArrayFromImage(data['lesions'][mask]).astype(np.uint8)
-    return data
-
-def merge_crossmod_masks(data, rename_tags, mode="union"):
-    """if data has multiple ground truths (e.g. after registration), merge 
-        masks by mode. class labels (leason gleason) are assumed to be naturally registered (no ambiguity)
-    :param rename_tags: usually from prepro_cf['rename_tags']
-    :param mode: 'union' or name of mod ('adc', 't2') to consider only one gt
-    """   
-
-    if 'adc' in data['img'].keys() and 't2' in data['img'].keys():
-        if mode=='union':
-            #print("Merging gts of T2, ADC mods. Assuming data is registered!")
-            tags = list(data["anatomical_masks"].keys())
-            for tag in tags:
-                tags.remove(tag)
-                merge_with = [mtag for mtag in tags\
-                              if mtag.lower().split("_")[2]==tag.lower().split("_")[2]]
-                assert len(merge_with)==1, "attempted to merge {} ground truths".format(len(merge_with))
-                merge_with = merge_with[0]
-                tags.remove(merge_with)
-                #masks are binary
-                #will throw error if masks dont have same shape
-                data["anatomical_masks"][tag] = np.logical_or(data["anatomical_masks"][tag].astype(np.uint8),
-                    data["anatomical_masks"].pop(merge_with).astype(np.uint8)).astype(np.uint8)
-
-            tags = list(data["lesions"].keys())
-            for tag in tags:
-                tags.remove(tag)
-                merge_with = [mtag for mtag in tags\
-                              if mtag.lower().split("_")[2]==tag.lower().split("_")[2]]
-                assert len(merge_with)==1, "attempted to merge {} ground truths".format(len(merge_with))
-                merge_with = merge_with[0]
-                tags.remove(merge_with)
-                data["lesions"][tag] = np.logical_or(data["lesions"][tag],
-                    data["lesions"].pop(merge_with)).astype(np.uint8)
-
-        elif mode=='adc' or mode=='t2':
-            data["anatomical_masks"] = {tag:v for tag,v in data["anatomical_masks"].items() if
-                                        tag.lower().split("_")[1]==mode}
-            data["lesions"] = {tag: v for tag, v in data["lesions"].items() if tag.lower().split("_")[1] == mode}
-
-        else:
-            raise Exception("cross-mod gt merge mode {} not implemented".format(mode))
-
-    for tag in list(data["anatomical_masks"]):
-        data["anatomical_masks"][rename_tags[tag]] = data["anatomical_masks"].pop(tag)
-        #del data["anatomical_masks"][tag]
-    for tag in list(data["lesions"]):
-        new_tag = "seg_REG_"+"".join(tag.split("_")[2:])
-        data["lesions"][new_tag] = data["lesions"].pop(tag)
-        data["lesion_gleasons"][new_tag] = data["lesion_gleasons"].pop(tag)
-
-    return data
-
-def crop_3D(data, pre_crop_size, center_of_mass_crop=True):
-    pre_crop_size = np.array(pre_crop_size)
-    # restrain z-ranges to where ADC has valid entries
-    if 'adc' in data['img'].keys():
-        ref_mod = 'adc'
-        comp_mod = 't2'
-    else:
-        ref_mod = 't2'
-        comp_mod = 'adc'
-    min_z, max_z = get_valid_z_range(data['img'][ref_mod])    
-    if comp_mod in data['img'].keys():
-        assert (min_z, max_z) == get_valid_z_range(data['img'][comp_mod]), "adc, t2 different valid z range"
-
-    if center_of_mass_crop:
-        # cut the arrays to the desired x_y_crop_size around the center-of-mass of the PRO segmentation
-        pro_com = center_of_mass(data['anatomical_masks']['pro'])
-        center = [int(np.round(i, 0)) for i in pro_com]
-    else:
-        center = [data['img'][ref_mod].shape[i] // 2 for i in range(3)]
-    
-    
-    l = pre_crop_size // 2
-    #z_low, z_up = max(min_z, center[0] - l[0]), min(max_z + 1, center[0] + l[0])
-    z_low, z_up = center[0] - l[0], center[0] + l[0]
-    while z_low<min_z or z_up>max_z+1:
-        if z_low<min_z:
-            z_low += 1
-            z_up += 1
-            if z_up>max_z+1:
-                warnings.warn("could not crop patient {}'s z-dim to demanded size.".format(data['Original_ID']))
-        if z_up>max_z+1:
-            z_low -= 1
-            z_up -= 1
-            if z_low<min_z:
-                warnings.warn("could not crop patient {}'s z-dim to demanded size.".format(data['Original_ID']))
-
-    #ensure too small image/ too large pcropsize don't lead to error
-    d = np.array((z_low, center[1]-l[1], center[2]-l[2]))
-    assert np.all(d>=0),\
-        "Precropsize too large for image dimensions by {} pixels in patient {}".format(d, data['Original_ID'])
-
-    for mod in data['img'].keys():
-        data['img'][mod] = data['img'][mod][z_low:z_up, center[1]-l[1]: center[1] + l[1], center[2]-l[2]: center[2]+l[2]]
-    vals_lst = list(data['img'].values())
-    assert np.all([mod.shape==vals_lst[0].shape for mod in vals_lst]),\
-    "produced modalities for same subject with different shapes"
-    
-    for mask in data['anatomical_masks'].keys():
-        data['anatomical_masks'][mask] = data['anatomical_masks'][mask] \
-            [z_low:z_up, center[1]-l[1]: center[1]+l[1], center[2]-l[2]: center[2]+l[2]]
-            
-    for mask in data['lesions'].keys():
-        data['lesions'][mask] = data['lesions'][mask] \
-            [z_low:z_up, center[1]-l[1]: center[1]+l[1], center[2]-l[2]: center[2]+l[2]]
-    return data
-
-def add_transitional_zone_mask(data):
-    if 'pro' in data['anatomical_masks'] and 'pz' in data['anatomical_masks']:
-        intersection = data['anatomical_masks']['pro'] & data['anatomical_masks']['pz']
-        data['anatomical_masks']['tz'] = data['anatomical_masks']['pro'] - intersection
-    return data
-
-def generate_labels(data, seg_labels, class_labels, gleason_map, observables_rois):
-    """merge individual binary labels to an integer label mask and create class labels from Gleason score.
-        if seg_labels has seg_label 'roi': seg label will be roi count.
-    """
-    anatomical_masks2label = [l for l in data['anatomical_masks'].keys() if l in seg_labels.keys()]
-    
-    data['seg'] = np.zeros(shape=data['anatomical_masks']['pro'].shape, dtype=np.uint8)
-    data['roi_classes'] = []
-    #data['roi_observables']: dict, each entry is one list of length final roi_count in this patient
-    data['roi_observables'] = {obs:[] for obs in observables_rois}
-    roi_count = 0
-
-    for mask in anatomical_masks2label:
-        ixs = np.where(data['anatomical_masks'][mask])
-        roi_class = class_labels[mask]
-        if len(ixs)>0 and roi_class!=-1:
-            roi_count+=1
-            label = seg_labels[mask]
-            if label=='roi':
-                label = roi_count
-            data['seg'][ixs] = label
-            data['roi_classes'].append(roi_class)
-            for obs in observables_rois:
-                obs_val = data[obs][mask] if mask in data[obs].keys() else None
-                data['roi_observables'][obs].append(obs_val)
-        #print("appended mask lab", class_labels[mask])
-      
-    if "lesions" in seg_labels.keys():   
-        for lesion_key, lesion_mask in data['lesions'].items():
-            ixs = np.where(lesion_mask)
-            roi_class = class_labels['lesions']
-            if roi_class == "gleason":
-                roi_class = gleason_map(data['lesion_gleasons'][lesion_key])
-                # roi_class =  data['lesion_gleasons'][lesion_key]
-            if len(ixs)>0 and roi_class!=-1:
-                roi_count+=1
-                label = seg_labels['lesions']
-                if label=='roi':
-                    label = roi_count
-                data['seg'][ixs] = label
-                #segs have form: slices x h x w, i.e., one channel per z-slice, each lesion has its own label
-                data['roi_classes'].append(roi_class)
-                for obs in observables_rois:
-                    obs_val = data[obs][lesion_key] if lesion_key in data[obs].keys() else None
-                    data['roi_observables'][obs].append(obs_val)
-
-                # data['lesion_gleasons'][label] = data['lesion_gleasons'].pop(lesion_key)
-    for obs in data['roi_observables'].keys():
-        del data[obs]
-    return data
-
-def normalize_image(data, normalization_dict):
-    """normalize the full image."""
-    percentiles = normalization_dict['percentiles']
-    for mod in data['img'].keys():
-        p = np.percentile(data['img'][mod], percentiles[0])
-        q = np.percentile(data['img'][mod], percentiles[1])
-        masked_img = data['img'][mod][(data['img'][mod] > p) & (data['img'][mod] < q)]
-        data['img'][mod] = (data['img'][mod] - np.median(masked_img)) / np.std(masked_img)
-    return data
-
-def concat_mods(data, mods2concat):
-    """concat modalities on new first channel
-    """
-    concat_on_channel = [] #holds tmp data to be concatenated on the same channel
-    for mod in mods2concat:
-        mod_img = data['img'][mod][np.newaxis]
-        concat_on_channel.append(mod_img)
-    data['img'] = np.concatenate(concat_on_channel, axis=0)
-    
-    return data
-
-def swap_yx(data, apply_flag):
-    """swap x and y axes in img and seg
-    """
-    if apply_flag:
-        data["img"] = np.swapaxes(data["img"], -1,-2)
-        data["seg"] = np.swapaxes(data["seg"], -1,-2)
-
-    return data
-
-def get_fg_z_indices(seg):
-    """return z-indices of array at which the x-y-arrays have labels!=0, 0 is background
-    """
-    fg_slices = np.argwhere(seg.astype(int))[:,0]
-    fg_slices = np.unique(fg_slices)
-    return fg_slices
-
-
-class Preprocessor():
-
-    def __init__(self, config):
-
-        self._config_path = config.config_path
-        self.full_cf = config
-        self._cf = config.prepro
-
-    def get_excluded_master_ids(self):
-        """Get the Master IDs that are excluded from their corresponding Prisma/Radval/Master IDs."""
-
-        excluded_prisma = self._cf['excluded_prisma_subjects']
-        excluded_radval = self._cf['excluded_radval_subjects']
-        excluded_master = self._cf['excluded_master_subjects']
-        histo = self._histo_patient_based
-
-        excluded_master_ids = []
-
-        if len(excluded_prisma) > 0:
-            for prisma_id in excluded_prisma:
-                master_spec = histo['Master_ID'][histo['Original_ID'] == id_to_spec(prisma_id, 'Prisma')].values[0]
-                excluded_master_ids.append(spec_to_id(master_spec))
-
-        if len(excluded_radval) > 0:
-            for radval_id in excluded_radval:
-                master_spec = histo['Master_ID'][histo['Original_ID'] == id_to_spec(radval_id, 'Radiology')].values[0]
-                excluded_master_ids.append(spec_to_id(master_spec))
-
-        excluded_master_ids += excluded_master
-
-        return excluded_master_ids
-
-
-    def prepare_filenames(self):
-        """check whether histology-backed subjects and lesions are available in the data and
-        yield dict of subject file-paths."""
-
-        # assemble list of histology-backed subject ids and check that corresponding images are available
-        self._histo_lesion_based = pd.read_csv(os.path.join(self._cf['histo_dir'], self._cf['histo_lesion_based']))
-        self._histo_patient_based = pd.read_csv(os.path.join(self._cf['histo_dir'], self._cf['histo_patient_based']))
-
-        excluded_master_ids = self.get_excluded_master_ids()
-        self._subj_ids = np.unique(self._histo_lesion_based[self._cf['histo_id_column_name']].values)
-        self._subj_ids = [s for s in self._subj_ids.tolist() if
-                          s not in excluded_master_ids]
-
-        # get subject directory paths from
-        img_paths = os.listdir(self._cf['data_dir'])
-        self._img_paths = [p for p in img_paths if 'Master' in p and len(p) == len('Master') + 6]
-
-        # check that all images of subjects with histology are available
-        available_subj_ids = np.array([spec_to_id(s) for s in self._img_paths])
-        self._missing_image_ids = np.setdiff1d(self._subj_ids, available_subj_ids)
-
-        assert len(self._missing_image_ids)== 0,\
-                'Images of subjs {} are not available.'.format(self._missing_image_ids)
-
-        # make dict holding relevant paths to data of each subject
-        self._paths_by_subject = {}
-        for s in self._subj_ids:
-            self._paths_by_subject[s] = self.load_subject_paths(s)
-        
-
-    def load_subject_paths(self, subject_id):
-        """Make dict holding relevant paths to data of a given subject."""
-        dir_spec = self._cf['dir_spec']
-        s_dict = {}
-
-        # iterate images
-        images_paths = {}
-        for kind, filename in self._cf['images'].items():
-            filename += self._cf['img_postfix']+self._cf['overall_postfix']
-            images_paths[kind] = os.path.join(self._cf['data_dir'], id_to_spec(subject_id, dir_spec), filename)
-        s_dict['images'] = images_paths
-
-        # iterate anatomical structures
-        anatomical_masks_paths = {}
-        for tag in self._cf['anatomical_masks']:
-            filename = tag + self._cf['overall_postfix']
-            anatomical_masks_paths[tag] = os.path.join(self._cf['data_dir'], id_to_spec(subject_id, dir_spec), filename)
-        s_dict['anatomical_masks'] = anatomical_masks_paths
-
-        # iterate lesions
-        lesion_names = []
-        if 'adc' in self._cf['images']:
-            lesion_names.extend(self._histo_lesion_based[self._histo_lesion_based[self._cf['histo_id_column_name']]\
-                                                    == subject_id]['segmentationsNameADC'].dropna())
-        if 't2' in self._cf['images']:
-            lesion_names.extend(self._histo_lesion_based[self._histo_lesion_based[self._cf['histo_id_column_name']]\
-                                                    == subject_id]['segmentationsNameT2'].dropna())
-        lesion_paths = {}
-        for l in lesion_names:
-            lesion_path = os.path.join(self._cf['data_dir'], id_to_spec(subject_id, dir_spec),
-                                       l+self._cf['lesion_postfix']+self._cf['overall_postfix'])
-            assert os.path.isfile(lesion_path), 'Lesion mask not found under {}!'.format(lesion_path)
-
-            lesion_paths[l] = lesion_path
-
-        s_dict['lesions'] = lesion_paths
-        return s_dict
-
-
-    def load_subject_data(self, subject_id):
-        """load img data, masks, histo data for a single subject."""
-        subj_paths = self._paths_by_subject[subject_id]
-        data = {}
-
-        # iterate images
-        data['img'] = {}
-        for mod in subj_paths['images']:
-            data['img'][mod] = sitk.ReadImage(subj_paths['images'][mod])
-
-        # iterate anatomical masks
-        data['anatomical_masks'] = {} 
-        for tag in subj_paths['anatomical_masks']:
-            data['anatomical_masks'][tag] = sitk.ReadImage(subj_paths['anatomical_masks'][tag])
-
-        # iterate lesions, include gleason score
-        data['lesions'] = {}
-        data['lesion_gleasons'] = {}
-        idcol = self._cf['histo_id_column_name']
-        subj_histo = self._histo_lesion_based[self._histo_lesion_based[idcol]==subject_id]
-        for l in subj_paths['lesions']:
-            #print("subjpaths lesions l ", l)
-            data['lesions'][l] = sitk.ReadImage(subj_paths['lesions'][l])
-
-            try:
-                gleason = subj_histo[subj_histo["segmentationsNameADC"]==l]["Gleason"].tolist()[0]
-            except IndexError:
-                gleason = subj_histo[subj_histo["segmentationsNameT2"]==l]["Gleason"].tolist()[0]
-
-            data['lesion_gleasons'][l] = gleason
-        
-        # add other subj-specific histo and id data
-        idcol = self._cf['histo_pb_id_column_name']
-        subj_histo = self._histo_patient_based[self._histo_patient_based[idcol]==subject_id]
-        for d in self._cf['observables_patient']:
-            data[d] = subj_histo[d].values
-        
-        return data
-
-    def analyze_subject_data(self, data):
-        """record post-alignment geometries."""
-
-        ref_mods = data['img'].keys()
-        geos = {}
-        for ref_mod in ref_mods:
-            geos[ref_mod] = {'size': data['img'][ref_mod].GetSize(), 'origin': data['img'][ref_mod].GetOrigin(),
-                   'spacing': data['img'][ref_mod].GetSpacing()}
-
-        return geos
-
-    def process_subject_data(self, data):
-        """evtly rescale images, check for geometry miss-alignments and perform crop."""
-        
-        if not self._cf['mod_scaling'] == (1,1,1):
-            for img_name in data['img']:
-                res_img = rescale(data["img"][img_name], self._cf['mod_scaling'])
-                data['img'][img_name] = res_img
-
-        #----check geometry alignment between masks and image---
-        for tag in self._cf['anatomical_masks']:
-            if tag.lower().startswith("seg_adc"):
-                ref_mod = 'adc'
-            elif tag.lower().startswith("seg_t2"):
-                ref_mod = 't2'
-            if not has_equal_geometry(data['img'][ref_mod], data['anatomical_masks'][tag]):
-                #print("bef", np.unique(sitk.GetArrayFromImage(data['anatomical_masks'][tag])))
-                #print('Geometry mismatch: {}, {} is resampled to its image geometry!'.format(data["Original_ID"], tag))
-                data['anatomical_masks'][tag] =\
-                    resample_to_reference(data['img'][ref_mod], data['anatomical_masks'][tag],
-                                          interpolation=self._cf['interpolation'])
-                #print("aft", np.unique(sitk.GetArrayFromImage(data['anatomical_masks'][tag])))
-
-        for tag in data['lesions'].keys():
-            if tag.lower().startswith("seg_adc"):
-                ref_mod = 'adc'
-            elif tag.lower().startswith("seg_t2"):
-                ref_mod = 't2'
-            if not has_equal_geometry(data['img'][ref_mod], data['lesions'][tag]):
-                #print('Geometry mismatch: {}, {} is resampled to its image geometry!'.format(data["Original_ID"], tag))
-                #print("pre-sampling data type: {}".format(data['lesions'][tag]))
-                data['lesions'][tag] = resample_to_reference(data['img'][ref_mod], data['lesions'][tag],
-                                                              interpolation=self._cf['interpolation'])
-
-
-        data = convert_to_arrays(data)
-        data = merge_crossmod_masks(data, self._cf['rename_tags'], mode=self._cf['merge_mode'])
-        data = crop_3D(data, self._cf['pre_crop_size'], self._cf['center_of_mass_crop'])
-        data = add_transitional_zone_mask(data)
-        data = generate_labels(data, self._cf['seg_labels'], self._cf['class_labels'], self._cf['gleason_map'],
-                               self._cf['observables_rois'])
-        data = normalize_image(data, self._cf['normalization'])
-        data = concat_mods(data, self._cf['modalities2concat'])
-        data = swap_yx(data, self._cf["swap_yx_to_xy"])
-        
-        data['fg_slices'] = get_fg_z_indices(data['seg'])
-        
-        return data
-
-    def write_subject_arrays(self, data, subject_spec):
-        """Write arrays to disk and save file names in dict."""
-
-        out_dir = self._cf['output_directory']
-        os.makedirs(out_dir, exist_ok=True) #might throw error if restrictive permissions
-
-        out_dict = {}
-
-        # image(s)
-        name = subject_spec + '_imgs.npy'
-        np.save(os.path.join(out_dir, name), data['img'])
-        out_dict['img'] = name
-        
-        # merged labels
-        name = subject_spec + '_merged_seg.npy'
-        np.save(os.path.join(out_dir, name), data['seg'])
-        out_dict['seg'] = name
-
-        # anatomical masks separately
-        #for mask in list(data['anatomical_masks'].keys()) + (['tz'] if 'tz' in data.keys() else []):
-        #    name = subject_spec + '_{}.npy'.format(mask)
-        #    np.save(os.path.join(out_dir, name), data['anatomical_masks'][mask])
-        #    out_dict[mask] = name
-
-        # lesion masks and lesion classes separately
-        #out_dict['lesion_gleasons'] = {}
-        #for mask in data['lesions'].keys():
-        #    name = subject_spec + '_{}.npy'.format(mask)
-        #    np.save(os.path.join(out_dir, name), data['lesions'][mask])
-        #    out_dict[mask] = name
-        #    out_dict['lesion_gleasons'][int(mask[-1])] = data['lesion_gleasons'][int(mask[-1])]
-            
-        # roi classes
-        out_dict['roi_classes'] = data['roi_classes']
-
-        
-        # fg_slices info
-        out_dict['fg_slices'] = data['fg_slices']
-        
-        # other observables
-        for obs in self._cf['observables_patient']:
-            out_dict[obs] = data[obs]
-        for obs in data['roi_observables'].keys():
-            out_dict[obs] = data['roi_observables'][obs]
-        #print("subj outdict ", out_dict.keys())
-        return out_dict
-
-    def subject_iteration(self, subj_id): #single iteration, wrapped for pooling
-        data = self.load_subject_data(subj_id)
-        data = self.process_subject_data(data)
-        subj_out_dict = self.write_subject_arrays(data, id_to_spec(subj_id, self._cf['dir_spec']))
-        
-        print('Processed subject {}.'.format(id_to_spec(subj_id, self._cf['dir_spec'])))
-        
-        return (subj_id, subj_out_dict)
-        
-    def iterate_subjects(self, ids_subset=None, processes=6):
-        """process all subjects."""
-        
-        if ids_subset is None:
-            ids_subset = self._subj_ids
-        else:
-            ids_subset = np.array(ids_subset)
-            id_check = np.array([id in self._subj_ids for id in ids_subset])
-            assert np.all(id_check), "pids {} not in eligible pids".format(ids_subset[np.invert(id_check)])
-
-        p = Pool(processes)
-        subj_out_dicts = p.map(self.subject_iteration, ids_subset)
-        """note on Pool.map: only takes one arg, pickles the function for execution -->
-        cannot write to variables defined outside local scope --> cannot write to
-        self.variables, therefore need to return single subj_out_dicts and join after;
-        however p.map can access object methods via self.method().
-        Is a bit complicated, but speedup is huge.
-        """
-        p.close()
-        p.join()
-        assert len(subj_out_dicts)==len(ids_subset), "produced less subject dicts than demanded"
-        self._info_dict = {id:dic for (id, dic) in subj_out_dicts}
-        
-        return
-
-    def subject_analysis(self, subj_id):  # single iteration, wrapped for pooling
-        data = self.load_subject_data(subj_id)
-        analysis = self.analyze_subject_data(data)
-
-        print('Analyzed subject {}.'.format(id_to_spec(subj_id, self._cf['dir_spec'])))
-
-        return (subj_id, analysis)
-
-    def analyze_subjects(self, ids_subset=None, processes=os.cpu_count()):
-        """process all subjects."""
-
-        if ids_subset is None:
-            ids_subset = self._subj_ids
-        else:
-            ids_subset = np.array(ids_subset)
-            id_check = np.array([id in self._subj_ids for id in ids_subset])
-            assert np.all(id_check), "pids {} not in eligible pids".format(ids_subset[np.invert(id_check)])
-
-        p = Pool(processes)
-        subj_analyses = p.map(self.subject_analysis, ids_subset)
-        """note on Pool.map: only takes one arg, pickles the function for execution -->
-        cannot write to variables defined outside local scope --> cannot write to
-        self.variables, therefore need to return single subj_out_dicts and join after;
-        however p.map can access object methods via self.method().
-        Is a bit complicated, but speedup is huge.
-        """
-        p.close()
-        p.join()
-
-        df = pd.DataFrame(columns=['id', 'mod', 'size', 'origin', 'spacing'])
-        for subj_id, analysis in subj_analyses:
-            for mod, geo in analysis.items():
-                df.loc[len(df)] = [subj_id, mod, np.array(geo['size']), np.array(geo['origin']), np.array(geo['spacing'])]
-
-        os.makedirs(self._cf['output_directory'], exist_ok=True)
-        df.to_csv(os.path.join(self._cf['output_directory'], "analysis_df"))
-
-        print("\nOver all mods")
-        print("Size mean {}\u00B1{}".format(df['size'].mean(), np.std(df['size'].values)))
-        print("Origin mean {}\u00B1{}".format(df['origin'].mean(), np.std(df['origin'].values)))
-        print("Spacing mean {}\u00B1{}".format(df['spacing'].mean(), np.std(df['spacing'].values)))
-        print("-----------------------------------------\n")
-
-        for mod in df['mod'].unique():
-            print("\nModality: {}".format(mod))
-            mod_df = df[df['mod']==mod]
-            print("Size mean {}\u00B1{}".format(mod_df['size'].mean(), np.std(mod_df['size'].values)))
-            print("Origin mean {}\u00B1{}".format(mod_df['origin'].mean(), np.std(mod_df['origin'].values)))
-            print("Spacing mean {}\u00B1{}".format(mod_df['spacing'].mean(), np.std(mod_df['spacing'].values)))
-            print("-----------------------------------------\n")
-        return
-
-
-    def dump_class_labels(self, out_dir):
-        """save used GS mapping and class labels to file.
-            will likely not work if non-lesion classes (anatomy) are contained
-        """
-        #if "gleason_thresh" in self._cf.keys():
-        possible_gs = {gs for p_dict in self._info_dict.values() for gs in p_dict['lesion_gleasons']}
-        gs_mapping_inv = [(self._cf["gleason_map"](gs)+1, gs) for gs in possible_gs]
-        #elif "gleason_mapping" in self._cf.keys():
-            #gs_mapping_inv = [(val + 1, key) for (key, val) in self._cf["gleason_mapping"].items() if val != -1]
-        classes = {pair[0] for pair in gs_mapping_inv}
-        groups = [[pair[1] for pair in gs_mapping_inv if pair[0]==cl] for cl in classes]
-        gr_names = [ "GS{}-{}".format(min(gr), max(gr)) if len(gr)>1 else "GS"+str(*gr) for gr in groups ]
-        if "color_palette" in self._cf.keys():
-            class_labels = {cl: {"gleasons": groups[ix], "name": gr_names[ix], "color": self._cf["color_palette"][ix]}
-                            for ix, cl in enumerate(classes) }
-        else:
-            class_labels = {cl: {"gleasons": groups[ix], "name": gr_names[ix], "color": self.full_cf.color_palette[ix]}
-                            for ix, cl in enumerate(classes)}
-
-        save_obj(class_labels, os.path.join(out_dir,"pp_class_labels"))
-
-
-
-    def save_and_finish(self):
-        """copy config and used code to out_dir."""
-
-        out_dir = self._cf['output_directory']
-
-        # save script
-        current_script = os.path.realpath(__file__)
-        shutil.copyfile(current_script, os.path.join(out_dir, 'applied_preprocessing.py'))
-
-        # save config
-        if self._config_path[-1] == 'c':
-            self._config_path = self._config_path[:-1]
-        shutil.copyfile(self._config_path, os.path.join(out_dir, 'applied_config.py'))
-        
-        #copy histo data to local dir
-        lbased = self._cf['histo_lesion_based']
-        pbased = self._cf['histo_patient_based']
-        os.makedirs(self._cf['histo_dir_out'], exist_ok=True)
-        shutil.copyfile(self._cf['histo_dir']+lbased, self._cf['histo_dir_out']+lbased)
-        shutil.copyfile(self._cf['histo_dir']+pbased, self._cf['histo_dir_out']+pbased)
-       
-        # save info dict
-        #print("info dict ", self._info_dict)
-        save_obj(self._info_dict, self._cf['info_dict_path'][:-4])
-        self.dump_class_labels(out_dir)
-
-        return
-    
-    def convert_copy_npz(self):
-        if not self._cf["npz_dir"]:
-            return
-        print("npz dir", self._cf['npz_dir'])
-        os.makedirs(self._cf['npz_dir'], exist_ok=True)
-        save_obj(self._info_dict, os.path.join(self._cf['npz_dir'], 
-                                               self._cf['info_dict_path'].split("/")[-1][:-4]))
-        lbased = self._cf['histo_lesion_based']
-        pbased = self._cf['histo_patient_based']
-        histo_out = os.path.join(self._cf['npz_dir'], "histos/")
-        print("histo dir", histo_out)
-        os.makedirs(histo_out, exist_ok=True)
-        shutil.copyfile(self._cf['histo_dir']+lbased, histo_out+lbased)
-        shutil.copyfile(self._cf['histo_dir']+pbased, histo_out+pbased)
-        shutil.copyfile(os.path.join(self._cf['output_directory'], 'applied_config.py'),
-                        os.path.join(self._cf['npz_dir'], 'applied_config.py'))
-        shutil.copyfile(os.path.join(self._cf['output_directory'], 'applied_preprocessing.py'),
-                        os.path.join(self._cf['npz_dir'], 'applied_preprocessing.py'))
-        shutil.copyfile(os.path.join(self._cf['output_directory'], 'pp_class_labels.pkl'),
-                        os.path.join(self._cf['npz_dir'], 'pp_class_labels.pkl'))
-        
-        dmanager.pack_dataset(self._cf["output_directory"], self._cf["npz_dir"], recursive=True)
-        
-        
-        
-
-
-if __name__ == "__main__":
-
-    stime = time.time()
-    
-    from configs import Configs
-    cf = configs()
-    
-    
-    pp = Preprocessor(config=cf)
-    pp.prepare_filenames()
-    #pp.analyze_subjects(ids_subset=None)#[1,2,3])
-    pp.iterate_subjects(ids_subset=None, processes=os.cpu_count())
-    pp.save_and_finish()
-    pp.convert_copy_npz()
-    
-   
-    #patient_id = 17
-    #data = pp.load_subject_data(patient_id)
-    #data = pp.process_subject_data(data)
-    
-    #img = data['img']
-    #print("img shape ", img.shape)
-    #print("seg shape ",  data['seg'].shape)
-    #label_remap = {0:0}
-    #label_remap.update({roi_id : 1 for roi_id in range(1,5)})
-    #plg.view_slices(cf, img[0], data['seg'], instance_labels=True,
-    #                out_dir="experiments/dev/ex_slices.png")
-    
-    mins, secs = divmod((time.time() - stime), 60)
-    h, mins = divmod(mins, 60)
-    t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) 
-    print("Prepro program runtime: {}".format(t))
diff --git a/graphics_generation.py b/graphics_generation.py
deleted file mode 100644
index 6c59a0c..0000000
--- a/graphics_generation.py
+++ /dev/null
@@ -1,1932 +0,0 @@
-"""
-Created at 07/03/19 11:42
-@author: gregor 
-"""
-import plotting as plg
-import matplotlib.lines as mlines
-
-import os
-import sys
-import multiprocessing
-from copy import deepcopy
-import logging
-import time
-
-import numpy as np
-import pandas as pd
-from scipy.stats import norm
-from sklearn.metrics import confusion_matrix
-
-import utils.exp_utils as utils
-import utils.model_utils as mutils
-import utils.dataloader_utils as dutils
-from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates
-
-import predictor as predictor_file
-import evaluator as evaluator_file
-
-
-
-class NoDaemonProcess(multiprocessing.Process):
-    # make 'daemon' attribute always return False
-    def _get_daemon(self):
-        return False
-    def _set_daemon(self, value):
-        pass
-    daemon = property(_get_daemon, _set_daemon)
-
-# We sub-class multiprocessing.pool.Pool instead of multiprocessing.Pool
-# because the latter is only a wrapper function, not a proper class.
-class NoDaemonProcessPool(multiprocessing.pool.Pool):
-    Process = NoDaemonProcess
-
-class AttributeDict(dict):
-    __getattr__ = dict.__getitem__
-    __setattr__ = dict.__setitem__
-
-def get_cf(dataset_name, exp_dir=""):
-
-    cf_path = os.path.join('datasets', dataset_name, exp_dir, "configs.py")
-    cf_file = utils.import_module('configs', cf_path)
-
-    return cf_file.Configs()
-
-
-def prostate_results_static(plot_dir=None):
-    cf = get_cf('prostate', '')
-    if plot_dir is None:
-        plot_dir = os.path.join('datasets', 'prostate', 'misc')
-
-    text_fs = 18
-    fig = plg.plt.figure(figsize=(6, 3)) #w,h
-    grid = plg.plt.GridSpec(1, 1, wspace=0.0, hspace=0.0, figure=fig) #r,c
-
-    groups = ["b values", "ADC + b values", "T2"]
-    splits = ["Det. U-Net", "Mask R-CNN", "Faster R-CNN+"]
-    values = {"detu": [(0.296, 0.031), (0.312, 0.045), (0.090, 0.040)],
-              "mask": [(0.393, 0.051), (0.382, 0.047), (0.136, 0.016)],
-              "fast": [(0.424, 0.083), (0.390, 0.086), (0.036, 0.013)]}
-    bar_values = [[v[0] for v in split] for split in values.values()]
-    errors = [[v[1] for v in split] for split in values.values()]
-    ax = fig.add_subplot(grid[0,0])
-    colors = [cf.aubergine, cf.blue, cf.dark_blue]
-    plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, errors=errors, colors=colors, ax=ax, legend=True,
-                               title="Prostate Main Results (3D)", ylabel=r"Performance as $\mathrm{AP}_{10}$", xlabel="Input Modalities")
-    plg.plt.tight_layout()
-    plg.plt.savefig(os.path.join(plot_dir, 'prostate_main_results.png'), dpi=600)
-
-def prostate_GT_examples(exp_dir='', plot_dir=None, pid=8., z_ix=None):
-
-    import datasets.prostate.data_loader as dl
-    cf = get_cf('prostate', exp_dir)
-    cf.exp_dir = exp_dir
-    cf.fold = 0
-    cf.data_sourcedir =  "/mnt/HDD2TB/Documents/data/prostate/data_di_250519_ps384_gs6071/"
-    dataset = dl.Dataset(cf)
-    dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits)
-    dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle'))
-    set_splits = dataset.fg.splits
-
-    test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1)
-    train_ids = np.concatenate(set_splits, axis=0)
-
-    if cf.held_out_test_set:
-        train_ids = np.concatenate((train_ids, test_ids), axis=0)
-        test_ids = []
-    print("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids),
-                                                                                    len(test_ids)))
-
-
-    if plot_dir is None:
-        plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'prostate', 'misc')
-
-    text_fs = 18
-    fig = plg.plt.figure(figsize=(10, 7.7)) #w,h
-    grid = plg.plt.GridSpec(3, 4, wspace=0.0, hspace=0.0, figure=fig) #r,c
-    text_x, text_y = 0.1, 0.8
-
-    # ------- DWI -------
-    if z_ix is None:
-        z_ix_dwi = np.random.choice(dataset[pid]["fg_slices"])
-    img = np.load(dataset[pid]["img"])[:,z_ix_dwi] # mods, z,y,x
-    seg = np.load(dataset[pid]["seg"])[z_ix_dwi] # z,y,x
-    ax = fig.add_subplot(grid[0,0])
-    ax.imshow(img[0], cmap='gray')
-    ax.text(text_x, text_y, "ADC", size=text_fs, color=cf.white, transform=ax.transAxes,
-          bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7))
-    ax.axis('off')
-    ax = fig.add_subplot(grid[0,1])
-    ax.imshow(img[0], cmap='gray')
-    cmap = cf.class_cmap
-    for r_ix in np.unique(seg[seg>0]):
-        seg[seg==r_ix] = dataset[pid]["class_targets"][r_ix-1]
-    ax.imshow(plg.to_rgba(seg, cmap), alpha=1)
-    ax.text(text_x, text_y, "DWI GT", size=text_fs, color=cf.white, transform=ax.transAxes,
-          bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7))
-    ax.axis('off')
-    for b_ix, b in enumerate([50,500,1000,1500]):
-        ax = fig.add_subplot(grid[1, b_ix])
-        ax.imshow(img[b_ix+1], cmap='gray')
-        ax.text(text_x, text_y, r"{}{}".format("$b=$" if b_ix == 0 else "", b), size=text_fs, color=cf.white,
-                transform=ax.transAxes,
-                bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7))
-        ax.axis('off')
-
-    # ----- T2 -----
-    cf.data_sourcedir = "/mnt/HDD2TB/Documents/data/prostate/data_t2_250519_ps384_gs6071/"
-    dataset = dl.Dataset(cf)
-    if z_ix is None:
-        if z_ix_dwi in dataset[pid]["fg_slices"]:
-            z_ix_t2 = z_ix_dwi
-        else:
-            z_ix_t2 = np.random.choice(dataset[pid]["fg_slices"])
-    img = np.load(dataset[pid]["img"])[:,z_ix_t2] # mods, z,y,x
-    seg = np.load(dataset[pid]["seg"])[z_ix_t2] # z,y,x
-    ax = fig.add_subplot(grid[2,0])
-    ax.imshow(img[0], cmap='gray')
-    ax.text(text_x, text_y, "T2w", size=text_fs, color=cf.white, transform=ax.transAxes,
-          bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7))
-    ax.axis('off')
-    ax = fig.add_subplot(grid[2,1])
-    ax.imshow(img[0], cmap='gray')
-    cmap = cf.class_cmap
-    for r_ix in np.unique(seg[seg>0]):
-        seg[seg==r_ix] = dataset[pid]["class_targets"][r_ix-1]
-    ax.imshow(plg.to_rgba(seg, cmap), alpha=1)
-    ax.text(text_x, text_y, "T2 GT", size=text_fs, color=cf.white, transform=ax.transAxes,
-          bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7))
-    ax.axis('off')
-
-    #grid.tight_layout(fig)
-    plg.plt.tight_layout()
-    plg.plt.savefig(os.path.join(plot_dir, 'prostate_gt_examples.png'), dpi=600)
-
-
-def prostate_dataset_stats(exp_dir='', plot_dir=None, show_splits=True,):
-
-    import datasets.prostate.data_loader as dl
-    cf = get_cf('prostate', exp_dir)
-    cf.exp_dir = exp_dir
-    cf.fold = 0
-    dataset = dl.Dataset(cf)
-    dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits)
-    dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle'))
-    set_splits = dataset.fg.splits
-
-    test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1)
-    train_ids = np.concatenate(set_splits, axis=0)
-
-    if cf.held_out_test_set:
-        train_ids = np.concatenate((train_ids, test_ids), axis=0)
-        test_ids = []
-
-    print("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids),
-                                                                                    len(test_ids)))
-
-    df, labels = dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids}, plot_dir=None)
-
-    if plot_dir is None:
-        plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'prostate', 'misc')
-
-    if show_splits:
-        fig = plg.plt.figure(figsize=(6, 6)) # w, h
-        grid = plg.plt.GridSpec(2, 2, wspace=0.05, hspace=0.15, figure=fig) # rows, cols
-    else:
-        fig = plg.plt.figure(figsize=(6, 3.))
-        grid = plg.plt.GridSpec(1, 1, wspace=0.0, hspace=0.15, figure=fig)
-
-    ax = fig.add_subplot(grid[0,0])
-    ax = plg.plot_data_stats(cf, df, labels, ax=ax)
-    ax.set_xlabel("")
-    ax.set_xticklabels(df.columns, rotation='horizontal', fontsize=11)
-    ax.set_title("")
-    if show_splits:
-        ax.text(0.05,0.95, 'a)', horizontalalignment='center', verticalalignment='center', transform = ax.transAxes, weight='bold')
-    ax.text(0, 25, "GS$=6$", horizontalalignment='center', verticalalignment='center', bbox=dict(facecolor=(*cf.white, 0.8), edgecolor=cf.dark_green, pad=3))
-    ax.text(1, 25, "GS$\geq 7a$", horizontalalignment='center', verticalalignment='center', bbox=dict(facecolor=(*cf.white, 0.8), edgecolor=cf.red, pad=3))
-    ax.margins(y=0.1)
-
-    if show_splits:
-        ax = fig.add_subplot(grid[:, 1])
-        ax = plg.plot_fold_stats(cf, df, labels, ax=ax)
-        ax.set_xlabel("")
-        ax.set_title("")
-        ax.text(0.05, 0.98, 'c)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, weight='bold')
-        ax.yaxis.tick_right()
-        ax.yaxis.set_label_position("right")
-        ax.margins(y=0.1)
-
-        ax = fig.add_subplot(grid[1, 0])
-        cf.balance_target = "lesion_gleasons"
-        dataset.df = None
-        df, labels = dataset.calc_statistics(plot_dir=None, overall_stats=True)
-        ax = plg.plot_data_stats(cf, df, labels, ax=ax)
-        ax.set_xlabel("")
-        ax.set_title("")
-        ax.text(0.05, 0.95, 'b)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, weight='bold')
-        ax.margins(y=0.1)
-        # rename GS according to names in thesis
-        renamer = {'GS60':'GS 6', 'GS71':'GS 7a', 'GS72':'GS 7b', 'GS80':'GS 8', 'GS90': 'GS 9', 'GS91':'GS 9a', 'GS92':'GS 9b'}
-        x_ticklabels = [str(l.get_text()) for l in ax.xaxis.get_ticklabels()]
-        ax.xaxis.set_ticklabels([renamer[l] for l in x_ticklabels])
-
-    plg.plt.tight_layout()
-    plg.plt.savefig(os.path.join(plot_dir, 'data_stats_prostate.png'), dpi=600)
-
-    return
-
-def lidc_merged_sa_joint_plot(exp_dir='', plot_dir=None):
-    import datasets.lidc.data_loader as dl
-    cf = get_cf('lidc', exp_dir)
-    cf.balance_target = "regression_targets"
-
-    if plot_dir is None:
-        plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'lidc', 'misc')
-
-    cf.training_gts = 'merged'
-    dataset = dl.Dataset(cf, mode='train')
-    df, labels = dataset.calc_statistics(plot_dir=None, overall_stats=True)
-
-    fig = plg.plt.figure(figsize=(4, 5.6)) #w, h
-    # fig.subplots_adjust(hspace=0, wspace=0)
-    grid = plg.plt.GridSpec(3, 1, wspace=0.0, hspace=0.7, figure=fig) #rows, cols
-    fs = 9
-
-    ax = fig.add_subplot(grid[0, 0])
-
-    labels = [AttributeDict({ 'name': rg_val, 'color': cf.bin_id2label[cf.rg_val_to_bin_id(rg_val)].color}) for rg_val
-              in df.columns]
-    ax = plg.plot_data_stats(cf, df, labels, ax=ax, fs=fs)
-    ax.set_xlabel("averaged multi-rater malignancy scores (ms)", fontsize=fs)
-    ax.set_title("")
-    ax.text(0.05, 0.91, 'a)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes,
-            weight='bold', fontsize=fs)
-    ax.margins(y=0.2)
-
-    #----- single annotator -------
-    cf.training_gts = 'sa'
-    dataset = dl.Dataset(cf, mode='train')
-    df, labels = dataset.calc_statistics(plot_dir=None, overall_stats=True)
-
-    ax = fig.add_subplot(grid[1, 0])
-    labels = [AttributeDict({ 'name': '{:.0f}'.format(rg_val), 'color': cf.bin_id2label[cf.rg_val_to_bin_id(rg_val)].color}) for rg_val
-              in df.columns]
-    mapper = {rg_val:'{:.0f}'.format(rg_val) for rg_val in df.columns}
-    df = df.rename(mapper, axis=1)
-    ax = plg.plot_data_stats(cf, df, labels, ax=ax, fs=fs)
-    ax.set_xlabel("unaggregrated single-rater malignancy scores (ms)", fontsize=fs)
-    ax.set_title("")
-    ax.text(0.05, 0.91, 'b)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes,
-            weight='bold', fontsize=fs)
-    ax.margins(y=0.45)
-
-    #------ binned dissent -----
-    #cf.balance_target = "regression_targets"
-    all_patients = [(pid,patient['rg_bin_targets']) for pid, patient in dataset.data.items()]
-    non_empty_patients = [(pid, lesions) for (pid, lesions) in all_patients if len(lesions) > 0]
-
-    mean_std_per_lesion = np.array([(np.mean(roi), np.std(roi)) for (pid, lesions) in non_empty_patients for roi in lesions])
-    distribution_max_per_lesion = [np.unique(roi, return_counts=True) for (pid, lesions) in non_empty_patients for roi in lesions]
-    distribution_max_per_lesion = np.array([uniq[cts.argmax()] for (uniq, cts) in distribution_max_per_lesion])
-
-    binned_stats = [[] for bin_id in cf.bin_id2rg_val.keys()]
-    for l_ix, mean_std in enumerate(mean_std_per_lesion):
-        bin_id = cf.rg_val_to_bin_id(mean_std[0])
-        bin_id_max = cf.rg_val_to_bin_id(distribution_max_per_lesion[l_ix])
-        binned_stats[int(bin_id)].append((*mean_std, distribution_max_per_lesion[l_ix], bin_id-bin_id_max))
-
-    ax = fig.add_subplot(grid[2, 0])
-    plg.plot_binned_rater_dissent(cf, binned_stats, ax=ax, fs=fs)
-    ax.set_title("")
-    ax.text(0.05, 0.91, 'c)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes,
-            weight='bold', fontsize=fs)
-    ax.margins(y=0.2)
-
-
-    plg.plt.savefig(os.path.join(plot_dir, 'data_stats_lidc_solarized.png'), bbox_inches='tight', dpi=600)
-
-    return
-
-def lidc_dataset_stats(exp_dir='', plot_dir=None):
-
-    import datasets.lidc.data_loader as dl
-    cf = get_cf('lidc', exp_dir)
-    cf.data_rootdir = cf.pp_data_path
-    cf.balance_target = "regression_targets"
-
-    dataset = dl.Dataset(cf, data_dir=cf.data_rootdir)
-    if plot_dir is None:
-        plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'lidc', 'misc')
-
-    df, labels = dataset.calc_statistics(plot_dir=plot_dir, overall_stats=True)
-
-    return df, labels
-
-def lidc_sa_dataset_stats(exp_dir='', plot_dir=None):
-
-    import datasets.lidc_sa.data_loader as dl
-    cf = get_cf('lidc_sa', exp_dir)
-    #cf.data_rootdir = cf.pp_data_path
-    cf.balance_target = "regression_targets"
-
-    dataset = dl.Dataset(cf)
-    if plot_dir is None:
-        plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'lidc_sa', 'misc')
-
-    dataset.calc_statistics(plot_dir=plot_dir, overall_stats=True)
-
-    all_patients = [(pid,patient['rg_bin_targets']) for pid, patient in dataset.data.items()]
-    empty_patients = [pid for (pid, lesions) in all_patients if len(lesions) == 0]
-    non_empty_patients = [(pid, lesions) for (pid, lesions) in all_patients if len(lesions) > 0]
-    full_consent_patients = [(pid, lesions) for (pid, lesions) in non_empty_patients if np.all([np.unique(roi).size == 1 for roi in lesions])]
-    all_lesions = [roi for (pid, lesions) in non_empty_patients for roi in lesions]
-    two_vote_min = [roi for (pid, lesions) in non_empty_patients for roi in lesions if np.count_nonzero(roi) > 1]
-    three_vote_min = [roi for (pid, lesions) in non_empty_patients for roi in lesions if np.count_nonzero(roi) > 2]
-    mean_std_per_lesion = np.array([(np.mean(roi), np.std(roi)) for (pid, lesions) in non_empty_patients for roi in lesions])
-    avg_mean_std_pl = np.mean(mean_std_per_lesion, axis=0)
-    # call std dev per lesion disconsent from now on
-    disconsent_std = np.std(mean_std_per_lesion[:, 1])
-
-    distribution_max_per_lesion = [np.unique(roi, return_counts=True) for (pid, lesions) in non_empty_patients for roi in lesions]
-    distribution_max_per_lesion = np.array([uniq[cts.argmax()] for (uniq, cts) in distribution_max_per_lesion])
-
-    mean_max_delta = abs(mean_std_per_lesion[:, 0] - distribution_max_per_lesion)
-
-    binned_stats = [[] for bin_id in cf.bin_id2rg_val.keys()]
-    for l_ix, mean_std in enumerate(mean_std_per_lesion):
-        bin_id = cf.rg_val_to_bin_id(mean_std[0])
-        bin_id_max = cf.rg_val_to_bin_id(distribution_max_per_lesion[l_ix])
-        binned_stats[int(bin_id)].append((*mean_std, distribution_max_per_lesion[l_ix], bin_id-bin_id_max))
-
-    plg.plot_binned_rater_dissent(cf, binned_stats, out_file=os.path.join(plot_dir, "binned_dissent.png"))
-
-
-    mean_max_bin_divergence = [[] for bin_id in cf.bin_id2rg_val.keys()]
-    for bin_id, bin_stats in enumerate(binned_stats):
-        mean_max_bin_divergence[bin_id].append([roi for roi in bin_stats if roi[3] != 0])
-        mean_max_bin_divergence[bin_id].insert(0,len(mean_max_bin_divergence[bin_id][0]))
-
-
-    return
-
-def lidc_annotator_confusion(exp_dir='', plot_dir=None, normalize=None, dataset=None, plot=True):
-    """
-    :param exp_dir:
-    :param plot_dir:
-    :param normalize: str or None. str in ['truth', 'pred']
-    :param dataset:
-    :param plot:
-    :return:
-    """
-    if dataset is None:
-        import datasets.lidc.data_loader as dl
-        cf = get_cf('lidc', exp_dir)
-        # cf.data_rootdir = cf.pp_data_path
-        cf.training_gts = "sa"
-        cf.balance_target = "regression_targets"
-        dataset = dl.Dataset(cf)
-    else:
-        cf = dataset.cf
-
-    if plot_dir is None:
-        plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'lidc', 'misc')
-
-    dataset.calc_statistics(plot_dir=plot_dir, overall_stats=True)
-
-    all_patients = [(pid,patient['rg_bin_targets']) for pid, patient in dataset.data.items()]
-    non_empty_patients = [(pid, lesions) for (pid, lesions) in all_patients if len(lesions) > 0]
-
-    y_true, y_pred = [], []
-    for (pid, lesions) in non_empty_patients:
-        for roi in lesions:
-            true_bin = cf.rg_val_to_bin_id(np.mean(roi))
-            y_true.extend([true_bin] * len(roi))
-            y_pred.extend(roi)
-    cm = confusion_matrix(y_true, y_pred)
-    if normalize in ["truth", "row"]:
-        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
-    elif normalize in ["pred", "prediction", "column", "col"]:
-        cm = cm.astype('float') / cm.sum(axis=0)[:, np.newaxis]
-
-    if plot:
-        plg.plot_confusion_matrix(cf, cm, out_file=os.path.join(plot_dir, "annotator_confusion.pdf"))
-
-    return cm
-
-def plot_lidc_dissent_and_example(confusion_matrix=True, bin_stds=False, plot_dir=None, numbering=True, example_title="Example"):
-    import datasets.lidc.data_loader as dl
-    dataset_name = 'lidc'
-    exp_dir1 = '/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/ms12345_mrcnn3d_rg_bs8'
-    exp_dir2 = '/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/ms12345_mrcnn3d_rgbin_bs8'
-    #exp_dir1 = '/home/gregor/networkdrives/E132-Cluster-Projects/lidc_sa/experiments/ms12345_mrcnn3d_rg_bs8'
-    #exp_dir2 = '/home/gregor/networkdrives/E132-Cluster-Projects/lidc_sa/experiments/ms12345_mrcnn3d_rgbin_bs8'
-    cf = get_cf(dataset_name, exp_dir1)
-    #file_names = [f_name for f_name in os.listdir(os.path.join(exp_dir, 'inference_analysis')) if f_name.endswith('.pkl')]
-    # file_names = [os.path.join(exp_dir, "inference_analysis", f_name) for f_name in file_names]
-    file_names = ["bytes_merged_boxes_fold_0_pid_0811a.pkl",]
-    z_ics = [194,]
-    plot_files = [
-        {'files': [os.path.join(exp_dir, "inference_analysis", f_name) for exp_dir in [exp_dir1, exp_dir2]],
-         'z_ix': z_ix} for (f_name, z_ix) in zip(file_names, z_ics)
-    ]
-
-    cf.training_gts = 'sa'
-    info_df_path = '/mnt/HDD2TB/Documents/data/lidc/pp_20190805/patient_gts_{}/info_df.pickle'.format(cf.training_gts)
-    info_df = pd.read_pickle(info_df_path)
-
-    cf.roi_items = ['regression_targets', 'rg_bin_targets_sa'] #['class_targets'] + cf.observables_rois
-
-    text_fs = 14
-    title_fs = text_fs
-    text_x, text_y = 0.06, 0.92
-    fig = plg.plt.figure(figsize=(8.6, 3)) #w, h
-    #fig.subplots_adjust(hspace=0, wspace=0)
-    grid = plg.plt.GridSpec(1, 4, wspace=0.0, hspace=0.0, figure=fig) #rows, cols
-    cf.plot_class_ids = True
-
-    f_ix = 0
-    z_ix = plot_files[f_ix]['z_ix']
-    for model_ix in range(2)[::-1]:
-        print("f_ix, m_ix", f_ix, model_ix)
-        plot_file = utils.load_obj(plot_files[f_ix]['files'][model_ix])
-        batch = plot_file["batch"]
-        pid = batch["pid"][0]
-        batch['patient_rg_bin_targets_sa'] = info_df[info_df.pid == pid]['class_target'].tolist()
-        # apply same filter as with merged GTs: need at least two non-zero votes to consider a RoI.
-        batch['patient_rg_bin_targets_sa'] = [[four_votes.astype("uint8") for four_votes in batch_el if
-                                               np.count_nonzero(four_votes>0)>=2] for batch_el in
-                                              batch['patient_rg_bin_targets_sa']]
-        results_dict = plot_file["res_dict"]
-
-        # pred
-        ax = fig.add_subplot(grid[0, model_ix+2])
-        plg.view_batch_thesis(cf, batch, res_dict=results_dict, legend=False, sample_picks=None, fontsize=text_fs*1.3,
-                              vol_slice_picks=[z_ix, ], show_gt_labels=True, box_score_thres=0.2, plot_mods=False,
-                              seg_cmap="rg", show_cl_ids=False,
-                              out_file=None, dpi=600, patient_items=True, return_fig=False, axes={'pred': ax})
-
-        #ax.set_title("{}".format("Reg R-CNN" if model_ix==0 else "Mask R-CNN"), size=title_fs)
-        ax.set_title("")
-        ax.set_xlabel("{}".format("Reg R-CNN" if model_ix == 0 else "Mask R-CNN"), size=title_fs)
-        if numbering:
-            ax.text(text_x, text_y, chr(model_ix+99)+")", horizontalalignment='center', verticalalignment='center',
-                    transform=ax.transAxes, weight='bold', color=cf.white, fontsize=title_fs)
-        #ax.axis("off")
-        ax.axis("on")
-        plg.suppress_axes_lines(ax)
-
-        # GT
-        if model_ix==0:
-            ax.set_title(example_title, fontsize=title_fs)
-            ax = fig.add_subplot(grid[0, 1])
-            # ax.imshow(batch['patient_data'][0, 0, :, :, z_ix], cmap='gray')
-            # ax.imshow(plg.to_rgba(batch['patient_seg'][0,0,:,:,z_ix], cf.cmap), alpha=0.8)
-            plg.view_batch_thesis(cf, batch, res_dict=results_dict, legend=True, sample_picks=None, fontsize=text_fs*1.3,
-                                  vol_slice_picks=[z_ix, ], show_gt_labels=True, box_score_thres=0.13, plot_mods=False, seg_cmap="rg",
-                                  out_file=None, dpi=600, patient_items=True, return_fig=False, axes={'gt':ax})
-            if numbering:
-                ax.text(text_x, text_y, "b)", horizontalalignment='center', verticalalignment='center', transform=ax.transAxes,
-            weight='bold', color=cf.white, fontsize=title_fs)
-            #ax.set_title("Ground Truth", size=title_fs)
-            ax.set_title("")
-            ax.set_xlabel("Ground Truth", size=title_fs)
-            plg.suppress_axes_lines(ax)
-            #ax.axis('off')
-    #----- annotator dissent plot(s) ------
-
-    cf.training_gts = 'sa'
-    cf.balance_targets = 'rg_bin_targets'
-    dataset = dl.Dataset(cf, mode='train')
-
-    if bin_stds:
-        #------ binned dissent -----
-        #cf = get_cf('lidc', "")
-
-        #cf.balance_target = "regression_targets"
-        all_patients = [(pid,patient['rg_bin_targets']) for pid, patient in dataset.data.items()]
-        non_empty_patients = [(pid, lesions) for (pid, lesions) in all_patients if len(lesions) > 0]
-
-        mean_std_per_lesion = np.array([(np.mean(roi), np.std(roi)) for (pid, lesions) in non_empty_patients for roi in lesions])
-        distribution_max_per_lesion = [np.unique(roi, return_counts=True) for (pid, lesions) in non_empty_patients for roi in lesions]
-        distribution_max_per_lesion = np.array([uniq[cts.argmax()] for (uniq, cts) in distribution_max_per_lesion])
-
-        binned_stats = [[] for bin_id in cf.bin_id2rg_val.keys()]
-        for l_ix, mean_std in enumerate(mean_std_per_lesion):
-            bin_id = cf.rg_val_to_bin_id(mean_std[0])
-            bin_id_max = cf.rg_val_to_bin_id(distribution_max_per_lesion[l_ix])
-            binned_stats[int(bin_id)].append((*mean_std, distribution_max_per_lesion[l_ix], bin_id-bin_id_max))
-
-        ax = fig.add_subplot(grid[0, 0])
-        plg.plot_binned_rater_dissent(cf, binned_stats, ax=ax, fs=text_fs)
-        if numbering:
-            ax.text(text_x, text_y, 'a)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes,
-                    weight='bold', fontsize=title_fs)
-        ax.margins(y=0.2)
-        ax.set_xlabel("Malignancy-Score Bins", fontsize=title_fs)
-        #ax.yaxis.set_label_position("right")
-        #ax.yaxis.tick_right()
-        ax.set_yticklabels([])
-        #ax.xaxis.set_label_position("top")
-        #ax.xaxis.tick_top()
-        ax.set_title("Average Rater Dissent", fontsize=title_fs)
-
-    if confusion_matrix:
-        #------ confusion matrix -------
-        cm = lidc_annotator_confusion(dataset=dataset, plot=False, normalize="truth")
-        ax = fig.add_subplot(grid[0, 0])
-        cmap = plg.make_colormap([(1,1,1), cf.dkfz_blue])
-        plg.plot_confusion_matrix(cf, cm, ax=ax, fs=text_fs, color_bar=False, cmap=cmap )#plg.plt.cm.Purples)
-        ax.set_xticks(np.arange(cm.shape[1]))
-        if numbering:
-            ax.text(-0.16, text_y, 'a)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes,
-                    weight='bold', fontsize=title_fs)
-        ax.margins(y=0.2)
-        ax.set_title("Annotator Dissent", fontsize=title_fs)
-
-    #fig.suptitle("               Example", fontsize=title_fs)
-    #fig.text(0.63, 1.03, "Example", va="center", ha="center", size=title_fs, transform=fig.transFigure)
-
-    #fig_patches = fig_leg.get_patches()
-    #patches= [plg.mpatches.Patch(color=label.color, label="{:.10s}".format(label.name)) for label in cf.bin_id2label.values() if label.id!=0]
-    #fig.legends.append(fig_leg)
-    #plg.plt.figlegend(handles=patches, loc="lower center", bbox_to_anchor=(0.5, 0.0), borderaxespad=0.,
-    # ncol=len(patches), bbox_transform=fig.transFigure, title="Binned Malignancy Score", fontsize= text_fs)
-    plg.plt.tight_layout()
-    if plot_dir is None:
-        plot_dir = "datasets/lidc/misc"
-    out_file = os.path.join(plot_dir, "regrcnn_lidc_diss_example.png")
-    if out_file is not None:
-        plg.plt.savefig(out_file, dpi=600, bbox_inches='tight')
-
-def lidc_annotator_dissent_images(exp_dir='', plot_dir=None):
-    if plot_dir is None:
-        plot_dir = "datasets/lidc/misc"
-
-    import datasets.lidc.data_loader as dl
-    cf = get_cf('lidc', exp_dir)
-    cf.training_gts = "sa"
-
-    dataset = dl.Dataset(cf, mode='train')
-
-    pids = {'0069a': 132, '0493a':125, '1008a': 164}#, '0355b': 138, '0484a': 86} # pid : (z_ix to show)
-    # add_pids = dataset.set_ids[65:80]
-    # for pid in add_pids:
-    #     try:
-    #
-    #         pids[pid] = int(np.median(dataset.data[pid]['fg_slices'][0]))
-    #
-    #     except (IndexError, ValueError):
-    #         print("pid {} has no foreground".format(pid))
-
-    if not os.path.exists(plot_dir):
-        os.mkdir(plot_dir)
-    out_file = os.path.join(plot_dir, "lidc_example_rater_dissent.png")
-
-    #cf.training_gts = 'sa'
-    cf.roi_items = ['regression_targets', 'rg_bin_targets_sa'] #['class_targets'] + cf.observables_rois
-
-    title_fs = 14
-    text_fs = 14
-    fig = plg.plt.figure(figsize=(10, 5.9)) #w, h
-    #fig.subplots_adjust(hspace=0, wspace=0)
-    grid = plg.plt.GridSpec(len(pids.keys()), 5, wspace=0.0, hspace=0.0, figure=fig) #rows, cols
-    cf.plot_class_ids = True
-    cmap = {id : (label.color if id!=0 else (0.,0.,0.)) for id, label in cf.bin_id2label.items()}
-    legend_handles = set()
-    window_size = (250,250)
-
-    for p_ix, (pid, z_ix) in enumerate(pids.items()):
-        try:
-            print("plotting pid, z_ix", pid, z_ix)
-            patient = dataset[pid]
-            img = np.load(patient['data'], mmap_mode='r')[z_ix] # z,y,x --> y,x
-            seg = np.load(patient['seg'], mmap_mode='r')['seg'][:,z_ix] # rater,z,y,x --> rater,y,x
-            rg_bin_targets = patient['rg_bin_targets']
-
-            contours = np.nonzero(seg[0])
-            center_y, center_x = np.median(contours[0]), np.median(contours[1])
-            #min_y, min_x = np.min(contours[0]), np.min(contours[1])
-            #max_y, max_x = np.max(contours[0]), np.max(contours[1])
-            #buffer_y, buffer_x = int(seg.shape[1]*0.5), int(seg.shape[2]*0.5)
-            #y_range = np.arange(max(min_y-buffer_y, 0), min(min_y+buffer_y, seg.shape[1]))
-            #x_range =  np.arange(max(min_x-buffer_x, 0), min(min_x+buffer_x, seg.shape[2]))
-            y_range = np.arange(max(int(center_y-window_size[0]/2), 0), min(int(center_y+window_size[0]/2), seg.shape[1]))
-
-            min_x = int(center_x-window_size[1]/2)
-            max_x = int(center_x+window_size[1]/2)
-            if min_x<0:
-                max_x += abs(min_x)
-            elif max_x>seg.shape[2]:
-                min_x -= max_x-seg.shape[2]
-            x_range =  np.arange(max(min_x, 0), min(max_x, seg.shape[2]))
-            img = img[y_range][:,x_range]
-            seg = seg[:, y_range][:,:,x_range]
-            # data
-            ax = fig.add_subplot(grid[p_ix, 0])
-            ax.imshow(img, cmap='gray')
-
-            plg.suppress_axes_lines(ax)
-            # key = "spec" if "spec" in batch.keys() else "pid"
-            ylabel = str(pid) + "/" + str(z_ix)
-            ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs)  # show id-number
-            if p_ix == 0:
-                ax.set_title("Image", fontsize=title_fs)
-
-            # raters
-            for r_ix in range(seg.shape[0]):
-                rater_bin_targets = rg_bin_targets[:,r_ix]
-                for roi_ix, rating in enumerate(rater_bin_targets):
-                    seg[r_ix][seg[r_ix]==roi_ix+1] = rating
-                ax = fig.add_subplot(grid[p_ix, r_ix+1])
-                ax.imshow(seg[r_ix], cmap='gray')
-                ax.imshow(plg.to_rgba(seg[r_ix], cmap), alpha=0.8)
-                ax.axis('off')
-                if p_ix == 0:
-                    ax.set_title("Rating {}".format(r_ix+1), fontsize=title_fs)
-                legend_handles.update([cf.bin_id2label[id] for id in np.unique(seg[r_ix]) if id!=0])
-        except:
-            print("failed pid", pid)
-            pass
-
-    legend_handles = [plg.mpatches.Patch(color=label.color, label="{:.10s}".format(label.name)) for label in legend_handles]
-    legend_handles = sorted(legend_handles, key=lambda h: h._label)
-    fig.suptitle("LIDC Single-Rater Annotations", fontsize=title_fs)
-    #patches= [plg.mpatches.Patch(color=label.color, label="{:.10s}".format(label.name)) for label in cf.bin_id2label.values() if label.id!=0]
-
-    legend = fig.legend(handles=legend_handles, loc="lower center", bbox_to_anchor=(0.5, 0.0), borderaxespad=0, fontsize=text_fs,
-                      bbox_transform=fig.transFigure, ncol=len(legend_handles), title="Malignancy Score")
-    plg.plt.setp(legend.get_title(), fontsize=title_fs)
-    #grid.tight_layout(fig)
-    #plg.plt.tight_layout(rect=[0, 0.00, 1, 1.5])
-    if out_file is not None:
-        plg.plt.savefig(out_file, dpi=600, bbox_inches='tight')
-
-
-
-    return
-
-def lidc_results_static(xlabels=None, plot_dir=None, in_percent=True):
-    cf = get_cf('lidc', '')
-    if plot_dir is None:
-        plot_dir = os.path.join('datasets', 'lidc', 'misc')
-
-    text_fs = 18
-    fig = plg.plt.figure(figsize=(3, 2.5)) #w,h
-    grid = plg.plt.GridSpec(2, 1, wspace=0.0, hspace=0.0, figure=fig) #r,c
-
-    #--- LIDC 3D -----
-
-
-    splits = ["Reg R-CNN", "Mask R-CNN"]#, "Reg R-CNN 2D", "Mask R-CNN 2D"]
-    values = {"reg3d": [(0.259, 0.035), (0.628, 0.038), (0.477, 0.035)],
-              "mask3d": [(0.235, 0.027), (0.622, 0.029), (0.411, 0.026)],}
-    groups = [r"$\mathrm{AVP}_{10}$", "$\mathrm{AP}_{10}$", "Bin Acc."]
-    if in_percent:
-        bar_values = [[v[0]*100 for v in split] for split in values.values()]
-        errors = [[v[1]*100 for v in split] for split in values.values()]
-    else:
-        bar_values = [[v[0] for v in split] for split in values.values()]
-        errors = [[v[1] for v in split] for split in values.values()]
-
-    ax = fig.add_subplot(grid[0,0])
-    colors = [cf.blue, cf.dkfz_blue]
-    plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, errors=errors, colors=colors, ax=ax, legend=False, label_format="{:.1f}",
-                               title="LIDC Results", ylabel=r"3D Perf. (%)", xlabel="Metric", yticklabels=[], ylim=(0,80 if in_percent else 0.8))
-    #------ LIDC 2D -------
-
-    splits = ["Reg R-CNN", "Mask R-CNN"]
-    values = {"reg2d": [(0.148, 0.046), (0.414, 0.052), (0.468, 0.057)],
-              "mask2d": [(0.127, 0.034), (0.406, 0.040), (0.447, 0.018)]}
-    groups = [r"$\mathrm{AVP}_{10}$", "$\mathrm{AP}_{10}$", "Bin Acc."]
-    if in_percent:
-        bar_values = [[v[0]*100 for v in split] for split in values.values()]
-        errors = [[v[1]*100 for v in split] for split in values.values()]
-    else:
-        bar_values = [[v[0] for v in split] for split in values.values()]
-        errors = [[v[1] for v in split] for split in values.values()]
-    ax = fig.add_subplot(grid[1,0])
-    colors = [cf.blue, cf.dkfz_blue]
-    plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, errors=errors, colors=colors, ax=ax, legend=False, label_format="{:.1f}",
-                               title="", ylabel=r"2D Perf.", xlabel="Metric", xticklabels=xlabels, yticklabels=[], ylim=(None,60 if in_percent else 0.6))
-    plg.plt.tight_layout()
-    plg.plt.savefig(os.path.join(plot_dir, 'lidc_static_results.png'), dpi=700)
-
-def toy_results_static(xlabels=None, plot_dir=None, in_percent=True):
-    cf = get_cf('toy', '')
-    if plot_dir is None:
-        plot_dir = os.path.join('datasets', 'toy', 'misc')
-
-    text_fs = 18
-    fig = plg.plt.figure(figsize=(3, 2.5)) #w,h
-    grid = plg.plt.GridSpec(2, 1, wspace=0.0, hspace=0.0, figure=fig) #r,c
-
-    #--- Toy 3D -----
-    groups = [r"$\mathrm{AVP}_{10}$", "$\mathrm{AP}_{10}$", "Bin Acc."]
-    splits = ["Reg R-CNN", "Mask R-CNN"]#, "Reg R-CNN 2D", "Mask R-CNN 2D"]
-    values = {"reg3d": [(0.881, 0.014), (0.998, 0.004), (0.887, 0.014)],
-              "mask3d": [(0.822, 0.070), (1.0, 0.0), (0.826, 0.069)],}
-    if in_percent:
-        bar_values = [[v[0]*100 for v in split] for split in values.values()]
-        errors = [[v[1]*100 for v in split] for split in values.values()]
-    else:
-        bar_values = [[v[0] for v in split] for split in values.values()]
-        errors = [[v[1] for v in split] for split in values.values()]
-    ax = fig.add_subplot(grid[0,0])
-    colors = [cf.blue, cf.dkfz_blue]
-    plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, errors=errors, colors=colors, ax=ax, legend=True, label_format="{:.1f}",
-                               title="Toy Results", ylabel=r"3D Perf. (%)", xlabel="Metric", yticklabels=[], ylim=(0,130 if in_percent else .3))
-    #------ Toy 2D -------
-    groups = [r"$\mathrm{AVP}_{10}$", "$\mathrm{AP}_{10}$", "Bin Acc."]
-    splits = ["Reg R-CNN", "Mask R-CNN"]
-    values = {"reg2d": [(0.859, 0.021), (1., 0.0), (0.860, 0.021)],
-              "mask2d": [(0.748, 0.022), (1., 0.0), (0.748, 0.021)]}
-    if in_percent:
-        bar_values = [[v[0]*100 for v in split] for split in values.values()]
-        errors = [[v[1]*100 for v in split] for split in values.values()]
-    else:
-        bar_values = [[v[0] for v in split] for split in values.values()]
-        errors = [[v[1] for v in split] for split in values.values()]
-    ax = fig.add_subplot(grid[1,0])
-    colors = [cf.blue, cf.dkfz_blue]
-    plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, errors=errors, colors=colors, ax=ax, legend=False, label_format="{:.1f}",
-                               title="", ylabel=r"2D Perf.", xlabel="Metric", xticklabels=xlabels, yticklabels=[], ylim=(None,130 if in_percent else 1.3))
-    plg.plt.tight_layout()
-    plg.plt.savefig(os.path.join(plot_dir, 'toy_static_results.png'), dpi=700)
-
-def analyze_test_df(dataset_name, exp_dir='', cf=None, logger=None, plot_dir=None):
-    evaluator_file = utils.import_module('evaluator', "evaluator.py")
-    if cf is None:
-        cf = get_cf(dataset_name, exp_dir)
-        cf.exp_dir = exp_dir
-        cf.test_dir = os.path.join(exp_dir, 'test')
-    if logger is None:
-        logger = utils.get_logger(cf.exp_dir, False)
-    evaluator = evaluator_file.Evaluator(cf, logger, mode='test')
-
-    fold_df_paths = sorted([ii for ii in os.listdir(cf.test_dir) if 'test_df.pkl' in ii])
-    fold_seg_df_paths = sorted([ii for ii in os.listdir(cf.test_dir) if 'test_seg_df.pkl' in ii])
-    metrics_to_score = ['ap', 'auc']#, 'patient_ap', 'patient_auc', 'patient_dice'] #'rg_bin_accuracy_weighted_tp', 'rg_MAE_w_std_weighted_tp'] #cf.metrics
-    if cf.evaluate_fold_means:
-        means_to_score = [m for m in metrics_to_score] #+ ['rg_MAE_w_std_weighted_tp']
-    #metrics_to_score += ['rg_MAE_std']
-    metrics_to_score = []
-
-
-    cf.fold = 'overall'
-    dfs_list = [pd.read_pickle(os.path.join(cf.test_dir, ii)) for ii in fold_df_paths]
-    evaluator.test_df = pd.concat(dfs_list, sort=True)
-
-    seg_dfs_list = [pd.read_pickle(os.path.join(cf.test_dir, ii)) for ii in fold_seg_df_paths]
-    if len(seg_dfs_list) > 0:
-        evaluator.seg_df = pd.concat(seg_dfs_list, sort=True)
-
-    # stats, _ = evaluator.return_metrics(evaluator.test_df, cf.class_dict)
-    # results_table_path = os.path.join(cf.exp_dir, "../", "semi_man_summary.csv")
-    # # ---column headers---
-    # col_headers = ["Experiment Name", "CV Folds", "Spatial Dim", "Clustering Kind", "Clustering IoU", "Merge-2D-to-3D IoU"]
-    # if hasattr(cf, "test_against_exact_gt"):
-    #     col_headers.append('Exact GT')
-    # for s in stats:
-    #     assert "overall" in s['name'].split(" ")[0]
-    #     if cf.class_dict[cf.patient_class_of_interest] in s['name']:
-    #         for metric in metrics_to_score:
-    #             #if metric in s.keys() and not np.isnan(s[metric]):
-    #             col_headers.append('{}_{} : {}'.format(*s['name'].split(" ")[1:], metric))
-    #         for mean in means_to_score:
-    #             if mean == "rg_MAE_w_std_weighted_tp":
-    #                 col_headers.append('(MAE_fold_mean\u00B1std_fold_mean)\u00B1fold_mean_std\u00B1fold_std_std)'.format(*s['name'].split(" ")[1:], mean))
-    #             elif mean in s.keys() and not np.isnan(s[mean]):
-    #                 col_headers.append('{}_{} : {}'.format(*s['name'].split(" ")[1:], mean))
-    #             else:
-    #                 print("skipping {}".format(mean))
-    # with open(results_table_path, 'a') as handle:
-    #     with open(results_table_path, 'r') as doublehandle:
-    #         last_header = doublehandle.readlines()
-    #     if len(last_header)==0 or len(col_headers)!=len(last_header[1].split(",")[:-1]) or \
-    #             not all([col_headers[ix]==lhix for ix, lhix in enumerate(last_header[1].split(",")[:-1])]):
-    #         handle.write('\n')
-    #         for head in col_headers:
-    #             handle.write(head+',')
-    #         handle.write('\n')
-    #
-    #     # --- columns content---
-    #     handle.write('{},'.format(cf.exp_dir.split(os.sep)[-1]))
-    #     handle.write('{},'.format(str(evaluator.test_df.fold.unique().tolist()).replace(",", "")))
-    #     handle.write('{}D,'.format(cf.dim))
-    #     handle.write('{},'.format(cf.clustering))
-    #     handle.write('{},'.format(cf.clustering_iou if cf.clustering else str("N/A")))
-    #     handle.write('{},'.format(cf.merge_3D_iou if cf.merge_2D_to_3D_preds else str("N/A")))
-    #     if hasattr(cf, "test_against_exact_gt"):
-    #         handle.write('{},'.format(cf.test_against_exact_gt))
-    #     for s in stats:
-    #         if cf.class_dict[cf.patient_class_of_interest] in s['name']:
-    #             for metric in metrics_to_score:
-    #                 #if metric in s.keys() and not np.isnan(s[metric]):  # needed as long as no dice on patient level poss
-    #                 handle.write('{:0.3f}, '.format(s[metric]))
-    #             for mean in means_to_score:
-    #                 #if metric in s.keys() and not np.isnan(s[metric]):
-    #                 if mean=="rg_MAE_w_std_weighted_tp":
-    #                     handle.write('({:0.3f}\u00B1{:0.3f})\u00B1({:0.3f}\u00B1{:0.3f}),'.format(*s[mean + "_folds_mean"], *s[mean + "_folds_std"]))
-    #                 elif mean in s.keys() and not np.isnan(s[mean]):
-    #                     handle.write('{:0.3f}\u00B1{:0.3f},'.format(s[mean+"_folds_mean"], s[mean+"_folds_std"]))
-    #                 else:
-    #                     print("skipping {}".format(mean))
-    #
-    #     handle.write('\n')
-
-    return evaluator.test_df
-
-def cluster_results_to_df(dataset_name, exp_dir='', overall_df=None, cf=None, logger=None, plot_dir=None):
-    evaluator_file = utils.import_module('evaluator', "evaluator.py")
-    if cf is None:
-        cf = get_cf(dataset_name, exp_dir)
-        cf.exp_dir = exp_dir
-        cf.test_dir = os.path.join(exp_dir, 'test')
-    if logger is None:
-        logger = utils.get_logger(cf.exp_dir, False)
-    evaluator = evaluator_file.Evaluator(cf, logger, mode='test')
-    cf.fold = 'overall'
-    metrics_to_score = ['ap', 'auc']#, 'patient_ap', 'patient_auc', 'patient_dice'] #'rg_bin_accuracy_weighted_tp', 'rg_MAE_w_std_weighted_tp'] #cf.metrics
-    if cf.evaluate_fold_means:
-        means_to_score = [m for m in metrics_to_score] #+ ['rg_MAE_w_std_weighted_tp']
-    #metrics_to_score += ['rg_MAE_std']
-    metrics_to_score = []
-
-    # use passed overall_df or, if not given, read dfs from file
-    if overall_df is None:
-        fold_df_paths = sorted([ii for ii in os.listdir(cf.test_dir) if 'test_df.pkl' in ii])
-        fold_seg_df_paths = sorted([ii for ii in os.listdir(cf.test_dir) if 'test_seg_df.pkl' in ii])
-        for paths in [fold_df_paths, fold_seg_df_paths]:
-            assert len(paths) <= cf.n_cv_splits, "found {} > nr of cv splits results dfs in {}".format(len(paths), cf.test_dir)
-        dfs_list = [pd.read_pickle(os.path.join(cf.test_dir, ii)) for ii in fold_df_paths]
-        evaluator.test_df = pd.concat(dfs_list, sort=True)
-
-        # seg_dfs_list = [pd.read_pickle(os.path.join(cf.test_dir, ii)) for ii in fold_seg_df_paths]
-        # if len(seg_dfs_list) > 0:
-        #     evaluator.seg_df = pd.concat(seg_dfs_list, sort=True)
-
-    else:
-        evaluator.test_df = overall_df
-        # todo seg_df if desired
-
-    stats, _ = evaluator.return_metrics(evaluator.test_df, cf.class_dict)
-    # ---column headers---
-    col_headers = ["Experiment Name", "Model", "CV Folds", "Spatial Dim", "Clustering Kind", "Clustering IoU", "Merge-2D-to-3D IoU"]
-    for s in stats:
-        assert "overall" in s['name'].split(" ")[0]
-        if cf.class_dict[cf.patient_class_of_interest] in s['name']:
-            for metric in metrics_to_score:
-                #if metric in s.keys() and not np.isnan(s[metric]):
-                col_headers.append('{}_{} : {}'.format(*s['name'].split(" ")[1:], metric))
-            for mean in means_to_score:
-                if mean in s.keys() and not np.isnan(s[mean]):
-                    col_headers.append('{}_{} : {}'.format(*s['name'].split(" ")[1:], mean+"_folds_mean"))
-                else:
-                    print("skipping {}".format(mean))
-    results_df = pd.DataFrame(columns=col_headers)
-    # --- columns content---
-    row = []
-    row.append('{}'.format(cf.exp_dir.split(os.sep)[-1]))
-    model = 'frcnn' if (cf.model=="mrcnn" and cf.frcnn_mode) else cf.model
-    row.append('{}'.format(model))
-    row.append('{}'.format(str(evaluator.test_df.fold.unique().tolist()).replace(",", "")))
-    row.append('{}D'.format(cf.dim))
-    row.append('{}'.format(cf.clustering))
-    row.append('{}'.format(cf.clustering_iou if cf.clustering else "N/A"))
-    row.append('{}'.format(cf.merge_3D_iou if cf.merge_2D_to_3D_preds else "N/A"))
-    for s in stats:
-        if cf.class_dict[cf.patient_class_of_interest] in s['name']:
-            for metric in metrics_to_score:
-                #if metric in s.keys() and not np.isnan(s[metric]):  # needed as long as no dice on patient level poss
-                row.append('{:0.3f} '.format(s[metric]))
-            for mean in means_to_score:
-                #if metric in s.keys() and not np.isnan(s[metric]):
-                if mean+"_folds_mean" in s.keys() and not np.isnan(s[mean+"_folds_mean"]):
-                    row.append('{:0.3f}\u00B1{:0.3f}'.format(s[mean+"_folds_mean"], s[mean+"_folds_std"]))
-                else:
-                    print("skipping {}".format(mean+"_folds_mean"))
-    #print("row, clustering, iou, exp", row, cf.clustering, cf.clustering_iou, cf.exp_dir)
-    results_df.loc[0] = row
-
-    return results_df
-
-def multiple_clustering_results(dataset_name, exp_dir, plot_dir=None, plot_hist=False):
-    print("Gathering exp {}".format(exp_dir))
-    cf = get_cf(dataset_name, exp_dir)
-    cf.n_workers = 1
-    logger = logging.getLogger("dummy")
-    logger.setLevel(logging.DEBUG)
-    #logger.addHandler(logging.StreamHandler())
-    cf.exp_dir = exp_dir
-    cf.test_dir = os.path.join(exp_dir, 'test')
-    cf.plot_prediction_histograms = False
-    if plot_dir is None:
-        #plot_dir = os.path.join(cf.test_dir, 'histograms')
-        plot_dir = os.path.join("datasets", dataset_name, "misc")
-        os.makedirs(plot_dir, exist_ok=True)
-
-    # fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if
-    #                     os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")])
-    folds = range(cf.n_cv_splits)
-    clusterings = {None: ['lol'], 'wbc': [0.0, 0.1, 0.2, 0.3, 0.4], 'nms': [0.0, 0.1, 0.2, 0.3, 0.4]}
-    #clusterings = {'wbc': [0.1,], 'nms': [0.1,]}
-    #clusterings = {None: ['lol']}
-    if plot_hist:
-        clusterings = {None: ['lol'], 'nms': [0.1, ], 'wbc': [0.1, ]}
-    class_of_interest = cf.patient_class_of_interest
-
-    try:
-        if plot_hist:
-            title_fs, text_fs = 16, 13
-            fig = plg.plt.figure(figsize=(11, 8)) #width, height
-            grid = plg.plt.GridSpec(len(clusterings.keys()), max([len(v) for v in clusterings.values()])+1, wspace=0.0,
-                                    hspace=0.0, figure=fig) #rows, cols
-            plg.plt.suptitle("Faster R-CNN+", fontsize=title_fs, va='bottom', y=0.925)
-
-        results_df = pd.DataFrame()
-        for cl_ix, (clustering, ious) in enumerate(clusterings.items()):
-            cf.clustering = clustering
-            for iou_ix, iou in enumerate(ious):
-                cf.clustering_iou = iou
-                print(r"Producing Results for Clustering {} @ IoU {}".format(cf.clustering, cf.clustering_iou))
-                overall_test_df = pd.DataFrame()
-                for fold in folds[:]:
-                    cf.fold = fold
-                    cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold))
-
-                    predictor = predictor_file.Predictor(cf, net=None, logger=logger, mode='analysis')
-                    results_list = predictor.load_saved_predictions()
-                    logger.info('starting evaluation...')
-                    evaluator = evaluator_file.Evaluator(cf, logger, mode='test')
-                    evaluator.evaluate_predictions(results_list)
-                    #evaluator.score_test_df(max_fold=100)
-                    overall_test_df = overall_test_df.append(evaluator.test_df)
-
-                results_df = results_df.append(cluster_results_to_df(dataset_name, overall_df=overall_test_df,cf=cf,
-                                                                     logger=logger))
-
-                if plot_hist:
-                    if clustering=='wbc' and iou_ix==len(ious)-1:
-                        # plot n_missing histogram for last wbc clustering only
-                        out_filename = os.path.join(plot_dir, 'analysis_n_missing_overall_hist_{}_{}.png'.format(clustering, iou))
-                        ax = fig.add_subplot(grid[cl_ix, iou_ix+1])
-                        plg.plot_wbc_n_missing(cf, overall_test_df, outfile=out_filename, fs=text_fs, ax=ax)
-                        ax.set_title("WBC Missing Predictions per Cluster.", fontsize=title_fs)
-                        #ax.set_ylabel(r"Average Missing Preds per Cluster (%)")
-                        ax.yaxis.tick_right()
-                        ax.yaxis.set_label_position("right")
-                        ax.text(0.07, 0.87, "{}) WBC".format(chr(len(clusterings.keys())*len(ious)+97)), transform=ax.transAxes, color=cf.white, fontsize=title_fs,
-                                bbox=dict(boxstyle='square', facecolor='black', edgecolor='none', alpha=0.9))
-                    overall_test_df = overall_test_df[overall_test_df.pred_class == class_of_interest]
-                    overall_test_df = overall_test_df[overall_test_df.det_type!='patient_tn']
-                    out_filename = "analysis_fold_overall_hist_{}_{}.png".format(clustering, iou)
-                    out_filename = os.path.join(plot_dir, out_filename)
-                    ax = fig.add_subplot(grid[cl_ix, iou_ix])
-                    plg.plot_prediction_hist(cf, overall_test_df, out_filename, fs=text_fs, ax=ax)
-                    ax.text(0.11, 0.87, "{}) {}".format(chr((cl_ix+1)*len(ious)+96), clustering.upper() if clustering else "Raw Preds"), transform=ax.transAxes, color=cf.white,
-                            bbox=dict(boxstyle='square', facecolor='black', edgecolor='none', alpha=0.9), fontsize=title_fs)
-                    if cl_ix==0 and iou_ix==0:
-                        ax.set_title("Prediction Histograms Malignant Class", fontsize=title_fs)
-                        ax.legend(loc="best", fontsize=text_fs)
-                    else:
-                        ax.set_title("")
-                #analyze_test_df(dataset_name, cf=cf, logger=logger)
-        if plot_hist:
-            #plg.plt.subplots_adjust(top=0.)
-            plg.plt.savefig(os.path.join(plot_dir, "combined_hist_plot.pdf"), dpi=600, bbox_inches='tight')
-
-    except FileNotFoundError as e:
-        print("Ignoring exp dir {} due to\n{}".format(exp_dir, e))
-    logger.handlers = []
-    del cf; del logger
-    return results_df
-
-def gather_clustering_results(dataset_name, exp_parent_dir, exps_filter=None, processes=os.cpu_count()//2):
-    exp_dirs = [os.path.join(exp_parent_dir, i) for i in os.listdir(exp_parent_dir + "/") if
-                os.path.isdir(os.path.join(exp_parent_dir, i))]#[:1]
-    if exps_filter is not None:
-        exp_dirs = [ed for ed in exp_dirs if not exps_filter in ed]
-    # for debugging
-    #exp_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_frcnn3d_cl_bs6"
-    #exp_dirs = [exp_dir,]
-    #exp_dirs = ["/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_detfpn2d_cl_bs10",]
-
-    results_df = pd.DataFrame()
-
-    p = NoDaemonProcessPool(processes=processes)
-    mp_inputs = [(dataset_name, exp_dir) for exp_dir in exp_dirs][:]
-    results_dfs = p.starmap(multiple_clustering_results, mp_inputs)
-    p.close()
-    p.join()
-    for df in results_dfs:
-        results_df = results_df.append(df)
-
-    results_df.to_csv(os.path.join(exp_parent_dir, "df_cluster_summary.csv"), index=False)
-
-    return results_df
-
-def plot_cluster_results_grid(cf, res_df, ylim=None, out_file=None):
-    """
-    :param cf:
-    :param res_df: results over a single dimension setting (2D or 3D), over all clustering methods and ious.
-    :param out_file:
-    :return:
-    """
-    is_2d = np.all(res_df["Spatial Dim"]=="2D")
-    # pandas has problems with recognising "N/A" string --> replace by None
-    #res_df['Merge-2D-to-3D IoU'].iloc[res_df['Merge-2D-to-3D IoU'] == "N/A"] = None
-    n_rows = 3#4 if is_2d else 3
-    grid = plg.plt.GridSpec(n_rows, 5, wspace=0.4, hspace=0.3)
-
-    fig = plg.plt.figure(figsize=(11,6))
-
-    splits = res_df["Model"].unique().tolist() # need to be model names
-    for split in splits:
-        assoc_exps = res_df[res_df["Model"]==split]["Experiment Name"].unique()
-        if len(assoc_exps)>1:
-            print("Model {} has multiple experiments:\n{}".format(split, assoc_exps))
-            #res_df = res_df.where(~(res_df["Model"] == split), res_df["Experiment Name"], axis=0)
-            raise Exception("Multiple Experiments")
-
-    sort_map = {'detection_fpn': 0, 'mrcnn':1, 'frcnn':2, 'retina_net':3, 'retina_unet':4}
-    splits.sort(key=sort_map.__getitem__)
-    #colors = [cf.color_palette[ix+3 % len(cf.color_palette)] for ix in range(len(splits))]
-    color_map = {'detection_fpn': cf.magenta, 'mrcnn':cf.blue, 'frcnn': cf.dark_blue, 'retina_net': cf.aubergine, 'retina_unet': cf.purple}
-
-    colors = [color_map[split] for split in splits]
-    alphas =  [0.9,] * len(splits)
-    legend_handles = []
-    model_renamer = {'detection_fpn': "Detection U-Net", 'mrcnn': "Mask R-CNN", 'frcnn': "Faster R-CNN+", 'retina_net': "RetinaNet", 'retina_unet': "Retina U-Net"}
-
-    for rix, c_kind in zip([0, 1],['wbc', 'nms']):
-        kind_df = res_df[res_df['Clustering Kind'] == c_kind]
-        groups = kind_df['Clustering IoU'].unique()
-        #for cix, iou in enumerate(groups):
-        assert np.all([split in splits for split in kind_df["Model"].unique()]) #need to be model names
-        ax = fig.add_subplot(grid[rix,:])
-        bar_values = [kind_df[kind_df["Model"]==split]["rois_malignant : ap_folds_mean"] for split in splits]
-        bar_stds = [[float(val.split('\u00B1')[1]) for val in split_vals] for split_vals in bar_values]
-        bar_values = [ [float(val.split('\u00B1')[0]) for val in split_vals] for split_vals in bar_values ]
-
-
-        xlabel='' if rix == 0 else "Clustering IoU"
-        ylabel = str(c_kind.upper()) + " / AP"
-        lh = plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, colors=colors, alphas=alphas, errors=bar_stds,
-                                        ax=ax, ylabel=ylabel, xlabel=xlabel)
-        legend_handles.append(lh)
-        if rix == 0:
-            ax.axes.get_xaxis().set_ticks([])
-            #ax.spines['top'].set_visible(False)
-            #ax.spines['right'].set_visible(False)
-            ax.spines['bottom'].set_visible(False)
-            #ax.spines['left'].set_visible(False)
-        else:
-            ax.spines['top'].set_visible(False)
-            #ticklab = ax.xaxis.get_ticklabels()
-            #trans = ticklab.get_transform()
-            ax.xaxis.set_label_coords(0.05, -0.05)
-        ax.set_ylim(0.,ylim)
-
-    if is_2d:
-        # only 2d-3d merging @ 0.1
-        ax = fig.add_subplot(grid[2, 1])
-        kind_df = res_df[(res_df['Clustering Kind'] == 'None') & ~(res_df['Merge-2D-to-3D IoU'].isna())]
-        groups = kind_df['Clustering IoU'].unique()
-        bar_values = [kind_df[kind_df["Model"] == split]["rois_malignant : ap_folds_mean"] for split in splits]
-        bar_stds = [[float(val.split('\u00B1')[1]) for val in split_vals] for split_vals in bar_values]
-        bar_values = np.array([[float(val.split('\u00B1')[0]) for val in split_vals] for split_vals in bar_values])
-        lh = plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, colors=colors, alphas=alphas, errors=bar_stds,
-                                        ax=ax, ylabel="2D-3D Merging\nOnly / AP")
-        legend_handles.append(lh)
-        ax.axes.get_xaxis().set_ticks([])
-        ax.spines['top'].set_visible(False)
-        ax.spines['right'].set_visible(False)
-        ax.spines['bottom'].set_visible(False)
-        ax.spines['left'].set_visible(False)
-        ax.set_ylim(0., ylim)
-
-        next_row = 2
-        next_col = 2
-    else:
-        next_row = 2
-        next_col = 2
-
-    # No clustering at all
-    ax = fig.add_subplot(grid[next_row, next_col])
-    kind_df = res_df[(res_df['Clustering Kind'] == 'None') & (res_df['Merge-2D-to-3D IoU'].isna())]
-    groups = kind_df['Clustering IoU'].unique()
-    bar_values = [kind_df[kind_df["Model"] == split]["rois_malignant : ap_folds_mean"] for split in splits]
-    bar_stds = [[float(val.split('\u00B1')[1]) for val in split_vals] for split_vals in bar_values]
-    bar_values = np.array([[float(val.split('\u00B1')[0]) for val in split_vals] for split_vals in bar_values])
-    lh = plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, colors=colors, alphas=alphas, errors=bar_stds,
-                                    ax=ax, ylabel="No Clustering / AP")
-    legend_handles.append(lh)
-    #plg.suppress_axes_lines(ax)
-    #ax = fig.add_subplot(grid[next_row, 0])
-    #ax.set_ylabel("No Clustering")
-    #plg.suppress_axes_lines(ax)
-    ax.axes.get_xaxis().set_ticks([])
-    ax.spines['top'].set_visible(False)
-    ax.spines['right'].set_visible(False)
-    ax.spines['bottom'].set_visible(False)
-    ax.spines['left'].set_visible(False)
-    ax.set_ylim(0., ylim)
-
-
-    ax = fig.add_subplot(grid[next_row, 3])
-    # awful hot fix: only legend_handles[0] used in order to have same order as in plots.
-    legend_handles = [plg.mpatches.Patch(color=handle[0], alpha=handle[1], label=model_renamer[handle[2]]) for handle in legend_handles[0]]
-    ax.legend(handles=legend_handles)
-    ax.axis('off')
-
-    fig.suptitle('Prostate {} Results over Clustering Settings'.format(res_df["Spatial Dim"].unique().item()), fontsize=14)
-
-    if out_file is not None:
-        plg.plt.savefig(out_file)
-
-    return
-
-def get_plot_clustering_results(dataset_name, exp_parent_dir, res_from_file=True, exps_filter=None):
-    if not res_from_file:
-        results_df = gather_clustering_results(dataset_name, exp_parent_dir, exps_filter=exps_filter)
-    else:
-        results_df = pd.read_csv(os.path.join(exp_parent_dir, "df_cluster_summary.csv"))
-        if os.path.isfile(os.path.join(exp_parent_dir, "df_cluster_summary_no_clustering_2D.csv")):
-            results_df = results_df.append(pd.read_csv(os.path.join(exp_parent_dir, "df_cluster_summary_no_clustering_2D.csv")))
-
-    cf = get_cf(dataset_name)
-    if np.count_nonzero(results_df["Spatial Dim"] == "3D") >0:
-        # 3D
-        plot_cluster_results_grid(cf, results_df[results_df["Spatial Dim"] == "3D"], ylim=0.52, out_file=os.path.join(exp_parent_dir, "cluster_results_3D.pdf"))
-    if np.count_nonzero(results_df["Spatial Dim"] == "2D") > 0:
-        # 2D
-        plot_cluster_results_grid(cf, results_df[results_df["Spatial Dim"]=="2D"], ylim=0.4, out_file=os.path.join(exp_parent_dir, "cluster_results_2D.pdf"))
-
-
-def plot_single_results(cf, exp_dir, plot_files, res_df=None):
-    out_file = os.path.join(exp_dir, "inference_analysis", "single_results.pdf")
-
-    plot_files = utils.load_obj(plot_files)
-    batch = plot_files["batch"]
-    results_dict = plot_files["res_dict"]
-    cf.roi_items = ['class_targets']
-
-    class_renamer = {1: "GS 6", 2: "GS $\geq 7$"}
-    gs_renamer = {60: "6", 71: "7a"}
-
-    if "adcb" in exp_dir:
-        modality = "adcb"
-    elif "t2" in exp_dir:
-        modality = "t2"
-    else:
-        modality = "b"
-    text_fs = 16
-
-    if modality=="t2":
-        n_rows, n_cols = 2, 3
-        gt_col = 1
-        fig_w, fig_h = 14, 4
-        input_x, input_y = 0.05, 0.9
-        z_ix = 11
-        thresh = 0.22
-        input_title = "Input"
-    elif modality=="b":
-        n_rows, n_cols = 2, 6
-        gt_col = 2 # = gt_span
-        fig_w, fig_h = 14, 4
-        input_x, input_y = 0.08, 0.8
-        z_ix = 8
-        thresh = 0.16
-        input_title = "                                 Input"
-    elif modality=="adcb":
-        n_rows, n_cols = 2, 7
-        gt_col = 3
-        fig_w, fig_h = 14, 4
-        input_x, input_y = 0.08, 0.8
-        z_ix = 8
-        thresh = 0.16
-        input_title = "Input"
-    fig_w, fig_h = 12, 3.87
-    fig = plg.plt.figure(figsize=(fig_w, fig_h))
-    grid = plg.plt.GridSpec(n_rows, n_cols, wspace=0.0, hspace=0.0, figure=fig)
-    cf.plot_class_ids = True
-
-    if modality=="t2":
-        ax = fig.add_subplot(grid[:, 0])
-        ax.imshow(batch['patient_data'][0, 0, :, :, z_ix], cmap='gray')
-        ax.set_title("Input", size=text_fs)
-        ax.text(0.05, 0.9, "T2", size=text_fs, color=cf.white, transform=ax.transAxes,
-                bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7))
-        ax.axis("off")
-    elif modality=="b":
-        for m_ix, b in enumerate([50, 500, 1000, 1500]):
-            ax = fig.add_subplot(grid[int(np.round(m_ix/4+0.0001)), m_ix%2])
-            print(int(np.round(m_ix/4+0.0001)), m_ix%2)
-            ax.imshow(batch['patient_data'][0, m_ix, :, :, z_ix], cmap='gray')
-            ax.text(input_x, input_y, r"{}{}".format("$b=$" if m_ix==0 else "", b), size=text_fs, color=cf.white, transform=ax.transAxes,
-                    bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7))
-            ax.axis("off")
-            if b==50:
-                ax.set_title(input_title, size=text_fs)
-    elif modality=="adcb":
-        for m_ix, b in enumerate(["ADC", 50, 500, 1000, 1500]):
-            p_ix = m_ix + 1 if m_ix>2 else m_ix
-            ax = fig.add_subplot(grid[int(np.round(p_ix/6+0.0001)), p_ix%3])
-            print(int(np.round(p_ix/4+0.0001)), p_ix%2)
-            ax.imshow(batch['patient_data'][0, m_ix, :, :, z_ix], cmap='gray')
-            ax.text(input_x, input_y, r"{}{}".format("$b=$" if m_ix==1 else "", b), size=text_fs, color=cf.white, transform=ax.transAxes,
-                    bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7))
-            ax.axis("off")
-            if b==50:
-                ax.set_title(input_title, size=text_fs)
-
-    ax_gt = fig.add_subplot(grid[:, gt_col:gt_col+2]) # GT
-    ax_pred = fig.add_subplot(grid[:, gt_col+2:gt_col+4]) # Prediction
-    #ax.imshow(batch['patient_data'][0, 0, :, :, z_ix], cmap='gray')
-    #ax.imshow(batch['patient_data'][0, 0, :, :, z_ix], cmap='gray')
-    #ax.imshow(plg.to_rgba(batch['patient_seg'][0,0,:,:,z_ix], cf.cmap), alpha=0.8)
-    plg.view_batch_thesis(cf, batch, res_dict=results_dict, legend=True, sample_picks=None, patient_items=True,
-                          vol_slice_picks=[z_ix,], show_gt_labels=True, box_score_thres=thresh, plot_mods=True,
-                          out_file=None, dpi=600, return_fig=False, axes={'gt':ax_gt, 'pred':ax_pred}, fontsize=text_fs)
-
-
-    ax_gt.set_title("Ground Truth", size=text_fs)
-    ax_pred.set_title("Prediction", size=text_fs)
-    texts = list(ax_gt.texts)
-    ax_gt.texts = []
-    for text in texts:
-        cl_id = int(text.get_text())
-        x, y = text.get_position()
-        text_str = "GS="+str(gs_renamer[cf.class_id2label[cl_id].gleasons[0]])
-        ax_gt.text(x-4*text_fs//2, y,  text_str, color=text.get_color(),
-        fontsize=text_fs, bbox=dict(facecolor=text.get_bbox_patch().get_facecolor(), alpha=0.7, edgecolor='none', clip_on=True, pad=0))
-    texts = list(ax_pred.texts)
-    ax_pred.texts = []
-    for text in texts:
-        x, y = text.get_position()
-        x -= 4 * text_fs // 2
-        try:
-            cl_id = int(text.get_text())
-            text_str = class_renamer[cl_id]
-        except ValueError:
-            text_str = text.get_text()
-        if text.get_bbox_patch().get_facecolor()[:3]==cf.dark_green:
-            x -= 4* text_fs
-        ax_pred.text(x, y,  text_str, color=text.get_color(),
-        fontsize=text_fs, bbox=dict(facecolor=text.get_bbox_patch().get_facecolor(), alpha=0.7, edgecolor='none', clip_on=True, pad=0))
-
-    ax_gt.axis("off")
-    ax_pred.axis("off")
-
-    plg.plt.tight_layout()
-
-    if out_file is not None:
-        plg.plt.savefig(out_file, dpi=600, bbox_inches='tight')
-
-
-
-    return
-
-def find_suitable_examples(exp_dir1, exp_dir2):
-    test_df1 = analyze_test_df('lidc',exp_dir1)
-    test_df2 = analyze_test_df('lidc', exp_dir2)
-    test_df1 = test_df1[test_df1.pred_score>0.3]
-    test_df2 = test_df2[test_df2.pred_score > 0.3]
-
-    tp_df1 = test_df1[test_df1.det_type == 'det_tp']
-
-    tp_pids = tp_df1.pid.unique()
-    tp_fp_pids = test_df2[(test_df2.pid.isin(tp_pids)) &
-                          ((test_df2.regressions-test_df2.rg_targets).abs()>1)].pid.unique()
-    cand_df = tp_df1[tp_df1.pid.isin(tp_fp_pids)]
-    sorter = (cand_df.regressions - cand_df.rg_targets).abs().argsort()
-    cand_df = cand_df.iloc[sorter]
-    print("Good guesses for examples: ", cand_df.pid.unique()[:20])
-    return
-
-def plot_single_results_lidc():
-    dataset_name = 'lidc'
-    exp_dir1 = '/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/ms12345_mrcnn3d_rg_copiedparams'
-    exp_dir2 = '/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/ms12345_mrcnn3d_rgbin_copiedparams'
-    cf = get_cf(dataset_name, exp_dir1)
-    #file_names = [f_name for f_name in os.listdir(os.path.join(exp_dir, 'inference_analysis')) if f_name.endswith('.pkl')]
-    # file_names = [os.path.join(exp_dir, "inference_analysis", f_name) for f_name in file_names]
-    file_names = ['bytes_merged_boxes_fold_0_pid_0296a.pkl', 'bytes_merged_boxes_fold_2_pid_0416a.pkl',
-                  'bytes_merged_boxes_fold_1_pid_0635a.pkl', "bytes_merged_boxes_fold_0_pid_0811a.pkl",
-                  "bytes_merged_boxes_fold_0_pid_0969a.pkl",
-                  # 'bytes_merged_boxes_fold_0_pid_0484a.pkl', 'bytes_merged_boxes_fold_0_pid_0492a.pkl',
-                  # 'bytes_merged_boxes_fold_0_pid_0505a.pkl','bytes_merged_boxes_fold_2_pid_0164a.pkl',
-                  # 'bytes_merged_boxes_fold_3_pid_0594a.pkl',
-
-
-                  ]
-    z_ics = [167, 159,
-             107, 194,
-             177,
-             # 84, 145,
-             # 212, 219,
-             # 67
-             ]
-    plot_files = [
-        {'files': [os.path.join(exp_dir, "inference_analysis", f_name) for exp_dir in [exp_dir1, exp_dir2]],
-         'z_ix': z_ix} for (f_name, z_ix) in zip(file_names, z_ics)
-    ]
-
-    info_df_path = '/mnt/HDD2TB/Documents/data/lidc/pp_20190318/patient_gts_{}/info_df.pickle'.format(cf.training_gts)
-    info_df = pd.read_pickle(info_df_path)
-
-    #cf.training_gts = 'sa'
-    cf.roi_items = ['regression_targets', 'rg_bin_targets_sa'] #['class_targets'] + cf.observables_rois
-
-    text_fs = 8
-    fig = plg.plt.figure(figsize=(6, 9.9)) #w, h
-    #fig = plg.plt.figure(figsize=(6, 6.5))
-    #fig.subplots_adjust(hspace=0, wspace=0)
-    grid = plg.plt.GridSpec(len(plot_files), 3, wspace=0.0, hspace=0.0, figure=fig) #rows, cols
-    cf.plot_class_ids = True
-
-
-    for f_ix, pack in enumerate(plot_files):
-        z_ix = plot_files[f_ix]['z_ix']
-        for model_ix in range(2)[::-1]:
-            print("f_ix, m_ix", f_ix, model_ix)
-            plot_file = utils.load_obj(plot_files[f_ix]['files'][model_ix])
-            batch = plot_file["batch"]
-            pid = batch["pid"][0]
-            batch['patient_rg_bin_targets_sa'] = info_df[info_df.pid == pid]['class_target'].tolist()
-            # apply same filter as with merged GTs: need at least two non-zero votes to consider a RoI.
-            batch['patient_rg_bin_targets_sa'] = [[four_votes for four_votes in batch_el if
-                                                   np.count_nonzero(four_votes>0)>=2] for batch_el in
-                                                  batch['patient_rg_bin_targets_sa']]
-            results_dict = plot_file["res_dict"]
-
-            # pred
-            ax = fig.add_subplot(grid[f_ix, model_ix+1])
-            plg.view_batch_thesis(cf, batch, res_dict=results_dict, legend=True, sample_picks=None,
-                                              vol_slice_picks=[z_ix, ], show_gt_labels=True, box_score_thres=0.2,
-                                              plot_mods=False,
-                                              out_file=None, dpi=600, patient_items=True, return_fig=False,
-                                              axes={'pred': ax})
-            if f_ix==0:
-                ax.set_title("{}".format("Reg R-CNN" if model_ix==0 else "Mask R-CNN"), size=text_fs*1.3)
-            else:
-                ax.set_title("")
-
-            ax.axis("off")
-            #grid.tight_layout(fig)
-
-            # GT
-            if model_ix==0:
-                ax = fig.add_subplot(grid[f_ix, 0])
-                # ax.imshow(batch['patient_data'][0, 0, :, :, z_ix], cmap='gray')
-                # ax.imshow(plg.to_rgba(batch['patient_seg'][0,0,:,:,z_ix], cf.cmap), alpha=0.8)
-                boxes_fig = plg.view_batch_thesis(cf, batch, res_dict=results_dict, legend=True, sample_picks=None,
-                                                  vol_slice_picks=[z_ix, ], show_gt_labels=True, box_score_thres=0.1,
-                                                  plot_mods=False, seg_cmap="rg",
-                                                  out_file=None, dpi=600, patient_items=True, return_fig=False,
-                                                  axes={'gt':ax})
-                ax.set_ylabel(r"$\mathbf{"+chr(f_ix+97)+")}$ " + ax.get_ylabel())
-                ax.set_ylabel("")
-                if f_ix==0:
-                    ax.set_title("Ground Truth", size=text_fs*1.3)
-                else:
-                    ax.set_title("")
-
-
-    #fig_patches = fig_leg.get_patches()
-    patches= [plg.mpatches.Patch(color=label.color, label="{:.10s}".format(label.name)) for label in cf.bin_id2label.values() if not label.id in [0,]]
-    #fig.legends.append(fig_leg)
-    plg.plt.figlegend(handles=patches, loc="lower center", bbox_to_anchor=(0.5, 0.0), borderaxespad=0.,
-                      ncol=len(patches), bbox_transform=fig.transFigure, title="Binned Malignancy Score",
-                      fontsize= text_fs)
-    plg.plt.tight_layout()
-    out_file = os.path.join(exp_dir1, "inference_analysis", "lidc_example_results_solarized.pdf")
-    if out_file is not None:
-        plg.plt.savefig(out_file, dpi=600, bbox_inches='tight')
-
-
-def box_clustering(exp_dir='', plot_dir=None):
-    import datasets.prostate.data_loader as dl
-    cf = get_cf('prostate', exp_dir)
-    if plot_dir is None:
-        plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'prostate', 'misc')
-
-    fig = plg.plt.figure(figsize=(10, 4))
-    #fig.subplots_adjust(hspace=0, wspace=0)
-    grid = plg.plt.GridSpec(2, 3, wspace=0.0, hspace=0., figure=fig)
-    fs = 14
-    xyA = (.9, 0.5)
-    xyB = (0.05, .5)
-
-    patch_size = np.array([200, 320])
-    clustering_iou = 0.1
-    img_y, img_x = patch_size
-
-    boxes = [
-        {'box_coords': [img_y * 0.2, img_x * 0.04, img_y * 0.55, img_x * 0.31], 'box_score': 0.45, 'box_cl': 1,
-         'regression': 2., 'rg_bin': cf.rg_val_to_bin_id(1.),
-         'box_patch_center_factor': 1., 'ens_ix': 1, 'box_n_overlaps': 1.},
-        {'box_coords': [img_y*0.05, img_x*0.05, img_y*0.5, img_x*0.3], 'box_score': 0.85, 'box_cl': 2,
-         'regression': 1., 'rg_bin': cf.rg_val_to_bin_id(1.),
-         'box_patch_center_factor': 1., 'ens_ix':1, 'box_n_overlaps':1.},
-        {'box_coords': [img_y * 0.1, img_x * 0.2, img_y * 0.4, img_x * 0.7], 'box_score': 0.95, 'box_cl': 2,
-         'regression': 1., 'rg_bin': cf.rg_val_to_bin_id(1.),
-         'box_patch_center_factor': 1., 'ens_ix':1, 'box_n_overlaps':1.},
-        {'box_coords': [img_y * 0.80, img_x * 0.35, img_y * 0.95, img_x * 0.85], 'box_score': 0.6, 'box_cl': 2,
-         'regression': 1., 'rg_bin': cf.rg_val_to_bin_id(1.),
-         'box_patch_center_factor': 1., 'ens_ix': 1, 'box_n_overlaps': 1.},
-        {'box_coords': [img_y * 0.85, img_x * 0.4, img_y * 0.93, img_x * 0.9], 'box_score': 0.85, 'box_cl': 2,
-         'regression': 1., 'rg_bin': cf.rg_val_to_bin_id(1.),
-         'box_patch_center_factor': 1., 'ens_ix':1, 'box_n_overlaps':1.},
-    ]
-    for box in boxes:
-        c = box['box_coords']
-        box_centers = np.array([(c[ii + 2] - c[ii]) / 2 for ii in range(len(c) // 2)])
-        box['box_patch_center_factor'] = np.mean(
-            [norm.pdf(bc, loc=pc, scale=pc * 0.8) * np.sqrt(2 * np.pi) * pc * 0.8 for bc, pc in
-             zip(box_centers, patch_size / 2)])
-        print("pc fact", box['box_patch_center_factor'])
-
-    box_coords = np.array([box['box_coords'] for box in boxes])
-    box_scores = np.array([box['box_score'] for box in boxes])
-    box_cl_ids = np.array([box['box_cl'] for box in boxes])
-    ax0 = fig.add_subplot(grid[:,:2])
-    plg.plot_boxes(cf, box_coords, patch_size, box_scores, box_cl_ids, out_file=os.path.join(plot_dir, "demo_boxes_unclustered.png"), ax=ax0)
-    ax0.text(*xyA, 'a) Raw ', horizontalalignment='right', verticalalignment='center', transform=ax0.transAxes,
-            weight='bold', fontsize=fs)
-
-    nms_boxes = []
-    for cl in range(1,3):
-        cl_boxes = [box for box in boxes if box['box_cl'] == cl ]
-        box_coords = np.array([box['box_coords'] for box in cl_boxes])
-        box_scores = np.array([box['box_score'] for box in cl_boxes])
-        if 0 not in box_scores.shape:
-            keep_ix = mutils.nms_numpy(box_coords, box_scores, thresh=clustering_iou)
-        else:
-            keep_ix = []
-        nms_boxes += [cl_boxes[ix] for ix in keep_ix]
-        box_coords = np.array([box['box_coords'] for box in nms_boxes])
-        box_scores = np.array([box['box_score'] for box in nms_boxes])
-        box_cl_ids = np.array([box['box_cl'] for box in nms_boxes])
-    ax1 = fig.add_subplot(grid[1, 2])
-    nms_color = cf.black
-    plg.plot_boxes(cf, box_coords, patch_size, box_scores, box_cl_ids, out_file=os.path.join(plot_dir, "demo_boxes_nms_iou_{}.png".format(clustering_iou)), ax=ax1)
-    ax1.text(*xyB, ' c) NMS', horizontalalignment='left', verticalalignment='center', transform=ax1.transAxes,
-            weight='bold', color=nms_color, fontsize=fs)
-
-    #------ WBC -------------------
-    regress_flag = False
-
-    wbc_boxes = []
-    for cl in range(1,3):
-        cl_boxes = [box for box in boxes if box['box_cl'] == cl]
-        box_coords = np.array([box['box_coords'] for box in cl_boxes])
-        box_scores = np.array([box['box_score'] for box in cl_boxes])
-        box_center_factor = np.array([b['box_patch_center_factor'] for b in cl_boxes])
-        box_n_overlaps = np.array([b['box_n_overlaps'] for b in cl_boxes])
-        box_ens_ix = np.array([b['ens_ix'] for b in cl_boxes])
-        box_regressions = np.array([b['regression'] for b in cl_boxes]) if regress_flag else None
-        box_rg_bins = np.array([b['rg_bin'] if 'rg_bin' in b.keys() else float('NaN') for b in cl_boxes])
-        box_rg_uncs = np.array([b['rg_uncertainty'] if 'rg_uncertainty' in b.keys() else float('NaN') for b in cl_boxes])
-        if 0 not in box_scores.shape:
-            keep_scores, keep_coords, keep_n_missing, keep_regressions, keep_rg_bins, keep_rg_uncs = \
-                predictor_file.weighted_box_clustering(box_coords, box_scores, box_center_factor, box_n_overlaps, box_rg_bins, box_rg_uncs,
-                                        box_regressions, box_ens_ix, clustering_iou, n_ens=1)
-
-            for boxix in range(len(keep_scores)):
-                clustered_box = {'box_type': 'det', 'box_coords': keep_coords[boxix],
-                                 'box_score': keep_scores[boxix], 'cluster_n_missing': keep_n_missing[boxix],
-                                 'box_pred_class_id': cl}
-                if regress_flag:
-                    clustered_box.update({'regression': keep_regressions[boxix],
-                                          'rg_uncertainty': keep_rg_uncs[boxix],
-                                          'rg_bin': keep_rg_bins[boxix]})
-                wbc_boxes.append(clustered_box)
-
-    box_coords = np.array([box['box_coords'] for box in wbc_boxes])
-    box_scores = np.array([box['box_score'] for box in wbc_boxes])
-    box_cl_ids = np.array([box['box_pred_class_id'] for box in wbc_boxes])
-    ax2 = fig.add_subplot(grid[0, 2])
-    wbc_color = cf.black
-    plg.plot_boxes(cf, box_coords, patch_size, box_scores, box_cl_ids, out_file=os.path.join(plot_dir, "demo_boxes_wbc_iou_{}.png".format(clustering_iou)), ax=ax2)
-    ax2.text(*xyB, ' b) WBC', horizontalalignment='left', verticalalignment='center', transform=ax2.transAxes,
-            weight='bold', color=wbc_color, fontsize=fs)
-    # ax2.spines['bottom'].set_color(wbc_color)
-    # ax2.spines['top'].set_color(wbc_color)
-    # ax2.spines['right'].set_color(wbc_color)
-    # ax2.spines['left'].set_color(wbc_color)
-
-    from matplotlib.patches import ConnectionPatch
-    con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA="axes fraction", coordsB="axes fraction",
-                          axesA=ax0, axesB=ax2, color=wbc_color, lw=1.5, arrowstyle='-|>')
-    ax0.add_artist(con)
-
-    con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA="axes fraction", coordsB="axes fraction",
-                          axesA=ax0, axesB=ax1, color=nms_color, lw=1.5, arrowstyle='-|>')
-    ax0.add_artist(con)
-    # ax0.text(0.5, 0.5, "Test", size=30, va="center", ha="center", rotation=30,
-    #          bbox=dict(boxstyle="angled,pad=0.5", alpha=0.2))
-    plg.plt.tight_layout()
-    plg.plt.savefig(os.path.join(plot_dir, "box_clustering.pdf"), bbox_inches='tight')
-
-def sketch_AP_AUC(plot_dir=None, draw_auc=True):
-    from sklearn.metrics import roc_curve, roc_auc_score
-    from understanding_metrics import get_det_types
-    import matplotlib.transforms as mtrans
-    cf = get_cf('prostate', '')
-    if plot_dir is None:
-        plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('.')
-
-    if draw_auc:
-        fig = plg.plt.figure(figsize=(7, 6)) #width, height
-        # fig.subplots_adjust(hspace=0, wspace=0)
-        grid = plg.plt.GridSpec(2, 2, wspace=0.23, hspace=.45, figure=fig) #rows, cols
-    else:
-        fig = plg.plt.figure(figsize=(12, 3)) #width, height
-        # fig.subplots_adjust(hspace=0, wspace=0)
-        grid = plg.plt.GridSpec(1, 3, wspace=0.23, hspace=.45, figure=fig) #rows, cols
-    fs = 13
-    text_fs = 11
-    optim_color = cf.dark_green
-    non_opt_color = cf.aubergine
-
-    df = pd.DataFrame(columns=['pred_score', 'class_label', 'pred_class', 'det_type', 'match_iou'])
-    df2 = df.copy()
-    df["pred_score"] = [0,0.3,0.25,0.2, 0.8, 0.9, 0.9, 0.9, 0.9]
-    df["class_label"] = [0,0,0,0, 1, 1, 1, 1, 1]
-    df["det_type"] = get_det_types(df)
-    df["match_iou"] = [0.1] * len(df)
-
-    df2["pred_score"] = [0, 0.77, 0.5, 1., 0.5, 0.35, 0.3, 0., 0.7, 0.85, 0.9]
-    df2["class_label"] = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
-    df2["det_type"] = get_det_types(df2)
-    df2["match_iou"] = [0.1] * len(df2)
-
-    #------ PRC -------
-    # optimal
-    if draw_auc:
-        ax = fig.add_subplot(grid[1, 0])
-    else:
-        ax = fig.add_subplot(grid[0, 2])
-    pr, rc = evaluator_file.compute_prc(df)
-    ax.plot(rc, pr, color=optim_color, label="Optimal Detection")
-    ax.fill_between(rc, pr, alpha=0.33, color=optim_color)
-
-    # suboptimal
-    pr, rc = evaluator_file.compute_prc(df2)
-    ax.plot(rc, pr, color=non_opt_color, label="Suboptimal")
-    ax.fill_between(rc, pr, alpha=0.33, color=non_opt_color)
-    #plt.title()
-    #plt.legend(loc=3 if c == 'prc' else 4)
-    ax.set_ylabel('precision', fontsize=text_fs)
-    ax.set_ylim((0., 1.1))
-    ax.set_xlabel('recall', fontsize=text_fs)
-    ax.set_title('Precision-Recall Curves', fontsize=fs)
-    #ax.legend(ncol=2, loc='center')#, bbox_to_anchor=(0.5, 1.05))
-
-
-    #---- ROC curve
-    if draw_auc:
-        ax = fig.add_subplot(grid[1, 1])
-        roc = roc_curve(df.class_label.tolist(), df.pred_score.tolist())
-        ax.plot(roc[0], roc[1], color=optim_color)
-        ax.fill_between(roc[0], roc[1], alpha=0.33, color=optim_color)
-        ax.set_xlabel('false-positive rate', fontsize=text_fs)
-        ax.set_ylim((0., 1.1))
-        ax.set_ylabel('recall', fontsize=text_fs)
-
-        roc = roc_curve(df2.class_label.tolist(), df2.pred_score.tolist())
-        ax.plot(roc[0], roc[1], color=non_opt_color)
-        ax.fill_between(roc[0], roc[1], alpha=0.33, color=non_opt_color)
-
-        roc = ([0, 1], [0, 1])
-        ax.plot(roc[0], roc[1], color=cf.gray, linestyle='dashed', label="random predictor")
-
-        ax.set_title('ROC Curves', fontsize=fs)
-        ax.legend(ncol=2, loc='lower right', fontsize=text_fs)
-
-    #--- hist optimal
-    text_left = 0.05
-    ax = fig.add_subplot(grid[0, 0])
-    tn_count = df.det_type.tolist().count('det_tn')
-    AUC = roc_auc_score(df.class_label, df.pred_score)
-    df = df[(df.det_type=="det_tp") | (df.det_type=="det_fp") | (df.det_type=="det_fn")]
-    labels = df.class_label.values
-    preds = df.pred_score.values
-    type_list = df.det_type.tolist()
-
-    ax.hist(preds[labels == 0], alpha=0.3, color=cf.red, range=(0, 1), bins=50, label="FP")
-    ax.hist(preds[labels == 1], alpha=0.3, color=cf.blue, range=(0, 1), bins=50, label="FN at score 0 and TP")
-    #ax.axvline(x=cf.min_det_thresh, alpha=0.4, color=cf.orange, linewidth=1.5, label="min det thresh")
-    fp_count = type_list.count('det_fp')
-    fn_count = type_list.count('det_fn')
-    tp_count = type_list.count('det_tp')
-    pos_count = fn_count + tp_count
-    if draw_auc:
-        text = "AP: {:.2f} ROC-AUC: {:.2f}\n".format(evaluator_file.get_roi_ap_from_df((df, 0.0, False)), AUC)
-    else:
-        text = "AP: {:.2f}\n".format(evaluator_file.get_roi_ap_from_df((df, 0.0, False)))
-    text += 'TP: {} FP: {} FN: {} TN: {}\npositives: {}'.format(tp_count, fp_count, fn_count, tn_count, pos_count)
-
-    ax.text(text_left,4, text, fontsize=text_fs)
-    ax.set_yscale('log')
-    ax.set_ylim(bottom=10**-2, top=10**2)
-    ax.set_xlabel("prediction score", fontsize=text_fs)
-    ax.set_ylabel("occurences", fontsize=text_fs)
-    #autoAxis = ax.axis()
-    # rec = plg.mpatches.Rectangle((autoAxis[0] - 0.7, autoAxis[2] - 0.2), (autoAxis[1] - autoAxis[0]) + 1,
-    #                 (autoAxis[3] - autoAxis[2]) + 0.4, fill=False, lw=2)
-    # rec = plg.mpatches.Rectangle((autoAxis[0] , autoAxis[2] ), (autoAxis[1] - autoAxis[0]) ,
-    #                 (autoAxis[3] - autoAxis[2]) , fill=False, lw=2, color=optim_color)
-    # rec = ax.add_patch(rec)
-    # rec.set_clip_on(False)
-    plg.plt.setp(ax.spines.values(), color=optim_color, linewidth=2)
-    ax.set_facecolor((*optim_color,0.1))
-    ax.set_title("Detection Histograms", fontsize=fs)
-
-    ax = fig.add_subplot(grid[0, 1])
-    tn_count = df2.det_type.tolist().count('det_tn')
-    AUC = roc_auc_score(df2.class_label, df2.pred_score)
-    df2 = df2[(df2.det_type=="det_tp") | (df2.det_type=="det_fp") | (df2.det_type=="det_fn")]
-    labels = df2.class_label.values
-    preds = df2.pred_score.values
-    type_list = df2.det_type.tolist()
-
-    ax.hist(preds[labels == 0], alpha=0.3, color=cf.red, range=(0, 1), bins=50, label="FP")
-    ax.hist(preds[labels == 1], alpha=0.3, color=cf.blue, range=(0, 1), bins=50, label="FN at score 0 and TP")
-    # ax.axvline(x=cf.min_det_thresh, alpha=0.4, color=cf.orange, linewidth=1.5, label="min det thresh")
-    fp_count = type_list.count('det_fp')
-    fn_count = type_list.count('det_fn')
-    tp_count = type_list.count('det_tp')
-    pos_count = fn_count + tp_count
-    if draw_auc:
-        text = "AP: {:.2f} ROC-AUC: {:.2f}\n".format(evaluator_file.get_roi_ap_from_df((df2, 0.0, False)), AUC)
-    else:
-        text = "AP: {:.2f}\n".format(evaluator_file.get_roi_ap_from_df((df2, 0.0, False)))
-    text += 'TP: {} FP: {} FN: {} TN: {}\npositives: {}'.format(tp_count, fp_count, fn_count, tn_count, pos_count)
-
-    ax.text(text_left, 4*10**0, text, fontsize=text_fs)
-    ax.set_yscale('log')
-    ax.margins(y=10e2)
-    ax.set_ylim(bottom=10**-2, top=10**2)
-    ax.set_xlabel("prediction score", fontsize=text_fs)
-    ax.set_yticks([])
-    plg.plt.setp(ax.spines.values(), color=non_opt_color, linewidth=2)
-    ax.set_facecolor((*non_opt_color, 0.05))
-    ax.legend(ncol=2, loc='upper center', bbox_to_anchor=(0.5, 1.18), fontsize=text_fs)
-
-    if draw_auc:
-        # Draw a horizontal line
-        line = plg.plt.Line2D([0.1, .9], [0.48, 0.48], transform=fig.transFigure, color="black")
-        fig.add_artist(line)
-
-    outfile = os.path.join(plot_dir, "metrics.png")
-    print("Saving plot to {}".format(outfile))
-    plg.plt.savefig(outfile, bbox_inches='tight', dpi=600)
-
-    return
-
-def draw_toy_cylinders(plot_dir=None):
-    source_path = "datasets/toy"
-    if plot_dir is None:
-        plot_dir = os.path.join(source_path, "misc")
-        #plot_dir = '/home/gregor/Dropbox/Thesis/Main/tmp'
-    os.makedirs(plot_dir, exist_ok=True)
-
-    cf = get_cf('toy', '')
-    cf.pre_crop_size = [2200, 2200,1] #y,x,z;
-    #cf.dim = 2
-    cf.ambiguities = {"radius_calib": (1., 1. / 6) }
-    cf.pp_blur_min_intensity = 0.2
-
-    generate_toys = utils.import_module("generate_toys", os.path.join(source_path, 'generate_toys.py'))
-    ToyGen = generate_toys.ToyGenerator(cf)
-
-    fig = plg.plt.figure(figsize=(10, 8.2)) #width, height
-    grid = plg.plt.GridSpec(4, 5, wspace=0.0, hspace=.0, figure=fig) #rows, cols
-    fs, text_fs = 16, 14
-    text_x, text_y = 0.5, 0.85
-    true_gt_col, dist_gt_col = cf.dark_green, cf.blue
-    true_cmap = {1:true_gt_col}
-
-    img = np.random.normal(loc=0.0, scale=cf.noise_scale, size=ToyGen.sample_size)
-    img[img < 0.] = 0.
-    # one-hot-encoded seg
-    seg = np.zeros((cf.num_classes + 1, *ToyGen.sample_size)).astype('uint8')
-    undistorted_seg = np.copy(seg)
-    applied_gt_distort = False
-
-    class_id, shape = 1, 'cylinder'
-    #all_radii = ToyGen.generate_sample_radii(class_ids, shapes)
-    enlarge_f = 20
-    all_radii = np.array([np.mean(label.bin_vals) if label.id!=5 else label.bin_vals[0]+5 for label in cf.bin_labels if label.id!=0])
-    bins = [(min(label.bin_vals), max(label.bin_vals)) for label in cf.bin_labels]
-    bin_edges = [(bins[i][1] + bins[i + 1][0])*enlarge_f / 2 for i in range(len(bins) - 1)]
-    all_radii = [np.array([r*enlarge_f, r*enlarge_f, 1]) for r in all_radii] # extend to required 3D format
-    regress_targets, undistorted_rg_targets = [], []
-    ics = np.argwhere(np.ones(seg[0].shape)) # indices ics equal positions within img/volume
-    center = np.array([dim//2 for dim in img.shape])
-
-    # for illustrating GT distribution, keep scale same size
-    #x = np.linspace(mu - 300, mu + 300, 100)
-    x = np.linspace(0, 50*enlarge_f, 500)
-    ax_gauss = fig.add_subplot(grid[3, :])
-    mus, sigmas = [], []
-
-    for roi_ix, radii in enumerate(all_radii):
-        print('processing {} {}'.format(roi_ix, radii))
-        cur_img, cur_seg, cur_undistorted_seg, cur_regress_targets, cur_undistorted_rg_targets, cur_applied_gt_distort = \
-            ToyGen.draw_object(img.copy(), seg.copy(), undistorted_seg, ics, regress_targets, undistorted_rg_targets, applied_gt_distort,
-                             roi_ix, class_id, shape, np.copy(radii), center)
-
-        ax = fig.add_subplot(grid[0,roi_ix])
-        ax.imshow(cur_img[...,0], cmap='gray', vmin=0)
-        ax.set_title("r{}".format(roi_ix+1), fontsize=fs)
-        if roi_ix==0:
-            ax.set_ylabel(r"$\mathbf{a)}$ Input", fontsize=fs)
-            plg.suppress_axes_lines(ax)
-        else:
-            ax.axis('off')
-
-        ax = fig.add_subplot(grid[1, roi_ix])
-        ax.imshow(cur_img[..., 0], cmap='gray')
-        ax.imshow(plg.to_rgba(np.argmax(cur_undistorted_seg[...,0], axis=0), true_cmap), alpha=0.8)
-        ax.text(text_x, text_y, r"$r_{a}=$"+"{:.1f}".format(cur_undistorted_rg_targets[roi_ix][0]/enlarge_f), transform=ax.transAxes,
-                color=cf.white, bbox=dict(facecolor=true_gt_col, alpha=0.7, edgecolor=cf.white, clip_on=False,pad=2.5),
-                fontsize=text_fs, ha='center', va='center')
-        if roi_ix==0:
-            ax.set_ylabel(r"$\mathbf{b)}$ Exact GT", fontsize=fs)
-            plg.suppress_axes_lines(ax)
-        else:
-            ax.axis('off')
-        ax = fig.add_subplot(grid[2, roi_ix])
-        ax.imshow(cur_img[..., 0], cmap='gray')
-        ax.imshow(plg.to_rgba(np.argmax(cur_seg[..., 0], axis=0), cf.cmap), alpha=0.7)
-        ax.text(text_x, text_y, r"$r_{a}=$"+"{:.1f}".format(cur_regress_targets[roi_ix][0]/enlarge_f), transform=ax.transAxes,
-                color=cf.white, bbox=dict(facecolor=cf.blue, alpha=0.7, edgecolor=cf.white, clip_on=False,pad=2.5),
-                fontsize=text_fs, ha='center', va='center')
-        if roi_ix == 0:
-            ax.set_ylabel(r"$\mathbf{c)}$ Noisy GT", fontsize=fs)
-            plg.suppress_axes_lines(ax)
-        else:
-            ax.axis('off')
-
-        # GT distributions
-        assert radii[0]==radii[1]
-        mu, sigma = radii[0], radii[0] * cf.ambiguities["radius_calib"][1]
-        ax_gauss.axvline(mu, color=true_gt_col)
-        ax_gauss.text(mu, -0.003, "$r=${:.0f}".format(mu/enlarge_f), color=true_gt_col, fontsize=text_fs, ha='center', va='center',
-                      bbox = dict(facecolor='none', alpha=0.7, edgecolor=true_gt_col, clip_on=False, pad=2.5))
-        mus.append(mu); sigmas.append(sigma)
-        lower_bound = max(bin_edges[roi_ix], min(x))# if roi_ix>0 else 2*mu-bin_edges[roi_ix+1]
-        upper_bound = bin_edges[roi_ix+1] if len(bin_edges)>roi_ix+1 else max(x)#2*mu-bin_edges[roi_ix]
-        if roi_ix<len(all_radii)-1:
-            ax_gauss.axvline(upper_bound, color='white', linewidth=7)
-        ax_gauss.axvspan(lower_bound, upper_bound, ymax=0.9999, facecolor=true_gt_col, alpha=0.4, edgecolor='none')
-        if roi_ix == 0:
-            ax_gauss.set_ylabel(r"$\mathbf{d)}$ GT Distr.", fontsize=fs)
-            #plg.suppress_axes_lines(ax_gauss)
-            #min_x, max_x = min(x/enlarge_f), max(x/enlarge_f)
-            #ax_gauss.xaxis.set_ticklabels(["{:.0f}".format(x_tick) for x_tick in np.arange(min_x, max_x, (max_x-min_x)/5)])
-            ax_gauss.xaxis.set_ticklabels([])
-            ax_gauss.axes.yaxis.set_ticks([])
-            ax_gauss.spines['top'].set_visible(False)
-            ax_gauss.spines['right'].set_visible(False)
-            #ax.spines['bottom'].set_visible(False)
-            ax_gauss.spines['left'].set_visible(False)
-    for d_ix, (mu, sigma) in enumerate(zip(mus, sigmas)):
-        ax_gauss.plot(x, norm.pdf(x, mu, sigma), color=dist_gt_col, alpha=0.6+d_ix/10)
-    ax_gauss.margins(x=0)
-    # in-axis coordinate cross
-    arrow_x, arrow_y, arrow_dx, arrow_dy = 30, ax_gauss.get_ylim()[1]/3, 30, ax_gauss.get_ylim()[1]/3
-    ax_gauss.arrow(arrow_x, arrow_y, 0., arrow_dy, length_includes_head=False, head_width=10, head_length=0.001, head_starts_at_zero=False, shape="full", width=0.5, fc="black", ec="black")
-    ax_gauss.arrow(arrow_x, arrow_y, arrow_dx, 0, length_includes_head=False, head_width=0.001, head_length=8,
-                   head_starts_at_zero=False, shape="full", width=0.00005, fc="black", ec="black")
-    ax_gauss.text(arrow_x-20, arrow_y + arrow_dy*0.5, r"$prob$", fontsize=text_fs, ha='center', va='center', rotation=90)
-    ax_gauss.text(arrow_x + arrow_dx * 0.5, arrow_y *0.85, r"$r$", fontsize=text_fs, ha='center', va='center', rotation=0)
-    # ax_gauss.annotate(r"$p$", xytext=(0, 0), xy=(0, arrow_y), fontsize=fs,
-    #             arrowprops=dict(arrowstyle="-|>, head_length = 0.05, head_width = .005", lw=1))
-    #ax_gauss.arrow(1, 0.5, 0., 0.1)
-    handles = [plg.mpatches.Patch(facecolor=dist_gt_col, label='Inexact Seg.', alpha=0.7, edgecolor='none'),
-               mlines.Line2D([], [], color=dist_gt_col, marker=r'$\curlywedge$', linestyle='none', markersize=11, label='GT Sampling Distr.'),
-               mlines.Line2D([], [], color=true_gt_col, marker='|', markersize=12, label='Exact GT Radius.', linestyle='none'),
-               plg.mpatches.Patch(facecolor=true_gt_col, label='a)-c) Exact Seg., d) Bin', alpha=0.7, edgecolor='none')]
-    fig.legend(handles=handles, loc="lower center", ncol=len(handles), fontsize=text_fs)
-    outfile = os.path.join(plot_dir, "toy_cylinders.png")
-    print("Saving plot to {}".format(outfile))
-    plg.plt.savefig(outfile, bbox_inches='tight', dpi=600)
-
-
-    return
-
-def seg_det_cityscapes_example(plot_dir=None):
-    cf = get_cf('cityscapes', '')
-    source_path = "datasets/cityscapes"
-    if plot_dir is None:
-        plot_dir = os.path.join(source_path, "misc")
-    os.makedirs(plot_dir, exist_ok=True)
-
-
-    dl = utils.import_module("dl", os.path.join(source_path, 'data_loader.py'))
-    #from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates
-    data_set = dl.Dataset(cf)
-    Converter = dl.ConvertSegToBoundingBoxCoordinates(2, cf.roi_items)
-
-    fig = plg.plt.figure(figsize=(9, 3)) #width, height
-    grid = plg.plt.GridSpec(1, 2, wspace=0.05, hspace=.0, figure=fig) #rows, cols
-    fs, text_fs = 12, 10
-
-    nice_imgs = ["bremen000099000019", "hamburg000000033506", "frankfurt000001058914",]
-    img_id = nice_imgs[2]
-    #img_id = np.random.choice(data_set.set_ids)
-
-
-    print("Selected img", img_id)
-    img = np.load(data_set[img_id]["img"]).transpose(1,2,0)
-    seg = np.load(data_set[img_id]["seg"])
-    cl_targs = data_set[img_id]["class_targets"]
-    roi_ids = np.unique(seg[seg > 0])
-    # ---- detection example -----
-    cl_id2name = {1: "h", 2: "v"}
-    color_palette = [cf.purple, cf.aubergine, cf.magenta, cf.dark_blue, cf.blue, cf.bright_blue, cf.cyan, cf.dark_green,
-                     cf.green, cf.dark_yellow, cf.yellow, cf.orange,  cf.red, cf.dark_red, cf.bright_red]
-    n_colors = len(color_palette)
-    cmap = {roi_id : color_palette[(roi_id-1)%n_colors] for roi_id in roi_ids}
-    cmap[0] = (1,1,1,0.)
-
-    ax = fig.add_subplot(grid[0, 1])
-    ax.imshow(img)
-    ax.imshow(plg.to_rgba(seg, cmap), alpha=0.7)
-
-    data_dict = Converter(**{'seg':seg[np.newaxis, np.newaxis], 'class_targets': [cl_targs]}) # needs batch dim and channel
-    for roi_ix, bb_target in enumerate(data_dict['bb_target'][0]):
-        [y1, x1, y2, x2] = bb_target
-        width, height = x2 - x1, y2 - y1
-        cl_id = cl_targs[roi_ix]
-        label = cf.class_id2label[cl_id]
-        text_x, text_y = x2, y1
-        id_text = cl_id2name[cl_id]
-        text_str = '{}'.format(id_text)
-        text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0)
-        #ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=text_fs, ha="center", va="center")
-        edgecolor = label.color
-        bbox = plg.mpatches.Rectangle((x1, y1), width, height, linewidth=1.05, edgecolor=edgecolor, facecolor='none')
-        ax.add_patch(bbox)
-    ax.axis('off')
-
-    # ---- seg example -----
-    for roi_id in roi_ids:
-        seg[seg==roi_id] = cl_targs[roi_id-1]
-
-    ax = fig.add_subplot(grid[0,0])
-    ax.imshow(img)
-    ax.imshow(plg.to_rgba(seg, cf.cmap), alpha=0.7)
-    ax.axis('off')
-
-    plg.plt.tight_layout()
-    outfile = os.path.join(plot_dir, "cityscapes_example.png")
-    print("Saving plot to {}".format(outfile))
-    plg.plt.savefig(outfile, bbox_inches='tight', dpi=600)
-
-
-
-
-
-if __name__=="__main__":
-    stime = time.time()
-    #seg_det_cityscapes_example()
-    #box_clustering()
-    #sketch_AP_AUC(draw_auc=False)
-    #draw_toy_cylinders()
-    #prostate_GT_examples(plot_dir="/home/gregor/Dropbox/Thesis/Main/MFPPresentation/graphics")
-    #prostate_results_static()
-    #prostate_dataset_stats(plot_dir="/home/gregor/Dropbox/Thesis/Main/MFPPresentation/graphics", show_splits=False)
-    #lidc_dataset_stats()
-    #lidc_sa_dataset_stats()
-    #lidc_annotator_confusion()
-    #lidc_merged_sa_joint_plot()
-    #lidc_annotator_dissent_images()
-    exp_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_frcnn3d_cl_bs6"
-    #multiple_clustering_results('prostate', exp_dir, plot_hist=True)
-    exp_parent_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments"
-    exp_parent_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments_debug_retinas"
-    #get_plot_clustering_results('prostate', exp_parent_dir, res_from_file=False)
-
-    exp_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_frcnn3d_cl_bs6"
-    #cf = get_cf('prostate', exp_dir)
-    #plot_file = os.path.join(exp_dir, "inference_analysis/bytes_merged_boxes_fold_1_pid_177.pkl")
-    #plot_single_results(cf, exp_dir, plot_file)
-
-    exp_dir1 = "/home/gregor/networkdrives/E132-Cluster-Projects/lidc_sa/experiments/ms12345_mrcnn3d_rg_bs8"
-    exp_dir2 = "/home/gregor/networkdrives/E132-Cluster-Projects/lidc_sa/experiments/ms12345_mrcnn3d_rgbin_bs8"
-    #find_suitable_examples(exp_dir1, exp_dir2)
-    #plot_single_results_lidc()
-    plot_dir = "/home/gregor/Dropbox/Thesis/MICCAI2019/Graphics"
-    #lidc_results_static(plot_dir=plot_dir)
-    #toy_results_static(plot_dir=plot_dir)
-    plot_lidc_dissent_and_example(plot_dir=plot_dir, confusion_matrix=True, numbering=False, example_title="LIDC example result")
-
-    mins, secs = divmod((time.time() - stime), 60)
-    h, mins = divmod(mins, 60)
-    t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
-    print("{} total runtime: {}".format(os.path.split(__file__)[1], t))
\ No newline at end of file
diff --git a/models/mrcnn_aleatoric.py b/models/mrcnn_aleatoric.py
deleted file mode 100644
index 30d54d5..0000000
--- a/models/mrcnn_aleatoric.py
+++ /dev/null
@@ -1,735 +0,0 @@
-#!/usr/bin/env python
-# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""
-Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn
-published under MIT license.
-"""
-import time
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils
-
-import utils.model_utils as mutils
-import utils.exp_utils as utils
-#from cuda_functions.nms_2D.pth_nms import nms_gpu as nms_2D
-#from cuda_functions.nms_3D.pth_nms import nms_gpu as nms_3D
-#from cuda_functions.roi_align_2D.roi_align.crop_and_resize import CropAndResizeFunction as ra2D
-#from cuda_functions.roi_align_3D.roi_align.crop_and_resize import CropAndResizeFunction as ra3D
-
-
-class RPN(nn.Module):
-    """
-    Region Proposal Network.
-    """
-
-    def __init__(self, cf, conv):
-
-        super(RPN, self).__init__()
-        self.dim = conv.dim
-
-        self.conv_shared = conv(cf.end_filts, cf.n_rpn_features, ks=3, stride=cf.rpn_anchor_stride, pad=1, relu=cf.relu)
-        self.conv_class = conv(cf.n_rpn_features, 2 * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None)
-        self.conv_bbox = conv(cf.n_rpn_features, 2 * self.dim * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None)
-
-
-    def forward(self, x):
-        """
-        :param x: input feature maps (b, in_channels, y, x, (z))
-        :return: rpn_class_logits (b, 2, n_anchors)
-        :return: rpn_probs_logits (b, 2, n_anchors)
-        :return: rpn_bbox (b, 2 * dim, n_anchors)
-        """
-
-        # Shared convolutional base of the RPN.
-        x = self.conv_shared(x)
-
-        # Anchor Score. (batch, anchors per location * 2, y, x, (z)).
-        rpn_class_logits = self.conv_class(x)
-        # Reshape to (batch, 2, anchors)
-        axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1)
-        rpn_class_logits = rpn_class_logits.permute(*axes)
-        rpn_class_logits = rpn_class_logits.contiguous()
-        rpn_class_logits = rpn_class_logits.view(x.size()[0], -1, 2)
-
-        # Softmax on last dimension (fg vs. bg).
-        rpn_probs = F.softmax(rpn_class_logits, dim=2)
-
-        # Bounding box refinement. (batch, anchors_per_location * (y, x, (z), log(h), log(w), (log(d)), y, x, (z))
-        rpn_bbox = self.conv_bbox(x)
-
-        # Reshape to (batch, 2*dim, anchors)
-        rpn_bbox = rpn_bbox.permute(*axes)
-        rpn_bbox = rpn_bbox.contiguous()
-        rpn_bbox = rpn_bbox.view(x.size()[0], -1, self.dim * 2)
-
-        return [rpn_class_logits, rpn_probs, rpn_bbox]
-
-class Classifier(nn.Module):
-    """
-    Head network for classification and bounding box refinement. Performs RoiAlign, processes resulting features through a
-    shared convolutional base and finally branches off the classifier- and regression head.
-    """
-    def __init__(self, cf, conv):
-        super(Classifier, self).__init__()
-
-        self.cf = cf
-        self.dim = conv.dim
-        self.in_channels = cf.end_filts
-        self.pool_size = cf.pool_size
-        self.pyramid_levels = cf.pyramid_levels
-        # instance_norm does not work with spatial dims (1, 1, (1))
-        norm = cf.norm if cf.norm != 'instance_norm' else None
-
-        self.conv1 = conv(cf.end_filts, cf.end_filts * 4, ks=self.pool_size, stride=1, norm=norm, relu=cf.relu)
-        self.conv2 = conv(cf.end_filts * 4, cf.end_filts * 4, ks=1, stride=1, norm=norm, relu=cf.relu)
-        self.linear_bbox = nn.Linear(cf.end_filts * 4, cf.head_classes * 2 * self.dim)
-
-
-        if 'regression_ken_gal' in self.cf.prediction_tasks:
-            self.linear_regressor = nn.Linear(cf.end_filts * 4, cf.head_classes*cf.regression_n_features)
-            self.uncert_regressor = nn.Linear(cf.end_filts * 4, cf.head_classes)
-        else:
-            raise NotImplementedError
-        if 'class' in self.cf.prediction_tasks:
-            #raise NotImplementedError
-            self.linear_class = nn.Linear(cf.end_filts * 4, cf.head_classes)
-        else:
-            assert cf.head_classes==2, "#head classes {} needs to be 2 (bg/fg) when not predicting classes"
-            self.linear_class = lambda x: torch.zeros((x.shape[0], cf.head_classes), dtype=torch.float64).cuda()
-            #assert hasattr(cf, "regression_n_features"), "cannot choose class inference from regression if regression not applied"
-
-    def forward(self, x, rois):
-        """
-        :param x: input feature maps (b, in_channels, y, x, (z))
-        :param rois: normalized box coordinates as proposed by the RPN to be forwarded through
-        the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements
-        have been merged to one vector, while the origin info has been stored for re-allocation.
-        :return: mrcnn_class_logits (n_proposals, n_head_classes)
-        :return: mrcnn_bbox (n_proposals, n_head_classes, 2 * dim) predicted corrections to be applied to proposals for refinement.
-        :return: mrcnn_regress (n_proposals, n_head_classes, regression_n_features+1) +1 is aleatoric uncertainty
-        """
-        x = mutils.pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim)
-        x = self.conv1(x)
-        x = self.conv2(x)
-        x = x.view(-1, self.in_channels * 4)
-
-        mrcnn_bbox = self.linear_bbox(x)
-        mrcnn_bbox = mrcnn_bbox.view(mrcnn_bbox.size()[0], -1, self.dim * 2)
-        mrcnn_class_logits = self.linear_class(x)
-        mrcnn_regress, uncert_rg = self.linear_regressor(x), self.uncert_regressor(x)
-        mrcnn_regress = torch.cat((mrcnn_regress.view(mrcnn_regress.shape[0], -1, self.cf.regression_n_features),
-                                   uncert_rg.unsqueeze(-1)), dim=2)
-
-        return [mrcnn_bbox, mrcnn_class_logits, mrcnn_regress]
-
-class Mask(nn.Module):
-    """
-    Head network for proposal-based mask segmentation. Performs RoiAlign, some convolutions and applies sigmoid on the
-    output logits to allow for overlapping classes.
-    """
-    def __init__(self, cf, conv):
-        super(Mask, self).__init__()
-        self.pool_size = cf.mask_pool_size
-        self.pyramid_levels = cf.pyramid_levels
-        self.dim = conv.dim
-        self.conv1 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
-        self.conv2 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
-        self.conv3 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
-        self.conv4 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
-        if conv.dim == 2:
-            self.deconv = nn.ConvTranspose2d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2)
-        else:
-            self.deconv = nn.ConvTranspose3d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2)
-
-        self.relu = nn.ReLU(inplace=True) if cf.relu == 'relu' else nn.LeakyReLU(inplace=True)
-        self.conv5 = conv(cf.end_filts, cf.head_classes, ks=1, stride=1, relu=None)
-        self.sigmoid = nn.Sigmoid()
-
-    def forward(self, x, rois):
-        """
-        :param x: input feature maps (b, in_channels, y, x, (z))
-        :param rois: normalized box coordinates as proposed by the RPN to be forwarded through
-        the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements
-        have been merged to one vector, while the origin info has been stored for re-allocation.
-        :return: x: masks (n_sampled_proposals (n_detections in inference), n_classes, y, x, (z))
-        """
-        x = mutils.pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim)
-        x = self.conv1(x)
-        x = self.conv2(x)
-        x = self.conv3(x)
-        x = self.conv4(x)
-        x = self.relu(self.deconv(x))
-        x = self.conv5(x)
-        x = self.sigmoid(x)
-        return x
-
-
-############################################################
-#  Loss Functions
-############################################################
-
-def compute_rpn_class_loss(rpn_class_logits, rpn_match, shem_poolsize):
-    """
-    :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors.
-    :param rpn_class_logits: (n_anchors, 2). logits from RPN classifier.
-    :param SHEM_poolsize: int. factor of top-k candidates to draw from per negative sample (stochastic-hard-example-mining).
-    :return: loss: torch tensor
-    :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training.
-    """
-
-    # Filter out netural anchors
-    pos_indices = torch.nonzero(rpn_match == 1)
-    neg_indices = torch.nonzero(rpn_match == -1)
-
-    # loss for positive samples
-    if not 0 in pos_indices.size():
-        pos_indices = pos_indices.squeeze(1)
-        roi_logits_pos = rpn_class_logits[pos_indices]
-        pos_loss = F.cross_entropy(roi_logits_pos, torch.LongTensor([1] * pos_indices.shape[0]).cuda())
-    else:
-        pos_loss = torch.FloatTensor([0]).cuda()
-
-    # loss for negative samples: draw hard negative examples (SHEM)
-    # that match the number of positive samples, but at least 1.
-    if not 0 in neg_indices.size():
-        neg_indices = neg_indices.squeeze(1)
-        roi_logits_neg = rpn_class_logits[neg_indices]
-        negative_count = np.max((1, pos_indices.cpu().data.numpy().size))
-        roi_probs_neg = F.softmax(roi_logits_neg, dim=1)
-        neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize)
-        neg_loss = F.cross_entropy(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda())
-        np_neg_ix = neg_ix.cpu().data.numpy()
-        #print("pos, neg count", pos_indices.cpu().data.numpy().size, negative_count)
-    else:
-        neg_loss = torch.FloatTensor([0]).cuda()
-        np_neg_ix = np.array([]).astype('int32')
-
-    loss = (pos_loss + neg_loss) / 2
-    return loss, np_neg_ix
-
-
-def compute_rpn_bbox_loss(rpn_pred_deltas, rpn_target_deltas, rpn_match):
-    """
-    :param rpn_target_deltas:   (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))).
-    Uses 0 padding to fill in unsed bbox deltas.
-    :param rpn_pred_deltas: predicted deltas from RPN. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd))))
-    :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors.
-    :return: loss: torch 1D tensor.
-    """
-    if not 0 in torch.nonzero(rpn_match == 1).size():
-
-        indices = torch.nonzero(rpn_match == 1).squeeze(1)
-        # Pick bbox deltas that contribute to the loss
-        rpn_pred_deltas = rpn_pred_deltas[indices]
-        # Trim target bounding box deltas to the same length as rpn_bbox.
-        target_deltas = rpn_target_deltas[:rpn_pred_deltas.size()[0], :]
-        # Smooth L1 loss
-        loss = F.smooth_l1_loss(rpn_pred_deltas, target_deltas)
-    else:
-        loss = torch.FloatTensor([0]).cuda()
-
-    return loss
-
-def compute_mrcnn_bbox_loss(mrcnn_pred_deltas, mrcnn_target_deltas, target_class_ids):
-    """
-    :param mrcnn_pred_deltas: (n_sampled_rois, n_classes, (dy, dx, (dz), log(dh), log(dw), (log(dh)))
-    :param mrcnn_target_deltas: (n_sampled_rois, (dy, dx, (dz), log(dh), log(dw), (log(dh)))
-    :param target_class_ids: (n_sampled_rois)
-    :return: loss: torch 1D tensor.
-    """
-    if not 0 in torch.nonzero(target_class_ids > 0).size():
-        positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0]
-        positive_roi_class_ids = target_class_ids[positive_roi_ix].long()
-        target_bbox = mrcnn_target_deltas[positive_roi_ix, :].detach()
-        pred_bbox = mrcnn_pred_deltas[positive_roi_ix, positive_roi_class_ids, :]
-        loss = F.smooth_l1_loss(pred_bbox, target_bbox)
-    else:
-        loss = torch.FloatTensor([0]).cuda()
-
-    return loss
-
-def compute_mrcnn_mask_loss(pred_masks, target_masks, target_class_ids):
-    """
-    :param pred_masks: (n_sampled_rois, n_classes, y, x, (z)) float32 tensor with values between [0, 1].
-    :param target_masks: (n_sampled_rois, y, x, (z)) A float32 tensor of values 0 or 1. Uses zero padding to fill array.
-    :param target_class_ids: (n_sampled_rois)
-    :return: loss: torch 1D tensor.
-    """
-    if not 0 in torch.nonzero(target_class_ids > 0).size():
-        # Only positive ROIs contribute to the loss. And only
-        # the class specific mask of each ROI.
-        positive_ix = torch.nonzero(target_class_ids > 0)[:, 0]
-        positive_class_ids = target_class_ids[positive_ix].long()
-        y_true = target_masks[positive_ix, :, :].detach()
-        y_pred = pred_masks[positive_ix, positive_class_ids, :, :]
-        loss = F.binary_cross_entropy(y_pred, y_true)
-    else:
-        loss = torch.FloatTensor([0]).cuda()
-
-    return loss
-
-def compute_mrcnn_class_loss(tasks, pred_class_logits, target_class_ids):
-    """
-    :param pred_class_logits: (n_sampled_rois, n_classes)
-    :param target_class_ids: (n_sampled_rois) batch dimension was merged into roi dimension.
-    :return: loss: torch 1D tensor.
-    """
-    if 'class' in tasks and not 0 in target_class_ids.size():
-        loss = F.cross_entropy(pred_class_logits, target_class_ids.long())
-    else:
-        loss = torch.FloatTensor([0.]).cuda()
-
-    return loss
-
-def compute_mrcnn_regression_loss(pred, target, target_class_ids):
-    """regression loss is a distance metric between target vector and predicted regression vector.
-    :param pred: (n_sample_rois, n_classes, n_regr_feats+1) regression pred where last entry of each regression
-        pred is the uncertainty parameter
-    :param target: (n_sample_rois, n_regr_feats)
-    :param target_class_ids: (n_sample_rois)
-    :return: differentiable loss, torch 1D tensor on cuda
-    """
-
-    if not 0 in torch.nonzero(target_class_ids > 0).size():
-         positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0]
-         positive_roi_class_ids = target_class_ids[positive_roi_ix].long()
-         target = target[positive_roi_ix, :].float().detach()
-         pred = pred[positive_roi_ix, positive_roi_class_ids, :]
-
-         # loss is 1/(2N)*[Sum_i^N exp(-s_i) distance(pred_vec, targ_vec) + s_i]
-         loss = F.smooth_l1_loss(pred[...,:-1], target, reduction='none').sum(dim=1) * torch.exp(-pred[...,-1])
-         loss += pred[...,-1] #regularizer for sigma
-         loss = 0.5*loss.mean()
-    else:
-        loss = torch.FloatTensor([0.]).cuda()
-
-    return loss
-
-############################################################
-#  Detection Layer
-############################################################
-
-def compute_roi_scores(cf, batch_rpn_proposals, mrcnn_cl_logits):
-    """Compute scores from uncertainty measures (lower=better) to use for sorting/clustering algos (higher=better).
-    :param cf:
-    :param uncert_class:
-    :param uncert_regression:
-    :return:
-    """
-    if 'class' in cf.prediction_tasks:
-        scores = F.softmax(mrcnn_cl_logits, dim=1)
-    else:
-        scores = batch_rpn_proposals[:,:,-1].view(-1, 1)
-        scores = torch.cat((1-scores, scores), dim=1)
-
-    return scores
-
-############################################################
-#  MaskRCNN Class
-############################################################
-
-class net(nn.Module):
-
-
-    def __init__(self, cf, logger):
-
-        super(net, self).__init__()
-        self.cf = cf
-        self.logger = logger
-        self.regress_flag = any(['regression' in task for task in self.cf.prediction_tasks])
-        self.build()
-
-
-        if self.cf.weight_init=="custom":
-            logger.info("Tried to use custom weight init which is not defined. Using pytorch default.")
-        elif self.cf.weight_init:
-            mutils.initialize_weights(self)
-        else:
-            logger.info("using default pytorch weight init")
-
-    def build(self):
-        """Build Mask R-CNN architecture."""
-
-        # Image size must be dividable by 2 multiple times.
-        h, w = self.cf.patch_size[:2]
-        if h / 2**5 != int(h / 2**5) or w / 2**5 != int(w / 2**5):
-            raise Exception("Image size must be dividable by 2 at least 5 times "
-                            "to avoid fractions when downscaling and upscaling."
-                            "For example, use 256, 320, 384, 448, 512, ... etc.,i.e.,"
-                            "any number x*32 will do!")
-
-        # instantiate abstract multi-dimensional conv generator and load backbone module.
-        backbone = utils.import_module('bbone', self.cf.backbone_path)
-        conv = backbone.ConvGenerator(self.cf.dim)
-
-        # build Anchors, FPN, RPN, Classifier / Bbox-Regressor -head, Mask-head
-        self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf)
-        self.anchors = torch.from_numpy(self.np_anchors).float().cuda()
-        self.fpn = backbone.FPN(self.cf, conv, relu_enc=self.cf.relu, operate_stride1=False).cuda()
-        self.rpn = RPN(self.cf, conv)
-        self.classifier = Classifier(self.cf, conv)
-        self.mask = Mask(self.cf, conv)
-
-    def forward(self, img, is_training=True):
-        """
-        :param img: input images (b, c, y, x, (z)).
-        :return: rpn_pred_logits: (b, n_anchors, 2)
-        :return: rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d))))
-        :return: batch_unnormed_props: (b, n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix)) only for monitoring/plotting.
-        :return: detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score)
-        :return: detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head.
-        """
-        # extract features.
-        fpn_outs = self.fpn(img)
-        rpn_feature_maps = [fpn_outs[i] for i in self.cf.pyramid_levels]
-        self.mrcnn_feature_maps = rpn_feature_maps
-
-        # loop through pyramid layers and apply RPN.
-        layer_outputs = []  # list of lists
-        for p in rpn_feature_maps:
-            layer_outputs.append(self.rpn(p))
-
-        # concatenate layer outputs.
-        # convert from list of lists of level outputs to list of lists of outputs across levels.
-        # e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]]
-        outputs = list(zip(*layer_outputs))
-        outputs = [torch.cat(list(o), dim=1) for o in outputs]
-        rpn_pred_logits, rpn_pred_probs, rpn_pred_deltas = outputs
-
-        # generate proposals: apply predicted deltas to anchors and filter by foreground scores from RPN classifier.
-        proposal_count = self.cf.post_nms_rois_training if is_training else self.cf.post_nms_rois_inference
-        batch_normed_props, batch_unnormed_props = mutils.refine_proposals(rpn_pred_probs, rpn_pred_deltas, proposal_count, self.anchors, self.cf)
-
-        # merge batch dimension of proposals while storing allocation info in coordinate dimension.
-        batch_ixs = torch.from_numpy(np.repeat(np.arange(batch_normed_props.shape[0]), batch_normed_props.shape[1])).float().cuda()
-        rpn_rois = batch_normed_props[:,:,:-1].view(-1, batch_normed_props[:,:,:-1].shape[2])
-        self.rpn_rois_batch_info = torch.cat((rpn_rois, batch_ixs.unsqueeze(1)), dim=1)
-
-        # this is the first of two forward passes in the second stage, where no activations are stored for backprop.
-        # here, all proposals are forwarded (with virtual_batch_size = batch_size * post_nms_rois.)
-        # for inference/monitoring as well as sampling of rois for the loss functions.
-        # processed in chunks of roi_chunk_size to re-adjust to gpu-memory.
-        chunked_rpn_rois = self.rpn_rois_batch_info.split(self.cf.roi_chunk_size)
-        bboxes_list, class_logits_list, regressions_list = [], [], []
-        with torch.no_grad():
-            for chunk in chunked_rpn_rois:
-                chunk_bboxes, chunk_class_logits, chunk_regressions = self.classifier(self.mrcnn_feature_maps, chunk)
-                bboxes_list.append(chunk_bboxes)
-                class_logits_list.append(chunk_class_logits)
-                regressions_list.append(chunk_regressions)
-        mrcnn_bbox = torch.cat(bboxes_list, 0)
-        mrcnn_class_logits = torch.cat(class_logits_list, 0)
-        mrcnn_regressions = torch.cat(regressions_list, 0)
-        #self.mrcnn_class_logits = F.softmax(mrcnn_class_logits, dim=1)
-        #why were mrcnn_bbox, class_logs, regress called batch_ ? they have no batch dim, in contrast to batch_normed_props
-        self.mrcnn_roi_scores = compute_roi_scores(self.cf, batch_normed_props, mrcnn_class_logits)
-        # refine classified proposals, filter and return final detections.
-        # returns (cf.max_inst_per_batch_element, n_coords+1+...)
-        detections = mutils.refine_detections(self.cf, batch_ixs, rpn_rois, mrcnn_bbox, self.mrcnn_roi_scores,
-                                       mrcnn_regressions)
-
-        # forward remaining detections through mask-head to generate corresponding masks.
-        scale = [img.shape[2]] * 4 + [img.shape[-1]] * 2
-        scale = torch.from_numpy(np.array(scale[:self.cf.dim * 2] + [1])[None]).float().cuda()
-
-        # first self.cf.dim * 2 entries on axis 1 are always the box coords, +1 is batch_ics
-        detection_boxes = detections[:, :self.cf.dim * 2 + 1] / scale
-        with torch.no_grad():
-            detection_masks = self.mask(self.mrcnn_feature_maps, detection_boxes)
-
-        return [rpn_pred_logits, rpn_pred_deltas, batch_unnormed_props, detections, detection_masks]
-
-    def loss_samples_forward(self, batch_gt_boxes, batch_gt_masks, batch_gt_class_ids, batch_gt_regressions):
-        """
-        this is the second forward pass through the second stage (features from stage one are re-used).
-        samples few rois in loss_example_mining and forwards only those for loss computation.
-        :param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels. can be None.
-        :param batch_gt_regressions: can be None.
-        :param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates.
-        :param batch_gt_masks: list over batch elements. Each element is binary mask of shape (n_gt_rois, y, x, (z), c)
-        :return: sample_logits: (n_sampled_rois, n_classes) predicted class scores.
-        :return: sample_deltas: (n_sampled_rois, n_classes, 2 * dim) predicted corrections to be applied to proposals for refinement.
-        :return: sample_mask: (n_sampled_rois, n_classes, y, x, (z)) predicted masks per class and proposal.
-        :return: sample_target_class_ids: (n_sampled_rois) target class labels of sampled proposals.
-        :return: sample_target_deltas: (n_sampled_rois, 2 * dim) target deltas of sampled proposals for box refinement.
-        :return: sample_target_masks: (n_sampled_rois, y, x, (z)) target masks of sampled proposals.
-        :return: sample_proposals: (n_sampled_rois, 2 * dim) RPN output for sampled proposals. only for monitoring/plotting.
-        """
-        # sample rois for loss and get corresponding targets for all Mask R-CNN head network losses.
-        sample_ics, sample_target_deltas, sample_target_mask, sample_target_class_ids, sample_target_regressions = \
-            mutils.loss_example_mining(self.cf, self.rpn_rois_batch_info, batch_gt_boxes, batch_gt_masks,
-                                   self.mrcnn_roi_scores, batch_gt_class_ids, batch_gt_regressions)
-
-        # re-use feature maps and RPN output from first forward pass.
-        sample_proposals = self.rpn_rois_batch_info[sample_ics]
-        if not 0 in sample_proposals.size():
-            sample_deltas, sample_logits, sample_regressions = self.classifier(self.mrcnn_feature_maps, sample_proposals)
-            sample_mask = self.mask(self.mrcnn_feature_maps, sample_proposals)
-        else:
-            sample_logits = torch.FloatTensor().cuda()
-            sample_deltas = torch.FloatTensor().cuda()
-            sample_mask = torch.FloatTensor().cuda()
-
-        return [sample_deltas, sample_mask, sample_logits, sample_regressions, sample_proposals,
-                sample_target_deltas, sample_target_mask, sample_target_class_ids, sample_target_regressions]
-
-    def get_results(self, img_shape, detections, detection_masks, box_results_list=None, return_masks=True):
-        """
-        Restores batch dimension of merged detections, unmolds detections, creates and fills results dict.
-        :param img_shape:
-        :param detections: shape (n_final_detections, len(info)), where
-            info=( y1, x1, y2, x2, (z1,z2), batch_ix, pred_class_id, pred_score )
-        :param detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head.
-        :param box_results_list: None or list of output boxes for monitoring/plotting.
-        each element is a list of boxes per batch element.
-        :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off).
-        :return: results_dict: dictionary with keys:
-                 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
-                          [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
-                 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, 1] only fg. vs. bg for now.
-                 class-specific return of masks will come with implementation of instance segmentation evaluation.
-        """
-
-        detections = detections.cpu().data.numpy()
-        if self.cf.dim == 2:
-            detection_masks = detection_masks.permute(0, 2, 3, 1).cpu().data.numpy()
-        else:
-            detection_masks = detection_masks.permute(0, 2, 3, 4, 1).cpu().data.numpy()
-        # det masks shape now (n_dets, y,x(,z), n_classes)
-        # restore batch dimension of merged detections using the batch_ix info.
-        batch_ixs = detections[:, self.cf.dim*2]
-        detections = [detections[batch_ixs == ix] for ix in range(img_shape[0])]
-        mrcnn_mask = [detection_masks[batch_ixs == ix] for ix in range(img_shape[0])]
-        #mrcnn_mask: shape (b_size, variable, variable, n_classes), variable bc depends on single instance mask size
-
-        if box_results_list == None: # for test_forward, where no previous list exists.
-            box_results_list =  [[] for _ in range(img_shape[0])]
-
-        seg_logits = []
-        # loop over batch and unmold detections.
-        for ix in range(img_shape[0]):
-
-            # final masks are one-hot encoded (b, n_classes, y, x, (z))
-            final_masks = np.zeros((self.cf.num_classes + 1, *img_shape[2:]))
-            #+1 for bg, 0.5 bc mask head classifies only bg/fg with logits between 0,1--> bg is <0.5
-            if self.cf.num_classes + 1 != self.cf.num_seg_classes:
-                self.logger.warning("n of box classifier head classes {} doesnt match cf.num_seg_classes {}".format(
-                    self.cf.num_classes + 1, self.cf.num_seg_classes))
-
-            if not 0 in detections[ix].shape:
-                boxes = detections[ix][:, :self.cf.dim*2].astype(np.int32)
-                class_ids = detections[ix][:, self.cf.dim*2 + 1].astype(np.int32)
-                scores = detections[ix][:, self.cf.dim*2 + 2]
-                masks = mrcnn_mask[ix][np.arange(boxes.shape[0]), ..., class_ids]
-                regressions = detections[ix][:,self.cf.dim*2+3:]
-
-                # Filter out detections with zero area. Often only happens in early
-                # stages of training when the network weights are still a bit random.
-                if self.cf.dim == 2:
-                    exclude_ix = np.where((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0]
-                else:
-                    exclude_ix = np.where(
-                        (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4]) <= 0)[0]
-
-                if exclude_ix.shape[0] > 0:
-                    boxes = np.delete(boxes, exclude_ix, axis=0)
-                    masks = np.delete(masks, exclude_ix, axis=0)
-                    class_ids = np.delete(class_ids, exclude_ix, axis=0)
-                    scores = np.delete(scores, exclude_ix, axis=0)
-                    regressions = np.delete(regressions, exclude_ix, axis=0)
-
-                # Resize masks to original image size and set boundary threshold.
-                if return_masks:
-                    for i in range(masks.shape[0]): #masks per this batch instance/element/image
-                        # Convert neural network mask to full size mask
-                        if self.cf.dim == 2:
-                            full_mask = mutils.unmold_mask_2D(masks[i], boxes[i], img_shape[2:])
-                        else:
-                            full_mask = mutils.unmold_mask_3D(masks[i], boxes[i], img_shape[2:])
-                        # take the maximum seg_logits per class of instances in that class, i.e., a pixel in a class
-                        # has the max seg_logit value over all instances of that class in one sample
-                        final_masks[class_ids[i]] = np.max((final_masks[class_ids[i]], full_mask), axis=0)
-                    final_masks[0] = np.full(final_masks[0].shape, 0.49999999) #effectively min_det_thres at 0.5 per pixel
-
-                # add final predictions to results.
-                if not 0 in boxes.shape:
-                    for ix2, coords in enumerate(boxes):
-                        box = {'box_coords': coords, 'box_type': 'det', 'box_score': scores[ix2],
-                               'box_pred_class_id': class_ids[ix2]}
-                        if 'regression_ken_gal' or 'regression_feindt' in self.cf.prediction_tasks:
-                            rg_uncert = np.sqrt(np.exp(regressions[ix2][-1]))
-                            box.update({'regression': regressions[ix2][:-1], 'rg_uncertainty': rg_uncert })
-                        if hasattr(self.cf, "rg_val_to_bin_id"):
-                            box['rg_bin'] = self.cf.rg_val_to_bin_id(regressions[ix2][:-1])
-                        box_results_list[ix].append(box)
-
-            # if no detections were made--> keep full bg mask (zeros).
-            seg_logits.append(final_masks)
-
-        # create and fill results dictionary.
-        results_dict = {}
-        results_dict['boxes'] = box_results_list
-        results_dict['seg_preds'] = np.array(seg_logits)
-
-        return results_dict
-
-
-    def train_forward(self, batch, is_validation=False):
-        """
-        train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
-        for processing, computes losses, and stores outputs in a dictionary.
-        :param batch: dictionary containing 'data', 'seg', etc.
-        :return: results_dict: dictionary with keys:
-                'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
-                        [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
-                'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes].
-                'torch_loss': 1D torch tensor for backprop.
-                'class_loss': classification loss for monitoring.
-        """
-        img = batch['data']
-        gt_boxes = batch['bb_target']
-        axes = (0, 2, 3, 1) if self.cf.dim == 2 else (0, 2, 3, 4, 1)
-        gt_masks = [np.transpose(batch['roi_masks'][ii], axes=axes) for ii in range(len(batch['roi_masks']))]
-        gt_regressions = batch["regression_targets"] if self.regress_flag else None
-        gt_class_ids = batch['class_targets']
-
-
-        img = torch.from_numpy(img).float().cuda()
-        batch_rpn_class_loss = torch.FloatTensor([0]).cuda()
-        batch_rpn_bbox_loss = torch.FloatTensor([0]).cuda()
-
-        # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element.
-        box_results_list = [[] for _ in range(img.shape[0])]
-
-        #forward passes. 1. general forward pass, where no activations are saved in second stage (for performance
-        # monitoring and loss sampling). 2. second stage forward pass of sampled rois with stored activations for backprop.
-        rpn_class_logits, rpn_pred_deltas, proposal_boxes, detections, detection_masks = self.forward(img)
-
-        mrcnn_pred_deltas, mrcnn_pred_mask, mrcnn_class_logits, mrcnn_regressions, sample_proposals, \
-        mrcnn_target_deltas, target_mask, target_class_ids, target_regressions = \
-            self.loss_samples_forward(gt_boxes, gt_masks, gt_class_ids, gt_regressions)
-
-        #loop over batch
-        for b in range(img.shape[0]):
-            if len(gt_boxes[b]) > 0:
-                # add gt boxes to output list
-                for tix in range(len(gt_boxes[b])):
-                    gt_box = {'box_type': 'gt', 'box_coords': batch['bb_target'][b][tix]}
-                    for name in self.cf.roi_items:
-                        gt_box.update({name: batch[name][b][tix]})
-                    box_results_list[b].append(gt_box)
-
-                # match gt boxes with anchors to generate targets for RPN losses.
-                rpn_match, rpn_target_deltas = mutils.gt_anchor_matching(self.cf, self.np_anchors, gt_boxes[b])
-
-                # add positive anchors used for loss to output list for monitoring.
-                pos_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == 1)][:, 0], img.shape[2:])
-                for p in pos_anchors:
-                    box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'})
-
-            else:
-                rpn_match = np.array([-1]*self.np_anchors.shape[0])
-                rpn_target_deltas = np.array([0])
-
-            rpn_match = torch.from_numpy(rpn_match).cuda()
-            rpn_target_deltas = torch.from_numpy(rpn_target_deltas).float().cuda()
-
-            # compute RPN losses.
-            rpn_class_loss, neg_anchor_ix = compute_rpn_class_loss(rpn_class_logits[b], rpn_match, self.cf.shem_poolsize)
-            rpn_bbox_loss = compute_rpn_bbox_loss(rpn_pred_deltas[b], rpn_target_deltas, rpn_match)
-            batch_rpn_class_loss += rpn_class_loss /img.shape[0]
-            batch_rpn_bbox_loss += rpn_bbox_loss /img.shape[0]
-
-            # add negative anchors used for loss to output list for monitoring.
-            neg_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == -1)][0, neg_anchor_ix], img.shape[2:])
-            for n in neg_anchors:
-                box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'})
-
-            # add highest scoring proposals to output list for monitoring.
-            rpn_proposals = proposal_boxes[b][proposal_boxes[b, :, -1].argsort()][::-1]
-            for r in rpn_proposals[:self.cf.n_plot_rpn_props, :-1]:
-                box_results_list[b].append({'box_coords': r, 'box_type': 'prop'})
-
-        # add positive and negative roi samples used for mrcnn losses to output list for monitoring.
-        if not 0 in sample_proposals.shape:
-            rois = mutils.clip_to_window(self.cf.window, sample_proposals).cpu().data.numpy()
-            for ix, r in enumerate(rois):
-                box_results_list[int(r[-1])].append({'box_coords': r[:-1] * self.cf.scale,
-                                            'box_type': 'pos_class' if target_class_ids[ix] > 0 else 'neg_class'})
-
-        # compute mrcnn losses.
-        mrcnn_class_loss = compute_mrcnn_class_loss(self.cf.prediction_tasks, mrcnn_class_logits, target_class_ids)
-        mrcnn_bbox_loss = compute_mrcnn_bbox_loss(mrcnn_pred_deltas, mrcnn_target_deltas, target_class_ids)
-        mrcnn_regression_loss = compute_mrcnn_regression_loss(mrcnn_regressions, target_regressions, target_class_ids)
-        # mrcnn can be run without pixelwise annotations available (Faster R-CNN mode).
-        # In this case, the mask_loss is taken out of training.
-        if not self.cf.frcnn_mode:
-            mrcnn_mask_loss = compute_mrcnn_mask_loss(mrcnn_pred_mask, target_mask, target_class_ids)
-        else:
-            mrcnn_mask_loss = torch.FloatTensor([0]).cuda()
-
-        loss = batch_rpn_class_loss + batch_rpn_bbox_loss +\
-               mrcnn_bbox_loss + mrcnn_mask_loss +  mrcnn_class_loss + mrcnn_regression_loss
-
-        # monitor RPN performance: detection count = the number of correctly matched proposals per fg-class.
-        #dcount = [list(target_class_ids.cpu().data.numpy()).count(c) for c in np.arange(self.cf.head_classes)[1:]]
-        #self.logger.info("regression loss {:.3f}".format(mrcnn_regression_loss.item()))
-        #self.logger.info("loss: {0:.2f}, rpn_class: {1:.2f}, rpn_bbox: {2:.2f}, mrcnn_class: {3:.2f}, mrcnn_bbox: {4:.2f}, "
-        #      "mrcnn_mask: {5:.2f}, dcount {6}".format(loss.item(), batch_rpn_class_loss.item(),
-        #      batch_rpn_bbox_loss.item(), mrcnn_class_loss.item(), mrcnn_bbox_loss.item(), mrcnn_mask_loss.item(), dcount))
-
-        # run unmolding of predictions for monitoring and merge all results to one dictionary.
-
-        return_masks = self.cf.return_masks_in_val if is_validation else self.cf.return_masks_in_train
-        results_dict = self.get_results(
-            img.shape, detections, detection_masks, box_results_list, return_masks=return_masks)
-        results_dict['seg_preds'] = results_dict['seg_preds'].argmax(axis=1).astype('uint8')[:,np.newaxis]
-        if 'dice' in self.cf.metrics:
-            results_dict['batch_dices'] = mutils.dice_per_batch_and_class(
-                results_dict['seg_preds'], batch["seg"], self.cf.num_seg_classes, convert_to_ohe=True)
-
-
-        results_dict['torch_loss'] = loss
-        results_dict['class_loss'] = mrcnn_class_loss.item()
-        results_dict['rg_loss'] = mrcnn_regression_loss.item()
-        results_dict['bbox_loss'] = mrcnn_bbox_loss.item()
-        results_dict['rpn_bbox_loss'] = rpn_bbox_loss.item()
-        results_dict['rpn_class_loss'] = rpn_class_loss.item()
-
-        return results_dict
-
-
-    def test_forward(self, batch, return_masks=True):
-        """
-        test method. wrapper around forward pass of network without usage of any ground truth information.
-        prepares input data for processing and stores outputs in a dictionary.
-        :param batch: dictionary containing 'data'
-        :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off).
-        :return: results_dict: dictionary with keys:
-               'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
-                       [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
-               'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]
-        """
-        img = batch['data']
-        img = torch.from_numpy(img).float().cuda()
-        _, _, _, detections, detection_masks = self.forward(img)
-        results_dict = self.get_results(img.shape, detections, detection_masks, return_masks=return_masks)
-
-        return results_dict
\ No newline at end of file
diff --git a/models/mrcnn_gan.py b/models/mrcnn_gan.py
deleted file mode 100644
index af5632c..0000000
--- a/models/mrcnn_gan.py
+++ /dev/null
@@ -1,844 +0,0 @@
-#!/usr/bin/env python
-# Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""
-Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn
-published under MIT license.
-"""
-import time
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils
-
-import utils.model_utils as mutils
-import utils.exp_utils as utils
-
-
-class Generator_RPN(nn.Module):
-    """
-    Region Proposal Network.
-    """
-
-    def __init__(self, cf, conv):
-
-        super(Generator_RPN, self).__init__()
-        self.dim = conv.dim
-
-        #assert cf.batch_size%2==0
-        self.conv_shared = conv(cf.end_filts+1, cf.n_rpn_features, ks=3, stride=cf.rpn_anchor_stride, pad=1, relu=cf.relu)
-        self.conv_class = conv(cf.n_rpn_features, 2 * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None)
-        self.conv_bbox = conv(cf.n_rpn_features, 2 * self.dim * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None)
-
-
-    def forward(self, x):
-        """
-        :param x: input feature maps (b, in_channels, y, x, (z))
-        :return: rpn_class_logits (b, n_anchors, 2)
-        :return: rpn_probs_logits (b, n_anchors, 2)
-        :return: rpn_bbox (b, n_anchors, 2*dim)
-        """
-        # latent vector from vanilla base distribution
-        z = torch.randn(x.shape[0], 1, *x.shape[2:], requires_grad=True).cuda()
-        x = torch.cat((x,z), dim=1)
-        # Shared convolutional base of the RPN.
-        x = self.conv_shared(x)
-
-        # Anchor Score. (batch, anchors per location * 2, y, x, (z)).
-        rpn_class_logits = self.conv_class(x)
-        # Reshape to (batch, anchors, 2)
-        axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1)
-        rpn_class_logits = rpn_class_logits.permute(*axes)
-        rpn_class_logits = rpn_class_logits.contiguous()
-        rpn_class_logits = rpn_class_logits.view(x.size()[0], -1, 2)
-
-        # Softmax on last dimension (fg vs. bg).
-        rpn_probs = F.softmax(rpn_class_logits, dim=2)
-
-        # Bounding box refinement. (batch, anchors_per_location * (y, x, (z), log(h), log(w), (log(d)), y, x, (z))
-        rpn_bbox = self.conv_bbox(x)
-
-        # Reshape to (batch, anchors, 2*dim)
-        rpn_bbox = rpn_bbox.permute(*axes)
-        rpn_bbox = rpn_bbox.contiguous()
-        rpn_bbox = rpn_bbox.view(x.size()[0], -1, self.dim * 2)
-
-        return [rpn_class_logits, rpn_probs, rpn_bbox]
-
-class RPN_Discriminator(nn.Module):
-    """
-    Region Proposal Network.
-    """
-
-    def __init__(self, cf, conv):
-
-        super(RPN_Discriminator, self).__init__()
-        self.dim = conv.dim
-
-        #assert cf.batch_size%2==0
-        self.resizer = nn.Sequential(
-            conv(cf.end_filts, cf.end_filts//2, ks=3, stride=cf.rpn_anchor_stride, pad=0, relu=cf.relu),
-            nn.MaxPool2d(kernel_size=3, stride=2, padding=0) if \
-                conv.dim == 2 else nn.MaxPool3d(kernel_size=3,stride=(2, 2, 1),padding=0),
-            conv(cf.end_filts//2, cf.end_filts // 2, ks=1, stride=1, pad=0, relu=cf.relu),
-            nn.MaxPool2d(kernel_size=3, stride=2, padding=0) if \
-                conv.dim == 2 else nn.MaxPool3d(kernel_size=3, stride=(2, 2, 1), padding=0),
-
-        )
-        self.in_channels = cf.end_filts * 4
-        self.conv2 = conv(cf.end_filts, cf.n_rpn_features, ks=1, stride=1, pad=1, relu=cf.relu)
-        self.conv3 = conv(cf.n_rpn_features, 2 * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None)
-
-    def forward(self, f_maps, probs, deltas):
-        """
-        :param feature_maps: list of tensors of sizes (bsize, cf.end_filts, varying map dimensions)
-        :param probs: tensor of size (bsize, n_proposals on all fpn layers, 2)
-        :param deltas: tensor of size (bsize, n_proposals on all fpn layers, cf.dim*2)
-        :return:
-        """
-        f_maps = [self.resizer(m) for m in f_maps]
-        x = torch.cat([t.view(t.shape[0], t.shape[1], -1) for t in f_maps], dim=-1)
-        x = x.view(-1, self.in_channels)
-        x = torch.cat((x,z), dim=1)
-        # Shared convolutional base of the RPN.
-        x = self.conv_shared(x)
-
-        # Anchor Score. (batch, anchors per location * 2, y, x, (z)).
-        rpn_class_logits = self.conv_class(x)
-        # Reshape to (batch, 2, anchors)
-        axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1)
-        rpn_class_logits = rpn_class_logits.permute(*axes)
-        rpn_class_logits = rpn_class_logits.contiguous()
-        rpn_class_logits = rpn_class_logits.view(x.size()[0], -1, 2)
-
-        # Softmax on last dimension (fg vs. bg).
-        rpn_probs = F.softmax(rpn_class_logits, dim=2)
-
-        # Bounding box refinement. (batch, anchors_per_location * (y, x, (z), log(h), log(w), (log(d)), y, x, (z))
-        rpn_bbox = self.conv_bbox(x)
-
-        # Reshape to (batch, 2*dim, anchors)
-        rpn_bbox = rpn_bbox.permute(*axes)
-        rpn_bbox = rpn_bbox.contiguous()
-        rpn_bbox = rpn_bbox.view(x.size()[0], -1, self.dim * 2)
-
-        return [rpn_class_logits, rpn_probs, rpn_bbox]
-
-
-
-
-
-class Classifier(nn.Module):
-    """
-    Head network for classification and bounding box refinement. Performs RoiAlign, processes resulting features through a
-    shared convolutional base and finally branches off the classifier- and regression head.
-    """
-    def __init__(self, cf, conv):
-        super(Classifier, self).__init__()
-
-        self.cf = cf
-        self.dim = conv.dim
-        self.in_channels = cf.end_filts
-        self.pool_size = cf.pool_size
-        self.pyramid_levels = cf.pyramid_levels
-        # instance_norm does not work with spatial dims (1, 1, (1))
-        norm = cf.norm if cf.norm != 'instance_norm' else None
-
-        self.conv1 = conv(cf.end_filts, cf.end_filts * 4, ks=self.pool_size, stride=1, norm=norm, relu=cf.relu)
-        self.conv2 = conv(cf.end_filts * 4, cf.end_filts * 4, ks=1, stride=1, norm=norm, relu=cf.relu)
-        self.linear_bbox = nn.Linear(cf.end_filts * 4, cf.head_classes * 2 * self.dim)
-
-
-        if 'regression' in self.cf.prediction_tasks:
-            self.linear_regressor = nn.Linear(cf.end_filts * 4, cf.head_classes * cf.regression_n_features)
-            self.rg_n_feats = cf.regression_n_features
-        #classify into bins of regression values
-        elif 'regression_bin' in self.cf.prediction_tasks:
-            self.linear_regressor = nn.Linear(cf.end_filts * 4, cf.head_classes * len(cf.bin_labels))
-            self.rg_n_feats = len(cf.bin_labels)
-        else:
-            self.linear_regressor = lambda x: torch.zeros((x.shape[0], cf.head_classes * cf.regression_n_features), dtype=torch.float32).fill_(float('NaN')).cuda()
-            self.rg_n_feats = cf.regression_n_features
-        if 'class' in self.cf.prediction_tasks:
-            self.linear_class = nn.Linear(cf.end_filts * 4, cf.head_classes)
-        else:
-            assert cf.head_classes == 2, "#head classes {} needs to be 2 (bg/fg) when not predicting classes".format(cf.head_classes)
-            self.linear_class = lambda x: torch.zeros((x.shape[0], cf.head_classes), dtype=torch.float64).cuda()
-            #print("\n\nWARNING: using extra class head\n\n")
-            #self.linear_class = nn.Linear(cf.end_filts * 4, cf.head_classes)
-
-    def forward(self, x, rois):
-        """
-        :param x: input feature maps (b, in_channels, y, x, (z))
-        :param rois: normalized box coordinates as proposed by the RPN to be forwarded through
-        the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements
-        have been merged to one vector, while the origin info has been stored for re-allocation.
-        :return: mrcnn_class_logits (n_proposals, n_head_classes)
-        :return: mrcnn_bbox (n_proposals, n_head_classes, 2 * dim) predicted corrections to be applied to proposals for refinement.
-        """
-        x = mutils.pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim)
-        x = self.conv1(x)
-        x = self.conv2(x)
-        x = x.view(-1, self.in_channels * 4)
-
-        mrcnn_bbox = self.linear_bbox(x)
-        mrcnn_bbox = mrcnn_bbox.view(mrcnn_bbox.size()[0], -1, self.dim * 2)
-        mrcnn_class_logits = self.linear_class(x)
-        mrcnn_regress = self.linear_regressor(x)
-        mrcnn_regress = mrcnn_regress.view(mrcnn_regress.size()[0], -1, self.rg_n_feats)
-
-        return [mrcnn_bbox, mrcnn_class_logits, mrcnn_regress]
-
-
-class Mask(nn.Module):
-    """
-    Head network for proposal-based mask segmentation. Performs RoiAlign, some convolutions and applies sigmoid on the
-    output logits to allow for overlapping classes.
-    """
-    def __init__(self, cf, conv):
-        super(Mask, self).__init__()
-        self.pool_size = cf.mask_pool_size
-        self.pyramid_levels = cf.pyramid_levels
-        self.dim = conv.dim
-        self.conv1 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
-        self.conv2 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
-        self.conv3 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
-        self.conv4 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
-        if conv.dim == 2:
-            self.deconv = nn.ConvTranspose2d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2)
-        else:
-            self.deconv = nn.ConvTranspose3d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2)
-
-        self.relu = nn.ReLU(inplace=True) if cf.relu == 'relu' else nn.LeakyReLU(inplace=True)
-        self.conv5 = conv(cf.end_filts, cf.head_classes, ks=1, stride=1, relu=None)
-        self.sigmoid = nn.Sigmoid()
-
-    def forward(self, x, rois):
-        """
-        :param x: input feature maps (b, in_channels, y, x, (z))
-        :param rois: normalized box coordinates as proposed by the RPN to be forwarded through
-        the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements
-        have been merged to one vector, while the origin info has been stored for re-allocation.
-        :return: x: masks (n_sampled_proposals (n_detections in inference), n_classes, y, x, (z))
-        """
-        x = mutils.pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim)
-        x = self.conv1(x)
-        x = self.conv2(x)
-        x = self.conv3(x)
-        x = self.conv4(x)
-        x = self.relu(self.deconv(x))
-        x = self.conv5(x)
-        x = self.sigmoid(x)
-        return x
-
-
-############################################################
-#  Loss Functions
-############################################################
-
-def compute_rpn_class_loss(rpn_class_logits, rpn_match, shem_poolsize):
-    """
-    :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors.
-    :param rpn_class_logits: (n_anchors, 2). logits from RPN classifier.
-    :param SHEM_poolsize: int. factor of top-k candidates to draw from per negative sample (stochastic-hard-example-mining).
-    :return: loss: torch tensor
-    :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training.
-    """
-
-    # Filter out netural anchors
-    pos_indices = torch.nonzero(rpn_match == 1)
-    neg_indices = torch.nonzero(rpn_match == -1)
-
-    # loss for positive samples
-    if not 0 in pos_indices.size():
-        pos_indices = pos_indices.squeeze(1)
-        roi_logits_pos = rpn_class_logits[pos_indices]
-        pos_loss = F.cross_entropy(roi_logits_pos, torch.LongTensor([1] * pos_indices.shape[0]).cuda())
-    else:
-        pos_loss = torch.FloatTensor([0]).cuda()
-
-    # loss for negative samples: draw hard negative examples (SHEM)
-    # that match the number of positive samples, but at least 1.
-    if not 0 in neg_indices.size():
-        neg_indices = neg_indices.squeeze(1)
-        roi_logits_neg = rpn_class_logits[neg_indices]
-        negative_count = np.max((1, pos_indices.cpu().data.numpy().size))
-        roi_probs_neg = F.softmax(roi_logits_neg, dim=1)
-        neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize)
-        neg_loss = F.cross_entropy(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda())
-        np_neg_ix = neg_ix.cpu().data.numpy()
-        #print("pos, neg count", pos_indices.cpu().data.numpy().size, negative_count)
-    else:
-        neg_loss = torch.FloatTensor([0]).cuda()
-        np_neg_ix = np.array([]).astype('int32')
-
-    loss = (pos_loss + neg_loss) / 2
-    return loss, np_neg_ix
-
-
-def compute_rpn_bbox_loss(rpn_pred_deltas, rpn_target_deltas, rpn_match):
-    """
-    :param rpn_target_deltas:   (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))).
-    Uses 0 padding to fill in unsed bbox deltas.
-    :param rpn_pred_deltas: predicted deltas from RPN. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd))))
-    :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors.
-    :return: loss: torch 1D tensor.
-    """
-    if not 0 in torch.nonzero(rpn_match == 1).size():
-
-        indices = torch.nonzero(rpn_match == 1).squeeze(1)
-        # Pick bbox deltas that contribute to the loss
-        rpn_pred_deltas = rpn_pred_deltas[indices]
-        # Trim target bounding box deltas to the same length as rpn_bbox.
-        target_deltas = rpn_target_deltas[:rpn_pred_deltas.size()[0], :]
-        # Smooth L1 loss
-        loss = F.smooth_l1_loss(rpn_pred_deltas, target_deltas)
-    else:
-        loss = torch.FloatTensor([0]).cuda()
-
-    return loss
-
-def compute_disc_loss(d_target, d_pred, target, shem_poolsize):
-
-
-
-
-    return
-
-
-def compute_mrcnn_bbox_loss(mrcnn_pred_deltas, mrcnn_target_deltas, target_class_ids):
-    """
-    :param mrcnn_target_deltas: (n_sampled_rois, (dy, dx, (dz), log(dh), log(dw), (log(dh)))
-    :param mrcnn_pred_deltas: (n_sampled_rois, n_classes, (dy, dx, (dz), log(dh), log(dw), (log(dh)))
-    :param target_class_ids: (n_sampled_rois)
-    :return: loss: torch 1D tensor.
-    """
-    if not 0 in torch.nonzero(target_class_ids > 0).size():
-        positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0]
-        positive_roi_class_ids = target_class_ids[positive_roi_ix].long()
-        target_bbox = mrcnn_target_deltas[positive_roi_ix, :].detach()
-        pred_bbox = mrcnn_pred_deltas[positive_roi_ix, positive_roi_class_ids, :]
-        loss = F.smooth_l1_loss(pred_bbox, target_bbox)
-    else:
-        loss = torch.FloatTensor([0]).cuda()
-
-    return loss
-
-def compute_mrcnn_mask_loss(pred_masks, target_masks, target_class_ids):
-    """
-    :param target_masks: (n_sampled_rois, y, x, (z)) A float32 tensor of values 0 or 1. Uses zero padding to fill array.
-    :param pred_masks: (n_sampled_rois, n_classes, y, x, (z)) float32 tensor with values between [0, 1].
-    :param target_class_ids: (n_sampled_rois)
-    :return: loss: torch 1D tensor.
-    """
-    if not 0 in torch.nonzero(target_class_ids > 0).size():
-        # Only positive ROIs contribute to the loss. And only
-        # the class-specific mask of each ROI.
-        positive_ix = torch.nonzero(target_class_ids > 0)[:, 0]
-        positive_class_ids = target_class_ids[positive_ix].long()
-        y_true = target_masks[positive_ix, :, :].detach()
-        y_pred = pred_masks[positive_ix, positive_class_ids, :, :]
-        loss = F.binary_cross_entropy(y_pred, y_true)
-    else:
-        loss = torch.FloatTensor([0]).cuda()
-
-    return loss
-
-def compute_mrcnn_class_loss(tasks, pred_class_logits, target_class_ids):
-    """
-    :param pred_class_logits: (n_sampled_rois, n_classes)
-    :param target_class_ids: (n_sampled_rois) batch dimension was merged into roi dimension.
-    :return: loss: torch 1D tensor.
-    """
-    if 'class' in tasks and not 0 in target_class_ids.size():
-    #if 0 in target_class_ids.size():
-    #    print("WARNING: using additional cl head")
-        loss = F.cross_entropy(pred_class_logits, target_class_ids.long())
-    else:
-        loss = torch.FloatTensor([0.]).cuda()
-
-    return loss
-
-def compute_mrcnn_regression_loss(tasks, pred, target, target_class_ids):
-    """regression loss is a distance metric between target vector and predicted regression vector.
-    :param pred: (n_sampled_rois, n_classes, [n_rg_feats if real regression or 1 if rg_bin task)
-    :param target: (n_sampled_rois, [n_rg_feats or n_rg_bins])
-    :return: differentiable loss, torch 1D tensor on cuda
-    """
-
-    if not 0 in target.shape and not 0 in torch.nonzero(target_class_ids > 0).shape:
-        if "regression_bin" in tasks:
-            positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0]
-            positive_roi_class_ids = target_class_ids[positive_roi_ix].long()
-            target = target[positive_roi_ix].detach()
-            pred = pred[positive_roi_ix, positive_roi_class_ids] #are the class logits
-            loss = F.cross_entropy(pred, target.long())
-        else:
-            positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0]
-            positive_roi_class_ids = target_class_ids[positive_roi_ix].long()
-            target = target[positive_roi_ix, :].detach()
-            pred = pred[positive_roi_ix, positive_roi_class_ids, :]
-            loss = F.smooth_l1_loss(pred, target)
-    else:
-        loss = torch.FloatTensor([0.]).cuda()
-
-    return loss
-
-############################################################
-#  Detection Layer
-############################################################
-
-def compute_roi_scores(cf, batch_rpn_proposals, mrcnn_cl_logits):
-    """Compute scores from uncertainty measures (lower=better) to use for sorting/clustering algos (higher=better).
-    :param cf:
-    :param uncert_class:
-    :param uncert_regression:
-    :return:
-    """
-    if not 'class' in cf.prediction_tasks:
-        scores = batch_rpn_proposals[:, :, -1].view(-1, 1)
-        scores = torch.cat((1 - scores, scores), dim=1)
-    else:
-        #print("WARNING: using extra class head")
-        scores = F.softmax(mrcnn_cl_logits, dim=1)
-
-    return scores
-
-############################################################
-#  MaskRCNN Class
-############################################################
-
-class net(nn.Module):
-
-
-    def __init__(self, cf, logger):
-
-        super(net, self).__init__()
-        self.cf = cf
-        self.logger = logger
-        self.build()
-
-
-        if self.cf.weight_init=="custom":
-            logger.info("Tried to use custom weight init which is not defined. Using pytorch default.")
-        elif self.cf.weight_init:
-            mutils.initialize_weights(self)
-        else:
-            logger.info("using default pytorch weight init")
-
-    def build(self):
-        """Build Mask R-CNN architecture."""
-
-        # Image size must be dividable by 2 multiple times.
-        h, w = self.cf.patch_size[:2]
-        if h / 2**5 != int(h / 2**5) or w / 2**5 != int(w / 2**5):
-            raise Exception("Image size must be divisible by 2 at least 5 times "
-                            "to avoid fractions when downscaling and upscaling."
-                            "For example, use 256, 320, 384, 448, 512, ... etc.,i.e.,"
-                            "any number x*32 will do!")
-
-        # instantiate abstract multi-dimensional conv generator and load backbone module.
-        backbone = utils.import_module('bbone', self.cf.backbone_path)
-        conv = backbone.ConvGenerator(self.cf.dim)
-
-        # build Anchors, FPN, RPN, Classifier / Bbox-Regressor -head, Mask-head
-        self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf)
-        self.anchors = torch.from_numpy(self.np_anchors).float().cuda()
-        self.fpn = backbone.FPN(self.cf, conv, relu_enc=self.cf.relu, operate_stride1=False).cuda()
-        self.rpn = Generator_RPN(self.cf, conv)
-        self.discriminator = RPN_Discriminator(self.cf, conv)
-        self.classifier = Classifier(self.cf, conv)
-        self.mask = Mask(self.cf, conv)
-
-    def forward(self, img, is_training=True):
-        """
-        :param img: input images (b, c, y, x, (z)).
-        :return: rpn_pred_logits: (b, n_anchors, 2)
-        :return: rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d))))
-        :return: batch_proposal_boxes: (b, n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix)) only for monitoring/plotting.
-        :return: detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score)
-        :return: detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head.
-        """
-        # extract features.
-        fpn_outs = self.fpn(img)
-        rpn_feature_maps = [fpn_outs[i] for i in self.cf.pyramid_levels]
-        self.mrcnn_feature_maps = rpn_feature_maps
-
-        # loop through pyramid layers and apply RPN.
-        layer_outputs = [ self.rpn(p_feats) for p_feats in rpn_feature_maps ]
-
-        # concatenate layer outputs.
-        # convert from list of lists of level outputs to list of lists of outputs across levels.
-        # e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]]
-        outputs = list(zip(*layer_outputs))
-        rpn_pred_logits, rpn_pred_probs, rpn_pred_deltas = [torch.cat(list(o), dim=1) for o in outputs]
-        #
-        # # generate proposals: apply predicted deltas to anchors and filter by foreground scores from RPN classifier.
-        proposal_count = self.cf.post_nms_rois_training if is_training else self.cf.post_nms_rois_inference
-        batch_normed_props, batch_unnormed_props = mutils.refine_proposals(rpn_pred_probs, rpn_pred_deltas,proposal_count,
-                                                                    self.anchors, self.cf)
-        # merge batch dimension of proposals while storing allocation info in coordinate dimension.
-        batch_ixs = torch.arange(
-            batch_normed_props.shape[0]).cuda().unsqueeze(1).repeat(1, batch_normed_props.shape[1]).view(-1).float()
-        rpn_rois = batch_normed_props[:, :, :-1].view(-1, batch_normed_props[:, :, :-1].shape[2])
-        self.rpn_rois_batch_info = torch.cat((rpn_rois, batch_ixs.unsqueeze(1)), dim=1)
-
-        # this is the first of two forward passes in the second stage, where no activations are stored for backprop.
-        # here, all proposals are forwarded (with virtual_batch_size = batch_size * post_nms_rois.)
-        # for inference/monitoring as well as sampling of rois for the loss functions.
-        # processed in chunks of roi_batch_size to re-adjust to gpu-memory.
-        chunked_rpn_rois = self.rpn_rois_batch_info.split(self.cf.roi_batch_size)
-        bboxes_list, class_logits_list, regressions_list = [], [], []
-        with torch.no_grad():
-            for chunk in chunked_rpn_rois:
-                chunk_bboxes, chunk_class_logits, chunk_regressions = self.classifier(self.mrcnn_feature_maps, chunk)
-                bboxes_list.append(chunk_bboxes)
-                class_logits_list.append(chunk_class_logits)
-                regressions_list.append(chunk_regressions)
-        mrcnn_bbox = torch.cat(bboxes_list, 0)
-        mrcnn_class_logits = torch.cat(class_logits_list, 0)
-        mrcnn_regressions = torch.cat(regressions_list, 0)
-        self.mrcnn_roi_scores = compute_roi_scores(self.cf, batch_normed_props, mrcnn_class_logits)
-
-        # refine classified proposals, filter and return final detections.
-        # returns (cf.max_inst_per_batch_element, n_coords+1+...)
-        detections = mutils.refine_detections(self.cf, batch_ixs, rpn_rois, mrcnn_bbox, self.mrcnn_roi_scores,
-                                       mrcnn_regressions)
-
-        # forward remaining detections through mask-head to generate corresponding masks.
-        scale = [img.shape[2]] * 4 + [img.shape[-1]] * 2
-        scale = torch.from_numpy(np.array(scale[:self.cf.dim * 2] + [1])[None]).float().cuda()
-
-        # first self.cf.dim * 2 entries on axis 1 are always the box coords, +1 is batch_ix
-        detection_boxes = detections[:, :self.cf.dim * 2 + 1] / scale
-        with torch.no_grad():
-            detection_masks = self.mask(self.mrcnn_feature_maps, detection_boxes)
-
-        return rpn_pred_logits, rpn_pred_probs, rpn_pred_deltas, batch_unnormed_props, detections, detection_masks
-
-    def loss_samples_forward(self, batch_gt_boxes, batch_gt_masks, batch_gt_class_ids, batch_gt_regressions=None):
-        """
-        this is the second forward pass through the second stage (features from stage one are re-used).
-        samples few rois in loss_example_mining and forwards only those for loss computation.
-        :param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels.
-        :param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates.
-        :param batch_gt_masks: list over batch elements. Each element is binary mask of shape (n_gt_rois, y, x, (z), c)
-        :return: sample_logits: (n_sampled_rois, n_classes) predicted class scores.
-        :return: sample_deltas: (n_sampled_rois, n_classes, 2 * dim) predicted corrections to be applied to proposals for refinement.
-        :return: sample_mask: (n_sampled_rois, n_classes, y, x, (z)) predicted masks per class and proposal.
-        :return: sample_target_class_ids: (n_sampled_rois) target class labels of sampled proposals.
-        :return: sample_target_deltas: (n_sampled_rois, 2 * dim) target deltas of sampled proposals for box refinement.
-        :return: sample_target_masks: (n_sampled_rois, y, x, (z)) target masks of sampled proposals.
-        :return: sample_proposals: (n_sampled_rois, 2 * dim) RPN output for sampled proposals. only for monitoring/plotting.
-        """
-        # sample rois for loss and get corresponding targets for all Mask R-CNN head network losses.
-        sample_ics, sample_target_deltas, sample_target_mask, sample_target_class_ids, sample_target_regressions = \
-            mutils.loss_example_mining(self.cf, self.rpn_rois_batch_info, batch_gt_boxes, batch_gt_masks,
-                                       self.mrcnn_roi_scores, batch_gt_class_ids, batch_gt_regressions)
-
-        # re-use feature maps and RPN output from first forward pass.
-        sample_proposals = self.rpn_rois_batch_info[sample_ics]
-        if not 0 in sample_proposals.size():
-            sample_deltas, sample_logits, sample_regressions = self.classifier(self.mrcnn_feature_maps, sample_proposals)
-            sample_mask = self.mask(self.mrcnn_feature_maps, sample_proposals)
-        else:
-            sample_logits = torch.FloatTensor().cuda()
-            sample_deltas = torch.FloatTensor().cuda()
-            sample_regressions = torch.FloatTensor().cuda()
-            sample_mask = torch.FloatTensor().cuda()
-
-        return [sample_deltas, sample_mask, sample_logits, sample_regressions, sample_proposals,
-                sample_target_deltas, sample_target_mask, sample_target_class_ids, sample_target_regressions]
-
-    def get_results(self, img_shape, detections, detection_masks, box_results_list=None, return_masks=True):
-        """
-        Restores batch dimension of merged detections, unmolds detections, creates and fills results dict.
-        :param img_shape:
-        :param detections: shape (n_final_detections, len(info)), where
-            info=( y1, x1, y2, x2, (z1,z2), batch_ix, pred_class_id, pred_score )
-        :param detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head.
-        :param box_results_list: None or list of output boxes for monitoring/plotting.
-        each element is a list of boxes per batch element.
-        :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off).
-        :return: results_dict: dictionary with keys:
-                 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
-                          [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
-                 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, 1] only fg. vs. bg for now.
-                 class-specific return of masks will come with implementation of instance segmentation evaluation.
-        """
-
-        detections = detections.cpu().data.numpy()
-        if self.cf.dim == 2:
-            detection_masks = detection_masks.permute(0, 2, 3, 1).cpu().data.numpy()
-        else:
-            detection_masks = detection_masks.permute(0, 2, 3, 4, 1).cpu().data.numpy()
-        # det masks shape now (n_dets, y,x(,z), n_classes)
-        # restore batch dimension of merged detections using the batch_ix info.
-        batch_ixs = detections[:, self.cf.dim*2]
-        detections = [detections[batch_ixs == ix] for ix in range(img_shape[0])]
-        mrcnn_mask = [detection_masks[batch_ixs == ix] for ix in range(img_shape[0])]
-        #mrcnn_mask: shape (b_size, variable, variable, n_classes), variable bc depends on single instance mask size
-
-        if box_results_list == None: # for test_forward, where no previous list exists.
-            box_results_list =  [[] for _ in range(img_shape[0])]
-        # seg_logits == seg_probs in mrcnn since mask head finishes with sigmoid (--> image space = [0,1])
-        seg_probs = []
-        # loop over batch and unmold detections.
-        for ix in range(img_shape[0]):
-
-            # final masks are one-hot encoded (b, n_classes, y, x, (z))
-            final_masks = np.zeros((self.cf.num_classes + 1, *img_shape[2:]))
-            #+1 for bg, 0.5 bc mask head classifies only bg/fg with logits between 0,1--> bg is <0.5
-            if self.cf.num_classes + 1 != self.cf.num_seg_classes:
-                self.logger.warning("n of box classifier head classes {} doesnt match cf.num_seg_classes {}".format(
-                    self.cf.num_classes + 1, self.cf.num_seg_classes))
-
-            if not 0 in detections[ix].shape:
-                boxes = detections[ix][:, :self.cf.dim*2].astype(np.int32)
-                class_ids = detections[ix][:, self.cf.dim*2 + 1].astype(np.int32)
-                scores = detections[ix][:, self.cf.dim*2 + 2]
-                masks = mrcnn_mask[ix][np.arange(boxes.shape[0]), ..., class_ids]
-                regressions = detections[ix][:,self.cf.dim*2+3:]
-
-                # Filter out detections with zero area. Often only happens in early
-                # stages of training when the network weights are still a bit random.
-                if self.cf.dim == 2:
-                    exclude_ix = np.where((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0]
-                else:
-                    exclude_ix = np.where(
-                        (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4]) <= 0)[0]
-
-                if exclude_ix.shape[0] > 0:
-                    boxes = np.delete(boxes, exclude_ix, axis=0)
-                    masks = np.delete(masks, exclude_ix, axis=0)
-                    class_ids = np.delete(class_ids, exclude_ix, axis=0)
-                    scores = np.delete(scores, exclude_ix, axis=0)
-                    regressions = np.delete(regressions, exclude_ix, axis=0)
-
-                # Resize masks to original image size and set boundary threshold.
-                if return_masks:
-                    for i in range(masks.shape[0]): #masks per this batch instance/element/image
-                        # Convert neural network mask to full size mask
-                        if self.cf.dim == 2:
-                            full_mask = mutils.unmold_mask_2D(masks[i], boxes[i], img_shape[2:])
-                        else:
-                            full_mask = mutils.unmold_mask_3D(masks[i], boxes[i], img_shape[2:])
-                        # take the maximum seg_logits per class of instances in that class, i.e., a pixel in a class
-                        # has the max seg_logit value over all instances of that class in one sample
-                        final_masks[class_ids[i]] = np.max((final_masks[class_ids[i]], full_mask), axis=0)
-                    final_masks[0] = np.full(final_masks[0].shape, 0.49999999) #effectively min_det_thres at 0.5 per pixel
-
-                # add final predictions to results.
-                if not 0 in boxes.shape:
-                    for ix2, coords in enumerate(boxes):
-                        box = {'box_coords': coords, 'box_type': 'det'}
-                        box.update({'box_score': scores[ix2], 'box_pred_class_id': class_ids[ix2]})
-                        #if (hasattr(self.cf, "convert_cl_to_rg") and self.cf.convert_cl_to_rg):
-                        if "regression_bin" in self.cf.prediction_tasks:
-                            # in this case, regression preds are actually the rg_bin_ids --> map to rg value the bin stands for
-                            box['rg_bin'] = regressions[ix2].argmax()
-                            box['regression'] = self.cf.bin_id2rg_val[box['rg_bin']]
-                        else:
-                            if hasattr(self.cf, "rg_val_to_bin_id"):
-                                box.update({'rg_bin': self.cf.rg_val_to_bin_id(regressions[ix2])})
-                            box['regression'] = regressions[ix2]
-
-                        box_results_list[ix].append(box)
-
-            # if no detections were made--> keep full bg mask (zeros).
-            seg_probs.append(final_masks)
-
-        # create and fill results dictionary.
-        results_dict = {}
-        results_dict['boxes'] = box_results_list
-        results_dict['seg_preds'] = np.array(seg_probs)
-
-        return results_dict
-
-
-    def train_forward(self, batch, is_validation=False):
-        """
-        train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
-        for processing, computes losses, and stores outputs in a dictionary.
-        :param batch: dictionary containing 'data', 'seg', etc.
-        :return: results_dict: dictionary with keys:
-                'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
-                        [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
-                'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes].
-                'torch_loss': 1D torch tensor for backprop.
-                'class_loss': classification loss for monitoring.
-        """
-        img = batch['data']
-        gt_boxes = batch['bb_target']
-        axes = (0, 2, 3, 1) if self.cf.dim == 2 else (0, 2, 3, 4, 1)
-        gt_masks = [np.transpose(batch['roi_masks'][ii], axes=axes) for ii in range(len(batch['roi_masks']))]
-        gt_class_ids = batch['class_targets']
-        if 'regression' in self.cf.prediction_tasks:
-            gt_regressions = batch["regression_targets"]
-        elif 'regression_bin' in self.cf.prediction_tasks:
-            gt_regressions = batch["rg_bin_targets"]
-        else:
-            gt_regressions = None
-
-
-        img = torch.from_numpy(img).float().cuda()
-        batch_rpn_class_loss = torch.FloatTensor([0]).cuda()
-        batch_rpn_bbox_loss = torch.FloatTensor([0]).cuda()
-        # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element.
-        box_results_list = [[] for _ in range(img.shape[0])]
-
-        #forward passes. 1. general forward pass, where no activations are saved in second stage (for performance
-        # monitoring and loss sampling). 2. second stage forward pass of sampled rois with stored activations for backprop.
-        rpn_class_logits, rpn_probs, rpn_pred_deltas, proposal_boxes, detections, detection_masks = self.forward(img)
-
-        mrcnn_pred_deltas, mrcnn_pred_mask, mrcnn_class_logits, mrcnn_regressions, sample_proposals, \
-        mrcnn_target_deltas, target_mask, target_class_ids, target_regressions = \
-            self.loss_samples_forward(gt_boxes, gt_masks, gt_class_ids, gt_regressions)
-
-        rpn_batch_match_targets = torch.zeros(img.shape[0], self.np_anchors.shape[0]).cuda()
-        rpn_batch_delta_targets = torch.zeros(img.shape[0], self.np_anchors.shape[0], self.cf.dim*2).cuda()
-        #loop over batch
-        for b in range(img.shape[0]):
-            rpn_target_deltas = np.zeros((self.np_anchors.shape[0], self.cf.dim * 2))
-            if len(gt_boxes[b]) > 0:
-                # add gt boxes to output list
-                for tix in range(len(gt_boxes[b])):
-                    gt_box = {'box_type': 'gt', 'box_coords': batch['bb_target'][b][tix]}
-                    for name in self.cf.roi_items:
-                        gt_box.update({name: batch[name][b][tix]})
-                    box_results_list[b].append(gt_box)
-
-                # match gt boxes with anchors to generate targets for RPN losses.
-                rpn_match, rpn_t_deltas = mutils.gt_anchor_matching(self.cf, self.np_anchors, gt_boxes[b])
-                indices = np.nonzero(rpn_match == 1)[0]
-                rpn_target_deltas[indices] = rpn_t_deltas[:indices.shape[0]]
-
-                # add positive anchors used for loss to output list for monitoring.
-                # pos_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == 1)][:, 0], img.shape[2:])
-                # for p in pos_anchors:
-                #     box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'})
-            else:
-                rpn_match = np.array([-1]*self.np_anchors.shape[0])
-
-            rpn_batch_match_targets[b] = torch.from_numpy(rpn_match).cuda()
-            rpn_batch_delta_targets[b] = torch.from_numpy(rpn_target_deltas).float().cuda()
-            # compute RPN losses.
-            #rpn_class_loss, neg_anchor_ix = compute_rpn_class_loss(rpn_class_logits[b], rpn_match, self.cf.shem_poolsize)
-            #rpn_bbox_loss = compute_rpn_bbox_loss(rpn_pred_deltas[b], rpn_target_deltas, rpn_match)
-
-            # batch_rpn_class_loss += rpn_class_loss /img.shape[0]
-            # batch_rpn_bbox_loss += rpn_bbox_loss /img.shape[0]
-
-            # add negative anchors used for loss to output list for monitoring.
-            # neg_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == -1)][0, neg_anchor_ix], img.shape[2:])
-            # for n in neg_anchors:
-            #     box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'})
-
-            # add highest scoring proposals to output list for monitoring.
-            rpn_proposals = proposal_boxes[b][proposal_boxes[b, :, -1].argsort()][::-1]
-            for r in rpn_proposals[:self.cf.n_plot_rpn_props, :-1]:
-                box_results_list[b].append({'box_coords': r, 'box_type': 'prop'})
-
-        #filter_anchors(rpn_batch_match_targets, rpn_class_logits, rpn_batch_delta_targets, rpn_pred_deltas,
-        #                        self.cf.shem_poolsize)
-        # todo maybe send fixed number of rois to disc (fill up targets with bg-rois)?
-        non_neutral_mask = (rpn_batch_match_targets == 1) | (rpn_batch_match_targets == -1)
-        rpn_batch_match_targets = rpn_batch_match_targets[non_neutral_mask]
-        rpn_batch_delta_targets = rpn_batch_delta_targets[non_neutral_mask]
-        rpn_probs = rpn_probs[non_neutral_mask]
-        rpn_pred_deltas = rpn_pred_deltas[non_neutral_mask]
-
-        # add positive and negative roi samples used for mrcnn losses to output list for monitoring.
-        # if not 0 in sample_proposals.shape:
-        #     rois = mutils.clip_to_window(self.cf.window, sample_proposals).cpu().data.numpy()
-        #     for ix, r in enumerate(rois):
-        #         box_results_list[int(r[-1])].append({'box_coords': r[:-1] * self.cf.scale,
-        #                                     'box_type': 'pos_class' if target_class_ids[ix] > 0 else 'neg_class'})
-
-        # get discriminator judgement on predicted proposals
-        # d_z = self.discriminator(self.mrcnn_feature_maps, rpn_probs, rpn_pred_deltas)
-        d_judgement_gen = self.discriminator(self.mrcnn_feature_maps, rpn_batch_match_targets, rpn_batch_delta_targets)
-
-        # compute Discriminator loss
-        compute_disc_loss(d_pred_target, d_pred_pred, d_target, self.cf.shem_poolsize)
-
-
-        # compute mrcnn losses.
-        mrcnn_class_loss = compute_mrcnn_class_loss(self.cf.prediction_tasks, mrcnn_class_logits, target_class_ids)
-        mrcnn_bbox_loss = compute_mrcnn_bbox_loss(mrcnn_pred_deltas, mrcnn_target_deltas, target_class_ids)
-        mrcnn_regressions_loss = compute_mrcnn_regression_loss(self.cf.prediction_tasks, mrcnn_regressions, target_regressions, target_class_ids)
-        # mrcnn can be run without pixelwise annotations available (Faster R-CNN mode).
-        # In this case, the mask_loss is taken out of training.
-        if not self.cf.frcnn_mode:
-            mrcnn_mask_loss = compute_mrcnn_mask_loss(mrcnn_pred_mask, target_mask, target_class_ids)
-        else:
-            mrcnn_mask_loss = torch.FloatTensor([0]).cuda()
-
-        loss = batch_rpn_class_loss + batch_rpn_bbox_loss +\
-               mrcnn_bbox_loss + mrcnn_mask_loss +  mrcnn_class_loss + mrcnn_regressions_loss
-
-        # monitor RPN performance: detection count = the number of correctly matched proposals per fg-class.
-        #dcount = [list(target_class_ids.cpu().data.numpy()).count(c) for c in np.arange(self.cf.head_classes)[1:]]
-        #self.logger.info("regression loss {:.3f}".format(mrcnn_regressions_loss.item()))
-        #self.logger.info("loss: {0:.2f}, rpn_class: {1:.2f}, rpn_bbox: {2:.2f}, mrcnn_class: {3:.2f}, mrcnn_bbox: {4:.2f}, "
-        #      "mrcnn_mask: {5:.2f}, dcount {6}".format(loss.item(), batch_rpn_class_loss.item(),
-        #      batch_rpn_bbox_loss.item(), mrcnn_class_loss.item(), mrcnn_bbox_loss.item(), mrcnn_mask_loss.item(), dcount))
-
-        # run unmolding of predictions for monitoring and merge all results to one dictionary.
-        if is_validation or self.cf.detect_while_training:
-            return_masks = self.cf.return_masks_in_val if is_validation else self.cf.return_masks_in_train
-            results_dict = self.get_results(
-                img.shape, detections, detection_masks, box_results_list, return_masks=return_masks) #TODO make multithreaded?
-            results_dict['seg_preds'] = results_dict['seg_preds'].argmax(axis=1).astype('uint8')[:,np.newaxis]
-            if 'dice' in self.cf.metrics:
-                results_dict['batch_dices'] = mutils.dice_per_batch_and_class(
-                    results_dict['seg_preds'], batch["seg"], self.cf.num_seg_classes, convert_to_ohe=True)
-        else:
-            results_dict = {'boxes': box_results_list}
-
-        results_dict['torch_loss'] = loss
-        results_dict['class_loss'] = mrcnn_class_loss.item()
-        results_dict['bbox_loss'] = mrcnn_bbox_loss.item()
-        results_dict['rg_loss'] = mrcnn_regressions_loss.item()
-        results_dict['rpn_class_loss'] = rpn_class_loss.item()
-        results_dict['rpn_bbox_loss'] = rpn_bbox_loss.item()
-        # #todo remove assert when sufficiently checked
-        # boxescoords = [b['box_coords'] for boxlist in box_results_list for b in boxlist]
-        # coords_check = np.array([len(coords) == self.cf.dim*2 for coords in boxescoords])
-        # assert np.all(coords_check), "cand box with wrong bcoords dim: {}".format(boxescoords[~coords_check])
-
-        return results_dict
-
-
-    def test_forward(self, batch, return_masks=True):
-        """
-        test method. wrapper around forward pass of network without usage of any ground truth information.
-        prepares input data for processing and stores outputs in a dictionary.
-        :param batch: dictionary containing 'data'
-        :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off).
-        :return: results_dict: dictionary with keys:
-               'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
-                       [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
-               'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]
-        """
-        img = batch['data']
-        img = torch.from_numpy(img).float().cuda()
-        _, _, _, detections, detection_masks = self.forward(img)
-        results_dict = self.get_results(img.shape, detections, detection_masks, return_masks=return_masks)
-
-        return results_dict
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 90510b3..2ab3c74 100644
--- a/setup.py
+++ b/setup.py
@@ -1,60 +1,60 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 from setuptools import find_packages, setup
-import os, site
+import os
 
 def parse_requirements(filename, exclude=[]):
     lineiter = (line.strip() for line in open(filename))
     return [line for line in lineiter if line and not line.startswith("#") and not line.split("==")[0] in exclude]
 
 def install_custom_ext(setup_path):
     os.system("python "+setup_path+" install")
     return
 
 def clean():
     """Custom clean command to tidy up the project root."""
     os.system('rm -vrf ./build ./dist ./*.pyc ./*.tgz ./*.egg-info')
 
 req_file = "requirements.txt"
 custom_exts = ["nms-extension", "RoIAlign-extension-2D", "RoIAlign-extension-3D"]
 install_reqs = parse_requirements(req_file, exclude=custom_exts)
 
 setup(name='RegRCNN',
       version='0.0.2',
       url="https://github.com/MIC-DKFZ/RegRCNN",
       author='G. Ramien, P. Jaeger, MIC at DKFZ Heidelberg',
       author_email='g.ramien@dkfz.de',
       licence="Apache 2.0",
       description="Medical Object-Detection Toolkit incl. Regression Capability.",
       classifiers=[
           "Development Status :: 4 - Beta",
           "Intended Audience :: Developers",
           "Programming Language :: Python :: 3.7"
       ],
       packages=find_packages(exclude=['test', 'test.*']),
       install_requires=install_reqs,
       )
 
 custom_exts =  ["custom_extensions/nms", "custom_extensions/roi_align"]
 for path in custom_exts:
     setup_path = os.path.join(path, "setup.py")
     try:
         install_custom_ext(setup_path)
     except Exception as e:
         print("FAILED to install custom extension {} due to Error:\n{}".format(path, e))
 
 clean()
\ No newline at end of file
diff --git a/shell_scripts/ana_starter.sh b/shell_scripts/ana_starter.sh
deleted file mode 100644
index 1eeb63d..0000000
--- a/shell_scripts/ana_starter.sh
+++ /dev/null
@@ -1,11 +0,0 @@
-mode=${1}
-dataset_name=${2}
-
-source_dir=/home/gregor/Documents/medicaldetectiontoolkit
-
-exps_dir=/home/gregor/networkdrives/E132-Cluster-Projects/${dataset_name}/experiments_float_data
-exps_dirs=$(ls -d ${exps_dir}/*)
-for dir in ${exps_dirs}; do
-	echo "starting ${mode} in ${dir}"
-	(python ${source_dir}/exec.py --use_stored_settings --mode ${mode} --dataset_name ${dataset_name} --exp_dir ${dir}) || (echo "FAILED!")
-done
diff --git a/understanding_metrics.py b/understanding_metrics.py
deleted file mode 100644
index 6e1532f..0000000
--- a/understanding_metrics.py
+++ /dev/null
@@ -1,66 +0,0 @@
-
-"""
-Created at 06/12/18 13:34
-@author: gregor 
-"""
-import sys
-import os
-import numpy as np
-import pandas as pd
-from sklearn.metrics import roc_auc_score, average_precision_score
-from sklearn.metrics import roc_curve, precision_recall_curve
-
-import plotting as plg
-import evaluator
-
-sys.path.append("datasets/prostate/")
-from configs import Configs
-
-""" This is just a supplementary file which you may use to demonstrate or understand detection metrics.
-"""
-
-
-def get_det_types(df):
-    det_types = []
-    for ix, score in enumerate(df["pred_score"]):
-        if score > 0 and df["class_label"][ix] == 1:
-            det_types.append("det_tp")
-        elif score > 0 and df["class_label"][ix] == 0:
-            det_types.append("det_fp")
-        elif score == 0 and df["class_label"][ix] == 1:
-            det_types.append("det_fn")
-        elif score == 0 and df["class_label"][ix] == 0:
-            det_types.append("det_tn")
-    return det_types
-
-
-if __name__=="__main__":
-    cf = Configs()
-
-    working_dir = "/home/gregor/Documents/ramien/Thesis/UnderstandingMetrics"
-
-    df = pd.DataFrame(columns=['pred_score', 'class_label', 'pred_class', 'det_type', 'match_iou'])
-
-    df["pred_score"] = [0.3,  0.]
-    df["class_label"] = [0,   1]
-    #df["pred_class"] = [1]*len(df)
-    det_types = get_det_types(df)
-
-    df["det_type"] = det_types
-    df["match_iou"] = [0.1]*len(df)
-
-    prc_own = evaluator.compute_prc(df)
-    all_stats = [{"prc":prc_own, 'roc':np.nan, 'name': "demon"}]
-    plg.plot_stat_curves(cf, all_stats, os.path.join(working_dir, "understanding_ap_own"), fill=True)
-
-    prc_sk = precision_recall_curve(df.class_label.tolist(), df.pred_score.tolist())
-    all_stats = [{"prc":prc_sk, 'roc':np.nan, 'name': "demon"}]
-    plg.plot_stat_curves(cf, all_stats, os.path.join(working_dir, "understanding_ap"), fill=True)
-
-    ap = evaluator.get_roi_ap_from_df((df, 0.02, False))
-    ap_sk = average_precision_score(df.class_label.tolist(), df.pred_score.tolist())
-    print("roi_ap_from_df (own implement):",ap)
-    print("aver_prec_sc (sklearn):",ap_sk)
-
-    plg.plot_prediction_hist(cf, df, os.path.join(working_dir, "understanding_ap.png"), title="AP_own {:.2f}, AP_sklearn {:.2f}".format(ap, ap_sk))
-