diff --git a/custom_extensions/nms/setup.py b/custom_extensions/nms/setup.py
index 90a5d13..70ed1e6 100644
--- a/custom_extensions/nms/setup.py
+++ b/custom_extensions/nms/setup.py
@@ -1,22 +1,25 @@
 """
 Created at 07.11.19 19:12
 @author: gregor
 
 """
 
 import os, sys, site
 from pathlib import Path
 
 # recognise newly installed packages in path
 site.main()
 
 from setuptools import setup
 from torch.utils import cpp_extension
 
 dir_ = Path(os.path.dirname(sys.argv[0]))
 
+sources = [str(dir_/'src/nms_interface.cpp'), str(dir_/'src/nms.cu')]
+
 setup(name='nms_extension',
-      ext_modules=[cpp_extension.CUDAExtension('nms_extension', [str(dir_/'src/nms_interface.cpp'), str(dir_/'src/nms.cu')])],
+      ext_modules=[cpp_extension.CUDAExtension(
+            'nms_extension', sources
+      )],
       cmdclass={'build_ext': cpp_extension.BuildExtension}
-      )
-
+      )
\ No newline at end of file
diff --git a/custom_extensions/roi_align/setup.py b/custom_extensions/roi_align/2D/setup.py
similarity index 64%
copy from custom_extensions/roi_align/setup.py
copy to custom_extensions/roi_align/2D/setup.py
index ebe50f4..921913f 100644
--- a/custom_extensions/roi_align/setup.py
+++ b/custom_extensions/roi_align/2D/setup.py
@@ -1,28 +1,22 @@
 """
 Created at 07.11.19 19:12
 @author: gregor
 
 """
 
 import os, sys, site
 from pathlib import Path
 
 # recognise newly installed packages in path
 site.main()
 
 from setuptools import setup
 from torch.utils import cpp_extension
 
 dir_ = Path(os.path.dirname(sys.argv[0]))
 
 setup(name='RoIAlign extension 2D',
       ext_modules=[cpp_extension.CUDAExtension('roi_al_extension', [str(dir_/'src/RoIAlign_interface.cpp'),
                                                                     str(dir_/'src/RoIAlign_cuda.cu')])],
       cmdclass={'build_ext': cpp_extension.BuildExtension}
       )
-
-setup(name='RoIAlign extension 3D',
-      ext_modules=[cpp_extension.CUDAExtension('roi_al_extension_3d', [str(dir_/'src/RoIAlign_interface_3d.cpp'),
-                                                                       str(dir_/'src/RoIAlign_cuda_3d.cu')])],
-      cmdclass={'build_ext': cpp_extension.BuildExtension}
-      )
\ No newline at end of file
diff --git a/custom_extensions/roi_align/src/RoIAlign_cuda.cu b/custom_extensions/roi_align/2D/src/RoIAlign_cuda.cu
similarity index 100%
rename from custom_extensions/roi_align/src/RoIAlign_cuda.cu
rename to custom_extensions/roi_align/2D/src/RoIAlign_cuda.cu
diff --git a/custom_extensions/roi_align/src/RoIAlign_interface.cpp b/custom_extensions/roi_align/2D/src/RoIAlign_interface.cpp
similarity index 100%
rename from custom_extensions/roi_align/src/RoIAlign_interface.cpp
rename to custom_extensions/roi_align/2D/src/RoIAlign_interface.cpp
diff --git a/custom_extensions/roi_align/src/cuda_helpers.h b/custom_extensions/roi_align/2D/src/cuda_helpers.h
similarity index 100%
copy from custom_extensions/roi_align/src/cuda_helpers.h
copy to custom_extensions/roi_align/2D/src/cuda_helpers.h
diff --git a/custom_extensions/roi_align/setup.py b/custom_extensions/roi_align/3D/setup.py
similarity index 64%
rename from custom_extensions/roi_align/setup.py
rename to custom_extensions/roi_align/3D/setup.py
index ebe50f4..f2d164b 100644
--- a/custom_extensions/roi_align/setup.py
+++ b/custom_extensions/roi_align/3D/setup.py
@@ -1,28 +1,22 @@
 """
 Created at 07.11.19 19:12
 @author: gregor
 
 """
 
 import os, sys, site
 from pathlib import Path
 
 # recognise newly installed packages in path
 site.main()
 
 from setuptools import setup
 from torch.utils import cpp_extension
 
 dir_ = Path(os.path.dirname(sys.argv[0]))
 
-setup(name='RoIAlign extension 2D',
-      ext_modules=[cpp_extension.CUDAExtension('roi_al_extension', [str(dir_/'src/RoIAlign_interface.cpp'),
-                                                                    str(dir_/'src/RoIAlign_cuda.cu')])],
-      cmdclass={'build_ext': cpp_extension.BuildExtension}
-      )
-
 setup(name='RoIAlign extension 3D',
       ext_modules=[cpp_extension.CUDAExtension('roi_al_extension_3d', [str(dir_/'src/RoIAlign_interface_3d.cpp'),
                                                                        str(dir_/'src/RoIAlign_cuda_3d.cu')])],
       cmdclass={'build_ext': cpp_extension.BuildExtension}
       )
