diff --git a/.gitignore b/.gitignore
index 758000e..792f3c2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,12 +1,13 @@
 *.pyc
 *.pickle
 *.ipynb_checkpoints*
 *.npy
 *.pkl
 *.log
 *.png
 *.jpg
 __pycache__/*
 .idea/*
+*.so
 
 !/assets/*
diff --git a/custom_extensions/roi_align/src/RoIAlign_cuda.cu b/custom_extensions/roi_align/src/RoIAlign_cuda.cu
index d768f2d..39426bf 100644
--- a/custom_extensions/roi_align/src/RoIAlign_cuda.cu
+++ b/custom_extensions/roi_align/src/RoIAlign_cuda.cu
@@ -1,422 +1,422 @@
 /*
 ROIAlign implementation in CUDA from pytorch framework
 (https://github.com/pytorch/vision/tree/master/torchvision/csrc/cuda on Nov 14 2019)
 
 */
 
 #include <ATen/ATen.h>
 #include <ATen/TensorUtils.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAApplyUtils.cuh>
 #include <typeinfo>
 #include "cuda_helpers.h"
 
 template <typename T>
 __device__ T bilinear_interpolate(
     const T* input,
     const int height,
     const int width,
     T y,
     T x,
     const int index /* index for debug only*/) {
   // deal with cases that inverse elements are out of feature map boundary
   if (y < -1.0 || y > height || x < -1.0 || x > width) {
     // empty
     return 0;
   }
 
   if (y <= 0)
     y = 0;
   if (x <= 0)
     x = 0;
 
   int y_low = (int)y;
   int x_low = (int)x;
   int y_high;
   int x_high;
 
   if (y_low >= height - 1) {
     y_high = y_low = height - 1;
     y = (T)y_low;
   } else {
     y_high = y_low + 1;
   }
 
   if (x_low >= width - 1) {
     x_high = x_low = width - 1;
     x = (T)x_low;
   } else {
     x_high = x_low + 1;
   }
 
   T ly = y - y_low;
   T lx = x - x_low;
   T hy = 1. - ly, hx = 1. - lx;
 
   // do bilinear interpolation
   T v1 = input[y_low * width + x_low];
   T v2 = input[y_low * width + x_high];
   T v3 = input[y_high * width + x_low];
   T v4 = input[y_high * width + x_high];
   T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
 
   T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
 
   return val;
 }
 
 template <typename T>
 __global__ void RoIAlignForward(
     const int nthreads,
     const T* input,
     const T spatial_scale,
     const int channels,
     const int height,
     const int width,
     const int pooled_height,
     const int pooled_width,
     const int sampling_ratio,
     const T* rois,
     T* output) {
   CUDA_1D_KERNEL_LOOP(index, nthreads) {
     // (n, c, ph, pw) is an element in the pooled output
     const int pw = index % pooled_width;
     const int ph = (index / pooled_width) % pooled_height;
     const int c = (index / pooled_width / pooled_height) % channels;
     const int n = index / pooled_width / pooled_height / channels;
 
     const T* offset_rois = rois + n * 5;
     int roi_batch_ind = offset_rois[0];
 
     // Do not using rounding; this implementation detail is critical
     T roi_start_h = offset_rois[1] * spatial_scale;
     T roi_start_w = offset_rois[2] * spatial_scale;
     T roi_end_h = offset_rois[3] * spatial_scale;
     T roi_end_w = offset_rois[4] * spatial_scale;
 
     // Force malformed ROIs to be 1x1
     T roi_width = max(roi_end_w - roi_start_w, (T)1.);
     T roi_height = max(roi_end_h - roi_start_h, (T)1.);
 
     T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
     T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
 
     const T* offset_input =
         input + (roi_batch_ind * channels + c) * height * width;
 
     // We use roi_bin_grid to sample the grid and mimic integral
     int roi_bin_grid_h = (sampling_ratio > 0)
         ? sampling_ratio
         : ceil(roi_height / pooled_height); // e.g., = 2
     int roi_bin_grid_w =
         (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
 
     // We do average (integral) pooling inside a bin
     const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
     T output_val = 0.;
     for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
     {
       const T y = roi_start_h + ph * bin_size_h +
-          static_cast<T>(iy + .5f) * (bin_size_h - 1.f) / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
+          static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
       for (int ix = 0; ix < roi_bin_grid_w; ix++) {
         const T x = roi_start_w + pw * bin_size_w +
-            static_cast<T>(ix + .5f) * (bin_size_w - 1.f) / static_cast<T>(roi_bin_grid_w);
+            static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
         T val = bilinear_interpolate(offset_input, height, width, y, x, index);
         output_val += val;
       }
     }
     output_val /= count;
 
     output[index] = output_val;
   }
 }
 
 template <typename T>
 __device__ void bilinear_interpolate_gradient(
     const int height,
     const int width,
     T y,
     T x,
     T& w1,
     T& w2,
     T& w3,
     T& w4,
     int& x_low,
     int& x_high,
     int& y_low,
     int& y_high,
     const int index /* index for debug only*/) {
   // deal with cases that inverse elements are out of feature map boundary
   if (y < -1.0 || y > height || x < -1.0 || x > width) {
     // empty
     w1 = w2 = w3 = w4 = 0.;
     x_low = x_high = y_low = y_high = -1;
     return;
   }
 
   if (y <= 0)
     y = 0;
   if (x <= 0)
     x = 0;
 
   y_low = (int)y;
   x_low = (int)x;
 
   if (y_low >= height - 1) {
     y_high = y_low = height - 1;
     y = (T)y_low;
   } else {
     y_high = y_low + 1;
   }
 
   if (x_low >= width - 1) {
     x_high = x_low = width - 1;
     x = (T)x_low;
   } else {
     x_high = x_low + 1;
   }
 
   T ly = y - y_low;
   T lx = x - x_low;
   T hy = 1. - ly, hx = 1. - lx;
 
   // reference in forward
   // T v1 = input[y_low * width + x_low];
   // T v2 = input[y_low * width + x_high];
   // T v3 = input[y_high * width + x_low];
   // T v4 = input[y_high * width + x_high];
   // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
 
   w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
 
   return;
 }
 
 template <typename T>
 __global__ void RoIAlignBackward(
     const int nthreads,
     const T* grad_output,
     const T spatial_scale,
     const int channels,
     const int height,
     const int width,
     const int pooled_height,
     const int pooled_width,
     const int sampling_ratio,
     T* grad_input,
     const T* rois,
     const int n_stride,
     const int c_stride,
     const int h_stride,
     const int w_stride)
 {
   CUDA_1D_KERNEL_LOOP(index, nthreads) {
     // (n, c, ph, pw) is an element in the pooled output
     int pw = index % pooled_width;
     int ph = (index / pooled_width) % pooled_height;
     int c = (index / pooled_width / pooled_height) % channels;
     int n = index / pooled_width / pooled_height / channels;
 
     const T* offset_rois = rois + n * 5;
     int roi_batch_ind = offset_rois[0];
 
     // Do not using rounding; this implementation detail is critical
     T roi_start_h = offset_rois[1] * spatial_scale;
     T roi_start_w = offset_rois[2] * spatial_scale;
     T roi_end_h = offset_rois[3] * spatial_scale;
     T roi_end_w = offset_rois[4] * spatial_scale;
 
     // Force malformed ROIs to be 1x1
     T roi_width = max(roi_end_w - roi_start_w, (T)1.);
     T roi_height = max(roi_end_h - roi_start_h, (T)1.);
     T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
     T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
 
     T* offset_grad_input =
         grad_input + ((roi_batch_ind * channels + c) * height * width);
 
     // We need to index the gradient using the tensor strides to access the
     // correct values.
     int output_offset = n * n_stride + c * c_stride;
     const T* offset_grad_output = grad_output + output_offset;
     const T grad_output_this_bin =
         offset_grad_output[ph * h_stride + pw * w_stride];
 
     // We use roi_bin_grid to sample the grid and mimic integral
     int roi_bin_grid_h = (sampling_ratio > 0)
         ? sampling_ratio
         : ceil(roi_height / pooled_height); // e.g., = 2
     int roi_bin_grid_w =
         (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
 
     // We do average (integral) pooling inside a bin
     const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
 
     for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
     {
       const T y = roi_start_h + ph * bin_size_h +
-          static_cast<T>(iy + .5f) * (bin_size_h - 1.f) / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
+          static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
       for (int ix = 0; ix < roi_bin_grid_w; ix++) {
         const T x = roi_start_w + pw * bin_size_w  +
-            static_cast<T>(ix + .5f) * (bin_size_w- 1.f) / static_cast<T>(roi_bin_grid_w);
+            static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
 
         T w1, w2, w3, w4;
         int x_low, x_high, y_low, y_high;
 
         bilinear_interpolate_gradient(
             height,
             width,
             y,
             x,
             w1,
             w2,
             w3,
             w4,
             x_low,
             x_high,
             y_low,
             y_high,
             index);
 
         T g1 = grad_output_this_bin * w1 / count;
         T g2 = grad_output_this_bin * w2 / count;
         T g3 = grad_output_this_bin * w3 / count;
         T g4 = grad_output_this_bin * w4 / count;
 
         if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
           atomicAdd(
               offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
           atomicAdd(
               offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
           atomicAdd(
               offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
           atomicAdd(
               offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
         } // if
       } // ix
     } // iy
   } // CUDA_1D_KERNEL_LOOP
 } // RoIAlignBackward
 
 at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, const at::Tensor& rois, const float spatial_scale,
                                 const int pooled_height, const int pooled_width, const int sampling_ratio) {
   /*
    input: feature-map tensor, shape (batch, n_channels, y, x(, z))
    */
   AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor");
   AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
 
   at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
 
   at::CheckedFrom c = "ROIAlign_forward_cuda";
   at::checkAllSameGPU(c, {input_t, rois_t});
   at::checkAllSameType(c, {input_t, rois_t});
 
   at::cuda::CUDAGuard device_guard(input.device());
 
   int num_rois = rois.size(0);
   int channels = input.size(1);
   int height = input.size(2);
   int width = input.size(3);
 
   at::Tensor output = at::zeros(
       {num_rois, channels, pooled_height, pooled_width}, input.options());
 
   auto output_size = num_rois * pooled_height * pooled_width * channels;
   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
   dim3 grid(std::min(
       at::cuda::ATenCeilDiv(
           static_cast<int64_t>(output_size), static_cast<int64_t>(512)),
       static_cast<int64_t>(4096)));
   dim3 block(512);
 
   if (output.numel() == 0) {
     AT_CUDA_CHECK(cudaGetLastError());
     return output;
   }
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign_forward", [&] {
     RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>(
         output_size,
         input.contiguous().data_ptr<scalar_t>(),
         spatial_scale,
         channels,
         height,
         width,
         pooled_height,
         pooled_width,
         sampling_ratio,
         rois.contiguous().data_ptr<scalar_t>(),
         output.data_ptr<scalar_t>());
   });
   AT_CUDA_CHECK(cudaGetLastError());
   return output;
 }
 
 at::Tensor ROIAlign_backward_cuda(
     const at::Tensor& grad,
     const at::Tensor& rois,
     const float spatial_scale,
     const int pooled_height,
     const int pooled_width,
     const int batch_size,
     const int channels,
     const int height,
     const int width,
     const int sampling_ratio) {
   AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor");
   AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
 
   at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
 
   at::CheckedFrom c = "ROIAlign_backward_cuda";
   at::checkAllSameGPU(c, {grad_t, rois_t});
   at::checkAllSameType(c, {grad_t, rois_t});
 
   at::cuda::CUDAGuard device_guard(grad.device());
 
   at::Tensor grad_input =
       at::zeros({batch_size, channels, height, width}, grad.options());
 
   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
   dim3 grid(std::min(
       at::cuda::ATenCeilDiv(
           static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)),
       static_cast<int64_t>(4096)));
   dim3 block(512);
 
   // handle possibly empty gradients
   if (grad.numel() == 0) {
     AT_CUDA_CHECK(cudaGetLastError());
     return grad_input;
   }
 
   int n_stride = grad.stride(0);
   int c_stride = grad.stride(1);
   int h_stride = grad.stride(2);
   int w_stride = grad.stride(3);
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_backward", [&] {
     RoIAlignBackward<scalar_t><<<grid, block, 0, stream>>>(
         grad.numel(),
         grad.data_ptr<scalar_t>(),
         spatial_scale,
         channels,
         height,
         width,
         pooled_height,
         pooled_width,
         sampling_ratio,
         grad_input.data_ptr<scalar_t>(),
         rois.contiguous().data_ptr<scalar_t>(),
         n_stride,
         c_stride,
         h_stride,
         w_stride);
   });
   AT_CUDA_CHECK(cudaGetLastError());
   return grad_input;
 }
