diff --git a/.gitignore b/.gitignore
index 0096abe..fc1fdd0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,15 +1,16 @@
 *.pyc
 *.pickle
 *.ipynb_checkpoints*
 *.pkl
 *.log
 *.png
 *.jpg
 *.pdf
 *.egg-info
+*.so
 sandbox/*
 .idea/*
 __pycache__/
 
 
 !/assets/*
diff --git a/custom_extensions/roi_align/src/RoIAlign_cuda.cu b/custom_extensions/roi_align/src/RoIAlign_cuda.cu
index 47c870a..39426bf 100644
--- a/custom_extensions/roi_align/src/RoIAlign_cuda.cu
+++ b/custom_extensions/roi_align/src/RoIAlign_cuda.cu
@@ -1,422 +1,422 @@
 /*
 ROIAlign implementation in CUDA from pytorch framework
 (https://github.com/pytorch/vision/tree/master/torchvision/csrc/cuda on Nov 14 2019)
 
 */
 
 #include <ATen/ATen.h>
 #include <ATen/TensorUtils.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAApplyUtils.cuh>
 #include <typeinfo>
 #include "cuda_helpers.h"
 
 template <typename T>
 __device__ T bilinear_interpolate(
     const T* input,
     const int height,
     const int width,
     T y,
     T x,
     const int index /* index for debug only*/) {
   // deal with cases that inverse elements are out of feature map boundary
   if (y < -1.0 || y > height || x < -1.0 || x > width) {
     // empty
     return 0;
   }
 
   if (y <= 0)
     y = 0;
   if (x <= 0)
     x = 0;
 
   int y_low = (int)y;
   int x_low = (int)x;
   int y_high;
   int x_high;
 
   if (y_low >= height - 1) {
     y_high = y_low = height - 1;
     y = (T)y_low;
   } else {
     y_high = y_low + 1;
   }
 
   if (x_low >= width - 1) {
     x_high = x_low = width - 1;
     x = (T)x_low;
   } else {
     x_high = x_low + 1;
   }
 
   T ly = y - y_low;
   T lx = x - x_low;
   T hy = 1. - ly, hx = 1. - lx;
 
   // do bilinear interpolation
   T v1 = input[y_low * width + x_low];
   T v2 = input[y_low * width + x_high];
   T v3 = input[y_high * width + x_low];
   T v4 = input[y_high * width + x_high];
   T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
 
   T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
 
   return val;
 }
 
 template <typename T>
 __global__ void RoIAlignForward(
     const int nthreads,
     const T* input,
     const T spatial_scale,
     const int channels,
     const int height,
     const int width,
     const int pooled_height,
     const int pooled_width,
     const int sampling_ratio,
     const T* rois,
     T* output) {
   CUDA_1D_KERNEL_LOOP(index, nthreads) {
     // (n, c, ph, pw) is an element in the pooled output
     const int pw = index % pooled_width;
     const int ph = (index / pooled_width) % pooled_height;
     const int c = (index / pooled_width / pooled_height) % channels;
     const int n = index / pooled_width / pooled_height / channels;
 
     const T* offset_rois = rois + n * 5;
     int roi_batch_ind = offset_rois[0];
 
     // Do not using rounding; this implementation detail is critical
     T roi_start_h = offset_rois[1] * spatial_scale;
     T roi_start_w = offset_rois[2] * spatial_scale;
     T roi_end_h = offset_rois[3] * spatial_scale;
     T roi_end_w = offset_rois[4] * spatial_scale;
 
     // Force malformed ROIs to be 1x1
     T roi_width = max(roi_end_w - roi_start_w, (T)1.);
     T roi_height = max(roi_end_h - roi_start_h, (T)1.);
 
     T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
     T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
 
     const T* offset_input =
         input + (roi_batch_ind * channels + c) * height * width;
 
     // We use roi_bin_grid to sample the grid and mimic integral
     int roi_bin_grid_h = (sampling_ratio > 0)
         ? sampling_ratio
         : ceil(roi_height / pooled_height); // e.g., = 2
     int roi_bin_grid_w =
         (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
 
     // We do average (integral) pooling inside a bin
     const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
     T output_val = 0.;
     for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
     {
       const T y = roi_start_h + ph * bin_size_h +
-          static_cast<T>(iy + .5f) * (bin_size_h - 1.f) / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
+          static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
       for (int ix = 0; ix < roi_bin_grid_w; ix++) {
         const T x = roi_start_w + pw * bin_size_w +
-            static_cast<T>(ix + .5f) * (bin_size_w - 1.f) / static_cast<T>(roi_bin_grid_w);
+            static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
         T val = bilinear_interpolate(offset_input, height, width, y, x, index);
         output_val += val;
       }
     }
     output_val /= count;
 
     output[index] = output_val;
   }
 }
 
 template <typename T>
 __device__ void bilinear_interpolate_gradient(
     const int height,
     const int width,
     T y,
     T x,
     T& w1,
     T& w2,
     T& w3,
     T& w4,
     int& x_low,
     int& x_high,
     int& y_low,
     int& y_high,
     const int index /* index for debug only*/) {
   // deal with cases that inverse elements are out of feature map boundary
   if (y < -1.0 || y > height || x < -1.0 || x > width) {
     // empty
     w1 = w2 = w3 = w4 = 0.;
     x_low = x_high = y_low = y_high = -1;
     return;
   }
 
   if (y <= 0)
     y = 0;
   if (x <= 0)
     x = 0;
 
   y_low = (int)y;
   x_low = (int)x;
 
   if (y_low >= height - 1) {
     y_high = y_low = height - 1;
     y = (T)y_low;
   } else {
     y_high = y_low + 1;
   }
 
   if (x_low >= width - 1) {
     x_high = x_low = width - 1;
     x = (T)x_low;
   } else {
     x_high = x_low + 1;
   }
 
   T ly = y - y_low;
   T lx = x - x_low;
   T hy = 1. - ly, hx = 1. - lx;
 
   // reference in forward
   // T v1 = input[y_low * width + x_low];
   // T v2 = input[y_low * width + x_high];
   // T v3 = input[y_high * width + x_low];
   // T v4 = input[y_high * width + x_high];
   // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
 
   w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
 
   return;
 }
 
 template <typename T>
 __global__ void RoIAlignBackward(
     const int nthreads,
     const T* grad_output,
     const T spatial_scale,
     const int channels,
     const int height,
     const int width,
     const int pooled_height,
     const int pooled_width,
     const int sampling_ratio,
     T* grad_input,
     const T* rois,
     const int n_stride,
     const int c_stride,
     const int h_stride,
     const int w_stride)
 {
   CUDA_1D_KERNEL_LOOP(index, nthreads) {
     // (n, c, ph, pw) is an element in the pooled output
     int pw = index % pooled_width;
     int ph = (index / pooled_width) % pooled_height;
     int c = (index / pooled_width / pooled_height) % channels;
     int n = index / pooled_width / pooled_height / channels;
 
     const T* offset_rois = rois + n * 5;
     int roi_batch_ind = offset_rois[0];
 
     // Do not using rounding; this implementation detail is critical
     T roi_start_h = offset_rois[1] * spatial_scale;
     T roi_start_w = offset_rois[2] * spatial_scale;
     T roi_end_h = offset_rois[3] * spatial_scale;
     T roi_end_w = offset_rois[4] * spatial_scale;
 
     // Force malformed ROIs to be 1x1
     T roi_width = max(roi_end_w - roi_start_w, (T)1.);
     T roi_height = max(roi_end_h - roi_start_h, (T)1.);
     T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
     T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
 
     T* offset_grad_input =
         grad_input + ((roi_batch_ind * channels + c) * height * width);
 
     // We need to index the gradient using the tensor strides to access the
     // correct values.
     int output_offset = n * n_stride + c * c_stride;
     const T* offset_grad_output = grad_output + output_offset;
     const T grad_output_this_bin =
         offset_grad_output[ph * h_stride + pw * w_stride];
 
     // We use roi_bin_grid to sample the grid and mimic integral
     int roi_bin_grid_h = (sampling_ratio > 0)
         ? sampling_ratio
         : ceil(roi_height / pooled_height); // e.g., = 2
     int roi_bin_grid_w =
         (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
 
     // We do average (integral) pooling inside a bin
     const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
 
     for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
     {
       const T y = roi_start_h + ph * bin_size_h +
-          static_cast<T>(iy + .5f) * (bin_size_h - 1.f) / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
+          static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
       for (int ix = 0; ix < roi_bin_grid_w; ix++) {
         const T x = roi_start_w + pw * bin_size_w  +
-            static_cast<T>(ix + .5f) * (bin_size_w - 1.f) / static_cast<T>(roi_bin_grid_w);
+            static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
 
         T w1, w2, w3, w4;
         int x_low, x_high, y_low, y_high;
 
         bilinear_interpolate_gradient(
             height,
             width,
             y,
             x,
             w1,
             w2,
             w3,
             w4,
             x_low,
             x_high,
             y_low,
             y_high,
             index);
 
         T g1 = grad_output_this_bin * w1 / count;
         T g2 = grad_output_this_bin * w2 / count;
         T g3 = grad_output_this_bin * w3 / count;
         T g4 = grad_output_this_bin * w4 / count;
 
         if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
           atomicAdd(
               offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
           atomicAdd(
               offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
           atomicAdd(
               offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
           atomicAdd(
               offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
         } // if
       } // ix
     } // iy
   } // CUDA_1D_KERNEL_LOOP
 } // RoIAlignBackward
 
 at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, const at::Tensor& rois, const float spatial_scale,
                                 const int pooled_height, const int pooled_width, const int sampling_ratio) {
   /*
    input: feature-map tensor, shape (batch, n_channels, y, x(, z))
    */
   AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor");
   AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
 
   at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
 
   at::CheckedFrom c = "ROIAlign_forward_cuda";
   at::checkAllSameGPU(c, {input_t, rois_t});
   at::checkAllSameType(c, {input_t, rois_t});
 
   at::cuda::CUDAGuard device_guard(input.device());
 
   int num_rois = rois.size(0);
   int channels = input.size(1);
   int height = input.size(2);
   int width = input.size(3);
 
   at::Tensor output = at::zeros(
       {num_rois, channels, pooled_height, pooled_width}, input.options());
 
   auto output_size = num_rois * pooled_height * pooled_width * channels;
   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
   dim3 grid(std::min(
       at::cuda::ATenCeilDiv(
           static_cast<int64_t>(output_size), static_cast<int64_t>(512)),
       static_cast<int64_t>(4096)));
   dim3 block(512);
 
   if (output.numel() == 0) {
     AT_CUDA_CHECK(cudaGetLastError());
     return output;
   }
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign_forward", [&] {
     RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>(
         output_size,
         input.contiguous().data_ptr<scalar_t>(),
         spatial_scale,
         channels,
         height,
         width,
         pooled_height,
         pooled_width,
         sampling_ratio,
         rois.contiguous().data_ptr<scalar_t>(),
         output.data_ptr<scalar_t>());
   });
   AT_CUDA_CHECK(cudaGetLastError());
   return output;
 }
 
 at::Tensor ROIAlign_backward_cuda(
     const at::Tensor& grad,
     const at::Tensor& rois,
     const float spatial_scale,
     const int pooled_height,
     const int pooled_width,
     const int batch_size,
     const int channels,
     const int height,
     const int width,
     const int sampling_ratio) {
   AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor");
   AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
 
   at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
 
   at::CheckedFrom c = "ROIAlign_backward_cuda";
   at::checkAllSameGPU(c, {grad_t, rois_t});
   at::checkAllSameType(c, {grad_t, rois_t});
 
   at::cuda::CUDAGuard device_guard(grad.device());
 
   at::Tensor grad_input =
       at::zeros({batch_size, channels, height, width}, grad.options());
 
   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
   dim3 grid(std::min(
       at::cuda::ATenCeilDiv(
           static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)),
       static_cast<int64_t>(4096)));
   dim3 block(512);
 
   // handle possibly empty gradients
   if (grad.numel() == 0) {
     AT_CUDA_CHECK(cudaGetLastError());
     return grad_input;
   }
 
   int n_stride = grad.stride(0);
   int c_stride = grad.stride(1);
   int h_stride = grad.stride(2);
   int w_stride = grad.stride(3);
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_backward", [&] {
     RoIAlignBackward<scalar_t><<<grid, block, 0, stream>>>(
         grad.numel(),
         grad.data_ptr<scalar_t>(),
         spatial_scale,
         channels,
         height,
         width,
         pooled_height,
         pooled_width,
         sampling_ratio,
         grad_input.data_ptr<scalar_t>(),
         rois.contiguous().data_ptr<scalar_t>(),
         n_stride,
         c_stride,
         h_stride,
         w_stride);
   });
   AT_CUDA_CHECK(cudaGetLastError());
   return grad_input;
 }
\ No newline at end of file
diff --git a/custom_extensions/roi_align/src/RoIAlign_cuda_3d.cu b/custom_extensions/roi_align/src/RoIAlign_cuda_3d.cu
index 0c75a34..26cde29 100644
--- a/custom_extensions/roi_align/src/RoIAlign_cuda_3d.cu
+++ b/custom_extensions/roi_align/src/RoIAlign_cuda_3d.cu
@@ -1,487 +1,487 @@
 /*
 ROIAlign implementation in CUDA from pytorch framework
 (https://github.com/pytorch/vision/tree/master/torchvision/csrc/cuda on Nov 14 2019)
 
 Adapted for additional 3D capability by G. Ramien, DKFZ Heidelberg
 */
 
 #include <ATen/ATen.h>
 #include <ATen/TensorUtils.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAApplyUtils.cuh>
 #include <cstdio>
 #include "cuda_helpers.h"
 
 /*-------------- gpu kernels -----------------*/
 
 template <typename T>
 __device__ T linear_interpolate(const T xl, const T val_low, const T val_high){
 
   T val = (val_high - val_low) * xl + val_low;
   return val;
 }
 
 template <typename T>
 __device__ T trilinear_interpolate(const T* input, const int height, const int width, const int depth,
                 T y, T x, T z, const int index /* index for debug only*/) {
   // deal with cases that inverse elements are out of feature map boundary
   if (y < -1.0 || y > height || x < -1.0 || x > width || z < -1.0 || z > depth) {
     // empty
     return 0;
   }
   if (y <= 0)
     y = 0;
   if (x <= 0)
     x = 0;
   if (z <= 0)
     z = 0;
 
   int y0 = (int)y;
   int x0 = (int)x;
   int z0 = (int)z;
   int y1;
   int x1;
   int z1;
 
   if (y0 >= height - 1) {
   /*if nearest gridpoint to y on the lower end is on border or border-1, set low, high, mid(=actual point) to border-1*/
     y1 = y0 = height - 1;
     y = (T)y0;
   } else {
     /* y1 is one pixel from y0, y is the actual point somewhere in between  */
     y1 = y0 + 1;
   }
   if (x0 >= width - 1) {
     x1 = x0 = width - 1;
     x = (T)x0;
   } else {
     x1 = x0 + 1;
   }
   if (z0 >= depth - 1) {
     z1 = z0 = depth - 1;
     z = (T)z0;
   } else {
     z1 = z0 + 1;
   }
 
 
   // do linear interpolation of x values
   // distance of actual point to lower boundary point, already normalized since x_high - x0 = 1
   T dis = x - x0;
   /*  accessing element b,c,y,x,z in 1D-rolled-out array of a tensor with dimensions (B, C, Y, X, Z):
       tensor[b,c,y,x,z] = arr[ (((b*C+c)*Y+y)*X + x)*Z + z ] = arr[ alpha + (y*X + x)*Z + z ]
       with alpha = batch&channel locator = (b*C+c)*YXZ.
-      hence, as current input pointer is already offset by alpha: y,x,z at input[( y*X + x)*Z + z], where
+      hence, as current input pointer is already offset by alpha: y,x,z is at input[( y*X + x)*Z + z], where
       X = width, Z = depth.
   */
   T x00 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z0], input[(y0*width+ x1)*depth+z0]);
   T x10 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z0], input[(y1*width+ x1)*depth+z0]);
   T x01 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z1], input[(y0*width+ x1)*depth+z1]);
   T x11 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z1], input[(y1*width+ x1)*depth+z1]);
 
   // linear interpol of y values = bilinear interpol of f(x,y)
   dis = y - y0;
   T xy0 = linear_interpolate(dis, x00, x10);
   T xy1 = linear_interpolate(dis, x01, x11);
 
   // linear interpol of z value = trilinear interpol of f(x,y,z)
   dis = z - z0;
   T xyz = linear_interpolate(dis, xy0, xy1);
 
   return xyz;
 }
 
 template <typename T>
 __device__ void trilinear_interpolate_gradient(const int height, const int width, const int depth, T y, T x, T z,
     T& g000, T& g001, T& g010, T& g100, T& g011, T& g101, T& g110, T& g111,
     int& x0, int& x1, int& y0, int& y1, int& z0, int&z1, const int index /* index for debug only*/)
 {
   // deal with cases that inverse elements are out of feature map boundary
   if (y < -1.0 || y > height || x < -1.0 || x > width || z < -1.0 || z > depth) {
     // empty
     g000 = g001 = g010 = g100 = g011 = g101 = g110 = g111 = 0.;
     x0 = x1 = y0 = y1 = z0 = z1 = -1;
     return;
   }
 
   if (y <= 0)
     y = 0;
   if (x <= 0)
     x = 0;
   if (z <= 0)
     z = 0;
 
   y0 = (int)y;
   x0 = (int)x;
   z0 = (int)z;
 
   if (y0 >= height - 1) {
     y1 = y0 = height - 1;
     y = (T)y0;
   } else {
     y1 = y0 + 1;
   }
 
   if (x0 >= width - 1) {
     x1 = x0 = width - 1;
     x = (T)x0;
   } else {
     x1 = x0 + 1;
   }
 
   if (z0 >= depth - 1) {
     z1 = z0 = depth - 1;
     z = (T)z0;
   } else {
     z1 = z0 + 1;
   }
 
   // forward calculations are added as hints
   T dis_x = x - x0;
   //T x00 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z0], input[(y0*width+ x1)*depth+z0]); // v000, v100
   //T x10 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z0], input[(y1*width+ x1)*depth+z0]); // v010, v110
   //T x01 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z1], input[(y0*width+ x1)*depth+z1]); // v001, v101
   //T x11 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z1], input[(y1*width+ x1)*depth+z1]); // v011, v111
 
   // linear interpol of y values = bilinear interpol of f(x,y)
   T dis_y = y - y0;
   //T xy0 = linear_interpolate(dis, x00, x10);
   //T xy1 = linear_interpolate(dis, x01, x11);
 
   // linear interpol of z value = trilinear interpol of f(x,y,z)
   T dis_z = z - z0;
   //T xyz = linear_interpolate(dis, xy0, xy1);
 
   /* need: grad_i := d(xyz)/d(v_i) with v_i = input_value_i  for all i = 0,..,7 (eight input values --> eight-entry gradient)
      d(lin_interp(dis,x,y))/dx = (-dis +1) and d(lin_interp(dis,x,y))/dy = dis --> derivatives are indep of x,y.
      notation: gxyz = gradient for d(trilin_interp)/d(input_value_at_xyz)
      below grads were calculated by hand
      save time by reusing (1-dis_x) = 1-x+x0 = x1-x =: dis_x1 */
   T dis_x1 = (1-dis_x), dis_y1 = (1-dis_y), dis_z1 = (1-dis_z);
 
   g000 = dis_z1 * dis_y1  * dis_x1;
   g001 = dis_z  * dis_y1  * dis_x1;
   g010 = dis_z1 * dis_y   * dis_x1;
   g100 = dis_z1 * dis_y1  * dis_x;
   g011 = dis_z  * dis_y   * dis_x1;
   g101 = dis_z  * dis_y1  * dis_x;
   g110 = dis_z1 * dis_y   * dis_x;
   g111 = dis_z  * dis_y   * dis_x;
 
   return;
 }
 
 template <typename T>
 __global__ void RoIAlignForward(const int nthreads, const T* input, const T spatial_scale, const int channels,
     const int height, const int width, const int depth, const int pooled_height, const int pooled_width,
     const int pooled_depth, const int sampling_ratio, const T* rois, T* output)
 {
 
   CUDA_1D_KERNEL_LOOP(index, nthreads) {
     // (n, c, ph, pw, pd) is an element in the pooled output
     int pd =  index % pooled_depth;
     int pw = (index / pooled_depth) % pooled_width;
     int ph = (index / pooled_depth / pooled_width) % pooled_height;
     int c  = (index / pooled_depth / pooled_width / pooled_height) % channels;
     int n  =  index / pooled_depth / pooled_width / pooled_height / channels;
 
 
     // rois are (y1,x1,y2,x2,z1,z2) --> tensor of shape (n_rois, 6)
     const T* offset_rois = rois + n * 7;
     int roi_batch_ind = offset_rois[0];
     // Do not use rounding; this implementation detail is critical
     T roi_start_h = offset_rois[1] * spatial_scale;
     T roi_start_w = offset_rois[2] * spatial_scale;
     T roi_end_h = offset_rois[3] * spatial_scale;
     T roi_end_w = offset_rois[4] * spatial_scale;
     T roi_start_d = offset_rois[5] * spatial_scale;
     T roi_end_d = offset_rois[6] * spatial_scale;
 
     // Force malformed ROIs to be 1x1
     T roi_height = max(roi_end_h - roi_start_h, (T)1.);
     T roi_width = max(roi_end_w - roi_start_w, (T)1.);
     T roi_depth = max(roi_end_d - roi_start_d, (T)1.);
 
     T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
     T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
     T bin_size_d = static_cast<T>(roi_depth) / static_cast<T>(pooled_depth);
 
     const T* offset_input =
         input + (roi_batch_ind * channels + c) * height * width * depth;
 
     // We use roi_bin_grid to sample the grid and mimic integral
     // roi_bin_grid == nr of sampling points per bin >= 1
     int roi_bin_grid_h =
         (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
     int roi_bin_grid_w =
         (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
     int roi_bin_grid_d =
         (sampling_ratio > 0) ? sampling_ratio : ceil(roi_depth / pooled_depth);
 
     // We do average (integral) pooling inside a bin
     const T n_voxels = roi_bin_grid_h * roi_bin_grid_w * roi_bin_grid_d; // e.g. = 4
 
     T output_val = 0.;
     for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
     {
       const T y = roi_start_h + ph * bin_size_h +
           static_cast<T>(iy + .5f) * (bin_size_h - 1.f) / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5, always in the middle of two grid pointsk
 
       for (int ix = 0; ix < roi_bin_grid_w; ix++)
       {
         const T x = roi_start_w + pw * bin_size_w +
             static_cast<T>(ix + .5f) * (bin_size_w - 1.f) / static_cast<T>(roi_bin_grid_w);
 
         for (int iz = 0; iz < roi_bin_grid_d; iz++)
         {
           const T z = roi_start_d + pd * bin_size_d +
               static_cast<T>(iz + .5f) * (bin_size_d - 1.f) / static_cast<T>(roi_bin_grid_d);
           // TODO verify trilinear interpolation
           T val = trilinear_interpolate(offset_input, height, width, depth, y, x, z, index);
           output_val += val;
         } // z iterator and calc+add value
       } // x iterator
     } // y iterator
     output_val /= n_voxels;
 
     output[index] = output_val;
   }
 }
 
 template <typename T>
 __global__ void RoIAlignBackward(const int nthreads, const T* grad_output, const T spatial_scale, const int channels,
     const int height, const int width, const int depth, const int pooled_height, const int pooled_width,
     const int pooled_depth, const int sampling_ratio, T* grad_input, const T* rois,
     const int n_stride, const int c_stride, const int h_stride, const int w_stride, const int d_stride)
 {
 
   CUDA_1D_KERNEL_LOOP(index, nthreads) {
     // (n, c, ph, pw, pd) is an element in the pooled output
     int pd =  index % pooled_depth;
     int pw = (index / pooled_depth) % pooled_width;
     int ph = (index / pooled_depth / pooled_width) % pooled_height;
     int c  = (index / pooled_depth / pooled_width / pooled_height) % channels;
     int n  =  index / pooled_depth / pooled_width / pooled_height / channels;
 
 
     const T* offset_rois = rois + n * 7;
     int roi_batch_ind = offset_rois[0];
 
     // Do not using rounding; this implementation detail is critical
     T roi_start_h = offset_rois[1] * spatial_scale;
     T roi_start_w = offset_rois[2] * spatial_scale;
     T roi_end_h = offset_rois[3] * spatial_scale;
     T roi_end_w = offset_rois[4] * spatial_scale;
     T roi_start_d = offset_rois[5] * spatial_scale;
     T roi_end_d = offset_rois[6] * spatial_scale;
 
 
     // Force malformed ROIs to be 1x1
     T roi_width = max(roi_end_w - roi_start_w, (T)1.);
     T roi_height = max(roi_end_h - roi_start_h, (T)1.);
     T roi_depth = max(roi_end_d - roi_start_d, (T)1.);
     T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
     T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
     T bin_size_d = static_cast<T>(roi_depth) / static_cast<T>(pooled_depth);
 
     // offset: index b,c,y,x,z of tensor of shape (B,C,Y,X,Z) is
     // b*C*Y*X*Z + c * Y*X*Z + y * X*Z + x *Z + z = (b*C+c)Y*X*Z + ...
     T* offset_grad_input =
         grad_input + ((roi_batch_ind * channels + c) * height * width * depth);
 
     // We need to index the gradient using the tensor strides to access the correct values.
     int output_offset = n * n_stride + c * c_stride;
     const T* offset_grad_output = grad_output + output_offset;
     const T grad_output_this_bin = offset_grad_output[ph * h_stride + pw * w_stride + pd * d_stride];
 
     // We use roi_bin_grid to sample the grid and mimic integral
     int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
     int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
     int roi_bin_grid_d = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_depth / pooled_depth);
 
     // We do average (integral) pooling inside a bin
     const T n_voxels = roi_bin_grid_h * roi_bin_grid_w * roi_bin_grid_d; // e.g. = 6
 
     for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
     {
       const T y = roi_start_h + ph * bin_size_h +
           static_cast<T>(iy + .5f) * (bin_size_h - 1.f) / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
 
       for (int ix = 0; ix < roi_bin_grid_w; ix++)
       {
         const T x = roi_start_w + pw * bin_size_w +
           static_cast<T>(ix + .5f) * (bin_size_w - 1.f) / static_cast<T>(roi_bin_grid_w);
 
         for (int iz = 0; iz < roi_bin_grid_d; iz++)
         {
           const T z = roi_start_d + pd * bin_size_d +
               static_cast<T>(iz + .5f) * (bin_size_d - 1.f) / static_cast<T>(roi_bin_grid_d);
 
           T g000, g001, g010, g100, g011, g101, g110, g111; // will hold the current partial derivatives
           int x0, x1, y0, y1, z0, z1;
           /* notation: gxyz = gradient at xyz, where x,y,z need to lie on feature-map grid (i.e., =x0,x1 etc.) */
           trilinear_interpolate_gradient(height, width, depth, y, x, z,
                                          g000, g001, g010, g100, g011, g101, g110, g111,
                                          x0, x1, y0, y1, z0, z1, index);
           /* chain rule: derivatives (i.e., the gradient) of trilin_interpolate(v1,v2,v3,v4,...) (div by n_voxels
              as we actually need gradient of whole roi_align) are multiplied with gradient so far*/
           g000 *= grad_output_this_bin / n_voxels;
           g001 *= grad_output_this_bin / n_voxels;
           g010 *= grad_output_this_bin / n_voxels;
           g100 *= grad_output_this_bin / n_voxels;
           g011 *= grad_output_this_bin / n_voxels;
           g101 *= grad_output_this_bin / n_voxels;
           g110 *= grad_output_this_bin / n_voxels;
           g111 *= grad_output_this_bin / n_voxels;
 
           if (x0 >= 0 && x1 >= 0 && y0 >= 0 && y1 >= 0 && z0 >= 0 && z1 >= 0)
           { // atomicAdd(address, content) reads content under address, adds content to it, while: no other thread
             // can interfere with the memory at address during this operation (thread lock, therefore "atomic").
             atomicAdd(offset_grad_input + (y0 * width + x0) * depth + z0, static_cast<T>(g000));
             atomicAdd(offset_grad_input + (y0 * width + x0) * depth + z1, static_cast<T>(g001));
             atomicAdd(offset_grad_input + (y1 * width + x0) * depth + z0, static_cast<T>(g010));
             atomicAdd(offset_grad_input + (y0 * width + x1) * depth + z0, static_cast<T>(g100));
             atomicAdd(offset_grad_input + (y1 * width + x0) * depth + z1, static_cast<T>(g011));
             atomicAdd(offset_grad_input + (y0 * width + x1) * depth + z1, static_cast<T>(g101));
             atomicAdd(offset_grad_input + (y1 * width + x1) * depth + z0, static_cast<T>(g110));
             atomicAdd(offset_grad_input + (y1 * width + x1) * depth + z1, static_cast<T>(g111));
           } // if
         } // iz
       } // ix
     } // iy
   } // CUDA_1D_KERNEL_LOOP
 } // RoIAlignBackward
 
 
 /*----------- wrapper functions ----------------*/
 
 at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, const at::Tensor& rois, const float spatial_scale,
                                 const int pooled_height, const int pooled_width, const int pooled_depth, const int sampling_ratio) {
   /*
    input: feature-map tensor, shape (batch, n_channels, y, x(, z))
    */
   AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor");
   AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
 
   at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
 
   at::CheckedFrom c = "ROIAlign_forward_cuda";
   at::checkAllSameGPU(c, {input_t, rois_t});
   at::checkAllSameType(c, {input_t, rois_t});
 
   at::cuda::CUDAGuard device_guard(input.device());
 
   auto num_rois = rois.size(0);
   auto channels = input.size(1);
   auto height = input.size(2);
   auto width = input.size(3);
   auto depth = input.size(4);
 
   at::Tensor output = at::zeros(
       {num_rois, channels, pooled_height, pooled_width, pooled_depth}, input.options());
 
   auto output_size = num_rois * channels * pooled_height * pooled_width * pooled_depth;
   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
   dim3 grid(std::min(
       at::cuda::ATenCeilDiv(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)));
   dim3 block(512);
 
   if (output.numel() == 0) {
     AT_CUDA_CHECK(cudaGetLastError());
     return output;
   }
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign forward in 3d", [&] {
     RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>(
         output_size,
         input.contiguous().data_ptr<scalar_t>(),
         spatial_scale,
         channels,
         height,
         width,
         depth,
         pooled_height,
         pooled_width,
         pooled_depth,
         sampling_ratio,
         rois.contiguous().data_ptr<scalar_t>(),
         output.data_ptr<scalar_t>());
   });
   AT_CUDA_CHECK(cudaGetLastError());
   return output;
 }
 
 at::Tensor ROIAlign_backward_cuda(
     const at::Tensor& grad,
     const at::Tensor& rois,
     const float spatial_scale,
     const int pooled_height,
     const int pooled_width,
     const int pooled_depth,
     const int batch_size,
     const int channels,
     const int height,
     const int width,
     const int depth,
     const int sampling_ratio)
 {
   AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor");
   AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
 
   at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
 
   at::CheckedFrom c = "ROIAlign_backward_cuda";
   at::checkAllSameGPU(c, {grad_t, rois_t});
   at::checkAllSameType(c, {grad_t, rois_t});
 
   at::cuda::CUDAGuard device_guard(grad.device());
 
   at::Tensor grad_input =
       at::zeros({batch_size, channels, height, width, depth}, grad.options());
 
   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
   dim3 grid(std::min(
       at::cuda::ATenCeilDiv(
           static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)),
       static_cast<int64_t>(4096)));
   dim3 block(512);
 
   // handle possibly empty gradients
   if (grad.numel() == 0) {
     AT_CUDA_CHECK(cudaGetLastError());
     return grad_input;
   }
 
   int n_stride = grad.stride(0);
   int c_stride = grad.stride(1);
   int h_stride = grad.stride(2);
   int w_stride = grad.stride(3);
   int d_stride = grad.stride(4);
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign backward 3D", [&] {
     RoIAlignBackward<scalar_t><<<grid, block, 0, stream>>>(
         grad.numel(),
         grad.data_ptr<scalar_t>(),
         spatial_scale,
         channels,
         height,
         width,
         depth,
         pooled_height,
         pooled_width,
         pooled_depth,
         sampling_ratio,
         grad_input.data_ptr<scalar_t>(),
         rois.contiguous().data_ptr<scalar_t>(),
         n_stride,
         c_stride,
         h_stride,
         w_stride,
         d_stride);
   });
   AT_CUDA_CHECK(cudaGetLastError());
   return grad_input;
 }
\ No newline at end of file
diff --git a/datasets/lidc/configs.py b/datasets/lidc/configs.py
index 413ce8f..ff2c2d4 100644
--- a/datasets/lidc/configs.py
+++ b/datasets/lidc/configs.py
@@ -1,445 +1,445 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import sys
 import os
 from collections import namedtuple
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 import numpy as np
 sys.path.append(os.path.dirname(os.path.realpath(__file__))+"/../..")
 from default_configs import DefaultConfigs
 
 # legends, nested classes are not handled well in multiprocessing! hence, Label class def in outer scope
 Label = namedtuple("Label", ['id', 'name', 'color', 'm_scores']) # m_scores = malignancy scores
 binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'm_scores', 'bin_vals'])
 
 
 class Configs(DefaultConfigs):
 
     def __init__(self, server_env=None):
         super(Configs, self).__init__(server_env)
 
         #########################
         #    Preprocessing      #
         #########################
 
         self.root_dir = '/home/gregor/networkdrives/E130-Personal/Goetz/Datenkollektive/Lungendaten/Nodules_LIDC_IDRI'
         self.raw_data_dir = '{}/new_nrrd'.format(self.root_dir)
         self.pp_dir = '/media/gregor/HDD2TB/data/lidc/pp_20200309_dev'
         # 'merged' for one gt per image, 'single_annotator' for four gts per image.
         self.gts_to_produce = ["single_annotator", "merged"]
 
         self.target_spacing = (0.7, 0.7, 1.25)
 
         #########################
         #         I/O          #
         #########################
 
         # path to preprocessed data.
         #self.pp_name = 'pp_20190318'
         self.pp_name = 'pp_20200309_dev'
 
         self.input_df_name = 'info_df.pickle'
         self.data_sourcedir = '/media/gregor/HDD2TB/data/lidc/{}/'.format(self.pp_name)
 
         # settings for deployment on cluster.
         if server_env:
             # path to preprocessed data.
             self.data_sourcedir = '/datasets/data_ramien/lidc/{}_npz/'.format(self.pp_name)
 
         # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_fpn'].
         self.model = 'mrcnn'
         self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net')
         self.model_path = os.path.join(self.source_dir, self.model_path)
 
 
         #########################
         #      Architecture     #
         #########################
 
         # dimension the model operates in. one out of [2, 3].
         self.dim = 2
 
         # 'class': standard object classification per roi, pairwise combinable with each of below tasks.
         # if 'class' is omitted from tasks, object classes will be fg/bg (1/0) from RPN.
         # 'regression': regress some vector per each roi
         # 'regression_ken_gal': use kendall-gal uncertainty sigma
         # 'regression_bin': classify each roi into a bin related to a regression scale
-        self.prediction_tasks = ['class']
+        self.prediction_tasks = ['regression']
 
         self.start_filts = 48 if self.dim == 2 else 18
         self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2
         self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50'
         self.norm = None # one of None, 'instance_norm', 'batch_norm'
 
         # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform')
         self.weight_init = None
 
         self.regression_n_features = 1
 
         #########################
         #      Data Loader      #
         #########################
 
         # distorted gt experiments: train on single-annotator gts in a random fashion to investigate network's
         # handling of noisy gts.
         # choose 'merged' for single, merged gt per image, or 'single_annotator' for four gts per image.
         # validation is always performed on same gt kind as training, testing always on merged gt.
         self.training_gts = "merged"
 
         # select modalities from preprocessed data
         self.channels = [0]
         self.n_channels = len(self.channels)
 
         # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation.
         self.pre_crop_size_2D = [320, 320]
         self.patch_size_2D = [320, 320]
         self.pre_crop_size_3D = [160, 160, 96]
         self.patch_size_3D = [160, 160, 96]
 
         self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D
         self.pre_crop_size = self.pre_crop_size_2D if self.dim == 2 else self.pre_crop_size_3D
 
         # ratio of free sampled batch elements before class balancing is triggered
         # (>0 to include "empty"/background patches.)
         self.batch_random_ratio = 0.3
         self.balance_target =  "class_targets" if 'class' in self.prediction_tasks else 'rg_bin_targets'
 
         # set 2D network to match 3D gt boxes.
         self.merge_2D_to_3D_preds = self.dim==2
 
         self.observables_rois = []
 
         #self.rg_map = {1:1, 2:2, 3:3, 4:4, 5:5}
 
         #########################
         #   Colors and Legends  #
         #########################
         self.plot_frequency = 5
 
         binary_cl_labels = [Label(1, 'benign',  (*self.dark_green, 1.),  (1, 2)),
                             Label(2, 'malignant', (*self.red, 1.),  (3, 4, 5))]
         quintuple_cl_labels = [Label(1, 'MS1',  (*self.dark_green, 1.),      (1,)),
                                Label(2, 'MS2',  (*self.dark_yellow, 1.),     (2,)),
                                Label(3, 'MS3',  (*self.orange, 1.),     (3,)),
                                Label(4, 'MS4',  (*self.bright_red, 1.), (4,)),
                                Label(5, 'MS5',  (*self.red, 1.),        (5,))]
         # choose here if to do 2-way or 5-way regression-bin classification
         task_spec_cl_labels = quintuple_cl_labels
 
         self.class_labels = [
             #       #id #name     #color              #malignancy score
             Label(  0,  'bg',     (*self.gray, 0.),  (0,))]
         if "class" in self.prediction_tasks:
             self.class_labels += task_spec_cl_labels
 
         else:
             self.class_labels += [Label(1, 'lesion', (*self.orange, 1.), (1,2,3,4,5))]
 
         if any(['regression' in task for task in self.prediction_tasks]):
             self.bin_labels = [binLabel(0, 'MS0', (*self.gray, 1.), (0,), (0,))]
             self.bin_labels += [binLabel(cll.id, cll.name, cll.color, cll.m_scores,
                                          tuple([ms for ms in cll.m_scores])) for cll in task_spec_cl_labels]
             self.bin_id2label = {label.id: label for label in self.bin_labels}
             self.ms2bin_label = {ms: label for label in self.bin_labels for ms in label.m_scores}
             bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels]
             self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)}
             self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)]
 
         if self.class_specific_seg:
             self.seg_labels = self.class_labels
         else:
             self.seg_labels = [  # id      #name           #color
                 Label(0, 'bg', (*self.gray, 0.)),
                 Label(1, 'fg', (*self.orange, 1.))
             ]
 
         self.class_id2label = {label.id: label for label in self.class_labels}
         self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0}
         # class_dict is used in evaluator / ap, auc, etc. statistics, and class 0 (bg) only needs to be
         # evaluated in debugging
         self.class_cmap = {label.id: label.color for label in self.class_labels}
 
         self.seg_id2label = {label.id: label for label in self.seg_labels}
         self.cmap = {label.id: label.color for label in self.seg_labels}
 
         self.plot_prediction_histograms = True
         self.plot_stat_curves = False
         self.has_colorchannels = False
         self.plot_class_ids = True
 
         self.num_classes = len(self.class_dict)  # for instance classification (excl background)
         self.num_seg_classes = len(self.seg_labels)  # incl background
 
 
         #########################
         #   Data Augmentation   #
         #########################
 
         self.da_kwargs={
             'mirror': True,
             'mirror_axes': tuple(np.arange(0, self.dim, 1)),
             'do_elastic_deform': True,
             'alpha':(0., 1500.),
             'sigma':(30., 50.),
             'do_rotation':True,
             'angle_x': (0., 2 * np.pi),
             'angle_y': (0., 0),
             'angle_z': (0., 0),
             'do_scale': True,
             'scale':(0.8, 1.1),
             'random_crop':False,
             'rand_crop_dist':  (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3),
             'border_mode_data': 'constant',
             'border_cval_data': 0,
             'order_data': 1}
 
         if self.dim == 3:
             self.da_kwargs['do_elastic_deform'] = False
             self.da_kwargs['angle_x'] = (0, 0.0)
             self.da_kwargs['angle_y'] = (0, 0.0) #must be 0!!
             self.da_kwargs['angle_z'] = (0., 2 * np.pi)
 
         #################################
         #  Schedule / Selection / Optim #
         #################################
 
         self.num_epochs = 130 if self.dim == 2 else 150
         self.num_train_batches = 200 if self.dim == 2 else 200
         self.batch_size = 20 if self.dim == 2 else 8
 
         # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training)
         # the former is morge accurate, while the latter is faster (depending on volume size)
         self.val_mode = 'val_sampling' # only 'val_sampling', 'val_patient' not implemented
         if self.val_mode == 'val_patient':
             raise NotImplementedError
         if self.val_mode == 'val_sampling':
             self.num_val_batches = 70
 
         self.save_n_models = 4
         # set a minimum epoch number for saving in case of instabilities in the first phase of training.
         self.min_save_thresh = 0 if self.dim == 2 else 0
         # criteria to average over for saving epochs, 'criterion':weight.
         if "class" in self.prediction_tasks:
             # 'criterion': weight
             if len(self.class_labels)==3:
                 self.model_selection_criteria = {"benign_ap": 0.5, "malignant_ap": 0.5}
             elif len(self.class_labels)==6:
                 self.model_selection_criteria = {str(label.name)+"_ap": 1./5 for label in self.class_labels if label.id!=0}
         elif any("regression" in task for task in self.prediction_tasks):
             self.model_selection_criteria = {"lesion_ap": 0.2, "lesion_avp": 0.8}
 
         self.weight_decay = 0
         self.clip_norm = 200 if 'regression_ken_gal' in self.prediction_tasks else None  # number or None
 
         # int in [0, dataset_size]. select n patients from dataset for prototyping. If None, all data is used.
         self.select_prototype_subset = None #self.batch_size
 
         #########################
         #        Testing        #
         #########################
 
         # set the top-n-epochs to be saved for temporal averaging in testing.
         self.test_n_epochs = self.save_n_models
 
         self.test_aug_axes = (0,1,(0,1))  # None or list: choices are 0,1,(0,1) (0==spatial y, 1== spatial x).
         self.held_out_test_set = False
         self.max_test_patients = "all"  # "all" or number
 
         self.report_score_level = ['rois', 'patient']  # choose list from 'patient', 'rois'
         self.patient_class_of_interest = 2 if 'class' in self.prediction_tasks else 1
 
         self.metrics = ['ap', 'auc']
         if any(['regression' in task for task in self.prediction_tasks]):
             self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp',
                              'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp']
         if 'aleatoric' in self.model:
             self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted']
         self.evaluate_fold_means = True
 
         self.ap_match_ious = [0.1]  # list of ious to be evaluated for ap-scoring.
         self.min_det_thresh = 0.1  # minimum confidence value to select predictions for evaluation.
 
         # aggregation method for test and val_patient predictions.
         # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf,
         # nms = standard non-maximum suppression, or None = no clustering
         self.clustering = 'wbc'
         # iou thresh (exclusive!) for regarding two preds as concerning the same ROI
         self.clustering_iou = 0.1  # has to be larger than desired possible overlap iou of model predictions
 
         self.plot_prediction_histograms = True
         self.plot_stat_curves = False
         self.n_test_plots = 1
 
         #########################
         #   Assertions          #
         #########################
         if not 'class' in self.prediction_tasks:
             assert self.num_classes == 1
 
         #########################
         #   Add model specifics #
         #########################
 
         {'detection_fpn': self.add_det_fpn_configs,
          'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs,
          'retina_net': self.add_mrcnn_configs,
          'retina_unet': self.add_mrcnn_configs,
         }[self.model]()
 
     def rg_val_to_bin_id(self, rg_val):
         return float(np.digitize(np.mean(rg_val), self.bin_edges))
 
     def add_det_fpn_configs(self):
 
         self.learning_rate = [1e-4] * self.num_epochs
         self.dynamic_lr_scheduling = False
 
         # RoI score assigned to aggregation from pixel prediction (connected component). One of ['max', 'median'].
         self.score_det = 'max'
 
         # max number of roi candidates to identify per batch element and class.
         self.n_roi_candidates = 10 if self.dim == 2 else 30
 
         # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
         self.seg_loss_mode = 'wce'
 
         # if <1, false positive predictions in foreground are penalized less.
         self.fp_dice_weight = 1 if self.dim == 2 else 1
         if len(self.class_labels)==3:
             self.wce_weights = [1., 1., 1.] if self.seg_loss_mode=="dice_wce" else [0.1, 1., 1.]
         elif len(self.class_labels)==6:
             self.wce_weights = [1., 1., 1., 1., 1., 1.] if self.seg_loss_mode == "dice_wce" else [0.1, 1., 1., 1., 1., 1.]
         else:
             raise Exception("mismatch loss weights & nr of classes")
         self.detection_min_confidence = self.min_det_thresh
 
         self.head_classes = self.num_seg_classes
 
     def add_mrcnn_configs(self):
 
         # learning rate is a list with one entry per epoch.
         self.learning_rate = [1e-4] * self.num_epochs
         self.dynamic_lr_scheduling = False
         # disable the re-sampling of mask proposals to original size for speed-up.
         # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching),
         # mask-outputs are optional.
         self.return_masks_in_train = False
         self.return_masks_in_val = True
         self.return_masks_in_test = False
 
         # set number of proposal boxes to plot after each epoch.
         self.n_plot_rpn_props = 5 if self.dim == 2 else 30
 
         # number of classes for network heads: n_foreground_classes + 1 (background)
         self.head_classes = self.num_classes + 1
 
         self.frcnn_mode = False
 
         # feature map strides per pyramid level are inferred from architecture.
         self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]}
 
         # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale
         # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.)
         self.rpn_anchor_scales = {'xy': [[8], [16], [32], [64]], 'z': [[2], [4], [8], [16]]}
 
         # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3.
         self.pyramid_levels = [0, 1, 2, 3]
 
         # number of feature maps in rpn. typically lowered in 3D to save gpu-memory.
         self.n_rpn_features = 512 if self.dim == 2 else 128
 
         # anchor ratios and strides per position in feature maps.
         self.rpn_anchor_ratios = [0.5, 1, 2]
         self.rpn_anchor_stride = 1
 
         # Threshold for first stage (RPN) non-maximum suppression (NMS):  LOWER == HARDER SELECTION
         self.rpn_nms_threshold = 0.7 if self.dim == 2 else 0.7
 
         # loss sampling settings.
         self.rpn_train_anchors_per_image = 6  #per batch element
         self.train_rois_per_image = 6 #per batch element
         self.roi_positive_ratio = 0.5
         self.anchor_matching_iou = 0.7
 
         # factor of top-k candidates to draw from  per negative sample (stochastic-hard-example-mining).
         # poolsize to draw top-k candidates from will be shem_poolsize * n_negative_samples.
         self.shem_poolsize = 10
 
         self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3)
         self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5)
         self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10)
 
         self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
         self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
         self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]])
         self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1],
                                self.patch_size_3D[2], self.patch_size_3D[2]])
         if self.dim == 2:
             self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4]
             self.bbox_std_dev = self.bbox_std_dev[:4]
             self.window = self.window[:4]
             self.scale = self.scale[:4]
 
         # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element.
         self.pre_nms_limit = 3000 if self.dim == 2 else 6000
 
         # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True,
         # since proposals of the entire batch are forwarded through second stage in as one "batch".
         self.roi_chunk_size = 2500 if self.dim == 2 else 600
         self.post_nms_rois_training = 500 if self.dim == 2 else 75
         self.post_nms_rois_inference = 500
 
         # Final selection of detections (refine_detections)
         self.model_max_instances_per_batch_element = 10 if self.dim == 2 else 30  # per batch element and class.
         self.detection_nms_threshold = 1e-5  # needs to be > 0, otherwise all predictions are one cluster.
         self.model_min_confidence = 0.1
 
         if self.dim == 2:
             self.backbone_shapes = np.array(
                 [[int(np.ceil(self.patch_size[0] / stride)),
                   int(np.ceil(self.patch_size[1] / stride))]
                  for stride in self.backbone_strides['xy']])
         else:
             self.backbone_shapes = np.array(
                 [[int(np.ceil(self.patch_size[0] / stride)),
                   int(np.ceil(self.patch_size[1] / stride)),
                   int(np.ceil(self.patch_size[2] / stride_z))]
                  for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z']
                                              )])
 
         if self.model == 'retina_net' or self.model == 'retina_unet':
 
             self.focal_loss = True
 
             # implement extra anchor-scales according to retina-net publication.
             self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                             self.rpn_anchor_scales['xy']]
             self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                            self.rpn_anchor_scales['z']]
             self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3
 
             self.n_rpn_features = 256 if self.dim == 2 else 128
 
             # pre-selection of detections for NMS-speedup. per entire batch.
             self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size
 
             # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002
             self.anchor_matching_iou = 0.5
 
             if self.model == 'retina_unet':
                 self.operate_stride1 = True
 
diff --git a/datasets/lidc/data_loader.py b/datasets/lidc/data_loader.py
index fb97815..4f5b3b0 100644
--- a/datasets/lidc/data_loader.py
+++ b/datasets/lidc/data_loader.py
@@ -1,1024 +1,1025 @@
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 '''
 Data Loader for the LIDC data set. This dataloader expects preprocessed data in .npy or .npz files per patient and
 a pandas dataframe containing the meta info e.g. file paths, and some ground-truth info like labels, foreground slice ids.
 
 LIDC 4-fold annotations storage capacity problem: keep segmentation gts compressed (npz), unpack at each batch generation.
 
 '''
 
 import plotting as plg
 
 import os
 import pickle
 import time
 from multiprocessing import Pool
 
 import numpy as np
 import pandas as pd
 from collections import OrderedDict
 
 # batch generator tools from https://github.com/MIC-DKFZ/batchgenerators
 from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror
 from batchgenerators.transforms.abstract_transforms import Compose
 from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
 from batchgenerators.transforms.spatial_transforms import SpatialTransform
 from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform
 
 
 import utils.dataloader_utils as dutils
 from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates
-
+from utils.dataloader_utils import BatchGenerator as BatchGeneratorParent
 
 def save_obj(obj, name):
     """Pickle a python object."""
     with open(name + '.pkl', 'wb') as f:
         pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
 
 def vector(item):
     """ensure item is vector-like (list or array or tuple)
     :param item: anything
     """
     if not isinstance(item, (list, tuple, np.ndarray)):
         item = [item]
     return item
 
 
 class Dataset(dutils.Dataset):
     r"""Load a dict holding memmapped arrays and clinical parameters for each patient,
     evtly subset of those.
         If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to
         cf.data_dest.
     :param cf: config object.
     :param logger: logger.
     :param subset_ids: subset of patient/sample identifiers to load from whole set.
     :param data_sourcedir: directory in which to find data, defaults to cf.data_sourcedir if None.
     :return: dict with imgs, segs, pids, class_labels, observables
     """
 
     def __init__(self, cf, logger=None, subset_ids=None, data_sourcedir=None, mode='train'):
         super(Dataset,self).__init__(cf, data_sourcedir)
         if mode == 'train' and not cf.training_gts == "merged":
             self.gt_dir = "patient_gts_sa"
             self.gt_kind = cf.training_gts
         else:
             self.gt_dir = "patient_gts_merged"
             self.gt_kind = "merged"
         if logger is not None:
             logger.info("loading {} ground truths for {}".format(self.gt_kind, 'training and validation' if mode=='train'
         else 'testing'))
 
         p_df = pd.read_pickle(os.path.join(self.data_sourcedir, self.gt_dir, cf.input_df_name))
         #exclude_pids = ["0305a", "0447a"]  # due to non-bg segmentation but bg mal label in nodules 5728, 8840
         #p_df = p_df[~p_df.pid.isin(exclude_pids)]
 
         if subset_ids is not None:
             p_df = p_df[p_df.pid.isin(subset_ids)]
             if logger is not None:
                 logger.info('subset: selected {} instances from df'.format(len(p_df)))
         if cf.select_prototype_subset is not None:
             prototype_pids = p_df.pid.tolist()[:cf.select_prototype_subset]
             p_df = p_df[p_df.pid.isin(prototype_pids)]
             if logger is not None:
                 logger.warning('WARNING: using prototyping data subset of length {}!!!'.format(len(p_df)))
 
         pids = p_df.pid.tolist()
 
         # evtly copy data from data_sourcedir to data_dest
         if cf.server_env and not hasattr(cf, 'data_dir') and hasattr(cf, "data_dest"):
                 # copy and unpack images
                 file_subset = ["{}_img.npz".format(pid) for pid in pids if not
                 os.path.isfile(os.path.join(cf.data_dest,'{}_img.npy'.format(pid)))]
                 file_subset += [os.path.join(self.data_sourcedir, self.gt_dir, cf.input_df_name)]
                 self.copy_data(cf, file_subset=file_subset, keep_packed=False, del_after_unpack=True)
                 # copy and do not unpack segmentations
                 file_subset = [os.path.join(self.gt_dir, "{}_rois.np*".format(pid)) for pid in pids]
                 keep_packed = not cf.training_gts == "merged"
                 self.copy_data(cf, file_subset=file_subset, keep_packed=keep_packed, del_after_unpack=(not keep_packed))
         else:
             cf.data_dir = self.data_sourcedir
 
         ext = 'npy' if self.gt_kind == "merged" else 'npz'
         imgs = [os.path.join(self.data_dir, '{}_img.npy'.format(pid)) for pid in pids]
         segs = [os.path.join(self.data_dir, self.gt_dir, '{}_rois.{}'.format(pid, ext)) for pid in pids]
         orig_class_targets = p_df['class_target'].tolist()
 
         data = OrderedDict()
 
         if self.gt_kind == 'merged':
             for ix, pid in enumerate(pids):
                 data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid}
                 data[pid]['fg_slices'] = np.array(p_df['fg_slices'].tolist()[ix])
                 if 'class' in cf.prediction_tasks:
                     if len(cf.class_labels)==3:
                         # malignancy scores are binarized: (benign: 1-2 --> cl 1, malignant: 3-5 --> cl 2)
                         data[pid]['class_targets'] = np.array([2 if ii >= 3 else 1 for ii in orig_class_targets[ix]],
                                                               dtype='uint8')
                     elif len(cf.class_labels)==6:
                         # classify each malignancy score
                         data[pid]['class_targets'] = np.array([1 if ii==0.5 else np.round(ii) for ii in orig_class_targets[ix]], dtype='uint8')
                     else:
                         raise Exception("mismatch class labels and data-loading implementations.")
                 else:
                     data[pid]['class_targets'] = np.ones_like(np.array(orig_class_targets[ix]), dtype='uint8')
                 if any(['regression' in task for task in cf.prediction_tasks]):
                     data[pid]["regression_targets"] = np.array([vector(v) for v in orig_class_targets[ix]],
                                                                dtype='float16')
                     data[pid]["rg_bin_targets"] = np.array(
                         [cf.rg_val_to_bin_id(v) for v in data[pid]["regression_targets"]], dtype='uint8')
         else:
             for ix, pid in enumerate(pids):
                 data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid}
                 data[pid]['fg_slices'] = np.array(p_df['fg_slices'].values[ix])
                 if 'class' in cf.prediction_tasks:
                     # malignancy scores are binarized: (benign: 1-2 --> cl 1, malignant: 3-5 --> cl 2)
                     raise NotImplementedError
                     # todo need to consider bg
                     # data[pid]['class_targets'] = np.array(
                     #     [[2 if ii >= 3 else 1 for ii in four_fold_targs] for four_fold_targs in orig_class_targets[ix]])
                 else:
                     data[pid]['class_targets'] = np.array(
                         [[1 if ii > 0 else 0 for ii in four_fold_targs] for four_fold_targs in orig_class_targets[ix]], dtype='uint8')
                 if any(['regression' in task for task in cf.prediction_tasks]):
                     data[pid]["regression_targets"] = np.array(
                         [[vector(v) for v in four_fold_targs] for four_fold_targs in orig_class_targets[ix]], dtype='float16')
                     data[pid]["rg_bin_targets"] = np.array(
                         [[cf.rg_val_to_bin_id(v) for v in four_fold_targs] for four_fold_targs in data[pid]["regression_targets"]], dtype='uint8')
 
         cf.roi_items = cf.observables_rois[:]
         cf.roi_items += ['class_targets']
         if any(['regression' in task for task in cf.prediction_tasks]):
             cf.roi_items += ['regression_targets']
             cf.roi_items += ['rg_bin_targets']
 
         self.data = data
         self.set_ids = np.array(list(self.data.keys()))
         self.df = None
 
 # merged GTs
 class BatchGenerator_merged(dutils.BatchGenerator):
     """
     creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D)
     from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size.
     Actual patch_size is obtained after data augmentation.
     :param data: data dictionary as provided by 'load_dataset'.
     :param batch_size: number of patients to sample for the batch
     :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target
     """
     def __init__(self, cf, data, name="train"):
         super(BatchGenerator_merged, self).__init__(cf, data)
 
         self.crop_margin = np.array(self.cf.patch_size)/8. #min distance of ROI center to edge of cropped_patch.
         self.p_fg = 0.5
         self.empty_samples_max_ratio = 0.6
 
         self.random_count = int(cf.batch_random_ratio * cf.batch_size)
         self.class_targets = {k: v["class_targets"] for (k, v) in self._data.items()}
 
 
         self.balance_target_distribution(plot=name=="train")
 
     def generate_train_batch(self):
 
         # samples patients towards equilibrium of foreground classes on a roi-level after sampling a random ratio
         # fully random patients
         batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False))
         # target-balanced patients
         batch_patient_ids += list(np.random.choice(self.dataset_pids, size=self.batch_size-self.random_count,
                                                    replace=False, p=self.p_probs))
 
         batch_data, batch_segs, batch_pids, batch_patient_labels = [], [], [], []
         batch_roi_items = {name: [] for name in self.cf.roi_items}
         # record roi count of classes in batch
         batch_roi_counts = np.zeros((len(self.unique_ts),), dtype='uint32')
         batch_empty_counts = np.zeros((len(self.unique_ts),), dtype='uint32')
         # empty count for full bg samples (empty slices in 2D/patients in 3D) per class
 
 
         for sample in range(self.batch_size):
             patient = self._data[batch_patient_ids[sample]]
 
             data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis]
             seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0))
             batch_pids.append(patient['pid'])
             (c, y, x, z) = data.shape
 
             if self.cf.dim == 2:
 
                 elig_slices, choose_fg = [], False
                 if len(patient['fg_slices']) > 0:
                     if np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or \
                             np.random.rand(1)<=self.p_fg:
                         # fg is to be picked
                         for tix in np.argsort(batch_roi_counts):
                             # pick slices of patient that have roi of sought-for target
                             # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix
                             elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero(
                                 patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0])-1] ==
                                 self.unique_ts[tix]) > 0]
                             if len(elig_slices) > 0:
                                 choose_fg = True
                                 break
                     else:
                         # pick bg
                         elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices'])
 
                 if len(elig_slices)>0:
                     sl_pick_ix = np.random.choice(elig_slices, size=None)
                 else:
                     sl_pick_ix = np.random.choice(z, size=None)
 
                 data = data[..., sl_pick_ix]
                 seg = seg[..., sl_pick_ix]
 
             # pad data if smaller than pre_crop_size.
             if np.any([data.shape[dim + 1] < ps for dim, ps in enumerate(self.cf.pre_crop_size)]):
                 new_shape = [np.max([data.shape[dim + 1], ps]) for dim, ps in enumerate(self.cf.pre_crop_size)]
                 data = dutils.pad_nd_image(data, new_shape, mode='constant')
                 seg = dutils.pad_nd_image(seg, new_shape, mode='constant')
 
             # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg.
             crop_dims = [dim for dim, ps in enumerate(self.cf.pre_crop_size) if data.shape[dim + 1] > ps]
             if len(crop_dims) > 0:
                 if self.cf.dim == 3:
                     choose_fg = np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio)\
                                 or np.random.rand(1) <= self.p_fg
                 if choose_fg and np.any(seg):
                     available_roi_ids = np.unique(seg)[1:]
                     for tix in np.argsort(batch_roi_counts):
                         elig_roi_ids = available_roi_ids[patient[self.balance_target][available_roi_ids-1] == self.unique_ts[tix]]
                         if len(elig_roi_ids)>0:
                             seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None))
                             break
                     roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)]
                     assert seg[tuple(roi_anchor_pixel)] > 0
                     # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by
                     # distance to the desired ROI < patch_size /2.
                     # (here final patch size to account for center_crop after data augmentation).
                     sample_seg_center = {}
                     for ii in crop_dims:
                         low = np.max((self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] - (self.cf.patch_size[ii]//2 - self.crop_margin[ii])))
                         high = np.min((data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2,
                                        roi_anchor_pixel[ii] + (self.cf.patch_size[ii]//2 - self.crop_margin[ii])))
                         # happens if lesion on the edge of the image. dont care about roi anymore,
                         # just make sure pre-crop is inside image.
                         if low >= high:
                             low = data.shape[ii + 1] // 2 - (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2)
                             high = data.shape[ii + 1] // 2 + (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2)
                         sample_seg_center[ii] = np.random.randint(low=low, high=high)
 
                 else:
                     # not guaranteed to be empty. probability of emptiness depends on the data.
                     sample_seg_center = {ii: np.random.randint(low=self.cf.pre_crop_size[ii]//2,
                                                            high=data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2) for ii in crop_dims}
 
                 for ii in crop_dims:
                     min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2)
                     max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2)
                     data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1)
                     seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii)
 
             batch_data.append(data)
             batch_segs.append(seg[np.newaxis])
             for o in batch_roi_items: #after loop, holds every entry of every batchpatient per roi-item
                     batch_roi_items[o].append(patient[o])
 
             if self.cf.dim == 3:
                 for tix in range(len(self.unique_ts)):
                     non_zero = np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix])
                     batch_roi_counts[tix] += non_zero
                     batch_empty_counts[tix] += int(non_zero==0)
                     # todo remove assert when checked
                     if not np.any(seg):
                         assert non_zero==0
             elif self.cf.dim == 2:
                 for tix in range(len(self.unique_ts)):
                     non_zero = np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix])
                     batch_roi_counts[tix] += non_zero
                     batch_empty_counts[tix] += int(non_zero == 0)
                     # todo remove assert when checked
                     if not np.any(seg):
                         assert non_zero==0
 
 
         data = np.array(batch_data).astype(np.float16)
         seg = np.array(batch_segs).astype(np.uint8)
         batch = {'data': data, 'seg': seg, 'pid': batch_pids,
                 'roi_counts':batch_roi_counts, 'empty_counts': batch_empty_counts}
         for key,val in batch_roi_items.items(): #extend batch dic by roi-wise items (obs, class ids, regression vectors...)
             batch[key] = np.array(val)
 
         return batch
 
 class PatientBatchIterator_merged(dutils.PatientBatchIterator):
     """
     creates a test generator that iterates over entire given dataset returning 1 patient per batch.
     Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actualy evaluation (done in 3D),
     if willing to accept speed-loss during training.
     :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or
     batch_size = n_2D_patches in 2D .
     """
 
     def __init__(self, cf, data):  # threads in augmenter
         super(PatientBatchIterator_merged, self).__init__(cf, data)
         self.patient_ix = 0
         self.patch_size = cf.patch_size + [1] if cf.dim == 2 else cf.patch_size
 
     def generate_train_batch(self, pid=None):
 
         if pid is None:
             pid = self.dataset_pids[self.patient_ix]
         patient = self._data[pid]
 
         data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))
         seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0))
 
         # pad data if smaller than patch_size seen during training.
         if np.any([data.shape[dim] < ps for dim, ps in enumerate(self.patch_size)]):
             new_shape = [np.max([data.shape[dim], self.patch_size[dim]]) for dim, ps in enumerate(self.patch_size)]
             data = dutils.pad_nd_image(data, new_shape)  # use 'return_slicer' to crop image back to original shape.
             seg = dutils.pad_nd_image(seg, new_shape)
 
         # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor.
         if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds:
             out_data = data[np.newaxis, np.newaxis]
             out_seg = seg[np.newaxis, np.newaxis]
             batch_3D = {'data': out_data, 'seg': out_seg}
             for o in self.cf.roi_items:
                 batch_3D[o] = np.array([patient[o]])
             converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg)
             batch_3D = converter(**batch_3D)
             batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape})
             for o in self.cf.roi_items:
                 batch_3D["patient_" + o] = batch_3D[o]
 
         if self.cf.dim == 2:
             out_data = np.transpose(data, axes=(2, 0, 1))[:, np.newaxis]  # (z, c, x, y )
             out_seg = np.transpose(seg, axes=(2, 0, 1))[:, np.newaxis]
 
             batch_2D = {'data': out_data, 'seg': out_seg}
             for o in self.cf.roi_items:
                 batch_2D[o] = np.repeat(np.array([patient[o]]), out_data.shape[0], axis=0)
 
             converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg)
             batch_2D = converter(**batch_2D)
 
             if self.cf.merge_2D_to_3D_preds:
                 batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'],
                                  'original_img_shape': out_data.shape})
                 for o in self.cf.roi_items:
                     batch_2D["patient_" + o] = batch_3D[o]
             else:
                 batch_2D.update({'patient_bb_target': batch_2D['bb_target'],
                                  'original_img_shape': out_data.shape})
                 for o in self.cf.roi_items:
                     batch_2D["patient_" + o] = batch_2D[o]
 
         out_batch = batch_3D if self.cf.dim == 3 else batch_2D
         out_batch.update({'pid': np.array([patient['pid']] * len(out_data))})
 
         # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension.
         # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1.
         if np.any([data.shape[dim] > self.patch_size[dim] for dim in range(3)]):
             patient_batch = out_batch
             patch_crop_coords_list = dutils.get_patch_crop_coords(data, self.patch_size)
             new_img_batch, new_seg_batch = [], []
 
             for cix, c in enumerate(patch_crop_coords_list):
 
                 seg_patch = seg[c[0]:c[1], c[2]: c[3], c[4]:c[5]]
                 new_seg_batch.append(seg_patch)
 
                 tmp_c_5 = c[5]
 
                 new_img_batch.append(data[c[0]:c[1], c[2]:c[3], c[4]:tmp_c_5])
 
             data = np.array(new_img_batch)[:, np.newaxis]  # (n_patches, c, x, y, z)
             seg = np.array(new_seg_batch)[:, np.newaxis]  # (n_patches, 1, x, y, z)
             if self.cf.dim == 2:
                 # all patches have z dimension 1 (slices). discard dimension
                 data = data[..., 0]
                 seg = seg[..., 0]
 
             patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'),
                            'pid': np.array([patient['pid']] * data.shape[0])}
             for o in self.cf.roi_items:
                 patch_batch[o] = np.repeat(np.array([patient[o]]), len(patch_crop_coords_list), axis=0)
             # patient-wise (orig) batch info for putting the patches back together after prediction
             for o in self.cf.roi_items:
                 patch_batch["patient_" + o] = patient_batch['patient_' + o]
                 if self.cf.dim == 2:
                     # this could also be named "unpatched_2d_roi_items"
                     patch_batch["patient_" + o + "_2d"] = patient_batch[o]
             # adding patient-wise data and seg adds about 2 GB of additional RAM consumption to a batch 20x288x288
             # and enables calculating test-dice/viewing patient-wise results in test
             # remove, but also remove dice from metrics, when like to save memory
             patch_batch['patient_data'] = patient_batch['data']
             patch_batch['patient_seg'] = patient_batch['seg']
             patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
             patch_batch['patient_bb_target'] = patient_batch['patient_bb_target']
             if self.cf.dim == 2:
                 patch_batch['patient_bb_target_2d'] = patient_batch['bb_target']
             patch_batch['original_img_shape'] = patient_batch['original_img_shape']
 
             converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False,
                                                            self.cf.class_specific_seg)
             patch_batch = converter(**patch_batch)
             out_batch = patch_batch
 
         self.patient_ix += 1
         if self.patient_ix == len(self.dataset_pids):
             self.patient_ix = 0
 
         return out_batch
 
 # single-annotator GTs
-class BatchGenerator_sa(dutils.BatchGenerator):
+class BatchGenerator_sa(BatchGeneratorParent):
     """
     creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D)
     from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size.
     Actual patch_size is obtained after data augmentation.
     :param data: data dictionary as provided by 'load_dataset'.
     :param batch_size: number of patients to sample for the batch
     :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target
     """
 
     # noinspection PyMethodOverriding
     def balance_target_distribution(self, rater, plot=False):
         """
         :param rater: for which rater slot to generate the distribution
         :param self.targets:  dic holding {patient_specifier : patient-wise-unique ROI targets}
         :param plot: whether to plot the generated patient distributions
         :return: probability distribution over all pids. draw without replace from this.
         """
         unique_ts = np.unique([v[rater] for pat in self.targets.values() for v in pat])
         sample_stats = pd.DataFrame(columns=[str(ix) + suffix for ix in unique_ts for suffix in ["", "_bg"]],
                                          index=list(self.targets.keys()))
         for pid in sample_stats.index:
             for targ in unique_ts:
                 fg_count = 0 if len(self.targets[pid]) == 0 else np.count_nonzero(self.targets[pid][:, rater] == targ)
                 sample_stats.loc[pid, str(targ)] = int(fg_count > 0)
                 sample_stats.loc[pid, str(targ) + "_bg"] = int(fg_count == 0)
 
         target_stats = sample_stats.agg(
             ("sum", lambda col: col.sum() / len(self._data)), axis=0, sort=False).rename({"<lambda>": "relative"})
 
         anchor = 1. - target_stats.loc["relative"].iloc[0]
         fg_bg_weights = anchor / target_stats.loc["relative"]
         cum_weights = anchor * len(fg_bg_weights)
         fg_bg_weights /= cum_weights
 
         p_probs = sample_stats.apply(self.sample_targets_to_weights, args=(fg_bg_weights,), axis=1).sum(axis=1)
         p_probs = p_probs / p_probs.sum()
         if plot:
             print("Rater: {}. Applying class-weights:\n {}".format(rater, fg_bg_weights))
         if len(sample_stats.columns) == 2:
             # assert that probs are calc'd correctly:
             # (p_probs * sample_stats["1"]).sum() == (p_probs * sample_stats["1_bg"]).sum()
             # only works if one label per patient (multi-label expectations depend on multi-label occurences).
             for rater in range(self.rater_bsize):
                 expectations = []
                 for targ in sample_stats.columns:
                     expectations.append((p_probs[rater] * sample_stats[targ]).sum())
                 assert np.allclose(expectations, expectations[0], atol=1e-4), "expectation values for fgs/bgs: {}".format(
                     expectations)
 
         if plot:
             plg.plot_batchgen_distribution(self.cf, self.dataset_pids, p_probs, self.balance_target,
                                            out_file=os.path.join(self.plot_dir,
                                                                  "train_gen_distr_"+str(self.cf.fold)+"_rater"+str(rater)+".png"))
         return p_probs, unique_ts, sample_stats
 
 
 
     def __init__(self, cf, data, name="train"):
         super(BatchGenerator_sa, self).__init__(cf, data)
         self.name = name
         self.crop_margin = np.array(self.cf.patch_size) / 8.  # min distance of ROI center to edge of cropped_patch.
         self.p_fg = 0.5
         self.empty_samples_max_ratio = 0.6
 
         self.random_count = int(cf.batch_random_ratio * cf.batch_size)
 
         self.rater_bsize = 4
         unique_ts_total = set()
         self.p_probs = []
         self.sample_stats = []
 
         # todo resolve pickling error
         # p = Pool(processes=min(self.rater_bsize, cf.n_workers))
         # mp_res = p.starmap(self.balance_target_distribution, [(r, name=="train") for r in range(self.rater_bsize)])
         # p.close()
         # p.join()
         # for r, res in enumerate(mp_res):
         #     p_probs, unique_ts, sample_stats = res
         #     self.p_probs.append(p_probs)
         #     self.sample_stats.append(sample_stats)
         #     unique_ts_total.update(unique_ts)
 
         for r in range(self.rater_bsize):
             # todo multiprocess. takes forever
             p_probs, unique_ts, sample_stats = self.balance_target_distribution(r, plot=name == "train")
             self.p_probs.append(p_probs)
             self.sample_stats.append(sample_stats)
             unique_ts_total.update(unique_ts)
 
         self.unique_ts = sorted(list(unique_ts_total))
         self.stats = {"roi_counts": np.zeros(len(self.unique_ts,), dtype='uint32'),
                       "empty_counts": np.zeros(len(self.unique_ts,), dtype='uint32')}
 
     def generate_train_batch(self):
 
         rater = np.random.randint(self.rater_bsize)
 
         # samples patients towards equilibrium of foreground classes on a roi-level (after randomly sampling the ratio batch_random_ratio).
         # random patients
         batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False))
         # target-balanced patients
         batch_patient_ids += list(np.random.choice(self.dataset_pids, size=self.batch_size-self.random_count, replace=False,
                                              p=self.p_probs[rater]))
 
         batch_data, batch_segs, batch_pids, batch_patient_labels = [], [], [], []
         batch_roi_items = {name: [] for name in self.cf.roi_items}
         # record roi count of classes in batch
         batch_roi_counts = np.zeros((len(self.unique_ts),), dtype='uint32')
         batch_empty_counts = np.zeros((len(self.unique_ts),), dtype='uint32')
         # empty count for full bg samples (empty slices in 2D/patients in 3D)
 
 
         for sample in range(self.batch_size):
 
             patient = self._data[batch_patient_ids[sample]]
 
             patient_balance_ts = np.array([roi[rater] for roi in patient[self.balance_target]])
             data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis]
             seg = np.load(patient['seg'], mmap_mode='r')
             seg = np.transpose(seg[list(seg.keys())[0]][rater], axes=(1, 2, 0))
             batch_pids.append(patient['pid'])
             (c, y, x, z) = data.shape
 
             if self.cf.dim == 2:
 
                 elig_slices, choose_fg = [], False
                 if len(patient['fg_slices']) > 0:
                     if np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or \
                             np.random.rand(1) <= self.p_fg:
                         # fg is to be picked
                         for tix in np.argsort(batch_roi_counts):
                             # pick slices of patient that have roi of sought-for target
                             # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix
                             elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero(
                                 patient_balance_ts[np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] ==
                                 self.unique_ts[tix]) > 0]
                             if len(elig_slices) > 0:
                                 choose_fg = True
                                 break
                     else:
                         # pick bg
                         elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices'][rater])
 
                 if len(elig_slices) > 0:
                     sl_pick_ix = np.random.choice(elig_slices, size=None)
                 else:
                     sl_pick_ix = np.random.choice(z, size=None)
 
                 data = data[..., sl_pick_ix]
                 seg = seg[..., sl_pick_ix]
 
             # pad data if smaller than pre_crop_size.
             if np.any([data.shape[dim + 1] < ps for dim, ps in enumerate(self.cf.pre_crop_size)]):
                 new_shape = [np.max([data.shape[dim + 1], ps]) for dim, ps in enumerate(self.cf.pre_crop_size)]
                 data = dutils.pad_nd_image(data, new_shape, mode='constant')
                 seg = dutils.pad_nd_image(seg, new_shape, mode='constant')
 
             # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg.
             crop_dims = [dim for dim, ps in enumerate(self.cf.pre_crop_size) if data.shape[dim + 1] > ps]
             if len(crop_dims) > 0:
                 if self.cf.dim == 3:
                     choose_fg = np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or \
                                 np.random.rand(1) <= self.p_fg
                 if choose_fg and np.any(seg):
                     available_roi_ids = np.unique(seg[seg>0])
                     assert np.all(patient_balance_ts[available_roi_ids-1]>0), "trying to choose roi with rating 0"
                     for tix in np.argsort(batch_roi_counts):
                         elig_roi_ids = available_roi_ids[ patient_balance_ts[available_roi_ids-1] == self.unique_ts[tix] ]
                         if len(elig_roi_ids)>0:
                             seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None))
                             roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)]
                             break
 
                     assert seg[tuple(roi_anchor_pixel)] > 0, "roi_anchor_pixel not inside roi: {}, pb_ts {}, elig ids {}".format(tuple(roi_anchor_pixel), patient_balance_ts, elig_roi_ids)
                     # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by
                     # distance to the desired ROI < patch_size /2.
                     # (here final patch size to account for center_crop after data augmentation).
                     sample_seg_center = {}
                     for ii in crop_dims:
                         low = np.max((self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] - (self.cf.patch_size[ii]//2 - self.crop_margin[ii])))
                         high = np.min((data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2,
                                        roi_anchor_pixel[ii] + (self.cf.patch_size[ii]//2 - self.crop_margin[ii])))
                         # happens if lesion on the edge of the image. dont care about roi anymore,
                         # just make sure pre-crop is inside image.
                         if low >= high:
                             low = data.shape[ii + 1] // 2 - (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2)
                             high = data.shape[ii + 1] // 2 + (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2)
                         sample_seg_center[ii] = np.random.randint(low=low, high=high)
                 else:
                     # not guaranteed to be empty. probability of emptiness depends on the data.
                     sample_seg_center = {ii: np.random.randint(low=self.cf.pre_crop_size[ii]//2,
                                                            high=data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2) for ii in crop_dims}
 
                 for ii in crop_dims:
                     min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2)
                     max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2)
                     data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1)
                     seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii)
 
             batch_data.append(data)
             batch_segs.append(seg[np.newaxis])
             for o in batch_roi_items: #after loop, holds every entry of every batchpatient per roi-item
                 batch_roi_items[o].append([roi[rater] for roi in patient[o]])
 
             if self.cf.dim == 3:
                 for tix in range(len(self.unique_ts)):
                     non_zero = np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix])
                     batch_roi_counts[tix] += non_zero
                     batch_empty_counts[tix] += int(non_zero==0)
                     # todo remove assert when checked
                     if not np.any(seg):
                         assert non_zero==0
             elif self.cf.dim == 2:
                 for tix in range(len(self.unique_ts)):
                     non_zero = np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix])
                     batch_roi_counts[tix] += non_zero
                     batch_empty_counts[tix] += int(non_zero == 0)
                     # todo remove assert when checked
                     if not np.any(seg):
                         assert non_zero==0
 
 
         data = np.array(batch_data).astype('float16')
         seg = np.array(batch_segs).astype('uint8')
         batch = {'data': data, 'seg': seg, 'pid': batch_pids, 'rater_id': rater,
                 'roi_counts': batch_roi_counts, 'empty_counts': batch_empty_counts}
         for key,val in batch_roi_items.items(): #extend batch dic by roi-wise items (obs, class ids, regression vectors...)
             batch[key] = np.array(val)
 
         return batch
 
 class PatientBatchIterator_sa(dutils.PatientBatchIterator):
     """
     creates a test generator that iterates over entire given dataset returning 1 patient per batch.
     Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actual evaluation (done in 3D),
     if willing to accept speed loss during training.
     :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or
     batch_size = n_2D_patches in 2D .
 
     This is the data & gt loader for the 4-fold single-annotator GTs: each data input has separate annotations of 4 annotators.
     the way the pipeline is currently setup, the single-annotator GTs are only used if training with validation mode
     val_patient; during testing the Iterator with the merged GTs is used.
     # todo mode val_patient not implemented yet (since very slow). would need to sample from all available rater GTs.
     """
     def __init__(self, cf, data): #threads in augmenter
         super(PatientBatchIterator_sa, self).__init__(cf, data)
         self.cf = cf
         self.patient_ix = 0
         self.dataset_pids = list(self._data.keys())
         self.patch_size =  cf.patch_size+[1] if cf.dim==2 else cf.patch_size
 
         self.rater_bsize = 4
 
 
     def generate_train_batch(self, pid=None):
 
         if pid is None:
             pid = self.dataset_pids[self.patient_ix]
         patient = self._data[pid]
 
         data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))
         # all gts are 4-fold and npz!
         seg = np.load(patient['seg'], mmap_mode='r')
         seg = np.transpose(seg[list(seg.keys())[0]], axes=(0, 2, 3, 1))
 
         # pad data if smaller than patch_size seen during training.
         if np.any([data.shape[dim] < ps for dim, ps in enumerate(self.patch_size)]):
             new_shape = [np.max([data.shape[dim], self.patch_size[dim]]) for dim, ps in enumerate(self.patch_size)]
             data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape.
             seg = dutils.pad_nd_image(seg, new_shape)
 
         # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor.
         if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds:
             out_data = data[np.newaxis, np.newaxis]
             out_seg = seg[:, np.newaxis]
             batch_3D = {'data': out_data, 'seg': out_seg}
 
             for item in self.cf.roi_items:
                 batch_3D[item] = []
             for r in range(self.rater_bsize):
                 for item in self.cf.roi_items:
                     batch_3D[item].append(np.array([roi[r] for roi in patient[item]]))
 
             converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg)
             batch_3D = converter(**batch_3D)
             batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape})
             for o in self.cf.roi_items:
                 batch_3D["patient_" + o] = batch_3D[o]
 
         if self.cf.dim == 2:
             out_data = np.transpose(data, axes=(2, 0, 1))[:, np.newaxis]  # (z, c, y, x )
             out_seg = np.transpose(seg, axes=(0, 3, 1, 2))[:, :, np.newaxis] # (n_raters, z, 1, y,x)
 
             batch_2D = {'data': out_data}
 
             for item in ["seg", "bb_target"]+self.cf.roi_items:
                 batch_2D[item] = []
 
             converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg)
             for r in range(self.rater_bsize):
                 tmp_batch = {"seg": out_seg[r]}
                 for item in self.cf.roi_items:
                     tmp_batch[item] = np.repeat(np.array([[roi[r] for roi in patient[item]]]), out_data.shape[0], axis=0)
                 tmp_batch = converter(**tmp_batch)
                 for item in ["seg", "bb_target"]+self.cf.roi_items:
                     batch_2D[item].append(tmp_batch[item])
             # for item in ["seg", "bb_target"]+self.cf.roi_items:
             #     batch_2D[item] = np.array(batch_2D[item])
 
             if self.cf.merge_2D_to_3D_preds:
                 batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'],
                                  'original_img_shape': out_data.shape})
                 for o in self.cf.roi_items:
                     batch_2D["patient_" + o] = batch_3D[o]
             else:
                 batch_2D.update({'patient_bb_target': batch_2D['bb_target'],
                                  'original_img_shape': out_data.shape})
                 for o in self.cf.roi_items:
                     batch_2D["patient_" + o] = batch_2D[o]
 
         out_batch = batch_3D if self.cf.dim == 3 else batch_2D
         out_batch.update({'pid': np.array([patient['pid']] * out_data.shape[0])})
 
         # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension.
         # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1.
         if np.any([data.shape[dim] > self.patch_size[dim] for dim in range(3)]):
             patient_batch = out_batch
             patch_crop_coords_list = dutils.get_patch_crop_coords(data, self.patch_size)
             new_img_batch  = []
             new_seg_batch = []
 
             for cix, c in enumerate(patch_crop_coords_list):
                 seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]]
                 new_seg_batch.append(seg_patch)
                 tmp_c_5 = c[5]
 
                 new_img_batch.append(data[c[0]:c[1], c[2]:c[3], c[4]:tmp_c_5])
 
             data = np.array(new_img_batch)[:, np.newaxis] # (n_patches, c, x, y, z)
             seg = np.transpose(np.array(new_seg_batch), axes=(1,0,2,3,4))[:,:,np.newaxis] # (n_raters, n_patches, x, y, z)
 
             if self.cf.dim == 2:
                 # all patches have z dimension 1 (slices). discard dimension
                 data = data[..., 0]
                 seg = seg[..., 0]
 
             patch_batch = {'data': data.astype('float32'),
                            'pid': np.array([patient['pid']] * data.shape[0])}
             # for o in self.cf.roi_items:
             #     patch_batch[o] = np.repeat(np.array([patient[o]]), len(patch_crop_coords_list), axis=0)
 
             converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False,
                                                            self.cf.class_specific_seg)
 
             for item in ["seg", "bb_target"]+self.cf.roi_items:
                 patch_batch[item] = []
             # coord_list = [np.min(seg_ixs[:, 1]) - 1, np.min(seg_ixs[:, 2]) - 1, np.max(seg_ixs[:, 1]) + 1,
             # IndexError: index 2 is out of bounds for axis 1 with size 2
             for r in range(self.rater_bsize):
                 tmp_batch = {"seg": seg[r]}
                 for item in self.cf.roi_items:
                     tmp_batch[item] = np.repeat(np.array([[roi[r] for roi in patient[item]]]), len(patch_crop_coords_list), axis=0)
                 tmp_batch = converter(**tmp_batch)
                 for item in ["seg", "bb_target"]+self.cf.roi_items:
                     patch_batch[item].append(tmp_batch[item])
 
             # patient-wise (orig) batch info for putting the patches back together after prediction
             for o in self.cf.roi_items:
                 patch_batch["patient_" + o] = patient_batch['patient_'+o]
                 if self.cf.dim==2:
                     # this could also be named "unpatched_2d_roi_items"
                     patch_batch["patient_"+o+"_2d"] = patient_batch[o]
             # adding patient-wise data and seg adds about 2 GB of additional RAM consumption to a batch 20x288x288
             # and enables calculating test-dice/viewing patient-wise results in test
             # remove, but also remove dice from metrics, if you like to save memory
             patch_batch['patient_data'] =  patient_batch['data']
             patch_batch['patient_seg'] = patient_batch['seg']
             patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
             patch_batch['patient_bb_target'] = patient_batch['patient_bb_target']
             if self.cf.dim==2:
                 patch_batch['patient_bb_target_2d'] = patient_batch['bb_target']
             patch_batch['original_img_shape'] = patient_batch['original_img_shape']
 
             out_batch = patch_batch
 
         self.patient_ix += 1
         if self.patient_ix == len(self.dataset_pids):
             self.patient_ix = 0
 
         return out_batch
 
 
 def create_data_gen_pipeline(cf, patient_data, is_training=True):
     """ create multi-threaded train/val/test batch generation and augmentation pipeline.
     :param cf: configs object.
     :param patient_data: dictionary containing one dictionary per patient in the train/test subset.
     :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing)
     :return: multithreaded_generator
     """
     BG_name = "train" if is_training else "val"
     data_gen = BatchGenerator_merged(cf, patient_data, name=BG_name) if cf.training_gts=='merged' else \
         BatchGenerator_sa(cf, patient_data, name=BG_name)
 
     # add transformations to pipeline.
     my_transforms = []
     if is_training:
         if cf.da_kwargs["mirror"]:
             mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes'])
             my_transforms.append(mirror_transform)
         spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
                                              patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
                                              do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
                                              alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
                                              do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
                                              angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
                                              do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
                                              random_crop=cf.da_kwargs['random_crop'])
 
         my_transforms.append(spatial_transform)
     else:
         my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))
 
     if cf.create_bounding_box_targets:
         my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg))
     all_transforms = Compose(my_transforms)
 
-    multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
+    multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=data_gen.n_filled_threads,
+                                                     seeds=range(data_gen.n_filled_threads))
     return multithreaded_generator
 
 def get_train_generators(cf, logger,  data_statistics=True):
     """
     wrapper function for creating the training batch generator pipeline. returns the train/val generators.
     selects patients according to cv folds (generated by first run/fold of experiment):
     splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set)
     If cf.held_out_test_set is True, adds the test split to the training data.
     """
     dataset = Dataset(cf, logger)
 
     dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits)
     dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle'))
     set_splits = dataset.fg.splits
 
     test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1)
     train_ids = np.concatenate(set_splits, axis=0)
 
     if cf.held_out_test_set:
         train_ids = np.concatenate((train_ids, test_ids), axis=0)
         test_ids = []
 
     train_data = {k: v for (k, v) in dataset.data.items() if k in train_ids}
     val_data = {k: v for (k, v) in dataset.data.items() if k in val_ids}
 
     logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids),
                                                                                     len(test_ids)))
     if data_statistics:
         dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids},
                                 plot_dir=os.path.join(cf.plot_dir,"dataset"))
 
     batch_gen = {}
     batch_gen['train'] = create_data_gen_pipeline(cf, train_data, is_training=True)
     batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, is_training=False)
     if cf.val_mode == 'val_patient':
         assert cf.training_gts == 'merged', 'val_patient not yet implemented for sa gts'
         batch_gen['val_patient'] = PatientBatchIterator_merged(cf, val_data) if cf.training_gts=='merged' \
             else PatientBatchIterator_sa(cf, val_data)
         batch_gen['n_val'] = len(val_data) if cf.max_val_patients=="all" else min(len(val_data), cf.max_val_patients)
     else:
         batch_gen['n_val'] = cf.num_val_batches
 
     return batch_gen
 
 def get_test_generator(cf, logger):
     """
     wrapper function for creating the test batch generator pipeline.
     selects patients according to cv folds (generated by first run/fold of experiment)
     If cf.held_out_test_set is True, gets the data from an external folder instead.
     """
     if cf.held_out_test_set:
         sourcedir = cf.test_data_sourcedir
         test_ids = None
     else:
         sourcedir = None
         with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle:
             set_splits = pickle.load(handle)
         test_ids = set_splits[cf.fold]
 
     test_data = Dataset(cf, logger, subset_ids=test_ids, data_sourcedir=sourcedir, mode="test").data
     logger.info("data set loaded with: {} test patients".format(len(test_ids)))
     batch_gen = {}
     batch_gen['test'] = PatientBatchIterator_merged(cf, test_data)
     batch_gen['n_test'] = len(test_ids) if cf.max_test_patients == "all" else min(cf.max_test_patients, len(test_ids))
     return batch_gen
 
 
 if __name__ == "__main__":
     import sys
     sys.path.append('../')
     import plotting as plg
     import utils.exp_utils as utils
     from configs import Configs
 
     cf = Configs()
     cf.batch_size = 3
     #dataset_path = os.path.dirname(os.path.realpath(__file__))
     #exp_path = os.path.join(dataset_path, "experiments/dev")
     #cf = utils.prep_exp(dataset_path, exp_path, server_env=False, use_stored_settings=False, is_training=True)
     cf.created_fold_id_pickle = False
     total_stime = time.time()
     times = {}
 
     # cf.server_env = True
     # cf.data_dir = "experiments/dev_data"
 
     # dataset = Dataset(cf)
     # patient = dataset['Master_00018']
     cf.exp_dir = "experiments/dev/"
     cf.plot_dir = cf.exp_dir + "plots"
     os.makedirs(cf.exp_dir, exist_ok=True)
     cf.fold = 0
     logger = utils.get_logger(cf.exp_dir)
     gens = get_train_generators(cf, logger)
     train_loader = gens['train']
 
 
 
     for i in range(1):
         stime = time.time()
         #ex_batch = next(train_loader)
         print("train batch", i)
         times["train_batch"] = time.time() - stime
         #plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch.png", show_gt_labels=True)
     #
     # # with open(os.path.join(cf.exp_dir, "fold_"+str(cf.fold), "BatchGenerator_stats.txt"), mode="w") as file:
     # #    train_loader.generator.print_stats(logger, file)
     #
     val_loader = gens['val_sampling']
     stime = time.time()
     ex_batch = next(val_loader)
     times["val_batch"] = time.time() - stime
     stime = time.time()
     #plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch.png", show_gt_labels=True, plot_mods=False,
     #               show_info=False)
     times["val_plot"] = time.time() - stime
     #
     test_loader = get_test_generator(cf, logger)["test"]
     stime = time.time()
     ex_batch = test_loader.generate_train_batch()
     times["test_batch"] = time.time() - stime
     stime = time.time()
     plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png", get_time=False)#, sample_picks=[0,1,2,3])
     times["test_patchbatch_plot"] = time.time() - stime
 
     # ex_batch['data'] = ex_batch['patient_data']
     # ex_batch['seg'] = ex_batch['patient_seg']
     # ex_batch['bb_target'] = ex_batch['patient_bb_target']
     # for item in cf.roi_items:
     #     ex_batch[]
     # stime = time.time()
     # #ex_batch = next(test_loader)
     # ex_batch = next(test_loader)
     # plg.view_batch(cf, ex_batch, show_gt_labels=False, show_gt_boxes=True, patient_items=True,# vol_slice_picks=[146,148, 218,220],
     #                 out_file="experiments/dev/dev_expatientbatch.png")  # , sample_picks=[0,1,2,3])
     # times["test_patient_batch_plot"] = time.time() - stime
 
 
 
     print("Times recorded throughout:")
     for (k, v) in times.items():
         print(k, "{:.2f}".format(v))
 
     mins, secs = divmod((time.time() - total_stime), 60)
     h, mins = divmod(mins, 60)
     t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
     print("{} total runtime: {}".format(os.path.split(__file__)[1], t))
diff --git a/datasets/lidc/preprocessing.py b/datasets/lidc/preprocessing.py
index 2f5efd4..a9c7f5d 100644
--- a/datasets/lidc/preprocessing.py
+++ b/datasets/lidc/preprocessing.py
@@ -1,478 +1,482 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 '''
 This preprocessing script loads nrrd files obtained by the data conversion tool: https://github.com/MIC-DKFZ/LIDC-IDRI-processing/tree/v1.0.1
 After applying preprocessing, images are saved as numpy arrays and the meta information for the corresponding patient is stored
 as a line in the dataframe saved as info_df.pickle.
 '''
 
 import os
 import sys
+import argparse
 import shutil
 import subprocess
 import pickle
 import time
 
 import SimpleITK as sitk
 import numpy as np
 from multiprocessing import Pool
 import pandas as pd
 import numpy.testing as npt
 from skimage.transform import resize
 
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 sys.path.append('../..')
 import data_manager as dmanager
 
 class AttributeDict(dict):
     __getattr__ = dict.__getitem__
     __setattr__ = dict.__setitem__
 
 def load_df(path):
     df = pd.read_pickle(path)
     print(df)
 
     return
 
 def resample_array(src_imgs, src_spacing, target_spacing):
     """ Resample a numpy array.
     :param src_imgs: source image.
     :param src_spacing: source image's spacing.
     :param target_spacing: spacing to resample source image to.
     :return:
     """
     src_spacing = np.round(src_spacing, 3)
     target_shape = [int(src_imgs.shape[ix] * src_spacing[::-1][ix] / target_spacing[::-1][ix]) for ix in range(len(src_imgs.shape))]
     for i in range(len(target_shape)):
         try:
             assert target_shape[i] > 0
         except:
             raise AssertionError("AssertionError:", src_imgs.shape, src_spacing, target_spacing)
 
     img = src_imgs.astype('float64')
     resampled_img = resize(img, target_shape, order=1, clip=True, mode='edge').astype('float32')
 
     return resampled_img
 
 class Preprocessor(object):
     """Preprocessor for LIDC raw data. Set in config: which ground truths to produce, choices are
         - "merged" for a single ground truth per input image, created by merging the given four rater annotations
             into one.
         - "single-annotator" for a four-fold ground truth per input image, created by leaving the each rater annotation
             separately.
     :param cf: config.
     :param exclude_inconsistents: bool or tuple, list, np.array, exclude patients that show technical inconsistencies
         in the raw files, likely due to file-naming mistakes. if bool and True: search for patients that have too many
         ratings per lesion or other inconstencies, exclude findings. if param is list of pids: exclude given pids.
     :param overwrite: look for patients that already exist in the pp dir. if overwrite is False, do not redo existing
         patients, otherwise ignore any existing files.
     :param max_count: maximum number of patients to preprocess.
     :param pids_subset: subset of pids to preprocess.
     """
 
     def __init__(self, cf, exclude_inconsistents=True, overwrite=False, max_count=None, pids_subset=None):
 
         self.cf = cf
 
         assert len(self.cf.gts_to_produce)>0, "need to specify which gts to produce, choices: 'merged', 'single_annotator'"
 
         self.paths = [os.path.join(cf.raw_data_dir, ii) for ii in os.listdir(cf.raw_data_dir)]
         if exclude_inconsistents:
             if isinstance(exclude_inconsistents, bool):
                 exclude_paths = self.exclude_too_many_ratings()
                 exclude_paths += self.verify_seg_label_pairings()
             else:
                 assert isinstance(exclude_inconsistents, (tuple,list,np.ndarray))
                 exclude_paths = exclude_inconsistents
             self.paths = [path for path in self.paths if path not in exclude_paths]
 
 
         if 'single_annotator' in self.cf.gts_to_produce or 'sa' in self.cf.gts_to_produce:
             self.pp_dir_sa = os.path.join(cf.pp_dir, "patient_gts_sa")
         if 'merged' in self.cf.gts_to_produce:
             self.pp_dir_merged = os.path.join(cf.pp_dir, "patient_gts_merged")
         orig_count = len(self.paths)
         # check if some patients already have ppd versions in destination dir
         if os.path.exists(cf.pp_dir) and not overwrite:
             fs_in_dir = os.listdir(cf.pp_dir)
             already_done =  [file.split("_")[0] for file in fs_in_dir if file.split("_")[1] == "img.npy"]
             if 'single_annotator' in self.cf.gts_to_produce or 'sa' in self.cf.gts_to_produce:
                 ext = '.npy' if hasattr(self.cf, "save_sa_segs_as") and (
                             self.cf.save_sa_segs_as == "npy" or self.cf.save_sa_segs_as == ".npy") else '.npz'
                 fs_in_dir = os.listdir(self.pp_dir_sa)
                 already_done = [ pid for pid in already_done if pid+"_rois"+ext in fs_in_dir and pid+"_meta_info.pickle" in fs_in_dir]
             if 'merged' in self.cf.gts_to_produce:
                 fs_in_dir = os.listdir(self.pp_dir_merged)
                 already_done = [pid for pid in already_done if
                                 pid + "_rois.npy" in fs_in_dir and pid+"_meta_info.pickle" in fs_in_dir]
 
             self.paths = [p for p in self.paths if not p.split(os.sep)[-1] in already_done]
             if len(self.paths)!=orig_count:
                 print("Due to existing ppd files: Selected a subset of {} patients from originally {}".format(len(self.paths), orig_count))
 
         if pids_subset:
             self.paths = [p for p in self.paths if p.split(os.sep)[-1] in pids_subset]
         if max_count is not None:
             self.paths = self.paths[:max_count]
 
         if not os.path.exists(cf.pp_dir):
             os.mkdir(cf.pp_dir)
         if ('single_annotator' in self.cf.gts_to_produce or 'sa' in self.cf.gts_to_produce) and \
                 not os.path.exists(self.pp_dir_sa):
             os.mkdir(self.pp_dir_sa)
         if 'merged' in self.cf.gts_to_produce and not os.path.exists(self.pp_dir_merged):
             os.mkdir(self.pp_dir_merged)
 
 
     def exclude_too_many_ratings(self):
         """exclude a patient's full path (the patient folder) from further processing if patient has nodules with
             ratings of more than four raters (which is inconsistent with what the raw data is supposed to comprise,
             also rater ids appear multiple times on the same nodule in these cases motivating the assumption that
             the same rater issued more than one rating / mixed up files or annotations for a nodule).
         :return: paths to be excluded.
         """
         exclude_paths = []
         for path in self.paths:
             roi_ids = set([ii.split('.')[0].split('_')[-1] for ii in os.listdir(path) if '.nii.gz' in ii])
             found = False
             for roi_id in roi_ids:
                 n_raters = len([ii for ii in os.listdir(path) if '{}.nii'.format(roi_id) in ii])
                 # assert n_raters<=4, "roi {} in path {} has {} raters".format(roi_id, path, n_raters)
                 if n_raters > 4:
                     print("roi {} in path {} has {} raters".format(roi_id, path, n_raters))
                     found = True
             if found:
                 exclude_paths.append(path)
         print("Patients excluded bc of too many raters:\n")
         for p in exclude_paths:
             print(p)
         print()
 
         return exclude_paths
 
     def analyze_lesion(self, pid, nodule_id):
         """print unique seg and counts of nodule nodule_id of patient pid.
         """
         nodule_id = nodule_id.lstrip("0")
         nodule_id_paths = [ii for ii in os.listdir(os.path.join(self.cf.raw_data_dir, pid)) if '.nii' in ii]
         nodule_id_paths = [ii for ii in nodule_id_paths if ii.split('_')[2].lstrip("0")==nodule_id]
         assert len(nodule_id_paths)==1
         nodule_path = nodule_id_paths[0]
 
         roi = sitk.ReadImage(os.path.join(self.cf.raw_data_dir, pid, nodule_path))
         roi_arr = sitk.GetArrayFromImage(roi).astype(np.uint8)
 
         print("pid {}, nodule {}, unique seg & counts: {}".format(pid, nodule_id, np.unique(roi_arr, return_counts=True)))
         return
 
     def verify_seg_label_pairing(self, path):
         """verifies that a nodule's segmentation has malignancy label > 0 if segmentation has foreground (>0 anywhere),
             and vice-versa that it has only background (==0 everywhere) if no malignancy label (==label 0) assigned.
         :param path: path to the patient folder.
         :return: df containing eventual inconsistency findings.
         """
 
         pid = path.split('/')[-1]
 
         df = pd.read_csv(os.path.join(self.cf.root_dir, 'characteristics.csv'), sep=';')
         df = df[df.PatientID == pid]
 
         findings_df = pd.DataFrame(columns=["problem", "pid", "roi_id", "nodule_id", "rater_ix", "seg_unique", "label"])
 
         print('verifying {}'.format(pid))
 
         roi_ids = set([ii.split('.')[0].split('_')[-1] for ii in os.listdir(path) if '.nii.gz' in ii])
 
         for roi_id in roi_ids:
             roi_id_paths = [ii for ii in os.listdir(path) if '{}.nii'.format(roi_id) in ii]
             nodule_ids = [rp.split('_')[2].lstrip("0") for rp in roi_id_paths]
             rater_ids = [rp.split('_')[1] for rp in roi_id_paths]
             rater_labels = [df[df.NoduleID == int(ii)].Malignancy.values[0] for ii in nodule_ids]
 
             # check double existence of nodule ids
             uniq, counts = np.unique(nodule_ids, return_counts=True)
             if np.any([count>1 for count in counts]):
                 finding = ("same nodule id exists more than once", pid, roi_id, nodule_ids, "N/A", "N/A", "N/A")
                 print("not unique nodule id", finding)
                 findings_df.loc[findings_df.shape[0]] = finding
 
             # check double gradings of single rater for single roi
             uniq, counts = np.unique(rater_ids, return_counts=True)
             if np.any([count>1 for count in counts]):
                 finding = ("same roi_id exists more than once for a single rater", pid, roi_id, nodule_ids, rater_ids, "N/A", rater_labels)
                 print("more than one grading per roi per single rater", finding)
                 findings_df.loc[findings_df.shape[0]] = finding
 
 
             rater_segs = []
             for rp in roi_id_paths:
                 roi = sitk.ReadImage(os.path.join(self.cf.raw_data_dir, pid, rp))
                 roi_arr = sitk.GetArrayFromImage(roi).astype(np.uint8)
 
                 rater_segs.append(roi_arr)
             rater_segs = np.array(rater_segs)
             for r in range(rater_segs.shape[0]):
                 if np.sum(rater_segs[r])>0:
                     if rater_labels[r]<=0:
                         finding =  ("non-empty seg w/ bg label", pid, roi_id, nodule_ids[r], rater_ids[r], np.unique(rater_segs[r]), rater_labels[r])
                         print("{}: pid {}, nodule {}, rater {}, seg unique {}, label {}".format(
                             *finding))
                         findings_df.loc[findings_df.shape[0]] = finding
                 else:
                     if rater_labels[r]>0:
                         finding = ("empty seg w/ fg label", pid, roi_id, nodule_ids[r], rater_ids[r], np.unique(rater_segs[r]), rater_labels[r])
                         print("{}: pid {}, nodule {}, rater {}, seg unique {}, label {}".format(
                             *finding))
                         findings_df.loc[findings_df.shape[0]] = finding
 
         return findings_df
 
     def verify_seg_label_pairings(self, processes=os.cpu_count()):
         """wrapper to multi-process verification of seg-label pairings.
         """
 
         pool = Pool(processes=processes)
         findings_dfs = pool.map(self.verify_seg_label_pairing, self.paths, chunksize=1)
         pool.close()
         pool.join()
 
         findings_df = pd.concat(findings_dfs, axis=0)
         findings_df.to_pickle(os.path.join(self.cf.pp_dir, "verification_seg_label_pairings.pickle"))
         findings_df.to_csv(os.path.join(self.cf.pp_dir, "verification_seg_label_pairings.csv"))
 
         return findings_df.pid.tolist()
 
     def produce_sa_gt(self, path, pid, df, img_spacing, img_arr_shape):
         """ Keep annotations separate, i.e., every processed image has four final GTs.
             Images are always saved as npy. For meeting hard-disk-memory constraints, segmentations can optionally be
             saved as .npz instead of .npy. Dataloader is only implemented for reading .npz segs.
         """
 
         final_rois = np.zeros((4, *img_arr_shape), dtype='uint8')
         patient_mal_labels = []
         roi_ids = list(set([ii.split('.')[0].split('_')[-1] for ii in os.listdir(path) if '.nii.gz' in ii]))
         roi_ids.sort() # just a precaution to have same order of lesions throughout separate runs
 
         rix = 1
         for roi_id in roi_ids:
             roi_id_paths = [ii for ii in os.listdir(path) if '{}.nii'.format(roi_id) in ii]
             assert len(roi_id_paths)>0 and len(roi_id_paths)<=4, "pid {}: should find 0< n_rois <4, but found {}".format(pid, len(roi_id_paths))
 
             """ not strictly necessary precaution: in theory, segmentations of different raters could overlap also for 
                 *different* rois, i.e., a later roi of a rater could (partially) cover up / destroy the roi of another 
                 rater. practically this is unlikely as overlapping lesions of different raters should be regarded as the
                 same lesion, but safety first. hence, the order of raters is maintained across rois, i.e., rater 0 
                 (marked as rater 0 in roi's file name) always has slot 0 in rater_labels and rater_segs, thereby rois
                 are certain to not overlap.
             """
             rater_labels, rater_segs = np.zeros((4,), dtype='uint8'), np.zeros((4,*img_arr_shape), dtype="float32")
             for ix, rp in enumerate(roi_id_paths): # one roi path per rater
                 nodule_id = rp.split('_')[2].lstrip("0")
                 assert not (nodule_id=="5728" or nodule_id=="8840"), "nodule ids {}, {} should be excluded due to seg-mal-label inconsistency.".format(5728, 8840)
                 rater = int(rp.split('_')[1])
                 rater_label = df[df.NoduleID == int(nodule_id)].Malignancy.values[0]
                 rater_labels[rater] = rater_label
 
                 roi = sitk.ReadImage(os.path.join(self.cf.raw_data_dir, pid, rp))
                 for dim in range(len(img_arr_shape)):
                     npt.assert_almost_equal(roi.GetSpacing()[dim], img_spacing[dim])
                 roi_arr = sitk.GetArrayFromImage(roi)
                 roi_arr = resample_array(roi_arr, roi.GetSpacing(), self.cf.target_spacing)
                 assert roi_arr.shape == img_arr_shape, [roi_arr.shape, img_arr_shape, pid, roi.GetSpacing()]
                 assert not np.any(rater_segs[rater]), "overwriting existing rater's seg with roi {}".format(rp)
                 rater_segs[rater] = roi_arr
             rater_segs = np.array(rater_segs)
 
             # rename/remap the malignancy to be positive.
             roi_mal_labels = [ii if ii > -1 else 0 for ii in rater_labels]
             assert rater_segs.shape == final_rois.shape, "rater segs shape {}, final rois shp {}".format(rater_segs.shape, final_rois.shape)
 
             # assert non-zero rating has non-zero seg
             for rater in range(4):
                 if roi_mal_labels[rater]>0:
                     assert np.any(rater_segs[rater]>0), "rater {} mal label {} but uniq seg {}".format(rater, roi_mal_labels[rater], np.unique(rater_segs[rater]))
 
             # add the roi to patient. i.e., write current lesion into final labels and seg of whole patient.
             assert np.any(rater_segs), "empty segmentations for all raters should not exist in single-annotator mode, pid {}, rois: {}".format(pid, roi_id_paths)
             patient_mal_labels.append(roi_mal_labels)
             final_rois[rater_segs > 0] = rix
             rix += 1
 
 
         fg_slices = [[ii for ii in np.unique(np.argwhere(final_rois[r] != 0)[:, 0])] for r in range(4)]
         patient_mal_labels = np.array(patient_mal_labels)
         roi_ids = np.unique(final_rois[final_rois>0])
         assert len(roi_ids) == len(patient_mal_labels), "mismatch {} rois in seg, {} rois in mal labels".format(len(roi_ids), len(patient_mal_labels))
 
         if hasattr(self.cf, "save_sa_segs_as") and (self.cf.save_sa_segs_as=="npy" or self.cf.save_sa_segs_as==".npy"):
             np.save(os.path.join(self.pp_dir_sa, '{}_rois.npy'.format(pid)), final_rois)
         else:
             np.savez_compressed(os.path.join(self.cf.pp_dir, 'patient_gts_sa', '{}_rois.npz'.format(pid)), seg=final_rois)
         with open(os.path.join(self.pp_dir_sa, '{}_meta_info.pickle'.format(pid)), 'wb') as handle:
             meta_info_dict = {'pid': pid, 'class_target': patient_mal_labels, 'spacing': img_spacing,
                               'fg_slices': fg_slices}
             pickle.dump(meta_info_dict, handle)
 
     def produce_merged_gt(self, path, pid, df, img_spacing, img_arr_shape):
         """ process patient with merged annotations, i.e., only one final GT per image. save img and seg to npy, rest to
             metadata.
             annotations merging:
                 - segmentations: only regard a pixel as foreground if at least two raters found it be foreground.
                 - malignancy labels: average over all four rater votes. every rater who did not assign a finding or
                     assigned -1 to the RoI contributes to the average with a vote of 0.
 
         :param path: path to patient folder.
         """
 
         final_rois = np.zeros(img_arr_shape, dtype=np.uint8)
         patient_mal_labels = []
         roi_ids = set([ii.split('.')[0].split('_')[-1] for ii in os.listdir(path) if '.nii.gz' in ii])
 
         rix = 1
         for roi_id in roi_ids:
             roi_id_paths = [ii for ii in os.listdir(path) if '{}.nii'.format(roi_id) in ii]
             nodule_ids = [ii.split('_')[2].lstrip("0") for ii in roi_id_paths]
             rater_labels = [df[df.NoduleID == int(ii)].Malignancy.values[0] for ii in nodule_ids]
             rater_labels.extend([0] * (4 - len(rater_labels)))
             mal_label = np.mean([ii if ii > -1 else 0 for ii in rater_labels])
             rater_segs = []
             for rp in roi_id_paths:
                 roi = sitk.ReadImage(os.path.join(self.cf.raw_data_dir, pid, rp))
                 for dim in range(len(img_arr_shape)):
                     npt.assert_almost_equal(roi.GetSpacing()[dim], img_spacing[dim])
                 roi_arr = sitk.GetArrayFromImage(roi).astype(np.uint8)
                 roi_arr = resample_array(roi_arr, roi.GetSpacing(), self.cf.target_spacing)
                 assert roi_arr.shape == img_arr_shape, [roi_arr.shape, img_arr_shape, pid, roi.GetSpacing()]
                 rater_segs.append(roi_arr)
             rater_segs.extend([np.zeros_like(rater_segs[-1])] * (4 - len(roi_id_paths)))
             rater_segs = np.mean(np.array(rater_segs), axis=0)
             # annotations merging: if less than two raters found fg, set segmentation to bg.
             rater_segs[rater_segs < 0.5] = 0
             if np.sum(rater_segs) > 0:
                 patient_mal_labels.append(mal_label)
                 final_rois[rater_segs > 0] = rix
                 rix += 1
             else:
                 # indicate rois suppressed by majority voting of raters
                 print('suppressed roi!', roi_id_paths)
                 with open(os.path.join(self.pp_dir_merged, 'suppressed_rois.txt'), 'a') as handle:
                     handle.write(" ".join(roi_id_paths))
 
         fg_slices = [ii for ii in np.unique(np.argwhere(final_rois != 0)[:, 0])]
         patient_mal_labels = np.array(patient_mal_labels)
         assert len(patient_mal_labels) + 1 == len(np.unique(final_rois)), [len(patient_mal_labels), np.unique(final_rois), pid]
         assert final_rois.dtype == 'uint8'
         np.save(os.path.join(self.pp_dir_merged, '{}_rois.npy'.format(pid)), final_rois)
 
         with open(os.path.join(self.pp_dir_merged, '{}_meta_info.pickle'.format(pid)), 'wb') as handle:
             meta_info_dict = {'pid': pid, 'class_target': patient_mal_labels, 'spacing': img_spacing,
                               'fg_slices': fg_slices}
             pickle.dump(meta_info_dict, handle)
 
     def pp_patient(self, path):
 
         pid = path.split('/')[-1]
         img = sitk.ReadImage(os.path.join(path, '{}_ct_scan.nrrd'.format(pid)))
         img_arr = sitk.GetArrayFromImage(img)
         print('processing {} with GT(s) {}, spacing {} and img shape {}.'.format(
             pid, " and ".join(self.cf.gts_to_produce), img.GetSpacing(), img_arr.shape))
         img_arr = resample_array(img_arr, img.GetSpacing(), self.cf.target_spacing)
         img_arr = np.clip(img_arr, -1200, 600)
         #img_arr = (1200 + img_arr) / (600 + 1200) * 255  # a+x / (b-a) * (c-d) (c, d = new)
         img_arr = img_arr.astype(np.float32)
         img_arr = (img_arr - np.mean(img_arr)) / np.std(img_arr).astype('float16')
 
         df = pd.read_csv(os.path.join(self.cf.root_dir, 'characteristics.csv'), sep=';')
         df = df[df.PatientID == pid]
 
         np.save(os.path.join(self.cf.pp_dir, '{}_img.npy'.format(pid)), img_arr)
         if 'single_annotator' in self.cf.gts_to_produce or 'sa' in self.cf.gts_to_produce:
             self.produce_sa_gt(path, pid, df, img.GetSpacing(), img_arr.shape)
         if 'merged' in self.cf.gts_to_produce:
             self.produce_merged_gt(path, pid, df, img.GetSpacing(), img_arr.shape)
 
 
     def iterate_patients(self, processes=os.cpu_count()):
         pool = Pool(processes=processes)
         pool.map(self.pp_patient, self.paths, chunksize=1)
         pool.close()
         pool.join()
         print("finished processing raw patient data")
 
 
     def aggregate_meta_info(self):
         self.dfs = {}
         for gt_kind in self.cf.gts_to_produce:
             kind_dir = self.pp_dir_merged if gt_kind == "merged" else self.pp_dir_sa
             files = [os.path.join(kind_dir, f) for f in os.listdir(kind_dir) if 'meta_info.pickle' in f]
             self.dfs[gt_kind] = pd.DataFrame(columns=['pid', 'class_target', 'spacing', 'fg_slices'])
             for f in files:
                 with open(f, 'rb') as handle:
                     self.dfs[gt_kind].loc[len(self.dfs[gt_kind])] = pickle.load(handle)
 
             self.dfs[gt_kind].to_pickle(os.path.join(kind_dir, 'info_df.pickle'))
             print("aggregated meta info to df with length", len(self.dfs[gt_kind]))
 
     def convert_copy_npz(self):
         npz_dir = os.path.join(self.cf.pp_dir+'_npz')
         print("converting to npz dir", npz_dir)
         os.makedirs(npz_dir, exist_ok=True)
 
         dmanager.pack_dataset(self.cf.pp_dir, destination=npz_dir, recursive=True, verbose=False)
         if hasattr(self, 'pp_dir_merged'):
             subprocess.call('rsync -avh --exclude="*.npy" {} {}'.format(self.pp_dir_merged, npz_dir), shell=True)
         if hasattr(self, 'pp_dir_sa'):
             subprocess.call('rsync -avh --exclude="*.npy" {} {}'.format(self.pp_dir_sa, npz_dir), shell=True)
 
 
 if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('-n', '--number', type=int, default=None, help='How many patients to maximally process.')
+    args = parser.parse_args()
     total_stime = time.time()
 
     import configs
-    cf = configs.configs()
+    cf = configs.Configs()
 
     # analysis finding: the following patients have unclear annotations. some raters gave more than one judgement
     # on the same roi.
     patients_to_exclude = ["0137a", "0404a", "0204a", "0252a", "0366a", "0863a", "0815a", "0060a", "0249a", "0436a", "0865a"]
     # further finding: the following patients contain nodules with segmentation-label inconsistencies
     # running Preprocessor.verify_seg_label_pairings() produces a data frame with detailed findings.
     patients_to_exclude += ["0305a", "0447a"]
     exclude_paths = [os.path.join(cf.raw_data_dir, pid) for pid in patients_to_exclude]
     # These pids are automatically found and excluded, when setting exclude_inconsistents=True at Preprocessor
     # initialization instead of passing the pre-compiled list.
 
 
-    pp = Preprocessor(cf, overwrite=True, exclude_inconsistents=exclude_paths, max_count=None, pids_subset=None)#["0998a"])
+    pp = Preprocessor(cf, overwrite=True, exclude_inconsistents=exclude_paths, max_count=args.number, pids_subset=None)#["0998a"])
     #pp.analyze_lesion("0305a", "5728")
     #pp.analyze_lesion("0305a", "5741")
     #pp.analyze_lesion("0447a", "8840")
 
     #pp.verify_seg_label_pairings()
     #load_df(os.path.join(cf.pp_dir, "verification_seg_label_pairings.pickle"))
     pp.iterate_patients(processes=8)
     # for i in ["/mnt/E130-Personal/Goetz/Datenkollektive/Lungendaten/Nodules_LIDC_IDRI/new_nrrd/0305a",
     #           "/mnt/E130-Personal/Goetz/Datenkollektive/Lungendaten/Nodules_LIDC_IDRI/new_nrrd/0447a"]:  #pp.paths[:1]:
     #      pp.pp_patient(i)
     pp.aggregate_meta_info()
     pp.convert_copy_npz()
 
 
 
     mins, secs = divmod((time.time() - total_stime), 60)
     h, mins = divmod(mins, 60)
     t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
     print("{} total runtime: {}".format(os.path.split(__file__)[1], t))
diff --git a/datasets/toy/configs.py b/datasets/toy/configs.py
index 6cb5859..ab45c69 100644
--- a/datasets/toy/configs.py
+++ b/datasets/toy/configs.py
@@ -1,490 +1,488 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import sys
 import os
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 import numpy as np
 from default_configs import DefaultConfigs
 from collections import namedtuple
 
 boxLabel = namedtuple('boxLabel', ["name", "color"])
 Label = namedtuple("Label", ['id', 'name', 'shape', 'radius', 'color', 'regression', 'ambiguities', 'gt_distortion'])
 binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'bin_vals'])
 
 class Configs(DefaultConfigs):
 
     def __init__(self, server_env=None):
         super(Configs, self).__init__(server_env)
 
         #########################
         #         Prepro        #
         #########################
-
-        self.pp_rootdir = os.path.join('/media/gregor/HDD2TB/data/toy', "cyl1ps_dev")
+        self.pp_rootdir = os.path.join('/home/gregor/datasets/toy', "cyl1ps_dev")
         self.pp_npz_dir = self.pp_rootdir+"_npz"
 
         self.pre_crop_size = [320,320,8] #y,x,z; determines pp data shape (2D easily implementable, but only 3D for now)
         self.min_2d_radius = 6 #in pixels
         self.n_train_samples, self.n_test_samples = 1200, 1000
 
         # not actually real one-hot encoding (ohe) but contains more info: roi-overlap only within classes.
         self.pp_create_ohe_seg = False
         self.pp_empty_samples_ratio = 0.1
 
         self.pp_place_radii_mid_bin = True
         self.pp_only_distort_2d = True
         # outer-most intensity of blurred radii, relative to inner-object intensity. <1 for decreasing, > 1 for increasing.
         # e.g.: setting 0.1 means blurred edge has min intensity 10% as large as inner-object intensity.
         self.pp_blur_min_intensity = 0.2
 
         self.max_instances_per_sample = 1 #how many max instances over all classes per sample (img if 2d, vol if 3d)
         self.max_instances_per_class = self.max_instances_per_sample  # how many max instances per image per class
         self.noise_scale = 0.  # std-dev of gaussian noise
 
         self.ambigs_sampling = "gaussian" #"gaussian" or "uniform"
         """ radius_calib: gt distort for calibrating uncertainty. Range of gt distortion is inferable from
             image by distinguishing it from the rest of the object.
             blurring width around edge will be shifted so that symmetric rel to orig radius.
             blurring scale: if self.ambigs_sampling is uniform, distribution's non-zero range (b-a) will be sqrt(12)*scale
             since uniform dist has variance (b-a)²/12. b,a will be placed symmetrically around unperturbed radius.
             if sampling is gaussian, then scale parameter sets one std dev, i.e., blurring width will be orig_radius * std_dev * 2.
         """
         self.ambiguities = {
              #set which classes to apply which ambs to below in class labels
              #choose out of: 'outer_radius', 'inner_radius', 'radii_relations'.
              #kind              #probability   #scale (gaussian std, relative to unperturbed value)
             #"outer_radius":     (1.,            0.5),
             #"outer_radius_xy":  (1.,            0.5),
             #"inner_radius":     (0.5,            0.1),
             #"radii_relations":  (0.5,            0.1),
             "radius_calib":     (1.,            1./6)
         }
 
         # shape choices: 'cylinder', 'block'
         #                        id,    name,       shape,      radius,                 color,              regression,     ambiguities,    gt_distortion
         self.pp_classes = [Label(1,     'cylinder', 'cylinder', ((6,6,1),(40,40,8)),    (*self.blue, 1.),   "radius_2d",    (),             ()),
                            #Label(2,      'block',      'block',        ((6,6,1),(40,40,8)),  (*self.aubergine,1.),  "radii_2d", (), ('radius_calib',))
             ]
 
 
         #########################
         #         I/O           #
         #########################
-
-        self.data_sourcedir = '/home/gregor/data/toy/cyl1ps_dev'
+        self.data_sourcedir = '/home/gregor/datasets/toy/cyl1ps_dev'
 
         if server_env:
             self.data_sourcedir = '/datasets/data_ramien/toy/cyl1ps_dev_npz'
 
 
         self.test_data_sourcedir = os.path.join(self.data_sourcedir, 'test')
         self.data_sourcedir = os.path.join(self.data_sourcedir, "train")
 
         self.info_df_name = 'info_df.pickle'
 
         # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn', 'detection_fpn'].
         self.model = 'mrcnn'
         self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net')
         self.model_path = os.path.join(self.source_dir, self.model_path)
 
 
         #########################
         #      Architecture     #
         #########################
 
         # one out of [2, 3]. dimension the model operates in.
         self.dim = 2
 
         # 'class', 'regression', 'regression_bin', 'regression_ken_gal'
         # currently only tested mode is a single-task at a time (i.e., only one task in below list)
         # but, in principle, tasks could be combined (e.g., object classes and regression per class)
         self.prediction_tasks = ['class', ]
 
         self.start_filts = 48 if self.dim == 2 else 18
         self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2
         self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50'
         self.norm = 'instance_norm' # one of None, 'instance_norm', 'batch_norm'
         self.relu = 'relu'
         # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform')
         self.weight_init = None
 
         self.regression_n_features = 1  # length of regressor target vector
 
 
         #########################
         #      Data Loader      #
         #########################
 
         self.num_epochs = 32
         self.num_train_batches = 120 if self.dim == 2 else 80
-        self.batch_size = 16 if self.dim == 2 else 8
+        self.batch_size = 8 if self.dim == 2 else 4
 
         self.n_cv_splits = 4
         # select modalities from preprocessed data
         self.channels = [0]
         self.n_channels = len(self.channels)
 
         # which channel (mod) to show as bg in plotting, will be extra added to batch if not in self.channels
         self.plot_bg_chan = 0
         self.crop_margin = [20, 20, 1]  # has to be smaller than respective patch_size//2
         self.patch_size_2D = self.pre_crop_size[:2]
         self.patch_size_3D = self.pre_crop_size[:2]+[8]
 
         # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation.
         self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D
 
         # ratio of free sampled batch elements before class balancing is triggered
         # (>0 to include "empty"/background patches.)
         self.batch_random_ratio = 0.2
         self.balance_target = "class_targets" if 'class' in self.prediction_tasks else "rg_bin_targets"
 
         self.observables_patient = []
         self.observables_rois = []
 
         self.seed = 3 #for generating folds
 
         #############################
         # Colors, Classes, Legends  #
         #############################
         self.plot_frequency = 1
 
         binary_bin_labels = [binLabel(1,  'r<=25',      (*self.green, 1.),      (1,25)),
                              binLabel(2,  'r>25',       (*self.red, 1.),        (25,))]
         quintuple_bin_labels = [binLabel(1,  'r2-10',   (*self.green, 1.),      (2,10)),
                                 binLabel(2,  'r10-20',  (*self.yellow, 1.),     (10,20)),
                                 binLabel(3,  'r20-30',  (*self.orange, 1.),     (20,30)),
                                 binLabel(4,  'r30-40',  (*self.bright_red, 1.), (30,40)),
                                 binLabel(5,  'r>40',    (*self.red, 1.), (40,))]
 
         # choose here if to do 2-way or 5-way regression-bin classification
         task_spec_bin_labels = quintuple_bin_labels
 
         self.class_labels = [
             # regression: regression-task label, either value or "(x,y,z)_radius" or "radii".
             # ambiguities: name of above defined ambig to apply to image data (not gt); need to be iterables!
             # gt_distortion: name of ambig to apply to gt only; needs to be iterable!
             #      #id  #name   #shape  #radius     #color              #regression #ambiguities    #gt_distortion
             Label(  0,  'bg',   None,   (0, 0, 0),  (*self.white, 0.),  (0, 0, 0),  (),             ())]
         if "class" in self.prediction_tasks:
             self.class_labels += self.pp_classes
         else:
             self.class_labels += [Label(1, 'object', 'object', ('various',), (*self.orange, 1.), ('radius_2d',), ("various",), ('various',))]
 
 
         if any(['regression' in task for task in self.prediction_tasks]):
             self.bin_labels = [binLabel(0,  'bg',       (*self.white, 1.),      (0,))]
             self.bin_labels += task_spec_bin_labels
             self.bin_id2label = {label.id: label for label in self.bin_labels}
             bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels]
             self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)}
             self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)]
             self.bin_dict = {label.id: label.name for label in self.bin_labels if label.id != 0}
 
         if self.class_specific_seg:
           self.seg_labels = self.class_labels
 
         self.box_type2label = {label.name: label for label in self.box_labels}
         self.class_id2label = {label.id: label for label in self.class_labels}
         self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0}
 
         self.seg_id2label = {label.id: label for label in self.seg_labels}
         self.cmap = {label.id: label.color for label in self.seg_labels}
 
         self.plot_prediction_histograms = True
         self.plot_stat_curves = False
         self.has_colorchannels = False
         self.plot_class_ids = True
 
         self.num_classes = len(self.class_dict)
         self.num_seg_classes = len(self.seg_labels)
 
         #########################
         #   Data Augmentation   #
         #########################
         self.do_aug = True
         self.da_kwargs = {
             'mirror': True,
             'mirror_axes': tuple(np.arange(0, self.dim, 1)),
             'do_elastic_deform': False,
             'alpha': (500., 1500.),
             'sigma': (40., 45.),
             'do_rotation': False,
             'angle_x': (0., 2 * np.pi),
             'angle_y': (0., 0),
             'angle_z': (0., 0),
             'do_scale': False,
             'scale': (0.8, 1.1),
             'random_crop': False,
             'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3),
             'border_mode_data': 'constant',
             'border_cval_data': 0,
             'order_data': 1
         }
 
         if self.dim == 3:
             self.da_kwargs['do_elastic_deform'] = False
             self.da_kwargs['angle_x'] = (0, 0.0)
             self.da_kwargs['angle_y'] = (0, 0.0)  # must be 0!!
             self.da_kwargs['angle_z'] = (0., 2 * np.pi)
 
         #########################
         #  Schedule / Selection #
         #########################
 
         # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training)
         # the former is morge accurate, while the latter is faster (depending on volume size)
         self.val_mode = 'val_sampling' # one of 'val_sampling' , 'val_patient'
         if self.val_mode == 'val_patient':
             self.max_val_patients = 220  # if 'all' iterates over entire val_set once.
         if self.val_mode == 'val_sampling':
             self.num_val_batches = 35 if self.dim==2 else 25
 
         self.save_n_models = 2
         self.min_save_thresh = 1 if self.dim == 2 else 1  # =wait time in epochs
         if "class" in self.prediction_tasks:
             self.model_selection_criteria = {name + "_ap": 1. for name in self.class_dict.values()}
         elif any("regression" in task for task in self.prediction_tasks):
             self.model_selection_criteria = {name + "_ap": 0.2 for name in self.class_dict.values()}
             self.model_selection_criteria.update({name + "_avp": 0.8 for name in self.class_dict.values()})
 
         self.lr_decay_factor = 0.5
         self.scheduling_patience = int(self.num_epochs / 5)
         self.weight_decay = 1e-5
         self.clip_norm = None  # number or None
 
         #########################
         #   Testing / Plotting  #
         #########################
 
         self.test_aug_axes = (0,1,(0,1)) # None or list: choices are 0,1,(0,1)
         self.held_out_test_set = True
         self.max_test_patients = "all"  # number or "all" for all
 
         self.test_against_exact_gt = True # only True implemented
         self.val_against_exact_gt = False # True is an unrealistic --> irrelevant scenario.
         self.report_score_level = ['rois']  # 'patient' or 'rois' (incl)
         self.patient_class_of_interest = 1
         self.patient_bin_of_interest = 2
 
         self.eval_bins_separately = False#"additionally" if not 'class' in self.prediction_tasks else False
         self.metrics = ['ap', 'auc', 'dice']
         if any(['regression' in task for task in self.prediction_tasks]):
             self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp',
                              'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp']
         if 'aleatoric' in self.model:
             self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted']
         self.evaluate_fold_means = True
 
         self.ap_match_ious = [0.5]  # threshold(s) for considering a prediction as true positive
         self.min_det_thresh = 0.3
 
         self.model_max_iou_resolution = 0.2
 
         # aggregation method for test and val_patient predictions.
         # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf,
         # nms = standard non-maximum suppression, or None = no clustering
         self.clustering = 'wbc'
         # iou thresh (exclusive!) for regarding two preds as concerning the same ROI
         self.clustering_iou = self.model_max_iou_resolution  # has to be larger than desired possible overlap iou of model predictions
 
         self.merge_2D_to_3D_preds = False
         self.merge_3D_iou = self.model_max_iou_resolution
         self.n_test_plots = 1  # per fold and rank
 
         self.test_n_epochs = self.save_n_models  # should be called n_test_ens, since is number of models to ensemble over during testing
         # is multiplied by (1 + nr of test augs)
 
         #########################
         #   Assertions          #
         #########################
         if not 'class' in self.prediction_tasks:
             assert self.num_classes == 1
 
         #########################
         #   Add model specifics #
         #########################
 
         {'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs,
          'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs,
          'detection_unet': self.add_det_unet_configs, 'detection_fpn': self.add_det_fpn_configs
          }[self.model]()
 
     def rg_val_to_bin_id(self, rg_val):
         #only meant for isotropic radii!!
         # only 2D radii (x and y dims) or 1D (x or y) are expected
         return np.round(np.digitize(rg_val, self.bin_edges).mean())
 
 
     def add_det_fpn_configs(self):
 
       self.learning_rate = [1 * 1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True
       self.scheduling_criterion = 'torch_loss'
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       self.n_roi_candidates = 4 if self.dim == 2 else 6
       # max number of roi candidates to identify per image (slice in 2D, volume in 3D)
 
       # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
       self.seg_loss_mode = 'wce'
       self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1]
 
       self.fp_dice_weight = 1 if self.dim == 2 else 1
       # if <1, false positive predictions in foreground are penalized less.
 
       self.detection_min_confidence = 0.05
       # how to determine score of roi: 'max' or 'median'
       self.score_det = 'max'
 
     def add_det_unet_configs(self):
 
       self.learning_rate = [1 * 1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True
       self.scheduling_criterion = "torch_loss"
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       # max number of roi candidates to identify per image (slice in 2D, volume in 3D)
       self.n_roi_candidates = 4 if self.dim == 2 else 6
 
       # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
       self.seg_loss_mode = 'wce'
       self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1]
       # if <1, false positive predictions in foreground are penalized less.
       self.fp_dice_weight = 1 if self.dim == 2 else 1
 
       self.detection_min_confidence = 0.05
       # how to determine score of roi: 'max' or 'median'
       self.score_det = 'max'
 
       self.init_filts = 32
       self.kernel_size = 3  # ks for horizontal, normal convs
       self.kernel_size_m = 2  # ks for max pool
       self.pad = "same"  # "same" or integer, padding of horizontal convs
 
     def add_mrcnn_configs(self):
 
       self.learning_rate = [1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True  # with scheduler set in exec
       self.scheduling_criterion = max(self.model_selection_criteria, key=self.model_selection_criteria.get)
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       # number of classes for network heads: n_foreground_classes + 1 (background)
       self.head_classes = self.num_classes + 1 if 'class' in self.prediction_tasks else 2
 
       # feed +/- n neighbouring slices into channel dimension. set to None for no context.
       self.n_3D_context = None
       if self.n_3D_context is not None and self.dim == 2:
         self.n_channels *= (self.n_3D_context * 2 + 1)
 
       self.detect_while_training = True
       # disable the re-sampling of mask proposals to original size for speed-up.
       # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching),
       # mask outputs are optional.
       self.return_masks_in_train = True
       self.return_masks_in_val = True
       self.return_masks_in_test = True
 
       # feature map strides per pyramid level are inferred from architecture. anchor scales are set accordingly.
       self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]}
       # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale
       # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.)
       self.rpn_anchor_scales = {'xy': [[4], [8], [16], [32]], 'z': [[1], [2], [4], [8]]}
       # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3.
       self.pyramid_levels = [0, 1, 2, 3]
       # number of feature maps in rpn. typically lowered in 3D to save gpu-memory.
       self.n_rpn_features = 512 if self.dim == 2 else 64
 
       # anchor ratios and strides per position in feature maps.
       self.rpn_anchor_ratios = [0.5, 1., 2.]
       self.rpn_anchor_stride = 1
       # Threshold for first stage (RPN) non-maximum suppression (NMS):  LOWER == HARDER SELECTION
       self.rpn_nms_threshold = max(0.8, self.model_max_iou_resolution)
 
       # loss sampling settings.
       self.rpn_train_anchors_per_image = 4
       self.train_rois_per_image = 6 # per batch_instance
       self.roi_positive_ratio = 0.5
       self.anchor_matching_iou = 0.8
 
       # k negative example candidates are drawn from a pool of size k*shem_poolsize (stochastic hard-example mining),
       # where k<=#positive examples.
       self.shem_poolsize = 2
 
       self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3)
       self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5)
       self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10)
 
       self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
       self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
       self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]])
       self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1],
                              self.patch_size_3D[2], self.patch_size_3D[2]])  # y1,x1,y2,x2,z1,z2
 
       if self.dim == 2:
         self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4]
         self.bbox_std_dev = self.bbox_std_dev[:4]
         self.window = self.window[:4]
         self.scale = self.scale[:4]
 
       self.plot_y_max = 1.5
       self.n_plot_rpn_props = 5 if self.dim == 2 else 30  # per batch_instance (slice in 2D / patient in 3D)
 
       # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element.
       self.pre_nms_limit = 2000 if self.dim == 2 else 4000
 
       # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True,
       # since proposals of the entire batch are forwarded through second stage as one "batch".
       self.roi_chunk_size = 1300 if self.dim == 2 else 500
       self.post_nms_rois_training = 200 * (self.head_classes-1) if self.dim == 2 else 400
       self.post_nms_rois_inference = 200 * (self.head_classes-1)
 
       # Final selection of detections (refine_detections)
       self.model_max_instances_per_batch_element = 9 if self.dim == 2 else 18 # per batch element and class.
       self.detection_nms_threshold = self.model_max_iou_resolution  # needs to be > 0, otherwise all predictions are one cluster.
       self.model_min_confidence = 0.2  # iou for nms in box refining (directly after heads), should be >0 since ths>=x in mrcnn.py
 
       if self.dim == 2:
         self.backbone_shapes = np.array(
           [[int(np.ceil(self.patch_size[0] / stride)),
             int(np.ceil(self.patch_size[1] / stride))]
            for stride in self.backbone_strides['xy']])
       else:
         self.backbone_shapes = np.array(
           [[int(np.ceil(self.patch_size[0] / stride)),
             int(np.ceil(self.patch_size[1] / stride)),
             int(np.ceil(self.patch_size[2] / stride_z))]
            for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z']
                                        )])
 
       if self.model == 'retina_net' or self.model == 'retina_unet':
         # whether to use focal loss or SHEM for loss-sample selection
         self.focal_loss = False
         # implement extra anchor-scales according to https://arxiv.org/abs/1708.02002
         self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                         self.rpn_anchor_scales['xy']]
         self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                        self.rpn_anchor_scales['z']]
         self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3
 
         # pre-selection of detections for NMS-speedup. per entire batch.
         self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size
 
         # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002
         self.anchor_matching_iou = 0.7
 
         if self.model == 'retina_unet':
           self.operate_stride1 = True
diff --git a/datasets/toy/data_loader.py b/datasets/toy/data_loader.py
index f4a444c..3f5387b 100644
--- a/datasets/toy/data_loader.py
+++ b/datasets/toy/data_loader.py
@@ -1,595 +1,596 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import sys
 sys.path.append('../') # works on cluster indep from where sbatch job is started
 import plotting as plg
 from multiprocessing import Pool
 
 import numpy as np
 import os
 from multiprocessing import Lock
 from collections import OrderedDict
 import pandas as pd
 import pickle
 import time
 
 # batch generator tools from https://github.com/MIC-DKFZ/batchgenerators
 from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror
 from batchgenerators.transforms.abstract_transforms import Compose
 from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
 from batchgenerators.transforms.spatial_transforms import SpatialTransform
 from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform
 
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 import utils.dataloader_utils as dutils
 from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates
 
 
 def load_obj(file_path):
     with open(file_path, 'rb') as handle:
         return pickle.load(handle)
 
 class Dataset(dutils.Dataset):
     r""" Load a dict holding memmapped arrays and clinical parameters for each patient,
     evtly subset of those.
         If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to
         cf.data_dir.
     :param cf: config file
     :param folds: number of folds out of @params n_cv folds to include
     :param n_cv: number of total folds
     :return: dict with imgs, segs, pids, class_labels, observables
     """
 
     def __init__(self, cf, logger, subset_ids=None, data_sourcedir=None, mode='train'):
         super(Dataset,self).__init__(cf, data_sourcedir=data_sourcedir)
 
         load_exact_gts = (mode=='test' or cf.val_mode=="val_patient") and self.cf.test_against_exact_gt
 
         p_df = pd.read_pickle(os.path.join(self.data_dir, cf.info_df_name))
 
         if subset_ids is not None:
             p_df = p_df[p_df.pid.isin(subset_ids)]
             logger.info('subset: selected {} instances from df'.format(len(p_df)))
 
         pids = p_df.pid.tolist()
         #evtly copy data from data_sourcedir to data_dest
         if cf.server_env and not hasattr(cf, "data_dir"):
             file_subset = [os.path.join(self.data_dir, '{}.*'.format(pid)) for pid in pids]
             file_subset += [os.path.join(self.data_dir, '{}_seg.*'.format(pid)) for pid in pids]
             file_subset += [cf.info_df_name]
             if load_exact_gts:
                 file_subset += [os.path.join(self.data_dir, '{}_exact_seg.*'.format(pid)) for pid in pids]
             self.copy_data(cf, file_subset=file_subset)
 
         img_paths = [os.path.join(self.data_dir, '{}.npy'.format(pid)) for pid in pids]
         seg_paths = [os.path.join(self.data_dir, '{}_seg.npy'.format(pid)) for pid in pids]
         if load_exact_gts:
             exact_seg_paths = [os.path.join(self.data_dir, '{}_exact_seg.npy'.format(pid)) for pid in pids]
 
         class_targets = p_df['class_ids'].tolist()
         rg_targets = p_df['regression_vectors'].tolist()
         if load_exact_gts:
             exact_rg_targets = p_df['undistorted_rg_vectors'].tolist()
         fg_slices = p_df['fg_slices'].tolist()
 
         self.data = OrderedDict()
         for ix, pid in enumerate(pids):
             self.data[pid] = {'data': img_paths[ix], 'seg': seg_paths[ix], 'pid': pid,
                               'fg_slices': np.array(fg_slices[ix])}
             if load_exact_gts:
                 self.data[pid]['exact_seg'] = exact_seg_paths[ix]
             if 'class' in self.cf.prediction_tasks:
                 self.data[pid]['class_targets'] = np.array(class_targets[ix], dtype='uint8')
             else:
                 self.data[pid]['class_targets'] = np.ones_like(np.array(class_targets[ix]), dtype='uint8')
             if load_exact_gts:
                 self.data[pid]['exact_class_targets'] = self.data[pid]['class_targets']
             if any(['regression' in task for task in self.cf.prediction_tasks]):
                 self.data[pid]['regression_targets'] = np.array(rg_targets[ix], dtype='float16')
                 self.data[pid]["rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in rg_targets[ix]], dtype='uint8')
                 if load_exact_gts:
                     self.data[pid]['exact_regression_targets'] = np.array(exact_rg_targets[ix], dtype='float16')
                     self.data[pid]["exact_rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in exact_rg_targets[ix]],
                                                                 dtype='uint8')
 
 
         cf.roi_items = cf.observables_rois[:]
         cf.roi_items += ['class_targets']
         if any(['regression' in task for task in self.cf.prediction_tasks]):
             cf.roi_items += ['regression_targets']
             cf.roi_items += ['rg_bin_targets']
 
         self.set_ids = np.array(list(self.data.keys()))
         self.df = None
 
 class BatchGenerator(dutils.BatchGenerator):
     """
     creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D)
     from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size.
     Actual patch_size is obtained after data augmentation.
     :param data: data dictionary as provided by 'load_dataset'.
     :param batch_size: number of patients to sample for the batch
     :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target
     """
     def __init__(self, cf, data, sample_pids_w_replace=True, max_batches=None, raise_stop_iteration=False, seed=0):
         super(BatchGenerator, self).__init__(cf, data, sample_pids_w_replace=sample_pids_w_replace,
                                              max_batches=max_batches, raise_stop_iteration=raise_stop_iteration,
                                              seed=seed)
 
         self.chans = cf.channels if cf.channels is not None else np.index_exp[:]
         assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing"
 
         self.crop_margin = np.array(self.cf.patch_size) / 8.  # min distance of ROI center to edge of cropped_patch.
         self.p_fg = 0.5
         self.empty_samples_max_ratio = 0.6
 
         self.balance_target_distribution(plot=sample_pids_w_replace)
 
     def generate_train_batch(self):
         # everything done in here is per batch
         # print statements in here get confusing due to multithreading
 
         batch_pids = self.get_batch_pids()
 
         batch_data, batch_segs, batch_patient_targets = [], [], []
         batch_roi_items = {name: [] for name in self.cf.roi_items}
         # record roi count and empty count of classes in batch
         # empty count for no presence of resp. class in whole sample (empty slices in 2D/patients in 3D)
         batch_roi_counts = np.zeros((len(self.unique_ts),), dtype='uint32')
         batch_empty_counts = np.zeros((len(self.unique_ts),), dtype='uint32')
 
         for b in range(len(batch_pids)):
             patient = self._data[batch_pids[b]]
 
             data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis]
             seg =  np.load(patient['seg'], mmap_mode='r').astype('uint8')
 
             (c, y, x, z) = data.shape
             if self.cf.dim == 2:
                 elig_slices, choose_fg = [], False
                 if len(patient['fg_slices']) > 0:
                     if np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or np.random.rand(
                             1) <= self.p_fg:
                         # fg is to be picked
                         for tix in np.argsort(batch_roi_counts):
                             # pick slices of patient that have roi of sought-for target
                             # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix
                             elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero(
                                 patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] ==
                                 self.unique_ts[tix]) > 0]
                             if len(elig_slices) > 0:
                                 choose_fg = True
                                 break
                     else:
                         # pick bg
                         elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices'])
                 if len(elig_slices) > 0:
                     sl_pick_ix = np.random.choice(elig_slices, size=None)
                 else:
                     sl_pick_ix = np.random.choice(z, size=None)
                 data = data[..., sl_pick_ix]
                 seg = seg[..., sl_pick_ix]
 
             spatial_shp = data[0].shape
             assert spatial_shp == seg.shape, "spatial shape incongruence betw. data and seg"
             if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]):
                 new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))]
                 data = dutils.pad_nd_image(data, (len(data), *new_shape))
                 seg = dutils.pad_nd_image(seg, new_shape)
 
             # eventual cropping to pre_crop_size: sample pixel from random ROI and shift center,
             # if possible, to that pixel, so that img still contains ROI after pre-cropping
             dim_cropflags = [spatial_shp[i] > self.cf.pre_crop_size[i] for i in range(len(spatial_shp))]
             if np.any(dim_cropflags):
                 # sample pixel from random ROI and shift center, if possible, to that pixel
                 if self.cf.dim==3:
                     choose_fg = np.any(batch_empty_counts/self.batch_size>=self.empty_samples_max_ratio) or \
                                 np.random.rand(1) <= self.p_fg
                 if choose_fg and np.any(seg):
                     available_roi_ids = np.unique(seg)[1:]
                     for tix in np.argsort(batch_roi_counts):
                         elig_roi_ids = available_roi_ids[patient[self.balance_target][available_roi_ids-1] == self.unique_ts[tix]]
                         if len(elig_roi_ids)>0:
                             seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None))
                             break
                     roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)]
                     assert seg[tuple(roi_anchor_pixel)] > 0
 
                     # sample the patch center coords. constrained by edges of image - pre_crop_size /2 and
                     # distance to the selected ROI < patch_size /2
                     def get_cropped_centercoords(dim):
                         low = np.max((self.cf.pre_crop_size[dim] // 2,
                                       roi_anchor_pixel[dim] - (
                                                   self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim])))
                         high = np.min((spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2,
                                        roi_anchor_pixel[dim] + (
                                                    self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim])))
                         if low >= high:  # happens if lesion on the edge of the image.
                             low = self.cf.pre_crop_size[dim] // 2
                             high = spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2
 
                         assert low < high, 'low greater equal high, data dimension {} too small, shp {}, patient {}, low {}, high {}'.format(
                             dim,
                             spatial_shp, patient['pid'], low, high)
                         return np.random.randint(low=low, high=high)
                 else:
                     # sample crop center regardless of ROIs, not guaranteed to be empty
                     def get_cropped_centercoords(dim):
                         return np.random.randint(low=self.cf.pre_crop_size[dim] // 2,
                                                  high=spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2)
 
                 sample_seg_center = {}
                 for dim in np.where(dim_cropflags)[0]:
                     sample_seg_center[dim] = get_cropped_centercoords(dim)
                     min_ = int(sample_seg_center[dim] - self.cf.pre_crop_size[dim] // 2)
                     max_ = int(sample_seg_center[dim] + self.cf.pre_crop_size[dim] // 2)
                     data = np.take(data, indices=range(min_, max_), axis=dim + 1)  # +1 for channeldim
                     seg = np.take(seg, indices=range(min_, max_), axis=dim)
 
             batch_data.append(data)
             batch_segs.append(seg[np.newaxis])
 
             for o in batch_roi_items: #after loop, holds every entry of every batchpatient per observable
                     batch_roi_items[o].append(patient[o])
 
             if self.cf.dim == 3:
                 for tix in range(len(self.unique_ts)):
                     non_zero = np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix])
                     batch_roi_counts[tix] += non_zero
                     batch_empty_counts[tix] += int(non_zero==0)
                     # todo remove assert when checked
                     if not np.any(seg):
                         assert non_zero==0
             elif self.cf.dim == 2:
                 for tix in range(len(self.unique_ts)):
                     non_zero = np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix])
                     batch_roi_counts[tix] += non_zero
                     batch_empty_counts[tix] += int(non_zero == 0)
                     # todo remove assert when checked
                     if not np.any(seg):
                         assert non_zero==0
 
         batch = {'data': np.array(batch_data), 'seg': np.array(batch_segs).astype('uint8'),
                  'pid': batch_pids,
                  'roi_counts': batch_roi_counts, 'empty_counts': batch_empty_counts}
         for key,val in batch_roi_items.items(): #extend batch dic by entries of observables dic
             batch[key] = np.array(val)
 
         return batch
 
 class PatientBatchIterator(dutils.PatientBatchIterator):
     """
     creates a test generator that iterates over entire given dataset returning 1 patient per batch.
     Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actually evaluation (done in 3D),
     if willing to accept speed-loss during training.
     Specific properties of toy data set: toy data may be created with added ground-truth noise. thus, there are
     exact ground truths (GTs) and noisy ground truths available. the normal or noisy GTs are used in training by
     the BatchGenerator. The PatientIterator, however, may use the exact GTs if set in configs.
 
     :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or
     batch_size = n_2D_patches in 2D .
     """
 
     def __init__(self, cf, data, mode='test'):
         super(PatientBatchIterator, self).__init__(cf, data)
 
         self.patch_size = cf.patch_size_2D + [1] if cf.dim == 2 else cf.patch_size_3D
         self.chans = cf.channels if cf.channels is not None else np.index_exp[:]
         assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing"
 
         if (mode=="validation" and hasattr(self.cf, 'val_against_exact_gt') and self.cf.val_against_exact_gt) or \
                 (mode == 'test' and self.cf.test_against_exact_gt):
             self.gt_prefix = 'exact_'
             print("PatientIterator: Loading exact Ground Truths.")
         else:
             self.gt_prefix = ''
 
         self.patient_ix = 0  # running index over all patients in set
 
     def generate_train_batch(self, pid=None):
 
         if pid is None:
             pid = self.dataset_pids[self.patient_ix]
         patient = self._data[pid]
 
         # already swapped dimensions in pp from (c,)z,y,x to c,y,x,z or h,w,d to ease 2D/3D-case handling
         data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis]
         seg =  np.load(patient[self.gt_prefix+'seg']).astype('uint8')[np.newaxis]
 
         data_shp_raw = data.shape
         plot_bg = data[self.cf.plot_bg_chan] if self.cf.plot_bg_chan not in self.chans else None
         data = data[self.chans]
         discarded_chans = len(
             [c for c in np.setdiff1d(np.arange(data_shp_raw[0]), self.chans) if c < self.cf.plot_bg_chan])
         spatial_shp = data[0].shape  # spatial dims need to be in order x,y,z
         assert spatial_shp == seg[0].shape, "spatial shape incongruence betw. data and seg"
 
         if np.any([spatial_shp[i] < ps for i, ps in enumerate(self.patch_size)]):
             new_shape = [np.max([spatial_shp[i], self.patch_size[i]]) for i in range(len(self.patch_size))]
             data = dutils.pad_nd_image(data, new_shape)  # use 'return_slicer' to crop image back to original shape.
             seg = dutils.pad_nd_image(seg, new_shape)
             if plot_bg is not None:
                 plot_bg = dutils.pad_nd_image(plot_bg, new_shape)
 
         if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds:
             # adds the batch dim here bc won't go through MTaugmenter
             out_data = data[np.newaxis]
             out_seg = seg[np.newaxis]
             if plot_bg is not None:
                out_plot_bg = plot_bg[np.newaxis]
             # data and seg shape: (1,c,x,y,z), where c=1 for seg
 
             batch_3D = {'data': out_data, 'seg': out_seg}
             for o in self.cf.roi_items:
                 batch_3D[o] = np.array([patient[self.gt_prefix+o]])
             converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg)
             batch_3D = converter(**batch_3D)
             batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape})
             for o in self.cf.roi_items:
                 batch_3D["patient_" + o] = batch_3D[o]
 
         if self.cf.dim == 2:
             out_data = np.transpose(data, axes=(3, 0, 1, 2)).astype('float32')  # (c,y,x,z) to (b=z,c,x,y), use z=b as batchdim
             out_seg = np.transpose(seg, axes=(3, 0, 1, 2)).astype('uint8')  # (c,y,x,z) to (b=z,c,x,y)
 
             batch_2D = {'data': out_data, 'seg': out_seg}
             for o in self.cf.roi_items:
                 batch_2D[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(out_data), axis=0)
             converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg)
             batch_2D = converter(**batch_2D)
 
             if plot_bg is not None:
                 out_plot_bg = np.transpose(plot_bg, axes=(2, 0, 1)).astype('float32')
 
             if self.cf.merge_2D_to_3D_preds:
                 batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'],
                                  'original_img_shape': out_data.shape})
                 for o in self.cf.roi_items:
                     batch_2D["patient_" + o] = batch_3D[o]
             else:
                 batch_2D.update({'patient_bb_target': batch_2D['bb_target'],
                                  'original_img_shape': out_data.shape})
                 for o in self.cf.roi_items:
                     batch_2D["patient_" + o] = batch_2D[o]
 
         out_batch = batch_3D if self.cf.dim == 3 else batch_2D
         out_batch.update({'pid': np.array([patient['pid']] * len(out_data))})
 
         if self.cf.plot_bg_chan in self.chans and discarded_chans > 0:  # len(self.chans[:self.cf.plot_bg_chan])<data_shp_raw[0]:
             assert plot_bg is None
             plot_bg = int(self.cf.plot_bg_chan - discarded_chans)
             out_plot_bg = plot_bg
         if plot_bg is not None:
             out_batch['plot_bg'] = out_plot_bg
 
         # eventual tiling into patches
         spatial_shp = out_batch["data"].shape[2:]
         if np.any([spatial_shp[ix] > self.patch_size[ix] for ix in range(len(spatial_shp))]):
             patient_batch = out_batch
             print("patientiterator produced patched batch!")
             patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size)
             new_img_batch, new_seg_batch = [], []
 
             for c in patch_crop_coords_list:
                 new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:c[5]])
                 seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]]
                 new_seg_batch.append(seg_patch)
             shps = []
             for arr in new_img_batch:
                 shps.append(arr.shape)
 
             data = np.array(new_img_batch)  # (patches, c, x, y, z)
             seg = np.array(new_seg_batch)
             if self.cf.dim == 2:
                 # all patches have z dimension 1 (slices). discard dimension
                 data = data[..., 0]
                 seg = seg[..., 0]
             patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'),
                            'pid': np.array([patient['pid']] * data.shape[0])}
             for o in self.cf.roi_items:
                 patch_batch[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(patch_crop_coords_list), axis=0)
             #patient-wise (orig) batch info for putting the patches back together after prediction
             for o in self.cf.roi_items:
                 patch_batch["patient_"+o] = patient_batch["patient_"+o]
                 if self.cf.dim == 2:
                     # this could also be named "unpatched_2d_roi_items"
                     patch_batch["patient_" + o + "_2d"] = patient_batch[o]
             patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
             patch_batch['patient_bb_target'] = patient_batch['patient_bb_target']
             if self.cf.dim == 2:
                 patch_batch['patient_bb_target_2d'] = patient_batch['bb_target']
             patch_batch['patient_data'] = patient_batch['data']
             patch_batch['patient_seg'] = patient_batch['seg']
             patch_batch['original_img_shape'] = patient_batch['original_img_shape']
             if plot_bg is not None:
                 patch_batch['patient_plot_bg'] = patient_batch['plot_bg']
 
             converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, get_rois_from_seg=False,
                                                            class_specific_seg=self.cf.class_specific_seg)
 
             patch_batch = converter(**patch_batch)
             out_batch = patch_batch
 
         self.patient_ix += 1
         if self.patient_ix == len(self.dataset_pids):
             self.patient_ix = 0
 
         return out_batch
 
 
 def create_data_gen_pipeline(cf, patient_data, do_aug=True, **kwargs):
     """
     create mutli-threaded train/val/test batch generation and augmentation pipeline.
     :param patient_data: dictionary containing one dictionary per patient in the train/test subset.
     :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing)
     :return: multithreaded_generator
     """
 
     # create instance of batch generator as first element in pipeline.
     data_gen = BatchGenerator(cf, patient_data, **kwargs)
 
     my_transforms = []
     if do_aug:
         if cf.da_kwargs["mirror"]:
             mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes'])
             my_transforms.append(mirror_transform)
 
         spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
                                              patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
                                              do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
                                              alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
                                              do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
                                              angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
                                              do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
                                              random_crop=cf.da_kwargs['random_crop'])
 
         my_transforms.append(spatial_transform)
     else:
         my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))
 
     my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg))
     all_transforms = Compose(my_transforms)
     # multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms)
-    multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
+    multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=data_gen.n_filled_threads,
+                                                     seeds=range(data_gen.n_filled_threads))
     return multithreaded_generator
 
 def get_train_generators(cf, logger, data_statistics=False):
     """
     wrapper function for creating the training batch generator pipeline. returns the train/val generators.
     selects patients according to cv folds (generated by first run/fold of experiment):
     splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set)
     If cf.hold_out_test_set is True, adds the test split to the training data.
     """
     dataset = Dataset(cf, logger)
     dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits)
     dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle'))
     set_splits = dataset.fg.splits
 
     test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1)
     train_ids = np.concatenate(set_splits, axis=0)
 
     if cf.held_out_test_set:
         train_ids = np.concatenate((train_ids, test_ids), axis=0)
         test_ids = []
 
     train_data = {k: v for (k, v) in dataset.data.items() if str(k) in train_ids}
     val_data = {k: v for (k, v) in dataset.data.items() if str(k) in val_ids}
 
     logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids),
                                                                                     len(test_ids)))
     if data_statistics:
         dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids}, plot_dir=
         os.path.join(cf.plot_dir,"dataset"))
 
 
 
     batch_gen = {}
     batch_gen['train'] = create_data_gen_pipeline(cf, train_data, do_aug=cf.do_aug, sample_pids_w_replace=True)
     if cf.val_mode == 'val_patient':
         batch_gen['val_patient'] = PatientBatchIterator(cf, val_data, mode='validation')
         batch_gen['n_val'] = len(val_ids) if cf.max_val_patients=="all" else min(len(val_ids), cf.max_val_patients)
     elif cf.val_mode == 'val_sampling':
         batch_gen['n_val'] = int(np.ceil(len(val_data)/cf.batch_size)) if cf.num_val_batches == "all" else cf.num_val_batches
         # in current setup, val loader is used like generator. with max_batches being applied in train routine.
         batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, do_aug=False, sample_pids_w_replace=False,
                                                              max_batches=None, raise_stop_iteration=False)
 
     return batch_gen
 
 def get_test_generator(cf, logger):
     """
     if get_test_generators is possibly called multiple times in server env, every time of
     Dataset initiation rsync will check for copying the data; this should be okay
     since rsync will not copy if files already exist in destination.
     """
 
     if cf.held_out_test_set:
         sourcedir = cf.test_data_sourcedir
         test_ids = None
     else:
         sourcedir = None
         with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle:
             set_splits = pickle.load(handle)
         test_ids = set_splits[cf.fold]
 
     test_set = Dataset(cf, logger, subset_ids=test_ids, data_sourcedir=sourcedir, mode='test')
     logger.info("data set loaded with: {} test patients".format(len(test_set.set_ids)))
     batch_gen = {}
     batch_gen['test'] = PatientBatchIterator(cf, test_set.data)
     batch_gen['n_test'] = len(test_set.set_ids) if cf.max_test_patients=="all" else \
         min(cf.max_test_patients, len(test_set.set_ids))
 
     return batch_gen
 
 
 if __name__=="__main__":
 
     import utils.exp_utils as utils
     from datasets.toy.configs import Configs
 
     cf = Configs()
 
     total_stime = time.time()
     times = {}
 
     # cf.server_env = True
     # cf.data_dir = "experiments/dev_data"
 
     cf.exp_dir = "experiments/dev/"
     cf.plot_dir = cf.exp_dir + "plots"
     os.makedirs(cf.exp_dir, exist_ok=True)
     cf.fold = 0
     logger = utils.get_logger(cf.exp_dir)
     gens = get_train_generators(cf, logger)
     train_loader = gens['train']
     for i in range(0):
         stime = time.time()
         print("producing training batch nr ", i)
         ex_batch = next(train_loader)
         times["train_batch"] = time.time() - stime
         #experiments/dev/dev_exbatch_{}.png".format(i)
         plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=False)
 
 
     val_loader = gens['val_sampling']
     stime = time.time()
     for i in range(1):
         ex_batch = next(val_loader)
         times["val_batch"] = time.time() - stime
         stime = time.time()
         #"experiments/dev/dev_exvalbatch_{}.png"
         plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=True)
         times["val_plot"] = time.time() - stime
     #
     test_loader = get_test_generator(cf, logger)["test"]
     stime = time.time()
     ex_batch = test_loader.generate_train_batch(pid=None)
     times["test_batch"] = time.time() - stime
     stime = time.time()
     plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png", vmin=0)
     times["test_patchbatch_plot"] = time.time() - stime
 
 
 
     print("Times recorded throughout:")
     for (k, v) in times.items():
         print(k, "{:.2f}".format(v))
 
     mins, secs = divmod((time.time() - total_stime), 60)
     h, mins = divmod(mins, 60)
     t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
     print("{} total runtime: {}".format(os.path.split(__file__)[1], t))
\ No newline at end of file
diff --git a/datasets/toy/generate_toys.py b/datasets/toy/generate_toys.py
index 0a3faeb..acaf3ba 100644
--- a/datasets/toy/generate_toys.py
+++ b/datasets/toy/generate_toys.py
@@ -1,399 +1,402 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 """ Generate a data set of toy examples. Examples can be cylinders, spheres, blocks, diamonds.
     Distortions may be applied, e.g., noise to the radius ground truths.
     Settings are configured in configs file.
 """
 
 import plotting as plg
 import os
 import shutil
 import warnings
 import time
 from multiprocessing import Pool
 
 import numpy as np
 import pandas as pd
 
 import data_manager as dmanager
 
 
 for msg in ["RuntimeWarning: divide by zero encountered in true_divide.*",]:
     warnings.filterwarnings("ignore", msg)
 
 
 class ToyGenerator(object):
     """ Generator of toy data set.
         A train and a test split with certain nr of samples are created and saved to disk. Samples can contain varying
         number of objects. Objects have shapes cylinder or block (diamond, ellipsoid, torus not fully implemented).
 
         self.mp_args holds image split and id, objects are then randomly drawn into each image. Multi-processing is
         enabled for parallel creation of images, final .npy-files can then be converted to .npz.
     """
     def __init__(self, cf):
         """
         :param cf: configs file holding object specifications and output directories.
         """
 
         self.cf = cf
 
         self.n_train, self.n_test = cf.n_train_samples, cf.n_test_samples
         self.sample_size = cf.pre_crop_size
         self.dim = len(self.sample_size)
         self.class_radii = np.array([label.radius for label in self.cf.pp_classes if label.id!=0])
         self.class_id2label = {label.id: label for label in self.cf.pp_classes}
 
         self.mp_args = []
         # count sample ids consecutively over train, test splits within on dataset (one shape kind)
         self.last_s_id = 0
         for split in ["train", "test"]:
             self.set_splits_info(split)
 
     def set_splits_info(self, split):
         """ Set info for data set splits, i.e., directory and nr of samples.
         :param split: name of split, in {"train", "test"}.
         """
         out_dir = os.path.join(self.cf.pp_rootdir, split)
         os.makedirs(out_dir, exist_ok=True)
 
         n_samples = self.n_train if "train" in split else self.n_test
         req_exact_gt = "test" in split
 
         self.mp_args += [[out_dir, self.last_s_id+running_id, req_exact_gt] for running_id in range(n_samples)]
         self.last_s_id+= n_samples
 
     def generate_sample_radii(self, class_ids, shapes):
 
         # the radii set in labels are ranges to sample from in the form [(min_x,min_y,min_z), (max_x,max_y,max_z)]
         all_radii = []
         for ix, cl_radii in enumerate([self.class_radii[cl_id - 1].transpose() for cl_id in class_ids]):
             if "cylinder" in shapes[ix] or "block" in shapes[ix]:
                 # maintain 2D aspect ratio
                 sample_radii = [np.random.uniform(*cl_radii[0])] * 2
                 assert len(sample_radii) == 2, "upper sr {}, cl_radii {}".format(sample_radii, cl_radii)
                 if self.cf.pp_place_radii_mid_bin:
                     bef_conv_r = np.copy(sample_radii)
                     bin_id =  self.cf.rg_val_to_bin_id(bef_conv_r)
                     assert np.isscalar(bin_id)
                     sample_radii = self.cf.bin_id2rg_val[bin_id]*2
                     assert len(sample_radii) == 2, "mid before sr {}, sr {}, rgv2bid {}, cl_radii {},  bid2rgval {}".format(bef_conv_r, sample_radii, bin_id, cl_radii,
                                                                                                              self.cf.bin_id2rg_val[bin_id])
             else:
                 raise NotImplementedError("requested object shape {}".format(shapes[ix]))
             if self.dim == 3:
                 assert len(sample_radii) == 2, "lower sr {}, cl_radii {}".format(sample_radii, cl_radii)
                 #sample_radii += [np.random.uniform(*cl_radii[2])]
                 sample_radii = np.concatenate((sample_radii, np.random.uniform(*cl_radii[2], size=1)))
             all_radii.append(sample_radii)
 
         return all_radii
 
     def apply_gt_distort(self, class_id, radii, radii_divs, outer_min_radii=None, outer_max_radii=None):
         """ Apply a distortion to the ground truth (gt). This is motivated by investigating the effects of noisy labels.
             GTs that can be distorted are the object radii and ensuing GT quantities like segmentation and regression
             targets.
         :param class_id: class id of object.
         :param radii: radii of object. This is in the abstract sense, s.t. for a block-shaped object radii give the side
             lengths.
         :param radii_divs: radii divisors, i.e., fractions to take from radii to get inner radii of hole-shaped objects,
             like a torus.
         :param outer_min_radii: min radii assignable when distorting gt.
         :param outer_max_radii: max radii assignable when distorting gt.
         :return:
         """
         applied_gt_distort = False
         for ambig in self.class_id2label[class_id].gt_distortion:
             if self.cf.ambiguities[ambig][0] > np.random.rand():
                 if ambig == "outer_radius":
                     radii = radii * abs(np.random.normal(1., self.cf.ambiguities["outer_radius"][1]))
                     applied_gt_distort = True
                 if ambig == "radii_relations":
                     radii = radii * abs(np.random.normal(1.,self.cf.ambiguities["radii_relations"][1],size=len(radii)))
                     applied_gt_distort = True
                 if ambig == "inner_radius":
                     radii_divs = radii_divs * abs(np.random.normal(1., self.cf.ambiguities["inner_radius"][1]))
                     applied_gt_distort = True
                 if ambig == "radius_calib":
                     if self.cf.ambigs_sampling=="uniform":
                         radii = abs(np.random.uniform(outer_min_radii, outer_max_radii))
                     elif self.cf.ambigs_sampling=="gaussian":
                         distort = abs(np.random.normal(1, scale=self.cf.ambiguities["radius_calib"][1], size=None))
                         assert len(radii) == self.dim, "radii {}".format(radii)
                         radii *= [distort, distort, 1.] if self.cf.pp_only_distort_2d else distort
                     applied_gt_distort = True
         return radii, radii_divs, applied_gt_distort
 
     def draw_object(self, img, seg, undistorted_seg, ics, regress_targets, undistorted_rg_targets, applied_gt_distort,
                                  roi_ix, class_id, shape, radii, center):
         """ Draw a single object into the given image and add it to the corresponding ground truths.
         :param img: image (volume) to hold the object.
         :param seg: pixel-wise labelling of the image, possibly distorted if gt distortions are applied.
         :param undistorted_seg: certainly undistorted, i.e., exact segmentation of object.
         :param ics: indices which mark the positions within the image.
         :param regress_targets: regression targets (e.g., 2D radii of object), evtly distorted.
         :param undistorted_rg_targets: undistorted regression targets.
         :param applied_gt_distort: boolean, whether or not gt distortion was applied.
         :param roi_ix: running index of object in whole image.
         :param class_id: class id of object.
         :param shape: shape of object (e.g., whether to draw a cylinder, or block, or ...).
         :param radii: radii of object (in an abstract sense, i.e., radii are side lengths in case of block shape).
         :param center: center of object in image coordinates.
         :return: img, seg, undistorted_seg, regress_targets, undistorted_rg_targets, applied_gt_distort, which are now
             extended are amended to reflect the new object.
         """
 
         radii_blur = hasattr(self.cf, "ambiguities") and hasattr(self.class_id2label[class_id],
                                                                  "gt_distortion") and 'radius_calib' in \
                      self.class_id2label[class_id].gt_distortion
 
         if radii_blur:
             blur_width = self.cf.ambiguities['radius_calib'][1]
             if self.cf.ambigs_sampling == "uniform":
                 blur_width *= np.sqrt(12)
             if self.cf.pp_only_distort_2d:
                 outer_max_radii = np.concatenate((radii[:2] + blur_width * radii[:2], [radii[2]]))
                 outer_min_radii = np.concatenate((radii[:2] - blur_width * radii[:2], [radii[2]]))
                 #print("belt width ", outer_max_radii - outer_min_radii)
             else:
                 outer_max_radii = radii + blur_width * radii
                 outer_min_radii = radii - blur_width * radii
         else:
             outer_max_radii, outer_min_radii = radii, radii
 
         if "ellipsoid" in shape or "torus" in shape:
             # sphere equation: (x-h)**2 + (y-k)**2 - (z-l)**2 = r**2
             # ellipsoid equation: ((x-h)/a)**2+((y-k)/b)**2+((z-l)/c)**2 <= 1; a, b, c the "radii"/ half-length of principal axes
             obj = ((ics - center) / radii) ** 2
         elif "diamond" in shape:
             # diamond equation: (|x-h|)/a+(|y-k|)/b+(|z-l|)/c <= 1
             obj = abs(ics - center) / radii
         elif "cylinder" in shape:
             # cylinder equation:((x-h)/a)**2 + ((y-k)/b)**2 <= 1 while |z-l| <= c
             obj = ((ics - center).astype("float64") / radii) ** 2
             # set z values s.t. z slices outside range are sorted out
             obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= radii[2], 0., 1.1)
             if radii_blur:
                 inner_obj = ((ics - center).astype("float64") / outer_min_radii) ** 2
                 inner_obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= outer_min_radii[2], 0., 1.1)
                 outer_obj = ((ics - center).astype("float64") / outer_max_radii) ** 2
                 outer_obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= outer_max_radii[2], 0., 1.1)
                 # radial dists: sqrt( (x-h)**2 + (y-k)**2 + (z-l)**2 )
                 obj_radial_dists = np.sqrt(np.sum((ics - center).astype("float64")**2, axis=1))
         elif "block" in shape:
             # block equation: (|x-h|)/a+(|y-k|)/b <= 1 while  |z-l| <= c
             obj = abs(ics - center) / radii
             obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= radii[2], 0., 1.1)
             if radii_blur:
                 inner_obj = abs(ics - center) / outer_min_radii
                 inner_obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= outer_min_radii[2], 0., 1.1)
                 outer_obj = abs(ics - center) / outer_max_radii
                 outer_obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= outer_max_radii[2], 0., 1.1)
                 obj_radial_dists = np.sum(abs(ics - center), axis=1).astype("float64")
         else:
             raise Exception("Invalid object shape '{}'".format(shape))
 
         # create the "original" GT, i.e., the actually true object and draw it into undistorted seg.
         obj = (np.sum(obj, axis=1) <= 1)
         obj = obj.reshape(seg[0].shape)
         slices_to_discard = np.where(np.count_nonzero(np.count_nonzero(obj, axis=0), axis=0) <= self.cf.min_2d_radius)[0]
         obj[..., slices_to_discard] = 0
         undistorted_radii = np.copy(radii)
         undistorted_seg[class_id][obj] = roi_ix + 1
         obj = obj.astype('float64')
 
         if radii_blur:
             inner_obj = np.sum(inner_obj, axis=1) <= 1
             outer_obj = (np.sum(outer_obj, axis=1) <= 1) & ~inner_obj
             obj_radial_dists[outer_obj] = obj_radial_dists[outer_obj] / max(obj_radial_dists[outer_obj])
             intensity_slope = self.cf.pp_blur_min_intensity - 1.
             # intensity(r) = (i(r_max)-i(0))/r_max * r + i(0), where i(0)==1.
             obj_radial_dists[outer_obj] = obj_radial_dists[outer_obj] * intensity_slope + 1.
             inner_obj = inner_obj.astype('float64')
             #outer_obj, obj_radial_dists = outer_obj.reshape(seg[0].shape), obj_radial_dists.reshape(seg[0].shape)
             inner_obj += np.where(outer_obj, obj_radial_dists, 0.)
             obj = inner_obj.reshape(seg[0].shape)
         if not np.any(obj):
             print("An object was completely discarded due to min 2d radius requirement, discarded slices: {}.".format(
                 slices_to_discard))
         # draw the evtly blurred obj into image.
         img += obj * (class_id + 1.)
 
         if hasattr(self.cf, "ambiguities") and hasattr(self.class_id2label[class_id], "gt_distortion"):
             radii_divs = [None]  # dummy since not implemented yet
             radii, radii_divs, applied_gt_distort = self.apply_gt_distort(class_id, radii, radii_divs,
                                                                           outer_min_radii, outer_max_radii)
             if applied_gt_distort:
                 if "ellipsoid" in shape or "torus" in shape:
                     obj = ((ics - center) / radii) ** 2
                 elif 'diamond' in shape:
                     obj = abs(ics - center) / radii
                 elif "cylinder" in shape:
                     obj = ((ics - center) / radii) ** 2
                     obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= radii[2], 0., 1.1)
                 elif "block" in shape:
                     obj = abs(ics - center) / radii
                     obj[:, -1] = np.where(abs((ics - center)[:, -1]) <= radii[2], 0., 1.1)
                 obj = (np.sum(obj, axis=1) <= 1).reshape(seg[0].shape)
                 obj[..., slices_to_discard] = False
 
         if self.class_id2label[class_id].regression == "radii":
             regress_targets.append(radii)
             undistorted_rg_targets.append(undistorted_radii)
         elif self.class_id2label[class_id].regression == "radii_2d":
             regress_targets.append(radii[:2])
             undistorted_rg_targets.append(undistorted_radii[:2])
         elif self.class_id2label[class_id].regression == "radius_2d":
             regress_targets.append(radii[:1])
             undistorted_rg_targets.append(undistorted_radii[:1])
         else:
             regress_targets.append(self.class_id2label[class_id].regression)
             undistorted_rg_targets.append(self.class_id2label[class_id].regression)
 
         seg[class_id][obj.astype('bool')] = roi_ix + 1
 
         return  img, seg, undistorted_seg, regress_targets, undistorted_rg_targets, applied_gt_distort
 
     def create_sample(self, args):
         """ Create a single sample and save to file. One sample is one image (volume) containing none, one, or multiple
             objects.
         :param args: out_dir: directory where to save sample, s_id: id of the sample.
         :return: specs that identify this single created image
         """
         out_dir, s_id, req_exact_gt = args
 
         print('processing {} {}'.format(out_dir, s_id))
         img = np.random.normal(loc=0.0, scale=self.cf.noise_scale, size=self.sample_size)
         img[img<0.] = 0.
         # one-hot-encoded seg
         seg = np.zeros((self.cf.num_classes+1, *self.sample_size)).astype('uint8')
         undistorted_seg = np.copy(seg)
         applied_gt_distort = False
 
         if hasattr(self.cf, "pp_empty_samples_ratio") and self.cf.pp_empty_samples_ratio >= np.random.rand():
             # generate fully empty sample
             class_ids, regress_targets, undistorted_rg_targets = [], [], []
         else:
             class_choices = np.repeat(np.arange(1, self.cf.num_classes+1), self.cf.max_instances_per_class)
             n_insts = np.random.randint(1, self.cf.max_instances_per_sample + 1)
             class_ids = np.random.choice(class_choices, size=n_insts, replace=False)
             shapes = np.array([self.class_id2label[cl_id].shape for cl_id in class_ids])
             all_radii = self.generate_sample_radii(class_ids, shapes)
 
             # reorder s.t. larger objects are drawn first (in order to not fully cover smaller objects)
             order = np.argsort(-1*np.prod(all_radii,axis=1))
             class_ids = class_ids[order]; all_radii = np.array(all_radii)[order]; shapes = shapes[order]
 
             regress_targets, undistorted_rg_targets = [], []
             # indices ics equal positions within img/volume
             ics = np.argwhere(np.ones(seg[0].shape))
             for roi_ix, class_id in enumerate(class_ids):
                 radii = all_radii[roi_ix]
                 # enforce distance between object center and image edge relative to radii.
                 margin_r_divisor = (2, 2, 4)
                 center = [np.random.randint(radii[dim] / margin_r_divisor[dim], img.shape[dim] -
                                             radii[dim] / margin_r_divisor[dim]) for dim in range(len(img.shape))]
 
                 img, seg, undistorted_seg, regress_targets, undistorted_rg_targets, applied_gt_distort = \
                     self.draw_object(img, seg, undistorted_seg, ics, regress_targets, undistorted_rg_targets, applied_gt_distort,
                                  roi_ix, class_id, shapes[roi_ix], radii, center)
 
         fg_slices = np.where(np.sum(np.sum(np.sum(seg,axis=0), axis=0), axis=0))[0]
         if self.cf.pp_create_ohe_seg:
             img = img[np.newaxis]
         else:
             # choosing rois to keep by smaller radius==higher prio needs to be ensured during roi generation,
             # smaller objects need to be drawn later (==higher roi id)
             seg = seg.max(axis=0)
             seg_ids = np.unique(seg)
             if len(seg_ids) != len(class_ids) + 1:
                 # in this case an object was completely covered by a succeeding object
                 print("skipping corrupt sample")
                 print("seg ids {}, class_ids {}".format(seg_ids, class_ids))
                 return None
             if not applied_gt_distort:
                 assert np.all(np.flatnonzero(img>0) == np.flatnonzero(seg>0))
                 assert np.all(np.array(regress_targets).flatten()==np.array(undistorted_rg_targets).flatten())
 
         # save the img
         out_path = os.path.join(out_dir, '{}.npy'.format(s_id))
         np.save(out_path, img.astype('float16'))
 
         # exact GT
         if req_exact_gt:
             if not self.cf.pp_create_ohe_seg:
                 undistorted_seg = undistorted_seg.max(axis=0)
             np.save(os.path.join(out_dir, '{}_exact_seg.npy'.format(s_id)), undistorted_seg)
         else:
             # if hasattr(self.cf, 'ambiguities') and \
             #     np.any([hasattr(label, "gt_distortion") and len(label.gt_distortion)>0 for label in self.class_id2label.values()]):
             # save (evtly) distorted GT
             np.save(os.path.join(out_dir, '{}_seg.npy'.format(s_id)), seg)
 
 
         return [out_dir, out_path, class_ids, regress_targets, fg_slices, undistorted_rg_targets, str(s_id)]
 
     def create_sets(self, processes=os.cpu_count()):
         """ Create whole training and test set, save to files under given directory cf.out_dir.
         :param processes: nr of parallel processes.
         """
 
 
         print('starting creation of {} images.'.format(len(self.mp_args)))
         shutil.copyfile("configs.py", os.path.join(self.cf.pp_rootdir, 'applied_configs.py'))
         pool = Pool(processes=processes)
-        imgs_info = pool.map(self.create_sample, self.mp_args)
+        try:
+            imgs_info = pool.map(self.create_sample, self.mp_args)
+        except AttributeError as e:
+            raise AttributeError("{}\nAre configs tasks = ['class', 'regression'] (both)?".format(e))
         imgs_info = [img for img in imgs_info if img is not None]
         pool.close()
         pool.join()
         print("created a total of {} samples.".format(len(imgs_info)))
 
         self.df = pd.DataFrame.from_records(imgs_info, columns=['out_dir', 'path', 'class_ids', 'regression_vectors',
                                                                 'fg_slices', 'undistorted_rg_vectors', 'pid'])
 
         for out_dir, group_df in self.df.groupby("out_dir"):
             group_df.to_pickle(os.path.join(out_dir, 'info_df.pickle'))
 
 
     def convert_copy_npz(self):
         """ Convert a copy of generated .npy-files to npz and save in .npz-directory given in configs.
         """
         if hasattr(self.cf, "pp_npz_dir") and self.cf.pp_npz_dir:
             for out_dir, group_df in self.df.groupby("out_dir"):
                 rel_dir = os.path.relpath(out_dir, self.cf.pp_rootdir).split(os.sep)
                 npz_out_dir = os.path.join(self.cf.pp_npz_dir, str(os.sep).join(rel_dir))
                 print("npz out dir: ", npz_out_dir)
                 os.makedirs(npz_out_dir, exist_ok=True)
                 group_df.to_pickle(os.path.join(npz_out_dir, 'info_df.pickle'))
                 dmanager.pack_dataset(out_dir, npz_out_dir, recursive=True, verbose=False)
         else:
             print("Did not convert .npy-files to .npz because npz directory not set in configs.")
 
 
 if __name__ == '__main__':
     import configs as cf
     cf = cf.Configs()
     total_stime = time.time()
 
     toy_gen = ToyGenerator(cf)
     toy_gen.create_sets()
     toy_gen.convert_copy_npz()
 
 
     mins, secs = divmod((time.time() - total_stime), 60)
     h, mins = divmod(mins, 60)
     t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
     print("{} total runtime: {}".format(os.path.split(__file__)[1], t))
diff --git a/exec.py b/exec.py
index 155100e..4d89fcd 100644
--- a/exec.py
+++ b/exec.py
@@ -1,341 +1,342 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 """ execution script. this where all routines come together and the only script you need to call.
     refer to parse args below to see options for execution.
 """
 
 import plotting as plg
 
 import os
 import warnings
 import argparse
 import time
 
 import torch
 
 import utils.exp_utils as utils
 from evaluator import Evaluator
 from predictor import Predictor
 
 
 for msg in ["Attempting to set identical bottom==top results",
             "This figure includes Axes that are not compatible with tight_layout",
             "Data has no positive values, and therefore cannot be log-scaled.",
             ".*invalid value encountered in true_divide.*"]:
     warnings.filterwarnings("ignore", msg)
 
 
 def train(cf, logger):
     """
     performs the training routine for a given fold. saves plots and selected parameters to the experiment dir
     specified in the configs. logs to file and tensorboard.
     """
     logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format(
         cf.dim, cf.fold, cf.exp_dir, cf.model))
     logger.time("train_val")
 
     # -------------- inits and settings -----------------
     net = model.net(cf, logger).cuda()
     if cf.optimizer == "ADAM":
         optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay)
     elif cf.optimizer == "SGD":
         optimizer = torch.optim.SGD(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay, momentum=0.3)
     if cf.dynamic_lr_scheduling:
         scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=cf.scheduling_mode, factor=cf.lr_decay_factor,
                                                                     patience=cf.scheduling_patience)
     model_selector = utils.ModelSelector(cf, logger)
 
     starting_epoch = 1
-    if cf.resume_from_checkpoint:
-        starting_epoch = utils.load_checkpoint(cf.resume_from_checkpoint, net, optimizer)
-        logger.info('resumed from checkpoint {} at epoch {}'.format(cf.resume_from_checkpoint, starting_epoch))
+    if cf.resume:
+        checkpoint_path = os.path.join(cf.fold_dir, "last_state.pth")
+        starting_epoch, net, optimizer, model_selector = \
+            utils.load_checkpoint(checkpoint_path, net, optimizer, model_selector)
+        logger.info('resumed from checkpoint {} to epoch {}'.format(checkpoint_path, starting_epoch))
 
     # prepare monitoring
     monitor_metrics = utils.prepare_monitoring(cf)
 
     logger.info('loading dataset and initializing batch generators...')
     batch_gen = data_loader.get_train_generators(cf, logger)
 
     # -------------- training -----------------
     for epoch in range(starting_epoch, cf.num_epochs + 1):
 
         logger.info('starting training epoch {}/{}'.format(epoch, cf.num_epochs))
         logger.time("train_epoch")
 
         net.train()
 
         train_results_list = []
         train_evaluator = Evaluator(cf, logger, mode='train')
 
         for i in range(cf.num_train_batches):
             logger.time("train_batch_loadfw")
             batch = next(batch_gen['train'])
             batch_gen['train'].generator.stats['roi_counts'] += batch['roi_counts']
             batch_gen['train'].generator.stats['empty_counts'] += batch['empty_counts']
 
             logger.time("train_batch_loadfw")
             logger.time("train_batch_netfw")
             results_dict = net.train_forward(batch)
             logger.time("train_batch_netfw")
             logger.time("train_batch_bw")
             optimizer.zero_grad()
             results_dict['torch_loss'].backward()
             if cf.clip_norm:
                 torch.nn.utils.clip_grad_norm_(net.parameters(), cf.clip_norm, norm_type=2) # gradient clipping
             optimizer.step()
             train_results_list.append(({k:v for k,v in results_dict.items() if k != "seg_preds"}, batch["pid"])) # slim res dict
             if not cf.server_env:
                 print("\rFinished training batch " +
                       "{}/{} in {:.1f}s ({:.2f}/{:.2f} forw load/net, {:.2f} backw).".format(i+1, cf.num_train_batches,
                                                                                              logger.get_time("train_batch_loadfw")+
                                                                                              logger.get_time("train_batch_netfw")
                                                                                              +logger.time("train_batch_bw"),
                                                                                              logger.get_time("train_batch_loadfw",reset=True),
                                                                                              logger.get_time("train_batch_netfw", reset=True),
                                                                                              logger.get_time("train_batch_bw", reset=True)), end="", flush=True)
         print()
 
         #--------------- train eval ----------------
         if (epoch-1)%cf.plot_frequency==0:
             # view an example batch
             utils.split_off_process(plg.view_batch, cf, batch, results_dict, has_colorchannels=cf.has_colorchannels,
                                     show_gt_labels=True, get_time="train-example plot",
                                     out_file=os.path.join(cf.plot_dir, 'batch_example_train_{}.png'.format(cf.fold)))
 
 
         logger.time("evals")
         _, monitor_metrics['train'] = train_evaluator.evaluate_predictions(train_results_list, monitor_metrics['train'])
         logger.time("evals")
         logger.time("train_epoch", toggle=False)
         del train_results_list
 
         #----------- validation ------------
         logger.info('starting validation in mode {}.'.format(cf.val_mode))
         logger.time("val_epoch")
         with torch.no_grad():
             net.eval()
             val_results_list = []
             val_evaluator = Evaluator(cf, logger, mode=cf.val_mode)
             val_predictor = Predictor(cf, net, logger, mode='val')
 
             for i in range(batch_gen['n_val']):
                 logger.time("val_batch")
                 batch = next(batch_gen[cf.val_mode])
                 if cf.val_mode == 'val_patient':
                     results_dict = val_predictor.predict_patient(batch)
                 elif cf.val_mode == 'val_sampling':
                     results_dict = net.train_forward(batch, is_validation=True)
                 val_results_list.append([results_dict, batch["pid"]])
                 if not cf.server_env:
                     print("\rFinished validation {} {}/{} in {:.1f}s.".format('patient' if cf.val_mode=='val_patient' else 'batch',
                                                                               i + 1, batch_gen['n_val'],
                                                                               logger.time("val_batch")), end="", flush=True)
             print()
 
             #------------ val eval -------------
             if (epoch - 1) % cf.plot_frequency == 0:
                 utils.split_off_process(plg.view_batch, cf, batch, results_dict, has_colorchannels=cf.has_colorchannels,
                                         show_gt_labels=True, get_time="val-example plot",
                                         out_file=os.path.join(cf.plot_dir, 'batch_example_val_{}.png'.format(cf.fold)))
 
             logger.time("evals")
             _, monitor_metrics['val'] = val_evaluator.evaluate_predictions(val_results_list, monitor_metrics['val'])
 
             model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch)
             del val_results_list
             #----------- monitoring -------------
             monitor_metrics.update({"lr": 
                 {str(g) : group['lr'] for (g, group) in enumerate(optimizer.param_groups)}})
             logger.metrics2tboard(monitor_metrics, global_step=epoch)
             logger.time("evals")
 
             logger.info('finished epoch {}/{}, took {:.2f}s. train total: {:.2f}s, average: {:.2f}s. val total: {:.2f}s, average: {:.2f}s.'.format(
                 epoch, cf.num_epochs, logger.get_time("train_epoch")+logger.time("val_epoch"), logger.get_time("train_epoch"),
                 logger.get_time("train_epoch", reset=True)/cf.num_train_batches, logger.get_time("val_epoch"),
                 logger.get_time("val_epoch", reset=True)/batch_gen["n_val"]))
             logger.info("time for evals: {:.2f}s".format(logger.get_time("evals", reset=True)))
 
         #-------------- scheduling -----------------
         if not cf.dynamic_lr_scheduling:
             for param_group in optimizer.param_groups:
                 param_group['lr'] = cf.learning_rate[epoch-1]
         else:
             scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1])
 
     logger.time("train_val")
     logger.info("Training and validating over {} epochs took {}".format(cf.num_epochs, logger.get_time("train_val", format="hms", reset=True)))
     batch_gen['train'].generator.print_stats(logger, plot=True)
 
 def test(cf, logger, max_fold=None):
     """performs testing for a given fold (or held out set). saves stats in evaluator.
     """
     logger.time("test_fold")
     logger.info('starting testing model of fold {} in exp {}'.format(cf.fold, cf.exp_dir))
     net = model.net(cf, logger).cuda()
     batch_gen = data_loader.get_test_generator(cf, logger)
 
     test_predictor = Predictor(cf, net, logger, mode='test')
     test_results_list = test_predictor.predict_test_set(batch_gen, return_results = not hasattr(
         cf, "eval_test_separately") or not cf.eval_test_separately)
 
     if test_results_list is not None:
         test_evaluator = Evaluator(cf, logger, mode='test')
         test_evaluator.evaluate_predictions(test_results_list)
         test_evaluator.score_test_df(max_fold=max_fold)
 
     logger.info('Testing of fold {} took {}.\n'.format(cf.fold, logger.get_time("test_fold", reset=True, format="hms")))
 
 if __name__ == '__main__':
     stime = time.time()
 
     parser = argparse.ArgumentParser()
     parser.add_argument('--dataset_name', type=str, default='toy',
                         help="path to the dataset-specific code in source_dir/datasets")
     parser.add_argument('--exp_dir', type=str, default='/home/gregor/Documents/regrcnn/datasets/toy/experiments/dev',
                         help='path to experiment dir. will be created if non existent.')
     parser.add_argument('-m', '--mode', type=str,  default='train_test', help='one out of: create_exp, analysis, train, train_test, or test')
     parser.add_argument('-f', '--folds', nargs='+', type=int, default=None, help='None runs over all folds in CV. otherwise specify list of folds.')
     parser.add_argument('--server_env', default=False, action='store_true', help='change IO settings to deploy models on a cluster.')
     parser.add_argument('--data_dest', type=str, default=None, help="path to final data folder if different from config")
     parser.add_argument('--use_stored_settings', default=False, action='store_true',
                         help='load configs from existing exp_dir instead of source dir. always done for testing, '
                              'but can be set to true to do the same for training. useful in job scheduler environment, '
                              'where source code might change before the job actually runs.')
-    parser.add_argument('--resume_from_checkpoint', type=str, default=None,
-                        help='path to checkpoint. if resuming from checkpoint, the desired fold still needs to be parsed via --folds.')
+    parser.add_argument('--resume', action="store_true", default=False,
+                        help='if given, resume from checkpoint(s) of the specified folds.')
     parser.add_argument('-d', '--dev', default=False, action='store_true', help="development mode: shorten everything")
 
     args = parser.parse_args()
     args.dataset_name = os.path.join("datasets", args.dataset_name) if not "datasets" in args.dataset_name else args.dataset_name
     folds = args.folds
-    resume_from_checkpoint = None if args.resume_from_checkpoint in ['None', 'none'] else args.resume_from_checkpoint
+    resume = None if args.resume in ['None', 'none'] else args.resume
 
     if args.mode == 'create_exp':
         cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=False)
         logger = utils.get_logger(cf.exp_dir, cf.server_env, -1)
         logger.info('created experiment directory at {}'.format(args.exp_dir))
 
     elif args.mode == 'train' or args.mode == 'train_test':
         cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, args.use_stored_settings)
         if args.dev:
             folds = [0,1]
-            cf.batch_size, cf.num_epochs, cf.min_save_thresh, cf.save_n_models = 3 if cf.dim==2 else 1, 1, 0, 1
+            cf.batch_size, cf.num_epochs, cf.min_save_thresh, cf.save_n_models = 3 if cf.dim==2 else 1, 2, 0, 1
             cf.num_train_batches, cf.num_val_batches, cf.max_val_patients = 5, 1, 1
             cf.test_n_epochs =  cf.save_n_models
             cf.max_test_patients = 1
             torch.backends.cudnn.benchmark = cf.dim==3
         else:
             torch.backends.cudnn.benchmark = cf.cuda_benchmark
         if args.data_dest is not None:
             cf.data_dest = args.data_dest
             
         logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval)
         data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py'))
         model = utils.import_module('model', cf.model_path)
         logger.info("loaded model from {}".format(cf.model_path))
         if folds is None:
             folds = range(cf.n_cv_splits)
 
         for fold in folds:
             """k-fold cross-validation: the dataset is split into k equally-sized folds, one used for validation,
             one for testing, the rest for training. This loop iterates k-times over the dataset, cyclically moving the
             splits. k==folds, fold in [0,folds) says which split is used for testing.
             """
             cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold)); cf.fold = fold
             logger.set_logfile(fold=fold)
-            cf.resume_from_checkpoint = resume_from_checkpoint
+            cf.resume = resume
+
             if not os.path.exists(cf.fold_dir):
                 os.mkdir(cf.fold_dir)
             train(cf, logger)
-            cf.resume_from_checkpoint = None
+            cf.resume = None
             if args.mode == 'train_test':
                 test(cf, logger)
 
     elif args.mode == 'test':
         cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False)
         if args.data_dest is not None:
             cf.data_dest = args.data_dest
         logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval)
         data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py'))
         model = utils.import_module('model', cf.model_path)
         logger.info("loaded model from {}".format(cf.model_path))
 
         fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if
                      os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")])
         if folds is None:
             folds = range(cf.n_cv_splits)
         if args.dev:
             folds = folds[:2]
             cf.batch_size, cf.max_test_patients, cf.test_n_epochs = 1 if cf.dim==2 else 1, 2, 2
         else:
             torch.backends.cudnn.benchmark = cf.cuda_benchmark
         for fold in folds:
             cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold)); cf.fold = fold
             logger.set_logfile(fold=fold)
             if cf.fold_dir in fold_dirs:
                 test(cf, logger, max_fold=max([int(f[-1]) for f in fold_dirs]))
             else:
                 logger.info("Skipping fold {} since no model parameters found.".format(fold))
     # load raw predictions saved by predictor during testing, run aggregation algorithms and evaluation.
     elif args.mode == 'analysis':
         """ analyse already saved predictions.
         """
         cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False)
         logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval)
 
         if cf.held_out_test_set and not cf.eval_test_fold_wise:
             predictor = Predictor(cf, net=None, logger=logger, mode='analysis')
             results_list = predictor.load_saved_predictions()
             logger.info('starting evaluation...')
             cf.fold = 0
             evaluator = Evaluator(cf, logger, mode='test')
             evaluator.evaluate_predictions(results_list)
             evaluator.score_test_df(max_fold=0)
         else:
             fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if
                          os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")])
             if args.dev:
                 fold_dirs = fold_dirs[:1]
             if folds is None:
                 folds = range(cf.n_cv_splits)
             for fold in folds:
                 cf.fold = fold; cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold))
                 logger.set_logfile(fold=fold)
                 if cf.fold_dir in fold_dirs:
                     predictor = Predictor(cf, net=None, logger=logger, mode='analysis')
                     results_list = predictor.load_saved_predictions()
                     # results_list[x][1] is pid, results_list[x][0] is list of len samples-per-patient, each entry hlds
                     # list of boxes per that sample, i.e., len(results_list[x][y][0]) would be nr of boxes in sample y of patient x
                     logger.info('starting evaluation...')
                     evaluator = Evaluator(cf, logger, mode='test')
                     evaluator.evaluate_predictions(results_list)
                     max_fold = max([int(f[-1]) for f in fold_dirs])
                     evaluator.score_test_df(max_fold=max_fold)
                 else:
                     logger.info("Skipping fold {} since no model parameters found.".format(fold))
     else:
         raise ValueError('mode "{}" specified in args is not implemented.'.format(args.mode))
         
     mins, secs = divmod((time.time() - stime), 60)
     h, mins = divmod(mins, 60)
     t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
     logger.info("{} total runtime: {}".format(os.path.split(__file__)[1], t))
     del logger
     torch.cuda.empty_cache()
 
-
-
diff --git a/models/mrcnn.py b/models/mrcnn.py
index e0b7982..e3bbb30 100644
--- a/models/mrcnn.py
+++ b/models/mrcnn.py
@@ -1,752 +1,752 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 """
 Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn
 published under MIT license.
 """
 import os
 from multiprocessing import  Pool
 import time
 
 import numpy as np
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.utils
 
 import utils.model_utils as mutils
 import utils.exp_utils as utils
 
 
 
 class RPN(nn.Module):
     """
     Region Proposal Network.
     """
 
     def __init__(self, cf, conv):
 
         super(RPN, self).__init__()
         self.dim = conv.dim
 
         self.conv_shared = conv(cf.end_filts, cf.n_rpn_features, ks=3, stride=cf.rpn_anchor_stride, pad=1, relu=cf.relu)
         self.conv_class = conv(cf.n_rpn_features, 2 * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None)
         self.conv_bbox = conv(cf.n_rpn_features, 2 * self.dim * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None)
 
 
     def forward(self, x):
         """
         :param x: input feature maps (b, in_channels, y, x, (z))
         :return: rpn_class_logits (b, 2, n_anchors)
         :return: rpn_probs_logits (b, 2, n_anchors)
         :return: rpn_bbox (b, 2 * dim, n_anchors)
         """
 
         # Shared convolutional base of the RPN.
         x = self.conv_shared(x)
 
         # Anchor Score. (batch, anchors per location * 2, y, x, (z)).
         rpn_class_logits = self.conv_class(x)
         # Reshape to (batch, 2, anchors)
         axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1)
         rpn_class_logits = rpn_class_logits.permute(*axes)
         rpn_class_logits = rpn_class_logits.contiguous()
         rpn_class_logits = rpn_class_logits.view(x.size()[0], -1, 2)
 
         # Softmax on last dimension (fg vs. bg).
         rpn_probs = F.softmax(rpn_class_logits, dim=2)
 
         # Bounding box refinement. (batch, anchors_per_location * (y, x, (z), log(h), log(w), (log(d)), y, x, (z))
         rpn_bbox = self.conv_bbox(x)
 
         # Reshape to (batch, 2*dim, anchors)
         rpn_bbox = rpn_bbox.permute(*axes)
         rpn_bbox = rpn_bbox.contiguous()
         rpn_bbox = rpn_bbox.view(x.size()[0], -1, self.dim * 2)
 
         return [rpn_class_logits, rpn_probs, rpn_bbox]
 
 
 
 class Classifier(nn.Module):
     """
     Head network for classification and bounding box refinement. Performs RoiAlign, processes resulting features through a
     shared convolutional base and finally branches off the classifier- and regression head.
     """
     def __init__(self, cf, conv):
         super(Classifier, self).__init__()
 
         self.cf = cf
         self.dim = conv.dim
         self.in_channels = cf.end_filts
         self.pool_size = cf.pool_size
         self.pyramid_levels = cf.pyramid_levels
         # instance_norm does not work with spatial dims (1, 1, (1))
         norm = cf.norm if cf.norm != 'instance_norm' else None
 
         self.conv1 = conv(cf.end_filts, cf.end_filts * 4, ks=self.pool_size, stride=1, norm=norm, relu=cf.relu)
         self.conv2 = conv(cf.end_filts * 4, cf.end_filts * 4, ks=1, stride=1, norm=norm, relu=cf.relu)
         self.linear_bbox = nn.Linear(cf.end_filts * 4, cf.head_classes * 2 * self.dim)
 
 
         if 'regression' in self.cf.prediction_tasks:
             self.linear_regressor = nn.Linear(cf.end_filts * 4, cf.head_classes * cf.regression_n_features)
             self.rg_n_feats = cf.regression_n_features
         #classify into bins of regression values
         elif 'regression_bin' in self.cf.prediction_tasks:
             self.linear_regressor = nn.Linear(cf.end_filts * 4, cf.head_classes * len(cf.bin_labels))
             self.rg_n_feats = len(cf.bin_labels)
         else:
             self.linear_regressor = lambda x: torch.zeros((x.shape[0], cf.head_classes * 1), dtype=torch.float32).fill_(float('NaN')).cuda()
             self.rg_n_feats = 1 #cf.regression_n_features
         if 'class' in self.cf.prediction_tasks:
             self.linear_class = nn.Linear(cf.end_filts * 4, cf.head_classes)
         else:
             assert cf.head_classes == 2, "#head classes {} needs to be 2 (bg/fg) when not predicting classes".format(cf.head_classes)
             self.linear_class = lambda x: torch.zeros((x.shape[0], cf.head_classes), dtype=torch.float64).cuda()
 
 
     def forward(self, x, rois):
         """
         :param x: input feature maps (b, in_channels, y, x, (z))
         :param rois: normalized box coordinates as proposed by the RPN to be forwarded through
         the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements
         have been merged to one vector, while the origin info has been stored for re-allocation.
         :return: mrcnn_class_logits (n_proposals, n_head_classes)
         :return: mrcnn_bbox (n_proposals, n_head_classes, 2 * dim) predicted corrections to be applied to proposals for refinement.
         """
         x = mutils.pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim)
         x = self.conv1(x)
         x = self.conv2(x)
         x = x.view(-1, self.in_channels * 4)
 
         mrcnn_bbox = self.linear_bbox(x)
         mrcnn_bbox = mrcnn_bbox.view(mrcnn_bbox.size()[0], -1, self.dim * 2)
         mrcnn_class_logits = self.linear_class(x)
         mrcnn_regress = self.linear_regressor(x)
         mrcnn_regress = mrcnn_regress.view(mrcnn_regress.size()[0], -1, self.rg_n_feats)
 
         return [mrcnn_bbox, mrcnn_class_logits, mrcnn_regress]
 
 
 class Mask(nn.Module):
     """
     Head network for proposal-based mask segmentation. Performs RoiAlign, some convolutions and applies sigmoid on the
     output logits to allow for overlapping classes.
     """
     def __init__(self, cf, conv):
         super(Mask, self).__init__()
         self.pool_size = cf.mask_pool_size
         self.pyramid_levels = cf.pyramid_levels
         self.dim = conv.dim
         self.conv1 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
         self.conv2 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
         self.conv3 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
         self.conv4 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
         if conv.dim == 2:
-            self.deconv = nn.ConvTranspose2d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2)
+            self.deconv = nn.ConvTranspose2d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2) # todo why no norm here?
         else:
             self.deconv = nn.ConvTranspose3d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2)
 
         self.relu = nn.ReLU(inplace=True) if cf.relu == 'relu' else nn.LeakyReLU(inplace=True)
         self.conv5 = conv(cf.end_filts, cf.head_classes, ks=1, stride=1, relu=None)
         self.sigmoid = nn.Sigmoid()
 
     def forward(self, x, rois):
         """
         :param x: input feature maps (b, in_channels, y, x, (z))
         :param rois: normalized box coordinates as proposed by the RPN to be forwarded through
         the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements
         have been merged to one vector, while the origin info has been stored for re-allocation.
         :return: x: masks (n_sampled_proposals (n_detections in inference), n_classes, y, x, (z))
         """
         x = mutils.pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim)
         x = self.conv1(x)
         x = self.conv2(x)
         x = self.conv3(x)
         x = self.conv4(x)
         x = self.relu(self.deconv(x))
         x = self.conv5(x)
         x = self.sigmoid(x)
         return x
 
 
 ############################################################
 #  Loss Functions
 ############################################################
 
 def compute_rpn_class_loss(rpn_class_logits, rpn_match, shem_poolsize):
     """
     :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors.
     :param rpn_class_logits: (n_anchors, 2). logits from RPN classifier.
     :param SHEM_poolsize: int. factor of top-k candidates to draw from per negative sample (stochastic-hard-example-mining).
     :return: loss: torch tensor
     :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training.
     """
 
     # Filter out netural anchors
     pos_indices = torch.nonzero(rpn_match == 1)
     neg_indices = torch.nonzero(rpn_match == -1)
 
     # loss for positive samples
     if not 0 in pos_indices.size():
         pos_indices = pos_indices.squeeze(1)
         roi_logits_pos = rpn_class_logits[pos_indices]
         pos_loss = F.cross_entropy(roi_logits_pos, torch.LongTensor([1] * pos_indices.shape[0]).cuda())
     else:
         pos_loss = torch.FloatTensor([0]).cuda()
 
     # loss for negative samples: draw hard negative examples (SHEM)
     # that match the number of positive samples, but at least 1.
     if not 0 in neg_indices.size():
         neg_indices = neg_indices.squeeze(1)
         roi_logits_neg = rpn_class_logits[neg_indices]
         negative_count = np.max((1, pos_indices.cpu().data.numpy().size))
         roi_probs_neg = F.softmax(roi_logits_neg, dim=1)
         neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize)
         neg_loss = F.cross_entropy(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda())
         np_neg_ix = neg_ix.cpu().data.numpy()
         #print("pos, neg count", pos_indices.cpu().data.numpy().size, negative_count)
     else:
         neg_loss = torch.FloatTensor([0]).cuda()
         np_neg_ix = np.array([]).astype('int32')
 
     loss = (pos_loss + neg_loss) / 2
     return loss, np_neg_ix
 
 
 def compute_rpn_bbox_loss(rpn_pred_deltas, rpn_target_deltas, rpn_match):
     """
     :param rpn_target_deltas:   (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))).
     Uses 0 padding to fill in unsed bbox deltas.
     :param rpn_pred_deltas: predicted deltas from RPN. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd))))
     :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors.
     :return: loss: torch 1D tensor.
     """
     if not 0 in torch.nonzero(rpn_match == 1).size():
 
         indices = torch.nonzero(rpn_match == 1).squeeze(1)
         # Pick bbox deltas that contribute to the loss
         rpn_pred_deltas = rpn_pred_deltas[indices]
         # Trim target bounding box deltas to the same length as rpn_bbox.
         target_deltas = rpn_target_deltas[:rpn_pred_deltas.size()[0], :]
         # Smooth L1 loss
         loss = F.smooth_l1_loss(rpn_pred_deltas, target_deltas)
     else:
         loss = torch.FloatTensor([0]).cuda()
 
     return loss
 
 def compute_mrcnn_bbox_loss(mrcnn_pred_deltas, mrcnn_target_deltas, target_class_ids):
     """
     :param mrcnn_target_deltas: (n_sampled_rois, (dy, dx, (dz), log(dh), log(dw), (log(dh)))
     :param mrcnn_pred_deltas: (n_sampled_rois, n_classes, (dy, dx, (dz), log(dh), log(dw), (log(dh)))
     :param target_class_ids: (n_sampled_rois)
     :return: loss: torch 1D tensor.
     """
     if not 0 in torch.nonzero(target_class_ids > 0).size():
         positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0]
         positive_roi_class_ids = target_class_ids[positive_roi_ix].long()
         target_bbox = mrcnn_target_deltas[positive_roi_ix, :].detach()
         pred_bbox = mrcnn_pred_deltas[positive_roi_ix, positive_roi_class_ids, :]
         loss = F.smooth_l1_loss(pred_bbox, target_bbox)
     else:
         loss = torch.FloatTensor([0]).cuda()
 
     return loss
 
 def compute_mrcnn_mask_loss(pred_masks, target_masks, target_class_ids):
     """
     :param target_masks: (n_sampled_rois, y, x, (z)) A float32 tensor of values 0 or 1. Uses zero padding to fill array.
     :param pred_masks: (n_sampled_rois, n_classes, y, x, (z)) float32 tensor with values between [0, 1].
     :param target_class_ids: (n_sampled_rois)
     :return: loss: torch 1D tensor.
     """
     #print("targ masks", target_masks.unique(return_counts=True))
     if not 0 in torch.nonzero(target_class_ids > 0).size():
         # Only positive ROIs contribute to the loss. And only
         # the class-specific mask of each ROI.
         positive_ix = torch.nonzero(target_class_ids > 0)[:, 0]
         positive_class_ids = target_class_ids[positive_ix].long()
         y_true = target_masks[positive_ix, :, :].detach()
         y_pred = pred_masks[positive_ix, positive_class_ids, :, :]
         loss = F.binary_cross_entropy(y_pred, y_true)
     else:
         loss = torch.FloatTensor([0]).cuda()
 
     return loss
 
 def compute_mrcnn_class_loss(tasks, pred_class_logits, target_class_ids):
     """
     :param pred_class_logits: (n_sampled_rois, n_classes)
     :param target_class_ids: (n_sampled_rois) batch dimension was merged into roi dimension.
     :return: loss: torch 1D tensor.
     """
     if 'class' in tasks and not 0 in target_class_ids.size():
         loss = F.cross_entropy(pred_class_logits, target_class_ids.long())
     else:
         loss = torch.FloatTensor([0.]).cuda()
 
     return loss
 
 def compute_mrcnn_regression_loss(tasks, pred, target, target_class_ids):
     """regression loss is a distance metric between target vector and predicted regression vector.
     :param pred: (n_sampled_rois, n_classes, [n_rg_feats if real regression or 1 if rg_bin task)
     :param target: (n_sampled_rois, [n_rg_feats or n_rg_bins])
     :return: differentiable loss, torch 1D tensor on cuda
     """
 
     if not 0 in target.shape and not 0 in torch.nonzero(target_class_ids > 0).shape:
         positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0]
         positive_roi_class_ids = target_class_ids[positive_roi_ix].long()
         target = target[positive_roi_ix].detach()
         pred = pred[positive_roi_ix, positive_roi_class_ids]
         if "regression_bin" in tasks:
             loss = F.cross_entropy(pred, target.long())
         else:
             loss = F.smooth_l1_loss(pred, target)
             #loss = F.mse_loss(pred, target)
     else:
         loss = torch.FloatTensor([0.]).cuda()
 
     return loss
 
 ############################################################
 #  Detection Layer
 ############################################################
 
 def compute_roi_scores(tasks, batch_rpn_proposals, mrcnn_cl_logits):
     """ Depending on the predicition tasks: if no class prediction beyong fg/bg (--> means no additional class
         head was applied) use RPN objectness scores as roi scores, otherwise class head scores.
     :param cf:
     :param batch_rpn_proposals:
     :param mrcnn_cl_logits:
     :return:
     """
     if not 'class' in tasks:
         scores = batch_rpn_proposals[:, :, -1].view(-1, 1)
         scores = torch.cat((1 - scores, scores), dim=1)
     else:
         scores = F.softmax(mrcnn_cl_logits, dim=1)
 
     return scores
 
 ############################################################
 #  MaskRCNN Class
 ############################################################
 
 class net(nn.Module):
 
 
     def __init__(self, cf, logger):
 
         super(net, self).__init__()
         self.cf = cf
         self.logger = logger
         self.build()
 
         loss_order = ['rpn_class', 'rpn_bbox', 'mrcnn_bbox', 'mrcnn_mask', 'mrcnn_class', 'mrcnn_rg']
         if hasattr(cf, "mrcnn_loss_weights"):
             # bring into right order
             self.loss_weights = np.array([cf.mrcnn_loss_weights[k] for k in loss_order])
         else:
             self.loss_weights = np.array([1.]*len(loss_order))
 
         if self.cf.weight_init=="custom":
             logger.info("Tried to use custom weight init which is not defined. Using pytorch default.")
         elif self.cf.weight_init:
             mutils.initialize_weights(self)
         else:
             logger.info("using default pytorch weight init")
 
     def build(self):
         """Build Mask R-CNN architecture."""
 
         # Image size must be dividable by 2 multiple times.
         h, w = self.cf.patch_size[:2]
         if h / 2**5 != int(h / 2**5) or w / 2**5 != int(w / 2**5):
             raise Exception("Image size must be divisible by 2 at least 5 times "
                             "to avoid fractions when downscaling and upscaling."
                             "For example, use 256, 288, 320, 384, 448, 512, ... etc.,i.e.,"
                             "any number x*32 will do!")
 
         # instantiate abstract multi-dimensional conv generator and load backbone module.
         backbone = utils.import_module('bbone', self.cf.backbone_path)
         self.logger.info("loaded backbone from {}".format(self.cf.backbone_path))
         conv = backbone.ConvGenerator(self.cf.dim)
 
         # build Anchors, FPN, RPN, Classifier / Bbox-Regressor -head, Mask-head
         self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf)
         self.anchors = torch.from_numpy(self.np_anchors).float().cuda()
         self.fpn = backbone.FPN(self.cf, conv, relu_enc=self.cf.relu, operate_stride1=False).cuda()
         self.rpn = RPN(self.cf, conv)
         self.classifier = Classifier(self.cf, conv)
         self.mask = Mask(self.cf, conv)
 
     def forward(self, img, is_training=True):
         """
         :param img: input images (b, c, y, x, (z)).
         :return: rpn_pred_logits: (b, n_anchors, 2)
         :return: rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d))))
         :return: batch_proposal_boxes: (b, n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix)) only for monitoring/plotting.
         :return: detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score)
         :return: detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head.
         """
         # extract features.
         fpn_outs = self.fpn(img)
         rpn_feature_maps = [fpn_outs[i] for i in self.cf.pyramid_levels]
         self.mrcnn_feature_maps = rpn_feature_maps
 
         # loop through pyramid layers and apply RPN.
         layer_outputs = [ self.rpn(p_feats) for p_feats in rpn_feature_maps ]
 
         # concatenate layer outputs.
         # convert from list of lists of level outputs to list of lists of outputs across levels.
         # e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]]
         outputs = list(zip(*layer_outputs))
         outputs = [torch.cat(list(o), dim=1) for o in outputs]
         rpn_pred_logits, rpn_pred_probs, rpn_pred_deltas = outputs
         #
         # # generate proposals: apply predicted deltas to anchors and filter by foreground scores from RPN classifier.
         proposal_count = self.cf.post_nms_rois_training if is_training else self.cf.post_nms_rois_inference
         batch_normed_props, batch_unnormed_props = mutils.refine_proposals(rpn_pred_probs, rpn_pred_deltas,
                                                                             proposal_count, self.anchors, self.cf)
 
         # merge batch dimension of proposals while storing allocation info in coordinate dimension.
         batch_ixs = torch.arange(
             batch_normed_props.shape[0]).cuda().unsqueeze(1).repeat(1,batch_normed_props.shape[1]).view(-1).float()
         rpn_rois = batch_normed_props[:, :, :-1].view(-1, batch_normed_props[:, :, :-1].shape[2])
         self.rpn_rois_batch_info = torch.cat((rpn_rois, batch_ixs.unsqueeze(1)), dim=1)
 
         # this is the first of two forward passes in the second stage, where no activations are stored for backprop.
         # here, all proposals are forwarded (with virtual_batch_size = batch_size * post_nms_rois.)
         # for inference/monitoring as well as sampling of rois for the loss functions.
         # processed in chunks of roi_chunk_size to re-adjust to gpu-memory.
         chunked_rpn_rois = self.rpn_rois_batch_info.split(self.cf.roi_chunk_size)
         bboxes_list, class_logits_list, regressions_list = [], [], []
         with torch.no_grad():
             for chunk in chunked_rpn_rois:
                 chunk_bboxes, chunk_class_logits, chunk_regressions = self.classifier(self.mrcnn_feature_maps, chunk)
                 bboxes_list.append(chunk_bboxes)
                 class_logits_list.append(chunk_class_logits)
                 regressions_list.append(chunk_regressions)
         mrcnn_bbox = torch.cat(bboxes_list, 0)
         mrcnn_class_logits = torch.cat(class_logits_list, 0)
         mrcnn_regressions = torch.cat(regressions_list, 0)
         self.mrcnn_roi_scores = compute_roi_scores(self.cf.prediction_tasks, batch_normed_props, mrcnn_class_logits)
 
         # refine classified proposals, filter and return final detections.
         # returns (cf.max_inst_per_batch_element, n_coords+1+...)
         detections = mutils.refine_detections(self.cf, batch_ixs, rpn_rois, mrcnn_bbox, self.mrcnn_roi_scores,
                                        mrcnn_regressions)
 
         # forward remaining detections through mask-head to generate corresponding masks.
         scale = [img.shape[2]] * 4 + [img.shape[-1]] * 2
         scale = torch.from_numpy(np.array(scale[:self.cf.dim * 2] + [1])[None]).float().cuda()
 
         # first self.cf.dim * 2 entries on axis 1 are always the box coords, +1 is batch_ix
         detection_boxes = detections[:, :self.cf.dim * 2 + 1] / scale
         with torch.no_grad():
             detection_masks = self.mask(self.mrcnn_feature_maps, detection_boxes)
 
         return [rpn_pred_logits, rpn_pred_deltas, batch_unnormed_props, detections, detection_masks]
 
 
     def loss_samples_forward(self, batch_gt_boxes, batch_gt_masks, batch_gt_class_ids, batch_gt_regressions=None):
         """
         this is the second forward pass through the second stage (features from stage one are re-used).
         samples few rois in loss_example_mining and forwards only those for loss computation.
         :param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels.
         :param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates.
         :param batch_gt_masks: (b, n(b), c, y, x (,z)) list over batch elements. Each element holds n_gt_rois(b)
                 (i.e., dependent on the batch element) binary masks of shape (c, y, x, (z)).
         :return: sample_logits: (n_sampled_rois, n_classes) predicted class scores.
         :return: sample_deltas: (n_sampled_rois, n_classes, 2 * dim) predicted corrections to be applied to proposals for refinement.
         :return: sample_mask: (n_sampled_rois, n_classes, y, x, (z)) predicted masks per class and proposal.
         :return: sample_target_class_ids: (n_sampled_rois) target class labels of sampled proposals.
         :return: sample_target_deltas: (n_sampled_rois, 2 * dim) target deltas of sampled proposals for box refinement.
         :return: sample_target_masks: (n_sampled_rois, y, x, (z)) target masks of sampled proposals.
         :return: sample_proposals: (n_sampled_rois, 2 * dim) RPN output for sampled proposals. only for monitoring/plotting.
         """
         # sample rois for loss and get corresponding targets for all Mask R-CNN head network losses.
         sample_ics, sample_target_deltas, sample_target_mask, sample_target_class_ids, sample_target_regressions = \
             mutils.loss_example_mining(self.cf, self.rpn_rois_batch_info, batch_gt_boxes, batch_gt_masks,
                                        self.mrcnn_roi_scores, batch_gt_class_ids, batch_gt_regressions)
 
         # re-use feature maps and RPN output from first forward pass.
         sample_proposals = self.rpn_rois_batch_info[sample_ics]
         if not 0 in sample_proposals.size():
             sample_deltas, sample_logits, sample_regressions = self.classifier(self.mrcnn_feature_maps, sample_proposals)
             sample_mask = self.mask(self.mrcnn_feature_maps, sample_proposals)
         else:
             sample_logits = torch.FloatTensor().cuda()
             sample_deltas = torch.FloatTensor().cuda()
             sample_regressions = torch.FloatTensor().cuda()
             sample_mask = torch.FloatTensor().cuda()
 
         return [sample_deltas, sample_mask, sample_logits, sample_regressions, sample_proposals,
                 sample_target_deltas, sample_target_mask, sample_target_class_ids, sample_target_regressions]
 
     def get_results(self, img_shape, detections, detection_masks, box_results_list=None, return_masks=True):
         """
         Restores batch dimension of merged detections, unmolds detections, creates and fills results dict.
         :param img_shape:
         :param detections: shape (n_final_detections, len(info)), where
             info=( y1, x1, y2, x2, (z1,z2), batch_ix, pred_class_id, pred_score )
         :param detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head.
         :param box_results_list: None or list of output boxes for monitoring/plotting.
         each element is a list of boxes per batch element.
         :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off).
         :return: results_dict: dictionary with keys:
                  'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                           [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                  'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, 1] only fg. vs. bg for now.
                  class-specific return of masks will come with implementation of instance segmentation evaluation.
         """
 
         detections = detections.cpu().data.numpy()
         if self.cf.dim == 2:
             detection_masks = detection_masks.permute(0, 2, 3, 1).cpu().data.numpy()
         else:
             detection_masks = detection_masks.permute(0, 2, 3, 4, 1).cpu().data.numpy()
         # det masks shape now (n_dets, y,x(,z), n_classes)
         # restore batch dimension of merged detections using the batch_ix info.
         batch_ixs = detections[:, self.cf.dim*2]
         detections = [detections[batch_ixs == ix] for ix in range(img_shape[0])]
         mrcnn_mask = [detection_masks[batch_ixs == ix] for ix in range(img_shape[0])]
         # mrcnn_mask: shape (b_size, variable, variable, n_classes), variable bc depends on single instance mask size
 
         if box_results_list == None: # for test_forward, where no previous list exists.
             box_results_list =  [[] for _ in range(img_shape[0])]
         # seg_logits == seg_probs in mrcnn since mask head finishes with sigmoid (--> image space = [0,1])
         seg_probs = []
         # loop over batch and unmold detections.
         for ix in range(img_shape[0]):
 
             # final masks are one-hot encoded (b, n_classes, y, x, (z))
             final_masks = np.zeros((self.cf.num_classes + 1, *img_shape[2:]))
             #+1 for bg, 0.5 bc mask head classifies only bg/fg with logits between 0,1--> bg is <0.5
             if self.cf.num_classes + 1 != self.cf.num_seg_classes:
                 self.logger.warning("n of roi-classifier head classes {} doesnt match cf.num_seg_classes {}".format(
                     self.cf.num_classes + 1, self.cf.num_seg_classes))
 
             if not 0 in detections[ix].shape:
                 boxes = detections[ix][:, :self.cf.dim*2].astype(np.int32)
                 class_ids = detections[ix][:, self.cf.dim*2 + 1].astype(np.int32)
                 scores = detections[ix][:, self.cf.dim*2 + 2]
                 masks = mrcnn_mask[ix][np.arange(boxes.shape[0]), ..., class_ids]
                 regressions = detections[ix][:,self.cf.dim*2+3:]
 
                 # Filter out detections with zero area. Often only happens in early
                 # stages of training when the network weights are still a bit random.
                 if self.cf.dim == 2:
                     exclude_ix = np.where((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0]
                 else:
                     exclude_ix = np.where(
                         (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4]) <= 0)[0]
 
                 if exclude_ix.shape[0] > 0:
                     boxes = np.delete(boxes, exclude_ix, axis=0)
                     masks = np.delete(masks, exclude_ix, axis=0)
                     class_ids = np.delete(class_ids, exclude_ix, axis=0)
                     scores = np.delete(scores, exclude_ix, axis=0)
                     regressions = np.delete(regressions, exclude_ix, axis=0)
 
                 # Resize masks to original image size and set boundary threshold.
                 if return_masks:
                     for i in range(masks.shape[0]): #masks per this batch instance/element/image
                         # Convert neural network mask to full size mask
                         if self.cf.dim == 2:
                             full_mask = mutils.unmold_mask_2D(masks[i], boxes[i], img_shape[2:])
                         else:
                             full_mask = mutils.unmold_mask_3D(masks[i], boxes[i], img_shape[2:])
                         # take the maximum seg_logits per class of instances in that class, i.e., a pixel in a class
                         # has the max seg_logit value over all instances of that class in one sample
                         final_masks[class_ids[i]] = np.max((final_masks[class_ids[i]], full_mask), axis=0)
                     final_masks[0] = np.full(final_masks[0].shape, 0.49999999) #effectively min_det_thres at 0.5 per pixel
 
                 # add final predictions to results.
                 if not 0 in boxes.shape:
                     for ix2, coords in enumerate(boxes):
                         box = {'box_coords': coords, 'box_type': 'det', 'box_score': scores[ix2],
                                'box_pred_class_id': class_ids[ix2]}
                         #if (hasattr(self.cf, "convert_cl_to_rg") and self.cf.convert_cl_to_rg):
                         if "regression_bin" in self.cf.prediction_tasks:
                             # in this case, regression preds are actually the rg_bin_ids --> map to rg value the bin represents
                             box['rg_bin'] = regressions[ix2].argmax()
                             box['regression'] = self.cf.bin_id2rg_val[box['rg_bin']]
                         else:
                             box['regression'] = regressions[ix2]
                             if hasattr(self.cf, "rg_val_to_bin_id") and \
                                     any(['regression' in task for task in self.cf.prediction_tasks]):
                                 box.update({'rg_bin': self.cf.rg_val_to_bin_id(regressions[ix2])})
 
                         box_results_list[ix].append(box)
 
             # if no detections were made--> keep full bg mask (zeros).
             seg_probs.append(final_masks)
 
         # create and fill results dictionary.
         results_dict = {}
         results_dict['boxes'] = box_results_list
         results_dict['seg_preds'] = np.array(seg_probs)
 
         return results_dict
 
 
     def train_forward(self, batch, is_validation=False):
         """
         train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
         for processing, computes losses, and stores outputs in a dictionary.
         :param batch: dictionary containing 'data', 'seg', etc.
             batch['roi_masks']: (b, n(b), c, h(n), w(n) (z(n))) list like roi_labels but with arrays (masks) inplace of
         integers. c==channels of the raw segmentation.
         :return: results_dict: dictionary with keys:
                 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                         [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes].
                 'torch_loss': 1D torch tensor for backprop.
                 'class_loss': classification loss for monitoring.
         """
         img = batch['data']
         gt_boxes = batch['bb_target']
         #axes = (0, 2, 3, 1) if self.cf.dim == 2 else (0, 2, 3, 4, 1)
         #gt_masks = [np.transpose(batch['roi_masks'][ii], axes=axes) for ii in range(len(batch['roi_masks']))]
         gt_masks = batch['roi_masks']
         gt_class_ids = batch['class_targets']
         if 'regression' in self.cf.prediction_tasks:
             gt_regressions = batch["regression_targets"]
         elif 'regression_bin' in self.cf.prediction_tasks:
             gt_regressions = batch["rg_bin_targets"]
         else:
             gt_regressions = None
 
         img = torch.from_numpy(img).cuda().float()
         batch_rpn_class_loss = torch.FloatTensor([0]).cuda()
         batch_rpn_bbox_loss = torch.FloatTensor([0]).cuda()
 
         # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element.
         box_results_list = [[] for _ in range(img.shape[0])]
 
         #forward passes. 1. general forward pass, where no activations are saved in second stage (for performance
         # monitoring and loss sampling). 2. second stage forward pass of sampled rois with stored activations for backprop.
         rpn_class_logits, rpn_pred_deltas, proposal_boxes, detections, detection_masks = self.forward(img)
 
         mrcnn_pred_deltas, mrcnn_pred_mask, mrcnn_class_logits, mrcnn_regressions, sample_proposals, \
         mrcnn_target_deltas, target_mask, target_class_ids, target_regressions = \
             self.loss_samples_forward(gt_boxes, gt_masks, gt_class_ids, gt_regressions)
         # loop over batch
         for b in range(img.shape[0]):
             if len(gt_boxes[b]) > 0:
                 # add gt boxes to output list
                 for tix in range(len(gt_boxes[b])):
                     gt_box = {'box_type': 'gt', 'box_coords': batch['bb_target'][b][tix]}
                     for name in self.cf.roi_items:
                         gt_box.update({name: batch[name][b][tix]})
                     box_results_list[b].append(gt_box)
 
                 # match gt boxes with anchors to generate targets for RPN losses.
                 rpn_match, rpn_target_deltas = mutils.gt_anchor_matching(self.cf, self.np_anchors, gt_boxes[b])
 
                 # add positive anchors used for loss to output list for monitoring.
                 pos_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == 1)][:, 0], img.shape[2:])
                 for p in pos_anchors:
                     box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'})
 
             else:
                 rpn_match = np.array([-1]*self.np_anchors.shape[0])
                 rpn_target_deltas = np.array([0])
 
             rpn_match_gpu = torch.from_numpy(rpn_match).cuda()
             rpn_target_deltas = torch.from_numpy(rpn_target_deltas).float().cuda()
 
             # compute RPN losses.
             rpn_class_loss, neg_anchor_ix = compute_rpn_class_loss(rpn_class_logits[b], rpn_match_gpu, self.cf.shem_poolsize)
             rpn_bbox_loss = compute_rpn_bbox_loss(rpn_pred_deltas[b], rpn_target_deltas, rpn_match_gpu)
             batch_rpn_class_loss += rpn_class_loss /img.shape[0]
             batch_rpn_bbox_loss += rpn_bbox_loss /img.shape[0]
 
             # add negative anchors used for loss to output list for monitoring.
             # neg_anchor_ix = neg_ix come from shem and mark positions in roi_probs_neg = rpn_class_logits[neg_indices]
             # with neg_indices = rpn_match == -1
             neg_anchors = mutils.clip_boxes_numpy(self.np_anchors[rpn_match == -1][neg_anchor_ix], img.shape[2:])
             for n in neg_anchors:
                 box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'})
 
             # add highest scoring proposals to output list for monitoring.
             rpn_proposals = proposal_boxes[b][proposal_boxes[b, :, -1].argsort()][::-1]
             for r in rpn_proposals[:self.cf.n_plot_rpn_props, :-1]:
                 box_results_list[b].append({'box_coords': r, 'box_type': 'prop'})
 
         # add positive and negative roi samples used for mrcnn losses to output list for monitoring.
         if not 0 in sample_proposals.shape:
             rois = mutils.clip_to_window(self.cf.window, sample_proposals).cpu().data.numpy()
             for ix, r in enumerate(rois):
                 box_results_list[int(r[-1])].append({'box_coords': r[:-1] * self.cf.scale,
                                             'box_type': 'pos_class' if target_class_ids[ix] > 0 else 'neg_class'})
 
         # compute mrcnn losses.
         mrcnn_class_loss = compute_mrcnn_class_loss(self.cf.prediction_tasks, mrcnn_class_logits, target_class_ids)
         mrcnn_bbox_loss = compute_mrcnn_bbox_loss(mrcnn_pred_deltas, mrcnn_target_deltas, target_class_ids)
         mrcnn_regressions_loss = compute_mrcnn_regression_loss(self.cf.prediction_tasks, mrcnn_regressions, target_regressions, target_class_ids)
         # mrcnn can be run without pixelwise annotations available (Faster R-CNN mode).
         # In this case, the mask_loss is taken out of training.
         if self.cf.frcnn_mode:
             mrcnn_mask_loss = torch.FloatTensor([0]).cuda()
         else:
             mrcnn_mask_loss = compute_mrcnn_mask_loss(mrcnn_pred_mask, target_mask, target_class_ids)
 
         loss = batch_rpn_class_loss + batch_rpn_bbox_loss +\
                mrcnn_bbox_loss + mrcnn_mask_loss +  mrcnn_class_loss + mrcnn_regressions_loss
 
         # run unmolding of predictions for monitoring and merge all results to one dictionary.
         return_masks = self.cf.return_masks_in_val if is_validation else self.cf.return_masks_in_train
         results_dict = self.get_results(img.shape, detections, detection_masks, box_results_list,
                                         return_masks=return_masks)
+        results_dict['seg_preds'] = results_dict['seg_preds'].argmax(axis=1).astype('uint8')[:,np.newaxis]
 
-        #results_dict['seg_preds'] = results_dict['seg_preds'].argmax(axis=1).astype('uint8')[:,np.newaxis]
         if 'dice' in self.cf.metrics:
             results_dict['batch_dices'] = mutils.dice_per_batch_and_class(
                 results_dict['seg_preds'], batch["seg"], self.cf.num_seg_classes, convert_to_ohe=True)
 
         results_dict['torch_loss'] = loss
         results_dict['class_loss'] = mrcnn_class_loss.item()
         results_dict['bbox_loss'] = mrcnn_bbox_loss.item()
         results_dict['mask_loss'] = mrcnn_mask_loss.item()
         results_dict['rg_loss'] = mrcnn_regressions_loss.item()
         results_dict['rpn_class_loss'] = rpn_class_loss.item()
         results_dict['rpn_bbox_loss'] = rpn_bbox_loss.item()
         return results_dict
 
 
     def test_forward(self, batch, return_masks=True):
         """
         test method. wrapper around forward pass of network without usage of any ground truth information.
         prepares input data for processing and stores outputs in a dictionary.
         :param batch: dictionary containing 'data'
         :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off).
         :return: results_dict: dictionary with keys:
                'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                        [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]
         """
         img = batch['data']
         img = torch.from_numpy(img).float().cuda()
         _, _, _, detections, detection_masks = self.forward(img)
         results_dict = self.get_results(img.shape, detections, detection_masks, return_masks=return_masks)
 
         return results_dict
\ No newline at end of file
diff --git a/predictor.py b/predictor.py
index b69f821..99035bd 100644
--- a/predictor.py
+++ b/predictor.py
@@ -1,1003 +1,1005 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import os
 from multiprocessing import Pool
 import pickle
 import time
 
 import numpy as np
 import torch
 from scipy.stats import norm
 from collections import OrderedDict
 
 import plotting as plg
 import utils.model_utils as mutils
 import utils.exp_utils as utils
 
 
 def get_mirrored_patch_crops(patch_crops, org_img_shape):
     mirrored_patch_crops = []
     mirrored_patch_crops.append([[org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2], ii[3]]
                                  if len(ii) == 4 else [org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2],
                                                        ii[3], ii[4], ii[5]]
                                  for ii in patch_crops])
 
     mirrored_patch_crops.append([[ii[0], ii[1], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2]]
                                  if len(ii) == 4 else [ii[0], ii[1], org_img_shape[3] - ii[3],
                                                        org_img_shape[3] - ii[2], ii[4], ii[5]]
                                  for ii in patch_crops])
 
     mirrored_patch_crops.append([[org_img_shape[2] - ii[1],
                                   org_img_shape[2] - ii[0],
                                   org_img_shape[3] - ii[3],
                                   org_img_shape[3] - ii[2]]
                                  if len(ii) == 4 else
                                  [org_img_shape[2] - ii[1],
                                   org_img_shape[2] - ii[0],
                                   org_img_shape[3] - ii[3],
                                   org_img_shape[3] - ii[2], ii[4], ii[5]]
                                  for ii in patch_crops])
 
     return mirrored_patch_crops
 
 def get_mirrored_patch_crops_ax_dep(patch_crops, org_img_shape, mirror_axes):
     mirrored_patch_crops = []
     for ax_ix, axes in enumerate(mirror_axes):
         if isinstance(axes, (int, float)) and int(axes) == 0:
             mirrored_patch_crops.append([[org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2], ii[3]]
                                          if len(ii) == 4 else [org_img_shape[2] - ii[1], org_img_shape[2] - ii[0],
                                                                ii[2], ii[3], ii[4], ii[5]]
                                          for ii in patch_crops])
         elif isinstance(axes, (int, float)) and int(axes) == 1:
             mirrored_patch_crops.append([[ii[0], ii[1], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2]]
                                          if len(ii) == 4 else [ii[0], ii[1], org_img_shape[3] - ii[3],
                                                                org_img_shape[3] - ii[2], ii[4], ii[5]]
                                          for ii in patch_crops])
         elif hasattr(axes, "__iter__") and (tuple(axes) == (0, 1) or tuple(axes) == (1, 0)):
             mirrored_patch_crops.append([[org_img_shape[2] - ii[1],
                                           org_img_shape[2] - ii[0],
                                           org_img_shape[3] - ii[3],
                                           org_img_shape[3] - ii[2]]
                                          if len(ii) == 4 else
                                          [org_img_shape[2] - ii[1],
                                           org_img_shape[2] - ii[0],
                                           org_img_shape[3] - ii[3],
                                           org_img_shape[3] - ii[2], ii[4], ii[5]]
                                          for ii in patch_crops])
         else:
             raise Exception("invalid mirror axes {} in get mirrored patch crops".format(axes))
 
     return mirrored_patch_crops
 
 def apply_wbc_to_patient(inputs):
     """
     wrapper around prediction box consolidation: weighted box clustering (wbc). processes a single patient.
     loops over batch elements in patient results (1 in 3D, slices in 2D) and foreground classes,
     aggregates and stores results in new list.
     :return. patient_results_list: list over batch elements. each element is a list over boxes, where each box is
                                  one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D
                                  predictions, and a dummy batch dimension of 1 for 3D predictions.
     :return. pid: string. patient id.
     """
     regress_flag, in_patient_results_list, pid, class_dict, clustering_iou, n_ens = inputs
     out_patient_results_list = [[] for _ in range(len(in_patient_results_list))]
 
     for bix, b in enumerate(in_patient_results_list):
 
         for cl in list(class_dict.keys()):
 
             boxes = [(ix, box) for ix, box in enumerate(b) if
                      (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)]
             box_coords = np.array([b[1]['box_coords'] for b in boxes])
             box_scores = np.array([b[1]['box_score'] for b in boxes])
             box_center_factor = np.array([b[1]['box_patch_center_factor'] for b in boxes])
             box_n_overlaps = np.array([b[1]['box_n_overlaps'] for b in boxes])
             try:
                 box_patch_id = np.array([b[1]['patch_id'] for b in boxes])
             except KeyError: #backward compatibility for already saved pred results ... omg
                 box_patch_id = np.array([b[1]['ens_ix'] for b in boxes])
             box_regressions = np.array([b[1]['regression'] for b in boxes]) if regress_flag else None
             box_rg_bins = np.array([b[1]['rg_bin'] if 'rg_bin' in b[1].keys() else float('NaN') for b in boxes])
             box_rg_uncs = np.array([b[1]['rg_uncertainty'] if 'rg_uncertainty' in b[1].keys() else float('NaN') for b in boxes])
 
             if 0 not in box_scores.shape:
                 keep_scores, keep_coords, keep_n_missing, keep_regressions, keep_rg_bins, keep_rg_uncs = \
                     weighted_box_clustering(box_coords, box_scores, box_center_factor, box_n_overlaps, box_rg_bins, box_rg_uncs,
                                              box_regressions, box_patch_id, clustering_iou, n_ens)
 
 
                 for boxix in range(len(keep_scores)):
                     clustered_box = {'box_type': 'det', 'box_coords': keep_coords[boxix],
                                      'box_score': keep_scores[boxix], 'cluster_n_missing': keep_n_missing[boxix],
                                      'box_pred_class_id': cl}
                     if regress_flag:
                         clustered_box.update({'regression': keep_regressions[boxix],
                                               'rg_uncertainty': keep_rg_uncs[boxix],
                                               'rg_bin': keep_rg_bins[boxix]})
 
                     out_patient_results_list[bix].append(clustered_box)
 
         # add gt boxes back to new output list.
         out_patient_results_list[bix].extend([box for box in b if box['box_type'] == 'gt'])
 
     return [out_patient_results_list, pid]
 
 
 def weighted_box_clustering(box_coords, scores, box_pc_facts, box_n_ovs, box_rg_bins, box_rg_uncs,
                              box_regress, box_patch_id, thresh, n_ens):
     """Consolidates overlapping predictions resulting from patch overlaps, test data augmentations and temporal ensembling.
     clusters predictions together with iou > thresh (like in NMS). Output score and coordinate for one cluster are the
     average weighted by individual patch center factors (how trustworthy is this candidate measured by how centered
     its position within the patch is) and the size of the corresponding box.
     The number of expected predictions at a position is n_data_aug * n_temp_ens * n_overlaps_at_position
     (1 prediction per unique patch). Missing predictions at a cluster position are defined as the number of unique
     patches in the cluster, which did not contribute any predict any boxes.
     :param dets: (n_dets, (y1, x1, y2, x2, (z1), (z2), scores, box_pc_facts, box_n_ovs).
     :param box_coords: y1, x1, y2, x2, (z1), (z2).
     :param scores: confidence scores.
     :param box_pc_facts: patch-center factors from position on patch tiles.
     :param box_n_ovs: number of patch overlaps at box position.
     :param box_rg_bins: regression bin predictions.
     :param box_rg_uncs: (n_dets,) regression uncertainties (from model mrcnn_aleatoric).
     :param box_regress: (n_dets, n_regression_features).
     :param box_patch_id: ensemble index.
     :param thresh: threshold for iou_matching.
     :param n_ens: number of models, that are ensembled. (-> number of expected predictions per position).
     :return: keep_scores: (n_keep)  new scores of boxes to be kept.
     :return: keep_coords: (n_keep, (y1, x1, y2, x2, (z1), (z2)) new coordinates of boxes to be kept.
     """
 
     dim = 2 if box_coords.shape[1] == 4 else 3
     y1 = box_coords[:,0]
     x1 = box_coords[:,1]
     y2 = box_coords[:,2]
     x2 = box_coords[:,3]
 
     areas = (y2 - y1 + 1) * (x2 - x1 + 1)
     if dim == 3:
         z1 = box_coords[:, 4]
         z2 = box_coords[:, 5]
         areas *= (z2 - z1 + 1)
 
     # order is the sorted index.  maps order to index o[1] = 24 (rank1, ix 24)
     order = scores.argsort()[::-1]
 
     keep_scores = []
     keep_coords = []
     keep_n_missing = []
     keep_regress = []
     keep_rg_bins = []
     keep_rg_uncs = []
 
     while order.size > 0:
         i = order[0]  # highest scoring element
         yy1 = np.maximum(y1[i], y1[order])
         xx1 = np.maximum(x1[i], x1[order])
         yy2 = np.minimum(y2[i], y2[order])
         xx2 = np.minimum(x2[i], x2[order])
 
         w = np.maximum(0, xx2 - xx1 + 1)
         h = np.maximum(0, yy2 - yy1 + 1)
         inter = w * h
 
         if dim == 3:
             zz1 = np.maximum(z1[i], z1[order])
             zz2 = np.minimum(z2[i], z2[order])
             d = np.maximum(0, zz2 - zz1 + 1)
             inter *= d
 
         # overlap between currently highest scoring box and all boxes.
         ovr = inter / (areas[i] + areas[order] - inter)
         ovr_fl = inter.astype('float64') / (areas[i] + areas[order] - inter.astype('float64'))
         assert np.all(ovr==ovr_fl), "ovr {}\n ovr_float {}".format(ovr, ovr_fl)
         # get all the predictions that match the current box to build one cluster.
         matches = np.nonzero(ovr > thresh)[0]
 
         match_n_ovs = box_n_ovs[order[matches]]
         match_pc_facts = box_pc_facts[order[matches]]
         match_patch_id = box_patch_id[order[matches]]
         match_ov_facts = ovr[matches]
         match_areas = areas[order[matches]]
         match_scores = scores[order[matches]]
 
         # weight all scores in cluster by patch factors, and size.
         match_score_weights = match_ov_facts * match_areas * match_pc_facts
         match_scores *= match_score_weights
 
         # for the weighted average, scores have to be divided by the number of total expected preds at the position
         # of the current cluster. 1 Prediction per patch is expected. therefore, the number of ensembled models is
         # multiplied by the mean overlaps of  patches at this position (boxes of the cluster might partly be
         # in areas of different overlaps).
         n_expected_preds = n_ens * np.mean(match_n_ovs)
         # the number of missing predictions is obtained as the number of patches,
         # which did not contribute any prediction to the current cluster.
         n_missing_preds = np.max((0, n_expected_preds - np.unique(match_patch_id).shape[0]))
 
         # missing preds are given the mean weighting
         # (expected prediction is the mean over all predictions in cluster).
         denom = np.sum(match_score_weights) + n_missing_preds * np.mean(match_score_weights)
 
         # compute weighted average score for the cluster
         avg_score = np.sum(match_scores) / denom
 
         # compute weighted average of coordinates for the cluster. now only take existing
         # predictions into account.
         avg_coords = [np.sum(y1[order[matches]] * match_scores) / np.sum(match_scores),
                       np.sum(x1[order[matches]] * match_scores) / np.sum(match_scores),
                       np.sum(y2[order[matches]] * match_scores) / np.sum(match_scores),
                       np.sum(x2[order[matches]] * match_scores) / np.sum(match_scores)]
 
         if dim == 3:
             avg_coords.append(np.sum(z1[order[matches]] * match_scores) / np.sum(match_scores))
             avg_coords.append(np.sum(z2[order[matches]] * match_scores) / np.sum(match_scores))
 
         if box_regress is not None:
             # compute wt. avg. of regression vectors (component-wise average)
             avg_regress = np.sum(box_regress[order[matches]] * match_scores[:, np.newaxis], axis=0) / np.sum(
                 match_scores)
             avg_rg_bins = np.round(np.sum(box_rg_bins[order[matches]] * match_scores) / np.sum(match_scores))
             avg_rg_uncs = np.sum(box_rg_uncs[order[matches]] * match_scores) / np.sum(match_scores)
         else:
             avg_regress = np.array(float('NaN'))
             avg_rg_bins = np.array(float('NaN'))
             avg_rg_uncs = np.array(float('NaN'))
 
         # some clusters might have very low scores due to high amounts of missing predictions.
         # filter out the with a conservative threshold, to speed up evaluation.
         if avg_score > 0.01:
             keep_scores.append(avg_score)
             keep_coords.append(avg_coords)
             keep_n_missing.append((n_missing_preds / n_expected_preds * 100))  # relative
             keep_regress.append(avg_regress)
             keep_rg_uncs.append(avg_rg_uncs)
             keep_rg_bins.append(avg_rg_bins)
 
         # get index of all elements that were not matched and discard all others.
         inds = np.nonzero(ovr <= thresh)[0]
         inds_where = np.where(ovr<=thresh)[0]
         assert np.all(inds == inds_where), "inds_nonzero {} \ninds_where {}".format(inds, inds_where)
         order = order[inds]
 
     return keep_scores, keep_coords, keep_n_missing, keep_regress, keep_rg_bins, keep_rg_uncs
 
 
 def apply_nms_to_patient(inputs):
 
     in_patient_results_list, pid, class_dict, iou_thresh = inputs
     out_patient_results_list = []
 
 
     # collect box predictions over batch dimension (slices) and store slice info as slice_ids.
     for batch in in_patient_results_list:
         batch_el_boxes = []
         for cl in list(class_dict.keys()):
             det_boxes = [box for box in batch if (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)]
 
             box_coords = np.array([box['box_coords'] for box in det_boxes])
             box_scores = np.array([box['box_score'] for box in det_boxes])
             if 0 not in box_scores.shape:
                 keep_ix = mutils.nms_numpy(box_coords, box_scores, iou_thresh)
             else:
                 keep_ix = []
 
             batch_el_boxes += [det_boxes[ix] for ix in keep_ix]
 
         batch_el_boxes += [box for box in batch if box['box_type'] == 'gt']
         out_patient_results_list.append(batch_el_boxes)
 
     assert len(in_patient_results_list) == len(out_patient_results_list), "batch dim needs to be maintained, in: {}, out {}".format(len(in_patient_results_list), len(out_patient_results_list))
 
     return [out_patient_results_list, pid]
 
 def nms_2to3D(dets, thresh):
     """
     Merges 2D boxes to 3D cubes. For this purpose, boxes of all slices are regarded as lying in one slice.
     An adaptation of Non-maximum suppression is applied where clusters are found (like in NMS) with the extra constraint
     that suppressed boxes have to have 'connected' z coordinates w.r.t the core slice (cluster center, highest
     scoring box, the prevailing box). 'connected' z-coordinates are determined
     as the z-coordinates with predictions until the first coordinate for which no prediction is found.
 
     example: a cluster of predictions was found overlap > iou thresh in xy (like NMS). The z-coordinate of the highest
     scoring box is 50. Other predictions have 23, 46, 48, 49, 51, 52, 53, 56, 57.
     Only the coordinates connected with 50 are clustered to one cube: 48, 49, 51, 52, 53. (46 not because nothing was
     found in 47, so 47 is a 'hole', which interrupts the connection). Only the boxes corresponding to these coordinates
     are suppressed. All others are kept for building of further clusters.
 
     This algorithm works better with a certain min_confidence of predictions, because low confidence (e.g. noisy/cluttery)
     predictions can break the relatively strong assumption of defining cubes' z-boundaries at the first 'hole' in the cluster.
 
     :param dets: (n_detections, (y1, x1, y2, x2, scores, slice_id)
     :param thresh: iou matchin threshold (like in NMS).
     :return: keep: (n_keep,) 1D tensor of indices to be kept.
     :return: keep_z: (n_keep, [z1, z2]) z-coordinates to be added to boxes, which are kept in order to form cubes.
     """
 
     y1 = dets[:, 0]
     x1 = dets[:, 1]
     y2 = dets[:, 2]
     x2 = dets[:, 3]
     assert np.all(y1 <= y2) and np.all(x1 <= x2), """"the definition of the coordinates is crucially important here: 
         where maximum is taken needs to be the lower coordinate"""
     scores = dets[:, -2]
     slice_id = dets[:, -1]
 
     areas = (x2 - x1 + 1) * (y2 - y1 + 1)
     order = scores.argsort()[::-1]
 
     keep = []
     keep_z = []
 
     while order.size > 0:  # order is the sorted index.  maps order to index: order[1] = 24 means (rank1, ix 24)
         i = order[0]  # highest scoring element
         yy1 = np.maximum(y1[i], y1[order])  # highest scoring element still in >order<, is compared to itself: okay?
         xx1 = np.maximum(x1[i], x1[order])
         yy2 = np.minimum(y2[i], y2[order])
         xx2 = np.minimum(x2[i], x2[order])
 
         h = np.maximum(0.0, yy2 - yy1 + 1)
         w = np.maximum(0.0, xx2 - xx1 + 1)
         inter = h * w
 
         iou = inter / (areas[i] + areas[order] - inter)
         matches = np.argwhere(
             iou > thresh)  # get all the elements that match the current box and have a lower score
 
         slice_ids = slice_id[order[matches]]
         core_slice = slice_id[int(i)]
         upper_holes = [ii for ii in np.arange(core_slice, np.max(slice_ids)) if ii not in slice_ids]
         lower_holes = [ii for ii in np.arange(np.min(slice_ids), core_slice) if ii not in slice_ids]
         max_valid_slice_id = np.min(upper_holes) if len(upper_holes) > 0 else np.max(slice_ids)
         min_valid_slice_id = np.max(lower_holes) if len(lower_holes) > 0 else np.min(slice_ids)
         z_matches = matches[(slice_ids <= max_valid_slice_id) & (slice_ids >= min_valid_slice_id)]
 
         # expand by one z voxel since box content is surrounded w/o overlap, i.e., z-content computed as z2-z1
         z1 = np.min(slice_id[order[z_matches]]) - 1
         z2 = np.max(slice_id[order[z_matches]]) + 1
 
         keep.append(i)
         keep_z.append([z1, z2])
         order = np.delete(order, z_matches, axis=0)
 
     return keep, keep_z
 
 def apply_2d_3d_merging_to_patient(inputs):
     """
     wrapper around 2Dto3D merging operation. Processes a single patient. Takes 2D patient results (slices in batch dimension)
     and returns 3D patient results (dummy batch dimension of 1). Applies an adaption of Non-Maximum Surpression
     (Detailed methodology is described in nms_2to3D).
     :return. results_dict_boxes: list over batch elements (1 in 3D). each element is a list over boxes, where each box is
                                  one dictionary: [[box_0, ...], [box_n,...]].
     :return. pid: string. patient id.
     """
 
     in_patient_results_list, pid, class_dict, merge_3D_iou = inputs
     out_patient_results_list = []
 
     for cl in list(class_dict.keys()):
         det_boxes, slice_ids = [], []
         # collect box predictions over batch dimension (slices) and store slice info as slice_ids.
         for batch_ix, batch in enumerate(in_patient_results_list):
             batch_element_det_boxes = [(ix, box) for ix, box in enumerate(batch) if
                                        (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)]
             det_boxes += batch_element_det_boxes
             slice_ids += [batch_ix] * len(batch_element_det_boxes)
 
         box_coords = np.array([batch[1]['box_coords'] for batch in det_boxes])
         box_scores = np.array([batch[1]['box_score'] for batch in det_boxes])
         slice_ids = np.array(slice_ids)
 
         if 0 not in box_scores.shape:
             keep_ix, keep_z = nms_2to3D(
                 np.concatenate((box_coords, box_scores[:, None], slice_ids[:, None]), axis=1), merge_3D_iou)
         else:
             keep_ix, keep_z = [], []
 
         # store kept predictions in new results list and add corresponding z-dimension info to coordinates.
         for kix, kz in zip(keep_ix, keep_z):
             keep_box = det_boxes[kix][1]
             keep_box['box_coords'] = list(keep_box['box_coords']) + kz
             out_patient_results_list.append(keep_box)
 
     gt_boxes = [box for b in in_patient_results_list for box in b if box['box_type'] == 'gt']
     if len(gt_boxes) > 0:
         assert np.all([len(box["box_coords"]) == 6 for box in gt_boxes]), "expanded preds to 3D but GT is 2D."
     out_patient_results_list += gt_boxes
 
     return [[out_patient_results_list], pid]  # additional list wrapping is extra batch dim.
 
 
 class Predictor:
     """
 	    Prediction pipeline:
 	    - receives a patched patient image (n_patches, c, y, x, (z)) from patient data loader.
 	    - forwards patches through model in chunks of batch_size. (method: batch_tiling_forward)
 	    - unmolds predictions (boxes and segmentations) to original patient coordinates. (method: spatial_tiling_forward)
 
 	    Ensembling (mode == 'test'):
 	    - for inference, forwards 4 mirrored versions of image to through model and unmolds predictions afterwards
 	      accordingly (method: data_aug_forward)
 	    - for inference, loads multiple parameter-sets of the trained model corresponding to different epochs. for each
 	      parameter-set loops over entire test set, runs prediction pipeline for each patient. (method: predict_test_set)
 
 	    Consolidation of predictions:
 	    - consolidates a patient's predictions (boxes, segmentations) collected over patches, data_aug- and temporal ensembling,
 	      performs clustering and weighted averaging (external function: apply_wbc_to_patient) to obtain consistent outptus.
 	    - for 2D networks, consolidates box predictions to 3D cubes via clustering (adaption of non-maximum surpression).
 	      (external function: apply_2d_3d_merging_to_patient)
 
 	    Ground truth handling:
 	    - dissmisses any ground truth boxes returned by the model (happens in validation mode, patch-based groundtruth)
 	    - if provided by data loader, adds patient-wise ground truth to the final predictions to be passed to the evaluator.
     """
     def __init__(self, cf, net, logger, mode):
 
         self.cf = cf
         self.batch_size = cf.batch_size
         self.logger = logger
         self.mode = mode
         self.net = net
         self.n_ens = 1
         self.rank_ix = '0'
         self.regress_flag = any(['regression' in task for task in self.cf.prediction_tasks])
 
         if self.cf.merge_2D_to_3D_preds:
             assert self.cf.dim == 2, "Merge 2Dto3D only valid for 2D preds, but current dim is {}.".format(self.cf.dim)
 
         if self.mode == 'test':
+            last_state_path = os.path.join(self.cf.fold_dir, 'last_state.pth')
             try:
-                self.epoch_ranking = np.load(os.path.join(self.cf.fold_dir, 'epoch_ranking.npy'))[:cf.test_n_epochs]
-            except:
-                raise RuntimeError('no epoch ranking file in fold directory. '
+                self.model_index = torch.load(last_state_path)["model_index"]
+                self.model_index = self.model_index[self.model_index["rank"] <= self.cf.test_n_epochs]
+            except FileNotFoundError:
+                raise FileNotFoundError('no last_state/model_index file in fold directory. '
                                    'seems like you are trying to run testing without prior training...')
             self.n_ens = cf.test_n_epochs
             if self.cf.test_aug_axes is not None:
                 self.n_ens *= (len(self.cf.test_aug_axes)+1)
             self.example_plot_dir = os.path.join(cf.test_dir, "example_plots")
             os.makedirs(self.example_plot_dir, exist_ok=True)
 
     def batch_tiling_forward(self, batch):
         """
         calls the actual network forward method. in patch-based prediction, the batch dimension might be overladed
         with n_patches >> batch_size, which would exceed gpu memory. In this case, batches are processed in chunks of
         batch_size. validation mode calls the train method to monitor losses (returned ground truth objects are discarded).
         test mode calls the test forward method, no ground truth required / involved.
         :return. results_dict: stores the results for one patient. dictionary with keys:
                  - 'boxes': list over batch elements. each element is a list over boxes, where each box is
                             one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions,
                             and a dummy batch dimension of 1 for 3D predictions.
                  - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z))
                  - loss / class_loss (only in validation mode)
         """
 
         img = batch['data']
 
         if img.shape[0] <= self.batch_size:
 
             if self.mode == 'val':
                 # call training method to monitor losses
                 results_dict = self.net.train_forward(batch, is_validation=True)
                 # discard returned ground-truth boxes (also training info boxes).
                 results_dict['boxes'] = [[box for box in b if box['box_type'] == 'det'] for b in results_dict['boxes']]
             elif self.mode == 'test':
                 results_dict = self.net.test_forward(batch, return_masks=self.cf.return_masks_in_test)
 
         else: # needs batch tiling
             split_ixs = np.split(np.arange(img.shape[0]), np.arange(img.shape[0])[::self.batch_size])
             chunk_dicts = []
             for chunk_ixs in split_ixs[1:]:  # first split is elements before 0, so empty
                 b = {k: batch[k][chunk_ixs] for k in batch.keys()
                      if (isinstance(batch[k], np.ndarray) and batch[k].shape[0] == img.shape[0])}
                 if self.mode == 'val':
                     chunk_dicts += [self.net.train_forward(b, is_validation=True)]
                 else:
                     chunk_dicts += [self.net.test_forward(b, return_masks=self.cf.return_masks_in_test)]
 
             results_dict = {}
             # flatten out batch elements from chunks ([chunk, chunk] -> [b, b, b, b, ...])
             results_dict['boxes'] = [item for d in chunk_dicts for item in d['boxes']]
             results_dict['seg_preds'] = np.array([item for d in chunk_dicts for item in d['seg_preds']])
 
             if self.mode == 'val':
                 # if hasattr(self.cf, "losses_to_monitor"):
                 #     loss_names = self.cf.losses_to_monitor
                 # else:
                 #     loss_names = {name for dic in chunk_dicts for name in dic if 'loss' in name}
                 # estimate patient loss by mean over batch_chunks. Most similar to training loss.
                 results_dict['torch_loss'] = torch.mean(torch.cat([d['torch_loss'] for d in chunk_dicts]))
                 results_dict['class_loss'] = np.mean([d['class_loss'] for d in chunk_dicts])
                 # discard returned ground-truth boxes (also training info boxes).
                 results_dict['boxes'] = [[box for box in b if box['box_type'] == 'det'] for b in results_dict['boxes']]
 
         return results_dict
 
     def spatial_tiling_forward(self, batch, patch_crops = None, n_aug='0'):
         """
         forwards batch to batch_tiling_forward method and receives and returns a dictionary with results.
         if patch-based prediction, the results received from batch_tiling_forward will be on a per-patch-basis.
         this method uses the provided patch_crops to re-transform all predictions to whole-image coordinates.
         Patch-origin information of all box-predictions will be needed for consolidation, hence it is stored as
         'patch_id', which is a unique string for each patch (also takes current data aug and temporal epoch instances
         into account). all box predictions get additional information about the amount overlapping patches at the
         respective position (used for consolidation).
         :return. results_dict: stores the results for one patient. dictionary with keys:
                  - 'boxes': list over batch elements. each element is a list over boxes, where each box is
                             one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions,
                             and a dummy batch dimension of 1 for 3D predictions.
                  - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z))
                  - monitor_values (only in validation mode)
         returned dict is a flattened version with 1 batch instance (3D) or slices (2D)
         """
 
         if patch_crops is not None:
             #print("patch_crops not None, applying patch center factor")
 
             patches_dict = self.batch_tiling_forward(batch)
             results_dict = {'boxes': [[] for _ in range(batch['original_img_shape'][0])]}
             #bc of ohe--> channel dim of seg has size num_classes
             out_seg_shape = list(batch['original_img_shape'])
             out_seg_shape[1] = patches_dict["seg_preds"].shape[1]
             out_seg_preds = np.zeros(out_seg_shape, dtype=np.float16)
             patch_overlap_map = np.zeros_like(out_seg_preds, dtype='uint8')
             for pix, pc in enumerate(patch_crops):
                 if self.cf.dim == 3:
                     out_seg_preds[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += patches_dict['seg_preds'][pix]
                     patch_overlap_map[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += 1
                 elif self.cf.dim == 2:
                     out_seg_preds[pc[4]:pc[5], :, pc[0]:pc[1], pc[2]:pc[3], ] += patches_dict['seg_preds'][pix]
                     patch_overlap_map[pc[4]:pc[5], :, pc[0]:pc[1], pc[2]:pc[3], ] += 1
 
             out_seg_preds[patch_overlap_map > 0] /= patch_overlap_map[patch_overlap_map > 0]
             results_dict['seg_preds'] = out_seg_preds
 
             for pix, pc in enumerate(patch_crops):
                 patch_boxes = patches_dict['boxes'][pix]
                 for box in patch_boxes:
 
                     # add unique patch id for consolidation of predictions.
                     box['patch_id'] = self.rank_ix + '_' + n_aug + '_' + str(pix)
                     # boxes from the edges of a patch have a lower prediction quality, than the ones at patch-centers.
                     # hence they will be down-weighted for consolidation, using the 'box_patch_center_factor', which is
                     # obtained by a gaussian distribution over positions in the patch and average over spatial dimensions.
                     # Also the info 'box_n_overlaps' is stored for consolidation, which represents the amount of
                     # overlapping patches at the box's position.
 
                     c = box['box_coords']
                     #box_centers = np.array([(c[ii] + c[ii+2])/2 for ii in range(len(c)//2)])
                     box_centers = [(c[ii] + c[ii + 2]) / 2 for ii in range(2)]
                     if self.cf.dim == 3:
                         box_centers.append((c[4] + c[5]) / 2)
                     box['box_patch_center_factor'] = np.mean(
                         [norm.pdf(bc, loc=pc, scale=pc * 0.8) * np.sqrt(2 * np.pi) * pc * 0.8 for bc, pc in
                          zip(box_centers, np.array(self.cf.patch_size) / 2)])
                     if self.cf.dim == 3:
                         c += np.array([pc[0], pc[2], pc[0], pc[2], pc[4], pc[4]])
                         int_c = [int(np.floor(ii)) if ix%2 == 0 else int(np.ceil(ii))  for ix, ii in enumerate(c)]
                         box['box_n_overlaps'] = np.mean(patch_overlap_map[:, :, int_c[1]:int_c[3], int_c[0]:int_c[2], int_c[4]:int_c[5]])
                         results_dict['boxes'][0].append(box)
                     else:
                         c += np.array([pc[0], pc[2], pc[0], pc[2]])
                         int_c = [int(np.floor(ii)) if ix % 2 == 0 else int(np.ceil(ii)) for ix, ii in enumerate(c)]
                         box['box_n_overlaps'] = np.mean(
                             patch_overlap_map[pc[4], :, int_c[1]:int_c[3], int_c[0]:int_c[2]])
                         results_dict['boxes'][pc[4]].append(box)
 
             if self.mode == 'val':
                 results_dict['torch_loss'] = patches_dict['torch_loss']
                 results_dict['class_loss'] = patches_dict['class_loss']
 
         else:
             results_dict = self.batch_tiling_forward(batch)
             for b in results_dict['boxes']:
                 for box in b:
                     box['box_patch_center_factor'] = 1
                     box['box_n_overlaps'] = 1
                     box['patch_id'] = self.rank_ix + '_' + n_aug
 
         return results_dict
 
     def data_aug_forward(self, batch):
         """
         in val_mode: passes batch through to spatial_tiling method without data_aug.
         in test_mode: if cf.test_aug is set in configs, createst 4 mirrored versions of the input image,
         passes all of them to the next processing step (spatial_tiling method) and re-transforms returned predictions
         to original image version.
         :return. results_dict: stores the results for one patient. dictionary with keys:
                  - 'boxes': list over batch elements. each element is a list over boxes, where each box is
                             one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions,
                             and a dummy batch dimension of 1 for 3D predictions.
                  - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z))
                  - loss / class_loss (only in validation mode)
         """
         patch_crops = batch['patch_crop_coords'] if self.patched_patient else None
         results_list = [self.spatial_tiling_forward(batch, patch_crops)]
         org_img_shape = batch['original_img_shape']
 
         if self.mode == 'test' and self.cf.test_aug_axes is not None:
             if isinstance(self.cf.test_aug_axes, (int, float)):
                 self.cf.test_aug_axes = (self.cf.test_aug_axes,)
             #assert np.all(np.array(self.cf.test_aug_axes)<self.cf.dim), "test axes {} need to be spatial axes".format(self.cf.test_aug_axes)
 
             if self.patched_patient:
                 # apply mirror transformations to patch-crop coordinates, for correct tiling in spatial_tiling method.
                 mirrored_patch_crops = get_mirrored_patch_crops_ax_dep(patch_crops, batch['original_img_shape'],
                                                                        self.cf.test_aug_axes)
                 self.logger.info("mirrored patch crop coords for patched patient in test augs!")
             else:
                 mirrored_patch_crops = [None] * 3
 
             img = np.copy(batch['data'])
 
             for n_aug, sp_axis in enumerate(self.cf.test_aug_axes):
                 #sp_axis = np.array(axis) #-2 #spatial axis index
                 axis = np.array(sp_axis)+2
                 if isinstance(sp_axis, (int, float)):
                     # mirroring along one axis at a time
                     batch['data'] = np.flip(img, axis=axis).copy()
                     chunk_dict = self.spatial_tiling_forward(batch, mirrored_patch_crops[n_aug], n_aug=str(n_aug))
                     # re-transform coordinates.
                     for ix in range(len(chunk_dict['boxes'])):
                         for boxix in range(len(chunk_dict['boxes'][ix])):
                             coords = chunk_dict['boxes'][ix][boxix]['box_coords'].copy()
                             coords[sp_axis] = org_img_shape[axis] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis+2]
                             coords[sp_axis+2] = org_img_shape[axis] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis]
                             assert coords[2] >= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']]
                             assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']]
                             chunk_dict['boxes'][ix][boxix]['box_coords'] = coords
                     # re-transform segmentation predictions.
                     chunk_dict['seg_preds'] = np.flip(chunk_dict['seg_preds'], axis=axis)
 
                 elif hasattr(sp_axis, "__iter__") and tuple(sp_axis)==(0,1) or tuple(sp_axis)==(1,0):
                     #NEED: mirrored patch crops are given as [(y-axis), (x-axis), (y-,x-axis)], obey this order!
                     # mirroring along two axes at same time
                     batch['data'] = np.flip(np.flip(img, axis=axis[0]), axis=axis[1]).copy()
                     chunk_dict = self.spatial_tiling_forward(batch, mirrored_patch_crops[n_aug], n_aug=str(n_aug))
                     # re-transform coordinates.
                     for ix in range(len(chunk_dict['boxes'])):
                         for boxix in range(len(chunk_dict['boxes'][ix])):
                             coords = chunk_dict['boxes'][ix][boxix]['box_coords'].copy()
                             coords[sp_axis[0]] = org_img_shape[axis[0]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[0]+2]
                             coords[sp_axis[0]+2] = org_img_shape[axis[0]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[0]]
                             coords[sp_axis[1]] = org_img_shape[axis[1]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[1]+2]
                             coords[sp_axis[1]+2] = org_img_shape[axis[1]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[1]]
                             assert coords[2] >= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']]
                             assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']]
                             chunk_dict['boxes'][ix][boxix]['box_coords'] = coords
                     # re-transform segmentation predictions.
                     chunk_dict['seg_preds'] = np.flip(np.flip(chunk_dict['seg_preds'], axis=axis[0]), axis=axis[1]).copy()
 
                 else:
                     raise Exception("Invalid axis type {} in test augs".format(type(axis)))
                 results_list.append(chunk_dict)
 
             batch['data'] = img
 
         # aggregate all boxes/seg_preds per batch element from data_aug predictions.
         results_dict = {}
         results_dict['boxes'] = [[item for d in results_list for item in d['boxes'][batch_instance]]
                                  for batch_instance in range(org_img_shape[0])]
         # results_dict['seg_preds'] = np.array([[item for d in results_list for item in d['seg_preds'][batch_instance]]
         #                                       for batch_instance in range(org_img_shape[0])])
         results_dict['seg_preds'] = np.stack([dic['seg_preds'] for dic in results_list], axis=1)
         # needs segs probs in seg_preds entry:
         results_dict['seg_preds'] = np.sum(results_dict['seg_preds'], axis=1) #add up seg probs from different augs per class
 
         if self.mode == 'val':
             results_dict['torch_loss'] = results_list[0]['torch_loss']
             results_dict['class_loss'] = results_list[0]['class_loss']
 
         return results_dict
 
     def load_saved_predictions(self):
         """loads raw predictions saved by self.predict_test_set. aggregates and/or merges 2D boxes to 3D cubes for
             evaluation (if model predicts 2D but evaluation is run in 3D), according to settings config.
         :return: list_of_results_per_patient: list over patient results. each entry is a dict with keys:
             - 'boxes': list over batch elements. each element is a list over boxes, where each box is
                        one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions
                        (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions.
             - 'batch_dices': dice scores as recorded in raw prediction results.
             - 'seg_preds': not implemented yet. could replace dices by seg preds to have raw seg info available, however
                 would consume critically large memory amount. todo evaluation of instance/semantic segmentation.
         """
 
         results_file = 'pred_results.pkl' if not self.cf.held_out_test_set else 'pred_results_held_out.pkl'
         if not self.cf.held_out_test_set or self.cf.eval_test_fold_wise:
             self.logger.info("loading saved predictions of fold {}".format(self.cf.fold))
             with open(os.path.join(self.cf.fold_dir, results_file), 'rb') as handle:
                 results_list = pickle.load(handle)
             box_results_list = [(res_dict["boxes"], pid) for res_dict, pid in results_list]
 
             da_factor = len(self.cf.test_aug_axes)+1 if self.cf.test_aug_axes is not None else 1
             self.n_ens = self.cf.test_n_epochs * da_factor
             self.logger.info('loaded raw test set predictions with n_patients = {} and n_ens = {}'.format(
                 len(results_list), self.n_ens))
         else:
             self.logger.info("loading saved predictions of hold-out test set")
             fold_dirs = sorted([os.path.join(self.cf.exp_dir, f) for f in os.listdir(self.cf.exp_dir) if
                                 os.path.isdir(os.path.join(self.cf.exp_dir, f)) and f.startswith("fold")])
 
             results_list = []
             folds_loaded = 0
             for fold in range(self.cf.n_cv_splits):
                 fold_dir = os.path.join(self.cf.exp_dir, 'fold_{}'.format(fold))
                 if fold_dir in fold_dirs:
                     with open(os.path.join(fold_dir, results_file), 'rb') as handle:
                         fold_list = pickle.load(handle)
                         results_list += fold_list
                         folds_loaded += 1
                 else:
                     self.logger.info("Skipping fold {} since no saved predictions found.".format(fold))
             box_results_list = []
             for res_dict, pid in results_list: #without filtering gt out:
                 box_results_list.append((res_dict['boxes'], pid))
                 #it's usually not right to filter out gts here, is it?
 
             da_factor = len(self.cf.test_aug_axes)+1 if self.cf.test_aug_axes is not None else 1
             self.n_ens = self.cf.test_n_epochs * da_factor * folds_loaded
 
         # -------------- aggregation of boxes via clustering -----------------
 
         if self.cf.clustering == "wbc":
             self.logger.info('applying WBC to test-set predictions with iou {} and n_ens {} over {} patients'.format(
                 self.cf.clustering_iou, self.n_ens, len(box_results_list)))
 
             mp_inputs = [[self.regress_flag, ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou, self.n_ens] for ii
                          in box_results_list]
             del box_results_list
             pool = Pool(processes=self.cf.n_workers)
             box_results_list = pool.map(apply_wbc_to_patient, mp_inputs, chunksize=1)
             pool.close()
             pool.join()
             del mp_inputs
         elif self.cf.clustering == "nms":
             self.logger.info('applying standard NMS to test-set predictions with iou {} over {} patients.'.format(
                 self.cf.clustering_iou, len(box_results_list)))
             pool = Pool(processes=self.cf.n_workers)
             mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou] for ii in box_results_list]
             del box_results_list
             box_results_list = pool.map(apply_nms_to_patient, mp_inputs, chunksize=1)
             pool.close()
             pool.join()
             del mp_inputs
 
         if self.cf.merge_2D_to_3D_preds:
             self.logger.info('applying 2Dto3D merging to test-set predictions with iou = {}.'.format(self.cf.merge_3D_iou))
             pool = Pool(processes=self.cf.n_workers)
             mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.merge_3D_iou] for ii in box_results_list]
             box_results_list = pool.map(apply_2d_3d_merging_to_patient, mp_inputs, chunksize=1)
             pool.close()
             pool.join()
             del mp_inputs
 
         for ix in range(len(results_list)):
             assert np.all(results_list[ix][1] == box_results_list[ix][1]), "pid mismatch between loaded and aggregated results"
             results_list[ix][0]["boxes"] = box_results_list[ix][0]
 
         return results_list # holds (results_dict, pid)
 
     def predict_patient(self, batch):
         """
         predicts one patient.
         called either directly via loop over validation set in exec.py (mode=='val')
         or from self.predict_test_set (mode=='test).
         in val mode:  adds 3D ground truth info to predictions and runs consolidation and 2Dto3D merging of predictions.
         in test mode: returns raw predictions (ground truth addition, consolidation, 2D to 3D merging are
                       done in self.predict_test_set, because patient predictions across several epochs might be needed
                       to be collected first, in case of temporal ensembling).
         :return. results_dict: stores the results for one patient. dictionary with keys:
                  - 'boxes': list over batch elements. each element is a list over boxes, where each box is
                             one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions
                             (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions.
                  - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z))
                  - loss / class_loss (only in validation mode)
         """
-        if self.mode=="test":
-            self.logger.info('predicting patient {} for fold {} '.format(np.unique(batch['pid']), self.cf.fold))
+        #if self.mode=="test":
+        #    self.logger.info('predicting patient {} for fold {} '.format(np.unique(batch['pid']), self.cf.fold))
 
         # True if patient is provided in patches and predictions need to be tiled.
         self.patched_patient = 'patch_crop_coords' in list(batch.keys())
 
         # forward batch through prediction pipeline.
         results_dict = self.data_aug_forward(batch)
         #has seg probs in entry 'seg_preds'
 
         if self.mode == 'val':
             for b in range(batch['patient_bb_target'].shape[0]):
                 for t in range(len(batch['patient_bb_target'][b])):
                     gt_box = {'box_type': 'gt', 'box_coords': batch['patient_bb_target'][b][t],
                               'class_targets': batch['patient_class_targets'][b][t]}
                     for name in self.cf.roi_items:
                         gt_box.update({name : batch['patient_'+name][b][t]})
                     results_dict['boxes'][b].append(gt_box)
 
             if 'dice' in self.cf.metrics:
                 if self.patched_patient:
                     assert 'patient_seg' in batch.keys(), "Results_dict preds are in original patient shape."
                 results_dict['batch_dices'] = mutils.dice_per_batch_and_class(
                     results_dict['seg_preds'], batch["patient_seg"] if self.patched_patient else batch['seg'],
                     self.cf.num_seg_classes, convert_to_ohe=True)
             if self.patched_patient and self.cf.clustering == "wbc":
                 wbc_input = [self.regress_flag, results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.clustering_iou, self.n_ens]
                 results_dict['boxes'] = apply_wbc_to_patient(wbc_input)[0]
             elif self.patched_patient:
                 nms_inputs = [results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.clustering_iou]
                 results_dict['boxes'] = apply_nms_to_patient(nms_inputs)[0]
 
             if self.cf.merge_2D_to_3D_preds:
                 results_dict['2D_boxes'] = results_dict['boxes']
                 merge_dims_inputs = [results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.merge_3D_iou]
                 results_dict['boxes'] = apply_2d_3d_merging_to_patient(merge_dims_inputs)[0]
 
         return results_dict
 
     def predict_test_set(self, batch_gen, return_results=True):
         """
         wrapper around test method, which loads multiple (or one) epoch parameters (temporal ensembling), loops through
         the test set and collects predictions per patient. Also flattens the results per patient and epoch
         and adds optional ground truth boxes for evaluation. Saves out the raw result list for later analysis and
         optionally consolidates and returns predictions immediately.
         :return: (optionally) list_of_results_per_patient: list over patient results. each entry is a dict with keys:
                  - 'boxes': list over batch elements. each element is a list over boxes, where each box is
                             one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions
                             (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions.
                  - 'seg_preds': not implemented yet. todo evaluation of instance/semantic segmentation.
         """
 
         # -------------- raw predicting -----------------
         dict_of_patients_results = OrderedDict()
         set_of_result_types = set()
+
+        self.model_index = self.model_index.sort_values(by="rank")
         # get paths of all parameter sets to be loaded for temporal ensembling. (or just one for no temp. ensembling).
-        weight_paths = [os.path.join(self.cf.fold_dir, '{}_best_params.pth'.format(epoch)) for epoch in self.epoch_ranking]
+        weight_paths = [os.path.join(self.cf.fold_dir, file_name) for file_name in self.model_index["file_name"]]
 
 
         for rank_ix, weight_path in enumerate(weight_paths):
             self.logger.info(('tmp ensembling over rank_ix:{} epoch:{}'.format(rank_ix, weight_path)))
             self.net.load_state_dict(torch.load(weight_path))
             self.net.eval()
             self.rank_ix = str(rank_ix)
             with torch.no_grad():
                 plot_batches = np.random.choice(np.arange(batch_gen['n_test']), size=self.cf.n_test_plots, replace=False)
                 for i in range(batch_gen['n_test']):
                     batch = next(batch_gen['test'])
                     pid = np.unique(batch['pid'])
                     assert len(pid)==1
                     pid = pid[0]
 
                     if not pid in dict_of_patients_results.keys():  # store batch info in patient entry of results dict.
                         dict_of_patients_results[pid] = {}
                         dict_of_patients_results[pid]['results_dicts'] = []
                         dict_of_patients_results[pid]['patient_bb_target'] = batch['patient_bb_target']
 
                         for name in self.cf.roi_items:
                             dict_of_patients_results[pid]["patient_"+name] = batch["patient_"+name]
                     stime = time.time()
                     results_dict = self.predict_patient(batch) #only holds "boxes", "seg_preds"
                     # needs ohe seg probs in seg_preds entry:
                     results_dict['seg_preds'] = np.argmax(results_dict['seg_preds'], axis=1)[:,np.newaxis]
                     self.logger.info("predicting patient {} with weight rank {} (progress: {}/{}) took {:.2f}s".format(
                         str(pid), rank_ix, (rank_ix)*batch_gen['n_test']+(i+1), len(weight_paths)*batch_gen['n_test'], time.time()-stime))
 
                     if i in plot_batches and (not self.patched_patient or 'patient_data' in batch.keys()):
                         try:
                             # view qualitative results of random test case
                             self.logger.time("test_plot")
                             out_file = os.path.join(self.example_plot_dir,
                                                     'batch_example_test_{}_rank_{}.png'.format(self.cf.fold, rank_ix))
                             utils.split_off_process(plg.view_batch, self.cf, batch, results_dict,
                                                     has_colorchannels=self.cf.has_colorchannels,
                                                     show_gt_labels=True, show_seg_ids='dice' in self.cf.metrics,
                                                     get_time="test-example plot", out_file=out_file)
-                            self.logger.info("split-off example test plot {} in {:.2f}s".format(
-                                os.path.basename(out_file), self.logger.time("test_plot")))
                         except Exception as e:
                             self.logger.info("WARNING: error in view_batch: {}".format(e))
 
                     if 'dice' in self.cf.metrics:
                         if self.patched_patient:
                             assert 'patient_seg' in batch.keys(), "Results_dict preds are in original patient shape."
                         results_dict['batch_dices'] = mutils.dice_per_batch_and_class( results_dict['seg_preds'],
                                 batch["patient_seg"] if self.patched_patient else batch['seg'],
                                 self.cf.num_seg_classes, convert_to_ohe=True)
 
                     dict_of_patients_results[pid]['results_dicts'].append({k:v for k,v in results_dict.items()
                                                                            if k in ["boxes", "batch_dices"]})
                     # collect result types to know which ones to look for when saving
                     set_of_result_types.update(dict_of_patients_results[pid]['results_dicts'][-1].keys())
 
 
 
         # -------------- re-order, save raw results -----------------
         self.logger.info('finished predicting test set. starting aggregation of predictions.')
         results_per_patient = []
         for pid, p_dict in dict_of_patients_results.items():
         # dict_of_patients_results[pid]['results_list'] has length batch['n_test']
 
             results_dict = {}
             # collect all boxes/seg_preds of same batch_instance over temporal instances.
             b_size = len(p_dict['results_dicts'][0]["boxes"])
             for res_type in [rtype for rtype in set_of_result_types if rtype in ["boxes", "batch_dices"]]:#, "seg_preds"]]:
                 if not 'batch' in res_type: #assume it's results on batch-element basis
                     results_dict[res_type] = [[item for rank_dict in p_dict['results_dicts'] for item in rank_dict[res_type][batch_instance]]
                                              for batch_instance in range(b_size)]
                 else:
                     results_dict[res_type] = []
                     for dict in p_dict['results_dicts']:
                         if 'dice' in res_type:
                             item = dict[res_type] #dict['batch_dices'] has shape (num_seg_classes,)
                             assert len(item) == self.cf.num_seg_classes, \
                                 "{}, {}".format(len(item), self.cf.num_seg_classes)
                         else:
                             raise NotImplementedError
                         results_dict[res_type].append(item)
                     # rdict[dice] shape (n_rank_epochs (n_saved_ranks), nsegclasses)
                     # calc mean over test epochs so inline with shape from sampling
                     results_dict[res_type] = np.mean(results_dict[res_type], axis=0) #maybe error type with other than dice
 
             if not hasattr(self.cf, "eval_test_separately") or not self.cf.eval_test_separately:
                 # add unpatched 2D or 3D (if dim==3 or merge_2D_to_3D) ground truth boxes for evaluation.
                 for b in range(p_dict['patient_bb_target'].shape[0]):
                     for targ in range(len(p_dict['patient_bb_target'][b])):
                         gt_box = {'box_type': 'gt', 'box_coords':p_dict['patient_bb_target'][b][targ],
                                   'class_targets': p_dict['patient_class_targets'][b][targ]}
                         for name in self.cf.roi_items:
                             gt_box.update({name: p_dict["patient_"+name][b][targ]})
                         results_dict['boxes'][b].append(gt_box)
 
             results_per_patient.append([results_dict, pid])
 
         out_string = 'pred_results_held_out' if self.cf.held_out_test_set else 'pred_results'
         with open(os.path.join(self.cf.fold_dir, '{}.pkl'.format(out_string)), 'wb') as handle:
             pickle.dump(results_per_patient, handle)
 
         if return_results:
             # -------------- results processing, clustering, etc. -----------------
             final_patient_box_results = [ (res_dict["boxes"], pid) for res_dict,pid in results_per_patient ]
             if self.cf.clustering == "wbc":
                 self.logger.info('applying WBC to test-set predictions with iou = {} and n_ens = {}.'.format(
                     self.cf.clustering_iou, self.n_ens))
                 mp_inputs = [[self.regress_flag, ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou, self.n_ens] for ii in final_patient_box_results]
                 del final_patient_box_results
                 pool = Pool(processes=self.cf.n_workers)
                 final_patient_box_results = pool.map(apply_wbc_to_patient, mp_inputs, chunksize=1)
                 pool.close()
                 pool.join()
                 del mp_inputs
             elif self.cf.clustering == "nms":
                 self.logger.info('applying standard NMS to test-set predictions with iou = {}.'.format(self.cf.clustering_iou))
                 pool = Pool(processes=self.cf.n_workers)
                 mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou] for ii in final_patient_box_results]
                 del final_patient_box_results
                 final_patient_box_results = pool.map(apply_nms_to_patient, mp_inputs, chunksize=1)
                 pool.close()
                 pool.join()
                 del mp_inputs
 
             if self.cf.merge_2D_to_3D_preds:
                 self.logger.info('applying 2D-to-3D merging to test-set predictions with iou = {}.'.format(self.cf.merge_3D_iou))
                 mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.merge_3D_iou] for ii in final_patient_box_results]
                 del final_patient_box_results
                 pool = Pool(processes=self.cf.n_workers)
                 final_patient_box_results = pool.map(apply_2d_3d_merging_to_patient, mp_inputs, chunksize=1)
                 pool.close()
                 pool.join()
                 del mp_inputs
             # final_patient_box_results holds [avg_boxes, pid] if wbc
             for ix in range(len(results_per_patient)):
                 assert results_per_patient[ix][1] == final_patient_box_results[ix][1], "should be same pid"
                 results_per_patient[ix][0]["boxes"] = final_patient_box_results[ix][0]
             # results_per_patient = [(res_dict["boxes"] = boxes, pid) for (boxes,pid) in final_patient_box_results]
 
             return results_per_patient # holds list of (results_dict, pid)
diff --git a/requirements.txt b/requirements.txt
index f8f0800..6d9b914 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,62 +1,65 @@
 absl-py==0.8.1
 backcall==0.1.0
-batchgenerators==0.19.3
+batchgenerators==0.19.7
 cachetools==3.1.1
 certifi==2019.11.28
 chardet==3.0.4
 cycler==0.10.0
 Cython==0.29.14
 decorator==4.4.1
 future==0.18.2
 google-auth==1.7.2
 google-auth-oauthlib==0.4.1
 grpcio==1.25.0
 idna==2.8
 imageio==2.6.1
 ipython==7.9.0
 ipython-genutils==0.2.0
 jedi==0.15.1
 joblib==0.14.0
 kiwisolver==1.1.0
 linecache2==1.0.0
 Markdown==3.1.1
 matplotlib==3.1.2
 networkx==2.4
 nms-extension==0.0.0
 numpy==1.17.4
 oauthlib==3.1.0
 pandas==0.25.3
 parso==0.5.1
 pexpect==4.7.0
 pickleshare==0.7.5
 Pillow==6.2.1
 prompt-toolkit==2.0.10
 protobuf==3.11.1
 psutil==5.7.0
 ptyprocess==0.6.0
 pyasn1==0.4.8
 pyasn1-modules==0.2.7
 Pygments==2.5.2
 pyparsing==2.4.5
 python-dateutil==2.8.1
 pytz==2019.3
 PyWavelets==1.1.1
 RegRCNN==0.0.2
 requests==2.22.0
 requests-oauthlib==1.3.0
+RoIAlign-extension-2D==0.0.0
+RoIAlign-extension-3D==0.0.0
 rsa==4.0
 scikit-image==0.16.2
 scikit-learn==0.21.3
 scipy==1.3.1
 SimpleITK==1.2.3
 six==1.13.0
 tensorboard==2.0.2
+threadpoolctl==2.0.0
 torch==1.3.1
 torchvision==0.4.2
 tqdm==4.39.0
 traceback2==1.4.0
 traitlets==4.3.3
 unittest2==1.1.0
 urllib3==1.25.7
 wcwidth==0.1.7
 Werkzeug==0.16.0
diff --git a/shell_scripts/cluster_runner_meddec.sh b/shell_scripts/cluster_runner_meddec.sh
index d884226..0cee463 100644
--- a/shell_scripts/cluster_runner_meddec.sh
+++ b/shell_scripts/cluster_runner_meddec.sh
@@ -1,65 +1,64 @@
 #!/bin/bash
 
 #Usage:
 # -->not true?: this script has to be started from the same directory the python files called below lie in (e.g. exec.py lies in meddetectiontkit).
 # part of the slurm-job name you pass to sbatch will be the experiment folder's name.
 # you need to pass 3 positional arguments to this script (cluster_runner_..sh #1 #2 #3):
 # -#1 source directory in which main source code (framework) is located (e.g. medicaldetectiontoolkit/)
 # -#2 the exp_dir where job-specific code was copied before by create_exp and exp results are safed by exec.py
 # -#3 absolute path to dataset-specific code in source dir
 # -#4 mode to run
 # -#5 folds to run on
 
 source_dir=${1}
 exp_dir=${2}
 dataset_abs_path=${3}
 mode=${4}
 folds=${5}
 resume=$6
 
 #known problem: trap somehow does not execute the rm -r tmp_dir command when using scancel on job
 #trap clean_up EXIT KILL TERM ABRT QUIT
 
 job_dir=/ssd/ramien/${LSB_JOBID}
 
 tmp_dir_data=${job_dir}/data
 mkdir $tmp_dir_data
 
 tmp_dir_cache=${job_dir}/cache
 mkdir $tmp_dir_cache
 CUDA_CACHE_PATH=$tmp_dir_cache
 export CUDA_CACHE_PATH
 
 
 #data must not lie permantly on nodes' ssd, only during training time
 #needs to be named with the SLURM_JOB_ID to not be automatically removed
 #can permanently lie on /datasets drive --> copy from there before every experiment
 #files on datasets are saved as npz (compressed) --> use data_manager.py to copy and unpack into .npy; is done implicitly in exec.py
 
 #(tensorboard --logdir ${exp_dir}/.. --port 1337 || echo "tboard startup failed")& # || tensorboard --logdir ${exp_dir}/.. --port 1338)&
 #tboard_pid=$!
 
 #clean_up() {
 #	rm -rf ${job_dir};
 #}
 
 export OMP_NUM_THREADS=1 # this is a work-around fix for batchgenerators to deal with numpy-inherent multi-threading.
 
+launch_opts=${source_dir}/exec.py --use_stored_settings --server_env --dataset_name ${dataset_abs_path} --data_dest ${tmp_dir_data} --exp_dir ${exp_dir} --mode ${mode}
+
+if [ ! -z "${resume}" ]; then
+  launch_opts=${launch_opts} --resume
+  echo "Resuming from checkpoint(s)."
+fi
+
 if [ ! -z "${folds}" ]; then
-	if [ -z "${resume}" ]; then
-		resume='None'
-	else
-		resume=${exp_dir}"/fold_${folds}/last_state.pth"
-		echo "Resuming from checkpoint at ${resume}."
-	fi
-	python ${source_dir}/exec.py --use_stored_settings --server_env --dataset_name ${dataset_abs_path} --data_dest ${tmp_dir_data} --exp_dir ${exp_dir} --mode ${mode} --folds ${folds} --resume_from_checkpoint ${resume}
-	
-else
-	python ${source_dir}/exec.py --use_stored_settings --server_env --dataset_name ${dataset_abs_path} --data_dest ${tmp_dir_data} --exp_dir ${exp_dir} --mode ${mode}
-	
+  launch_opts=${launch_opts} --folds ${folds}
 fi
 
+python ${launch_opts}
+
 
 
 
 
diff --git a/unittests.py b/unittests.py
index 10f29d3..3b6c7a1 100644
--- a/unittests.py
+++ b/unittests.py
@@ -1,569 +1,625 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import unittest
 
 import os
 import pickle
 import time
 from multiprocessing import  Pool
 import subprocess
 from pathlib import Path
 
 import numpy as np
 import pandas as pd
 import torch
 import torchvision as tv
 
 import tqdm
 
 import plotting as plg
 import utils.exp_utils as utils
 import utils.model_utils as mutils
 
 """ Note on unittests: run this file either in the way intended for unittests by starting the script with
     python -m unittest unittests.py or start it as a normal python file as python unittests.py.
     You can selective run single tests by calling python -m unittest unittests.TestClassOfYourChoice, where 
     TestClassOfYourChoice is the name of the test defined below, e.g., CompareFoldSplits.
 """
 
 
 
 def inspect_info_df(pp_dir):
     """ use your debugger to look into the info df of a pp dir.
     :param pp_dir: preprocessed-data directory
     """
 
     info_df = pd.read_pickle(os.path.join(pp_dir, "info_df.pickle"))
 
     return
 
 
 def generate_boxes(count, dim=2, h=100, w=100, d=20, normalize=False, on_grid=False, seed=0):
     """ generate boxes of format [y1, x1, y2, x2, (z1, z2)].
     :param count: nr of boxes
     :param dim: dimension of boxes (2 or 3)
     :return: boxes in format (n_boxes, 4 or 6), scores
     """
     np.random.seed(seed)
     if on_grid:
         lower_y = np.random.randint(0, h // 2, (count,))
         lower_x = np.random.randint(0, w // 2, (count,))
         upper_y = np.random.randint(h // 2, h, (count,))
         upper_x = np.random.randint(w // 2, w, (count,))
         if dim == 3:
             lower_z = np.random.randint(0, d // 2, (count,))
             upper_z = np.random.randint(d // 2, d, (count,))
     else:
         lower_y = np.random.rand(count) * h / 2.
         lower_x = np.random.rand(count) * w / 2.
         upper_y = (np.random.rand(count) + 1.) * h / 2.
         upper_x = (np.random.rand(count) + 1.) * w / 2.
         if dim == 3:
             lower_z = np.random.rand(count) * d / 2.
             upper_z = (np.random.rand(count) + 1.) * d / 2.
 
     if dim == 3:
         boxes = np.array(list(zip(lower_y, lower_x, upper_y, upper_x, lower_z, upper_z)))
         # add an extreme box that tests the boundaries
         boxes = np.concatenate((boxes, np.array([[0., 0., h, w, 0, d]])))
     else:
         boxes = np.array(list(zip(lower_y, lower_x, upper_y, upper_x)))
         boxes = np.concatenate((boxes, np.array([[0., 0., h, w]])))
 
     scores = np.random.rand(count + 1)
     if normalize:
         divisor = np.array([h, w, h, w, d, d]) if dim == 3 else np.array([h, w, h, w])
         boxes = boxes / divisor
     return boxes, scores
 
 #------- perform integrity checks on data set(s) -----------
 class VerifyLIDCSAIntegrity(unittest.TestCase):
     """ Perform integrity checks on preprocessed single-annotator GTs of LIDC data set.
     """
     @staticmethod
     def check_patient_sa_gt(pid, pp_dir, check_meta_files, check_info_df):
 
         faulty_cases = pd.DataFrame(columns=['pid', 'rater', 'cl_targets', 'roi_ids'])
 
         all_segs = np.load(os.path.join(pp_dir, pid + "_rois.npz"), mmap_mode='r')
         all_segs = all_segs[list(all_segs.keys())[0]]
         all_roi_ids = np.unique(all_segs[all_segs > 0])
         assert len(all_roi_ids) == np.max(all_segs), "roi ids not consecutive"
         if check_meta_files:
             meta_file = os.path.join(pp_dir, pid + "_meta_info.pickle")
             with open(meta_file, "rb") as handle:
                 info = pickle.load(handle)
             assert info["pid"] == pid, "wrong pid in meta_file"
             all_cl_targets = info["class_target"]
         if check_info_df:
             info_df = pd.read_pickle(os.path.join(pp_dir, "info_df.pickle"))
             pid_info = info_df[info_df.pid == pid]
             assert len(pid_info) == 1, "found {} entries for pid {} in info df, expected exactly 1".format(len(pid_info),
                                                                                                            pid)
             if check_meta_files:
                 assert pid_info[
                            "class_target"] == all_cl_targets, "meta_info and info_df class targets mismatch:\n{}\n{}".format(
                     pid_info["class_target"], all_cl_targets)
             all_cl_targets = pid_info["class_target"].iloc[0]
         assert len(all_roi_ids) == len(all_cl_targets)
         for rater in range(4):
             seg = all_segs[rater]
             roi_ids = np.unique(seg[seg > 0])
             cl_targs = np.array([roi[rater] for roi in all_cl_targets])
             assert np.count_nonzero(cl_targs) == len(roi_ids), "rater {} has targs {} but roi ids {}".format(rater, cl_targs, roi_ids)
             assert len(cl_targs) >= len(roi_ids), "not all marked rois have a label"
             for zeroix_roi_id, rating in enumerate(cl_targs):
                 if not ((rating > 0) == (np.any(seg == zeroix_roi_id + 1))):
                     print("\n\nFAULTY CASE:", end=" ", )
                     print("pid {}, rater {}, cl_targs {}, ids {}\n".format(pid, rater, cl_targs, roi_ids))
                     faulty_cases = faulty_cases.append(
                         {'pid': pid, 'rater': rater, 'cl_targets': cl_targs, 'roi_ids': roi_ids}, ignore_index=True)
         print("finished checking pid {}, {} faulty cases".format(pid, len(faulty_cases)))
         return faulty_cases
 
     def check_sa_gts(cf, pp_dir, pid_subset=None, check_meta_files=False, check_info_df=True, processes=os.cpu_count()):
         report_name = "verify_seg_label_pairings.csv"
         pids = {file_name.split("_")[0] for file_name in os.listdir(pp_dir) if file_name not in [report_name, "info_df.pickle"]}
         if pid_subset is not None:
             pids = [pid for pid in pids if pid in pid_subset]
 
 
         faulty_cases = pd.DataFrame(columns=['pid', 'rater', 'cl_targets', 'roi_ids'])
 
         p = Pool(processes=processes)
         mp_args = zip(pids, [pp_dir]*len(pids), [check_meta_files]*len(pids), [check_info_df]*len(pids))
         patient_cases = p.starmap(self.check_patient_sa_gt, mp_args)
         p.close(); p.join()
         faulty_cases = faulty_cases.append(patient_cases, sort=False)
 
 
         print("\n\nfaulty case count {}".format(len(faulty_cases)))
         print(faulty_cases)
         findings_file = os.path.join(pp_dir, "verify_seg_label_pairings.csv")
         faulty_cases.to_csv(findings_file)
 
         assert len(faulty_cases)==0, "there was a faulty case in data set {}.\ncheck {}".format(pp_dir, findings_file)
 
     def test(self):
         pp_root = "/media/gregor/HDD2TB/Documents/data/"
         pp_dir = "lidc/pp_20190805"
         gt_dir = os.path.join(pp_root, pp_dir, "patient_gts_sa")
         self.check_sa_gts(gt_dir, check_meta_files=True, check_info_df=False, pid_subset=None)  # ["0811a", "0812a"])
 
 #------ compare segmentation gts of preprocessed data sets ------
 class CompareSegGTs(unittest.TestCase):
     """ load and compare pre-processed gts by dice scores of segmentations.
 
     """
     @staticmethod
     def group_seg_paths(ref_path, comp_paths):
         # not working recursively
         ref_files = [fn for fn in os.listdir(ref_path) if
                      os.path.isfile(os.path.join(ref_path, fn)) and 'seg' in fn and fn.endswith('.npy')]
 
         comp_files = [[os.path.join(c_path, fn) for c_path in comp_paths] for fn in ref_files]
 
         ref_files = [os.path.join(ref_path, fn) for fn in ref_files]
 
         return zip(ref_files, comp_files)
 
     @staticmethod
     def load_calc_dice(paths):
         dices = []
         ref_seg = np.load(paths[0])[np.newaxis, np.newaxis]
         n_classes = len(np.unique(ref_seg))
         ref_seg = mutils.get_one_hot_encoding(ref_seg, n_classes)
 
         for c_file in paths[1]:
             c_seg = np.load(c_file)[np.newaxis, np.newaxis]
             assert n_classes == len(np.unique(c_seg)), "unequal nr of objects/classes betw segs {} {}".format(paths[0],
                                                                                                               c_file)
             c_seg = mutils.get_one_hot_encoding(c_seg, n_classes)
 
             dice = mutils.dice_per_batch_inst_and_class(c_seg, ref_seg, n_classes, convert_to_ohe=False)
             dices.append(dice)
         print("processed ref_path {}".format(paths[0]))
         return np.mean(dices), np.std(dices)
 
     def iterate_files(self, grouped_paths, processes=os.cpu_count()):
         p = Pool(processes)
 
         means_stds = np.array(p.map(self.load_calc_dice, grouped_paths))
 
         p.close(); p.join()
         min_dice = np.min(means_stds[:, 0])
         print("min mean dice {:.2f}, max std {:.4f}".format(min_dice, np.max(means_stds[:, 1])))
         assert min_dice > 1-1e5, "compared seg gts have insufficient minimum mean dice overlap of {}".format(min_dice)
 
     def test(self):
         ref_path = '/media/gregor/HDD2TB/Documents/data/prostate/data_t2_250519_ps384_gs6071'
         comp_paths = ['/media/gregor/HDD2TB/Documents/data/prostate/data_t2_190419_ps384_gs6071', ]
         paths = self.group_seg_paths(ref_path, comp_paths)
         self.iterate_files(paths)
 
 #------- check if cross-validation fold splits of different experiments are identical ----------
 class CompareFoldSplits(unittest.TestCase):
     """ Find evtl. differences in cross-val file splits across different experiments.
     """
     @staticmethod
     def group_id_paths(ref_exp_dir, comp_exp_dirs):
 
         f_name = 'fold_ids.pickle'
 
         ref_paths = os.path.join(ref_exp_dir, f_name)
         assert os.path.isfile(ref_paths), "ref file {} does not exist.".format(ref_paths)
 
 
         ref_paths = [ref_paths for comp_ed in comp_exp_dirs]
         comp_paths = [os.path.join(comp_ed, f_name) for comp_ed in comp_exp_dirs]
 
         return zip(ref_paths, comp_paths)
 
     @staticmethod
     def comp_fold_ids(mp_input):
         fold_ids1, fold_ids2 = mp_input
         with open(fold_ids1, 'rb') as f:
             fold_ids1 = pickle.load(f)
         try:
             with open(fold_ids2, 'rb') as f:
                 fold_ids2 = pickle.load(f)
         except FileNotFoundError:
             print("comp file {} does not exist.".format(fold_ids2))
             return
 
         n_splits = len(fold_ids1)
         assert n_splits == len(fold_ids2), "mismatch n splits: ref has {}, comp {}".format(n_splits, len(fold_ids2))
         split_diffs = [np.setdiff1d(fold_ids1[s], fold_ids2[s]) for s in range(n_splits)]
         all_equal = np.any(split_diffs)
         return (split_diffs, all_equal)
 
     def iterate_exp_dirs(self, ref_exp, comp_exps, processes=os.cpu_count()):
 
         grouped_paths = list(self.group_id_paths(ref_exp, comp_exps))
         print("performing {} comparisons of cross-val file splits".format(len(grouped_paths)))
         p = Pool(processes)
         split_diffs = p.map(self.comp_fold_ids, grouped_paths)
         p.close(); p.join()
 
         df = pd.DataFrame(index=range(0,len(grouped_paths)), columns=["ref", "comp", "all_equal"])#, "diffs"])
         for ix, (ref, comp) in enumerate(grouped_paths):
             df.iloc[ix] = [ref, comp, split_diffs[ix][1]]#, split_diffs[ix][0]]
 
         print("Any splits not equal?", df.all_equal.any())
         assert not df.all_equal.any(), "a split set is different from reference split set, {}".format(df[~df.all_equal])
 
     def test(self):
         exp_parent_dir = '/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/'
         ref_exp = '/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_detfpn2d_cl_bs10'
         comp_exps = [os.path.join(exp_parent_dir, p) for p in os.listdir(exp_parent_dir)]
         comp_exps = [p for p in comp_exps if os.path.isdir(p) and p != ref_exp]
         self.iterate_exp_dirs(ref_exp, comp_exps)
 
 
 #------- check if cross-validation fold splits of a single experiment are actually incongruent (as required) ----------
 class VerifyFoldSplits(unittest.TestCase):
     """ Check, for a single fold_ids file, i.e., for a single experiment, if the assigned folds (assignment of data
         identifiers) is actually incongruent. No overlaps between folds are required for a correct cross validation.
     """
     @staticmethod
     def verify_fold_ids(splits):
         for i, split1 in enumerate(splits):
             for j, split2 in enumerate(splits):
                 if j > i:
                     inter = np.intersect1d(split1, split2)
                     if len(inter) > 0:
                         raise Exception("Split {} and {} intersect by pids {}".format(i, j, inter))
     def test(self):
         exp_dir = "/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/dev"
         check_file = os.path.join(exp_dir, 'fold_ids.pickle')
         with open(check_file, 'rb') as handle:
             splits = pickle.load(handle)
         self.verify_fold_ids(splits)
 
 # -------- check own nms CUDA implement against own numpy implement ------
 class CheckNMSImplementation(unittest.TestCase):
 
     @staticmethod
     def assert_res_equality(keep_ics1, keep_ics2, boxes, scores, tolerance=0, names=("res1", "res2")):
         """
         :param keep_ics1: keep indices (results), torch.Tensor of shape (n_ics,)
         :param keep_ics2:
         :return:
         """
         keep_ics1, keep_ics2 = keep_ics1.cpu().numpy(), keep_ics2.cpu().numpy()
         discrepancies = np.setdiff1d(keep_ics1, keep_ics2)
         try:
             checks = np.array([
                 len(discrepancies) <= tolerance
             ])
         except:
             checks = np.zeros((1,)).astype("bool")
         msgs = np.array([
             """{}: {} \n{}: {} \nboxes: {}\n {}\n""".format(names[0], keep_ics1, names[1], keep_ics2, boxes,
                                                             scores)
         ])
 
         assert np.all(checks), "NMS: results mismatch: " + "\n".join(msgs[~checks])
 
     def single_case(self, count=20, dim=3, threshold=0.2, seed=0):
         boxes, scores = generate_boxes(count, dim, seed=seed, h=320, w=280, d=30)
 
         keep_numpy = torch.tensor(mutils.nms_numpy(boxes, scores, threshold))
 
         # for some reason torchvision nms requires box coords as floats.
         boxes = torch.from_numpy(boxes).type(torch.float32)
         scores = torch.from_numpy(scores).type(torch.float32)
         if dim == 2:
             """need to wait until next pytorch release where they fixed nms on cpu (currently they have >= where it
             needs to be >.
             """
             keep_ops = tv.ops.nms(boxes, scores, threshold)
             # self.assert_res_equality(keep_numpy, keep_ops, boxes, scores, tolerance=0, names=["np", "ops"])
             pass
 
         boxes = boxes.cuda()
         scores = scores.cuda()
         keep = self.nms_ext.nms(boxes, scores, threshold)
         self.assert_res_equality(keep_numpy, keep, boxes, scores, tolerance=0, names=["np", "cuda"])
 
     def test(self, n_cases=200, box_count=30, threshold=0.5):
         # dynamically import module so that it doesn't affect other tests if import fails
         self.nms_ext = utils.import_module("nms_ext", 'custom_extensions/nms/nms.py')
         # change seed to something fix if you want exactly reproducible test
         seed0 = np.random.randint(50)
         print("NMS test progress (done/total box configurations) 2D:", end="\n")
         for i in tqdm.tqdm(range(n_cases)):
             self.single_case(count=box_count, dim=2, threshold=threshold, seed=seed0+i)
         print("NMS test progress (done/total box configurations) 3D:", end="\n")
         for i in tqdm.tqdm(range(n_cases)):
             self.single_case(count=box_count, dim=3, threshold=threshold, seed=seed0+i)
 
         return
 
 class CheckRoIAlignImplementation(unittest.TestCase):
 
     def prepare(self, dim=2):
 
         b, c, h, w = 1, 3, 50, 50
         # feature map, (b, c, h, w(, z))
         if dim == 2:
             fmap = torch.rand(b, c, h, w).cuda()
             # rois = torch.tensor([[
             #     [0.1, 0.1, 0.3, 0.3],
             #     [0.2, 0.2, 0.4, 0.7],
             #     [0.5, 0.7, 0.7, 0.9],
             # ]]).cuda()
             pool_size = (7, 7)
             rois = generate_boxes(5, dim=dim, h=h, w=w, on_grid=True, seed=np.random.randint(50))[0]
         elif dim == 3:
             d = 20
             fmap = torch.rand(b, c, h, w, d).cuda()
             # rois = torch.tensor([[
             #     [0.1, 0.1, 0.3, 0.3, 0.1, 0.1],
             #     [0.2, 0.2, 0.4, 0.7, 0.2, 0.4],
             #     [0.5, 0.0, 0.7, 1.0, 0.4, 0.5],
             #     [0.0, 0.0, 0.9, 1.0, 0.0, 1.0],
             # ]]).cuda()
             pool_size = (7, 7, 3)
             rois = generate_boxes(5, dim=dim, h=h, w=w, d=d, on_grid=True, seed=np.random.randint(50),
                                   normalize=False)[0]
         else:
             raise ValueError("dim needs to be 2 or 3")
 
         rois = [torch.from_numpy(rois).type(dtype=torch.float32).cuda(), ]
         fmap.requires_grad_(True)
         return fmap, rois, pool_size
 
     def check_2d(self):
         """ check vs torchvision ops not possible as on purpose different approach.
         :return:
         """
         raise NotImplementedError
         # fmap, rois, pool_size = self.prepare(dim=2)
         # ra_object = self.ra_ext.RoIAlign(output_size=pool_size, spatial_scale=1., sampling_ratio=-1)
         # align_ext = ra_object(fmap, rois)
         # loss_ext = align_ext.sum()
         # loss_ext.backward()
         #
         # rois_swapped = [rois[0][:, [1,3,0,2]]]
         # align_ops = tv.ops.roi_align(fmap, rois_swapped, pool_size)
         # loss_ops = align_ops.sum()
         # loss_ops.backward()
         #
         # assert (loss_ops == loss_ext), "sum of roialign ops and extension 2D diverges"
         # assert (align_ops == align_ext).all(), "ROIAlign failed 2D test"
 
     def check_3d(self):
         fmap, rois, pool_size = self.prepare(dim=3)
         ra_object = self.ra_ext.RoIAlign(output_size=pool_size, spatial_scale=1., sampling_ratio=-1)
         align_ext = ra_object(fmap, rois)
         loss_ext = align_ext.sum()
         loss_ext.backward()
 
         align_np = mutils.roi_align_3d_numpy(fmap.cpu().detach().numpy(), [roi.cpu().numpy() for roi in rois],
                                              pool_size)
         align_np = np.squeeze(align_np)  # remove singleton batch dim
 
         align_ext = align_ext.cpu().detach().numpy()
         assert np.allclose(align_np, align_ext, rtol=1e-5,
                            atol=1e-8), "RoIAlign differences in numpy and CUDA implement"
 
     def specific_example_check(self):
         # dummy input
         self.ra_ext = utils.import_module("ra_ext", 'custom_extensions/roi_align/roi_align.py')
         exp = 6
         pool_size = (2,2)
         fmap = torch.arange(exp**2).view(exp,exp).unsqueeze(0).unsqueeze(0).cuda().type(dtype=torch.float32)
 
         boxes = torch.tensor([[1., 1., 5., 5.]]).cuda()/exp
         ind = torch.tensor([0.]*len(boxes)).cuda().type(torch.float32)
         y_exp, x_exp = fmap.shape[2:]  # exp = expansion
         boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp], dtype=torch.float32).cuda())
         boxes = torch.cat((ind.unsqueeze(1), boxes), dim=1)
         aligned_tv = tv.ops.roi_align(fmap, boxes, output_size=pool_size, sampling_ratio=-1)
         aligned = self.ra_ext.roi_align_2d(fmap, boxes, output_size=pool_size, sampling_ratio=-1)
 
         boxes_3d = torch.cat((boxes, torch.tensor([[-1.,1.]]*len(boxes)).cuda()), dim=1)
         fmap_3d = fmap.unsqueeze(dim=-1)
         pool_size = (*pool_size,1)
         ra_object = self.ra_ext.RoIAlign(output_size=pool_size, spatial_scale=1.,)
         aligned_3d = ra_object(fmap_3d, boxes_3d)
 
-        expected_res = torch.tensor([[[[10.5000, 12.5000],
-                                       [22.5000, 24.5000]]]]).cuda()
-        expected_res_3d = torch.tensor([[[[[10.5000],[12.5000]],
-                                          [[22.5000],[24.5000]]]]]).cuda()
+        # expected_res = torch.tensor([[[[10.5000, 12.5000], # this would be with an alternative grid-point setting
+        #                                [22.5000, 24.5000]]]]).cuda()
+        expected_res = torch.tensor([[[[14., 16.],
+                                       [26., 28.]]]]).cuda()
+        expected_res_3d = torch.tensor([[[[[14.],[16.]],
+                                          [[26.],[28.]]]]]).cuda()
         assert torch.all(aligned==expected_res), "2D RoIAlign check vs. specific example failed. res: {}\n expected: {}\n".format(aligned, expected_res)
         assert torch.all(aligned_3d==expected_res_3d), "3D RoIAlign check vs. specific example failed. res: {}\n expected: {}\n".format(aligned_3d, expected_res_3d)
 
     def manual_check(self):
         """ print examples from a toy batch to file.
         :return:
         """
         self.ra_ext = utils.import_module("ra_ext", 'custom_extensions/roi_align/roi_align.py')
         # actual mrcnn mask input
         from datasets.toy import configs
         cf = configs.Configs()
         cf.exp_dir = "datasets/toy/experiments/dev/"
         cf.plot_dir = cf.exp_dir + "plots"
         os.makedirs(cf.exp_dir, exist_ok=True)
         cf.fold = 0
         cf.n_workers = 1
         logger = utils.get_logger(cf.exp_dir)
         data_loader = utils.import_module('data_loader', os.path.join("datasets", "toy", 'data_loader.py'))
         batch_gen = data_loader.get_train_generators(cf, logger=logger)
         batch = next(batch_gen['train'])
         roi_mask = np.zeros((1, 320, 200))
         bb_target = (np.array([50, 40, 90, 120])).astype("int")
         roi_mask[:, bb_target[0]+1:bb_target[2]+1, bb_target[1]+1:bb_target[3]+1] = 1.
         #batch = {"roi_masks": np.array([np.array([roi_mask, roi_mask]), np.array([roi_mask])]), "bb_target": [[bb_target, bb_target + 25], [bb_target-20]]}
         #batch_boxes_cor = [torch.tensor(batch_el_boxes).cuda().float() for batch_el_boxes in batch_cor["bb_target"]]
         batch_boxes = [torch.tensor(batch_el_boxes).cuda().float() for batch_el_boxes in batch["bb_target"]]
         #import IPython; IPython.embed()
         for b in range(len(batch_boxes)):
             roi_masks = batch["roi_masks"][b]
             #roi_masks_cor = batch_cor["roi_masks"][b]
             if roi_masks.sum()>0:
                 boxes = batch_boxes[b]
                 roi_masks = torch.tensor(roi_masks).cuda().type(dtype=torch.float32)
                 box_ids = torch.arange(roi_masks.shape[0]).cuda().unsqueeze(1).type(dtype=torch.float32)
                 masks = tv.ops.roi_align(roi_masks, [boxes], cf.mask_shape)
                 masks = masks.squeeze(1)
                 masks = torch.round(masks)
                 masks_own = self.ra_ext.roi_align_2d(roi_masks, torch.cat((box_ids, boxes), dim=1), cf.mask_shape)
                 boxes = boxes.type(torch.int)
                 #print("check roi mask", roi_masks[0, 0, boxes[0][0]:boxes[0][2], boxes[0][1]:boxes[0][3]].sum(), (boxes[0][2]-boxes[0][0]) * (boxes[0][3]-boxes[0][1]))
                 #print("batch masks", batch["roi_masks"])
                 masks_own = masks_own.squeeze(1)
                 masks_own = torch.round(masks_own)
                 #import IPython; IPython.embed()
                 for mix, mask in enumerate(masks):
                     fig = plg.plt.figure()
                     ax = fig.add_subplot()
                     ax.imshow(roi_masks[mix][0].cpu().numpy(), cmap="gray", vmin=0.)
                     ax.axis("off")
                     y1, x1, y2, x2 = boxes[mix]
                     bbox = plg.mpatches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=0.9, edgecolor="c", facecolor='none')
                     ax.add_patch(bbox)
                     x1, y1, x2, y2 = boxes[mix]
                     bbox = plg.mpatches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=0.9, edgecolor="r",
                                                   facecolor='none')
                     ax.add_patch(bbox)
                     debug_dir = Path("/home/gregor/Documents/regrcnn/datasets/toy/experiments/debugroial")
                     os.makedirs(debug_dir, exist_ok=True)
                     plg.plt.savefig(debug_dir/"mask_b{}_{}.png".format(b, mix))
                     plg.plt.imsave(debug_dir/"mask_b{}_{}_pooled_tv.png".format(b, mix), mask.cpu().numpy(), cmap="gray", vmin=0.)
                     plg.plt.imsave(debug_dir/"mask_b{}_{}_pooled_own.png".format(b, mix), masks_own[mix].cpu().numpy(), cmap="gray", vmin=0.)
         return
 
     def test(self):
         # dynamically import module so that it doesn't affect other tests if import fails
         self.ra_ext = utils.import_module("ra_ext", 'custom_extensions/roi_align/roi_align.py')
 
         self.specific_example_check()
 
         # 2d test
         #self.check_2d()
 
         # 3d test
         self.check_3d()
 
         return
 
 
 class CheckRuntimeErrors(unittest.TestCase):
     """ Check if minimal examples of the exec.py module finish without runtime errors.
         This check requires a working path to data in the toy-dataset configs.
     """
 
     def test(self):
         cf = utils.import_module("toy_cf", 'datasets/toy/configs.py').Configs()
         exp_dir = "./unittesting/"
         #checks = {"retina_net": False, "mrcnn": False}
         #print("Testing for runtime errors with models {}".format(list(checks.keys())))
         #for model in tqdm.tqdm(list(checks.keys())):
             # cf.model = model
             # cf.model_path = 'models/{}.py'.format(cf.model if not 'retina' in cf.model else 'retina_net')
             # cf.model_path = os.path.join(cf.source_dir, cf.model_path)
             # {'mrcnn': cf.add_mrcnn_configs,
             #  'retina_net': cf.add_mrcnn_configs, 'retina_unet': cf.add_mrcnn_configs,
             #  'detection_unet': cf.add_det_unet_configs, 'detection_fpn': cf.add_det_fpn_configs
             #  }[model]()
         # todo change structure of configs-handling with exec.py so that its dynamically parseable instead of needing to
         # todo be changed in the file all the time.
         checks = {cf.model:False}
         completed_process = subprocess.run("python exec.py --dev --dataset_name toy -m train_test --exp_dir {}".format(exp_dir),
                                            shell=True, capture_output=True, text=True)
         if completed_process.returncode!=0:
             print("Runtime test of model {} failed due to\n{}".format(cf.model, completed_process.stderr))
         else:
             checks[cf.model] = True
         subprocess.call("rm -rf {}".format(exp_dir), shell=True)
         assert all(checks.values()), "A runtime test crashed."
 
+class MulithreadedDataiterator(unittest.TestCase):
+
+    def test(self):
+        print("Testing multithreaded iterator.")
+
+
+        dataset = "toy"
+        exp_dir = Path("datasets/{}/experiments/dev".format(dataset))
+        cf_file = utils.import_module("cf_file", exp_dir/"configs.py")
+        cf = cf_file.Configs()
+        dloader = utils.import_module('data_loader', 'datasets/{}/data_loader.py'.format(dataset))
+        cf.exp_dir = Path(exp_dir)
+        cf.n_workers = 5
+
+        cf.batch_size = 3
+        cf.fold = 0
+        cf.plot_dir = cf.exp_dir / "plots"
+        logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval)
+        cf.num_val_batches = "all"
+        cf.val_mode = "val_sampling"
+        cf.n_workers = 8
+        batch_gens = dloader.get_train_generators(cf, logger, data_statistics=False)
+        val_loader = batch_gens["val_sampling"]
+
+        for epoch in range(4):
+            produced_ids = []
+            for i in range(batch_gens['n_val']):
+                batch = next(val_loader)
+                produced_ids.append(batch["pid"])
+            uni, cts = np.unique(np.concatenate(produced_ids), return_counts=True)
+            assert np.all(cts < 3), "with batch size one: every item should occur exactly once.\n uni {}, cts {}".format(
+                uni[cts>2], cts[cts>2])
+            #assert len(np.setdiff1d(val_loader.generator.dataset_pids, uni))==0, "not all val pids were shown."
+            assert len(np.setdiff1d(uni, val_loader.generator.dataset_pids))==0, "pids shown that are not val set. impossible?"
+
+        cf.n_workers = os.cpu_count()
+        cf.batch_size = int(val_loader.generator.dataset_length / cf.n_workers) + 2
+        val_loader = dloader.create_data_gen_pipeline(cf, val_loader.generator._data, do_aug=False, sample_pids_w_replace=False,
+                                                             max_batches=None, raise_stop_iteration=True)
+        for epoch in range(2):
+            produced_ids = []
+            for b, batch in enumerate(val_loader):
+                produced_ids.append(batch["pid"])
+            uni, cts = np.unique(np.concatenate(produced_ids), return_counts=True)
+            assert np.all(cts == 1), "with batch size one: every item should occur exactly once.\n uni {}, cts {}".format(
+                uni[cts>1], cts[cts>1])
+            assert len(np.setdiff1d(val_loader.generator.dataset_pids, uni))==0, "not all val pids were shown."
+            assert len(np.setdiff1d(uni, val_loader.generator.dataset_pids))==0, "pids shown that are not val set. impossible?"
+
+
+
+
+        pass
+
 
 if __name__=="__main__":
     stime = time.time()
 
     t = CheckRoIAlignImplementation()
     t.manual_check()
     #unittest.main()
 
     mins, secs = divmod((time.time() - stime), 60)
     h, mins = divmod(mins, 60)
     t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
     print("{} total runtime: {}".format(os.path.split(__file__)[1], t))
\ No newline at end of file
diff --git a/utils/dataloader_utils.py b/utils/dataloader_utils.py
index 0724b28..f3e5850 100644
--- a/utils/dataloader_utils.py
+++ b/utils/dataloader_utils.py
@@ -1,723 +1,729 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 import plotting as plg
 
 import os
 from multiprocessing import Pool, Lock
 import pickle
 import warnings
 
 import numpy as np
 import pandas as pd
 from batchgenerators.transforms.abstract_transforms import AbstractTransform
 from scipy.ndimage.measurements import label as lb
 from torch.utils.data import Dataset as torchDataset
 from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
 
 import utils.exp_utils as utils
 import data_manager as dmanager
 
 
 for msg in ["This figure includes Axes that are not compatible with tight_layout",
             "Data has no positive values, and therefore cannot be log-scaled."]:
     warnings.filterwarnings("ignore", msg)
 
 
 class AttributeDict(dict):
     __getattr__ = dict.__getitem__
     __setattr__ = dict.__setitem__
 
 ##################################
 #  data loading, organisation  #
 ##################################
 
 
 class fold_generator:
     """
     generates splits of indices for a given length of a dataset to perform n-fold cross-validation.
     splits each fold into 3 subsets for training, validation and testing.
     This form of cross validation uses an inner loop test set, which is useful if test scores shall be reported on a
     statistically reliable amount of patients, despite limited size of a dataset.
     If hold out test set is provided and hence no inner loop test set needed, just add test_idxs to the training data in the dataloader.
     This creates straight-forward train-val splits.
     :returns names list: list of len n_splits. each element is a list of len 3 for train_ix, val_ix, test_ix.
     """
     def __init__(self, seed, n_splits, len_data):
         """
         :param seed: Random seed for splits.
         :param n_splits: number of splits, e.g. 5 splits for 5-fold cross-validation
         :param len_data: number of elements in the dataset.
         """
         self.tr_ix = []
         self.val_ix = []
         self.te_ix = []
         self.slicer = None
         self.missing = 0
         self.fold = 0
         self.len_data = len_data
         self.n_splits = n_splits
         self.myseed = seed
         self.boost_val = 0
 
     def init_indices(self):
 
         t = list(np.arange(self.l))
         # round up to next splittable data amount.
         split_length = int(np.ceil(len(t) / float(self.n_splits)))
         self.slicer = split_length
         self.mod = len(t) % self.n_splits
         if self.mod > 0:
             # missing is the number of folds, in which the new splits are reduced to account for missing data.
             self.missing = self.n_splits - self.mod
 
         self.te_ix = t[:self.slicer]
         self.tr_ix = t[self.slicer:]
         self.val_ix = self.tr_ix[:self.slicer]
         self.tr_ix = self.tr_ix[self.slicer:]
 
     def new_fold(self):
 
         slicer = self.slicer
         if self.fold < self.missing :
             slicer = self.slicer - 1
 
         temp = self.te_ix
 
         # catch exception mod == 1: test set collects 1+ data since walk through both roudned up splits.
         # account for by reducing last fold split by 1.
         if self.fold == self.n_splits-2 and self.mod ==1:
             temp += self.val_ix[-1:]
             self.val_ix = self.val_ix[:-1]
 
         self.te_ix = self.val_ix
         self.val_ix = self.tr_ix[:slicer]
         self.tr_ix = self.tr_ix[slicer:] + temp
 
 
     def get_fold_names(self):
         names_list = []
         rgen = np.random.RandomState(self.myseed)
         cv_names = np.arange(self.len_data)
 
         rgen.shuffle(cv_names)
         self.l = len(cv_names)
         self.init_indices()
 
         for split in range(self.n_splits):
             train_names, val_names, test_names = cv_names[self.tr_ix], cv_names[self.val_ix], cv_names[self.te_ix]
             names_list.append([train_names, val_names, test_names, self.fold])
             self.new_fold()
             self.fold += 1
 
         return names_list
 
 
 
 class FoldGenerator():
     r"""takes a set of elements (identifiers) and randomly splits them into the specified amt of subsets.
     """
 
     def __init__(self, identifiers, seed, n_splits=5):
         self.ids = np.array(identifiers)
         self.n_splits = n_splits
         self.seed = seed
 
     def generate_splits(self, n_splits=None):
         if n_splits is None:
             n_splits = self.n_splits
 
         rgen = np.random.RandomState(self.seed)
         rgen.shuffle(self.ids)
         self.splits = list(np.array_split(self.ids, n_splits, axis=0))  # already returns list, but to be sure
         return self.splits
 
 
 class Dataset(torchDataset):
     r"""Parent Class for actual Dataset classes to inherit from!
     """
     def __init__(self, cf, data_sourcedir=None):
         super(Dataset, self).__init__()
         self.cf = cf
 
         self.data_sourcedir = cf.data_sourcedir if data_sourcedir is None else data_sourcedir
         self.data_dir = cf.data_dir if hasattr(cf, 'data_dir') else self.data_sourcedir
 
         self.data_dest = cf.data_dest if hasattr(cf, "data_dest") else self.data_sourcedir
 
         self.data = {}
         self.set_ids = []
 
     def copy_data(self, cf, file_subset, keep_packed=False, del_after_unpack=False):
         if os.path.normpath(self.data_sourcedir) != os.path.normpath(self.data_dest):
             self.data_sourcedir = os.path.join(self.data_sourcedir, '')
             args = AttributeDict({
                     "source" :  self.data_sourcedir,
                     "destination" : self.data_dest,
                     "recursive" : True,
                     "cp_only_npz" : False,
                     "keep_packed" : keep_packed,
                     "del_after_unpack" : del_after_unpack,
                     "threads" : 16 if self.cf.server_env else os.cpu_count()
                     })
             dmanager.copy(args, file_subset=file_subset)
             self.data_dir = self.data_dest
 
 
 
     def __len__(self):
         return len(self.data)
     def __getitem__(self, id):
         """Return a sample of the dataset, i.e.,the dict of the id
         """
         return self.data[id]
     def __iter__(self):
         return self.data.__iter__()
 
     def init_FoldGenerator(self, seed, n_splits):
         self.fg = FoldGenerator(self.set_ids, seed=seed, n_splits=n_splits)
 
     def generate_splits(self, check_file):
         if not os.path.exists(check_file):
             self.fg.generate_splits()
             with open(check_file, 'wb') as handle:
                 pickle.dump(self.fg.splits, handle)
         else:
             with open(check_file, 'rb') as handle:
                 self.fg.splits = pickle.load(handle)
 
     def calc_statistics(self, subsets=None, plot_dir=None, overall_stats=True):
 
         if self.df is None:
             self.df = pd.DataFrame()
             balance_t = self.cf.balance_target if hasattr(self.cf, "balance_target") else "class_targets"
             self.df._metadata.append(balance_t)
             if balance_t=="class_targets":
                 mapper = lambda cl_id: self.cf.class_id2label[cl_id]
                 labels = self.cf.class_id2label.values()
             elif balance_t=="rg_bin_targets":
                 mapper = lambda rg_bin: self.cf.bin_id2label[rg_bin]
                 labels = self.cf.bin_id2label.values()
             # elif balance_t=="regression_targets":
             #     # todo this wont work
             #     mapper = lambda rg_val: AttributeDict({"name":rg_val}) #self.cf.bin_id2label[self.cf.rg_val_to_bin_id(rg_val)]
             #     labels = self.cf.bin_id2label.values()
             elif balance_t=="lesion_gleasons":
                 mapper = lambda gs: self.cf.gs2label[gs]
                 labels = self.cf.gs2label.values()
             else:
                 mapper = lambda x: AttributeDict({"name":x})
                 labels = None
             for pid, subj_data in self.data.items():
                 unique_ts, counts = np.unique(subj_data[balance_t], return_counts=True)
                 self.df = self.df.append(pd.DataFrame({"pid": [pid],
                                                        **{mapper(unique_ts[i]).name: [counts[i]] for i in
                                                           range(len(unique_ts))}}), ignore_index=True, sort=True)
             self.df = self.df.fillna(0)
 
         if overall_stats:
             df = self.df.drop("pid", axis=1)
             df = df.reindex(sorted(df.columns), axis=1).astype('uint32')
             print("Overall dataset roi counts per target kind:"); print(df.sum())
         if subsets is not None:
             self.df["subset"] = np.nan
             self.df["display_order"] = np.nan
             for ix, (subset, pids) in enumerate(subsets.items()):
                 self.df.loc[self.df.pid.isin(pids), "subset"] = subset
                 self.df.loc[self.df.pid.isin(pids), "display_order"] = ix
             df = self.df.groupby("subset").agg("sum").drop("pid", axis=1, errors='ignore').astype('int64')
             df = df.sort_values(by=['display_order']).drop('display_order', axis=1)
             df = df.reindex(sorted(df.columns), axis=1)
 
             print("Fold {} dataset roi counts per target kind:".format(self.cf.fold)); print(df)
         if plot_dir is not None:
             os.makedirs(plot_dir, exist_ok=True)
             if subsets is not None:
                 plg.plot_fold_stats(self.cf, df, labels, os.path.join(plot_dir, "data_stats_fold_" + str(self.cf.fold))+".pdf")
             if overall_stats:
                 plg.plot_data_stats(self.cf, df, labels, os.path.join(plot_dir, 'data_stats_overall.pdf'))
 
         return df, labels
 
 
 def get_class_balanced_patients(all_pids, class_targets, batch_size, num_classes, random_ratio=0):
     '''
     samples towards equilibrium of classes (on basis of total RoI counts). for highly imbalanced dataset, this might be a too strong requirement.
     :param class_targets: dic holding {patient_specifier : ROI class targets}, list position of ROI target corresponds to respective seg label - 1
     :param batch_size:
     :param num_classes:
     :return:
     '''
     # assert len(all_pids)>=batch_size, "not enough eligible pids {} to form a single batch of size {}".format(len(all_pids), batch_size)
     class_counts = {k: 0 for k in range(1,num_classes+1)}
     not_picked = np.array(all_pids)
     batch_patients = np.empty((batch_size,), dtype=not_picked.dtype)
     rarest_class = np.random.randint(1,num_classes+1)
 
     for ix in range(batch_size):
         if len(not_picked) == 0:
             warnings.warn("Dataset too small to generate batch with unique samples; => recycling.")
             not_picked = np.array(all_pids)
 
         np.random.shuffle(not_picked) #this could actually go outside(above) the loop.
         pick = not_picked[0]
         for cand in not_picked:
             if np.count_nonzero(class_targets[cand] == rarest_class) > 0:
                 pick = cand
                 cand_rarest_class = np.argmin([np.count_nonzero(class_targets[cand] == cl) for cl in
                                                range(1,num_classes+1)])+1
                 # if current batch already bigger than the batch random ratio, then
                 # check that weakest class in this patient is not the weakest in current batch (since needs to be boosted)
                 # also that at least one roi of this patient belongs to weakest class. If True, keep patient, else keep looking.
                 if (cand_rarest_class != rarest_class and np.count_nonzero(class_targets[cand] == rarest_class) > 0) \
                         or ix < int(batch_size * random_ratio):
                     break
 
         for c in range(1,num_classes+1):
             class_counts[c] += np.count_nonzero(class_targets[pick] == c)
         if not ix < int(batch_size * random_ratio) and class_counts[rarest_class] == 0:  # means searched thru whole set without finding rarest class
             print("Class {} not represented in current dataset.".format(rarest_class))
         rarest_class = np.argmin(([class_counts[c] for c in range(1,num_classes+1)]))+1
         batch_patients[ix] = pick
         not_picked = not_picked[not_picked != pick]  # removes pick
 
     return batch_patients
 
 
 class BatchGenerator(SlimDataLoaderBase):
     """
     create the training/validation batch generator. Randomly sample batch_size patients
     from the data set, (draw a random slice if 2D), pad-crop them to equal sizes and merge to an array.
     :param data: data dictionary as provided by 'load_dataset'
     :param img_modalities: list of strings ['adc', 'b1500'] from config
     :param batch_size: number of patients to sample for the batch
     :param pre_crop_size: equal size for merging the patients to a single array (before the final random-crop in data aug.)
     :return dictionary containing the batch data / seg / pids as lists; the augmenter will later concatenate them into an array.
     """
 
     def __init__(self, cf, data, sample_pids_w_replace=True, max_batches=None, raise_stop_iteration=False, n_threads=None, seed=0):
         if n_threads is None:
             n_threads = cf.n_workers
         super(BatchGenerator, self).__init__(data, cf.batch_size, number_of_threads_in_multithreaded=n_threads)
         self.cf = cf
         self.random_count = int(cf.batch_random_ratio * cf.batch_size)
         self.plot_dir = os.path.join(self.cf.plot_dir, 'train_generator')
         os.makedirs(self.plot_dir, exist_ok=True)
         self.max_batches = max_batches
         self.raise_stop = raise_stop_iteration
         self.thread_id = 0
         self.batches_produced = 0
 
         self.dataset_length = len(self._data)
         self.dataset_pids = list(self._data.keys())
+        self.n_filled_threads = min(int(self.dataset_length/self.batch_size), self.number_of_threads_in_multithreaded)
+        if self.n_filled_threads != self.number_of_threads_in_multithreaded:
+            print("Adjusting nr of threads from {} to {}.".format(self.number_of_threads_in_multithreaded,
+                                                                  self.n_filled_threads))
+
         self.rgen = np.random.RandomState(seed=seed)
         self.eligible_pids = self.rgen.permutation(self.dataset_pids.copy())
-        self.eligible_pids = np.array_split(self.eligible_pids, self.number_of_threads_in_multithreaded)
+        self.eligible_pids = np.array_split(self.eligible_pids, self.n_filled_threads)
         self.eligible_pids = sorted(self.eligible_pids, key=len, reverse=True)
+
         self.sample_pids_w_replace = sample_pids_w_replace
         if not self.sample_pids_w_replace:
-            assert len(self.dataset_pids) / self.number_of_threads_in_multithreaded >= self.batch_size, \
+            assert len(self.dataset_pids) / self.n_filled_threads >= self.batch_size, \
                 "at least one batch needed per thread. dataset size: {}, n_threads: {}, batch_size: {}.".format(
-                    len(self.dataset_pids), self.number_of_threads_in_multithreaded, self.batch_size)
+                    len(self.dataset_pids), self.n_filled_threads, self.batch_size)
             self.lock = Lock()
 
         if hasattr(cf, "balance_target"):
             # WARNING: "balance targets are only implemented for 1-d targets (or 1-component vectors)"
             self.balance_target = cf.balance_target
         else:
             self.balance_target = "class_targets"
         self.targets = {k:v[self.balance_target] for (k,v) in self._data.items()}
 
     def set_thread_id(self, thread_id):
         self.thread_ids = self.eligible_pids[thread_id]
         self.thread_id  = thread_id
 
     def reset(self):
         self.batches_produced = 0
         self.thread_ids = self.rgen.permutation(self.eligible_pids[self.thread_id])
 
     @staticmethod
     def sample_targets_to_weights(targets, fg_bg_weights):
         weights = targets * fg_bg_weights
         return weights
 
     def balance_target_distribution(self, plot=False):
         """Impose a drawing distribution over samples.
          Distribution should be designed so that classes' fg and bg examples are (as good as possible) shown in
          equal frequency. Since we are dealing with rois, fg/bg weights count a sample (e.g., a patient) with
          **at least** one occurrence as fg, otherwise bg. For fg weights among classes, each RoI counts.
 
         :param all_pids:
         :param self.targets:  dic holding {patient_specifier : patient-wise-unique ROI targets}
         :return: probability distribution over all pids. draw without replace from this.
         """
         self.unique_ts = np.unique([v for pat in self.targets.values() for v in pat])
         self.sample_stats = pd.DataFrame(columns=[str(ix)+suffix for ix in self.unique_ts for suffix in ["", "_bg"]], index=list(self.targets.keys()))
         for pid in self.sample_stats.index:
             for targ in self.unique_ts:
                 fg_count = np.count_nonzero(self.targets[pid] == targ)
                 self.sample_stats.loc[pid, str(targ)] = int(fg_count > 0)
                 self.sample_stats.loc[pid, str(targ)+"_bg"] = int(fg_count == 0)
 
         self.targ_stats = self.sample_stats.agg(
             ("sum", lambda col: col.sum() / len(self._data)), axis=0, sort=False).rename({"<lambda>": "relative"})
 
         anchor = 1. - self.targ_stats.loc["relative"].iloc[0]
         self.fg_bg_weights = anchor / self.targ_stats.loc["relative"]
         cum_weights = anchor * len(self.fg_bg_weights)
         self.fg_bg_weights /= cum_weights
 
         self.p_probs = self.sample_stats.apply(self.sample_targets_to_weights, args=(self.fg_bg_weights,), axis=1).sum(axis=1)
         self.p_probs = self.p_probs / self.p_probs.sum()
         if plot:
             print("Applying class-weights:\n {}".format(self.fg_bg_weights))
         if len(self.sample_stats.columns) == 2:
             # assert that probs are calc'd correctly:
             # (self.p_probs * self.sample_stats["1"]).sum() == (self.p_probs * self.sample_stats["1_bg"]).sum()
             # only works if one label per patient (multi-label expectations depend on multi-label occurences).
             expectations = []
             for targ in self.sample_stats.columns:
                 expectations.append((self.p_probs * self.sample_stats[targ]).sum())
             assert np.allclose(expectations, expectations[0], atol=1e-4), "expectation values for fgs/bgs: {}".format(expectations)
 
         self.stats = {"roi_counts": np.zeros(len(self.unique_ts,), dtype='uint32'),
                       "empty_counts": np.zeros(len(self.unique_ts,), dtype='uint32')}
 
         if plot:
             os.makedirs(self.plot_dir, exist_ok=True)
             plg.plot_batchgen_distribution(self.cf, self.dataset_pids, self.p_probs, self.balance_target,
                                            out_file=os.path.join(self.plot_dir,
                                                                  "train_gen_distr_"+str(self.cf.fold)+".png"))
         return self.p_probs
 
     def get_batch_pids(self):
-        if self.max_batches is not None and self.batches_produced * self.number_of_threads_in_multithreaded \
+        if self.max_batches is not None and self.batches_produced * self.n_filled_threads \
                 + self.thread_id >= self.max_batches:
             self.reset()
             raise StopIteration
 
         if self.sample_pids_w_replace:
             # fully random patients
             batch_pids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False))
             # target-balanced patients
             batch_pids += list(np.random.choice(
                 self.dataset_pids, size=self.batch_size - self.random_count, replace=False, p=self.p_probs))
         else:
             with self.lock:
                 if len(self.thread_ids) == 0:
                     if self.raise_stop:
                         self.reset()
                         raise StopIteration
                     else:
                         self.thread_ids = self.rgen.permutation(self.eligible_pids[self.thread_id])
                 batch_pids = self.thread_ids[:self.batch_size]
                 # batch_pids = np.random.choice(self.thread_ids, size=self.batch_size, replace=False)
                 self.thread_ids = [pid for pid in self.thread_ids if pid not in batch_pids]
         self.batches_produced += 1
 
         return batch_pids
 
     def generate_train_batch(self):
         # to be overriden by child
         # everything done in here is per batch
         # print statements in here get confusing due to multithreading
         raise NotImplementedError
 
     def print_stats(self, logger=None, file=None, plot_file=None, plot=True):
         print_f = utils.CombinedPrinter(logger, file)
 
         print_f('\n***Final Training Stats***')
         total_count = np.sum(self.stats['roi_counts'])
         for tix, count in enumerate(self.stats['roi_counts']):
             #name = self.cf.class_dict[tix] if self.balance_target=="class_targets" else str(self.unique_ts[tix])
             name=str(self.unique_ts[tix])
             print_f('{}: {} rois seen ({:.1f}%).'.format(name, count, count / total_count * 100))
         total_samples = self.cf.num_epochs*self.cf.num_train_batches*self.cf.batch_size
         empties = [
         '{}: {} ({:.1f}%)'.format(str(name), self.stats['empty_counts'][tix],
                                     self.stats['empty_counts'][tix]/total_samples*100)
             for tix, name in enumerate(self.unique_ts)
         ]
         empties = ", ".join(empties)
         print_f('empty samples seen: {}\n'.format(empties))
         if plot:
             if plot_file is None:
                 plot_file = os.path.join(self.plot_dir, "train_gen_stats_{}.png".format(self.cf.fold))
                 os.makedirs(self.plot_dir, exist_ok=True)
             plg.plot_batchgen_stats(self.cf, self.stats, empties, self.balance_target, self.unique_ts, plot_file)
 
 class PatientBatchIterator(SlimDataLoaderBase):
     """
     creates a val/test generator. Step through the dataset and return dictionaries per patient.
     2D is a special case of 3D patching with patch_size[2] == 1 (slices)
     Creates whole Patient batch and targets, and - if necessary - patchwise batch and targets.
     Appends patient targets anyway for evaluation.
     For Patching, shifts all patches into batch dimension. batch_tiling_forward will take care of exceeding batch dimensions.
 
     This iterator/these batches are not intended to go through MTaugmenter afterwards
     """
 
     def __init__(self, cf, data):
         super(PatientBatchIterator, self).__init__(data, 0)
         self.cf = cf
 
         self.dataset_length = len(self._data)
         self.dataset_pids = list(self._data.keys())
 
     def generate_train_batch(self, pid=None):
         # to be overriden by child
 
         return
 
 ###################################
 #  transforms, image manipulation #
 ###################################
 
 def get_patch_crop_coords(img, patch_size, min_overlap=30):
     """
     _:param img (y, x, (z))
     _:param patch_size: list of len 2 (2D) or 3 (3D).
     _:param min_overlap: minimum required overlap of patches.
     If too small, some areas are poorly represented only at edges of single patches.
     _:return ndarray: shape (n_patches, 2*dim). crop coordinates for each patch.
     """
     crop_coords = []
     for dim in range(len(img.shape)):
         n_patches = int(np.ceil(img.shape[dim] / patch_size[dim]))
 
         # no crops required in this dimension, add image shape as coordinates.
         if n_patches == 1:
             crop_coords.append([(0, img.shape[dim])])
             continue
 
         # fix the two outside patches to coords patchsize/2 and interpolate.
         center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1)
 
         if (patch_size[dim] - center_dists) < min_overlap:
             n_patches += 1
             center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1)
 
         patch_centers = np.round([(patch_size[dim] / 2 + (center_dists * ii)) for ii in range(n_patches)])
         dim_crop_coords = [(center - patch_size[dim] / 2, center + patch_size[dim] / 2) for center in patch_centers]
         crop_coords.append(dim_crop_coords)
 
     coords_mesh_grid = []
     for ymin, ymax in crop_coords[0]:
         for xmin, xmax in crop_coords[1]:
             if len(crop_coords) == 3 and patch_size[2] > 1:
                 for zmin, zmax in crop_coords[2]:
                     coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmax])
             elif len(crop_coords) == 3 and patch_size[2] == 1:
                 for zmin in range(img.shape[2]):
                     coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmin + 1])
             else:
                 coords_mesh_grid.append([ymin, ymax, xmin, xmax])
     return np.array(coords_mesh_grid).astype(int)
 
 def pad_nd_image(image, new_shape=None, mode="edge", kwargs=None, return_slicer=False, shape_must_be_divisible_by=None):
     """
     one padder to pad them all. Documentation? Well okay. A little bit. by Fabian Isensee
 
     :param image: nd image. can be anything
     :param new_shape: what shape do you want? new_shape does not have to have the same dimensionality as image. If
     len(new_shape) < len(image.shape) then the last axes of image will be padded. If new_shape < image.shape in any of
     the axes then we will not pad that axis, but also not crop! (interpret new_shape as new_min_shape)
     Example:
     image.shape = (10, 1, 512, 512); new_shape = (768, 768) -> result: (10, 1, 768, 768). Cool, huh?
     image.shape = (10, 1, 512, 512); new_shape = (364, 768) -> result: (10, 1, 512, 768).
 
     :param mode: see np.pad for documentation
     :param return_slicer: if True then this function will also return what coords you will need to use when cropping back
     to original shape
     :param shape_must_be_divisible_by: for network prediction. After applying new_shape, make sure the new shape is
     divisibly by that number (can also be a list with an entry for each axis). Whatever is missing to match that will
     be padded (so the result may be larger than new_shape if shape_must_be_divisible_by is not None)
     :param kwargs: see np.pad for documentation
     """
     if kwargs is None:
         kwargs = {}
 
     if new_shape is not None:
         old_shape = np.array(image.shape[-len(new_shape):])
     else:
         assert shape_must_be_divisible_by is not None
         assert isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray))
         new_shape = image.shape[-len(shape_must_be_divisible_by):]
         old_shape = new_shape
 
     num_axes_nopad = len(image.shape) - len(new_shape)
 
     new_shape = [max(new_shape[i], old_shape[i]) for i in range(len(new_shape))]
 
     if not isinstance(new_shape, np.ndarray):
         new_shape = np.array(new_shape)
 
     if shape_must_be_divisible_by is not None:
         if not isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)):
             shape_must_be_divisible_by = [shape_must_be_divisible_by] * len(new_shape)
         else:
             assert len(shape_must_be_divisible_by) == len(new_shape)
 
         for i in range(len(new_shape)):
             if new_shape[i] % shape_must_be_divisible_by[i] == 0:
                 new_shape[i] -= shape_must_be_divisible_by[i]
 
         new_shape = np.array([new_shape[i] + shape_must_be_divisible_by[i] - new_shape[i] % shape_must_be_divisible_by[i] for i in range(len(new_shape))])
 
     difference = new_shape - old_shape
     pad_below = difference // 2
     pad_above = difference // 2 + difference % 2
     pad_list = [[0, 0]]*num_axes_nopad + list([list(i) for i in zip(pad_below, pad_above)])
     res = np.pad(image, pad_list, mode, **kwargs)
     if not return_slicer:
         return res
     else:
         pad_list = np.array(pad_list)
         pad_list[:, 1] = np.array(res.shape) - pad_list[:, 1]
         slicer = list(slice(*i) for i in pad_list)
         return res, slicer
 
 def convert_seg_to_bounding_box_coordinates(data_dict, dim, roi_item_keys, get_rois_from_seg=False,
                                                 class_specific_seg=False):
     '''adapted from batchgenerators
 
     :param data_dict: seg: segmentation with labels indicating roi_count (get_rois_from_seg=False) or classes (get_rois_from_seg=True),
         class_targets: list where list index corresponds to roi id (roi_count)
     :param dim:
     :param roi_item_keys: keys of the roi-wise items in data_dict to process
     :param n_rg_feats: nr of regression vector features
     :param get_rois_from_seg:
     :return: coords (y1,x1,y2,x2 (,z1,z2)) where the segmentation GT is framed by +1 voxel, i.e., for an object with
         z-extensions z1=0 through z2=5, bbox target coords will be z1=-1, z2=6. (analogically for x,y).
         data_dict['roi_masks']: (b, n(b), c, h(n), w(n) (z(n))) list like roi_labels but with arrays (masks) inplace of
         integers. c==1 if segmentation not one-hot encoded.
     '''
 
     bb_target = []
     roi_masks = []
     roi_items = {name:[] for name in roi_item_keys}
     out_seg = np.copy(data_dict['seg'])
     for b in range(data_dict['seg'].shape[0]):
 
         p_coords_list = [] #p for patient?
         p_roi_masks_list = []
         p_roi_items_lists = {name:[] for name in roi_item_keys}
 
         if np.sum(data_dict['seg'][b] != 0) > 0:
             if get_rois_from_seg:
                 clusters, n_cands = lb(data_dict['seg'][b])
                 data_dict['class_targets'][b] = [data_dict['class_targets'][b]] * n_cands
             else:
                 n_cands = int(np.max(data_dict['seg'][b]))
 
             rois = np.array(
                 [(data_dict['seg'][b] == ii) * 1 for ii in range(1, n_cands + 1)], dtype='uint8')  # separate clusters
 
             for rix, r in enumerate(rois):
                 if np.sum(r != 0) > 0:  # check if the roi survived slicing (3D->2D) and data augmentation (cropping etc.)
                     seg_ixs = np.argwhere(r != 0)
                     coord_list = [np.min(seg_ixs[:, 1]) - 1, np.min(seg_ixs[:, 2]) - 1, np.max(seg_ixs[:, 1]) + 1,
                                   np.max(seg_ixs[:, 2]) + 1]
                     if dim == 3:
                         coord_list.extend([np.min(seg_ixs[:, 3]) - 1, np.max(seg_ixs[:, 3]) + 1])
 
                     p_coords_list.append(coord_list)
                     p_roi_masks_list.append(r)
                     # add background class = 0. rix is a patient wide index of lesions. since 'class_targets' is
                     # also patient wide, this assignment is not dependent on patch occurrences.
                     for name in roi_item_keys:
                         p_roi_items_lists[name].append(data_dict[name][b][rix])
 
                     assert data_dict["class_targets"][b][rix]>=1, "convertsegtobbox produced bg roi w cl targ {} and unique roi seg {}".format(data_dict["class_targets"][b][rix], np.unique(r))
 
 
                 if class_specific_seg:
                     out_seg[b][data_dict['seg'][b] == rix + 1] = data_dict['class_targets'][b][rix]
 
             if not class_specific_seg:
                 out_seg[b][data_dict['seg'][b] > 0] = 1
 
             bb_target.append(np.array(p_coords_list))
             roi_masks.append(np.array(p_roi_masks_list))
             for name in roi_item_keys:
                 roi_items[name].append(np.array(p_roi_items_lists[name]))
 
 
         else:
             bb_target.append([])
             roi_masks.append(np.zeros_like(data_dict['seg'][b], dtype='uint8')[None])
             for name in roi_item_keys:
                 roi_items[name].append(np.array([]))
 
     if get_rois_from_seg:
         data_dict.pop('class_targets', None)
 
     data_dict['bb_target'] = np.array(bb_target)
     data_dict['roi_masks'] = np.array(roi_masks)
     data_dict['seg'] = out_seg
     for name in roi_item_keys:
         data_dict[name] = np.array(roi_items[name])
 
 
     return data_dict
 
 class ConvertSegToBoundingBoxCoordinates(AbstractTransform):
     """ Converts segmentation masks into bounding box coordinates.
     """
 
     def __init__(self, dim, roi_item_keys, get_rois_from_seg=False, class_specific_seg=False):
         self.dim = dim
         self.roi_item_keys = roi_item_keys
         self.get_rois_from_seg = get_rois_from_seg
         self.class_specific_seg = class_specific_seg
 
     def __call__(self, **data_dict):
         return convert_seg_to_bounding_box_coordinates(data_dict, self.dim, self.roi_item_keys, self.get_rois_from_seg,
                                                        self.class_specific_seg)
 
 
 
 
 
 #############################
 #  data packing / unpacking # not used, data_manager.py used instead
 #############################
 
 def get_case_identifiers(folder):
     case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith("npz")]
     return case_identifiers
 
 
 def convert_to_npy(npz_file):
     if not os.path.isfile(npz_file[:-3] + "npy"):
         a = np.load(npz_file)['data']
         np.save(npz_file[:-3] + "npy", a)
 
 
 def unpack_dataset(folder, threads=8):
     case_identifiers = get_case_identifiers(folder)
     p = Pool(threads)
     npz_files = [os.path.join(folder, i + ".npz") for i in case_identifiers]
     p.map(convert_to_npy, npz_files)
     p.close()
     p.join()
 
 
 def delete_npy(folder):
     case_identifiers = get_case_identifiers(folder)
     npy_files = [os.path.join(folder, i + ".npy") for i in case_identifiers]
     npy_files = [i for i in npy_files if os.path.isfile(i)]
     for n in npy_files:
         os.remove(n)
\ No newline at end of file
diff --git a/utils/exp_utils.py b/utils/exp_utils.py
index e528993..759deff 100644
--- a/utils/exp_utils.py
+++ b/utils/exp_utils.py
@@ -1,679 +1,691 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 # import plotting as plg
 
 import sys
 import os
 import subprocess
 from multiprocessing import Process
 import threading
 import pickle
 import importlib.util
 import psutil
 import time
+import nvidia_smi
 
 import logging
 from torch.utils.tensorboard import SummaryWriter
 
 from collections import OrderedDict
 import numpy as np
 import pandas as pd
 import torch
 
 
 def import_module(name, path):
     """
     correct way of importing a module dynamically in python 3.
     :param name: name given to module instance.
     :param path: path to module.
     :return: module: returned module instance.
     """
     spec = importlib.util.spec_from_file_location(name, path)
     module = importlib.util.module_from_spec(spec)
     spec.loader.exec_module(module)
     return module
 
 
 def save_obj(obj, name):
     """Pickle a python object."""
     with open(name + '.pkl', 'wb') as f:
         pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
 
 
 def load_obj(file_path):
     with open(file_path, 'rb') as handle:
         return pickle.load(handle)
 
 
 def IO_safe(func, *args, _tries=5, _raise=True, **kwargs):
     """ Wrapper calling function func with arguments args and keyword arguments kwargs to catch input/output errors
         on cluster.
     :param func: function to execute (intended to be read/write operation to a problematic cluster drive, but can be
         any function).
     :param args: positional args of func.
     :param kwargs: kw args of func.
     :param _tries: how many attempts to make executing func.
     """
     for _try in range(_tries):
         try:
             return func(*args, **kwargs)
         except OSError as e:  # to catch cluster issues with network drives
             if _raise:
                 raise e
             else:
                 print("After attempting execution {} time{}, following error occurred:\n{}".format(_try + 1,
                                                                                                    "" if _try == 0 else "s",
                                                                                                    e))
                 continue
 
-def split_off_process(target, *args, **kwargs):
+def split_off_process(target, *args, daemon=False, **kwargs):
     """Start a process that won't block parent script.
-    No join(), no return value. Before parent exits, it waits for this to finish.
+    No join(), no return value. If daemon=False: before parent exits, it waits for this to finish.
     """
-    p = Process(target=target, args=tuple(args), kwargs=kwargs, daemon=False)
+    p = Process(target=target, args=tuple(args), kwargs=kwargs, daemon=daemon)
     p.start()
-
+    return p
 
 def query_nvidia_gpu(device_id, d_keyword=None, no_units=False):
     """
     :param device_id:
     :param d_keyword: -d, --display argument (keyword(s) for selective display), all are selected if None
     :return: dict of gpu-info items
     """
     cmd = ['nvidia-smi', '-i', str(device_id), '-q']
     if d_keyword is not None:
         cmd += ['-d', d_keyword]
     outp = subprocess.check_output(cmd).strip().decode('utf-8').split("\n")
     outp = [x for x in outp if len(x) > 0]
     headers = [ix for ix, item in enumerate(outp) if len(item.split(":")) == 1] + [len(outp)]
 
     out_dict = {}
     for lix, hix in enumerate(headers[:-1]):
         head = outp[hix].strip().replace(" ", "_").lower()
         out_dict[head] = {}
         for lix2 in range(hix, headers[lix + 1]):
             try:
                 key, val = [x.strip().lower() for x in outp[lix2].split(":")]
                 if no_units:
                     val = val.split()[0]
                 out_dict[head][key] = val
             except:
                 pass
 
     return out_dict
 
 
 class CombinedPrinter(object):
     """combined print function.
     prints to logger and/or file if given, to normal print if non given.
 
     """
 
     def __init__(self, logger=None, file=None):
 
         if logger is None and file is None:
             self.out = [print]
         elif logger is None:
             self.out = [file.write]
         elif file is None:
             self.out = [logger.info]
         else:
             self.out = [logger.info, file.write]
 
     def __call__(self, string):
         for fct in self.out:
             fct(string)
 
 
 class Nvidia_GPU_Logger(object):
     def __init__(self):
         self.count = None
 
     def get_vals(self):
 
-        cmd = ['nvidia-settings', '-t', '-q', 'GPUUtilization']
-        gpu_util = subprocess.check_output(cmd).strip().decode('utf-8').split(",")
-        gpu_util = dict([f.strip().split("=") for f in gpu_util])
-        cmd[-1] = 'UsedDedicatedGPUMemory'
-        gpu_used_mem = subprocess.check_output(cmd).strip().decode('utf-8')
-        current_vals = {"gpu_mem_alloc": gpu_used_mem, "gpu_graphics_util": int(gpu_util['graphics']),
-                        "gpu_mem_util": gpu_util['memory'], "time": time.time()}
+        # cmd = ['nvidia-settings', '-t', '-q', 'GPUUtilization']
+        # gpu_util = subprocess.check_output(cmd).strip().decode('utf-8').split(",")
+        # gpu_util = dict([f.strip().split("=") for f in gpu_util])
+        # cmd[-1] = 'UsedDedicatedGPUMemory'
+        # gpu_used_mem = subprocess.check_output(cmd).strip().decode('utf-8')
+
+
+        nvidia_smi.nvmlInit()
+        # card id 0 hardcoded here, there is also a call to get all available card ids, so we could iterate
+        self.gpu_handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
+        util_res = nvidia_smi.nvmlDeviceGetUtilizationRates(self.gpu_handle)
+        #mem_res = nvidia_smi.nvmlDeviceGetMemoryInfo(self.gpu_handle)
+        # current_vals = {"gpu_mem_alloc": mem_res.used / (1024**2), "gpu_graphics_util": int(gpu_util['graphics']),
+        #                 "gpu_mem_util": gpu_util['memory'], "time": time.time()}
+        current_vals = {"gpu_graphics_util": float(util_res.gpu),
+                        "time": time.time()}
         return current_vals
 
     def loop(self, interval):
         i = 0
         while True:
-            self.get_vals()
+            current_vals = self.get_vals()
             self.log["time"].append(time.time())
-            self.log["gpu_util"].append(self.current_vals["gpu_graphics_util"])
+            self.log["gpu_util"].append(current_vals["gpu_graphics_util"])
             if self.count is not None:
                 i += 1
                 if i == self.count:
                     exit(0)
             time.sleep(self.interval)
 
     def start(self, interval=1.):
         self.interval = interval
         self.start_time = time.time()
         self.log = {"time": [], "gpu_util": []}
         if self.interval is not None:
             thread = threading.Thread(target=self.loop)
             thread.daemon = True
             thread.start()
 
 class CombinedLogger(object):
     """Combine console and tensorboard logger and record system metrics.
     """
 
     def __init__(self, name, log_dir, server_env=True, fold="all", sysmetrics_interval=2):
         self.pylogger = logging.getLogger(name)
         self.tboard = SummaryWriter(log_dir=os.path.join(log_dir, "tboard"))
         self.times = {}
         self.log_dir = log_dir
         self.fold = str(fold)
         self.server_env = server_env
 
         self.pylogger.setLevel(logging.DEBUG)
         self.log_file = os.path.join(log_dir, "fold_"+self.fold, 'exec.log')
         os.makedirs(os.path.dirname(self.log_file), exist_ok=True)
         self.pylogger.addHandler(logging.FileHandler(self.log_file))
         if not server_env:
             self.pylogger.addHandler(ColorHandler())
         else:
             self.pylogger.addHandler(logging.StreamHandler())
         self.pylogger.propagate = False
 
         # monitor system metrics (cpu, mem, ...)
         if not server_env and sysmetrics_interval > 0:
             self.sysmetrics = pd.DataFrame(
                 columns=["global_step", "rel_time", r"CPU (%)", "mem_used (GB)", r"mem_used (%)",
                          r"swap_used (GB)", r"gpu_utilization (%)"], dtype="float16")
             for device in range(torch.cuda.device_count()):
                 self.sysmetrics[
                     "mem_allocd (GB) by torch on {:10s}".format(torch.cuda.get_device_name(device))] = np.nan
                 self.sysmetrics[
                     "mem_cached (GB) by torch on {:10s}".format(torch.cuda.get_device_name(device))] = np.nan
             self.sysmetrics_start(sysmetrics_interval)
             pass
         else:
             print("NOT logging sysmetrics")
 
     def __getattr__(self, attr):
         """delegate all undefined method requests to objects of
         this class in order pylogger, tboard (first find first serve).
         E.g., combinedlogger.add_scalars(...) should trigger self.tboard.add_scalars(...)
         """
         for obj in [self.pylogger, self.tboard]:
             if attr in dir(obj):
                 return getattr(obj, attr)
         print("logger attr not found")
         #raise AttributeError("CombinedLogger has no attribute {}".format(attr))
 
     def set_logfile(self, fold=None, log_file=None):
         if fold is not None:
             self.fold = str(fold)
         if log_file is None:
             self.log_file = os.path.join(self.log_dir, "fold_"+self.fold, 'exec.log')
         else:
             self.log_file = log_file
         os.makedirs(os.path.dirname(self.log_file), exist_ok=True)
         for hdlr in self.pylogger.handlers:
             hdlr.close()
         self.pylogger.handlers = []
         self.pylogger.addHandler(logging.FileHandler(self.log_file))
         if not self.server_env:
             self.pylogger.addHandler(ColorHandler())
         else:
             self.pylogger.addHandler(logging.StreamHandler())
 
     def time(self, name, toggle=None):
         """record time-spans as with a stopwatch.
         :param name:
         :param toggle: True^=On: start time recording, False^=Off: halt rec. if None determine from current status.
         :return: either start-time or last recorded interval
         """
         if toggle is None:
             if name in self.times.keys():
                 toggle = not self.times[name]["toggle"]
             else:
                 toggle = True
 
         if toggle:
             if not name in self.times.keys():
                 self.times[name] = {"total": 0, "last": 0}
             elif self.times[name]["toggle"] == toggle:
                 self.info("restarting running stopwatch")
             self.times[name]["last"] = time.time()
             self.times[name]["toggle"] = toggle
             return time.time()
         else:
             if toggle == self.times[name]["toggle"]:
                 self.info("WARNING: tried to stop stopped stop watch: {}.".format(name))
             self.times[name]["last"] = time.time() - self.times[name]["last"]
             self.times[name]["total"] += self.times[name]["last"]
             self.times[name]["toggle"] = toggle
             return self.times[name]["last"]
 
     def get_time(self, name=None, kind="total", format=None, reset=False):
         """
         :param name:
         :param kind: 'total' or 'last'
         :param format: None for float, "hms"/"ms" for (hours), mins, secs as string
         :param reset: reset time after retrieving
         :return:
         """
         if name is None:
             times = self.times
             if reset:
                 self.reset_time()
             return times
 
         else:
             if self.times[name]["toggle"]:
                 self.time(name, toggle=False)
             time = self.times[name][kind]
             if format == "hms":
                 m, s = divmod(time, 60)
                 h, m = divmod(m, 60)
                 time = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(m), int(s))
             elif format == "ms":
                 m, s = divmod(time, 60)
                 time = "{:02d}m:{:02d}s".format(int(m), int(s))
             if reset:
                 self.reset_time(name)
             return time
 
     def reset_time(self, name=None):
         if name is None:
             self.times = {}
         else:
             del self.times[name]
 
     def sysmetrics_update(self, global_step=None):
         if global_step is None:
             global_step = time.strftime("%x_%X")
         mem = psutil.virtual_memory()
         mem_used = (mem.total - mem.available)
         gpu_vals = self.gpu_logger.get_vals()
         rel_time = time.time() - self.sysmetrics_start_time
         self.sysmetrics.loc[len(self.sysmetrics)] = [global_step, rel_time,
                                                      psutil.cpu_percent(), mem_used / 1024 ** 3,
                                                      mem_used / mem.total * 100,
                                                      psutil.swap_memory().used / 1024 ** 3,
                                                      int(gpu_vals['gpu_graphics_util']),
                                                      *[torch.cuda.memory_allocated(d) / 1024 ** 3 for d in
                                                        range(torch.cuda.device_count())],
                                                      *[torch.cuda.memory_cached(d) / 1024 ** 3 for d in
                                                        range(torch.cuda.device_count())]
                                                      ]
         return self.sysmetrics.loc[len(self.sysmetrics) - 1].to_dict()
 
     def sysmetrics2tboard(self, metrics=None, global_step=None, suptitle=None):
         tag = "per_time"
         if metrics is None:
             metrics = self.sysmetrics_update(global_step=global_step)
             tag = "per_epoch"
 
         if suptitle is not None:
             suptitle = str(suptitle)
         elif self.fold != "":
             suptitle = "Fold_" + str(self.fold)
         if suptitle is not None:
             self.tboard.add_scalars(suptitle + "/System_Metrics/" + tag,
                                     {k: v for (k, v) in metrics.items() if (k != "global_step"
                                                                             and k != "rel_time")}, global_step)
 
     def sysmetrics_loop(self):
         try:
             os.nice(-19)
             self.info("Logging system metrics with superior process priority.")
         except:
             self.info("Logging system metrics without superior process priority.")
         while True:
             metrics = self.sysmetrics_update()
             self.sysmetrics2tboard(metrics, global_step=metrics["rel_time"])
             # print("thread alive", self.thread.is_alive())
             time.sleep(self.sysmetrics_interval)
 
     def sysmetrics_start(self, interval):
         if interval is not None and interval > 0:
             self.sysmetrics_interval = interval
             self.gpu_logger = Nvidia_GPU_Logger()
             self.sysmetrics_start_time = time.time()
-            self.thread = threading.Thread(target=self.sysmetrics_loop)
-            self.thread.daemon = True
-            self.thread.start()
+            self.sys_metrics_process = split_off_process(target=self.sysmetrics_loop, daemon=True)
+            # self.thread = threading.Thread(target=self.sysmetrics_loop)
+            # self.thread.daemon = True
+            # self.thread.start()
 
     def sysmetrics_save(self, out_file):
         self.sysmetrics.to_pickle(out_file)
 
     def metrics2tboard(self, metrics, global_step=None, suptitle=None):
         """
         :param metrics: {'train': dataframe, 'val':df}, df as produced in
             evaluator.py.evaluate_predictions
         """
         # print("metrics", metrics)
         if global_step is None:
             global_step = len(metrics['train'][list(metrics['train'].keys())[0]]) - 1
         if suptitle is not None:
             suptitle = str(suptitle)
         else:
             suptitle = "Fold_" + str(self.fold)
 
         for key in ['train', 'val']:
             # series = {k:np.array(v[-1]) for (k,v) in metrics[key].items() if not np.isnan(v[-1]) and not 'Bin_Stats' in k}
             loss_series = {}
             unc_series = {}
             bin_stat_series = {}
             mon_met_series = {}
             for tag, val in metrics[key].items():
                 val = val[-1]  # maybe remove list wrapping, recording in evaluator?
                 if 'bin_stats' in tag.lower() and not np.isnan(val):
                     bin_stat_series["{}".format(tag.split("/")[-1])] = val
                 elif 'uncertainty' in tag.lower() and not np.isnan(val):
                     unc_series["{}".format(tag)] = val
                 elif 'loss' in tag.lower() and not np.isnan(val):
                     loss_series["{}".format(tag)] = val
                 elif not np.isnan(val):
                     mon_met_series["{}".format(tag)] = val
 
             self.tboard.add_scalars(suptitle + "/Binary_Statistics/{}".format(key), bin_stat_series, global_step)
             self.tboard.add_scalars(suptitle + "/Uncertainties/{}".format(key), unc_series, global_step)
             self.tboard.add_scalars(suptitle + "/Losses/{}".format(key), loss_series, global_step)
             self.tboard.add_scalars(suptitle + "/Monitor_Metrics/{}".format(key), mon_met_series, global_step)
         self.tboard.add_scalars(suptitle + "/Learning_Rate", metrics["lr"], global_step)
         return
 
     def batchImgs2tboard(self, batch, results_dict, cmap, boxtype2color, img_bg=False, global_step=None):
         raise NotImplementedError("not up-to-date, problem with importing plotting-file, torchvision dependency.")
         if len(batch["seg"].shape) == 5:  # 3D imgs
             slice_ix = np.random.randint(batch["seg"].shape[-1])
             seg_gt = plg.to_rgb(batch['seg'][:, 0, :, :, slice_ix], cmap)
             seg_pred = plg.to_rgb(results_dict['seg_preds'][:, 0, :, :, slice_ix], cmap)
 
             mod_img = plg.mod_to_rgb(batch["data"][:, 0, :, :, slice_ix]) if img_bg else None
 
         elif len(batch["seg"].shape) == 4:
             seg_gt = plg.to_rgb(batch['seg'][:, 0, :, :], cmap)
             seg_pred = plg.to_rgb(results_dict['seg_preds'][:, 0, :, :], cmap)
             mod_img = plg.mod_to_rgb(batch["data"][:, 0]) if img_bg else None
         else:
             raise Exception("batch content has wrong format: {}".format(batch["seg"].shape))
 
         # from here on only works in 2D
         seg_gt = np.transpose(seg_gt, axes=(0, 3, 1, 2))  # previous shp: b,x,y,c
         seg_pred = np.transpose(seg_pred, axes=(0, 3, 1, 2))
 
         seg = np.concatenate((seg_gt, seg_pred), axis=0)
         # todo replace torchvision (tv) dependency
         seg = tv.utils.make_grid(torch.from_numpy(seg), nrow=2)
         self.tboard.add_image("Batch seg, 1st col: gt, 2nd: pred.", seg, global_step=global_step)
 
         if img_bg:
             bg_img = np.transpose(mod_img, axes=(0, 3, 1, 2))
         else:
             bg_img = seg_gt
         box_imgs = plg.draw_boxes_into_batch(bg_img, results_dict["boxes"], boxtype2color)
         box_imgs = tv.utils.make_grid(torch.from_numpy(box_imgs), nrow=4)
         self.tboard.add_image("Batch bboxes", box_imgs, global_step=global_step)
 
         return
 
     def __del__(self):  # otherwise might produce multiple prints e.g. in ipython console
+        self.sys_metrics_process.terminate()
         for hdlr in self.pylogger.handlers:
             hdlr.close()
         self.pylogger.handlers = []
         del self.pylogger
         self.tboard.close()
 
 
 def get_logger(exp_dir, server_env=False, sysmetrics_interval=2):
     log_dir = os.path.join(exp_dir, "logs")
     logger = CombinedLogger('Reg R-CNN', log_dir, server_env=server_env,
                             sysmetrics_interval=sysmetrics_interval)
     print("logging to {}".format(logger.log_file))
     return logger
 
 
 def prep_exp(dataset_path, exp_path, server_env, use_stored_settings=True, is_training=True):
     """
     I/O handling, creating of experiment folder structure. Also creates a snapshot of configs/model scripts and copies them to the exp_dir.
     This way the exp_dir contains all info needed to conduct an experiment, independent to changes in actual source code. Thus, training/inference of this experiment can be started at anytime.
     Therefore, the model script is copied back to the source code dir as tmp_model (tmp_backbone).
     Provides robust structure for cloud deployment.
     :param dataset_path: path to source code for specific data set. (e.g. medicaldetectiontoolkit/lidc_exp)
     :param exp_path: path to experiment directory.
     :param server_env: boolean flag. pass to configs script for cloud deployment.
     :param use_stored_settings: boolean flag. When starting training: If True, starts training from snapshot in existing
         experiment directory, else creates experiment directory on the fly using configs/model scripts from source code.
     :param is_training: boolean flag. distinguishes train vs. inference mode.
     :return: configs object.
     """
 
     if is_training:
 
         if use_stored_settings:
             cf_file = import_module('cf', os.path.join(exp_path, 'configs.py'))
             cf = cf_file.Configs(server_env)
             # in this mode, previously saved model and backbone need to be found in exp dir.
             if not os.path.isfile(os.path.join(exp_path, 'model.py')) or \
                     not os.path.isfile(os.path.join(exp_path, 'backbone.py')):
                 raise Exception(
                     "Selected use_stored_settings option but no model and/or backbone source files exist in exp dir.")
             cf.model_path = os.path.join(exp_path, 'model.py')
             cf.backbone_path = os.path.join(exp_path, 'backbone.py')
         else:  # this case overwrites settings files in exp dir, i.e., default_configs, configs, backbone, model
             os.makedirs(exp_path, exist_ok=True)
             # run training with source code info and copy snapshot of model to exp_dir for later testing (overwrite scripts if exp_dir already exists.)
             subprocess.call('cp {} {}'.format('default_configs.py', os.path.join(exp_path, 'default_configs.py')),
                             shell=True)
             subprocess.call(
                 'cp {} {}'.format(os.path.join(dataset_path, 'configs.py'), os.path.join(exp_path, 'configs.py')),
                 shell=True)
             cf_file = import_module('cf_file', os.path.join(dataset_path, 'configs.py'))
             cf = cf_file.Configs(server_env)
             subprocess.call('cp {} {}'.format(cf.model_path, os.path.join(exp_path, 'model.py')), shell=True)
             subprocess.call('cp {} {}'.format(cf.backbone_path, os.path.join(exp_path, 'backbone.py')), shell=True)
             if os.path.isfile(os.path.join(exp_path, "fold_ids.pickle")):
                 subprocess.call('rm {}'.format(os.path.join(exp_path, "fold_ids.pickle")), shell=True)
 
     else:  # testing, use model and backbone stored in exp dir.
         cf_file = import_module('cf', os.path.join(exp_path, 'configs.py'))
         cf = cf_file.Configs(server_env)
         cf.model_path = os.path.join(exp_path, 'model.py')
         cf.backbone_path = os.path.join(exp_path, 'backbone.py')
 
     cf.exp_dir = exp_path
     cf.test_dir = os.path.join(cf.exp_dir, 'test')
     cf.plot_dir = os.path.join(cf.exp_dir, 'plots')
     if not os.path.exists(cf.test_dir):
         os.mkdir(cf.test_dir)
     if not os.path.exists(cf.plot_dir):
         os.mkdir(cf.plot_dir)
     cf.experiment_name = exp_path.split("/")[-1]
     cf.dataset_name = dataset_path
     cf.server_env = server_env
     cf.created_fold_id_pickle = False
 
     return cf
 
 
 class ModelSelector:
     '''
     saves a checkpoint after each epoch as 'last_state' (can be loaded to continue interrupted training).
     saves the top-k (k=cf.save_n_models) ranked epochs. In inference, predictions of multiple epochs can be ensembled
     to improve performance.
     '''
 
     def __init__(self, cf, logger):
 
         self.cf = cf
-        self.saved_epochs = [-1] * cf.save_n_models
         self.logger = logger
 
+        self.model_index = pd.DataFrame(columns=["rank", "score", "criteria_values", "file_name"],
+                                        index=pd.RangeIndex(self.cf.min_save_thresh, self.cf.num_epochs, name="epoch"))
+
     def run_model_selection(self, net, optimizer, monitor_metrics, epoch):
         """rank epoch via weighted mean from self.cf.model_selection_criteria: {criterion : weight}
         :param net:
         :param optimizer:
         :param monitor_metrics:
         :param epoch:
         :return:
         """
         crita = self.cf.model_selection_criteria  # shorter alias
+        metrics =  monitor_metrics['val']
+
+        epoch_score = np.sum([metrics[criterion][-1] * weight for criterion, weight in crita.items() if
+                              not np.isnan(metrics[criterion][-1])])
+        if not self.cf.resume:
+            epoch_score_check = np.sum([metrics[criterion][epoch] * weight for criterion, weight in crita.items() if
+                                  not np.isnan(metrics[criterion][epoch])])
+            assert np.all(epoch_score == epoch_score_check)
 
-        non_nan_scores = {}
-        for criterion in crita.keys():
-            # exclude first entry bc its dummy None entry
-            non_nan_scores[criterion] = [0 if (ii is None or np.isnan(ii)) else ii for ii in
-                                         monitor_metrics['val'][criterion]][1:]
-            n_epochs = len(non_nan_scores[criterion])
-        epochs_scores = []
-        for e_ix in range(n_epochs):
-            epochs_scores.append(np.sum([weight * non_nan_scores[criterion][e_ix] for
-                                         criterion, weight in crita.items()]) / len(crita.keys()))
-
-        # ranking of epochs according to model_selection_criterion
-        epoch_ranking = np.argsort(epochs_scores)[::-1] + 1  # epochs start at 1
-
-        # if set in configs, epochs < min_save_thresh are discarded from saving process.
-        epoch_ranking = epoch_ranking[epoch_ranking >= self.cf.min_save_thresh]
-
-        # check if current epoch is among the top-k epchs.
-        if epoch in epoch_ranking[:self.cf.save_n_models]:
+        self.model_index.loc[epoch, ["score", "criteria_values"]] = epoch_score, {cr: metrics[cr][-1] for cr in crita.keys()}
+
+        nonna_ics = self.model_index["score"].dropna(axis=0).index
+        order = np.argsort(self.model_index.loc[nonna_ics, "score"].to_numpy(), kind="stable")[::-1]
+        self.model_index.loc[nonna_ics, "rank"] = np.argsort(order) + 1 # no zero-indexing for ranks (best rank is 1).
+
+        rank = int(self.model_index.loc[epoch, "rank"])
+        if rank <= self.cf.save_n_models:
+            name = '{}_best_params.pth'.format(epoch)
             if self.cf.server_env:
-                IO_safe(torch.save, net.state_dict(),
-                        os.path.join(self.cf.fold_dir, '{}_best_params.pth'.format(epoch)))
-                # save epoch_ranking to keep info for inference.
-                IO_safe(np.save, os.path.join(self.cf.fold_dir, 'epoch_ranking'), epoch_ranking[:self.cf.save_n_models])
+                IO_safe(torch.save, net.state_dict(), os.path.join(self.cf.fold_dir, name))
             else:
-                torch.save(net.state_dict(), os.path.join(self.cf.fold_dir, '{}_best_params.pth'.format(epoch)))
-                np.save(os.path.join(self.cf.fold_dir, 'epoch_ranking'), epoch_ranking[:self.cf.save_n_models])
-            self.logger.info(
-                "saving current epoch {} at rank {}".format(epoch, np.argwhere(epoch_ranking == epoch)))
-            # delete params of the epoch that just fell out of the top-k epochs.
-            for se in [int(ii.split('_')[0]) for ii in os.listdir(self.cf.fold_dir) if 'best_params' in ii]:
-                if se in epoch_ranking[self.cf.save_n_models:]:
-                    subprocess.call('rm {}'.format(os.path.join(self.cf.fold_dir, '{}_best_params.pth'.format(se))),
-                                    shell=True)
-                    self.logger.info('deleting epoch {} at rank {}'.format(se, np.argwhere(epoch_ranking == se)))
+                torch.save(net.state_dict(), os.path.join(self.cf.fold_dir, name))
+            self.model_index.loc[epoch, "file_name"] = name
+            self.logger.info("saved current epoch {} at rank {}".format(epoch, rank))
+
+            clean_up = self.model_index.dropna(axis=0, subset=["file_name"])
+            clean_up = clean_up[clean_up["rank"] > self.cf.save_n_models]
+            if clean_up.size > 0:
+                file_name = clean_up["file_name"].to_numpy().item()
+                subprocess.call("rm {}".format(os.path.join(self.cf.fold_dir, file_name)), shell=True)
+                self.logger.info("removed outranked epoch {} at {}".format(clean_up.index.values.item(),
+                                                                       os.path.join(self.cf.fold_dir, file_name)))
+                self.model_index.loc[clean_up.index, "file_name"] = np.nan
 
         state = {
             'epoch': epoch,
             'state_dict': net.state_dict(),
             'optimizer': optimizer.state_dict(),
+            'model_index': self.model_index,
         }
 
         if self.cf.server_env:
             IO_safe(torch.save, state, os.path.join(self.cf.fold_dir, 'last_state.pth'))
         else:
             torch.save(state, os.path.join(self.cf.fold_dir, 'last_state.pth'))
 
-
-def load_checkpoint(checkpoint_path, net, optimizer):
+def load_checkpoint(checkpoint_path, net, optimizer, model_selector):
     checkpoint = torch.load(checkpoint_path)
     net.load_state_dict(checkpoint['state_dict'])
     optimizer.load_state_dict(checkpoint['optimizer'])
-    return checkpoint['epoch']
+    model_selector.model_index = checkpoint["model_index"]
+    return checkpoint['epoch'] + 1, net, optimizer, model_selector
 
 
 def prepare_monitoring(cf):
     """
     creates dictionaries, where train/val metrics are stored.
     """
     metrics = {}
     # first entry for loss dict accounts for epoch starting at 1.
     metrics['train'] = OrderedDict()  # [(l_name, [np.nan]) for l_name in cf.losses_to_monitor] )
     metrics['val'] = OrderedDict()  # [(l_name, [np.nan]) for l_name in cf.losses_to_monitor] )
     metric_classes = []
     if 'rois' in cf.report_score_level:
         metric_classes.extend([v for k, v in cf.class_dict.items()])
         if hasattr(cf, "eval_bins_separately") and cf.eval_bins_separately:
             metric_classes.extend([v for k, v in cf.bin_dict.items()])
     if 'patient' in cf.report_score_level:
         metric_classes.extend(['patient_' + cf.class_dict[cf.patient_class_of_interest]])
         if hasattr(cf, "eval_bins_separately") and cf.eval_bins_separately:
             metric_classes.extend(['patient_' + cf.bin_dict[cf.patient_bin_of_interest]])
     for cl in metric_classes:
         for m in cf.metrics:
             metrics['train'][cl + '_' + m] = [np.nan]
             metrics['val'][cl + '_' + m] = [np.nan]
 
     return metrics
 
 
 class _AnsiColorizer(object):
     """
     A colorizer is an object that loosely wraps around a stream, allowing
     callers to write text to the stream in a particular color.
 
     Colorizer classes must implement C{supported()} and C{write(text, color)}.
     """
     _colors = dict(black=30, red=31, green=32, yellow=33,
                    blue=34, magenta=35, cyan=36, white=37, default=39)
 
     def __init__(self, stream):
         self.stream = stream
 
     @classmethod
     def supported(cls, stream=sys.stdout):
         """
         A class method that returns True if the current platform supports
         coloring terminal output using this method. Returns False otherwise.
         """
         if not stream.isatty():
             return False  # auto color only on TTYs
         try:
             import curses
         except ImportError:
             return False
         else:
             try:
                 try:
                     return curses.tigetnum("colors") > 2
                 except curses.error:
                     curses.setupterm()
                     return curses.tigetnum("colors") > 2
             except:
                 raise
                 # guess false in case of error
                 return False
 
     def write(self, text, color):
         """
         Write the given text to the stream in the given color.
 
         @param text: Text to be written to the stream.
 
         @param color: A string label for a color. e.g. 'red', 'white'.
         """
         color = self._colors[color]
         self.stream.write('\x1b[%sm%s\x1b[0m' % (color, text))
 
 
 class ColorHandler(logging.StreamHandler):
 
     def __init__(self, stream=sys.stdout):
         super(ColorHandler, self).__init__(_AnsiColorizer(stream))
 
     def emit(self, record):
         msg_colors = {
             logging.DEBUG: "green",
             logging.INFO: "default",
             logging.WARNING: "red",
             logging.ERROR: "red"
         }
         color = msg_colors.get(record.levelno, "blue")
         self.stream.write(record.msg + "\n", color)
diff --git a/utils/model_utils.py b/utils/model_utils.py
index da1f34a..e951ec7 100644
--- a/utils/model_utils.py
+++ b/utils/model_utils.py
@@ -1,1527 +1,1527 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 """
 Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn
 published under MIT license.
 """
 import warnings
 warnings.filterwarnings('ignore', '.*From scipy 0.13.0, the output shape of zoom()*')
 
 import numpy as np
 import scipy.misc
 import scipy.ndimage
 import scipy.interpolate
 from scipy.ndimage.measurements import label as lb
 import torch
 
 import tqdm
 
 from custom_extensions.nms import nms
 from custom_extensions.roi_align import roi_align
 
 ############################################################
 #  Segmentation Processing
 ############################################################
 
 def sum_tensor(input, axes, keepdim=False):
     axes = np.unique(axes)
     if keepdim:
         for ax in axes:
             input = input.sum(ax, keepdim=True)
     else:
         for ax in sorted(axes, reverse=True):
             input = input.sum(int(ax))
     return input
 
 def get_one_hot_encoding(y, n_classes):
     """
     transform a numpy label array to a one-hot array of the same shape.
     :param y: array of shape (b, 1, y, x, (z)).
     :param n_classes: int, number of classes to unfold in one-hot encoding.
     :return y_ohe: array of shape (b, n_classes, y, x, (z))
     """
 
     dim = len(y.shape) - 2
     if dim == 2:
         y_ohe = np.zeros((y.shape[0], n_classes, y.shape[2], y.shape[3])).astype('int32')
     elif dim == 3:
         y_ohe = np.zeros((y.shape[0], n_classes, y.shape[2], y.shape[3], y.shape[4])).astype('int32')
     else:
         raise Exception("invalid dimensions {} encountered".format(y.shape))
     for cl in np.arange(n_classes):
         y_ohe[:, cl][y[:, 0] == cl] = 1
     return y_ohe
 
 def dice_per_batch_inst_and_class(pred, y, n_classes, convert_to_ohe=True, smooth=1e-8):
     '''
     computes dice scores per batch instance and class.
     :param pred: prediction array of shape (b, 1, y, x, (z)) (e.g. softmax prediction with argmax over dim 1)
     :param y: ground truth array of shape (b, 1, y, x, (z)) (contains int [0, ..., n_classes]
     :param n_classes: int
     :return: dice scores of shape (b, c)
     '''
     if convert_to_ohe:
         pred = get_one_hot_encoding(pred, n_classes)
         y = get_one_hot_encoding(y, n_classes)
     axes = tuple(range(2, len(pred.shape)))
     intersect = np.sum(pred*y, axis=axes)
     denominator = np.sum(pred, axis=axes)+np.sum(y, axis=axes)
     dice = (2.0*intersect + smooth) / (denominator + smooth)
     return dice
 
 def dice_per_batch_and_class(pred, targ, n_classes, convert_to_ohe=True, smooth=1e-8):
     '''
     computes dice scores per batch and class.
     :param pred: prediction array of shape (b, 1, y, x, (z)) (e.g. softmax prediction with argmax over dim 1)
     :param targ: ground truth array of shape (b, 1, y, x, (z)) (contains int [0, ..., n_classes])
     :param n_classes: int
     :param smooth: Laplacian smooth, https://en.wikipedia.org/wiki/Additive_smoothing
     :return: dice scores of shape (b, c)
     '''
     if convert_to_ohe:
         pred = get_one_hot_encoding(pred, n_classes)
         targ = get_one_hot_encoding(targ, n_classes)
     axes = (0, *list(range(2, len(pred.shape)))) #(0,2,3(,4))
 
     intersect = np.sum(pred * targ, axis=axes)
 
     denominator = np.sum(pred, axis=axes) + np.sum(targ, axis=axes)
     dice = (2.0 * intersect + smooth) / (denominator + smooth)
 
     assert dice.shape==(n_classes,), "dice shp {}".format(dice.shape)
     return dice
 
 
 def batch_dice(pred, y, false_positive_weight=1.0, smooth=1e-6):
     '''
     compute soft dice over batch. this is a differentiable score and can be used as a loss function.
     only dice scores of foreground classes are returned, since training typically
     does not benefit from explicit background optimization. Pixels of the entire batch are considered a pseudo-volume to compute dice scores of.
     This way, single patches with missing foreground classes can not produce faulty gradients.
     :param pred: (b, c, y, x, (z)), softmax probabilities (network output).
     :param y: (b, c, y, x, (z)), one hote encoded segmentation mask.
     :param false_positive_weight: float [0,1]. For weighting of imbalanced classes,
     reduces the penalty for false-positive pixels. Can be beneficial sometimes in data with heavy fg/bg imbalances.
     :return: soft dice score (float).This function discards the background score and returns the mena of foreground scores.
     '''
 
     if len(pred.size()) == 4:
         axes = (0, 2, 3)
         intersect = sum_tensor(pred * y, axes, keepdim=False)
         denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False)
         return torch.mean(( (2*intersect + smooth) / (denom + smooth))[1:]) #only fg dice here.
 
     elif len(pred.size()) == 5:
         axes = (0, 2, 3, 4)
         intersect = sum_tensor(pred * y, axes, keepdim=False)
         denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False)
         return torch.mean(( (2*intersect + smooth) / (denom + smooth))[1:]) #only fg dice here.
     else:
         raise ValueError('wrong input dimension in dice loss')
 
 
 ############################################################
 #  Bounding Boxes
 ############################################################
 
 def compute_iou_2D(box, boxes, box_area, boxes_area):
     """Calculates IoU of the given box with the array of the given boxes.
     box: 1D vector [y1, x1, y2, x2] THIS IS THE GT BOX
     boxes: [boxes_count, (y1, x1, y2, x2)]
     box_area: float. the area of 'box'
     boxes_area: array of length boxes_count.
 
     Note: the areas are passed in rather than calculated here for
           efficency. Calculate once in the caller to avoid duplicate work.
     """
     # Calculate intersection areas
     y1 = np.maximum(box[0], boxes[:, 0])
     y2 = np.minimum(box[2], boxes[:, 2])
     x1 = np.maximum(box[1], boxes[:, 1])
     x2 = np.minimum(box[3], boxes[:, 3])
     intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0)
     union = box_area + boxes_area[:] - intersection[:]
     iou = intersection / union
 
     return iou
 
 
 def compute_iou_3D(box, boxes, box_volume, boxes_volume):
     """Calculates IoU of the given box with the array of the given boxes.
     box: 1D vector [y1, x1, y2, x2, z1, z2] (typically gt box)
     boxes: [boxes_count, (y1, x1, y2, x2, z1, z2)]
     box_area: float. the area of 'box'
     boxes_area: array of length boxes_count.
 
     Note: the areas are passed in rather than calculated here for
           efficency. Calculate once in the caller to avoid duplicate work.
     """
     # Calculate intersection areas
     y1 = np.maximum(box[0], boxes[:, 0])
     y2 = np.minimum(box[2], boxes[:, 2])
     x1 = np.maximum(box[1], boxes[:, 1])
     x2 = np.minimum(box[3], boxes[:, 3])
     z1 = np.maximum(box[4], boxes[:, 4])
     z2 = np.minimum(box[5], boxes[:, 5])
     intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0) * np.maximum(z2 - z1, 0)
     union = box_volume + boxes_volume[:] - intersection[:]
     iou = intersection / union
 
     return iou
 
 
 
 def compute_overlaps(boxes1, boxes2):
     """Computes IoU overlaps between two sets of boxes.
     boxes1, boxes2: [N, (y1, x1, y2, x2)]. / 3D: (z1, z2))
     For better performance, pass the largest set first and the smaller second.
     :return: (#boxes1, #boxes2), ious of each box of 1 machted with each of 2
     """
     # Areas of anchors and GT boxes
     if boxes1.shape[1] == 4:
         area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
         area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
         # Compute overlaps to generate matrix [boxes1 count, boxes2 count]
         # Each cell contains the IoU value.
         overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0]))
         for i in range(overlaps.shape[1]):
             box2 = boxes2[i] #this is the gt box
             overlaps[:, i] = compute_iou_2D(box2, boxes1, area2[i], area1)
         return overlaps
 
     else:
         # Areas of anchors and GT boxes
         volume1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) * (boxes1[:, 5] - boxes1[:, 4])
         volume2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) * (boxes2[:, 5] - boxes2[:, 4])
         # Compute overlaps to generate matrix [boxes1 count, boxes2 count]
         # Each cell contains the IoU value.
         overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0]))
         for i in range(boxes2.shape[0]):
             box2 = boxes2[i]  # this is the gt box
             overlaps[:, i] = compute_iou_3D(box2, boxes1, volume2[i], volume1)
         return overlaps
 
 
 
 def box_refinement(box, gt_box):
     """Compute refinement needed to transform box to gt_box.
     box and gt_box are [N, (y1, x1, y2, x2)] / 3D: (z1, z2))
     """
     height = box[:, 2] - box[:, 0]
     width = box[:, 3] - box[:, 1]
     center_y = box[:, 0] + 0.5 * height
     center_x = box[:, 1] + 0.5 * width
 
     gt_height = gt_box[:, 2] - gt_box[:, 0]
     gt_width = gt_box[:, 3] - gt_box[:, 1]
     gt_center_y = gt_box[:, 0] + 0.5 * gt_height
     gt_center_x = gt_box[:, 1] + 0.5 * gt_width
 
     dy = (gt_center_y - center_y) / height
     dx = (gt_center_x - center_x) / width
     dh = torch.log(gt_height / height)
     dw = torch.log(gt_width / width)
     result = torch.stack([dy, dx, dh, dw], dim=1)
 
     if box.shape[1] > 4:
         depth = box[:, 5] - box[:, 4]
         center_z = box[:, 4] + 0.5 * depth
         gt_depth = gt_box[:, 5] - gt_box[:, 4]
         gt_center_z = gt_box[:, 4] + 0.5 * gt_depth
         dz = (gt_center_z - center_z) / depth
         dd = torch.log(gt_depth / depth)
         result = torch.stack([dy, dx, dz, dh, dw, dd], dim=1)
 
     return result
 
 
 
 def unmold_mask_2D(mask, bbox, image_shape):
     """Converts a mask generated by the neural network into a format similar
     to it's original shape.
     mask: [height, width] of type float. A small, typically 28x28 mask.
     bbox: [y1, x1, y2, x2]. The box to fit the mask in.
 
     Returns a binary mask with the same size as the original image.
     """
     y1, x1, y2, x2 = bbox
     out_zoom = [y2 - y1, x2 - x1]
     zoom_factor = [i / j for i, j in zip(out_zoom, mask.shape)]
 
     mask = scipy.ndimage.zoom(mask, zoom_factor, order=1).astype(np.float32)
 
     # Put the mask in the right location.
     full_mask = np.zeros(image_shape[:2]) #only y,x
     full_mask[y1:y2, x1:x2] = mask
     return full_mask
 
 
 def unmold_mask_2D_torch(mask, bbox, image_shape):
     """Converts a mask generated by the neural network into a format similar
     to it's original shape.
     mask: [height, width] of type float. A small, typically 28x28 mask.
     bbox: [y1, x1, y2, x2]. The box to fit the mask in.
 
     Returns a binary mask with the same size as the original image.
     """
     y1, x1, y2, x2 = bbox
     out_zoom = [(y2 - y1).float(), (x2 - x1).float()]
     zoom_factor = [i / j for i, j in zip(out_zoom, mask.shape)]
 
     mask = mask.unsqueeze(0).unsqueeze(0)
     mask = torch.nn.functional.interpolate(mask, scale_factor=zoom_factor)
     mask = mask[0][0]
     #mask = scipy.ndimage.zoom(mask.cpu().numpy(), zoom_factor, order=1).astype(np.float32)
     #mask = torch.from_numpy(mask).cuda()
     # Put the mask in the right location.
     full_mask = torch.zeros(image_shape[:2])  # only y,x
     full_mask[y1:y2, x1:x2] = mask
     return full_mask
 
 
 
 def unmold_mask_3D(mask, bbox, image_shape):
     """Converts a mask generated by the neural network into a format similar
     to it's original shape.
     mask: [height, width] of type float. A small, typically 28x28 mask.
     bbox: [y1, x1, y2, x2, z1, z2]. The box to fit the mask in.
 
     Returns a binary mask with the same size as the original image.
     """
     y1, x1, y2, x2, z1, z2 = bbox
     out_zoom = [y2 - y1, x2 - x1, z2 - z1]
     zoom_factor = [i/j for i,j in zip(out_zoom, mask.shape)]
     mask = scipy.ndimage.zoom(mask, zoom_factor, order=1).astype(np.float32)
 
     # Put the mask in the right location.
     full_mask = np.zeros(image_shape[:3])
     full_mask[y1:y2, x1:x2, z1:z2] = mask
     return full_mask
 
 def nms_numpy(box_coords, scores, thresh):
     """ non-maximum suppression on 2D or 3D boxes in numpy.
     :param box_coords: [y1,x1,y2,x2 (,z1,z2)] with y1<=y2, x1<=x2, z1<=z2.
     :param scores: ranking scores (higher score == higher rank) of boxes.
     :param thresh: IoU threshold for clustering.
     :return:
     """
     y1 = box_coords[:, 0]
     x1 = box_coords[:, 1]
     y2 = box_coords[:, 2]
     x2 = box_coords[:, 3]
     assert np.all(y1 <= y2) and np.all(x1 <= x2), """"the definition of the coordinates is crucially important here: 
             coordinates of which maxima are taken need to be the lower coordinates"""
     areas = (x2 - x1) * (y2 - y1)
 
     is_3d = box_coords.shape[1] == 6
     if is_3d: # 3-dim case
         z1 = box_coords[:, 4]
         z2 = box_coords[:, 5]
         assert np.all(z1<=z2), """"the definition of the coordinates is crucially important here: 
            coordinates of which maxima are taken need to be the lower coordinates"""
         areas *= (z2 - z1)
 
     order = scores.argsort()[::-1]
 
     keep = []
     while order.size > 0:  # order is the sorted index.  maps order to index: order[1] = 24 means (rank1, ix 24)
         i = order[0] # highest scoring element
         yy1 = np.maximum(y1[i], y1[order])  # highest scoring element still in >order<, is compared to itself, that is okay.
         xx1 = np.maximum(x1[i], x1[order])
         yy2 = np.minimum(y2[i], y2[order])
         xx2 = np.minimum(x2[i], x2[order])
 
         h = np.maximum(0.0, yy2 - yy1)
         w = np.maximum(0.0, xx2 - xx1)
         inter = h * w
 
         if is_3d:
             zz1 = np.maximum(z1[i], z1[order])
             zz2 = np.minimum(z2[i], z2[order])
             d = np.maximum(0.0, zz2 - zz1)
             inter *= d
 
         iou = inter / (areas[i] + areas[order] - inter)
 
         non_matches = np.nonzero(iou <= thresh)[0]  # get all elements that were not matched and discard all others.
         order = order[non_matches]
         keep.append(i)
 
     return keep
 
 
 
 ############################################################
 #  M-RCNN
 ############################################################
 
 def refine_proposals(rpn_pred_probs, rpn_pred_deltas, proposal_count, batch_anchors, cf):
     """
     Receives anchor scores and selects a subset to pass as proposals
     to the second stage. Filtering is done based on anchor scores and
     non-max suppression to remove overlaps. It also applies bounding
     box refinment details to anchors.
     :param rpn_pred_probs: (b, n_anchors, 2)
     :param rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d))))
     :return: batch_normalized_props: Proposals in normalized coordinates (b, proposal_count, (y1, x1, y2, x2, (z1), (z2), score))
     :return: batch_out_proposals: Box coords + RPN foreground scores
     for monitoring/plotting (b, proposal_count, (y1, x1, y2, x2, (z1), (z2), score))
     """
     std_dev = torch.from_numpy(cf.rpn_bbox_std_dev[None]).float().cuda()
     norm = torch.from_numpy(cf.scale).float().cuda()
     anchors = batch_anchors.clone()
 
 
 
     batch_scores = rpn_pred_probs[:, :, 1]
     # norm deltas
     batch_deltas = rpn_pred_deltas * std_dev
     batch_normalized_props = []
     batch_out_proposals = []
 
     # loop over batch dimension.
     for ix in range(batch_scores.shape[0]):
 
         scores = batch_scores[ix]
         deltas = batch_deltas[ix]
 
         # improve performance by trimming to top anchors by score
         # and doing the rest on the smaller subset.
         pre_nms_limit = min(cf.pre_nms_limit, anchors.size()[0])
         scores, order = scores.sort(descending=True)
         order = order[:pre_nms_limit]
         scores = scores[:pre_nms_limit]
         deltas = deltas[order, :]
 
         # apply deltas to anchors to get refined anchors and filter with non-maximum suppression.
         if batch_deltas.shape[-1] == 4:
             boxes = apply_box_deltas_2D(anchors[order, :], deltas)
             boxes = clip_boxes_2D(boxes, cf.window)
         else:
             boxes = apply_box_deltas_3D(anchors[order, :], deltas)
             boxes = clip_boxes_3D(boxes, cf.window)
         # boxes are y1,x1,y2,x2, torchvision-nms requires x1,y1,x2,y2, but consistent swap x<->y is irrelevant.
         keep = nms.nms(boxes, scores, cf.rpn_nms_threshold)
 
 
         keep = keep[:proposal_count]
         boxes = boxes[keep, :]
         rpn_scores = scores[keep][:, None]
 
         # pad missing boxes with 0.
         if boxes.shape[0] < proposal_count:
             n_pad_boxes = proposal_count - boxes.shape[0]
             zeros = torch.zeros([n_pad_boxes, boxes.shape[1]]).cuda()
             boxes = torch.cat([boxes, zeros], dim=0)
             zeros = torch.zeros([n_pad_boxes, rpn_scores.shape[1]]).cuda()
             rpn_scores = torch.cat([rpn_scores, zeros], dim=0)
 
         # concat box and score info for monitoring/plotting.
         batch_out_proposals.append(torch.cat((boxes, rpn_scores), 1).cpu().data.numpy())
         # normalize dimensions to range of 0 to 1.
         normalized_boxes = boxes / norm
         where = normalized_boxes <=1
         assert torch.all(where), "normalized box coords >1 found:\n {}\n".format(normalized_boxes[where])
         #assert torch.all(normalized_boxes <= 1), "normalized box coords >1 found"
 
         # add again batch dimension
         batch_normalized_props.append(torch.cat((normalized_boxes, rpn_scores), 1).unsqueeze(0))
 
     batch_normalized_props = torch.cat(batch_normalized_props)
     batch_out_proposals = np.array(batch_out_proposals)
 
     return batch_normalized_props, batch_out_proposals
 
 def pyramid_roi_align(feature_maps, rois, pool_size, pyramid_levels, dim):
     """
     Implements ROI Pooling on multiple levels of the feature pyramid.
     :param feature_maps: list of feature maps, each of shape (b, c, y, x , (z))
     :param rois: proposals (normalized coords.) as returned by RPN. contain info about original batch element allocation.
     (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ixs)
     :param pool_size: list of poolsizes in dims: [x, y, (z)]
     :param pyramid_levels: list. [0, 1, 2, ...]
     :return: pooled: pooled feature map rois (n_proposals, c, poolsize_y, poolsize_x, (poolsize_z))
 
     Output:
     Pooled regions in the shape: [num_boxes, height, width, channels].
     The width and height are those specific in the pool_shape in the layer
     constructor.
     """
     boxes = rois[:, :dim*2]
     batch_ixs = rois[:, dim*2]
 
     # Assign each ROI to a level in the pyramid based on the ROI area.
     if dim == 2:
         y1, x1, y2, x2 = boxes.chunk(4, dim=1)
     else:
         y1, x1, y2, x2, z1, z2 = boxes.chunk(6, dim=1)
 
     h = y2 - y1
     w = x2 - x1
 
     # Equation 1 in https://arxiv.org/abs/1612.03144. Account for
     # the fact that our coordinates are normalized here.
     # divide sqrt(h*w) by 1 instead image_area.
     roi_level = (4 + torch.log2(torch.sqrt(h*w))).round().int().clamp(pyramid_levels[0], pyramid_levels[-1])
     # if Pyramid contains additional level P6, adapt the roi_level assignment accordingly.
     if len(pyramid_levels) == 5:
         roi_level[h*w > 0.65] = 5
 
     # Loop through levels and apply ROI pooling to each.
     pooled = []
     box_to_level = []
     fmap_shapes = [f.shape for f in feature_maps]
     for level_ix, level in enumerate(pyramid_levels):
         ix = roi_level == level
         if not ix.any():
             continue
         ix = torch.nonzero(ix)[:, 0]
         level_boxes = boxes[ix, :]
         # re-assign rois to feature map of original batch element.
         ind = batch_ixs[ix].int()
 
         # Keep track of which box is mapped to which level
         box_to_level.append(ix)
 
         # Stop gradient propogation to ROI proposals
         level_boxes = level_boxes.detach()
         if len(pool_size) == 2:
             # remap to feature map coordinate system
             y_exp, x_exp = fmap_shapes[level_ix][2:]  # exp = expansion
             level_boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp], dtype=torch.float32).cuda())
             pooled_features = roi_align.roi_align_2d(feature_maps[level_ix],
                                                      torch.cat((ind.unsqueeze(1).float(), level_boxes), dim=1),
                                                      pool_size)
         else:
             y_exp, x_exp, z_exp = fmap_shapes[level_ix][2:]
             level_boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp, z_exp, z_exp], dtype=torch.float32).cuda())
             pooled_features = roi_align.roi_align_3d(feature_maps[level_ix],
                                                      torch.cat((ind.unsqueeze(1).float(), level_boxes), dim=1),
                                                      pool_size)
         pooled.append(pooled_features)
 
 
     # Pack pooled features into one tensor
     pooled = torch.cat(pooled, dim=0)
 
     # Pack box_to_level mapping into one array and add another
     # column representing the order of pooled boxes
     box_to_level = torch.cat(box_to_level, dim=0)
 
     # Rearrange pooled features to match the order of the original boxes
     _, box_to_level = torch.sort(box_to_level)
     pooled = pooled[box_to_level, :, :]
 
     return pooled
 
 
 def roi_align_3d_numpy(input: np.ndarray, rois, output_size: tuple,
                        spatial_scale: float = 1., sampling_ratio: int = -1) -> np.ndarray:
     """ This fct mainly serves as a verification method for 3D CUDA implementation of RoIAlign, it's highly
         inefficient due to the nested loops.
     :param input:  (ndarray[N, C, H, W, D]): input feature map
     :param rois: list (N,K(n), 6), K(n) = nr of rois in batch-element n, single roi of format (y1,x1,y2,x2,z1,z2)
     :param output_size:
     :param spatial_scale:
     :param sampling_ratio:
     :return: (List[N, K(n), C, output_size[0], output_size[1], output_size[2]])
     """
 
     out_height, out_width, out_depth = output_size
 
     coord_grid = tuple([np.linspace(0, input.shape[dim] - 1, num=input.shape[dim]) for dim in range(2, 5)])
     pooled_rois = [[]] * len(rois)
     assert len(rois) == input.shape[0], "batch dim mismatch, rois: {}, input: {}".format(len(rois), input.shape[0])
     print("Numpy 3D RoIAlign progress:", end="\n")
     for b in range(input.shape[0]):
         for roi in tqdm.tqdm(rois[b]):
             y1, x1, y2, x2, z1, z2 = np.array(roi) * spatial_scale
             roi_height = max(float(y2 - y1), 1.)
             roi_width = max(float(x2 - x1), 1.)
             roi_depth = max(float(z2 - z1), 1.)
 
             if sampling_ratio <= 0:
                 sampling_ratio_h = int(np.ceil(roi_height / out_height))
                 sampling_ratio_w = int(np.ceil(roi_width / out_width))
                 sampling_ratio_d = int(np.ceil(roi_depth / out_depth))
             else:
                 sampling_ratio_h = sampling_ratio_w = sampling_ratio_d = sampling_ratio  # == n points per bin
 
             bin_height = roi_height / out_height
             bin_width = roi_width / out_width
             bin_depth = roi_depth / out_depth
 
             n_points = sampling_ratio_h * sampling_ratio_w * sampling_ratio_d
             pooled_roi = np.empty((input.shape[1], out_height, out_width, out_depth), dtype="float32")
             for chan in range(input.shape[1]):
                 lin_interpolator = scipy.interpolate.RegularGridInterpolator(coord_grid, input[b, chan],
                                                                              method="linear")
                 for bin_iy in range(out_height):
                     for bin_ix in range(out_width):
                         for bin_iz in range(out_depth):
 
                             bin_val = 0.
                             for i in range(sampling_ratio_h):
                                 for j in range(sampling_ratio_w):
                                     for k in range(sampling_ratio_d):
                                         loc_ijk = [
-                                            y1 + bin_iy * bin_height + (i + 0.5)* ((bin_height -1) / sampling_ratio_h),
-                                            x1 + bin_ix * bin_width + (j + 0.5) * ((bin_width -1) / sampling_ratio_w),
-                                            z1 + bin_iz * bin_depth + (k + 0.5) * ((bin_depth -1) / sampling_ratio_d)]
+                                            y1 + bin_iy * bin_height + (i + 0.5) * (bin_height / sampling_ratio_h),
+                                            x1 + bin_ix * bin_width +  (j + 0.5) * (bin_width / sampling_ratio_w),
+                                            z1 + bin_iz * bin_depth +  (k + 0.5) * (bin_depth / sampling_ratio_d)]
                                         # print("loc_ijk", loc_ijk)
                                         if not (np.any([c < -1.0 for c in loc_ijk]) or loc_ijk[0] > input.shape[2] or
                                                 loc_ijk[1] > input.shape[3] or loc_ijk[2] > input.shape[4]):
                                             for catch_case in range(3):
                                                 # catch on-border cases
                                                 if int(loc_ijk[catch_case]) == input.shape[catch_case + 2] - 1:
                                                     loc_ijk[catch_case] = input.shape[catch_case + 2] - 1
                                             bin_val += lin_interpolator(loc_ijk)
                             pooled_roi[chan, bin_iy, bin_ix, bin_iz] = bin_val / n_points
 
             pooled_rois[b].append(pooled_roi)
 
     return np.array(pooled_rois)
 
 def refine_detections(cf, batch_ixs, rois, deltas, scores, regressions):
     """
     Refine classified proposals (apply deltas to rpn rois), filter overlaps (nms) and return final detections.
 
     :param rois: (n_proposals, 2 * dim) normalized boxes as proposed by RPN. n_proposals = batch_size * POST_NMS_ROIS
     :param deltas: (n_proposals, n_classes, 2 * dim) box refinement deltas as predicted by mrcnn bbox regressor.
     :param batch_ixs: (n_proposals) batch element assignment info for re-allocation.
     :param scores: (n_proposals, n_classes) probabilities for all classes per roi as predicted by mrcnn classifier.
     :param regressions: (n_proposals, n_classes, regression_features (+1 for uncertainty if predicted) regression vector
     :return: result: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score, *regression vector features))
     """
     # class IDs per ROI. Since scores of all classes are of interest (not just max class), all are kept at this point.
     class_ids = []
     fg_classes = cf.head_classes - 1
     # repeat vectors to fill in predictions for all foreground classes.
     for ii in range(1, fg_classes + 1):
         class_ids += [ii] * rois.shape[0]
     class_ids = torch.from_numpy(np.array(class_ids)).cuda()
 
     batch_ixs = batch_ixs.repeat(fg_classes)
     rois = rois.repeat(fg_classes, 1)
     deltas = deltas.repeat(fg_classes, 1, 1)
     scores = scores.repeat(fg_classes, 1)
     regressions = regressions.repeat(fg_classes, 1, 1)
 
     # get class-specific scores and  bounding box deltas
     idx = torch.arange(class_ids.size()[0]).long().cuda()
     # using idx instead of slice [:,] squashes first dimension.
     #len(class_ids)>scores.shape[1] --> probs is broadcasted by expansion from fg_classes-->len(class_ids)
     batch_ixs = batch_ixs[idx]
     deltas_specific = deltas[idx, class_ids]
     class_scores = scores[idx, class_ids]
     regressions = regressions[idx, class_ids]
 
     # apply bounding box deltas. re-scale to image coordinates.
     std_dev = torch.from_numpy(np.reshape(cf.rpn_bbox_std_dev, [1, cf.dim * 2])).float().cuda()
     scale = torch.from_numpy(cf.scale).float().cuda()
     refined_rois = apply_box_deltas_2D(rois, deltas_specific * std_dev) * scale if cf.dim == 2 else \
         apply_box_deltas_3D(rois, deltas_specific * std_dev) * scale
 
     # round and cast to int since we're dealing with pixels now
     refined_rois = clip_to_window(cf.window, refined_rois)
     refined_rois = torch.round(refined_rois)
 
     # filter out low confidence boxes
     keep = idx
     keep_bool = (class_scores >= cf.model_min_confidence)
     if not 0 in torch.nonzero(keep_bool).size():
 
         score_keep = torch.nonzero(keep_bool)[:, 0]
         pre_nms_class_ids = class_ids[score_keep]
         pre_nms_rois = refined_rois[score_keep]
         pre_nms_scores = class_scores[score_keep]
         pre_nms_batch_ixs = batch_ixs[score_keep]
 
         for j, b in enumerate(unique1d(pre_nms_batch_ixs)):
 
             bixs = torch.nonzero(pre_nms_batch_ixs == b)[:, 0]
             bix_class_ids = pre_nms_class_ids[bixs]
             bix_rois = pre_nms_rois[bixs]
             bix_scores = pre_nms_scores[bixs]
 
             for i, class_id in enumerate(unique1d(bix_class_ids)):
 
                 ixs = torch.nonzero(bix_class_ids == class_id)[:, 0]
                 # nms expects boxes sorted by score.
                 ix_rois = bix_rois[ixs]
                 ix_scores = bix_scores[ixs]
                 ix_scores, order = ix_scores.sort(descending=True)
                 ix_rois = ix_rois[order, :]
 
                 class_keep = nms.nms(ix_rois, ix_scores, cf.detection_nms_threshold)
 
                 # map indices back.
                 class_keep = keep[score_keep[bixs[ixs[order[class_keep]]]]]
                 # merge indices over classes for current batch element
                 b_keep = class_keep if i == 0 else unique1d(torch.cat((b_keep, class_keep)))
 
             # only keep top-k boxes of current batch-element
             top_ids = class_scores[b_keep].sort(descending=True)[1][:cf.model_max_instances_per_batch_element]
             b_keep = b_keep[top_ids]
 
             # merge indices over batch elements.
             batch_keep = b_keep  if j == 0 else unique1d(torch.cat((batch_keep, b_keep)))
 
         keep = batch_keep
 
     else:
         keep = torch.tensor([0]).long().cuda()
 
     # arrange output
     output = [refined_rois[keep], batch_ixs[keep].unsqueeze(1)]
     output += [class_ids[keep].unsqueeze(1).float(), class_scores[keep].unsqueeze(1)]
     output += [regressions[keep]]
 
     result = torch.cat(output, dim=1)
     # shape: (n_keeps, catted feats), catted feats: [0:dim*2] are box_coords, [dim*2] are batch_ics,
     # [dim*2+1] are class_ids, [dim*2+2] are scores, [dim*2+3:] are regression vector features (incl uncertainty)
     return result
 
 
 def loss_example_mining(cf, batch_proposals, batch_gt_boxes, batch_gt_masks, batch_roi_scores,
                            batch_gt_class_ids, batch_gt_regressions):
     """
     Subsamples proposals for mrcnn losses and generates targets. Sampling is done per batch element, seems to have positive
     effects on training, as opposed to sampling over entire batch. Negatives are sampled via stochastic hard-example mining
     (SHEM), where a number of negative proposals is drawn from larger pool of highest scoring proposals for stochasticity.
     Scoring is obtained here as the max over all foreground probabilities as returned by mrcnn_classifier (worked better than
     loss-based class-balancing methods like "online hard-example mining" or "focal loss".)
 
     Classification-regression duality: regressions can be given along with classes (at least fg/bg, only class scores
     are used for ranking).
 
     :param batch_proposals: (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ixs).
     boxes as proposed by RPN. n_proposals here is determined by batch_size * POST_NMS_ROIS.
     :param mrcnn_class_logits: (n_proposals, n_classes)
     :param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates.
     :param batch_gt_masks: list over batch elements. Each element is binary mask of shape (n_gt_rois, c, y, x, (z))
     :param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels.
         if no classes predicted (only fg/bg from RPN): expected as pseudo classes [0, 1] for bg, fg.
     :param batch_gt_regressions: list over b elements. Each element is a regression target vector. if None--> pseudo
     :return: sample_indices: (n_sampled_rois) indices of sampled proposals to be used for loss functions.
     :return: target_class_ids: (n_sampled_rois)containing target class labels of sampled proposals.
     :return: target_deltas: (n_sampled_rois, 2 * dim) containing target deltas of sampled proposals for box refinement.
     :return: target_masks: (n_sampled_rois, y, x, (z)) containing target masks of sampled proposals.
     """
     # normalization of target coordinates
     #global sample_regressions
     if cf.dim == 2:
         h, w = cf.patch_size
         scale = torch.from_numpy(np.array([h, w, h, w])).float().cuda()
     else:
         h, w, z = cf.patch_size
         scale = torch.from_numpy(np.array([h, w, h, w, z, z])).float().cuda()
 
     positive_count = 0
     negative_count = 0
     sample_positive_indices = []
     sample_negative_indices = []
     sample_deltas = []
     sample_masks = []
     sample_class_ids = []
     if batch_gt_regressions is not None:
         sample_regressions = []
     else:
         target_regressions = torch.FloatTensor().cuda()
 
     std_dev = torch.from_numpy(cf.bbox_std_dev).float().cuda()
 
     # loop over batch and get positive and negative sample rois.
     for b in range(len(batch_gt_boxes)):
 
         gt_masks = torch.from_numpy(batch_gt_masks[b]).float().cuda()
         gt_class_ids = torch.from_numpy(batch_gt_class_ids[b]).int().cuda()
         if batch_gt_regressions is not None:
             gt_regressions = torch.from_numpy(batch_gt_regressions[b]).float().cuda()
 
         #if np.any(batch_gt_class_ids[b] > 0):  # skip roi selection for no gt images.
         if np.any([len(coords)>0 for coords in batch_gt_boxes[b]]):
             gt_boxes = torch.from_numpy(batch_gt_boxes[b]).float().cuda() / scale
         else:
             gt_boxes = torch.FloatTensor().cuda()
 
         # get proposals and indices of current batch element.
         proposals = batch_proposals[batch_proposals[:, -1] == b][:, :-1]
         batch_element_indices = torch.nonzero(batch_proposals[:, -1] == b).squeeze(1)
 
         # Compute overlaps matrix [proposals, gt_boxes]
         if not 0 in gt_boxes.size():
             if gt_boxes.shape[1] == 4:
                 assert cf.dim == 2, "gt_boxes shape {} doesnt match cf.dim{}".format(gt_boxes.shape, cf.dim)
                 overlaps = bbox_overlaps_2D(proposals, gt_boxes)
             else:
                 assert cf.dim == 3, "gt_boxes shape {} doesnt match cf.dim{}".format(gt_boxes.shape, cf.dim)
                 overlaps = bbox_overlaps_3D(proposals, gt_boxes)
 
             # Determine positive and negative ROIs
             roi_iou_max = torch.max(overlaps, dim=1)[0]
             # 1. Positive ROIs are those with >= 0.5 IoU with a GT box
             positive_roi_bool = roi_iou_max >= (0.5 if cf.dim == 2 else 0.3)
             # 2. Negative ROIs are those with < 0.1 with every GT box.
             negative_roi_bool = roi_iou_max < (0.1 if cf.dim == 2 else 0.01)
         else:
             positive_roi_bool = torch.FloatTensor().cuda()
             negative_roi_bool = torch.from_numpy(np.array([1]*proposals.shape[0])).cuda()
 
         # Sample Positive ROIs
         if not 0 in torch.nonzero(positive_roi_bool).size():
             positive_indices = torch.nonzero(positive_roi_bool).squeeze(1)
             positive_samples = int(cf.train_rois_per_image * cf.roi_positive_ratio)
             rand_idx = torch.randperm(positive_indices.size()[0])
             rand_idx = rand_idx[:positive_samples].cuda()
             positive_indices = positive_indices[rand_idx]
             positive_samples = positive_indices.size()[0]
             positive_rois = proposals[positive_indices, :]
             # Assign positive ROIs to GT boxes.
             positive_overlaps = overlaps[positive_indices, :]
             roi_gt_box_assignment = torch.max(positive_overlaps, dim=1)[1]
             roi_gt_boxes = gt_boxes[roi_gt_box_assignment, :]
             roi_gt_class_ids = gt_class_ids[roi_gt_box_assignment]
             if batch_gt_regressions is not None:
                 roi_gt_regressions = gt_regressions[roi_gt_box_assignment]
 
             # Compute bbox refinement targets for positive ROIs
             deltas = box_refinement(positive_rois, roi_gt_boxes)
             deltas /= std_dev
 
             roi_masks = gt_masks[roi_gt_box_assignment]
             assert roi_masks.shape[1] == 1, "gt masks have more than one channel --> is this desired?"
             # Compute mask targets
             boxes = positive_rois
             box_ids = torch.arange(roi_masks.shape[0]).cuda().unsqueeze(1).float()
 
             if len(cf.mask_shape) == 2:
                 y_exp, x_exp = roi_masks.shape[2:]  # exp = expansion
                 boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp], dtype=torch.float32).cuda())
                 masks = roi_align.roi_align_2d(roi_masks,
                                                torch.cat((box_ids, boxes), dim=1),
                                                cf.mask_shape)
             else:
                 y_exp, x_exp, z_exp = roi_masks.shape[2:]  # exp = expansion
                 boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp, z_exp, z_exp], dtype=torch.float32).cuda())
                 masks = roi_align.roi_align_3d(roi_masks,
                                                torch.cat((box_ids, boxes), dim=1),
                                                cf.mask_shape)
 
             masks = masks.squeeze(1)
             # Threshold mask pixels at 0.5 to have GT masks be 0 or 1 to use with
             # binary cross entropy loss.
             masks = torch.round(masks)
 
             sample_positive_indices.append(batch_element_indices[positive_indices])
             sample_deltas.append(deltas)
             sample_masks.append(masks)
             sample_class_ids.append(roi_gt_class_ids)
             if batch_gt_regressions is not None:
                 sample_regressions.append(roi_gt_regressions)
             positive_count += positive_samples
         else:
             positive_samples = 0
 
         # Sample negative ROIs. Add enough to maintain positive:negative ratio, but at least 1. Sample via SHEM.
         if not 0 in torch.nonzero(negative_roi_bool).size():
             negative_indices = torch.nonzero(negative_roi_bool).squeeze(1)
             r = 1.0 / cf.roi_positive_ratio
             b_neg_count = np.max((int(r * positive_samples - positive_samples), 1))
             roi_scores_neg = batch_roi_scores[batch_element_indices[negative_indices]]
             raw_sampled_indices = shem(roi_scores_neg, b_neg_count, cf.shem_poolsize)
             sample_negative_indices.append(batch_element_indices[negative_indices[raw_sampled_indices]])
             negative_count  += raw_sampled_indices.size()[0]
 
     if len(sample_positive_indices) > 0:
         target_deltas = torch.cat(sample_deltas)
         target_masks = torch.cat(sample_masks)
         target_class_ids = torch.cat(sample_class_ids)
         if batch_gt_regressions is not None:
             target_regressions = torch.cat(sample_regressions)
 
     # Pad target information with zeros for negative ROIs.
     if positive_count > 0 and negative_count > 0:
         sample_indices = torch.cat((torch.cat(sample_positive_indices), torch.cat(sample_negative_indices)), dim=0)
         zeros = torch.zeros(negative_count, cf.dim * 2).cuda()
         target_deltas = torch.cat([target_deltas, zeros], dim=0)
         zeros = torch.zeros(negative_count, *cf.mask_shape).cuda()
         target_masks = torch.cat([target_masks, zeros], dim=0)
         zeros = torch.zeros(negative_count).int().cuda()
         target_class_ids = torch.cat([target_class_ids, zeros], dim=0)
         if batch_gt_regressions is not None:
             # regression targets need to have 0 as background/negative with below practice
             if 'regression_bin' in cf.prediction_tasks:
                 zeros = torch.zeros(negative_count, dtype=torch.float).cuda()
             else:
                 zeros = torch.zeros(negative_count, cf.regression_n_features, dtype=torch.float).cuda()
             target_regressions = torch.cat([target_regressions, zeros], dim=0)
 
     elif positive_count > 0:
         sample_indices = torch.cat(sample_positive_indices)
     elif negative_count > 0:
         sample_indices = torch.cat(sample_negative_indices)
         target_deltas = torch.zeros(negative_count, cf.dim * 2).cuda()
         target_masks = torch.zeros(negative_count, *cf.mask_shape).cuda()
         target_class_ids = torch.zeros(negative_count).int().cuda()
         if batch_gt_regressions is not None:
             if 'regression_bin' in cf.prediction_tasks:
                 target_regressions = torch.zeros(negative_count, dtype=torch.float).cuda()
             else:
                 target_regressions = torch.zeros(negative_count, cf.regression_n_features, dtype=torch.float).cuda()
     else:
         sample_indices = torch.LongTensor().cuda()
         target_class_ids = torch.IntTensor().cuda()
         target_deltas = torch.FloatTensor().cuda()
         target_masks = torch.FloatTensor().cuda()
         target_regressions = torch.FloatTensor().cuda()
 
     return sample_indices, target_deltas, target_masks, target_class_ids, target_regressions
 
 ############################################################
 #  Anchors
 ############################################################
 
 def generate_anchors(scales, ratios, shape, feature_stride, anchor_stride):
     """
     scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128]
     ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2]
     shape: [height, width] spatial shape of the feature map over which
             to generate anchors.
     feature_stride: Stride of the feature map relative to the image in pixels.
     anchor_stride: Stride of anchors on the feature map. For example, if the
         value is 2 then generate anchors for every other feature map pixel.
     """
     # Get all combinations of scales and ratios
     scales, ratios = np.meshgrid(np.array(scales), np.array(ratios))
     scales = scales.flatten()
     ratios = ratios.flatten()
 
     # Enumerate heights and widths from scales and ratios
     heights = scales / np.sqrt(ratios)
     widths = scales * np.sqrt(ratios)
 
     # Enumerate shifts in feature space
     shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride
     shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride
     shifts_x, shifts_y = np.meshgrid(shifts_x, shifts_y)
 
     # Enumerate combinations of shifts, widths, and heights
     box_widths, box_centers_x = np.meshgrid(widths, shifts_x)
     box_heights, box_centers_y = np.meshgrid(heights, shifts_y)
 
     # Reshape to get a list of (y, x) and a list of (h, w)
     box_centers = np.stack([box_centers_y, box_centers_x], axis=2).reshape([-1, 2])
     box_sizes = np.stack([box_heights, box_widths], axis=2).reshape([-1, 2])
 
     # Convert to corner coordinates (y1, x1, y2, x2)
     boxes = np.concatenate([box_centers - 0.5 * box_sizes, box_centers + 0.5 * box_sizes], axis=1)
     return boxes
 
 
 
 def generate_anchors_3D(scales_xy, scales_z, ratios, shape, feature_stride_xy, feature_stride_z, anchor_stride):
     """
     scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128]
     ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2]
     shape: [height, width] spatial shape of the feature map over which
             to generate anchors.
     feature_stride: Stride of the feature map relative to the image in pixels.
     anchor_stride: Stride of anchors on the feature map. For example, if the
         value is 2 then generate anchors for every other feature map pixel.
     """
     # Get all combinations of scales and ratios
 
     scales_xy, ratios_meshed = np.meshgrid(np.array(scales_xy), np.array(ratios))
     scales_xy = scales_xy.flatten()
     ratios_meshed = ratios_meshed.flatten()
 
     # Enumerate heights and widths from scales and ratios
     heights = scales_xy / np.sqrt(ratios_meshed)
     widths = scales_xy * np.sqrt(ratios_meshed)
     depths = np.tile(np.array(scales_z), len(ratios_meshed)//np.array(scales_z)[..., None].shape[0])
 
     # Enumerate shifts in feature space
     shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride_xy #translate from fm positions to input coords.
     shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride_xy
     shifts_z = np.arange(0, shape[2], anchor_stride) * (feature_stride_z)
     shifts_x, shifts_y, shifts_z = np.meshgrid(shifts_x, shifts_y, shifts_z)
 
     # Enumerate combinations of shifts, widths, and heights
     box_widths, box_centers_x = np.meshgrid(widths, shifts_x)
     box_heights, box_centers_y = np.meshgrid(heights, shifts_y)
     box_depths, box_centers_z = np.meshgrid(depths, shifts_z)
 
     # Reshape to get a list of (y, x, z) and a list of (h, w, d)
     box_centers = np.stack(
         [box_centers_y, box_centers_x, box_centers_z], axis=2).reshape([-1, 3])
     box_sizes = np.stack([box_heights, box_widths, box_depths], axis=2).reshape([-1, 3])
 
     # Convert to corner coordinates (y1, x1, y2, x2, z1, z2)
     boxes = np.concatenate([box_centers - 0.5 * box_sizes,
                             box_centers + 0.5 * box_sizes], axis=1)
 
     boxes = np.transpose(np.array([boxes[:, 0], boxes[:, 1], boxes[:, 3], boxes[:, 4], boxes[:, 2], boxes[:, 5]]), axes=(1, 0))
     return boxes
 
 
 def generate_pyramid_anchors(logger, cf):
     """Generate anchors at different levels of a feature pyramid. Each scale
     is associated with a level of the pyramid, but each ratio is used in
     all levels of the pyramid.
 
     from configs:
     :param scales: cf.RPN_ANCHOR_SCALES , for conformity with retina nets: scale entries need to be list, e.g. [[4], [8], [16], [32]]
     :param ratios: cf.RPN_ANCHOR_RATIOS , e.g. [0.5, 1, 2]
     :param feature_shapes: cf.BACKBONE_SHAPES , e.g.  [array of shapes per feature map] [80, 40, 20, 10, 5]
     :param feature_strides: cf.BACKBONE_STRIDES , e.g. [2, 4, 8, 16, 32, 64]
     :param anchors_stride: cf.RPN_ANCHOR_STRIDE , e.g. 1
     :return anchors: (N, (y1, x1, y2, x2, (z1), (z2)). All generated anchors in one array. Sorted
     with the same order of the given scales. So, anchors of scale[0] come first, then anchors of scale[1], and so on.
     """
     scales = cf.rpn_anchor_scales
     ratios = cf.rpn_anchor_ratios
     feature_shapes = cf.backbone_shapes
     anchor_stride = cf.rpn_anchor_stride
     pyramid_levels = cf.pyramid_levels
     feature_strides = cf.backbone_strides
 
     logger.info("anchor scales {} and feature map shapes {}".format(scales, feature_shapes))
     expected_anchors = [np.prod(feature_shapes[level]) * len(ratios) * len(scales['xy'][level]) for level in pyramid_levels]
 
     anchors = []
     for lix, level in enumerate(pyramid_levels):
         if len(feature_shapes[level]) == 2:
             anchors.append(generate_anchors(scales['xy'][level], ratios, feature_shapes[level],
                                             feature_strides['xy'][level], anchor_stride))
         elif len(feature_shapes[level]) == 3:
             anchors.append(generate_anchors_3D(scales['xy'][level], scales['z'][level], ratios, feature_shapes[level],
                                             feature_strides['xy'][level], feature_strides['z'][level], anchor_stride))
         else:
             raise Exception("invalid feature_shapes[{}] size {}".format(level, feature_shapes[level]))
         logger.info("level {}: expected anchors {}, built anchors {}.".format(level, expected_anchors[lix], anchors[-1].shape))
 
     out_anchors = np.concatenate(anchors, axis=0)
     logger.info("Total: expected anchors {}, built anchors {}.".format(np.sum(expected_anchors), out_anchors.shape))
 
     return out_anchors
 
 
 
 def apply_box_deltas_2D(boxes, deltas):
     """Applies the given deltas to the given boxes.
     boxes: [N, 4] where each row is y1, x1, y2, x2
     deltas: [N, 4] where each row is [dy, dx, log(dh), log(dw)]
     """
     # Convert to y, x, h, w
     height = boxes[:, 2] - boxes[:, 0]
     width = boxes[:, 3] - boxes[:, 1]
     center_y = boxes[:, 0] + 0.5 * height
     center_x = boxes[:, 1] + 0.5 * width
     # Apply deltas
     center_y += deltas[:, 0] * height
     center_x += deltas[:, 1] * width
     height *= torch.exp(deltas[:, 2])
     width *= torch.exp(deltas[:, 3])
     # Convert back to y1, x1, y2, x2
     y1 = center_y - 0.5 * height
     x1 = center_x - 0.5 * width
     y2 = y1 + height
     x2 = x1 + width
     result = torch.stack([y1, x1, y2, x2], dim=1)
     return result
 
 
 
 def apply_box_deltas_3D(boxes, deltas):
     """Applies the given deltas to the given boxes.
     boxes: [N, 6] where each row is y1, x1, y2, x2, z1, z2
     deltas: [N, 6] where each row is [dy, dx, dz, log(dh), log(dw), log(dd)]
     """
     # Convert to y, x, h, w
     height = boxes[:, 2] - boxes[:, 0]
     width = boxes[:, 3] - boxes[:, 1]
     depth = boxes[:, 5] - boxes[:, 4]
     center_y = boxes[:, 0] + 0.5 * height
     center_x = boxes[:, 1] + 0.5 * width
     center_z = boxes[:, 4] + 0.5 * depth
     # Apply deltas
     center_y += deltas[:, 0] * height
     center_x += deltas[:, 1] * width
     center_z += deltas[:, 2] * depth
     height *= torch.exp(deltas[:, 3])
     width *= torch.exp(deltas[:, 4])
     depth *= torch.exp(deltas[:, 5])
     # Convert back to y1, x1, y2, x2
     y1 = center_y - 0.5 * height
     x1 = center_x - 0.5 * width
     z1 = center_z - 0.5 * depth
     y2 = y1 + height
     x2 = x1 + width
     z2 = z1 + depth
     result = torch.stack([y1, x1, y2, x2, z1, z2], dim=1)
     return result
 
 
 
 def clip_boxes_2D(boxes, window):
     """
     boxes: [N, 4] each col is y1, x1, y2, x2
     window: [4] in the form y1, x1, y2, x2
     """
     boxes = torch.stack( \
         [boxes[:, 0].clamp(float(window[0]), float(window[2])),
          boxes[:, 1].clamp(float(window[1]), float(window[3])),
          boxes[:, 2].clamp(float(window[0]), float(window[2])),
          boxes[:, 3].clamp(float(window[1]), float(window[3]))], 1)
     return boxes
 
 def clip_boxes_3D(boxes, window):
     """
     boxes: [N, 6] each col is y1, x1, y2, x2, z1, z2
     window: [6] in the form y1, x1, y2, x2, z1, z2
     """
     boxes = torch.stack( \
         [boxes[:, 0].clamp(float(window[0]), float(window[2])),
          boxes[:, 1].clamp(float(window[1]), float(window[3])),
          boxes[:, 2].clamp(float(window[0]), float(window[2])),
          boxes[:, 3].clamp(float(window[1]), float(window[3])),
          boxes[:, 4].clamp(float(window[4]), float(window[5])),
          boxes[:, 5].clamp(float(window[4]), float(window[5]))], 1)
     return boxes
 
 from matplotlib import pyplot as plt
 
 
 def clip_boxes_numpy(boxes, window):
     """
     boxes: [N, 4] each col is y1, x1, y2, x2 / [N, 6] in 3D.
     window: iamge shape (y, x, (z))
     """
     if boxes.shape[1] == 4:
         boxes = np.concatenate(
             (np.clip(boxes[:, 0], 0, window[0])[:, None],
             np.clip(boxes[:, 1], 0, window[0])[:, None],
             np.clip(boxes[:, 2], 0, window[1])[:, None],
             np.clip(boxes[:, 3], 0, window[1])[:, None]), 1
         )
 
     else:
         boxes = np.concatenate(
             (np.clip(boxes[:, 0], 0, window[0])[:, None],
              np.clip(boxes[:, 1], 0, window[0])[:, None],
              np.clip(boxes[:, 2], 0, window[1])[:, None],
              np.clip(boxes[:, 3], 0, window[1])[:, None],
              np.clip(boxes[:, 4], 0, window[2])[:, None],
              np.clip(boxes[:, 5], 0, window[2])[:, None]), 1
         )
 
     return boxes
 
 
 
 def bbox_overlaps_2D(boxes1, boxes2):
     """Computes IoU overlaps between two sets of boxes.
     boxes1, boxes2: [N, (y1, x1, y2, x2)].
     """
     # 1. Tile boxes2 and repeate boxes1. This allows us to compare
     # every boxes1 against every boxes2 without loops.
     # TF doesn't have an equivalent to np.repeate() so simulate it
     # using tf.tile() and tf.reshape.
 
     boxes1_repeat = boxes2.size()[0]
     boxes2_repeat = boxes1.size()[0]
 
     boxes1 = boxes1.repeat(1,boxes1_repeat).view(-1,4)
     boxes2 = boxes2.repeat(boxes2_repeat,1)
 
     # 2. Compute intersections
     b1_y1, b1_x1, b1_y2, b1_x2 = boxes1.chunk(4, dim=1)
     b2_y1, b2_x1, b2_y2, b2_x2 = boxes2.chunk(4, dim=1)
     y1 = torch.max(b1_y1, b2_y1)[:, 0]
     x1 = torch.max(b1_x1, b2_x1)[:, 0]
     y2 = torch.min(b1_y2, b2_y2)[:, 0]
     x2 = torch.min(b1_x2, b2_x2)[:, 0]
     #--> expects x1<x2 & y1<y2
     zeros = torch.zeros(y1.size()[0], requires_grad=False)
     if y1.is_cuda:
         zeros = zeros.cuda()
     intersection = torch.max(x2 - x1, zeros) * torch.max(y2 - y1, zeros)
 
     # 3. Compute unions
     b1_area = (b1_y2 - b1_y1) * (b1_x2 - b1_x1)
     b2_area = (b2_y2 - b2_y1) * (b2_x2 - b2_x1)
     union = b1_area[:,0] + b2_area[:,0] - intersection
 
     # 4. Compute IoU and reshape to [boxes1, boxes2]
     iou = intersection / union
     assert torch.all(iou<=1), "iou score>1 produced in bbox_overlaps_2D"
     overlaps = iou.view(boxes2_repeat, boxes1_repeat) #--> per gt box: ious of all proposal boxes with that gt box
 
     return overlaps
 
 def bbox_overlaps_3D(boxes1, boxes2):
     """Computes IoU overlaps between two sets of boxes.
     boxes1, boxes2: [N, (y1, x1, y2, x2, z1, z2)].
     """
     # 1. Tile boxes2 and repeate boxes1. This allows us to compare
     # every boxes1 against every boxes2 without loops.
     # TF doesn't have an equivalent to np.repeate() so simulate it
     # using tf.tile() and tf.reshape.
     boxes1_repeat = boxes2.size()[0]
     boxes2_repeat = boxes1.size()[0]
     boxes1 = boxes1.repeat(1,boxes1_repeat).view(-1,6)
     boxes2 = boxes2.repeat(boxes2_repeat,1)
 
     # 2. Compute intersections
     b1_y1, b1_x1, b1_y2, b1_x2, b1_z1, b1_z2 = boxes1.chunk(6, dim=1)
     b2_y1, b2_x1, b2_y2, b2_x2, b2_z1, b2_z2 = boxes2.chunk(6, dim=1)
     y1 = torch.max(b1_y1, b2_y1)[:, 0]
     x1 = torch.max(b1_x1, b2_x1)[:, 0]
     y2 = torch.min(b1_y2, b2_y2)[:, 0]
     x2 = torch.min(b1_x2, b2_x2)[:, 0]
     z1 = torch.max(b1_z1, b2_z1)[:, 0]
     z2 = torch.min(b1_z2, b2_z2)[:, 0]
     zeros = torch.zeros(y1.size()[0], requires_grad=False)
     if y1.is_cuda:
         zeros = zeros.cuda()
     intersection = torch.max(x2 - x1, zeros) * torch.max(y2 - y1, zeros) * torch.max(z2 - z1, zeros)
 
     # 3. Compute unions
     b1_volume = (b1_y2 - b1_y1) * (b1_x2 - b1_x1)  * (b1_z2 - b1_z1)
     b2_volume = (b2_y2 - b2_y1) * (b2_x2 - b2_x1)  * (b2_z2 - b2_z1)
     union = b1_volume[:,0] + b2_volume[:,0] - intersection
 
     # 4. Compute IoU and reshape to [boxes1, boxes2]
     iou = intersection / union
     overlaps = iou.view(boxes2_repeat, boxes1_repeat)
     return overlaps
 
 def gt_anchor_matching(cf, anchors, gt_boxes, gt_class_ids=None):
     """Given the anchors and GT boxes, compute overlaps and identify positive
     anchors and deltas to refine them to match their corresponding GT boxes.
 
     anchors: [num_anchors, (y1, x1, y2, x2, (z1), (z2))]
     gt_boxes: [num_gt_boxes, (y1, x1, y2, x2, (z1), (z2))]
     gt_class_ids (optional): [num_gt_boxes] Integer class IDs for one stage detectors. in RPN case of Mask R-CNN,
     set all positive matches to 1 (foreground)
 
     Returns:
     anchor_class_matches: [N] (int32) matches between anchors and GT boxes.
                1 = positive anchor, -1 = negative anchor, 0 = neutral
     anchor_delta_targets: [N, (dy, dx, (dz), log(dh), log(dw), (log(dd)))] Anchor bbox deltas.
     """
 
     anchor_class_matches = np.zeros([anchors.shape[0]], dtype=np.int32)
     anchor_delta_targets = np.zeros((cf.rpn_train_anchors_per_image, 2*cf.dim))
     anchor_matching_iou = cf.anchor_matching_iou
 
     if gt_boxes is None:
         anchor_class_matches = np.full(anchor_class_matches.shape, fill_value=-1)
         return anchor_class_matches, anchor_delta_targets
 
     # for mrcnn: anchor matching is done for RPN loss, so positive labels are all 1 (foreground)
     if gt_class_ids is None:
         gt_class_ids = np.array([1] * len(gt_boxes))
 
     # Compute overlaps [num_anchors, num_gt_boxes]
     overlaps = compute_overlaps(anchors, gt_boxes)
 
     # Match anchors to GT Boxes
     # If an anchor overlaps a GT box with IoU >= anchor_matching_iou then it's positive.
     # If an anchor overlaps a GT box with IoU < 0.1 then it's negative.
     # Neutral anchors are those that don't match the conditions above,
     # and they don't influence the loss function.
     # However, don't keep any GT box unmatched (rare, but happens). Instead,
     # match it to the closest anchor (even if its max IoU is < 0.1).
 
     # 1. Set negative anchors first. They get overwritten below if a GT box is
     # matched to them. Skip boxes in crowd areas.
     anchor_iou_argmax = np.argmax(overlaps, axis=1)
     anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax]
     if anchors.shape[1] == 4:
         anchor_class_matches[(anchor_iou_max < 0.1)] = -1
     elif anchors.shape[1] == 6:
         anchor_class_matches[(anchor_iou_max < 0.01)] = -1
     else:
         raise ValueError('anchor shape wrong {}'.format(anchors.shape))
 
     # 2. Set an anchor for each GT box (regardless of IoU value).
     gt_iou_argmax = np.argmax(overlaps, axis=0)
     for ix, ii in enumerate(gt_iou_argmax):
         anchor_class_matches[ii] = gt_class_ids[ix]
 
     # 3. Set anchors with high overlap as positive.
     above_thresh_ixs = np.argwhere(anchor_iou_max >= anchor_matching_iou)
     anchor_class_matches[above_thresh_ixs] = gt_class_ids[anchor_iou_argmax[above_thresh_ixs]]
 
     # Subsample to balance positive anchors.
     ids = np.where(anchor_class_matches > 0)[0]
     extra = len(ids) - (cf.rpn_train_anchors_per_image // 2)
     if extra > 0:
         # Reset the extra ones to neutral
         ids = np.random.choice(ids, extra, replace=False)
         anchor_class_matches[ids] = 0
 
     # Leave all negative proposals negative for now and sample from them later in online hard example mining.
     # For positive anchors, compute shift and scale needed to transform them to match the corresponding GT boxes.
     ids = np.where(anchor_class_matches > 0)[0]
     ix = 0  # index into anchor_delta_targets
     for i, a in zip(ids, anchors[ids]):
         # closest gt box (it might have IoU < anchor_matching_iou)
         gt = gt_boxes[anchor_iou_argmax[i]]
 
         # convert coordinates to center plus width/height.
         gt_h = gt[2] - gt[0]
         gt_w = gt[3] - gt[1]
         gt_center_y = gt[0] + 0.5 * gt_h
         gt_center_x = gt[1] + 0.5 * gt_w
         # Anchor
         a_h = a[2] - a[0]
         a_w = a[3] - a[1]
         a_center_y = a[0] + 0.5 * a_h
         a_center_x = a[1] + 0.5 * a_w
 
         if cf.dim == 2:
             anchor_delta_targets[ix] = [
                 (gt_center_y - a_center_y) / a_h,
                 (gt_center_x - a_center_x) / a_w,
                 np.log(gt_h / a_h),
                 np.log(gt_w / a_w),
             ]
 
         else:
             gt_d = gt[5] - gt[4]
             gt_center_z = gt[4] + 0.5 * gt_d
             a_d = a[5] - a[4]
             a_center_z = a[4] + 0.5 * a_d
 
             anchor_delta_targets[ix] = [
                 (gt_center_y - a_center_y) / a_h,
                 (gt_center_x - a_center_x) / a_w,
                 (gt_center_z - a_center_z) / a_d,
                 np.log(gt_h / a_h),
                 np.log(gt_w / a_w),
                 np.log(gt_d / a_d)
             ]
 
         # normalize.
         anchor_delta_targets[ix] /= cf.rpn_bbox_std_dev
         ix += 1
 
     return anchor_class_matches, anchor_delta_targets
 
 
 
 def clip_to_window(window, boxes):
     """
         window: (y1, x1, y2, x2) / 3D: (z1, z2). The window in the image we want to clip to.
         boxes: [N, (y1, x1, y2, x2)]  / 3D: (z1, z2)
     """
     boxes[:, 0] = boxes[:, 0].clamp(float(window[0]), float(window[2]))
     boxes[:, 1] = boxes[:, 1].clamp(float(window[1]), float(window[3]))
     boxes[:, 2] = boxes[:, 2].clamp(float(window[0]), float(window[2]))
     boxes[:, 3] = boxes[:, 3].clamp(float(window[1]), float(window[3]))
 
     if boxes.shape[1] > 5:
         boxes[:, 4] = boxes[:, 4].clamp(float(window[4]), float(window[5]))
         boxes[:, 5] = boxes[:, 5].clamp(float(window[4]), float(window[5]))
 
     return boxes
 
 ############################################################
 #  Connected Componenent Analysis
 ############################################################
 
 def get_coords(binary_mask, n_components, dim):
     """
     loops over batch to perform connected component analysis on binary input mask. computes box coordinates around
     n_components - biggest components (rois).
     :param binary_mask: (b, y, x, (z)). binary mask for one specific foreground class.
     :param n_components: int. number of components to extract per batch element and class.
     :return: coords (b, n, (y1, x1, y2, x2 (,z1, z2))
     :return: batch_components (b, n, (y1, x1, y2, x2, (z1), (z2))
     """
     assert len(binary_mask.shape)==dim+1
     binary_mask = binary_mask.astype('uint8')
     batch_coords = []
     batch_components = []
     for ix,b in enumerate(binary_mask):
         clusters, n_cands = lb(b)  # performs connected component analysis.
         uniques, counts = np.unique(clusters, return_counts=True)
         keep_uniques = uniques[1:][np.argsort(counts[1:])[::-1]][:n_components] #only keep n_components largest components
         p_components = np.array([(clusters == ii) * 1 for ii in keep_uniques])  # separate clusters and concat
         p_coords = []
         if p_components.shape[0] > 0:
             for roi in p_components:
                 mask_ixs = np.argwhere(roi != 0)
 
                 # get coordinates around component.
                 roi_coords = [np.min(mask_ixs[:, 0]) - 1, np.min(mask_ixs[:, 1]) - 1, np.max(mask_ixs[:, 0]) + 1,
                                np.max(mask_ixs[:, 1]) + 1]
                 if dim == 3:
                     roi_coords += [np.min(mask_ixs[:, 2]), np.max(mask_ixs[:, 2])+1]
                 p_coords.append(roi_coords)
 
             p_coords = np.array(p_coords)
 
             #clip coords.
             p_coords[p_coords < 0] = 0
             p_coords[:, :4][p_coords[:, :4] > binary_mask.shape[-2]] = binary_mask.shape[-2]
             if dim == 3:
                 p_coords[:, 4:][p_coords[:, 4:] > binary_mask.shape[-1]] = binary_mask.shape[-1]
 
         batch_coords.append(p_coords)
         batch_components.append(p_components)
     return batch_coords, batch_components
 
 
 # noinspection PyCallingNonCallable
 def get_coords_gpu(binary_mask, n_components, dim):
     """
     loops over batch to perform connected component analysis on binary input mask. computes box coordiantes around
     n_components - biggest components (rois).
     :param binary_mask: (b, y, x, (z)). binary mask for one specific foreground class.
     :param n_components: int. number of components to extract per batch element and class.
     :return: coords (b, n, (y1, x1, y2, x2 (,z1, z2))
     :return: batch_components (b, n, (y1, x1, y2, x2, (z1), (z2))
     """
     raise Exception("throws floating point exception")
     assert len(binary_mask.shape)==dim+1
     binary_mask = binary_mask.type(torch.uint8)
     batch_coords = []
     batch_components = []
     for ix,b in enumerate(binary_mask):
         clusters, n_cands = lb(b.cpu().data.numpy())  # peforms connected component analysis.
         clusters = torch.from_numpy(clusters).cuda()
         uniques = torch.unique(clusters)
         counts = torch.stack([(clusters==unique).sum() for unique in uniques])
         keep_uniques = uniques[1:][torch.sort(counts[1:])[1].flip(0)][:n_components] #only keep n_components largest components
         p_components = torch.cat([(clusters == ii).unsqueeze(0) for ii in keep_uniques]).cuda()  # separate clusters and concat
         p_coords = []
         if p_components.shape[0] > 0:
             for roi in p_components:
                 mask_ixs = torch.nonzero(roi)
 
                 # get coordinates around component.
                 roi_coords = [torch.min(mask_ixs[:, 0]) - 1, torch.min(mask_ixs[:, 1]) - 1,
                               torch.max(mask_ixs[:, 0]) + 1,
                               torch.max(mask_ixs[:, 1]) + 1]
                 if dim == 3:
                     roi_coords += [torch.min(mask_ixs[:, 2]), torch.max(mask_ixs[:, 2])+1]
                 p_coords.append(roi_coords)
 
             p_coords = torch.tensor(p_coords)
 
             #clip coords.
             p_coords[p_coords < 0] = 0
             p_coords[:, :4][p_coords[:, :4] > binary_mask.shape[-2]] = binary_mask.shape[-2]
             if dim == 3:
                 p_coords[:, 4:][p_coords[:, 4:] > binary_mask.shape[-1]] = binary_mask.shape[-1]
 
         batch_coords.append(p_coords)
         batch_components.append(p_components)
     return batch_coords, batch_components
 
 
 ############################################################
 #  Pytorch Utility Functions
 ############################################################
 
 def unique1d(tensor):
     """discard all elements of tensor that occur more than once; make tensor unique.
     :param tensor:
     :return:
     """
     if tensor.size()[0] == 0 or tensor.size()[0] == 1:
         return tensor
     tensor = tensor.sort()[0]
     unique_bool = tensor[1:] != tensor[:-1]
     first_element = torch.tensor([True], dtype=torch.bool, requires_grad=False)
     if tensor.is_cuda:
         first_element = first_element.cuda()
     unique_bool = torch.cat((first_element, unique_bool), dim=0)
     return tensor[unique_bool.data]
 
 
 def intersect1d(tensor1, tensor2):
     aux = torch.cat((tensor1, tensor2), dim=0)
     aux = aux.sort(descending=True)[0]
     return aux[:-1][(aux[1:] == aux[:-1]).data]
 
 
 
 def shem(roi_probs_neg, negative_count, poolsize):
     """
     stochastic hard example mining: from a list of indices (referring to non-matched predictions),
     determine a pool of highest scoring (worst false positives) of size negative_count*poolsize.
     Then, sample n (= negative_count) predictions of this pool as negative examples for loss.
     :param roi_probs_neg: tensor of shape (n_predictions, n_classes).
     :param negative_count: int.
     :param poolsize: int.
     :return: (negative_count).  indices refer to the positions in roi_probs_neg. If pool smaller than expected due to
     limited negative proposals availabel, this function will return sampled indices of number < negative_count without
     throwing an error.
     """
     # sort according to higehst foreground score.
     probs, order = roi_probs_neg[:, 1:].max(1)[0].sort(descending=True)
     select = torch.tensor((poolsize * int(negative_count), order.size()[0])).min().int()
 
     pool_indices = order[:select]
     rand_idx = torch.randperm(pool_indices.size()[0])
     return pool_indices[rand_idx[:negative_count].cuda()]
 
 
 ############################################################
 #  Weight Init
 ############################################################
 
 
 def initialize_weights(net):
     """Initialize model weights. Current Default in Pytorch (version 0.4.1) is initialization from a uniform distriubtion.
     Will expectably be changed to kaiming_uniform in future versions.
     """
     init_type = net.cf.weight_init
 
     for m in [module for module in net.modules() if type(module) in [torch.nn.Conv2d, torch.nn.Conv3d,
                                                                      torch.nn.ConvTranspose2d,
                                                                      torch.nn.ConvTranspose3d,
                                                                      torch.nn.Linear]]:
         if init_type == 'xavier_uniform':
             torch.nn.init.xavier_uniform_(m.weight.data)
             if m.bias is not None:
                 m.bias.data.zero_()
 
         elif init_type == 'xavier_normal':
             torch.nn.init.xavier_normal_(m.weight.data)
             if m.bias is not None:
                 m.bias.data.zero_()
 
         elif init_type == "kaiming_uniform":
             torch.nn.init.kaiming_uniform_(m.weight.data, mode='fan_out', nonlinearity=net.cf.relu, a=0)
             if m.bias is not None:
                 fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                 bound = 1 / np.sqrt(fan_out)
                 torch.nn.init.uniform_(m.bias, -bound, bound)
 
         elif init_type == "kaiming_normal":
             torch.nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity=net.cf.relu, a=0)
             if m.bias is not None:
                 fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                 bound = 1 / np.sqrt(fan_out)
                 torch.nn.init.normal_(m.bias, -bound, bound)
     net.logger.info("applied {} weight init.".format(init_type))
\ No newline at end of file