\ No newline at end of file
diff --git a/custom_extensions/roi_align/src/RoIAlign_cuda_3d.cu b/custom_extensions/roi_align/3D/src/RoIAlign_cuda_3d.cu
similarity index 100%
rename from custom_extensions/roi_align/src/RoIAlign_cuda_3d.cu
rename to custom_extensions/roi_align/3D/src/RoIAlign_cuda_3d.cu
diff --git a/custom_extensions/roi_align/src/RoIAlign_interface_3d.cpp b/custom_extensions/roi_align/3D/src/RoIAlign_interface_3d.cpp
similarity index 100%
rename from custom_extensions/roi_align/src/RoIAlign_interface_3d.cpp
rename to custom_extensions/roi_align/3D/src/RoIAlign_interface_3d.cpp
diff --git a/custom_extensions/roi_align/src/cuda_helpers.h b/custom_extensions/roi_align/3D/src/cuda_helpers.h
similarity index 100%
rename from custom_extensions/roi_align/src/cuda_helpers.h
rename to custom_extensions/roi_align/3D/src/cuda_helpers.h
diff --git a/custom_extensions/sandbox/setup.py b/custom_extensions/sandbox/setup.py
deleted file mode 100644
index 71ac616..0000000
--- a/custom_extensions/sandbox/setup.py
+++ /dev/null
@@ -1,13 +0,0 @@
-"""
-Created at 07.11.19 19:12
-@author: gregor
-
-"""
-
-
-from setuptools import setup
-from torch.utils import cpp_extension
-
-setup(name='sandbox_cuda',
-      ext_modules=[cpp_extension.CUDAExtension('sandbox', ['src/sandbox.cpp', 'src/sandbox_cuda.cu'])],
-      cmdclass={'build_ext': cpp_extension.BuildExtension})
\ No newline at end of file
diff --git a/custom_extensions/sandbox/src/sandbox.cpp b/custom_extensions/sandbox/src/sandbox.cpp
deleted file mode 100644
index e1fbe12..0000000
--- a/custom_extensions/sandbox/src/sandbox.cpp
+++ /dev/null
@@ -1,82 +0,0 @@
-// ------------------------------------------------------------------
-// Faster R-CNN
-// Copyright (c) 2015 Microsoft
-// Licensed under The MIT License [see fast-rcnn/LICENSE for details]
-// Written by Shaoqing Ren, rewritten in C++ by Gregor Ramien
-// ------------------------------------------------------------------
-//#include <THC/THC.h>
-//#include <TH/TH.h>
-
-#include <torch/extension.h>
-#include <iostream>
-
-#include "sandbox.h"
-
-//#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) // divide m by n, add +1 if there is a remainder
-#define getNBlocks(m,n) ( (m+n-1) / (n) ) // m = nr of total (required) threads, n = nr of threads per block.
-int const threadsPerBlock = sizeof(unsigned long long) * 8;
-
-//---- declarations that will be defined in cuda kernel
-void add_cuda(int n=1<<3);
-void nms_cuda(at::Tensor *boxes, at::Tensor *scores, float thresh);
-//-----------------------------------------------------------------
-
-#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
-#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
-#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
-
-void sandbox() {
-
-    //std::cout<< "number: "<< number << std::endl;
-
-    torch::Tensor tensor = torch::tensor({0,1,2,3,4,5}, at::kLong).view({2,3});
-    std::cout<< "tensor: " << tensor << std::endl;
-    std::cout<< "tensor shape: " << at::size(tensor,0) << ", " << at::size(tensor,1) << std::endl;
-    return;
-}
-
-void add(){
-    //tutorial function: add two arrays (x,y) of length n
-    add_cuda();
-}
-
-//void nms(at::Tensor boxes, at::Tensor scores, float thresh) {
-void nms() {
-
-    // x1, y1, x2, y2
-    at::Tensor boxes = torch::tensor({
-        {20, 10, 60, 40},
-        {10, 20, 40, 60},
-        {20, 20, 80, 50}
-    }, at::TensorOptions().dtype(at::kInt).device(at::kCUDA));
-    std::cout << "boxes: \n" << boxes << std::endl;
-    at::Tensor scores = torch::tensor({
-        0.5,
-        0.6,
-        0.7
-    }, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA));
-    std::cout<< "scores: \n" << scores << std::endl;
-
-    CHECK_INPUT(boxes); CHECK_INPUT(scores);
-
-    int boxes_num = at::size(boxes,0);
-    int boxes_dim = at::size(boxes,1);
-
-    std::cout << "boxes shape: " << boxes_num << ", " << boxes_dim << std::endl;
-
-    float * boxes_dev; unsigned long long * mask_dev; float thresh = 0.05;
-
-    dim3 blocks((boxes_num, threadsPerBlock);
-    int threads(threadsPerBlock);
-
-
-}
-
-
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-  m.def("sandbox", &sandbox, "Sandy Box Function");
-  m.def("nms", &nms, "NMS in cpp");
-  m.def("add", &add, "tutorial add function");
-
-}
\ No newline at end of file
diff --git a/custom_extensions/sandbox/src/sandbox.h b/custom_extensions/sandbox/src/sandbox.h
deleted file mode 100644
index c67254f..0000000
--- a/custom_extensions/sandbox/src/sandbox.h
+++ /dev/null
@@ -1,3 +0,0 @@
-
-
-#define getNBlocks(m,n) ( (m+n-1) / (n) ) // m = nr of total (required) threads, n = nr of threads per block.
\ No newline at end of file
diff --git a/custom_extensions/sandbox/src/sandbox_cuda.cu b/custom_extensions/sandbox/src/sandbox_cuda.cu
deleted file mode 100644
index 8bb6764..0000000
--- a/custom_extensions/sandbox/src/sandbox_cuda.cu
+++ /dev/null
@@ -1,130 +0,0 @@
-// ------------------------------------------------------------------
-// Faster R-CNN
-// Copyright (c) 2015 Microsoft
-// Licensed under The MIT License [see fast-rcnn/LICENSE for details]
-// Written by Shaoqing Ren
-// ------------------------------------------------------------------
-
-#include <cuda.h>
-#include <cuda_runtime.h>
-
-#include <math.h>
-#include <stdio.h>
-#include <iostream>
-#include <float.h>
-
-#include "sandbox.h"
-//tutorial cuda function add
-
-__global__ void add_kernel(float *x, float *y, int n){
-    printf("block %d: threadIdx.x %d, threadIdx.y %d, threadIdx.z %d.\n", blockIdx.x, threadIdx.x, threadIdx.y, threadIdx.z);
-    int index = blockIdx.x * blockDim.x + threadIdx.x;
-    int stride = blockDim.x * gridDim.x;
-    for (int i=index; i<n; i+=stride)
-        y[i] = x[i] + y[i];
-}
-
-void add_cuda(int n=1<<3){
-    float *x, *y;
-    std::cout << "n: " << n << std::endl;
-    cudaMallocManaged(&x, n*sizeof(float));
-    cudaMallocManaged(&y, n*sizeof(float));
-
-    for (int i=0; i<n;i++){
-        x[i] = 1.0f;
-        y[i] = 2.0f;
-    }
-
-    int blockSize = 256;
-    int numBlocks = getNBlocks(n, blockSize);
-    std::cout << "numBlocks " << numBlocks << std::endl;
-    add_kernel<<<numBlocks, blockSize>>>(x, y, n);
-
-    cudaDeviceSynchronize();
-
-    float maxError = 0.0f;
-    for (int i = 0; i < n; i++)
-        maxError = fmax(maxError, fabs(y[i]-3.0f));
-    std::cout << "Max error: " << maxError << std::endl;
-
-    cudaFree(x);
-    cudaFree(y);
-
-}
-
-/*
-__device__ inline float devIoU(float const * const a, float const * const b) {
-  float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]);
-  float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]);
-  float width = fmaxf(right - left + 1, 0.f), height = fmaxf(bottom - top + 1, 0.f);
-  float interS = width * height;
-  float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
-  float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
-  return interS / (Sa + Sb - interS);
-}
-
-__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
-                           const float *dev_boxes, unsigned long long *dev_mask) {
-  const int row_start = blockIdx.y;
-  const int col_start = blockIdx.x;
-
-  // if (row_start > col_start) return;
-
-  const int row_size =
-        fminf(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
-  const int col_size =
-        fminf(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
-
-  __shared__ float block_boxes[threadsPerBlock * 5];
-  if (threadIdx.x < col_size) {
-    block_boxes[threadIdx.x * 5 + 0] =
-        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
-    block_boxes[threadIdx.x * 5 + 1] =
-        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
-    block_boxes[threadIdx.x * 5 + 2] =
-        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
-    block_boxes[threadIdx.x * 5 + 3] =
-        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
-    block_boxes[threadIdx.x * 5 + 4] =
-        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
-  }
-  __syncthreads();
-
-  if (threadIdx.x < row_size) {
-    const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
-    const float *cur_box = dev_boxes + cur_box_idx * 5;
-    int i = 0;
-    unsigned long long t = 0;
-    int start = 0;
-    if (row_start == col_start) {
-      start = threadIdx.x + 1;
-    }
-    for (i = start; i < col_size; i++) {
-      if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
-        t |= 1ULL << i;
-      }
-    }
-    const int col_blocks = DIVUP(n_boxes, threadsPerBlock);
-    dev_mask[cur_box_idx * col_blocks + col_start] = t;
-  }
-}
-//
-*/
-
-__global__ void nms_kernel(){
-
-}
-
-void nms_cuda(){
-
-
-
-}
-
-
-int main(void){
-
-    nms_cuda();
-
-    return 0;
-}
\ No newline at end of file
diff --git a/datasets/toy/configs.py b/datasets/toy/configs.py
index 8210f14..94288ad 100644
--- a/datasets/toy/configs.py
+++ b/datasets/toy/configs.py
@@ -1,490 +1,490 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import sys
 import os
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 import numpy as np
 from default_configs import DefaultConfigs
 from collections import namedtuple
 
 boxLabel = namedtuple('boxLabel', ["name", "color"])
 Label = namedtuple("Label", ['id', 'name', 'shape', 'radius', 'color', 'regression', 'ambiguities', 'gt_distortion'])
 binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'bin_vals'])
 
 class Configs(DefaultConfigs):
 
     def __init__(self, server_env=None):
         super(Configs, self).__init__(server_env)
 
         #########################
         #         Prepro        #
         #########################
 
         self.pp_rootdir = os.path.join('/home/gregor/datasets/toy', "cyl1ps_dev")
         self.pp_npz_dir = self.pp_rootdir+"_npz"
 
         self.pre_crop_size = [320,320,8] #y,x,z; determines pp data shape (2D easily implementable, but only 3D for now)
         self.min_2d_radius = 6 #in pixels
         self.n_train_samples, self.n_test_samples = 1200, 1000
 
         # not actually real one-hot encoding (ohe) but contains more info: roi-overlap only within classes.
         self.pp_create_ohe_seg = False
         self.pp_empty_samples_ratio = 0.1
 
         self.pp_place_radii_mid_bin = True
         self.pp_only_distort_2d = True
         # outer-most intensity of blurred radii, relative to inner-object intensity. <1 for decreasing, > 1 for increasing.
         # e.g.: setting 0.1 means blurred edge has min intensity 10% as large as inner-object intensity.
         self.pp_blur_min_intensity = 0.2
 
         self.max_instances_per_sample = 1 #how many max instances over all classes per sample (img if 2d, vol if 3d)
         self.max_instances_per_class = self.max_instances_per_sample  # how many max instances per image per class
         self.noise_scale = 0.  # std-dev of gaussian noise
 
         self.ambigs_sampling = "gaussian" #"gaussian" or "uniform"
         """ radius_calib: gt distort for calibrating uncertainty. Range of gt distortion is inferable from
             image by distinguishing it from the rest of the object.
             blurring width around edge will be shifted so that symmetric rel to orig radius.
             blurring scale: if self.ambigs_sampling is uniform, distribution's non-zero range (b-a) will be sqrt(12)*scale
             since uniform dist has variance (b-a)²/12. b,a will be placed symmetrically around unperturbed radius.
             if sampling is gaussian, then scale parameter sets one std dev, i.e., blurring width will be orig_radius * std_dev * 2.
         """
         self.ambiguities = {
              #set which classes to apply which ambs to below in class labels
              #choose out of: 'outer_radius', 'inner_radius', 'radii_relations'.
              #kind              #probability   #scale (gaussian std, relative to unperturbed value)
             #"outer_radius":     (1.,            0.5),
             #"outer_radius_xy":  (1.,            0.5),
             #"inner_radius":     (0.5,            0.1),
             #"radii_relations":  (0.5,            0.1),
             "radius_calib":     (1.,            1./6)
         }
 
         # shape choices: 'cylinder', 'block'
         #                        id,    name,       shape,      radius,                 color,              regression,     ambiguities,    gt_distortion
         self.pp_classes = [Label(1,     'cylinder', 'cylinder', ((6,6,1),(40,40,8)),    (*self.blue, 1.),   "radius_2d",    (),             ()),
                            #Label(2,      'block',      'block',        ((6,6,1),(40,40,8)),  (*self.aubergine,1.),  "radii_2d", (), ('radius_calib',))
             ]
 
 
         #########################
         #         I/O           #
         #########################
 
         self.data_sourcedir = '/home/gregor/datasets/toy/cyl1ps_dev'
 
         if server_env:
             self.data_sourcedir = '/datasets/data_ramien/toy/cyl1ps_dev_npz'
 
 
         self.test_data_sourcedir = os.path.join(self.data_sourcedir, 'test')
         self.data_sourcedir = os.path.join(self.data_sourcedir, "train")
 
         self.info_df_name = 'info_df.pickle'
 
         # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn', 'detection_fpn'].
         self.model = 'mrcnn'
         self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net')
         self.model_path = os.path.join(self.source_dir, self.model_path)
 
 
         #########################
         #      Architecture     #
         #########################
 
         # one out of [2, 3]. dimension the model operates in.