\ No newline at end of file
diff --git a/custom_extensions/roi_align/src/RoIAlign_cuda_3d.cu b/custom_extensions/roi_align/src/RoIAlign_cuda_3d.cu
index 0c75a34..182274f 100644
--- a/custom_extensions/roi_align/src/RoIAlign_cuda_3d.cu
+++ b/custom_extensions/roi_align/src/RoIAlign_cuda_3d.cu
@@ -1,487 +1,487 @@
 /*
 ROIAlign implementation in CUDA from pytorch framework
 (https://github.com/pytorch/vision/tree/master/torchvision/csrc/cuda on Nov 14 2019)
 
 Adapted for additional 3D capability by G. Ramien, DKFZ Heidelberg
 */
 
 #include <ATen/ATen.h>
 #include <ATen/TensorUtils.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAApplyUtils.cuh>
 #include <cstdio>
 #include "cuda_helpers.h"
 
 /*-------------- gpu kernels -----------------*/
 
 template <typename T>
 __device__ T linear_interpolate(const T xl, const T val_low, const T val_high){
 
   T val = (val_high - val_low) * xl + val_low;
   return val;
 }
 
 template <typename T>
 __device__ T trilinear_interpolate(const T* input, const int height, const int width, const int depth,
                 T y, T x, T z, const int index /* index for debug only*/) {
   // deal with cases that inverse elements are out of feature map boundary
   if (y < -1.0 || y > height || x < -1.0 || x > width || z < -1.0 || z > depth) {
     // empty
     return 0;
   }
   if (y <= 0)
     y = 0;
   if (x <= 0)
     x = 0;
   if (z <= 0)
     z = 0;
 
   int y0 = (int)y;
   int x0 = (int)x;
   int z0 = (int)z;
   int y1;
   int x1;
   int z1;
 
   if (y0 >= height - 1) {
   /*if nearest gridpoint to y on the lower end is on border or border-1, set low, high, mid(=actual point) to border-1*/
     y1 = y0 = height - 1;
     y = (T)y0;
   } else {
     /* y1 is one pixel from y0, y is the actual point somewhere in between  */
     y1 = y0 + 1;
   }
   if (x0 >= width - 1) {
     x1 = x0 = width - 1;
     x = (T)x0;
   } else {
     x1 = x0 + 1;
   }
   if (z0 >= depth - 1) {
     z1 = z0 = depth - 1;
     z = (T)z0;
   } else {
     z1 = z0 + 1;
   }
 
 
   // do linear interpolation of x values
   // distance of actual point to lower boundary point, already normalized since x_high - x0 = 1
   T dis = x - x0;
   /*  accessing element b,c,y,x,z in 1D-rolled-out array of a tensor with dimensions (B, C, Y, X, Z):
       tensor[b,c,y,x,z] = arr[ (((b*C+c)*Y+y)*X + x)*Z + z ] = arr[ alpha + (y*X + x)*Z + z ]
       with alpha = batch&channel locator = (b*C+c)*YXZ.
-      hence, as current input pointer is already offset by alpha: y,x,z at input[( y*X + x)*Z + z], where
+      hence, as current input pointer is already offset by alpha: y,x,z is at input[( y*X + x)*Z + z], where
       X = width, Z = depth.
   */
   T x00 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z0], input[(y0*width+ x1)*depth+z0]);
   T x10 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z0], input[(y1*width+ x1)*depth+z0]);
   T x01 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z1], input[(y0*width+ x1)*depth+z1]);
   T x11 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z1], input[(y1*width+ x1)*depth+z1]);
 
   // linear interpol of y values = bilinear interpol of f(x,y)
   dis = y - y0;
   T xy0 = linear_interpolate(dis, x00, x10);
   T xy1 = linear_interpolate(dis, x01, x11);
 
   // linear interpol of z value = trilinear interpol of f(x,y,z)
   dis = z - z0;
   T xyz = linear_interpolate(dis, xy0, xy1);
 
   return xyz;
 }
 
 template <typename T>
 __device__ void trilinear_interpolate_gradient(const int height, const int width, const int depth, T y, T x, T z,
     T& g000, T& g001, T& g010, T& g100, T& g011, T& g101, T& g110, T& g111,
     int& x0, int& x1, int& y0, int& y1, int& z0, int&z1, const int index /* index for debug only*/)
 {
   // deal with cases that inverse elements are out of feature map boundary
   if (y < -1.0 || y > height || x < -1.0 || x > width || z < -1.0 || z > depth) {
     // empty
     g000 = g001 = g010 = g100 = g011 = g101 = g110 = g111 = 0.;
     x0 = x1 = y0 = y1 = z0 = z1 = -1;
     return;
   }
 
   if (y <= 0)
     y = 0;
   if (x <= 0)
     x = 0;
   if (z <= 0)
     z = 0;
 
   y0 = (int)y;
   x0 = (int)x;
   z0 = (int)z;
 
   if (y0 >= height - 1) {
     y1 = y0 = height - 1;
     y = (T)y0;
   } else {
     y1 = y0 + 1;
   }
 
   if (x0 >= width - 1) {
     x1 = x0 = width - 1;
     x = (T)x0;
   } else {
     x1 = x0 + 1;
   }
 
   if (z0 >= depth - 1) {
     z1 = z0 = depth - 1;
     z = (T)z0;
   } else {
     z1 = z0 + 1;
   }
 
   // forward calculations are added as hints
   T dis_x = x - x0;
   //T x00 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z0], input[(y0*width+ x1)*depth+z0]); // v000, v100
   //T x10 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z0], input[(y1*width+ x1)*depth+z0]); // v010, v110
   //T x01 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z1], input[(y0*width+ x1)*depth+z1]); // v001, v101
   //T x11 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z1], input[(y1*width+ x1)*depth+z1]); // v011, v111
 
   // linear interpol of y values = bilinear interpol of f(x,y)
   T dis_y = y - y0;
   //T xy0 = linear_interpolate(dis, x00, x10);
   //T xy1 = linear_interpolate(dis, x01, x11);
 
   // linear interpol of z value = trilinear interpol of f(x,y,z)
   T dis_z = z - z0;
   //T xyz = linear_interpolate(dis, xy0, xy1);
 
   /* need: grad_i := d(xyz)/d(v_i) with v_i = input_value_i  for all i = 0,..,7 (eight input values --> eight-entry gradient)
      d(lin_interp(dis,x,y))/dx = (-dis +1) and d(lin_interp(dis,x,y))/dy = dis --> derivatives are indep of x,y.
      notation: gxyz = gradient for d(trilin_interp)/d(input_value_at_xyz)
      below grads were calculated by hand
      save time by reusing (1-dis_x) = 1-x+x0 = x1-x =: dis_x1 */
   T dis_x1 = (1-dis_x), dis_y1 = (1-dis_y), dis_z1 = (1-dis_z);
 
   g000 = dis_z1 * dis_y1  * dis_x1;
   g001 = dis_z  * dis_y1  * dis_x1;
   g010 = dis_z1 * dis_y   * dis_x1;
   g100 = dis_z1 * dis_y1  * dis_x;
   g011 = dis_z  * dis_y   * dis_x1;
   g101 = dis_z  * dis_y1  * dis_x;
   g110 = dis_z1 * dis_y   * dis_x;
   g111 = dis_z  * dis_y   * dis_x;
 
   return;
 }
 
 template <typename T>
 __global__ void RoIAlignForward(const int nthreads, const T* input, const T spatial_scale, const int channels,
     const int height, const int width, const int depth, const int pooled_height, const int pooled_width,
     const int pooled_depth, const int sampling_ratio, const T* rois, T* output)
 {
 
   CUDA_1D_KERNEL_LOOP(index, nthreads) {
     // (n, c, ph, pw, pd) is an element in the pooled output
     int pd =  index % pooled_depth;
     int pw = (index / pooled_depth) % pooled_width;
     int ph = (index / pooled_depth / pooled_width) % pooled_height;
     int c  = (index / pooled_depth / pooled_width / pooled_height) % channels;
     int n  =  index / pooled_depth / pooled_width / pooled_height / channels;
 
 
     // rois are (y1,x1,y2,x2,z1,z2) --> tensor of shape (n_rois, 6)
     const T* offset_rois = rois + n * 7;
     int roi_batch_ind = offset_rois[0];
     // Do not use rounding; this implementation detail is critical
     T roi_start_h = offset_rois[1] * spatial_scale;
     T roi_start_w = offset_rois[2] * spatial_scale;
     T roi_end_h = offset_rois[3] * spatial_scale;
     T roi_end_w = offset_rois[4] * spatial_scale;
     T roi_start_d = offset_rois[5] * spatial_scale;
     T roi_end_d = offset_rois[6] * spatial_scale;
 
     // Force malformed ROIs to be 1x1
     T roi_height = max(roi_end_h - roi_start_h, (T)1.);
     T roi_width = max(roi_end_w - roi_start_w, (T)1.);
     T roi_depth = max(roi_end_d - roi_start_d, (T)1.);
 
     T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
     T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
     T bin_size_d = static_cast<T>(roi_depth) / static_cast<T>(pooled_depth);
 
     const T* offset_input =
         input + (roi_batch_ind * channels + c) * height * width * depth;
 
     // We use roi_bin_grid to sample the grid and mimic integral
     // roi_bin_grid == nr of sampling points per bin >= 1
     int roi_bin_grid_h =
         (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
     int roi_bin_grid_w =
         (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
     int roi_bin_grid_d =
         (sampling_ratio > 0) ? sampling_ratio : ceil(roi_depth / pooled_depth);
 
     // We do average (integral) pooling inside a bin
     const T n_voxels = roi_bin_grid_h * roi_bin_grid_w * roi_bin_grid_d; // e.g. = 4
 
     T output_val = 0.;
     for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
     {
       const T y = roi_start_h + ph * bin_size_h +
-          static_cast<T>(iy + .5f) * (bin_size_h - 1.f) / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5, always in the middle of two grid pointsk
+          static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5, always in the middle of two grid pointsk
 
       for (int ix = 0; ix < roi_bin_grid_w; ix++)
       {
         const T x = roi_start_w + pw * bin_size_w +
-            static_cast<T>(ix + .5f) * (bin_size_w - 1.f) / static_cast<T>(roi_bin_grid_w);
+            static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
 
         for (int iz = 0; iz < roi_bin_grid_d; iz++)
         {
           const T z = roi_start_d + pd * bin_size_d +
-              static_cast<T>(iz + .5f) * (bin_size_d - 1.f) / static_cast<T>(roi_bin_grid_d);
+              static_cast<T>(iz + .5f) * bin_size_d / static_cast<T>(roi_bin_grid_d);
           // TODO verify trilinear interpolation
           T val = trilinear_interpolate(offset_input, height, width, depth, y, x, z, index);
           output_val += val;
         } // z iterator and calc+add value
       } // x iterator
     } // y iterator
     output_val /= n_voxels;
 
     output[index] = output_val;
   }
 }
 
 template <typename T>
 __global__ void RoIAlignBackward(const int nthreads, const T* grad_output, const T spatial_scale, const int channels,
     const int height, const int width, const int depth, const int pooled_height, const int pooled_width,
     const int pooled_depth, const int sampling_ratio, T* grad_input, const T* rois,
     const int n_stride, const int c_stride, const int h_stride, const int w_stride, const int d_stride)
 {
 
   CUDA_1D_KERNEL_LOOP(index, nthreads) {
     // (n, c, ph, pw, pd) is an element in the pooled output
     int pd =  index % pooled_depth;
     int pw = (index / pooled_depth) % pooled_width;
     int ph = (index / pooled_depth / pooled_width) % pooled_height;
     int c  = (index / pooled_depth / pooled_width / pooled_height) % channels;
     int n  =  index / pooled_depth / pooled_width / pooled_height / channels;
 
 
     const T* offset_rois = rois + n * 7;
     int roi_batch_ind = offset_rois[0];
 
     // Do not using rounding; this implementation detail is critical
     T roi_start_h = offset_rois[1] * spatial_scale;
     T roi_start_w = offset_rois[2] * spatial_scale;
     T roi_end_h = offset_rois[3] * spatial_scale;
     T roi_end_w = offset_rois[4] * spatial_scale;
     T roi_start_d = offset_rois[5] * spatial_scale;
     T roi_end_d = offset_rois[6] * spatial_scale;
 
 
     // Force malformed ROIs to be 1x1
     T roi_width = max(roi_end_w - roi_start_w, (T)1.);
     T roi_height = max(roi_end_h - roi_start_h, (T)1.);
     T roi_depth = max(roi_end_d - roi_start_d, (T)1.);
     T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
     T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
     T bin_size_d = static_cast<T>(roi_depth) / static_cast<T>(pooled_depth);
 
     // offset: index b,c,y,x,z of tensor of shape (B,C,Y,X,Z) is
     // b*C*Y*X*Z + c * Y*X*Z + y * X*Z + x *Z + z = (b*C+c)Y*X*Z + ...
     T* offset_grad_input =
         grad_input + ((roi_batch_ind * channels + c) * height * width * depth);
 
     // We need to index the gradient using the tensor strides to access the correct values.
     int output_offset = n * n_stride + c * c_stride;
     const T* offset_grad_output = grad_output + output_offset;
     const T grad_output_this_bin = offset_grad_output[ph * h_stride + pw * w_stride + pd * d_stride];
 
     // We use roi_bin_grid to sample the grid and mimic integral
     int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
     int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
     int roi_bin_grid_d = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_depth / pooled_depth);
 
     // We do average (integral) pooling inside a bin
     const T n_voxels = roi_bin_grid_h * roi_bin_grid_w * roi_bin_grid_d; // e.g. = 6
 
     for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
     {
       const T y = roi_start_h + ph * bin_size_h +
-          static_cast<T>(iy + .5f) * (bin_size_h - 1.f) / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
+          static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
 
       for (int ix = 0; ix < roi_bin_grid_w; ix++)
       {
         const T x = roi_start_w + pw * bin_size_w +
-          static_cast<T>(ix + .5f) * (bin_size_w - 1.f) / static_cast<T>(roi_bin_grid_w);
+          static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
 
         for (int iz = 0; iz < roi_bin_grid_d; iz++)
         {
           const T z = roi_start_d + pd * bin_size_d +
-              static_cast<T>(iz + .5f) * (bin_size_d - 1.f) / static_cast<T>(roi_bin_grid_d);
+              static_cast<T>(iz + .5f) * bin_size_d / static_cast<T>(roi_bin_grid_d);
 
           T g000, g001, g010, g100, g011, g101, g110, g111; // will hold the current partial derivatives
           int x0, x1, y0, y1, z0, z1;
           /* notation: gxyz = gradient at xyz, where x,y,z need to lie on feature-map grid (i.e., =x0,x1 etc.) */
           trilinear_interpolate_gradient(height, width, depth, y, x, z,
                                          g000, g001, g010, g100, g011, g101, g110, g111,
                                          x0, x1, y0, y1, z0, z1, index);
           /* chain rule: derivatives (i.e., the gradient) of trilin_interpolate(v1,v2,v3,v4,...) (div by n_voxels
              as we actually need gradient of whole roi_align) are multiplied with gradient so far*/
           g000 *= grad_output_this_bin / n_voxels;
           g001 *= grad_output_this_bin / n_voxels;
           g010 *= grad_output_this_bin / n_voxels;
           g100 *= grad_output_this_bin / n_voxels;
           g011 *= grad_output_this_bin / n_voxels;
           g101 *= grad_output_this_bin / n_voxels;
           g110 *= grad_output_this_bin / n_voxels;
           g111 *= grad_output_this_bin / n_voxels;
 
           if (x0 >= 0 && x1 >= 0 && y0 >= 0 && y1 >= 0 && z0 >= 0 && z1 >= 0)
           { // atomicAdd(address, content) reads content under address, adds content to it, while: no other thread
             // can interfere with the memory at address during this operation (thread lock, therefore "atomic").
             atomicAdd(offset_grad_input + (y0 * width + x0) * depth + z0, static_cast<T>(g000));
             atomicAdd(offset_grad_input + (y0 * width + x0) * depth + z1, static_cast<T>(g001));
             atomicAdd(offset_grad_input + (y1 * width + x0) * depth + z0, static_cast<T>(g010));
             atomicAdd(offset_grad_input + (y0 * width + x1) * depth + z0, static_cast<T>(g100));
             atomicAdd(offset_grad_input + (y1 * width + x0) * depth + z1, static_cast<T>(g011));
             atomicAdd(offset_grad_input + (y0 * width + x1) * depth + z1, static_cast<T>(g101));
             atomicAdd(offset_grad_input + (y1 * width + x1) * depth + z0, static_cast<T>(g110));
             atomicAdd(offset_grad_input + (y1 * width + x1) * depth + z1, static_cast<T>(g111));
           } // if
         } // iz
       } // ix
     } // iy
   } // CUDA_1D_KERNEL_LOOP
 } // RoIAlignBackward
 
 
 /*----------- wrapper functions ----------------*/
 
 at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, const at::Tensor& rois, const float spatial_scale,
                                 const int pooled_height, const int pooled_width, const int pooled_depth, const int sampling_ratio) {
   /*
    input: feature-map tensor, shape (batch, n_channels, y, x(, z))
    */
   AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor");
   AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
 
   at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
 
   at::CheckedFrom c = "ROIAlign_forward_cuda";
   at::checkAllSameGPU(c, {input_t, rois_t});
   at::checkAllSameType(c, {input_t, rois_t});
 
   at::cuda::CUDAGuard device_guard(input.device());
 
   auto num_rois = rois.size(0);
   auto channels = input.size(1);
   auto height = input.size(2);
   auto width = input.size(3);
   auto depth = input.size(4);
 
   at::Tensor output = at::zeros(
       {num_rois, channels, pooled_height, pooled_width, pooled_depth}, input.options());
 
   auto output_size = num_rois * channels * pooled_height * pooled_width * pooled_depth;
   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
   dim3 grid(std::min(
       at::cuda::ATenCeilDiv(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)));
   dim3 block(512);
 
   if (output.numel() == 0) {
     AT_CUDA_CHECK(cudaGetLastError());
     return output;
   }
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign forward in 3d", [&] {
     RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>(
         output_size,
         input.contiguous().data_ptr<scalar_t>(),
         spatial_scale,
         channels,
         height,
         width,
         depth,
         pooled_height,
         pooled_width,
         pooled_depth,
         sampling_ratio,
         rois.contiguous().data_ptr<scalar_t>(),
         output.data_ptr<scalar_t>());
   });
   AT_CUDA_CHECK(cudaGetLastError());
   return output;
 }
 
 at::Tensor ROIAlign_backward_cuda(
     const at::Tensor& grad,
     const at::Tensor& rois,
     const float spatial_scale,
     const int pooled_height,
     const int pooled_width,
     const int pooled_depth,
     const int batch_size,
     const int channels,
     const int height,
     const int width,
     const int depth,
     const int sampling_ratio)
 {
   AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor");
   AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
 
   at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
 
   at::CheckedFrom c = "ROIAlign_backward_cuda";
   at::checkAllSameGPU(c, {grad_t, rois_t});
   at::checkAllSameType(c, {grad_t, rois_t});
 
   at::cuda::CUDAGuard device_guard(grad.device());
 
   at::Tensor grad_input =
       at::zeros({batch_size, channels, height, width, depth}, grad.options());
 
   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
   dim3 grid(std::min(
       at::cuda::ATenCeilDiv(
           static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)),
       static_cast<int64_t>(4096)));
   dim3 block(512);
 
   // handle possibly empty gradients
   if (grad.numel() == 0) {
     AT_CUDA_CHECK(cudaGetLastError());
     return grad_input;
   }
 
   int n_stride = grad.stride(0);
   int c_stride = grad.stride(1);
   int h_stride = grad.stride(2);
   int w_stride = grad.stride(3);
   int d_stride = grad.stride(4);
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign backward 3D", [&] {
     RoIAlignBackward<scalar_t><<<grid, block, 0, stream>>>(
         grad.numel(),
         grad.data_ptr<scalar_t>(),
         spatial_scale,
         channels,
         height,
         width,
         depth,
         pooled_height,
         pooled_width,
         pooled_depth,
         sampling_ratio,
         grad_input.data_ptr<scalar_t>(),
         rois.contiguous().data_ptr<scalar_t>(),
         n_stride,
         c_stride,
         h_stride,
         w_stride,
         d_stride);
   });
   AT_CUDA_CHECK(cudaGetLastError());
   return grad_input;
 }