-        self.dim = 3
+        self.dim = 2
 
         # 'class', 'regression', 'regression_bin', 'regression_ken_gal'
         # currently only tested mode is a single-task at a time (i.e., only one task in below list)
         # but, in principle, tasks could be combined (e.g., object classes and regression per class)
         self.prediction_tasks = ['class', ]
 
         self.start_filts = 48 if self.dim == 2 else 18
         self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2
         self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50'
         self.norm = 'instance_norm' # one of None, 'instance_norm', 'batch_norm'
         self.relu = 'relu'
         # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform')
         self.weight_init = None
 
         self.regression_n_features = 1  # length of regressor target vector
 
 
         #########################
         #      Data Loader      #
         #########################
 
         self.num_epochs = 32
-        self.num_train_batches = 120 if self.dim == 2 else 80
+        self.num_train_batches = 120 if self.dim == 2 else 180
         self.batch_size = 8 if self.dim == 2 else 4
 
         self.n_cv_splits = 4
         # select modalities from preprocessed data
         self.channels = [0]
         self.n_channels = len(self.channels)
 
         # which channel (mod) to show as bg in plotting, will be extra added to batch if not in self.channels
         self.plot_bg_chan = 0
         self.crop_margin = [20, 20, 1]  # has to be smaller than respective patch_size//2
         self.patch_size_2D = self.pre_crop_size[:2]
         self.patch_size_3D = self.pre_crop_size[:2]+[8]
 
         # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation.
         self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D
 
         # ratio of free sampled batch elements before class balancing is triggered
         # (>0 to include "empty"/background patches.)
         self.batch_random_ratio = 0.2
         self.balance_target = "class_targets" if 'class' in self.prediction_tasks else "rg_bin_targets"
 
         self.observables_patient = []
         self.observables_rois = []
 
         self.seed = 3 #for generating folds
 
         #############################
         # Colors, Classes, Legends  #
         #############################