\ No newline at end of file
diff --git a/exec.py b/exec.py
index f82b2f3..a448a17 100644
--- a/exec.py
+++ b/exec.py
@@ -1,247 +1,247 @@
 #!/usr/bin/env python
 # Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 """execution script."""
 
 import argparse
 import os
 import time
 import torch
 
 import utils.exp_utils as utils
 from evaluator import Evaluator
 from predictor import Predictor
 from plotting import plot_batch_prediction
 
 
 def train(logger):
     """
     perform the training routine for a given fold. saves plots and selected parameters to the experiment dir
     specified in the configs.
     """
     logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format(
         cf.dim, cf.fold, cf.exp_dir, cf.model))
 
     net = model.net(cf, logger).cuda()
     optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay)
     model_selector = utils.ModelSelector(cf, logger)
     train_evaluator = Evaluator(cf, logger, mode='train')
     val_evaluator = Evaluator(cf, logger, mode=cf.val_mode)
 
     starting_epoch = 1
 
     # prepare monitoring
     monitor_metrics = utils.prepare_monitoring(cf)
 
     if cf.resume_to_checkpoint:
         starting_epoch, monitor_metrics = utils.load_checkpoint(cf.resume_to_checkpoint, net, optimizer)
         logger.info('resumed to checkpoint {} at epoch {}'.format(cf.resume_to_checkpoint, starting_epoch))
 
     logger.info('loading dataset and initializing batch generators...')
     batch_gen = data_loader.get_train_generators(cf, logger)
 
     for epoch in range(starting_epoch, cf.num_epochs + 1):
 
         logger.info('starting training epoch {}'.format(epoch))
         for param_group in optimizer.param_groups:
             param_group['lr'] = cf.learning_rate[epoch - 1]
 
         start_time = time.time()
 
         net.train()
         train_results_list = []
 
         for bix in range(cf.num_train_batches):
             batch = next(batch_gen['train'])
             tic_fw = time.time()
             results_dict = net.train_forward(batch)
             tic_bw = time.time()
             optimizer.zero_grad()
             results_dict['torch_loss'].backward()
             optimizer.step()
-            logger.info('tr. batch {0}/{1} (ep. {2}) fw {3:.3f}s / bw {4:.3f}s / total {5:.3f}s || '
+            logger.info('tr. batch {0}/{1} (ep. {2}) fw {3:.2f}s / bw {4:.2f} s / total {5:.2f} s || '
                         .format(bix + 1, cf.num_train_batches, epoch, tic_bw - tic_fw,
                                 time.time() - tic_bw, time.time() - tic_fw) + results_dict['logger_string'])
             train_results_list.append([results_dict['boxes'], batch['pid']])
 
         _, monitor_metrics['train'] = train_evaluator.evaluate_predictions(train_results_list, monitor_metrics['train'])
         #import IPython; IPython.embed()
         train_time = time.time() - start_time
 
         logger.info('starting validation in mode {}.'.format(cf.val_mode))
         with torch.no_grad():
             net.eval()
             if cf.do_validation:
                 val_results_list = []
                 val_predictor = Predictor(cf, net, logger, mode='val')
                 for _ in range(batch_gen['n_val']):
                     batch = next(batch_gen[cf.val_mode])
                     if cf.val_mode == 'val_patient':
                         results_dict = val_predictor.predict_patient(batch)
                     elif cf.val_mode == 'val_sampling':
                         results_dict = net.train_forward(batch, is_validation=True)
                     val_results_list.append([results_dict['boxes'], batch['pid']])
 
                 _, monitor_metrics['val'] = val_evaluator.evaluate_predictions(val_results_list, monitor_metrics['val'])
                 model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch)
 
             # update monitoring and prediction plots
             monitor_metrics.update({"lr":
                                         {str(g): group['lr'] for (g, group) in enumerate(optimizer.param_groups)}})
             logger.metrics2tboard(monitor_metrics, global_step=epoch)
 
             epoch_time = time.time() - start_time
-            logger.info('trained epoch {}: took {} sec. ({} train / {} val)'.format(
+            logger.info('trained epoch {}: took {:.2f} s ({:.2f} s train / {:.2f} s val)'.format(
                 epoch, epoch_time, train_time, epoch_time-train_time))
             batch = next(batch_gen['val_sampling'])
             results_dict = net.train_forward(batch, is_validation=True)
             logger.info('plotting predictions from validation sampling.')
             plot_batch_prediction(batch, results_dict, cf)
 
 
 def test(logger):
     """
     perform testing for a given fold (or hold out set). save stats in evaluator.
     """
     logger.info('starting testing model of fold {} in exp {}'.format(cf.fold, cf.exp_dir))
     net = model.net(cf, logger).cuda()
     test_predictor = Predictor(cf, net, logger, mode='test')
     test_evaluator = Evaluator(cf, logger, mode='test')
     batch_gen = data_loader.get_test_generator(cf, logger)
     test_results_list = test_predictor.predict_test_set(batch_gen, return_results=True)
     test_evaluator.evaluate_predictions(test_results_list)
     test_evaluator.score_test_df()
 
 
 if __name__ == '__main__':
     stime = time.time()
 
     parser = argparse.ArgumentParser()
     parser.add_argument('-m', '--mode', type=str,  default='train_test',
                         help='one out of: train / test / train_test / analysis / create_exp')
     parser.add_argument('-f','--folds', nargs='+', type=int, default=None,
                         help='None runs over all folds in CV. otherwise specify list of folds.')
     parser.add_argument('--exp_dir', type=str, default='/path/to/experiment/directory',
                         help='path to experiment dir. will be created if non existent.')
     parser.add_argument('--server_env', default=False, action='store_true',
                         help='change IO settings to deploy models on a cluster.')
     parser.add_argument('--slurm_job_id', type=str, default=None, help='job scheduler info')
     parser.add_argument('--use_stored_settings', default=False, action='store_true',
                         help='load configs from existing exp_dir instead of source dir. always done for testing, '
                              'but can be set to true to do the same for training. useful in job scheduler environment, '
                              'where source code might change before the job actually runs.')
     parser.add_argument('--resume_to_checkpoint', type=str, default=None,
                         help='if resuming to checkpoint, the desired fold still needs to be parsed via --folds.')
     parser.add_argument('--exp_source', type=str, default='experiments/toy_exp',
                         help='specifies, from which source experiment to load configs and data_loader.')
     parser.add_argument('-d', '--dev', default=False, action='store_true', help="development mode: shorten everything")
 
     args = parser.parse_args()
     folds = args.folds
     resume_to_checkpoint = args.resume_to_checkpoint
 
     if args.mode == 'train' or args.mode == 'train_test':
 
         cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, args.use_stored_settings)
         if args.dev:
             folds = [0,1]
             cf.batch_size, cf.num_epochs, cf.min_save_thresh, cf.save_n_models = 3 if cf.dim==2 else 1, 1, 0, 1
             cf.num_train_batches, cf.num_val_batches, cf.max_val_patients = 5, 1, 1
             cf.test_n_epochs =  cf.save_n_models
             cf.max_test_patients = 1
 
         cf.slurm_job_id = args.slurm_job_id
         logger = utils.get_logger(cf.exp_dir, cf.server_env)
         data_loader = utils.import_module('dl', os.path.join(args.exp_source, 'data_loader.py'))
         model = utils.import_module('model', cf.model_path)
         logger.info("loaded model from {}".format(cf.model_path))
         if folds is None:
             folds = range(cf.n_cv_splits)
 
         for fold in folds:
             cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold))
             cf.fold = fold
             cf.resume_to_checkpoint = resume_to_checkpoint
             if not os.path.exists(cf.fold_dir):
                 os.mkdir(cf.fold_dir)
             logger.set_logfile(fold=fold)
             train(logger)
             cf.resume_to_checkpoint = None
             if args.mode == 'train_test':
                 test(logger)
 
     elif args.mode == 'test':
 
         cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, is_training=False, use_stored_settings=True)
         if args.dev:
             folds = [0,1]
             cf.test_n_epochs =  1; cf.max_test_patients = 1
 
         cf.slurm_job_id = args.slurm_job_id
         logger = utils.get_logger(cf.exp_dir, cf.server_env)
         data_loader = utils.import_module('dl', os.path.join(args.exp_source, 'data_loader.py'))
         model = utils.import_module('model', cf.model_path)
         logger.info("loaded model from {}".format(cf.model_path))
         if folds is None:
             folds = range(cf.n_cv_splits)
 
         for fold in folds:
             cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold))
             cf.fold = fold
             logger.set_logfile(fold=fold)
             test(logger)
 
 
     # load raw predictions saved by predictor during testing, run aggregation algorithms and evaluation.
     elif args.mode == 'analysis':
         cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, is_training=False, use_stored_settings=True)
         logger = utils.get_logger(cf.exp_dir, cf.server_env)
 
         if cf.hold_out_test_set:
             cf.folds = args.folds
             predictor = Predictor(cf, net=None, logger=logger, mode='analysis')
             results_list = predictor.load_saved_predictions(apply_wbc=True)
             utils.create_csv_output(results_list, cf, logger)
 
         else:
             if folds is None:
                 folds = range(cf.n_cv_splits)
             for fold in folds:
                 cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold))
                 cf.fold = fold
                 logger.set_logfile(fold=fold)
                 predictor = Predictor(cf, net=None, logger=logger, mode='analysis')
                 results_list = predictor.load_saved_predictions(apply_wbc=True)
                 logger.info('starting evaluation...')
                 evaluator = Evaluator(cf, logger, mode='test')
                 evaluator.evaluate_predictions(results_list)
                 evaluator.score_test_df()
 
     # create experiment folder and copy scripts without starting job.
     # useful for cloud deployment where configs might change before job actually runs.
     elif args.mode == 'create_exp':
         cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, use_stored_settings=True)
         logger = utils.get_logger(cf.exp_dir)
         logger.info('created experiment directory at {}'.format(args.exp_dir))
 
     else:
         raise RuntimeError('mode specified in args is not implemented...')
 
     mins, secs = divmod((time.time() - stime), 60)
     h, mins = divmod(mins, 60)
     t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
     logger.info("{} total runtime: {}".format(os.path.split(__file__)[1], t))
     del logger
\ No newline at end of file
diff --git a/experiments/toy_exp/configs.py b/experiments/toy_exp/configs.py
index a4087ee..8a82acf 100644
--- a/experiments/toy_exp/configs.py
+++ b/experiments/toy_exp/configs.py
@@ -1,344 +1,344 @@
 #!/usr/bin/env python
 # Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import sys
 import os
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 import numpy as np
 from default_configs import DefaultConfigs
 
 class configs(DefaultConfigs):
 
     def __init__(self, server_env=None):
 
         #########################
         #    Preprocessing      #
         #########################
 
-        self.root_dir = '/mnt/HDD2TB/Documents/data/mdt_toy'
+        self.root_dir = '/home/gregor/datasets/toy_mdt'
 
         #########################
         #         I/O           #
         #########################
 
 
         # one out of [2, 3]. dimension the model operates in.
         self.dim = 2
 
         # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn'].
-        self.model = 'retina_unet'
+        self.model = 'mrcnn'
 
         DefaultConfigs.__init__(self, self.model, server_env, self.dim)
 
         # int [0 < dataset_size]. select n patients from dataset for prototyping.
         self.select_prototype_subset = None
         self.hold_out_test_set = True
         self.n_train_data = 1000
 
         # choose one of the 3 toy experiments described in https://arxiv.org/pdf/1811.08661.pdf
         # one of ['donuts_shape', 'donuts_pattern', 'circles_scale'].
         toy_mode = 'donuts_shape'
 
 
         # path to preprocessed data.
         self.input_df_name = 'info_df.pickle'
         self.pp_name = os.path.join(toy_mode, 'train')
         self.pp_data_path = os.path.join(self.root_dir, self.pp_name)
         self.pp_test_name = os.path.join(toy_mode, 'test')
         self.pp_test_data_path = os.path.join(self.root_dir, self.pp_test_name)
 
         # settings for deployment in cloud.
         if server_env:
             # path to preprocessed data.
             pp_root_dir = '/path/to/data'
             self.pp_name = os.path.join(toy_mode, 'train')
             self.pp_data_path = os.path.join(pp_root_dir, self.pp_name)
             self.pp_test_name = os.path.join(toy_mode, 'test')
             self.pp_test_data_path = os.path.join(pp_root_dir, self.pp_test_name)
             self.select_prototype_subset = None
 
         #########################
         #      Data Loader      #
         #########################
 
         # select modalities from preprocessed data
         self.channels = [0]
         self.n_channels = len(self.channels)
 
         # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation.
         self.pre_crop_size_2D = [320, 320]
         self.patch_size_2D = [320, 320]
 
         self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D
         self.pre_crop_size = self.pre_crop_size_2D if self.dim == 2 else self.pre_crop_size_3D
 
         # ratio of free sampled batch elements before class balancing is triggered
         # (>0 to include "empty"/background patches.)
         self.batch_sample_slack = 0.2
 
         # set 2D network to operate in 3D images.
         self.merge_2D_to_3D_preds = False
 
         # feed +/- n neighbouring slices into channel dimension. set to None for no context.
         self.n_3D_context = None
         if self.n_3D_context is not None and self.dim == 2:
             self.n_channels *= (self.n_3D_context * 2 + 1)
 
 
         #########################
         #      Architecture      #
         #########################
 
         self.start_filts = 48 if self.dim == 2 else 18
         self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2
         self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50'
         self.norm = None # one of None, 'instance_norm', 'batch_norm'
         self.weight_decay = 0
 
         # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform')
         self.weight_init = None
 
         #########################
         #  Schedule / Selection #
         #########################
 
         self.num_epochs = 100
         self.num_train_batches = 100 if self.dim == 2 else 140
         self.batch_size = 20 if self.dim == 2 else 8
 
         self.do_validation = True
         # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training)
         # the former is morge accurate, while the latter is faster (depending on volume size)
         self.val_mode = 'val_patient' # one of 'val_sampling' , 'val_patient'
         if self.val_mode == 'val_patient':
             self.max_val_patients = None  # if 'None' iterates over entire val_set once.
         if self.val_mode == 'val_sampling':
             self.num_val_batches = 50
 
         #########################
         #   Testing / Plotting  #
         #########################
 
         # set the top-n-epochs to be saved for temporal averaging in testing.
         self.save_n_models = 5
         self.test_n_epochs = 5
 
         # set a minimum epoch number for saving in case of instabilities in the first phase of training.
         self.min_save_thresh = 0 if self.dim == 2 else 0
 
         self.report_score_level = ['patient', 'rois']  # choose list from 'patient', 'rois'
         self.class_dict = {1: 'benign', 2: 'malignant'}  # 0 is background.
         self.patient_class_of_interest = 2  # patient metrics are only plotted for one class.
         self.ap_match_ious = [0.1]  # list of ious to be evaluated for ap-scoring.
 
         self.model_selection_criteria = ['benign_ap', 'malignant_ap'] # criteria to average over for saving epochs.
         self.min_det_thresh = 0.1  # minimum confidence value to select predictions for evaluation.
 
         # threshold for clustering predictions together (wcs = weighted cluster scoring).
         # needs to be >= the expected overlap of predictions coming from one model (typically NMS threshold).
         # if too high, preds of the same object are separate clusters.
         self.wcs_iou = 1e-5
 
         self.plot_prediction_histograms = True
         self.plot_stat_curves = False
 
         #########################
         #   Data Augmentation   #
         #########################
 
         self.da_kwargs={
         'do_elastic_deform': True,
         'alpha':(0., 1500.),
         'sigma':(30., 50.),
         'do_rotation':True,
         'angle_x': (0., 2 * np.pi),
         'angle_y': (0., 0),
         'angle_z': (0., 0),
         'do_scale': True,
         'scale':(0.8, 1.1),
         'random_crop':False,
         'rand_crop_dist':  (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3),
         'border_mode_data': 'constant',
         'border_cval_data': 0,
         'order_data': 1
         }
 
         if self.dim == 3:
             self.da_kwargs['do_elastic_deform'] = False
             self.da_kwargs['angle_x'] = (0, 0.0)
             self.da_kwargs['angle_y'] = (0, 0.0) #must be 0!!
             self.da_kwargs['angle_z'] = (0., 2 * np.pi)
 
 
         #########################
         #   Add model specifics #
         #########################
 
         {'detection_unet': self.add_det_unet_configs,
          'mrcnn': self.add_mrcnn_configs,
          'ufrcnn': self.add_mrcnn_configs,
          'ufrcnn_surrounding': self.add_mrcnn_configs,
          'retina_net': self.add_mrcnn_configs,
          'retina_unet': self.add_mrcnn_configs,
          'prob_detector': self.add_mrcnn_configs,
         }[self.model]()
 
 
     def add_det_unet_configs(self):
 
         self.learning_rate = [1e-4] * self.num_epochs
 
         # aggregation from pixel perdiction to object scores (connected component). One of ['max', 'median']
         self.aggregation_operation = 'max'
 
         # max number of roi candidates to identify per image (slice in 2D, volume in 3D)
         self.n_roi_candidates = 3 if self.dim == 2 else 8
 
         # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
         self.seg_loss_mode = 'dice_wce'
 
         # if <1, false positive predictions in foreground are penalized less.
         self.fp_dice_weight = 1 if self.dim == 2 else 1
 
         self.wce_weights = [1, 1, 1]
         self.detection_min_confidence = self.min_det_thresh
 
         # if 'True', loss distinguishes all classes, else only foreground vs. background (class agnostic).
         self.class_specific_seg_flag = True
         self.num_seg_classes = 3 if self.class_specific_seg_flag else 2
         self.head_classes = self.num_seg_classes
 
     def add_mrcnn_configs(self):
 
         # learning rate is a list with one entry per epoch.
         self.learning_rate = [1e-4] * self.num_epochs
 
         # disable mask head loss. (e.g. if no pixelwise annotations available)
         self.frcnn_mode = False
 
         # disable the re-sampling of mask proposals to original size for speed-up.
         # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching),
         # mask-outputs are optional.
         self.return_masks_in_val = True
         self.return_masks_in_test = False
 
         # set number of proposal boxes to plot after each epoch.
         self.n_plot_rpn_props = 5 if self.dim == 2 else 30
 
         # number of classes for head networks: n_foreground_classes + 1 (background)
         self.head_classes = 3
 
         # seg_classes hier refers to the first stage classifier (RPN)
         self.num_seg_classes = 2  # foreground vs. background
 
         # feature map strides per pyramid level are inferred from architecture.
         self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]}
 
         # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale
         # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.)
         self.rpn_anchor_scales = {'xy': [[8], [16], [32], [64]], 'z': [[2], [4], [8], [16]]}
 
         # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3.
         self.pyramid_levels = [0, 1, 2, 3]
 
         # number of feature maps in rpn. typically lowered in 3D to save gpu-memory.
         self.n_rpn_features = 512 if self.dim == 2 else 128
 
         # anchor ratios and strides per position in feature maps.
         self.rpn_anchor_ratios = [0.5, 1, 2]
         self.rpn_anchor_stride = 1
 
         # Threshold for first stage (RPN) non-maximum suppression (NMS):  LOWER == HARDER SELECTION
         self.rpn_nms_threshold = 0.7 if self.dim == 2 else 0.7
 
         # loss sampling settings.
         self.rpn_train_anchors_per_image = 2  #per batch element
         self.train_rois_per_image = 2 #per batch element
         self.roi_positive_ratio = 0.5
         self.anchor_matching_iou = 0.7
 
         # factor of top-k candidates to draw from  per negative sample (stochastic-hard-example-mining).
         # poolsize to draw top-k candidates from will be shem_poolsize * n_negative_samples.
         self.shem_poolsize = 10
 
         self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3)
         self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5)
         self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10)
 
         self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
         self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
         self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1]])
         self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1]])
 
         if self.dim == 2:
             self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4]
             self.bbox_std_dev = self.bbox_std_dev[:4]
             self.window = self.window[:4]
             self.scale = self.scale[:4]
 
         # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element.
         self.pre_nms_limit = 3000 if self.dim == 2 else 6000
 
         # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True,
         # since proposals of the entire batch are forwarded through second stage in as one "batch".
         self.roi_chunk_size = 800 if self.dim == 2 else 600
         self.post_nms_rois_training = 500 if self.dim == 2 else 75
         self.post_nms_rois_inference = 500
 
         # Final selection of detections (refine_detections)
         self.model_max_instances_per_batch_element = 10 if self.dim == 2 else 30  # per batch element and class.
         self.detection_nms_threshold = 1e-5  # needs to be > 0, otherwise all predictions are one cluster.
         self.model_min_confidence = 0.1
 
         if self.dim == 2:
             self.backbone_shapes = np.array(
                 [[int(np.ceil(self.patch_size[0] / stride)),
                   int(np.ceil(self.patch_size[1] / stride))]
                  for stride in self.backbone_strides['xy']])
         else:
             self.backbone_shapes = np.array(
                 [[int(np.ceil(self.patch_size[0] / stride)),
                   int(np.ceil(self.patch_size[1] / stride)),
                   int(np.ceil(self.patch_size[2] / stride_z))]
                  for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z']
                                              )])
         if self.model == 'ufrcnn':
             self.operate_stride1 = True
             self.class_specific_seg_flag = True
             self.num_seg_classes = 3 if self.class_specific_seg_flag else 2
             self.frcnn_mode = True
 
         if self.model == 'retina_net' or self.model == 'retina_unet' or self.model == 'prob_detector':
             # implement extra anchor-scales according to retina-net publication.
             self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                             self.rpn_anchor_scales['xy']]
             self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                            self.rpn_anchor_scales['z']]
             self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3
 
             self.n_rpn_features = 256 if self.dim == 2 else 64
 
             # pre-selection of detections for NMS-speedup. per entire batch.
             self.pre_nms_limit = 10000 if self.dim == 2 else 50000
 
             # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002
             self.anchor_matching_iou = 0.5
 
             # if 'True', seg loss distinguishes all classes, else only foreground vs. background (class agnostic).
             self.num_seg_classes = 3 if self.class_specific_seg_flag else 2
 
             if self.model == 'retina_unet':
                 self.operate_stride1 = True
diff --git a/experiments/toy_exp/generate_toys.py b/experiments/toy_exp/generate_toys.py
index 4f44768..7af8cc9 100644
--- a/experiments/toy_exp/generate_toys.py
+++ b/experiments/toy_exp/generate_toys.py
@@ -1,107 +1,112 @@
 #!/usr/bin/env python
 # Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
-import os
+import os, time
 import numpy as np
 import pandas as pd
 import pickle
 from multiprocessing import Pool
 import configs as cf
 
 def multi_processing_create_image(inputs):
 
 
     out_dir, six, foreground_margin, class_diameters, mode = inputs
     print('proceesing {} {}'.format(out_dir, six))
 
     img = np.random.rand(320, 320)
     seg = np.zeros((320, 320)).astype('uint8')
     center_x = np.random.randint(foreground_margin, img.shape[0] - foreground_margin)
     center_y = np.random.randint(foreground_margin, img.shape[1] - foreground_margin)
     class_id = np.random.randint(0, 2)
 
     for y in range(img.shape[0]):
         for x in range(img.shape[0]):
             if ((x - center_x) ** 2 + (y - center_y) ** 2 - class_diameters[class_id] ** 2) < 0:
                 img[y][x] += 0.2
                 seg[y][x] = 1
 
     if 'donuts' in mode:
         whole_diameter = 4
         if class_id == 1:
             for y in range(img.shape[0]):
                 for x in range(img.shape[0]):
                     if ((x - center_x) ** 2 + (y - center_y) ** 2 - whole_diameter ** 2) < 0:
                         img[y][x] -= 0.2
                         if mode == 'donuts_shape':
                             seg[y][x] = 0
 
     out = np.concatenate((img[None], seg[None]))
     out_path = os.path.join(out_dir, '{}.npy'.format(six))
     np.save(out_path, out)
 
     with open(os.path.join(out_dir, 'meta_info_{}.pickle'.format(six)), 'wb') as handle:
         pickle.dump([out_path, class_id, str(six)], handle)
 
 
 def generate_experiment(exp_name, n_train_images, n_test_images, mode, class_diameters=(20, 20)):
 
     train_dir = os.path.join(cf.root_dir, exp_name, 'train')
     test_dir = os.path.join(cf.root_dir, exp_name, 'test')
     if not os.path.exists(train_dir):
         os.makedirs(train_dir)
     if not os.path.exists(test_dir):
         os.makedirs(test_dir)
 
     # enforced distance between object center and image edge.
     foreground_margin = np.max(class_diameters) // 2
 
     info = []
     info += [[train_dir, six, foreground_margin, class_diameters, mode] for six in range(n_train_images)]
     info += [[test_dir, six, foreground_margin, class_diameters, mode] for six in range(n_test_images)]
 
     print('starting creating {} images'.format(len(info)))
-    pool = Pool(processes=12)
+    pool = Pool(processes=os.cpu_count()-1)
     pool.map(multi_processing_create_image, info, chunksize=1)
     pool.close()
     pool.join()
 
     aggregate_meta_info(train_dir)
     aggregate_meta_info(test_dir)
 
 
 def aggregate_meta_info(exp_dir):
 
     files = [os.path.join(exp_dir, f) for f in os.listdir(exp_dir) if 'meta_info' in f]
     df = pd.DataFrame(columns=['path', 'class_id', 'pid'])
     for f in files:
         with open(f, 'rb') as handle:
             df.loc[len(df)] = pickle.load(handle)
 
     df.to_pickle(os.path.join(exp_dir, 'info_df.pickle'))
     print ("aggregated meta info to df with length", len(df))
 
 
 if __name__ == '__main__':
-
+    stime = time.time()
     cf = cf.configs()
 
-    generate_experiment('donuts_shape_threads', n_train_images=1500, n_test_images=1000, mode='donuts_shape')
+    generate_experiment('donuts_shape', n_train_images=1500, n_test_images=1000, mode='donuts_shape')
     generate_experiment('donuts_pattern', n_train_images=1500, n_test_images=1000, mode='donuts_pattern')
     generate_experiment('circles_scale', n_train_images=1500, n_test_images=1000, mode='circles_scale', class_diameters=(19, 20))
 
 
+    mins, secs = divmod((time.time() - stime), 60)
+    h, mins = divmod(mins, 60)
+    t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
+    print("{} total runtime: {}".format(os.path.split(__file__)[1], t))
+
 
diff --git a/requirements.txt b/requirements.txt
index 61eb923..4649d7a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,65 +1,60 @@
 absl-py==0.9.0
 backcall==0.1.0
 batchgenerators==0.19.3
 cachetools==4.0.0
 certifi==2019.11.28
 cffi==1.11.5
 chardet==3.0.4
 cycler==0.10.0
 Cython==0.29.14
 decorator==4.4.1
 future==0.18.2
 google-auth==1.10.0
 google-auth-oauthlib==0.4.1
 grpcio==1.26.0
 idna==2.8
 imageio==2.6.1
-ipython==7.10.2
-ipython-genutils==0.2.0
 jedi==0.15.1
 joblib==0.14.1
 kiwisolver==1.1.0
 linecache2==1.0.0
 Markdown==3.1.1
 matplotlib==3.1.2
 networkx==2.4