-        self.plot_frequency = 1
+        self.plot_frequency = 4
 
         binary_bin_labels = [binLabel(1,  'r<=25',      (*self.green, 1.),      (1,25)),
                              binLabel(2,  'r>25',       (*self.red, 1.),        (25,))]
         quintuple_bin_labels = [binLabel(1,  'r2-10',   (*self.green, 1.),      (2,10)),
                                 binLabel(2,  'r10-20',  (*self.yellow, 1.),     (10,20)),
                                 binLabel(3,  'r20-30',  (*self.orange, 1.),     (20,30)),
                                 binLabel(4,  'r30-40',  (*self.bright_red, 1.), (30,40)),
                                 binLabel(5,  'r>40',    (*self.red, 1.), (40,))]
 
         # choose here if to do 2-way or 5-way regression-bin classification
         task_spec_bin_labels = quintuple_bin_labels
 
         self.class_labels = [
             # regression: regression-task label, either value or "(x,y,z)_radius" or "radii".
             # ambiguities: name of above defined ambig to apply to image data (not gt); need to be iterables!
             # gt_distortion: name of ambig to apply to gt only; needs to be iterable!
             #      #id  #name   #shape  #radius     #color              #regression #ambiguities    #gt_distortion
             Label(  0,  'bg',   None,   (0, 0, 0),  (*self.white, 0.),  (0, 0, 0),  (),             ())]
         if "class" in self.prediction_tasks:
             self.class_labels += self.pp_classes
         else:
             self.class_labels += [Label(1, 'object', 'object', ('various',), (*self.orange, 1.), ('radius_2d',), ("various",), ('various',))]
 
 
         if any(['regression' in task for task in self.prediction_tasks]):
             self.bin_labels = [binLabel(0,  'bg',       (*self.white, 1.),      (0,))]
             self.bin_labels += task_spec_bin_labels
             self.bin_id2label = {label.id: label for label in self.bin_labels}
             bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels]
             self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)}
             self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)]
             self.bin_dict = {label.id: label.name for label in self.bin_labels if label.id != 0}
 
         if self.class_specific_seg:
           self.seg_labels = self.class_labels
 
         self.box_type2label = {label.name: label for label in self.box_labels}
         self.class_id2label = {label.id: label for label in self.class_labels}
         self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0}
 
         self.seg_id2label = {label.id: label for label in self.seg_labels}
         self.cmap = {label.id: label.color for label in self.seg_labels}
 
         self.plot_prediction_histograms = True
         self.plot_stat_curves = False
         self.has_colorchannels = False
         self.plot_class_ids = True
 
         self.num_classes = len(self.class_dict)
         self.num_seg_classes = len(self.seg_labels)
 
         #########################
         #   Data Augmentation   #
         #########################
         self.do_aug = True
         self.da_kwargs = {
             'mirror': True,
             'mirror_axes': tuple(np.arange(0, self.dim, 1)),
             'do_elastic_deform': False,
             'alpha': (500., 1500.),
             'sigma': (40., 45.),
             'do_rotation': False,
             'angle_x': (0., 2 * np.pi),
             'angle_y': (0., 0),
             'angle_z': (0., 0),
             'do_scale': False,
             'scale': (0.8, 1.1),
             'random_crop': False,
             'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3),
             'border_mode_data': 'constant',
             'border_cval_data': 0,
             'order_data': 1
         }
 
         if self.dim == 3:
             self.da_kwargs['do_elastic_deform'] = False
             self.da_kwargs['angle_x'] = (0, 0.0)
             self.da_kwargs['angle_y'] = (0, 0.0)  # must be 0!!
             self.da_kwargs['angle_z'] = (0., 2 * np.pi)
 
         #########################
         #  Schedule / Selection #
         #########################
 
         # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training)
         # the former is morge accurate, while the latter is faster (depending on volume size)
         self.val_mode = 'val_sampling' # one of 'val_sampling' , 'val_patient'
         if self.val_mode == 'val_patient':
             self.max_val_patients = 220  # if 'all' iterates over entire val_set once.
         if self.val_mode == 'val_sampling':
             self.num_val_batches = 35 if self.dim==2 else 25
 
         self.save_n_models = 2
         self.min_save_thresh = 1 if self.dim == 2 else 1  # =wait time in epochs
         if "class" in self.prediction_tasks:
             self.model_selection_criteria = {name + "_ap": 1. for name in self.class_dict.values()}
         elif any("regression" in task for task in self.prediction_tasks):
             self.model_selection_criteria = {name + "_ap": 0.2 for name in self.class_dict.values()}
             self.model_selection_criteria.update({name + "_avp": 0.8 for name in self.class_dict.values()})
 
         self.lr_decay_factor = 0.5
         self.scheduling_patience = int(self.num_epochs / 5)
         self.weight_decay = 1e-5
         self.clip_norm = None  # number or None
 
         #########################
         #   Testing / Plotting  #
         #########################
 
         self.test_aug_axes = (0,1,(0,1)) # None or list: choices are 0,1,(0,1)
         self.held_out_test_set = True
         self.max_test_patients = "all"  # number or "all" for all
 
         self.test_against_exact_gt = True # only True implemented
         self.val_against_exact_gt = False # True is an unrealistic --> irrelevant scenario.
         self.report_score_level = ['rois']  # 'patient' or 'rois' (incl)
         self.patient_class_of_interest = 1
         self.patient_bin_of_interest = 2
 
         self.eval_bins_separately = False#"additionally" if not 'class' in self.prediction_tasks else False
         self.metrics = ['ap', 'auc', 'dice']
         if any(['regression' in task for task in self.prediction_tasks]):
             self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp',
                              'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp']
         if 'aleatoric' in self.model:
             self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted']
         self.evaluate_fold_means = True
 
         self.ap_match_ious = [0.5]  # threshold(s) for considering a prediction as true positive
         self.min_det_thresh = 0.3
 
         self.model_max_iou_resolution = 0.2
 
         # aggregation method for test and val_patient predictions.
         # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf,
         # nms = standard non-maximum suppression, or None = no clustering
         self.clustering = 'wbc'
         # iou thresh (exclusive!) for regarding two preds as concerning the same ROI
         self.clustering_iou = self.model_max_iou_resolution  # has to be larger than desired possible overlap iou of model predictions
 
         self.merge_2D_to_3D_preds = False
         self.merge_3D_iou = self.model_max_iou_resolution
         self.n_test_plots = 1  # per fold and rank
 
         self.test_n_epochs = self.save_n_models  # should be called n_test_ens, since is number of models to ensemble over during testing
         # is multiplied by (1 + nr of test augs)
 
         #########################
         #   Assertions          #
         #########################
         if not 'class' in self.prediction_tasks:
             assert self.num_classes == 1
 
         #########################
         #   Add model specifics #
         #########################
 
         {'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs,
          'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs,
          'detection_unet': self.add_det_unet_configs, 'detection_fpn': self.add_det_fpn_configs
          }[self.model]()
 
     def rg_val_to_bin_id(self, rg_val):
         #only meant for isotropic radii!!
         # only 2D radii (x and y dims) or 1D (x or y) are expected
         return np.round(np.digitize(rg_val, self.bin_edges).mean())
 
 
     def add_det_fpn_configs(self):
 
       self.learning_rate = [1 * 1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True
       self.scheduling_criterion = 'torch_loss'
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       self.n_roi_candidates = 4 if self.dim == 2 else 6
       # max number of roi candidates to identify per image (slice in 2D, volume in 3D)
 
       # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
       self.seg_loss_mode = 'wce'
       self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1]
 
       self.fp_dice_weight = 1 if self.dim == 2 else 1
       # if <1, false positive predictions in foreground are penalized less.
 
       self.detection_min_confidence = 0.05
       # how to determine score of roi: 'max' or 'median'
       self.score_det = 'max'
 
     def add_det_unet_configs(self):
 
       self.learning_rate = [1 * 1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True
       self.scheduling_criterion = "torch_loss"
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       # max number of roi candidates to identify per image (slice in 2D, volume in 3D)
       self.n_roi_candidates = 4 if self.dim == 2 else 6
 
       # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
       self.seg_loss_mode = 'wce'
       self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1]
       # if <1, false positive predictions in foreground are penalized less.
       self.fp_dice_weight = 1 if self.dim == 2 else 1
 
       self.detection_min_confidence = 0.05
       # how to determine score of roi: 'max' or 'median'
       self.score_det = 'max'
 
       self.init_filts = 32
       self.kernel_size = 3  # ks for horizontal, normal convs
       self.kernel_size_m = 2  # ks for max pool
       self.pad = "same"  # "same" or integer, padding of horizontal convs
 
     def add_mrcnn_configs(self):
 
       self.learning_rate = [1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True  # with scheduler set in exec
       self.scheduling_criterion = max(self.model_selection_criteria, key=self.model_selection_criteria.get)
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       # number of classes for network heads: n_foreground_classes + 1 (background)
       self.head_classes = self.num_classes + 1 if 'class' in self.prediction_tasks else 2
 
       # feed +/- n neighbouring slices into channel dimension. set to None for no context.
       self.n_3D_context = None
       if self.n_3D_context is not None and self.dim == 2:
         self.n_channels *= (self.n_3D_context * 2 + 1)
 
       self.detect_while_training = True
       # disable the re-sampling of mask proposals to original size for speed-up.
       # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching),
       # mask outputs are optional.
       self.return_masks_in_train = True
       self.return_masks_in_val = True
       self.return_masks_in_test = True
 
       # feature map strides per pyramid level are inferred from architecture. anchor scales are set accordingly.
       self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]}
       # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale
       # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.)
       self.rpn_anchor_scales = {'xy': [[4], [8], [16], [32]], 'z': [[1], [2], [4], [8]]}
       # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3.
       self.pyramid_levels = [0, 1, 2, 3]
       # number of feature maps in rpn. typically lowered in 3D to save gpu-memory.
       self.n_rpn_features = 512 if self.dim == 2 else 64
 
       # anchor ratios and strides per position in feature maps.
       self.rpn_anchor_ratios = [0.5, 1., 2.]
       self.rpn_anchor_stride = 1
       # Threshold for first stage (RPN) non-maximum suppression (NMS):  LOWER == HARDER SELECTION
       self.rpn_nms_threshold = max(0.8, self.model_max_iou_resolution)
 
       # loss sampling settings.
       self.rpn_train_anchors_per_image = 4
       self.train_rois_per_image = 6 # per batch_instance
       self.roi_positive_ratio = 0.5
       self.anchor_matching_iou = 0.8
 
       # k negative example candidates are drawn from a pool of size k*shem_poolsize (stochastic hard-example mining),
       # where k<=#positive examples.
       self.shem_poolsize = 2
 
       self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3)
       self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5)
       self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10)
 
       self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
       self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
       self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]])
       self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1],
                              self.patch_size_3D[2], self.patch_size_3D[2]])  # y1,x1,y2,x2,z1,z2
 
       if self.dim == 2:
         self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4]
         self.bbox_std_dev = self.bbox_std_dev[:4]
         self.window = self.window[:4]
         self.scale = self.scale[:4]
 
       self.plot_y_max = 1.5
       self.n_plot_rpn_props = 5 if self.dim == 2 else 30  # per batch_instance (slice in 2D / patient in 3D)
 
       # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element.
       self.pre_nms_limit = 2000 if self.dim == 2 else 4000
 
       # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True,
       # since proposals of the entire batch are forwarded through second stage as one "batch".
       self.roi_chunk_size = 1300 if self.dim == 2 else 500
       self.post_nms_rois_training = 200 * (self.head_classes-1) if self.dim == 2 else 400
       self.post_nms_rois_inference = 200 * (self.head_classes-1)
 
       # Final selection of detections (refine_detections)
       self.model_max_instances_per_batch_element = 9 if self.dim == 2 else 18 # per batch element and class.
       self.detection_nms_threshold = self.model_max_iou_resolution  # needs to be > 0, otherwise all predictions are one cluster.
       self.model_min_confidence = 0.2  # iou for nms in box refining (directly after heads), should be >0 since ths>=x in mrcnn.py
 
       if self.dim == 2:
         self.backbone_shapes = np.array(
           [[int(np.ceil(self.patch_size[0] / stride)),
             int(np.ceil(self.patch_size[1] / stride))]
            for stride in self.backbone_strides['xy']])
       else:
         self.backbone_shapes = np.array(
           [[int(np.ceil(self.patch_size[0] / stride)),
             int(np.ceil(self.patch_size[1] / stride)),
             int(np.ceil(self.patch_size[2] / stride_z))]
            for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z']
                                        )])
 
       if self.model == 'retina_net' or self.model == 'retina_unet':
         # whether to use focal loss or SHEM for loss-sample selection
         self.focal_loss = False
         # implement extra anchor-scales according to https://arxiv.org/abs/1708.02002
         self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                         self.rpn_anchor_scales['xy']]
         self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                        self.rpn_anchor_scales['z']]
         self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3
 
         # pre-selection of detections for NMS-speedup. per entire batch.
         self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size
 
         # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002
         self.anchor_matching_iou = 0.7
 
         if self.model == 'retina_unet':
           self.operate_stride1 = True