--e git+ssh://git@phabricator.mitk.org:2222/source/mdt-public.git@b0eacd38bd05a3438deb829cc3f5b94743fde063#egg=nms_extension&subdirectory=custom_extensions/nms
 numpy==1.17.4
 oauthlib==3.1.0
 pandas==0.25.3
 parso==0.5.2
 pexpect==4.7.0
 pickleshare==0.7.5
 Pillow==6.2.1
 prompt-toolkit==3.0.2
 protobuf==3.11.2
 ptyprocess==0.6.0
 pyasn1==0.4.8
 pyasn1-modules==0.2.7
 pycparser==2.19
 Pygments==2.5.2
 pyparsing==2.4.5
 python-dateutil==2.8.1
 pytz==2019.3
 PyWavelets==1.1.1
 requests==2.22.0
 requests-oauthlib==1.3.0
--e git+ssh://git@phabricator.mitk.org:2222/source/mdt-public.git@b0eacd38bd05a3438deb829cc3f5b94743fde063#egg=RoIAlign_extension_2D&subdirectory=custom_extensions/roi_align
--e git+ssh://git@phabricator.mitk.org:2222/source/mdt-public.git@b0eacd38bd05a3438deb829cc3f5b94743fde063#egg=RoIAlign_extension_3D&subdirectory=custom_extensions/roi_align
 rsa==4.0
 scikit-image==0.16.2
 scikit-learn==0.22.1
 scipy==1.3.3
 six==1.13.0
 sklearn==0.0
 tensorboard==2.1.0
 threadpoolctl==1.1.0
 torch==1.3.1
 torchvision==0.4.2
 tqdm==4.40.2
 traceback2==1.4.0
 traitlets==4.3.3
 unittest2==1.1.0
 urllib3==1.25.7
 wcwidth==0.1.7
 Werkzeug==0.16.0
diff --git a/setup.py b/setup.py
index 8fc6cb9..625f23e 100644
--- a/setup.py
+++ b/setup.py
@@ -1,61 +1,61 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 from setuptools import find_packages, setup
 import os
 
 def parse_requirements(filename, exclude=[]):
     lineiter = (line.strip() for line in open(filename))
     return [line for line in lineiter if line and not line.startswith("#") and not line.split("==")[0] in exclude]
 
 def install_custom_ext(setup_path):
     os.system("python "+setup_path+" install")
     return
 
 def clean():
     """Custom clean command to tidy up the project root."""
     os.system('rm -vrf ./build ./dist ./*.pyc ./*.tgz ./*.egg-info')
 
 req_file = "requirements.txt"
 custom_exts = ["nms-extension", "RoIAlign-extension-2D", "RoIAlign-extension-3D"]
 install_reqs = parse_requirements(req_file, exclude=custom_exts)
 
 
 
 setup(name='medicaldetectiontoolkit',
       version='0.0.1',
       url="https://github.com/MIC-DKFZ/medicaldetectiontoolkit",
       author='P. Jaeger, G. Ramien, MIC at DKFZ Heidelberg',
-      licence="Apache 2.0",
+      license="Apache 2.0",
       description="Medical Object-Detection Toolkit.",
       classifiers=[
           "Development Status :: 4 - Beta",
           "Intended Audience :: Developers",
           "Programming Language :: Python :: 3.7"
       ],
       packages=find_packages(exclude=['test', 'test.*']),
       install_requires=install_reqs,
       )
 
 custom_exts =  ["custom_extensions/nms", "custom_extensions/roi_align"]
 for path in custom_exts:
     setup_path = os.path.join(path, "setup.py")
     try:
         install_custom_ext(setup_path)
     except Exception as e:
         print("FAILED to install custom extension {} due to Error:\n{}".format(path, e))
 
 clean()