diff --git a/exec.py b/exec.py
index dc570eb..5d46dd6 100644
--- a/exec.py
+++ b/exec.py
@@ -1,341 +1,341 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 """ execution script. this where all routines come together and the only script you need to call.
     refer to parse args below to see options for execution.
 """
 
 import plotting as plg
 
 import os
 import warnings
 import argparse
 import time
 
 import torch
 
 import utils.exp_utils as utils
 from evaluator import Evaluator
 from predictor import Predictor
 
 
 for msg in ["Attempting to set identical bottom==top results",
             "This figure includes Axes that are not compatible with tight_layout",
             "Data has no positive values, and therefore cannot be log-scaled.",
             ".*invalid value encountered in true_divide.*"]:
     warnings.filterwarnings("ignore", msg)
 
 
 def train(cf, logger):
     """
     performs the training routine for a given fold. saves plots and selected parameters to the experiment dir
     specified in the configs. logs to file and tensorboard.
     """
     logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format(
         cf.dim, cf.fold, cf.exp_dir, cf.model))
     logger.time("train_val")
 
     # -------------- inits and settings -----------------
     net = model.net(cf, logger).cuda()
     if cf.optimizer == "ADAM":
         optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay)
     elif cf.optimizer == "SGD":
         optimizer = torch.optim.SGD(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay, momentum=0.3)
     if cf.dynamic_lr_scheduling:
         scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=cf.scheduling_mode, factor=cf.lr_decay_factor,
                                                                     patience=cf.scheduling_patience)
     model_selector = utils.ModelSelector(cf, logger)
 
     starting_epoch = 1
     if cf.resume:
         checkpoint_path = os.path.join(cf.fold_dir, "last_state.pth")
         starting_epoch, net, optimizer, model_selector = \
             utils.load_checkpoint(checkpoint_path, net, optimizer, model_selector)
         logger.info('resumed from checkpoint {} to epoch {}'.format(checkpoint_path, starting_epoch))
 
     # prepare monitoring
     monitor_metrics = utils.prepare_monitoring(cf)
 
     logger.info('loading dataset and initializing batch generators...')
     batch_gen = data_loader.get_train_generators(cf, logger)
 
     # -------------- training -----------------
     for epoch in range(starting_epoch, cf.num_epochs + 1):
 
         logger.info('starting training epoch {}/{}'.format(epoch, cf.num_epochs))
         logger.time("train_epoch")
 
         net.train()
 
         train_results_list = []
         train_evaluator = Evaluator(cf, logger, mode='train')
 
         for i in range(cf.num_train_batches):
             logger.time("train_batch_loadfw")
             batch = next(batch_gen['train'])
             batch_gen['train'].generator.stats['roi_counts'] += batch['roi_counts']
             batch_gen['train'].generator.stats['empty_counts'] += batch['empty_counts']
 
             logger.time("train_batch_loadfw")
             logger.time("train_batch_netfw")
             results_dict = net.train_forward(batch)
             logger.time("train_batch_netfw")
             logger.time("train_batch_bw")
             optimizer.zero_grad()
             results_dict['torch_loss'].backward()
             if cf.clip_norm:
                 torch.nn.utils.clip_grad_norm_(net.parameters(), cf.clip_norm, norm_type=2) # gradient clipping
             optimizer.step()
             train_results_list.append(({k:v for k,v in results_dict.items() if k != "seg_preds"}, batch["pid"])) # slim res dict
             if not cf.server_env:
                 print("\rFinished training batch " +
                       "{}/{} in {:.1f}s ({:.2f}/{:.2f} forw load/net, {:.2f} backw).".format(i+1, cf.num_train_batches,
                                                                                              logger.get_time("train_batch_loadfw")+
                                                                                              logger.get_time("train_batch_netfw")
                                                                                              +logger.time("train_batch_bw"),
                                                                                              logger.get_time("train_batch_loadfw",reset=True),
                                                                                              logger.get_time("train_batch_netfw", reset=True),
                                                                                              logger.get_time("train_batch_bw", reset=True)), end="", flush=True)
         print()
 
         #--------------- train eval ----------------
         if (epoch-1)%cf.plot_frequency==0:
             # view an example batch
             utils.split_off_process(plg.view_batch, cf, batch, results_dict, has_colorchannels=cf.has_colorchannels,
                                     show_gt_labels=True, get_time="train-example plot",
                                     out_file=os.path.join(cf.plot_dir, 'batch_example_train_{}.png'.format(cf.fold)))
 
 
         logger.time("evals")
         _, monitor_metrics['train'] = train_evaluator.evaluate_predictions(train_results_list, monitor_metrics['train'])
         logger.time("evals")
         logger.time("train_epoch", toggle=False)
         del train_results_list
 
         #----------- validation ------------
         logger.info('starting validation in mode {}.'.format(cf.val_mode))
         logger.time("val_epoch")
         with torch.no_grad():
             net.eval()
             val_results_list = []
             val_evaluator = Evaluator(cf, logger, mode=cf.val_mode)
             val_predictor = Predictor(cf, net, logger, mode='val')
 
             for i in range(batch_gen['n_val']):
                 logger.time("val_batch")
                 batch = next(batch_gen[cf.val_mode])
                 if cf.val_mode == 'val_patient':
                     results_dict = val_predictor.predict_patient(batch)
                 elif cf.val_mode == 'val_sampling':
                     results_dict = net.train_forward(batch, is_validation=True)
                 val_results_list.append([results_dict, batch["pid"]])
                 if not cf.server_env:
                     print("\rFinished validation {} {}/{} in {:.1f}s.".format('patient' if cf.val_mode=='val_patient' else 'batch',
                                                                               i + 1, batch_gen['n_val'],
                                                                               logger.time("val_batch")), end="", flush=True)
             print()
 
             #------------ val eval -------------
             if (epoch - 1) % cf.plot_frequency == 0:
                 utils.split_off_process(plg.view_batch, cf, batch, results_dict, has_colorchannels=cf.has_colorchannels,
                                         show_gt_labels=True, get_time="val-example plot",
                                         out_file=os.path.join(cf.plot_dir, 'batch_example_val_{}.png'.format(cf.fold)))
 
             logger.time("evals")
             _, monitor_metrics['val'] = val_evaluator.evaluate_predictions(val_results_list, monitor_metrics['val'])
 
             model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch)
             del val_results_list
             #----------- monitoring -------------
             monitor_metrics.update({"lr": 
                 {str(g) : group['lr'] for (g, group) in enumerate(optimizer.param_groups)}})
             logger.metrics2tboard(monitor_metrics, global_step=epoch)
             logger.time("evals")
 
             logger.info('finished epoch {}/{}, took {:.2f}s. train total: {:.2f}s, average: {:.2f}s. val total: {:.2f}s, average: {:.2f}s.'.format(
                 epoch, cf.num_epochs, logger.get_time("train_epoch")+logger.time("val_epoch"), logger.get_time("train_epoch"),
                 logger.get_time("train_epoch", reset=True)/cf.num_train_batches, logger.get_time("val_epoch"),
                 logger.get_time("val_epoch", reset=True)/batch_gen["n_val"]))
             logger.info("time for evals: {:.2f}s".format(logger.get_time("evals", reset=True)))
 
         #-------------- scheduling -----------------
-        if not cf.dynamic_lr_scheduling:
+        if cf.dynamic_lr_scheduling:
+            scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1])
+        else:
             for param_group in optimizer.param_groups:
                 param_group['lr'] = cf.learning_rate[epoch-1]
-        else:
-            scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1])
 
     logger.time("train_val")
     logger.info("Training and validating over {} epochs took {}".format(cf.num_epochs, logger.get_time("train_val", format="hms", reset=True)))
     batch_gen['train'].generator.print_stats(logger, plot=True)
 
 def test(cf, logger, max_fold=None):
     """performs testing for a given fold (or held out set). saves stats in evaluator.
     """
     logger.time("test_fold")
     logger.info('starting testing model of fold {} in exp {}'.format(cf.fold, cf.exp_dir))
     net = model.net(cf, logger).cuda()
     batch_gen = data_loader.get_test_generator(cf, logger)
 
     test_predictor = Predictor(cf, net, logger, mode='test')
     test_results_list = test_predictor.predict_test_set(batch_gen, return_results = not hasattr(
         cf, "eval_test_separately") or not cf.eval_test_separately)
 
     if test_results_list is not None:
         test_evaluator = Evaluator(cf, logger, mode='test')
         test_evaluator.evaluate_predictions(test_results_list)
         test_evaluator.score_test_df(max_fold=max_fold)
 
     logger.info('Testing of fold {} took {}.\n'.format(cf.fold, logger.get_time("test_fold", reset=True, format="hms")))
 
 if __name__ == '__main__':
     stime = time.time()
 
     parser = argparse.ArgumentParser()
     parser.add_argument('--dataset_name', type=str, default='toy',
                         help="path to the dataset-specific code in source_dir/datasets")
     parser.add_argument('--exp_dir', type=str, default='/home/gregor/Documents/regrcnn/datasets/toy/experiments/dev',
                         help='path to experiment dir. will be created if non existent.')
     parser.add_argument('-m', '--mode', type=str,  default='train_test', help='one out of: create_exp, analysis, train, train_test, or test')
     parser.add_argument('-f', '--folds', nargs='+', type=int, default=None, help='None runs over all folds in CV. otherwise specify list of folds.')
     parser.add_argument('--server_env', default=False, action='store_true', help='change IO settings to deploy models on a cluster.')
     parser.add_argument('--data_dest', type=str, default=None, help="path to final data folder if different from config")
     parser.add_argument('--use_stored_settings', default=False, action='store_true',
                         help='load configs from existing exp_dir instead of source dir. always done for testing, '
                              'but can be set to true to do the same for training. useful in job scheduler environment, '
                              'where source code might change before the job actually runs.')
     parser.add_argument('--resume', action="store_true", default=False,
                         help='if given, resume from checkpoint(s) of the specified folds.')
     parser.add_argument('-d', '--dev', default=False, action='store_true', help="development mode: shorten everything")
 
     args = parser.parse_args()
     args.dataset_name = os.path.join("datasets", args.dataset_name) if not "datasets" in args.dataset_name else args.dataset_name
     folds = args.folds
     resume = None if args.resume in ['None', 'none'] else args.resume
 
     if args.mode == 'create_exp':
         cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=False)
         logger = utils.get_logger(cf.exp_dir, cf.server_env, -1)
         logger.info('created experiment directory at {}'.format(args.exp_dir))
 
     elif args.mode == 'train' or args.mode == 'train_test':
         cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, args.use_stored_settings)
         if args.dev:
             folds = [0,1]
             cf.batch_size, cf.num_epochs, cf.min_save_thresh, cf.save_n_models = 3 if cf.dim==2 else 1, 2, 0, 1
             cf.num_train_batches, cf.num_val_batches, cf.max_val_patients = 5, 1, 1
             cf.test_n_epochs =  cf.save_n_models
             cf.max_test_patients = 1
             torch.backends.cudnn.benchmark = cf.dim==3
         else:
             torch.backends.cudnn.benchmark = cf.cuda_benchmark
         if args.data_dest is not None:
             cf.data_dest = args.data_dest
             
         logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval)
         data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py'))
         model = utils.import_module('model', cf.model_path)
         logger.info("loaded model from {}".format(cf.model_path))
         if folds is None:
             folds = range(cf.n_cv_splits)
 
         for fold in folds:
             """k-fold cross-validation: the dataset is split into k equally-sized folds, one used for validation,
             one for testing, the rest for training. This loop iterates k-times over the dataset, cyclically moving the
             splits. k==folds, fold in [0,folds) says which split is used for testing.
             """
             cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold)); cf.fold = fold
             logger.set_logfile(fold=fold)
             cf.resume = resume
             if not os.path.exists(cf.fold_dir):
                 os.mkdir(cf.fold_dir)
             train(cf, logger)
             cf.resume = None
             if args.mode == 'train_test':
                 test(cf, logger)
 
     elif args.mode == 'test':
         cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False)
         if args.data_dest is not None:
             cf.data_dest = args.data_dest
         logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval)
         data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py'))
         model = utils.import_module('model', cf.model_path)
         logger.info("loaded model from {}".format(cf.model_path))
 
         fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if
                      os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")])
         if folds is None:
             folds = range(cf.n_cv_splits)
         if args.dev:
             folds = folds[:2]
             cf.batch_size, cf.max_test_patients, cf.test_n_epochs = 1 if cf.dim==2 else 1, 2, 2
         else:
             torch.backends.cudnn.benchmark = cf.cuda_benchmark
         for fold in folds:
             cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold)); cf.fold = fold
             logger.set_logfile(fold=fold)
             if cf.fold_dir in fold_dirs:
                 test(cf, logger, max_fold=max([int(f[-1]) for f in fold_dirs]))
             else:
                 logger.info("Skipping fold {} since no model parameters found.".format(fold))
     # load raw predictions saved by predictor during testing, run aggregation algorithms and evaluation.
     elif args.mode == 'analysis':
         """ analyse already saved predictions.
         """
         cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False)
         logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval)
 
         if cf.held_out_test_set and not cf.eval_test_fold_wise:
             predictor = Predictor(cf, net=None, logger=logger, mode='analysis')
             results_list = predictor.load_saved_predictions()
             logger.info('starting evaluation...')
             cf.fold = 0
             evaluator = Evaluator(cf, logger, mode='test')
             evaluator.evaluate_predictions(results_list)
             evaluator.score_test_df(max_fold=0)
         else:
             fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if
                          os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")])
             if args.dev:
                 fold_dirs = fold_dirs[:1]
             if folds is None:
                 folds = range(cf.n_cv_splits)
             for fold in folds:
                 cf.fold = fold; cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold))
                 logger.set_logfile(fold=fold)
                 if cf.fold_dir in fold_dirs:
                     predictor = Predictor(cf, net=None, logger=logger, mode='analysis')
                     results_list = predictor.load_saved_predictions()
                     # results_list[x][1] is pid, results_list[x][0] is list of len samples-per-patient, each entry hlds
                     # list of boxes per that sample, i.e., len(results_list[x][y][0]) would be nr of boxes in sample y of patient x
                     logger.info('starting evaluation...')
                     evaluator = Evaluator(cf, logger, mode='test')
                     evaluator.evaluate_predictions(results_list)
                     max_fold = max([int(f[-1]) for f in fold_dirs])
                     evaluator.score_test_df(max_fold=max_fold)
                 else:
                     logger.info("Skipping fold {} since no model parameters found.".format(fold))
     else:
         raise ValueError('mode "{}" specified in args is not implemented.'.format(args.mode))
         
     mins, secs = divmod((time.time() - stime), 60)
     h, mins = divmod(mins, 60)
     t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
     logger.info("{} total runtime: {}".format(os.path.split(__file__)[1], t))
     del logger
     torch.cuda.empty_cache()
 