\ No newline at end of file
diff --git a/unittests.py b/unittests.py
index e3f6e43..18f9258 100644
--- a/unittests.py
+++ b/unittests.py
@@ -1,272 +1,274 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import unittest
 
 import os
 import pickle
 import time
 from multiprocessing import  Pool
 import subprocess
 
 import numpy as np
 import pandas as pd
 import torch
 import torchvision as tv
 
 import tqdm
 
 import utils.exp_utils as utils
 import utils.model_utils as mutils
 
 """ Note on unittests: run this file either in the way intended for unittests by starting the script with
     python -m unittest unittests.py or start it as a normal python file as python unittests.py.
     You can selective run single tests by calling python -m unittest unittests.TestClassOfYourChoice, where 
     TestClassOfYourChoice is the name of the test defined below, e.g., CompareFoldSplits.
 """
 
 
 
 def inspect_info_df(pp_dir):
     """ use your debugger to look into the info df of a pp dir.
     :param pp_dir: preprocessed-data directory
     """
 
     info_df = pd.read_pickle(os.path.join(pp_dir, "info_df.pickle"))
 
     return
 
 
 def generate_boxes(count, dim=2, h=100, w=100, d=20, normalize=False, on_grid=False, seed=0):
     """ generate boxes of format [y1, x1, y2, x2, (z1, z2)].
     :param count: nr of boxes
     :param dim: dimension of boxes (2 or 3)
     :return: boxes in format (n_boxes, 4 or 6), scores
     """
     np.random.seed(seed)
     if on_grid:
         lower_y = np.random.randint(0, h // 2, (count,))
         lower_x = np.random.randint(0, w // 2, (count,))
         upper_y = np.random.randint(h // 2, h, (count,))
         upper_x = np.random.randint(w // 2, w, (count,))
         if dim == 3:
             lower_z = np.random.randint(0, d // 2, (count,))
             upper_z = np.random.randint(d // 2, d, (count,))
     else:
         lower_y = np.random.rand(count) * h / 2.
         lower_x = np.random.rand(count) * w / 2.
         upper_y = (np.random.rand(count) + 1.) * h / 2.
         upper_x = (np.random.rand(count) + 1.) * w / 2.
         if dim == 3:
             lower_z = np.random.rand(count) * d / 2.
             upper_z = (np.random.rand(count) + 1.) * d / 2.
 
     if dim == 3:
         boxes = np.array(list(zip(lower_y, lower_x, upper_y, upper_x, lower_z, upper_z)))
         # add an extreme box that tests the boundaries
         boxes = np.concatenate((boxes, np.array([[0., 0., h, w, 0, d]])))
     else:
         boxes = np.array(list(zip(lower_y, lower_x, upper_y, upper_x)))
         boxes = np.concatenate((boxes, np.array([[0., 0., h, w]])))
 
     scores = np.random.rand(count + 1)
     if normalize:
         divisor = np.array([h, w, h, w, d, d]) if dim == 3 else np.array([h, w, h, w])
         boxes = boxes / divisor
     return boxes, scores
 
 
 # -------- check own nms CUDA implement against own numpy implement ------
 class CheckNMSImplementation(unittest.TestCase):
 
     @staticmethod
     def assert_res_equality(keep_ics1, keep_ics2, boxes, scores, tolerance=0, names=("res1", "res2")):
         """
         :param keep_ics1: keep indices (results), torch.Tensor of shape (n_ics,)
         :param keep_ics2:
         :return:
         """
         keep_ics1, keep_ics2 = keep_ics1.cpu().numpy(), keep_ics2.cpu().numpy()
         discrepancies = np.setdiff1d(keep_ics1, keep_ics2)
         try:
             checks = np.array([
                 len(discrepancies) <= tolerance
             ])
         except:
             checks = np.zeros((1,)).astype("bool")
         msgs = np.array([
             """{}: {} \n{}: {} \nboxes: {}\n {}\n""".format(names[0], keep_ics1, names[1], keep_ics2, boxes,
                                                             scores)
         ])
 
         assert np.all(checks), "NMS: results mismatch: " + "\n".join(msgs[~checks])
 
     def single_case(self, count=20, dim=3, threshold=0.2, seed=0):
         boxes, scores = generate_boxes(count, dim, seed=seed, h=320, w=280, d=30)
 
         keep_numpy = torch.tensor(mutils.nms_numpy(boxes, scores, threshold))
 
         # for some reason torchvision nms requires box coords as floats.
         boxes = torch.from_numpy(boxes).type(torch.float32)
         scores = torch.from_numpy(scores).type(torch.float32)
         if dim == 2:
             """need to wait until next pytorch release where they fixed nms on cpu (currently they have >= where it
             needs to be >.
             """
             # keep_ops = tv.ops.nms(boxes, scores, threshold)
             # self.assert_res_equality(keep_numpy, keep_ops, boxes, scores, tolerance=0, names=["np", "ops"])
             pass
 
         boxes = boxes.cuda()
         scores = scores.cuda()
         keep = self.nms_ext.nms(boxes, scores, threshold)
         self.assert_res_equality(keep_numpy, keep, boxes, scores, tolerance=0, names=["np", "cuda"])
 
     def test(self, n_cases=200, box_count=30, threshold=0.5):
         # dynamically import module so that it doesn't affect other tests if import fails
         self.nms_ext = utils.import_module("nms_ext", 'custom_extensions/nms/nms.py')
         # change seed to something fix if you want exactly reproducible test
         seed0 = np.random.randint(50)
         print("NMS test progress (done/total box configurations) 2D:", end="\n")
         for i in tqdm.tqdm(range(n_cases)):
             self.single_case(count=box_count, dim=2, threshold=threshold, seed=seed0+i)
         print("NMS test progress (done/total box configurations) 3D:", end="\n")
         for i in tqdm.tqdm(range(n_cases)):
             self.single_case(count=box_count, dim=3, threshold=threshold, seed=seed0+i)
 
         return
 
 class CheckRoIAlignImplementation(unittest.TestCase):
 
     def prepare(self, dim=2):
 
         b, c, h, w = 1, 3, 50, 50
         # feature map, (b, c, h, w(, z))
         if dim == 2:
             fmap = torch.rand(b, c, h, w).cuda()
             # rois = torch.tensor([[
             #     [0.1, 0.1, 0.3, 0.3],
             #     [0.2, 0.2, 0.4, 0.7],
             #     [0.5, 0.7, 0.7, 0.9],
             # ]]).cuda()
             pool_size = (7, 7)
             rois = generate_boxes(5, dim=dim, h=h, w=w, on_grid=True, seed=np.random.randint(50))[0]
         elif dim == 3:
             d = 20
             fmap = torch.rand(b, c, h, w, d).cuda()
             # rois = torch.tensor([[
             #     [0.1, 0.1, 0.3, 0.3, 0.1, 0.1],
             #     [0.2, 0.2, 0.4, 0.7, 0.2, 0.4],
             #     [0.5, 0.0, 0.7, 1.0, 0.4, 0.5],
             #     [0.0, 0.0, 0.9, 1.0, 0.0, 1.0],
             # ]]).cuda()
             pool_size = (7, 7, 3)
             rois = generate_boxes(5, dim=dim, h=h, w=w, d=d, on_grid=True, seed=np.random.randint(50),
                                   normalize=False)[0]
         else:
             raise ValueError("dim needs to be 2 or 3")
 
         rois = [torch.from_numpy(rois).type(dtype=torch.float32).cuda(), ]
         fmap.requires_grad_(True)
         return fmap, rois, pool_size
 
     def check_2d(self):
         """ check vs torchvision ops not possible as on purpose different approach.
         :return:
         """
         raise NotImplementedError
         # fmap, rois, pool_size = self.prepare(dim=2)
         # ra_object = self.ra_ext.RoIAlign(output_size=pool_size, spatial_scale=1., sampling_ratio=-1)
         # align_ext = ra_object(fmap, rois)
         # loss_ext = align_ext.sum()
         # loss_ext.backward()
         #
         # rois_swapped = [rois[0][:, [1,3,0,2]]]
         # align_ops = tv.ops.roi_align(fmap, rois_swapped, pool_size)
         # loss_ops = align_ops.sum()
         # loss_ops.backward()
         #
         # assert (loss_ops == loss_ext), "sum of roialign ops and extension 2D diverges"
         # assert (align_ops == align_ext).all(), "ROIAlign failed 2D test"
 
     def check_3d(self):
         fmap, rois, pool_size = self.prepare(dim=3)
         ra_object = self.ra_ext.RoIAlign(output_size=pool_size, spatial_scale=1., sampling_ratio=-1)
         align_ext = ra_object(fmap, rois)
         loss_ext = align_ext.sum()
         loss_ext.backward()
 
         align_np = mutils.roi_align_3d_numpy(fmap.cpu().detach().numpy(), [roi.cpu().numpy() for roi in rois],
                                              pool_size)
         align_np = np.squeeze(align_np)  # remove singleton batch dim
 
         align_ext = align_ext.cpu().detach().numpy()
         assert np.allclose(align_np, align_ext, rtol=1e-5,
                            atol=1e-8), "RoIAlign differences in numpy and CUDA implement"
 
     def specific_example_check(self):
         # dummy input
         self.ra_ext = utils.import_module("ra_ext", 'custom_extensions/roi_align/roi_align.py')
         exp = 6
         pool_size = (2,2)
         fmap = torch.arange(exp**2).view(exp,exp).unsqueeze(0).unsqueeze(0).cuda().type(dtype=torch.float32)
 
         boxes = torch.tensor([[1., 1., 5., 5.]]).cuda()/exp
         ind = torch.tensor([0.]*len(boxes)).cuda().type(torch.float32)
         y_exp, x_exp = fmap.shape[2:]  # exp = expansion
         boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp], dtype=torch.float32).cuda())
         boxes = torch.cat((ind.unsqueeze(1), boxes), dim=1)
         aligned_tv = tv.ops.roi_align(fmap, boxes, output_size=pool_size, sampling_ratio=-1)
         aligned = self.ra_ext.roi_align_2d(fmap, boxes, output_size=pool_size, sampling_ratio=-1)
 
         boxes_3d = torch.cat((boxes, torch.tensor([[-1.,1.]]*len(boxes)).cuda()), dim=1)
         fmap_3d = fmap.unsqueeze(dim=-1)
         pool_size = (*pool_size,1)
         ra_object = self.ra_ext.RoIAlign(output_size=pool_size, spatial_scale=1.,)
         aligned_3d = ra_object(fmap_3d, boxes_3d)
 
-        expected_res = torch.tensor([[[[10.5000, 12.5000],
-                                       [22.5000, 24.5000]]]]).cuda()
-        expected_res_3d = torch.tensor([[[[[10.5000],[12.5000]],
-                                          [[22.5000],[24.5000]]]]]).cuda()
+        # expected_res = torch.tensor([[[[10.5000, 12.5000], # this would be with an alternative grid-point setting
+        #                                [22.5000, 24.5000]]]]).cuda()
+        expected_res = torch.tensor([[[[14., 16.],
+                                       [26., 28.]]]]).cuda()
+        expected_res_3d = torch.tensor([[[[[14.],[16.]],
+                                          [[26.],[28.]]]]]).cuda()
         assert torch.all(aligned==expected_res), "2D RoIAlign check vs. specific example failed. res: {}\n expected: {}\n".format(aligned, expected_res)
         assert torch.all(aligned_3d==expected_res_3d), "3D RoIAlign check vs. specific example failed. res: {}\n expected: {}\n".format(aligned_3d, expected_res_3d)
 
 
     def test(self):
         # dynamically import module so that it doesn't affect other tests if import fails
         self.ra_ext = utils.import_module("ra_ext", 'custom_extensions/roi_align/roi_align.py')
 
         self.specific_example_check()
 
         # 2d test
         #self.check_2d()
 
         # 3d test
         self.check_3d()
 
         return
 
 
 if __name__=="__main__":
     stime = time.time()
 
     unittest.main()
 
     mins, secs = divmod((time.time() - stime), 60)
     h, mins = divmod(mins, 60)
     t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
     print("{} total runtime: {}".format(os.path.split(__file__)[1], t))
\ No newline at end of file
diff --git a/utils/model_utils.py b/utils/model_utils.py
index 64c7cc0..3251577 100644
--- a/utils/model_utils.py
+++ b/utils/model_utils.py
@@ -1,1011 +1,1011 @@
 #!/usr/bin/env python
 # Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 """
 Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn
 published under MIT license.
 """
 
 import numpy as np
 import scipy.misc
 import scipy.ndimage
 import scipy.interpolate
 import torch
 from torch.autograd import Variable
 import torch.nn as nn
 
 import tqdm
 ############################################################
 #  Bounding Boxes
 ############################################################
 
 
 def compute_iou_2D(box, boxes, box_area, boxes_area):
     """Calculates IoU of the given box with the array of the given boxes.
     box: 1D vector [y1, x1, y2, x2] THIS IS THE GT BOX
     boxes: [boxes_count, (y1, x1, y2, x2)]
     box_area: float. the area of 'box'
     boxes_area: array of length boxes_count.
 
     Note: the areas are passed in rather than calculated here for
           efficency. Calculate once in the caller to avoid duplicate work.
     """
     # Calculate intersection areas
     y1 = np.maximum(box[0], boxes[:, 0])
     y2 = np.minimum(box[2], boxes[:, 2])
     x1 = np.maximum(box[1], boxes[:, 1])
     x2 = np.minimum(box[3], boxes[:, 3])
     intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0)
     union = box_area + boxes_area[:] - intersection[:]
     iou = intersection / union
 
     return iou
 
 
 
 def compute_iou_3D(box, boxes, box_volume, boxes_volume):
     """Calculates IoU of the given box with the array of the given boxes.
     box: 1D vector [y1, x1, y2, x2, z1, z2] (typically gt box)
     boxes: [boxes_count, (y1, x1, y2, x2, z1, z2)]
     box_area: float. the area of 'box'
     boxes_area: array of length boxes_count.
 
     Note: the areas are passed in rather than calculated here for
           efficency. Calculate once in the caller to avoid duplicate work.
     """
     # Calculate intersection areas
     y1 = np.maximum(box[0], boxes[:, 0])
     y2 = np.minimum(box[2], boxes[:, 2])
     x1 = np.maximum(box[1], boxes[:, 1])
     x2 = np.minimum(box[3], boxes[:, 3])
     z1 = np.maximum(box[4], boxes[:, 4])
     z2 = np.minimum(box[5], boxes[:, 5])
     intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0) * np.maximum(z2 - z1, 0)
     union = box_volume + boxes_volume[:] - intersection[:]
     iou = intersection / union
 
     return iou
 
 
 
 def compute_overlaps(boxes1, boxes2):
     """Computes IoU overlaps between two sets of boxes.
     boxes1, boxes2: [N, (y1, x1, y2, x2)]. / 3D: (z1, z2))
     For better performance, pass the largest set first and the smaller second.
     """
     # Areas of anchors and GT boxes
     if boxes1.shape[1] == 4:
         area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
         area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
         # Compute overlaps to generate matrix [boxes1 count, boxes2 count]
         # Each cell contains the IoU value.
         overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0]))
         for i in range(overlaps.shape[1]):
             box2 = boxes2[i] #this is the gt box
             overlaps[:, i] = compute_iou_2D(box2, boxes1, area2[i], area1)
         return overlaps
 
     else:
         # Areas of anchors and GT boxes
         volume1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) * (boxes1[:, 5] - boxes1[:, 4])
         volume2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) * (boxes2[:, 5] - boxes2[:, 4])
         # Compute overlaps to generate matrix [boxes1 count, boxes2 count]
         # Each cell contains the IoU value.
         overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0]))
         for i in range(overlaps.shape[1]):
             box2 = boxes2[i]  # this is the gt box
             overlaps[:, i] = compute_iou_3D(box2, boxes1, volume2[i], volume1)
         return overlaps
 
 
 
 def box_refinement(box, gt_box):
     """Compute refinement needed to transform box to gt_box.
     box and gt_box are [N, (y1, x1, y2, x2)] / 3D: (z1, z2))
     """
     height = box[:, 2] - box[:, 0]
     width = box[:, 3] - box[:, 1]
     center_y = box[:, 0] + 0.5 * height
     center_x = box[:, 1] + 0.5 * width
 
     gt_height = gt_box[:, 2] - gt_box[:, 0]
     gt_width = gt_box[:, 3] - gt_box[:, 1]
     gt_center_y = gt_box[:, 0] + 0.5 * gt_height
     gt_center_x = gt_box[:, 1] + 0.5 * gt_width
 
     dy = (gt_center_y - center_y) / height
     dx = (gt_center_x - center_x) / width
     dh = torch.log(gt_height / height)
     dw = torch.log(gt_width / width)
     result = torch.stack([dy, dx, dh, dw], dim=1)
 
     if box.shape[1] > 4:
         depth = box[:, 5] - box[:, 4]
         center_z = box[:, 4] + 0.5 * depth
         gt_depth = gt_box[:, 5] - gt_box[:, 4]
         gt_center_z = gt_box[:, 4] + 0.5 * gt_depth
         dz = (gt_center_z - center_z) / depth
         dd = torch.log(gt_depth / depth)
         result = torch.stack([dy, dx, dz, dh, dw, dd], dim=1)
 
     return result
 
 
 
 def unmold_mask_2D(mask, bbox, image_shape):
     """Converts a mask generated by the neural network into a format similar
     to it's original shape.
     mask: [height, width] of type float. A small, typically 28x28 mask.
     bbox: [y1, x1, y2, x2]. The box to fit the mask in.
 
     Returns a binary mask with the same size as the original image.
     """
     y1, x1, y2, x2 = bbox
     out_zoom = [y2 - y1, x2 - x1]
     zoom_factor = [i / j for i, j in zip(out_zoom, mask.shape)]
     mask = scipy.ndimage.zoom(mask, zoom_factor, order=1).astype(np.float32)
 
     # Put the mask in the right location.
     full_mask = np.zeros(image_shape[:2])
     full_mask[y1:y2, x1:x2] = mask
     return full_mask
 
 
 
 def unmold_mask_3D(mask, bbox, image_shape):
     """Converts a mask generated by the neural network into a format similar
     to it's original shape.
     mask: [height, width] of type float. A small, typically 28x28 mask.
     bbox: [y1, x1, y2, x2, z1, z2]. The box to fit the mask in.
 
     Returns a binary mask with the same size as the original image.
     """
     y1, x1, y2, x2, z1, z2 = bbox
     out_zoom = [y2 - y1, x2 - x1, z2 - z1]
     zoom_factor = [i/j for i,j in zip(out_zoom, mask.shape)]
     mask = scipy.ndimage.zoom(mask, zoom_factor, order=1).astype(np.float32)
 
     # Put the mask in the right location.
     full_mask = np.zeros(image_shape[:3])
     full_mask[y1:y2, x1:x2, z1:z2] = mask
     return full_mask
 
 
 ############################################################
 #  Anchors
 ############################################################
 
 def generate_anchors(scales, ratios, shape, feature_stride, anchor_stride):
     """
     scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128]
     ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2]
     shape: [height, width] spatial shape of the feature map over which
             to generate anchors.
     feature_stride: Stride of the feature map relative to the image in pixels.
     anchor_stride: Stride of anchors on the feature map. For example, if the
         value is 2 then generate anchors for every other feature map pixel.
     """
     # Get all combinations of scales and ratios
     scales, ratios = np.meshgrid(np.array(scales), np.array(ratios))
     scales = scales.flatten()
     ratios = ratios.flatten()
 
     # Enumerate heights and widths from scales and ratios
     heights = scales / np.sqrt(ratios)
     widths = scales * np.sqrt(ratios)
 
     # Enumerate shifts in feature space
     shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride
     shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride
     shifts_x, shifts_y = np.meshgrid(shifts_x, shifts_y)
 
     # Enumerate combinations of shifts, widths, and heights
     box_widths, box_centers_x = np.meshgrid(widths, shifts_x)
     box_heights, box_centers_y = np.meshgrid(heights, shifts_y)
 
     # Reshape to get a list of (y, x) and a list of (h, w)
     box_centers = np.stack(
         [box_centers_y, box_centers_x], axis=2).reshape([-1, 2])
     box_sizes = np.stack([box_heights, box_widths], axis=2).reshape([-1, 2])
 
     # Convert to corner coordinates (y1, x1, y2, x2)
     boxes = np.concatenate([box_centers - 0.5 * box_sizes,
                             box_centers + 0.5 * box_sizes], axis=1)
     return boxes
 
 
 
 def generate_anchors_3D(scales_xy, scales_z, ratios, shape, feature_stride_xy, feature_stride_z, anchor_stride):
     """
     scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128]
     ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2]
     shape: [height, width] spatial shape of the feature map over which
             to generate anchors.
     feature_stride: Stride of the feature map relative to the image in pixels.
     anchor_stride: Stride of anchors on the feature map. For example, if the
         value is 2 then generate anchors for every other feature map pixel.
     """
     # Get all combinations of scales and ratios
 
     scales_xy, ratios_meshed = np.meshgrid(np.array(scales_xy), np.array(ratios))
     scales_xy = scales_xy.flatten()
     ratios_meshed = ratios_meshed.flatten()
 
     # Enumerate heights and widths from scales and ratios
     heights = scales_xy / np.sqrt(ratios_meshed)
     widths = scales_xy * np.sqrt(ratios_meshed)
     depths = np.tile(np.array(scales_z), len(ratios_meshed)//np.array(scales_z)[..., None].shape[0])
 
     # Enumerate shifts in feature space
     shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride_xy #translate from fm positions to input coords.
     shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride_xy
     shifts_z = np.arange(0, shape[2], anchor_stride) * (feature_stride_z)
     shifts_x, shifts_y, shifts_z = np.meshgrid(shifts_x, shifts_y, shifts_z)
 
     # Enumerate combinations of shifts, widths, and heights
     box_widths, box_centers_x = np.meshgrid(widths, shifts_x)
     box_heights, box_centers_y = np.meshgrid(heights, shifts_y)
     box_depths, box_centers_z = np.meshgrid(depths, shifts_z)
 
     # Reshape to get a list of (y, x, z) and a list of (h, w, d)
     box_centers = np.stack(
         [box_centers_y, box_centers_x, box_centers_z], axis=2).reshape([-1, 3])
     box_sizes = np.stack([box_heights, box_widths, box_depths], axis=2).reshape([-1, 3])
 
     # Convert to corner coordinates (y1, x1, y2, x2, z1, z2)
     boxes = np.concatenate([box_centers - 0.5 * box_sizes,
                             box_centers + 0.5 * box_sizes], axis=1)
 
     boxes = np.transpose(np.array([boxes[:, 0], boxes[:, 1], boxes[:, 3], boxes[:, 4], boxes[:, 2], boxes[:, 5]]), axes=(1, 0))
     return boxes
 
 
 def generate_pyramid_anchors(logger, cf):
     """Generate anchors at different levels of a feature pyramid. Each scale
     is associated with a level of the pyramid, but each ratio is used in
     all levels of the pyramid.
 
     from configs:
     :param scales: cf.RPN_ANCHOR_SCALES , e.g. [4, 8, 16, 32]
     :param ratios: cf.RPN_ANCHOR_RATIOS , e.g. [0.5, 1, 2]
     :param feature_shapes: cf.BACKBONE_SHAPES , e.g.  [array of shapes per feature map] [80, 40, 20, 10, 5]
     :param feature_strides: cf.BACKBONE_STRIDES , e.g. [2, 4, 8, 16, 32, 64]
     :param anchors_stride: cf.RPN_ANCHOR_STRIDE , e.g. 1
     :return anchors: (N, (y1, x1, y2, x2, (z1), (z2)). All generated anchors in one array. Sorted
     with the same order of the given scales. So, anchors of scale[0] come first, then anchors of scale[1], and so on.
     """
     scales = cf.rpn_anchor_scales
     ratios = cf.rpn_anchor_ratios
     feature_shapes = cf.backbone_shapes
     anchor_stride = cf.rpn_anchor_stride
     pyramid_levels = cf.pyramid_levels
     feature_strides = cf.backbone_strides
 
     anchors = []
     logger.info("feature map shapes: {}".format(feature_shapes))
     logger.info("anchor scales: {}".format(scales))
 
     expected_anchors = [np.prod(feature_shapes[ii]) * len(ratios) * len(scales['xy'][ii]) for ii in pyramid_levels]
 
     for lix, level in enumerate(pyramid_levels):
         if len(feature_shapes[level]) == 2:
             anchors.append(generate_anchors(scales['xy'][level], ratios, feature_shapes[level],
                                             feature_strides['xy'][level], anchor_stride))
         else:
             anchors.append(generate_anchors_3D(scales['xy'][level], scales['z'][level], ratios, feature_shapes[level],
                                             feature_strides['xy'][level], feature_strides['z'][level], anchor_stride))
 
         logger.info("level {}: built anchors {} / expected anchors {} ||| total build {} / total expected {}".format(
             level, anchors[-1].shape, expected_anchors[lix], np.concatenate(anchors).shape, np.sum(expected_anchors)))
 
     out_anchors = np.concatenate(anchors, axis=0)
     return out_anchors
 
 
 
 def apply_box_deltas_2D(boxes, deltas):
     """Applies the given deltas to the given boxes.
     boxes: [N, 4] where each row is y1, x1, y2, x2
     deltas: [N, 4] where each row is [dy, dx, log(dh), log(dw)]
     """
     # Convert to y, x, h, w
     height = boxes[:, 2] - boxes[:, 0]
     width = boxes[:, 3] - boxes[:, 1]
     center_y = boxes[:, 0] + 0.5 * height
     center_x = boxes[:, 1] + 0.5 * width
     # Apply deltas
     center_y += deltas[:, 0] * height
     center_x += deltas[:, 1] * width
     height *= torch.exp(deltas[:, 2])
     width *= torch.exp(deltas[:, 3])
     # Convert back to y1, x1, y2, x2
     y1 = center_y - 0.5 * height
     x1 = center_x - 0.5 * width
     y2 = y1 + height
     x2 = x1 + width
     result = torch.stack([y1, x1, y2, x2], dim=1)
     return result
 
 
 
 def apply_box_deltas_3D(boxes, deltas):
     """Applies the given deltas to the given boxes.
     boxes: [N, 6] where each row is y1, x1, y2, x2, z1, z2
     deltas: [N, 6] where each row is [dy, dx, dz, log(dh), log(dw), log(dd)]
     """
     # Convert to y, x, h, w
     height = boxes[:, 2] - boxes[:, 0]
     width = boxes[:, 3] - boxes[:, 1]
     depth = boxes[:, 5] - boxes[:, 4]
     center_y = boxes[:, 0] + 0.5 * height
     center_x = boxes[:, 1] + 0.5 * width
     center_z = boxes[:, 4] + 0.5 * depth
     # Apply deltas
     center_y += deltas[:, 0] * height
     center_x += deltas[:, 1] * width
     center_z += deltas[:, 2] * depth
     height *= torch.exp(deltas[:, 3])
     width *= torch.exp(deltas[:, 4])
     depth *= torch.exp(deltas[:, 5])
     # Convert back to y1, x1, y2, x2
     y1 = center_y - 0.5 * height
     x1 = center_x - 0.5 * width
     z1 = center_z - 0.5 * depth
     y2 = y1 + height
     x2 = x1 + width
     z2 = z1 + depth
     result = torch.stack([y1, x1, y2, x2, z1, z2], dim=1)
     return result
 
 
 
 def clip_boxes_2D(boxes, window):
     """
     boxes: [N, 4] each col is y1, x1, y2, x2
     window: [4] in the form y1, x1, y2, x2
     """
     boxes = torch.stack( \
         [boxes[:, 0].clamp(float(window[0]), float(window[2])),
          boxes[:, 1].clamp(float(window[1]), float(window[3])),
          boxes[:, 2].clamp(float(window[0]), float(window[2])),
          boxes[:, 3].clamp(float(window[1]), float(window[3]))], 1)
     return boxes
 
 def clip_boxes_3D(boxes, window):
     """
     boxes: [N, 6] each col is y1, x1, y2, x2, z1, z2
     window: [6] in the form y1, x1, y2, x2, z1, z2
     """
     boxes = torch.stack( \
         [boxes[:, 0].clamp(float(window[0]), float(window[2])),
          boxes[:, 1].clamp(float(window[1]), float(window[3])),
          boxes[:, 2].clamp(float(window[0]), float(window[2])),
          boxes[:, 3].clamp(float(window[1]), float(window[3])),
          boxes[:, 4].clamp(float(window[4]), float(window[5])),
          boxes[:, 5].clamp(float(window[4]), float(window[5]))], 1)
     return boxes
 
 
 
 def clip_boxes_numpy(boxes, window):
     """
     boxes: [N, 4] each col is y1, x1, y2, x2 / [N, 6] in 3D.
     window: iamge shape (y, x, (z))
     """
     if boxes.shape[1] == 4:
         boxes = np.concatenate(
             (np.clip(boxes[:, 0], 0, window[0])[:, None],
             np.clip(boxes[:, 1], 0, window[0])[:, None],
             np.clip(boxes[:, 2], 0, window[1])[:, None],
             np.clip(boxes[:, 3], 0, window[1])[:, None]), 1
         )
 
     else:
         boxes = np.concatenate(
             (np.clip(boxes[:, 0], 0, window[0])[:, None],
              np.clip(boxes[:, 1], 0, window[0])[:, None],
              np.clip(boxes[:, 2], 0, window[1])[:, None],
              np.clip(boxes[:, 3], 0, window[1])[:, None],
              np.clip(boxes[:, 4], 0, window[2])[:, None],
              np.clip(boxes[:, 5], 0, window[2])[:, None]), 1
         )
 
     return boxes
 
 
 
 def bbox_overlaps_2D(boxes1, boxes2):
     """Computes IoU overlaps between two sets of boxes.
     boxes1, boxes2: [N, (y1, x1, y2, x2)].
     """
     # 1. Tile boxes2 and repeate boxes1. This allows us to compare
     # every boxes1 against every boxes2 without loops.
     # TF doesn't have an equivalent to np.repeate() so simulate it
     # using tf.tile() and tf.reshape.
     boxes1_repeat = boxes2.size()[0]
     boxes2_repeat = boxes1.size()[0]
     boxes1 = boxes1.repeat(1,boxes1_repeat).view(-1,4)
     boxes2 = boxes2.repeat(boxes2_repeat,1)
 
     # 2. Compute intersections
     b1_y1, b1_x1, b1_y2, b1_x2 = boxes1.chunk(4, dim=1)
     b2_y1, b2_x1, b2_y2, b2_x2 = boxes2.chunk(4, dim=1)
     y1 = torch.max(b1_y1, b2_y1)[:, 0]
     x1 = torch.max(b1_x1, b2_x1)[:, 0]
     y2 = torch.min(b1_y2, b2_y2)[:, 0]
     x2 = torch.min(b1_x2, b2_x2)[:, 0]
     zeros = Variable(torch.zeros(y1.size()[0]), requires_grad=False)
     if y1.is_cuda:
         zeros = zeros.cuda()
     intersection = torch.max(x2 - x1, zeros) * torch.max(y2 - y1, zeros)
 
     # 3. Compute unions
     b1_area = (b1_y2 - b1_y1) * (b1_x2 - b1_x1)
     b2_area = (b2_y2 - b2_y1) * (b2_x2 - b2_x1)
     union = b1_area[:,0] + b2_area[:,0] - intersection
 
     # 4. Compute IoU and reshape to [boxes1, boxes2]
     iou = intersection / union
     overlaps = iou.view(boxes2_repeat, boxes1_repeat)
     return overlaps
 
 
 
 def bbox_overlaps_3D(boxes1, boxes2):
     """Computes IoU overlaps between two sets of boxes.
     boxes1, boxes2: [N, (y1, x1, y2, x2, z1, z2)].
     """
     # 1. Tile boxes2 and repeate boxes1. This allows us to compare
     # every boxes1 against every boxes2 without loops.
     # TF doesn't have an equivalent to np.repeate() so simulate it
     # using tf.tile() and tf.reshape.
     boxes1_repeat = boxes2.size()[0]
     boxes2_repeat = boxes1.size()[0]
     boxes1 = boxes1.repeat(1,boxes1_repeat).view(-1,6)
     boxes2 = boxes2.repeat(boxes2_repeat,1)
 
     # 2. Compute intersections
     b1_y1, b1_x1, b1_y2, b1_x2, b1_z1, b1_z2 = boxes1.chunk(6, dim=1)
     b2_y1, b2_x1, b2_y2, b2_x2, b2_z1, b2_z2 = boxes2.chunk(6, dim=1)
     y1 = torch.max(b1_y1, b2_y1)[:, 0]
     x1 = torch.max(b1_x1, b2_x1)[:, 0]
     y2 = torch.min(b1_y2, b2_y2)[:, 0]
     x2 = torch.min(b1_x2, b2_x2)[:, 0]
     z1 = torch.max(b1_z1, b2_z1)[:, 0]
     z2 = torch.min(b1_z2, b2_z2)[:, 0]
     zeros = Variable(torch.zeros(y1.size()[0]), requires_grad=False)
     if y1.is_cuda:
         zeros = zeros.cuda()
     intersection = torch.max(x2 - x1, zeros) * torch.max(y2 - y1, zeros) * torch.max(z2 - z1, zeros)
 
     # 3. Compute unions
     b1_volume = (b1_y2 - b1_y1) * (b1_x2 - b1_x1)  * (b1_z2 - b1_z1)
     b2_volume = (b2_y2 - b2_y1) * (b2_x2 - b2_x1)  * (b2_z2 - b2_z1)
     union = b1_volume[:,0] + b2_volume[:,0] - intersection
 
     # 4. Compute IoU and reshape to [boxes1, boxes2]
     iou = intersection / union
     overlaps = iou.view(boxes2_repeat, boxes1_repeat)
     return overlaps
 
 
 
 def gt_anchor_matching(cf, anchors, gt_boxes, gt_class_ids=None):
     """Given the anchors and GT boxes, compute overlaps and identify positive
     anchors and deltas to refine them to match their corresponding GT boxes.
 
     anchors: [num_anchors, (y1, x1, y2, x2, (z1), (z2))]
     gt_boxes: [num_gt_boxes, (y1, x1, y2, x2, (z1), (z2))]
     gt_class_ids (optional): [num_gt_boxes] Integer class IDs for one stage detectors. in RPN case of Mask R-CNN,
     set all positive matches to 1 (foreground)
 
     Returns:
     anchor_class_matches: [N] (int32) matches between anchors and GT boxes.
                1 = positive anchor, -1 = negative anchor, 0 = neutral.
                In case of one stage detectors like RetinaNet/RetinaUNet this flag takes
                class_ids as positive anchor values, i.e. values >= 1!
     anchor_delta_targets: [N, (dy, dx, (dz), log(dh), log(dw), (log(dd)))] Anchor bbox deltas.
     """
 
     anchor_class_matches = np.zeros([anchors.shape[0]], dtype=np.int32)
     anchor_delta_targets = np.zeros((cf.rpn_train_anchors_per_image, 2*cf.dim))
     anchor_matching_iou = cf.anchor_matching_iou
 
     if gt_boxes is None:
         anchor_class_matches = np.full(anchor_class_matches.shape, fill_value=-1)
         return anchor_class_matches, anchor_delta_targets
 
     # for mrcnn: anchor matching is done for RPN loss, so positive labels are all 1 (foreground)
     if gt_class_ids is None:
         gt_class_ids = np.array([1] * len(gt_boxes))
 
     # Compute overlaps [num_anchors, num_gt_boxes]
     overlaps = compute_overlaps(anchors, gt_boxes)
 
     # Match anchors to GT Boxes
     # If an anchor overlaps a GT box with IoU >= anchor_matching_iou then it's positive.
     # If an anchor overlaps a GT box with IoU < 0.1 then it's negative.
     # Neutral anchors are those that don't match the conditions above,
     # and they don't influence the loss function.
     # However, don't keep any GT box unmatched (rare, but happens). Instead,
     # match it to the closest anchor (even if its max IoU is < 0.1).
 
     # 1. Set negative anchors first. They get overwritten below if a GT box is
     # matched to them. Skip boxes in crowd areas.
     anchor_iou_argmax = np.argmax(overlaps, axis=1)
     anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax]
     if anchors.shape[1] == 4:
         anchor_class_matches[(anchor_iou_max < 0.1)] = -1
     elif anchors.shape[1] == 6:
         anchor_class_matches[(anchor_iou_max < 0.01)] = -1
     else:
         raise ValueError('anchor shape wrong {}'.format(anchors.shape))
 
     # 2. Set an anchor for each GT box (regardless of IoU value).
     gt_iou_argmax = np.argmax(overlaps, axis=0)
     for ix, ii in enumerate(gt_iou_argmax):
         anchor_class_matches[ii] = gt_class_ids[ix]
 
     # 3. Set anchors with high overlap as positive.
     above_trhesh_ixs = np.argwhere(anchor_iou_max >= anchor_matching_iou)
     anchor_class_matches[above_trhesh_ixs] = gt_class_ids[anchor_iou_argmax[above_trhesh_ixs]]
 
     # Subsample to balance positive anchors.
     ids = np.where(anchor_class_matches > 0)[0]
     extra = len(ids) - (cf.rpn_train_anchors_per_image // 2)
     if extra > 0:
         # Reset the extra ones to neutral
         ids = np.random.choice(ids, extra, replace=False)
         anchor_class_matches[ids] = 0
 
     # Leave all negative proposals negative now and sample from them in online hard example mining.
     # For positive anchors, compute shift and scale needed to transform them to match the corresponding GT boxes.
     ids = np.where(anchor_class_matches > 0)[0]
     ix = 0  # index into anchor_delta_targets
     for i, a in zip(ids, anchors[ids]):
         # closest gt box (it might have IoU < anchor_matching_iou)
         gt = gt_boxes[anchor_iou_argmax[i]]
 
         # convert coordinates to center plus width/height.
         gt_h = gt[2] - gt[0]
         gt_w = gt[3] - gt[1]
         gt_center_y = gt[0] + 0.5 * gt_h
         gt_center_x = gt[1] + 0.5 * gt_w
         # Anchor
         a_h = a[2] - a[0]
         a_w = a[3] - a[1]
         a_center_y = a[0] + 0.5 * a_h
         a_center_x = a[1] + 0.5 * a_w
 
         if cf.dim == 2:
             anchor_delta_targets[ix] = [
                 (gt_center_y - a_center_y) / a_h,
                 (gt_center_x - a_center_x) / a_w,
                 np.log(gt_h / a_h),
                 np.log(gt_w / a_w),
             ]
 
         else:
             gt_d = gt[5] - gt[4]
             gt_center_z = gt[4] + 0.5 * gt_d
             a_d = a[5] - a[4]
             a_center_z = a[4] + 0.5 * a_d
 
             anchor_delta_targets[ix] = [
                 (gt_center_y - a_center_y) / a_h,
                 (gt_center_x - a_center_x) / a_w,
                 (gt_center_z - a_center_z) / a_d,
                 np.log(gt_h / a_h),
                 np.log(gt_w / a_w),
                 np.log(gt_d / a_d)
             ]
 
         # normalize.
         anchor_delta_targets[ix] /= cf.rpn_bbox_std_dev
         ix += 1
 
     return anchor_class_matches, anchor_delta_targets
 
 
 
 def clip_to_window(window, boxes):
     """
         window: (y1, x1, y2, x2) / 3D: (z1, z2). The window in the image we want to clip to.
         boxes: [N, (y1, x1, y2, x2)]  / 3D: (z1, z2)
     """
     boxes[:, 0] = boxes[:, 0].clamp(float(window[0]), float(window[2]))
     boxes[:, 1] = boxes[:, 1].clamp(float(window[1]), float(window[3]))
     boxes[:, 2] = boxes[:, 2].clamp(float(window[0]), float(window[2]))
     boxes[:, 3] = boxes[:, 3].clamp(float(window[1]), float(window[3]))
 
     if boxes.shape[1] > 5:
         boxes[:, 4] = boxes[:, 4].clamp(float(window[4]), float(window[5]))
         boxes[:, 5] = boxes[:, 5].clamp(float(window[4]), float(window[5]))
 
     return boxes
 
 
 def nms_numpy(box_coords, scores, thresh):
     """ non-maximum suppression on 2D or 3D boxes in numpy.
     :param box_coords: [y1,x1,y2,x2 (,z1,z2)] with y1<=y2, x1<=x2, z1<=z2.
     :param scores: ranking scores (higher score == higher rank) of boxes.
     :param thresh: IoU threshold for clustering.
     :return:
     """
     y1 = box_coords[:, 0]
     x1 = box_coords[:, 1]
     y2 = box_coords[:, 2]
     x2 = box_coords[:, 3]
     assert np.all(y1 <= y2) and np.all(x1 <= x2), """"the definition of the coordinates is crucially important here: 
             coordinates of which maxima are taken need to be the lower coordinates"""
     areas = (x2 - x1) * (y2 - y1)
 
     is_3d = box_coords.shape[1] == 6
     if is_3d: # 3-dim case
         z1 = box_coords[:, 4]
         z2 = box_coords[:, 5]
         assert np.all(z1<=z2), """"the definition of the coordinates is crucially important here: 
            coordinates of which maxima are taken need to be the lower coordinates"""
         areas *= (z2 - z1)
 
     order = scores.argsort()[::-1]
 
     keep = []
     while order.size > 0:  # order is the sorted index.  maps order to index: order[1] = 24 means (rank1, ix 24)
         i = order[0] # highest scoring element
         yy1 = np.maximum(y1[i], y1[order])  # highest scoring element still in >order<, is compared to itself, that is okay.
         xx1 = np.maximum(x1[i], x1[order])
         yy2 = np.minimum(y2[i], y2[order])
         xx2 = np.minimum(x2[i], x2[order])
 
         h = np.maximum(0.0, yy2 - yy1)
         w = np.maximum(0.0, xx2 - xx1)
         inter = h * w
 
         if is_3d:
             zz1 = np.maximum(z1[i], z1[order])
             zz2 = np.minimum(z2[i], z2[order])
             d = np.maximum(0.0, zz2 - zz1)
             inter *= d
 
         iou = inter / (areas[i] + areas[order] - inter)
 
         non_matches = np.nonzero(iou <= thresh)[0]  # get all elements that were not matched and discard all others.
         order = order[non_matches]
         keep.append(i)
 
     return keep
 
 def roi_align_3d_numpy(input: np.ndarray, rois, output_size: tuple,
                        spatial_scale: float = 1., sampling_ratio: int = -1) -> np.ndarray:
     """ This fct mainly serves as a verification method for 3D CUDA implementation of RoIAlign, it's highly
         inefficient due to the nested loops.
     :param input:  (ndarray[N, C, H, W, D]): input feature map
     :param rois: list (N,K(n), 6), K(n) = nr of rois in batch-element n, single roi of format (y1,x1,y2,x2,z1,z2)
     :param output_size:
     :param spatial_scale:
     :param sampling_ratio:
     :return: (List[N, K(n), C, output_size[0], output_size[1], output_size[2]])
     """
 
     out_height, out_width, out_depth = output_size
 
     coord_grid = tuple([np.linspace(0, input.shape[dim] - 1, num=input.shape[dim]) for dim in range(2, 5)])
     pooled_rois = [[]] * len(rois)
     assert len(rois) == input.shape[0], "batch dim mismatch, rois: {}, input: {}".format(len(rois), input.shape[0])
     print("Numpy 3D RoIAlign progress:", end="\n")
     for b in range(input.shape[0]):
         for roi in tqdm.tqdm(rois[b]):
             y1, x1, y2, x2, z1, z2 = np.array(roi) * spatial_scale
             roi_height = max(float(y2 - y1), 1.)
             roi_width = max(float(x2 - x1), 1.)
             roi_depth = max(float(z2 - z1), 1.)
 
             if sampling_ratio <= 0:
                 sampling_ratio_h = int(np.ceil(roi_height / out_height))
                 sampling_ratio_w = int(np.ceil(roi_width / out_width))
                 sampling_ratio_d = int(np.ceil(roi_depth / out_depth))
             else:
                 sampling_ratio_h = sampling_ratio_w = sampling_ratio_d = sampling_ratio  # == n points per bin
 
             bin_height = roi_height / out_height
             bin_width = roi_width / out_width
             bin_depth = roi_depth / out_depth
 
             n_points = sampling_ratio_h * sampling_ratio_w * sampling_ratio_d
             pooled_roi = np.empty((input.shape[1], out_height, out_width, out_depth), dtype="float32")
             for chan in range(input.shape[1]):
                 lin_interpolator = scipy.interpolate.RegularGridInterpolator(coord_grid, input[b, chan],
                                                                              method="linear")
                 for bin_iy in range(out_height):
                     for bin_ix in range(out_width):
                         for bin_iz in range(out_depth):
 
                             bin_val = 0.
                             for i in range(sampling_ratio_h):
                                 for j in range(sampling_ratio_w):
                                     for k in range(sampling_ratio_d):
                                         loc_ijk = [
-                                            y1 + bin_iy * bin_height + (i + 0.5) * ((bin_height -1) / sampling_ratio_h),
-                                            x1 + bin_ix * bin_width + (j + 0.5) * ((bin_width -1) / sampling_ratio_w),
-                                            z1 + bin_iz * bin_depth + (k + 0.5) * ((bin_depth -1) / sampling_ratio_d)]
+                                            y1 + bin_iy * bin_height + (i + 0.5) * (bin_height / sampling_ratio_h),
+                                            x1 + bin_ix * bin_width + (j + 0.5) * (bin_width / sampling_ratio_w),
+                                            z1 + bin_iz * bin_depth + (k + 0.5) * (bin_depth / sampling_ratio_d)]
                                         # print("loc_ijk", loc_ijk)
                                         if not (np.any([c < -1.0 for c in loc_ijk]) or loc_ijk[0] > input.shape[2] or
                                                 loc_ijk[1] > input.shape[3] or loc_ijk[2] > input.shape[4]):
                                             for catch_case in range(3):
                                                 # catch on-border cases
                                                 if int(loc_ijk[catch_case]) == input.shape[catch_case + 2] - 1:
                                                     loc_ijk[catch_case] = input.shape[catch_case + 2] - 1
                                             bin_val += lin_interpolator(loc_ijk)
                             pooled_roi[chan, bin_iy, bin_ix, bin_iz] = bin_val / n_points
 
             pooled_rois[b].append(pooled_roi)
 
     return np.array(pooled_rois)
 
 
 ############################################################
 #  Pytorch Utility Functions
 ############################################################
 
 
 def unique1d(tensor):
     if tensor.size()[0] == 0 or tensor.size()[0] == 1:
         return tensor
     tensor = tensor.sort()[0]
     unique_bool = tensor[1:] != tensor [:-1]
     first_element = torch.tensor([True], dtype=torch.bool, requires_grad=False)
     if tensor.is_cuda:
         first_element = first_element.cuda()
     unique_bool = torch.cat((first_element, unique_bool),dim=0)
     return tensor[unique_bool.data]
 
 
 
 def log2(x):
     """Implementatin of Log2. Pytorch doesn't have a native implemenation."""
     ln2 = Variable(torch.log(torch.FloatTensor([2.0])), requires_grad=False)
     if x.is_cuda:
         ln2 = ln2.cuda()
     return torch.log(x) / ln2
 
 
 
 def intersect1d(tensor1, tensor2):
     aux = torch.cat((tensor1, tensor2), dim=0)
     aux = aux.sort(descending=True)[0]
     return aux[:-1][(aux[1:] == aux[:-1]).data]
 
 
 
 def shem(roi_probs_neg, negative_count, ohem_poolsize):
     """
     stochastic hard example mining: from a list of indices (referring to non-matched predictions),
     determine a pool of highest scoring (worst false positives) of size negative_count*ohem_poolsize.
     Then, sample n (= negative_count) predictions of this pool as negative examples for loss.
     :param roi_probs_neg: tensor of shape (n_predictions, n_classes).
     :param negative_count: int.
     :param ohem_poolsize: int.
     :return: (negative_count).  indices refer to the positions in roi_probs_neg. If pool smaller than expected due to
     limited negative proposals availabel, this function will return sampled indices of number < negative_count without
     throwing an error.
     """
     # sort according to higehst foreground score.
     probs, order = roi_probs_neg[:, 1:].max(1)[0].sort(descending=True)
     select = torch.tensor((ohem_poolsize * int(negative_count), order.size()[0])).min().int()
     pool_indices = order[:select]
     rand_idx = torch.randperm(pool_indices.size()[0])
     return pool_indices[rand_idx[:negative_count].cuda()]
 
 
 
 def initialize_weights(net):
     """
    Initialize model weights. Current Default in Pytorch (version 0.4.1) is initialization from a uniform distriubtion.
    Will expectably be changed to kaiming_uniform in future versions.
    """
     init_type = net.cf.weight_init
 
     for m in [module for module in net.modules() if type(module) in [nn.Conv2d, nn.Conv3d,
                                                                      nn.ConvTranspose2d,
                                                                      nn.ConvTranspose3d,
                                                                      nn.Linear]]:
         if init_type == 'xavier_uniform':
             nn.init.xavier_uniform_(m.weight.data)
             if m.bias is not None:
                 m.bias.data.zero_()
 
         elif init_type == 'xavier_normal':
             nn.init.xavier_normal_(m.weight.data)
             if m.bias is not None:
                 m.bias.data.zero_()
 
         elif init_type == "kaiming_uniform":
             nn.init.kaiming_uniform_(m.weight.data, mode='fan_out', nonlinearity=net.cf.relu, a=0)
             if m.bias is not None:
                 fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                 bound = 1 / np.sqrt(fan_out)
                 nn.init.uniform_(m.bias, -bound, bound)
 
         elif init_type == "kaiming_normal":
             nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity=net.cf.relu, a=0)
             if m.bias is not None:
                 fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                 bound = 1 / np.sqrt(fan_out)
                 nn.init.normal_(m.bias, -bound, bound)
 
 
 
 class NDConvGenerator(object):
     """
     generic wrapper around conv-layers to avoid 2D vs. 3D distinguishing in code.
     """
     def __init__(self, dim):
         self.dim = dim
 
     def __call__(self, c_in, c_out, ks, pad=0, stride=1, norm=None, relu='relu'):
         """
         :param c_in: number of in_channels.
         :param c_out: number of out_channels.
         :param ks: kernel size.
         :param pad: pad size.
         :param stride: kernel stride.
         :param norm: string specifying type of feature map normalization. If None, no normalization is applied.
         :param relu: string specifying type of nonlinearity. If None, no nonlinearity is applied.
         :return: convolved feature_map.
         """
         if self.dim == 2:
             conv = nn.Conv2d(c_in, c_out, kernel_size=ks, padding=pad, stride=stride)
             if norm is not None:
                 if norm == 'instance_norm':
                     norm_layer = nn.InstanceNorm2d(c_out)
                 elif norm == 'batch_norm':
                     norm_layer = nn.BatchNorm2d(c_out)
                 else:
                     raise ValueError('norm type as specified in configs is not implemented...')
                 conv = nn.Sequential(conv, norm_layer)
 
         else:
             conv = nn.Conv3d(c_in, c_out, kernel_size=ks, padding=pad, stride=stride)
             if norm is not None:
                 if norm == 'instance_norm':
                     norm_layer = nn.InstanceNorm3d(c_out)
                 elif norm == 'batch_norm':
                     norm_layer = nn.BatchNorm3d(c_out)
                 else:
                     raise ValueError('norm type as specified in configs is not implemented... {}'.format(norm))
                 conv = nn.Sequential(conv, norm_layer)
 
         if relu is not None:
             if relu == 'relu':
                 relu_layer = nn.ReLU(inplace=True)
             elif relu == 'leaky_relu':
                 relu_layer = nn.LeakyReLU(inplace=True)
             else:
                 raise ValueError('relu type as specified in configs is not implemented...')
             conv = nn.Sequential(conv, relu_layer)
 
         return conv
 
 
 
 def get_one_hot_encoding(y, n_classes):
     """
     transform a numpy label array to a one-hot array of the same shape.
     :param y: array of shape (b, 1, y, x, (z)).
     :param n_classes: int, number of classes to unfold in one-hot encoding.
     :return y_ohe: array of shape (b, n_classes, y, x, (z))
     """
     dim = len(y.shape) - 2
     if dim == 2:
         y_ohe = np.zeros((y.shape[0], n_classes, y.shape[2], y.shape[3])).astype('int32')
     if dim ==3:
         y_ohe = np.zeros((y.shape[0], n_classes, y.shape[2], y.shape[3], y.shape[4])).astype('int32')
     for cl in range(n_classes):
         y_ohe[:, cl][y[:, 0] == cl] = 1
     return y_ohe
 
 
 
 def get_dice_per_batch_and_class(pred, y, n_classes):
     '''
     computes dice scores per batch instance and class.
     :param pred: prediction array of shape (b, 1, y, x, (z)) (e.g. softmax prediction with argmax over dim 1)
     :param y: ground truth array of shape (b, 1, y, x, (z)) (contains int [0, ..., n_classes]
     :param n_classes: int
     :return: dice scores of shape (b, c)
     '''
     pred = get_one_hot_encoding(pred, n_classes)
     y = get_one_hot_encoding(y, n_classes)
     axes = tuple(range(2, len(pred.shape)))
     intersect = np.sum(pred*y, axis=axes)
     denominator = np.sum(pred, axis=axes)+np.sum(y, axis=axes) + 1e-8
     dice = 2.0*intersect / denominator
     return dice
 
 
 
 def sum_tensor(input, axes, keepdim=False):
     axes = np.unique(axes)
     if keepdim:
         for ax in axes:
             input = input.sum(ax, keepdim=True)
     else:
         for ax in sorted(axes, reverse=True):
             input = input.sum(int(ax))
     return input
 
 
 
 def batch_dice(pred, y, false_positive_weight=1.0, smooth=1e-6):
     '''
     compute soft dice over batch. this is a differentiable score and can be used as a loss function.
     only dice scores of foreground classes are returned, since training typically
     does not benefit from explicit background optimization. Pixels of the entire batch are considered a pseudo-volume to compute dice scores of.
     This way, single patches with missing foreground classes can not produce faulty gradients.
     :param pred: (b, c, y, x, (z)), softmax probabilities (network output). (c==classes)
     :param y: (b, c, y, x, (z)), one-hot-encoded segmentation mask.
     :param false_positive_weight: float [0,1]. For weighting of imbalanced classes,
     reduces the penalty for false-positive pixels. Can be beneficial sometimes in data with heavy fg/bg imbalances.
     :return: soft dice score (float). This function discards the background score and returns the mean of foreground scores.
     '''
     if len(pred.size()) == 4:
         axes = (0, 2, 3)
         intersect = sum_tensor(pred * y, axes, keepdim=False)
         denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False)
         return torch.mean(( (2 * intersect + smooth) / (denom + smooth) )[1:]) # only fg dice here.
 
     elif len(pred.size()) == 5:
         axes = (0, 2, 3, 4)
         intersect = sum_tensor(pred * y, axes, keepdim=False)
         denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False)
         return torch.mean(( (2*intersect + smooth) / (denom + smooth) )[1:]) # only fg dice here.
 
     else:
         raise ValueError('wrong input dimension in dice loss')
 
 
 
 
 def batch_dice_mask(pred, y, mask, false_positive_weight=1.0, smooth=1e-6):
     '''
     compute soft dice over batch. this is a diffrentiable score and can be used as a loss function.
     only dice scores of foreground classes are returned, since training typically
     does not benefit from explicit background optimization. Pixels of the entire batch are considered a pseudo-volume to compute dice scores of.
     This way, single patches with missing foreground classes can not produce faulty gradients.
     :param pred: (b, c, y, x, (z)), softmax probabilities (network output).
     :param y: (b, c, y, x, (z)), one hote encoded segmentation mask.
     :param false_positive_weight: float [0,1]. For weighting of imbalanced classes,
     reduces the penalty for false-positive pixels. Can be beneficial sometimes in data with heavy fg/bg imbalances.
     :return: soft dice score (float). This function discards the background score and returns the mean of foreground scores.
     '''
 
     mask = mask.unsqueeze(1).repeat(1, 2, 1, 1)
 
     if len(pred.size()) == 4:
         axes = (0, 2, 3)
         intersect = sum_tensor(pred * y * mask, axes, keepdim=False)
         denom = sum_tensor(false_positive_weight*pred * mask + y * mask, axes, keepdim=False)
         return torch.mean(( (2*intersect + smooth) / (denom + smooth))[1:]) # only fg dice here.
 
     elif len(pred.size()) == 5:
         axes = (0, 2, 3, 4)
         intersect = sum_tensor(pred * y, axes, keepdim=False)
         denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False)
         return torch.mean(( (2*intersect + smooth) / (denom + smooth) )[1:]) # only fg dice here.
 
     else:
         raise ValueError('wrong input dimension in dice loss')
\ No newline at end of file