diff --git a/requirements.txt b/requirements.txt
index 5a09f51..18097a5 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,64 +1,67 @@
 absl-py==0.8.1
 backcall==0.1.0
 batchgenerators==0.19.7
 cachetools==3.1.1
 certifi==2019.11.28
 chardet==3.0.4
 cycler==0.10.0
 Cython==0.29.14
 decorator==4.4.1
 future==0.18.2
 google-auth==1.7.2
 google-auth-oauthlib==0.4.1
 grpcio==1.25.0
 idna==2.8
 imageio==2.6.1
 ipython==7.9.0
 ipython-genutils==0.2.0
 jedi==0.15.1
 joblib==0.14.0
 kiwisolver==1.1.0
 linecache2==1.0.0
 Markdown==3.1.1
 matplotlib==3.1.2
 networkx==2.4
 nms-extension==0.0.0
 numpy==1.17.4
 nvidia-ml-py3==7.352.0
 oauthlib==3.1.0
-pandas==0.25.3
+pandas==1.0.3
 parso==0.5.1
 pexpect==4.7.0
 pickleshare==0.7.5
-Pillow==6.2.1
+Pillow==7.1.0
 prompt-toolkit==2.0.10
 protobuf==3.11.1
 psutil==5.7.0
 ptyprocess==0.6.0
 pyasn1==0.4.8
 pyasn1-modules==0.2.7
 Pygments==2.5.2
 pyparsing==2.4.5
 python-dateutil==2.8.1
 pytz==2019.3
 PyWavelets==1.1.1
 RegRCNN==0.0.2
 requests==2.22.0
 requests-oauthlib==1.3.0
+RoIAlign-extension-2D==0.0.0
+RoIAlign-extension-3D==0.0.0
 rsa==4.0
 scikit-image==0.16.2
 scikit-learn==0.21.3
 scipy==1.3.1
 SimpleITK==1.2.3
 six==1.13.0
-tensorboard==2.0.2
+tensorboard==2.2.0
+tensorboard-plugin-wit==1.6.0.post2
 threadpoolctl==2.0.0
-torch==1.3.1
-torchvision==0.4.2
+torch==1.4.0
+torchvision==0.5.0
 tqdm==4.39.0
 traceback2==1.4.0
 traitlets==4.3.3
 unittest2==1.1.0
 urllib3==1.25.7
 wcwidth==0.1.7
-Werkzeug==0.16.0
+Werkzeug==1.0.1
diff --git a/setup.py b/setup.py
index 90510b3..b59bbe0 100644
--- a/setup.py
+++ b/setup.py
@@ -1,60 +1,72 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 from setuptools import find_packages, setup
-import os, site
+import os, sys, subprocess
 
 def parse_requirements(filename, exclude=[]):
     lineiter = (line.strip() for line in open(filename))
     return [line for line in lineiter if line and not line.startswith("#") and not line.split("==")[0] in exclude]
 
+def pip_install(item):
+    subprocess.check_call([sys.executable, "-m", "pip", "install", item])
+
 def install_custom_ext(setup_path):
-    os.system("python "+setup_path+" install")
-    return
+    try:
+        pip_install(setup_path)
+    except Exception as e:
+        print("Could not install custom extension {} from source due to error:\n{}\n".format(path, e) +
+              "Trying to install from pre-compiled wheel.")
+        dist_path = setup_path+"/dist"
+        wheel_file = [fn for fn in os.listdir(dist_path) if fn.endswith(".whl")][0]
+        pip_install(os.path.join(dist_path, wheel_file))
 
 def clean():
     """Custom clean command to tidy up the project root."""
     os.system('rm -vrf ./build ./dist ./*.pyc ./*.tgz ./*.egg-info')
 
-req_file = "requirements.txt"
-custom_exts = ["nms-extension", "RoIAlign-extension-2D", "RoIAlign-extension-3D"]
-install_reqs = parse_requirements(req_file, exclude=custom_exts)
-
-setup(name='RegRCNN',
-      version='0.0.2',
-      url="https://github.com/MIC-DKFZ/RegRCNN",
-      author='G. Ramien, P. Jaeger, MIC at DKFZ Heidelberg',
-      author_email='g.ramien@dkfz.de',
-      licence="Apache 2.0",
-      description="Medical Object-Detection Toolkit incl. Regression Capability.",
-      classifiers=[
-          "Development Status :: 4 - Beta",
-          "Intended Audience :: Developers",
-          "Programming Language :: Python :: 3.7"
-      ],
-      packages=find_packages(exclude=['test', 'test.*']),
-      install_requires=install_reqs,
-      )
-
-custom_exts =  ["custom_extensions/nms", "custom_extensions/roi_align"]
-for path in custom_exts:
-    setup_path = os.path.join(path, "setup.py")
-    try:
-        install_custom_ext(setup_path)
-    except Exception as e:
-        print("FAILED to install custom extension {} due to Error:\n{}".format(path, e))
 
-clean()
\ No newline at end of file
+
+if __name__ == "__main__":
+
+    req_file = "requirements.txt"
+    custom_exts = ["nms-extension", "RoIAlign-extension-2D", "RoIAlign-extension-3D"]
+    install_reqs = parse_requirements(req_file, exclude=custom_exts)
+
+    setup(name='RegRCNN',
+          version='0.0.2',
+          url="https://github.com/MIC-DKFZ/RegRCNN",
+          author='G. Ramien, P. Jaeger, MIC at DKFZ Heidelberg',
+          author_email='g.ramien@dkfz.de',
+          licence="Apache 2.0",
+          description="Medical Object-Detection Toolkit incl. Regression Capability.",
+          classifiers=[
+              "Development Status :: 4 - Beta",
+              "Intended Audience :: Developers",
+              "Programming Language :: Python :: 3.7"
+          ],
+          packages=find_packages(exclude=['test', 'test.*']),
+          install_requires=install_reqs,
+          )
+
+    custom_exts =  ["custom_extensions/nms", "custom_extensions/roi_align/2D", "custom_extensions/roi_align/3D"]
+    for path in custom_exts:
+        try:
+            install_custom_ext(path)
+        except Exception as e:
+            print("FAILED to install custom extension {} due to Error:\n{}".format(path, e))
+
+    clean()
\ No newline at end of file