diff --git a/custom_extensions/roi_align/roi_align.py b/custom_extensions/roi_align/roi_align.py
index 0091414..3054b2c 100644
--- a/custom_extensions/roi_align/roi_align.py
+++ b/custom_extensions/roi_align/roi_align.py
@@ -1,128 +1,132 @@
 """
 ROIAlign implementation from pytorch framework
 (https://github.com/pytorch/vision/blob/master/torchvision/ops/roi_align.py on Nov 14 2019)
 
 adapted for 3D support without additional python function interface (only cpp function interface).
 """
 
 import torch
 from torch import nn
 
 from torchvision.ops._utils import convert_boxes_to_roi_format
 
 import roi_al_extension
 import roi_al_extension_3d
 
 def roi_align_2d(input: torch.Tensor, boxes, output_size,
                  spatial_scale: float = 1.0, sampling_ratio: int =-1) -> torch.Tensor:
     """
     Performs Region of Interest (RoI) Align operator described in Mask R-CNN
 
     Arguments:
         input: (Tensor[N, C, H, W]), input tensor
-        boxes: (Tensor[K, 5] or List[Tensor[L, 4]]), the box coordinates in (x1, y1, x2, y2)
+        boxes: (Tensor[K, 5] or List[Tensor[L, 4]]), the box coordinates in (y1, x1, y2, x2)
+            NOTE: the order of box coordinates, (y1, x1, y2, x2), is swapped w.r.t. to the order in the
+                original torchvision implementation (which requires (x1, y1, x2, y2)).
             format where the regions will be taken from. If a single Tensor is passed,
             then the first column should contain the batch index. If a list of Tensors
             is passed, then each Tensor will correspond to the boxes for an element i
             in a batch
-        output_size: (int or Tuple[int, int]) the size of the output after the cropping
+        output_size: (Tuple[int, int]) the size of the output after the cropping
             is performed, as (height, width)
         spatial_scale: (float) a scaling factor that maps the input coordinates to
             the box coordinates. Default: 1.0
         sampling_ratio: (int) number of sampling points in the interpolation grid
             used to compute the output value of each pooled output bin. If > 0,
             then exactly sampling_ratio x sampling_ratio grid points are used. If
             <= 0, then an adaptive number of grid points are used (computed as
             ceil(roi_width / pooled_w), and likewise for height). Default: -1
 
     Returns:
         output (Tensor[K, C, output_size[0], output_size[1]])
     """
     rois = boxes
     if not isinstance(rois, torch.Tensor):
         rois = convert_boxes_to_roi_format(rois)
     return roi_al_extension.roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio)
 
 
 def roi_align_3d(input: torch.Tensor, boxes, output_size,
                  spatial_scale: float = 1.0, sampling_ratio: int = -1) -> torch.Tensor:
     """
     Performs Region of Interest (RoI) Align operator described in Mask R-CNN for 3-dim input.
 
     Arguments:
         input (Tensor[N, C, H, W, D]): input tensor
-        boxes (Tensor[K, 7] or List[Tensor[L, 6]]): the box coordinates in (x1, y1, x2, y2, z1 ,z2)
+        boxes (Tensor[K, 7] or List[Tensor[L, 6]]): the box coordinates in (y1, x1, y2, x2, z1, z2).
+            NOTE: the order of x, y box coordinates, (y1, x1, y2, x2), is swapped w.r.t. to the order in the
+                original torchvision implementation (which requires (x1, y1, x2, y2)).
             format where the regions will be taken from. If a single Tensor is passed,
             then the first column should contain the batch index. If a list of Tensors
             is passed, then each Tensor will correspond to the boxes for an element i
             in a batch
         output_size (int or Tuple[int, int, int]): the size of the output after the cropping
             is performed, as (height, width, depth)
         spatial_scale (float): a scaling factor that maps the input coordinates to
             the box coordinates. Default: 1.0
         sampling_ratio (int): number of sampling points in the interpolation grid
             used to compute the output value of each pooled output bin. If > 0,
             then exactly sampling_ratio x sampling_ratio grid points are used. If
             <= 0, then an adaptive number of grid points are used (computed as
             ceil(roi_width / pooled_w), and likewise for height). Default: -1
 
     Returns:
         output (Tensor[K, C, output_size[0], output_size[1], output_size[2]])
     """
     rois = boxes
     if not isinstance(rois, torch.Tensor):
         rois = convert_boxes_to_roi_format(rois)
     return roi_al_extension_3d.roi_align(input, rois, spatial_scale, output_size[0], output_size[1], output_size[2],
                                          sampling_ratio)
 
 
 class RoIAlign(nn.Module):
     """
     Performs Region of Interest (RoI) Align operator described in Mask R-CNN for 2- or 3-dim input.
 
     Arguments:
         input (Tensor[N, C, H, W(, D)]): input tensor
         boxes (Tensor[K, 5] or List[Tensor[L, 4]]) or (Tensor[K, 7] or List[Tensor[L, 6]]):
             the box coordinates in (x1, y1, x2, y2(, z1 ,z2))
             format where the regions will be taken from. If a single Tensor is passed,
             then the first column should contain the batch index. If a list of Tensors
             is passed, then each Tensor will correspond to the boxes for an element i
             in a batch
         output_size (int or Tuple[int, int(, int)]): the size of the output after the cropping
             is performed, as (height, width(, depth))
         spatial_scale (float): a scaling factor that maps the input coordinates to
             the box coordinates. Default: 1.0
         sampling_ratio (int): number of sampling points in the interpolation grid
             used to compute the output value of each pooled output bin. If > 0,
             then exactly sampling_ratio x sampling_ratio grid points are used. If
             <= 0, then an adaptive number of grid points are used (computed as
             ceil(roi_width / pooled_w), and likewise for height (and depth)). Default: -1
 
     Returns:
         output (Tensor[K, C, output_size[0], output_size[1](, output_size[2])])
     """
     def __init__(self, output_size, spatial_scale=1., sampling_ratio=-1):
         super(RoIAlign, self).__init__()
         self.output_size = output_size
         self.spatial_scale = spatial_scale
         self.sampling_ratio = sampling_ratio
         self.dim = len(self.output_size)
 
         if self.dim == 2:
             self.roi_align = roi_align_2d
         elif self.dim == 3:
             self.roi_align = roi_align_3d
         else:
             raise Exception("Tried to init RoIAlign module with incorrect output size: {}".format(self.output_size))
 
     def forward(self, input: torch.Tensor, rois) -> torch.Tensor:
         return self.roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio)
 
     def __repr__(self):
         tmpstr = self.__class__.__name__ + '('
         tmpstr += 'output_size=' + str(self.output_size)
         tmpstr += ', spatial_scale=' + str(self.spatial_scale)
         tmpstr += ', sampling_ratio=' + str(self.sampling_ratio)
         tmpstr += ', dimension=' + str(self.dim)
         tmpstr += ')'
         return tmpstr
diff --git a/custom_extensions/roi_align/src/RoIAlign_cuda.cu b/custom_extensions/roi_align/src/RoIAlign_cuda.cu
index 7794eb1..47c870a 100644
--- a/custom_extensions/roi_align/src/RoIAlign_cuda.cu
+++ b/custom_extensions/roi_align/src/RoIAlign_cuda.cu
@@ -1,427 +1,422 @@
 /*
 ROIAlign implementation in CUDA from pytorch framework
 (https://github.com/pytorch/vision/tree/master/torchvision/csrc/cuda on Nov 14 2019)
 
 */
 
 #include <ATen/ATen.h>
 #include <ATen/TensorUtils.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAApplyUtils.cuh>
-
+#include <typeinfo>
 #include "cuda_helpers.h"
 
 template <typename T>
 __device__ T bilinear_interpolate(
     const T* input,
     const int height,
     const int width,
     T y,
     T x,
     const int index /* index for debug only*/) {
   // deal with cases that inverse elements are out of feature map boundary
   if (y < -1.0 || y > height || x < -1.0 || x > width) {
     // empty
     return 0;
   }
 
   if (y <= 0)
     y = 0;
   if (x <= 0)
     x = 0;
 
   int y_low = (int)y;
   int x_low = (int)x;
   int y_high;
   int x_high;
 
   if (y_low >= height - 1) {
     y_high = y_low = height - 1;
     y = (T)y_low;
   } else {
     y_high = y_low + 1;
   }
 
   if (x_low >= width - 1) {
     x_high = x_low = width - 1;
     x = (T)x_low;
   } else {
     x_high = x_low + 1;
   }
 
   T ly = y - y_low;
   T lx = x - x_low;
   T hy = 1. - ly, hx = 1. - lx;
 
   // do bilinear interpolation
   T v1 = input[y_low * width + x_low];
   T v2 = input[y_low * width + x_high];
   T v3 = input[y_high * width + x_low];
   T v4 = input[y_high * width + x_high];
   T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
 
   T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
 
   return val;
 }
 
 template <typename T>
 __global__ void RoIAlignForward(
     const int nthreads,
     const T* input,
     const T spatial_scale,
     const int channels,
     const int height,
     const int width,
     const int pooled_height,
     const int pooled_width,
     const int sampling_ratio,
     const T* rois,
     T* output) {
   CUDA_1D_KERNEL_LOOP(index, nthreads) {
     // (n, c, ph, pw) is an element in the pooled output
-    int pw = index % pooled_width;
-    int ph = (index / pooled_width) % pooled_height;
-    int c = (index / pooled_width / pooled_height) % channels;
-    int n = index / pooled_width / pooled_height / channels;
+    const int pw = index % pooled_width;
+    const int ph = (index / pooled_width) % pooled_height;
+    const int c = (index / pooled_width / pooled_height) % channels;
+    const int n = index / pooled_width / pooled_height / channels;
 
     const T* offset_rois = rois + n * 5;
     int roi_batch_ind = offset_rois[0];
 
     // Do not using rounding; this implementation detail is critical
-    T roi_start_w = offset_rois[1] * spatial_scale;
-    T roi_start_h = offset_rois[2] * spatial_scale;
-    T roi_end_w = offset_rois[3] * spatial_scale;
-    T roi_end_h = offset_rois[4] * spatial_scale;
+    T roi_start_h = offset_rois[1] * spatial_scale;
+    T roi_start_w = offset_rois[2] * spatial_scale;
+    T roi_end_h = offset_rois[3] * spatial_scale;
+    T roi_end_w = offset_rois[4] * spatial_scale;
 
     // Force malformed ROIs to be 1x1
     T roi_width = max(roi_end_w - roi_start_w, (T)1.);
     T roi_height = max(roi_end_h - roi_start_h, (T)1.);
-    //printf("roi height %f, width %f\n", (float) roi_height, (float) roi_width);
+
     T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
     T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
 
     const T* offset_input =
         input + (roi_batch_ind * channels + c) * height * width;
 
     // We use roi_bin_grid to sample the grid and mimic integral
     int roi_bin_grid_h = (sampling_ratio > 0)
         ? sampling_ratio
         : ceil(roi_height / pooled_height); // e.g., = 2
     int roi_bin_grid_w =
         (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
 
     // We do average (integral) pooling inside a bin
     const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
-
     T output_val = 0.;
     for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
     {
       const T y = roi_start_h + ph * bin_size_h +
-          static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
+          static_cast<T>(iy + .5f) * (bin_size_h - 1.f) / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
       for (int ix = 0; ix < roi_bin_grid_w; ix++) {
         const T x = roi_start_w + pw * bin_size_w +
-            static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
-
+            static_cast<T>(ix + .5f) * (bin_size_w - 1.f) / static_cast<T>(roi_bin_grid_w);
         T val = bilinear_interpolate(offset_input, height, width, y, x, index);
         output_val += val;
       }
     }
     output_val /= count;
 
     output[index] = output_val;
   }
 }
 
 template <typename T>
 __device__ void bilinear_interpolate_gradient(
     const int height,
     const int width,
     T y,
     T x,
     T& w1,
     T& w2,
     T& w3,
     T& w4,
     int& x_low,
     int& x_high,
     int& y_low,
     int& y_high,
     const int index /* index for debug only*/) {
   // deal with cases that inverse elements are out of feature map boundary
   if (y < -1.0 || y > height || x < -1.0 || x > width) {
     // empty
     w1 = w2 = w3 = w4 = 0.;
     x_low = x_high = y_low = y_high = -1;
     return;
   }
 
   if (y <= 0)
     y = 0;
   if (x <= 0)
     x = 0;
 
   y_low = (int)y;
   x_low = (int)x;
 
   if (y_low >= height - 1) {
     y_high = y_low = height - 1;
     y = (T)y_low;
   } else {
     y_high = y_low + 1;
   }
 
   if (x_low >= width - 1) {
     x_high = x_low = width - 1;
     x = (T)x_low;
   } else {
     x_high = x_low + 1;
   }
 
   T ly = y - y_low;
   T lx = x - x_low;
   T hy = 1. - ly, hx = 1. - lx;
 
   // reference in forward
   // T v1 = input[y_low * width + x_low];
   // T v2 = input[y_low * width + x_high];
   // T v3 = input[y_high * width + x_low];
   // T v4 = input[y_high * width + x_high];
   // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
 
   w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
 
   return;
 }
 
 template <typename T>
 __global__ void RoIAlignBackward(
     const int nthreads,
     const T* grad_output,
     const T spatial_scale,
     const int channels,
     const int height,
     const int width,
     const int pooled_height,
     const int pooled_width,
     const int sampling_ratio,
     T* grad_input,
     const T* rois,
     const int n_stride,
     const int c_stride,
     const int h_stride,
     const int w_stride)
 {
   CUDA_1D_KERNEL_LOOP(index, nthreads) {
     // (n, c, ph, pw) is an element in the pooled output
     int pw = index % pooled_width;
     int ph = (index / pooled_width) % pooled_height;
     int c = (index / pooled_width / pooled_height) % channels;
     int n = index / pooled_width / pooled_height / channels;
 
     const T* offset_rois = rois + n * 5;
     int roi_batch_ind = offset_rois[0];
 
     // Do not using rounding; this implementation detail is critical
-    T roi_start_w = offset_rois[1] * spatial_scale;
-    T roi_start_h = offset_rois[2] * spatial_scale;
-    T roi_end_w = offset_rois[3] * spatial_scale;
-    T roi_end_h = offset_rois[4] * spatial_scale;
+    T roi_start_h = offset_rois[1] * spatial_scale;
+    T roi_start_w = offset_rois[2] * spatial_scale;
+    T roi_end_h = offset_rois[3] * spatial_scale;
+    T roi_end_w = offset_rois[4] * spatial_scale;
 
     // Force malformed ROIs to be 1x1
     T roi_width = max(roi_end_w - roi_start_w, (T)1.);
     T roi_height = max(roi_end_h - roi_start_h, (T)1.);
     T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
     T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
 
     T* offset_grad_input =
         grad_input + ((roi_batch_ind * channels + c) * height * width);
 
     // We need to index the gradient using the tensor strides to access the
     // correct values.
     int output_offset = n * n_stride + c * c_stride;
     const T* offset_grad_output = grad_output + output_offset;
     const T grad_output_this_bin =
         offset_grad_output[ph * h_stride + pw * w_stride];
 
     // We use roi_bin_grid to sample the grid and mimic integral
     int roi_bin_grid_h = (sampling_ratio > 0)
         ? sampling_ratio
         : ceil(roi_height / pooled_height); // e.g., = 2
     int roi_bin_grid_w =
         (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
 
     // We do average (integral) pooling inside a bin
     const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
 
     for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
     {
       const T y = roi_start_h + ph * bin_size_h +
-          static_cast<T>(iy + .5f) * bin_size_h /
-              static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
+          static_cast<T>(iy + .5f) * (bin_size_h - 1.f) / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
       for (int ix = 0; ix < roi_bin_grid_w; ix++) {
-        const T x = roi_start_w + pw * bin_size_w +
-            static_cast<T>(ix + .5f) * bin_size_w /
-                static_cast<T>(roi_bin_grid_w);
+        const T x = roi_start_w + pw * bin_size_w  +
+            static_cast<T>(ix + .5f) * (bin_size_w - 1.f) / static_cast<T>(roi_bin_grid_w);
 
         T w1, w2, w3, w4;
         int x_low, x_high, y_low, y_high;
 
         bilinear_interpolate_gradient(
             height,
             width,
             y,
             x,
             w1,
             w2,
             w3,
             w4,
             x_low,
             x_high,
             y_low,
             y_high,
             index);
 
         T g1 = grad_output_this_bin * w1 / count;
         T g2 = grad_output_this_bin * w2 / count;
         T g3 = grad_output_this_bin * w3 / count;
         T g4 = grad_output_this_bin * w4 / count;
 
         if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
           atomicAdd(
               offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
           atomicAdd(
               offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
           atomicAdd(
               offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
           atomicAdd(
               offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
         } // if
       } // ix
     } // iy
   } // CUDA_1D_KERNEL_LOOP
 } // RoIAlignBackward
 
 at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, const at::Tensor& rois, const float spatial_scale,
                                 const int pooled_height, const int pooled_width, const int sampling_ratio) {
   /*
    input: feature-map tensor, shape (batch, n_channels, y, x(, z))
    */
   AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor");
   AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
 
   at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
 
   at::CheckedFrom c = "ROIAlign_forward_cuda";
   at::checkAllSameGPU(c, {input_t, rois_t});
   at::checkAllSameType(c, {input_t, rois_t});
 
   at::cuda::CUDAGuard device_guard(input.device());
 
-  auto num_rois = rois.size(0);
-  auto channels = input.size(1);
-  auto height = input.size(2);
-  auto width = input.size(3);
+  int num_rois = rois.size(0);
+  int channels = input.size(1);
+  int height = input.size(2);
+  int width = input.size(3);
 
   at::Tensor output = at::zeros(
       {num_rois, channels, pooled_height, pooled_width}, input.options());
 
   auto output_size = num_rois * pooled_height * pooled_width * channels;
   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
   dim3 grid(std::min(
       at::cuda::ATenCeilDiv(
           static_cast<int64_t>(output_size), static_cast<int64_t>(512)),
       static_cast<int64_t>(4096)));
   dim3 block(512);
 
   if (output.numel() == 0) {
     AT_CUDA_CHECK(cudaGetLastError());
     return output;
   }
 
-
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign_forward", [&] {
     RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>(
         output_size,
         input.contiguous().data_ptr<scalar_t>(),
         spatial_scale,
         channels,
         height,
         width,
         pooled_height,
         pooled_width,
         sampling_ratio,
         rois.contiguous().data_ptr<scalar_t>(),
         output.data_ptr<scalar_t>());
   });
   AT_CUDA_CHECK(cudaGetLastError());
   return output;
 }
 
 at::Tensor ROIAlign_backward_cuda(
     const at::Tensor& grad,
     const at::Tensor& rois,
     const float spatial_scale,
     const int pooled_height,
     const int pooled_width,
     const int batch_size,
     const int channels,
     const int height,
     const int width,
     const int sampling_ratio) {
   AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor");
   AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
 
   at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
 
   at::CheckedFrom c = "ROIAlign_backward_cuda";
   at::checkAllSameGPU(c, {grad_t, rois_t});
   at::checkAllSameType(c, {grad_t, rois_t});
 
   at::cuda::CUDAGuard device_guard(grad.device());
 
   at::Tensor grad_input =
       at::zeros({batch_size, channels, height, width}, grad.options());
 
   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
   dim3 grid(std::min(
       at::cuda::ATenCeilDiv(
           static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)),
       static_cast<int64_t>(4096)));
   dim3 block(512);
 
   // handle possibly empty gradients
   if (grad.numel() == 0) {
     AT_CUDA_CHECK(cudaGetLastError());
     return grad_input;
   }
 
   int n_stride = grad.stride(0);
   int c_stride = grad.stride(1);
   int h_stride = grad.stride(2);
   int w_stride = grad.stride(3);
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_backward", [&] {
     RoIAlignBackward<scalar_t><<<grid, block, 0, stream>>>(
         grad.numel(),
         grad.data_ptr<scalar_t>(),
         spatial_scale,
         channels,
         height,
         width,
         pooled_height,
         pooled_width,
         sampling_ratio,
         grad_input.data_ptr<scalar_t>(),
         rois.contiguous().data_ptr<scalar_t>(),
         n_stride,
         c_stride,
         h_stride,
         w_stride);
   });
   AT_CUDA_CHECK(cudaGetLastError());
   return grad_input;
 }
\ No newline at end of file
diff --git a/custom_extensions/roi_align/src/RoIAlign_cuda_3d.cu b/custom_extensions/roi_align/src/RoIAlign_cuda_3d.cu
index a8022bb..0c75a34 100644
--- a/custom_extensions/roi_align/src/RoIAlign_cuda_3d.cu
+++ b/custom_extensions/roi_align/src/RoIAlign_cuda_3d.cu
@@ -1,488 +1,487 @@
 /*
 ROIAlign implementation in CUDA from pytorch framework
 (https://github.com/pytorch/vision/tree/master/torchvision/csrc/cuda on Nov 14 2019)
 
 Adapted for additional 3D capability by G. Ramien, DKFZ Heidelberg
 */
 
 #include <ATen/ATen.h>
 #include <ATen/TensorUtils.h>
 #include <ATen/cuda/CUDAContext.h>
 #include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAApplyUtils.cuh>
 #include <cstdio>
 #include "cuda_helpers.h"
 
 /*-------------- gpu kernels -----------------*/
 
 template <typename T>
 __device__ T linear_interpolate(const T xl, const T val_low, const T val_high){
 
   T val = (val_high - val_low) * xl + val_low;
   return val;
 }
 
 template <typename T>
 __device__ T trilinear_interpolate(const T* input, const int height, const int width, const int depth,
                 T y, T x, T z, const int index /* index for debug only*/) {
   // deal with cases that inverse elements are out of feature map boundary
   if (y < -1.0 || y > height || x < -1.0 || x > width || z < -1.0 || z > depth) {
     // empty
     return 0;
   }
   if (y <= 0)
     y = 0;
   if (x <= 0)
     x = 0;
   if (z <= 0)
     z = 0;
 
   int y0 = (int)y;
   int x0 = (int)x;
   int z0 = (int)z;
   int y1;
   int x1;
   int z1;
 
   if (y0 >= height - 1) {
   /*if nearest gridpoint to y on the lower end is on border or border-1, set low, high, mid(=actual point) to border-1*/
     y1 = y0 = height - 1;
     y = (T)y0;
   } else {
     /* y1 is one pixel from y0, y is the actual point somewhere in between  */
     y1 = y0 + 1;
   }
   if (x0 >= width - 1) {
     x1 = x0 = width - 1;
     x = (T)x0;
   } else {
     x1 = x0 + 1;
   }
   if (z0 >= depth - 1) {
     z1 = z0 = depth - 1;
     z = (T)z0;
   } else {
     z1 = z0 + 1;
   }
 
 
   // do linear interpolation of x values
   // distance of actual point to lower boundary point, already normalized since x_high - x0 = 1
   T dis = x - x0;
   /*  accessing element b,c,y,x,z in 1D-rolled-out array of a tensor with dimensions (B, C, Y, X, Z):
       tensor[b,c,y,x,z] = arr[ (((b*C+c)*Y+y)*X + x)*Z + z ] = arr[ alpha + (y*X + x)*Z + z ]
       with alpha = batch&channel locator = (b*C+c)*YXZ.
       hence, as current input pointer is already offset by alpha: y,x,z at input[( y*X + x)*Z + z], where
       X = width, Z = depth.
   */
   T x00 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z0], input[(y0*width+ x1)*depth+z0]);
   T x10 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z0], input[(y1*width+ x1)*depth+z0]);
   T x01 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z1], input[(y0*width+ x1)*depth+z1]);
   T x11 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z1], input[(y1*width+ x1)*depth+z1]);
 
   // linear interpol of y values = bilinear interpol of f(x,y)
   dis = y - y0;
   T xy0 = linear_interpolate(dis, x00, x10);
   T xy1 = linear_interpolate(dis, x01, x11);
 
   // linear interpol of z value = trilinear interpol of f(x,y,z)
   dis = z - z0;
   T xyz = linear_interpolate(dis, xy0, xy1);
 
   return xyz;
 }
 
 template <typename T>
 __device__ void trilinear_interpolate_gradient(const int height, const int width, const int depth, T y, T x, T z,
     T& g000, T& g001, T& g010, T& g100, T& g011, T& g101, T& g110, T& g111,
     int& x0, int& x1, int& y0, int& y1, int& z0, int&z1, const int index /* index for debug only*/)
 {
   // deal with cases that inverse elements are out of feature map boundary
   if (y < -1.0 || y > height || x < -1.0 || x > width || z < -1.0 || z > depth) {
     // empty
     g000 = g001 = g010 = g100 = g011 = g101 = g110 = g111 = 0.;
     x0 = x1 = y0 = y1 = z0 = z1 = -1;
     return;
   }
 
   if (y <= 0)
     y = 0;
   if (x <= 0)
     x = 0;
   if (z <= 0)
     z = 0;
 
   y0 = (int)y;
   x0 = (int)x;
   z0 = (int)z;
 
   if (y0 >= height - 1) {
     y1 = y0 = height - 1;
     y = (T)y0;
   } else {
     y1 = y0 + 1;
   }
 
   if (x0 >= width - 1) {
     x1 = x0 = width - 1;
     x = (T)x0;
   } else {
     x1 = x0 + 1;
   }
 
   if (z0 >= depth - 1) {
     z1 = z0 = depth - 1;
     z = (T)z0;
   } else {
     z1 = z0 + 1;
   }
 
   // forward calculations are added as hints
   T dis_x = x - x0;
   //T x00 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z0], input[(y0*width+ x1)*depth+z0]); // v000, v100
   //T x10 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z0], input[(y1*width+ x1)*depth+z0]); // v010, v110
   //T x01 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z1], input[(y0*width+ x1)*depth+z1]); // v001, v101
   //T x11 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z1], input[(y1*width+ x1)*depth+z1]); // v011, v111
 
   // linear interpol of y values = bilinear interpol of f(x,y)
   T dis_y = y - y0;
   //T xy0 = linear_interpolate(dis, x00, x10);
   //T xy1 = linear_interpolate(dis, x01, x11);
 
   // linear interpol of z value = trilinear interpol of f(x,y,z)
   T dis_z = z - z0;
   //T xyz = linear_interpolate(dis, xy0, xy1);
 
   /* need: grad_i := d(xyz)/d(v_i) with v_i = input_value_i  for all i = 0,..,7 (eight input values --> eight-entry gradient)
      d(lin_interp(dis,x,y))/dx = (-dis +1) and d(lin_interp(dis,x,y))/dy = dis --> derivatives are indep of x,y.
      notation: gxyz = gradient for d(trilin_interp)/d(input_value_at_xyz)
      below grads were calculated by hand
      save time by reusing (1-dis_x) = 1-x+x0 = x1-x =: dis_x1 */
   T dis_x1 = (1-dis_x), dis_y1 = (1-dis_y), dis_z1 = (1-dis_z);
 
   g000 = dis_z1 * dis_y1  * dis_x1;
   g001 = dis_z  * dis_y1  * dis_x1;
   g010 = dis_z1 * dis_y   * dis_x1;
   g100 = dis_z1 * dis_y1  * dis_x;
   g011 = dis_z  * dis_y   * dis_x1;
   g101 = dis_z  * dis_y1  * dis_x;
   g110 = dis_z1 * dis_y   * dis_x;
   g111 = dis_z  * dis_y   * dis_x;
 
   return;
 }
 
 template <typename T>
 __global__ void RoIAlignForward(const int nthreads, const T* input, const T spatial_scale, const int channels,
     const int height, const int width, const int depth, const int pooled_height, const int pooled_width,
     const int pooled_depth, const int sampling_ratio, const T* rois, T* output)
 {
 
   CUDA_1D_KERNEL_LOOP(index, nthreads) {
     // (n, c, ph, pw, pd) is an element in the pooled output
     int pd =  index % pooled_depth;
     int pw = (index / pooled_depth) % pooled_width;
     int ph = (index / pooled_depth / pooled_width) % pooled_height;
     int c  = (index / pooled_depth / pooled_width / pooled_height) % channels;
     int n  =  index / pooled_depth / pooled_width / pooled_height / channels;
 
 
     // rois are (y1,x1,y2,x2,z1,z2) --> tensor of shape (n_rois, 6)
     const T* offset_rois = rois + n * 7;
     int roi_batch_ind = offset_rois[0];
     // Do not use rounding; this implementation detail is critical
     T roi_start_h = offset_rois[1] * spatial_scale;
     T roi_start_w = offset_rois[2] * spatial_scale;
     T roi_end_h = offset_rois[3] * spatial_scale;
     T roi_end_w = offset_rois[4] * spatial_scale;
     T roi_start_d = offset_rois[5] * spatial_scale;
     T roi_end_d = offset_rois[6] * spatial_scale;
 
     // Force malformed ROIs to be 1x1
     T roi_height = max(roi_end_h - roi_start_h, (T)1.);
     T roi_width = max(roi_end_w - roi_start_w, (T)1.);
     T roi_depth = max(roi_end_d - roi_start_d, (T)1.);
 
     T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
     T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
     T bin_size_d = static_cast<T>(roi_depth) / static_cast<T>(pooled_depth);
 
     const T* offset_input =
         input + (roi_batch_ind * channels + c) * height * width * depth;
 
     // We use roi_bin_grid to sample the grid and mimic integral
     // roi_bin_grid == nr of sampling points per bin >= 1
     int roi_bin_grid_h =
         (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
     int roi_bin_grid_w =
         (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
     int roi_bin_grid_d =
         (sampling_ratio > 0) ? sampling_ratio : ceil(roi_depth / pooled_depth);
 
     // We do average (integral) pooling inside a bin
     const T n_voxels = roi_bin_grid_h * roi_bin_grid_w * roi_bin_grid_d; // e.g. = 4
 
     T output_val = 0.;
     for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
     {
       const T y = roi_start_h + ph * bin_size_h +
-          static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5, always in the middle of two grid pointsk
+          static_cast<T>(iy + .5f) * (bin_size_h - 1.f) / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5, always in the middle of two grid pointsk
 
       for (int ix = 0; ix < roi_bin_grid_w; ix++)
       {
         const T x = roi_start_w + pw * bin_size_w +
-            static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
+            static_cast<T>(ix + .5f) * (bin_size_w - 1.f) / static_cast<T>(roi_bin_grid_w);
 
         for (int iz = 0; iz < roi_bin_grid_d; iz++)
         {
           const T z = roi_start_d + pd * bin_size_d +
-              static_cast<T>(iz + .5f) * bin_size_d / static_cast<T>(roi_bin_grid_d);
+              static_cast<T>(iz + .5f) * (bin_size_d - 1.f) / static_cast<T>(roi_bin_grid_d);
           // TODO verify trilinear interpolation
           T val = trilinear_interpolate(offset_input, height, width, depth, y, x, z, index);
           output_val += val;
         } // z iterator and calc+add value
       } // x iterator
     } // y iterator
     output_val /= n_voxels;
 
     output[index] = output_val;
   }
 }
 
 template <typename T>
 __global__ void RoIAlignBackward(const int nthreads, const T* grad_output, const T spatial_scale, const int channels,
     const int height, const int width, const int depth, const int pooled_height, const int pooled_width,
     const int pooled_depth, const int sampling_ratio, T* grad_input, const T* rois,
     const int n_stride, const int c_stride, const int h_stride, const int w_stride, const int d_stride)
 {
 
   CUDA_1D_KERNEL_LOOP(index, nthreads) {
     // (n, c, ph, pw, pd) is an element in the pooled output
     int pd =  index % pooled_depth;
     int pw = (index / pooled_depth) % pooled_width;
     int ph = (index / pooled_depth / pooled_width) % pooled_height;
     int c  = (index / pooled_depth / pooled_width / pooled_height) % channels;
     int n  =  index / pooled_depth / pooled_width / pooled_height / channels;
 
 
     const T* offset_rois = rois + n * 7;
     int roi_batch_ind = offset_rois[0];
 
     // Do not using rounding; this implementation detail is critical
-    T roi_start_w = offset_rois[1] * spatial_scale;
-    T roi_start_h = offset_rois[2] * spatial_scale;
-    T roi_end_w = offset_rois[3] * spatial_scale;
-    T roi_end_h = offset_rois[4] * spatial_scale;
+    T roi_start_h = offset_rois[1] * spatial_scale;
+    T roi_start_w = offset_rois[2] * spatial_scale;
+    T roi_end_h = offset_rois[3] * spatial_scale;
+    T roi_end_w = offset_rois[4] * spatial_scale;
     T roi_start_d = offset_rois[5] * spatial_scale;
     T roi_end_d = offset_rois[6] * spatial_scale;
 
 
     // Force malformed ROIs to be 1x1
     T roi_width = max(roi_end_w - roi_start_w, (T)1.);
     T roi_height = max(roi_end_h - roi_start_h, (T)1.);
     T roi_depth = max(roi_end_d - roi_start_d, (T)1.);
     T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
     T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
     T bin_size_d = static_cast<T>(roi_depth) / static_cast<T>(pooled_depth);
 
     // offset: index b,c,y,x,z of tensor of shape (B,C,Y,X,Z) is
     // b*C*Y*X*Z + c * Y*X*Z + y * X*Z + x *Z + z = (b*C+c)Y*X*Z + ...
     T* offset_grad_input =
         grad_input + ((roi_batch_ind * channels + c) * height * width * depth);
 
     // We need to index the gradient using the tensor strides to access the correct values.
     int output_offset = n * n_stride + c * c_stride;
     const T* offset_grad_output = grad_output + output_offset;
     const T grad_output_this_bin = offset_grad_output[ph * h_stride + pw * w_stride + pd * d_stride];
 
     // We use roi_bin_grid to sample the grid and mimic integral
     int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
     int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
     int roi_bin_grid_d = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_depth / pooled_depth);
 
     // We do average (integral) pooling inside a bin
     const T n_voxels = roi_bin_grid_h * roi_bin_grid_w * roi_bin_grid_d; // e.g. = 6
 
     for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
     {
       const T y = roi_start_h + ph * bin_size_h +
-          static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
+          static_cast<T>(iy + .5f) * (bin_size_h - 1.f) / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
 
       for (int ix = 0; ix < roi_bin_grid_w; ix++)
       {
         const T x = roi_start_w + pw * bin_size_w +
-          static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
+          static_cast<T>(ix + .5f) * (bin_size_w - 1.f) / static_cast<T>(roi_bin_grid_w);
 
         for (int iz = 0; iz < roi_bin_grid_d; iz++)
         {
           const T z = roi_start_d + pd * bin_size_d +
-              static_cast<T>(iz + .5f) * bin_size_d / static_cast<T>(roi_bin_grid_d);
+              static_cast<T>(iz + .5f) * (bin_size_d - 1.f) / static_cast<T>(roi_bin_grid_d);
 
           T g000, g001, g010, g100, g011, g101, g110, g111; // will hold the current partial derivatives
           int x0, x1, y0, y1, z0, z1;
           /* notation: gxyz = gradient at xyz, where x,y,z need to lie on feature-map grid (i.e., =x0,x1 etc.) */
           trilinear_interpolate_gradient(height, width, depth, y, x, z,
                                          g000, g001, g010, g100, g011, g101, g110, g111,
                                          x0, x1, y0, y1, z0, z1, index);
           /* chain rule: derivatives (i.e., the gradient) of trilin_interpolate(v1,v2,v3,v4,...) (div by n_voxels
              as we actually need gradient of whole roi_align) are multiplied with gradient so far*/
           g000 *= grad_output_this_bin / n_voxels;
           g001 *= grad_output_this_bin / n_voxels;
           g010 *= grad_output_this_bin / n_voxels;
           g100 *= grad_output_this_bin / n_voxels;
           g011 *= grad_output_this_bin / n_voxels;
           g101 *= grad_output_this_bin / n_voxels;
           g110 *= grad_output_this_bin / n_voxels;
           g111 *= grad_output_this_bin / n_voxels;
 
           if (x0 >= 0 && x1 >= 0 && y0 >= 0 && y1 >= 0 && z0 >= 0 && z1 >= 0)
           { // atomicAdd(address, content) reads content under address, adds content to it, while: no other thread
             // can interfere with the memory at address during this operation (thread lock, therefore "atomic").
             atomicAdd(offset_grad_input + (y0 * width + x0) * depth + z0, static_cast<T>(g000));
             atomicAdd(offset_grad_input + (y0 * width + x0) * depth + z1, static_cast<T>(g001));
             atomicAdd(offset_grad_input + (y1 * width + x0) * depth + z0, static_cast<T>(g010));
             atomicAdd(offset_grad_input + (y0 * width + x1) * depth + z0, static_cast<T>(g100));
             atomicAdd(offset_grad_input + (y1 * width + x0) * depth + z1, static_cast<T>(g011));
             atomicAdd(offset_grad_input + (y0 * width + x1) * depth + z1, static_cast<T>(g101));
             atomicAdd(offset_grad_input + (y1 * width + x1) * depth + z0, static_cast<T>(g110));
             atomicAdd(offset_grad_input + (y1 * width + x1) * depth + z1, static_cast<T>(g111));
           } // if
         } // iz
       } // ix
     } // iy
   } // CUDA_1D_KERNEL_LOOP
 } // RoIAlignBackward
 
 
 /*----------- wrapper functions ----------------*/
 
 at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, const at::Tensor& rois, const float spatial_scale,
                                 const int pooled_height, const int pooled_width, const int pooled_depth, const int sampling_ratio) {
   /*
    input: feature-map tensor, shape (batch, n_channels, y, x(, z))
    */
   AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor");
   AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
 
   at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
 
   at::CheckedFrom c = "ROIAlign_forward_cuda";
   at::checkAllSameGPU(c, {input_t, rois_t});
   at::checkAllSameType(c, {input_t, rois_t});
 
   at::cuda::CUDAGuard device_guard(input.device());
 
   auto num_rois = rois.size(0);
   auto channels = input.size(1);
   auto height = input.size(2);
   auto width = input.size(3);
   auto depth = input.size(4);
-  //std::cout << "input.options" << input.options() << std::endl;
+
   at::Tensor output = at::zeros(
       {num_rois, channels, pooled_height, pooled_width, pooled_depth}, input.options());
 
   auto output_size = num_rois * channels * pooled_height * pooled_width * pooled_depth;
   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
   dim3 grid(std::min(
       at::cuda::ATenCeilDiv(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)));
   dim3 block(512);
 
   if (output.numel() == 0) {
     AT_CUDA_CHECK(cudaGetLastError());
     return output;
   }
 
-  //std::printf("launching kernel\n");
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign forward in 3d", [&] {
     RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>(
         output_size,
         input.contiguous().data_ptr<scalar_t>(),
         spatial_scale,
         channels,
         height,
         width,
         depth,
         pooled_height,
         pooled_width,
         pooled_depth,
         sampling_ratio,
         rois.contiguous().data_ptr<scalar_t>(),
         output.data_ptr<scalar_t>());
   });
   AT_CUDA_CHECK(cudaGetLastError());
   return output;
 }
 
 at::Tensor ROIAlign_backward_cuda(
     const at::Tensor& grad,
     const at::Tensor& rois,
     const float spatial_scale,
     const int pooled_height,
     const int pooled_width,
     const int pooled_depth,
     const int batch_size,
     const int channels,
     const int height,
     const int width,
     const int depth,
     const int sampling_ratio)
 {
   AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor");
   AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
 
   at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
 
   at::CheckedFrom c = "ROIAlign_backward_cuda";
   at::checkAllSameGPU(c, {grad_t, rois_t});
   at::checkAllSameType(c, {grad_t, rois_t});
 
   at::cuda::CUDAGuard device_guard(grad.device());
 
   at::Tensor grad_input =
       at::zeros({batch_size, channels, height, width, depth}, grad.options());
 
   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
   dim3 grid(std::min(
       at::cuda::ATenCeilDiv(
           static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)),
       static_cast<int64_t>(4096)));
   dim3 block(512);
 
   // handle possibly empty gradients
   if (grad.numel() == 0) {
     AT_CUDA_CHECK(cudaGetLastError());
     return grad_input;
   }
 
   int n_stride = grad.stride(0);
   int c_stride = grad.stride(1);
   int h_stride = grad.stride(2);
   int w_stride = grad.stride(3);
   int d_stride = grad.stride(4);
 
   AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign backward 3D", [&] {
     RoIAlignBackward<scalar_t><<<grid, block, 0, stream>>>(
         grad.numel(),
         grad.data_ptr<scalar_t>(),
         spatial_scale,
         channels,
         height,
         width,
         depth,
         pooled_height,
         pooled_width,
         pooled_depth,
         sampling_ratio,
         grad_input.data_ptr<scalar_t>(),
         rois.contiguous().data_ptr<scalar_t>(),
         n_stride,
         c_stride,
         h_stride,
         w_stride,
         d_stride);
   });
   AT_CUDA_CHECK(cudaGetLastError());
   return grad_input;
 }
\ No newline at end of file
diff --git a/datasets/toy/configs.py b/datasets/toy/configs.py
index a9e59b9..c13c954 100644
--- a/datasets/toy/configs.py
+++ b/datasets/toy/configs.py
@@ -1,491 +1,491 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import sys
 import os
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 import numpy as np
 from default_configs import DefaultConfigs
 from collections import namedtuple
 
 boxLabel = namedtuple('boxLabel', ["name", "color"])
 Label = namedtuple("Label", ['id', 'name', 'shape', 'radius', 'color', 'regression', 'ambiguities', 'gt_distortion'])
 binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'bin_vals'])
 
 class Configs(DefaultConfigs):
 
     def __init__(self, server_env=None):
         super(Configs, self).__init__(server_env)
 
         #########################
         #         Prepro        #
         #########################
 
         self.pp_rootdir = os.path.join('/mnt/HDD2TB/Documents/data/toy', "cyl1ps_dev_exact")
         self.pp_npz_dir = self.pp_rootdir+"_npz"
 
         self.pre_crop_size = [320,320,8] #y,x,z; determines pp data shape (2D easily implementable, but only 3D for now)
         self.min_2d_radius = 6 #in pixels
         self.n_train_samples, self.n_test_samples = 80, 80
 
         # not actually real one-hot encoding (ohe) but contains more info: roi-overlap only within classes.
         self.pp_create_ohe_seg = False
         self.pp_empty_samples_ratio = 0.1
 
         self.pp_place_radii_mid_bin = True
         self.pp_only_distort_2d = True
         # outer-most intensity of blurred radii, relative to inner-object intensity. <1 for decreasing, > 1 for increasing.
         # e.g.: setting 0.1 means blurred edge has min intensity 10% as large as inner-object intensity.
         self.pp_blur_min_intensity = 0.2
 
         self.max_instances_per_sample = 3 #how many max instances over all classes per sample (img if 2d, vol if 3d)
         self.max_instances_per_class = self.max_instances_per_sample  # how many max instances per image per class
         self.noise_scale = 0.  # std-dev of gaussian noise
 
         self.ambigs_sampling = "gaussian" #"gaussian" or "uniform"
         """ radius_calib: gt distort for calibrating uncertainty. Range of gt distortion is inferable from
             image by distinguishing it from the rest of the object.
             blurring width around edge will be shifted so that symmetric rel to orig radius.
             blurring scale: if self.ambigs_sampling is uniform, distribution's non-zero range (b-a) will be sqrt(12)*scale
             since uniform dist has variance (b-a)²/12. b,a will be placed symmetrically around unperturbed radius.
             if sampling is gaussian, then scale parameter sets one std dev, i.e., blurring width will be orig_radius * std_dev * 2.
         """
         self.ambiguities = {
              #set which classes to apply which ambs to below in class labels
              #choose out of: 'outer_radius', 'inner_radius', 'radii_relations'.
              #kind              #probability   #scale (gaussian std, relative to unperturbed value)
             #"outer_radius":     (1.,            0.5),
             #"outer_radius_xy":  (1.,            0.5),
             #"inner_radius":     (0.5,            0.1),
             #"radii_relations":  (0.5,            0.1),
             "radius_calib":     (1.,            1./6)
         }
 
         # shape choices: 'cylinder', 'block'
         #                        id,    name,       shape,      radius,                 color,              regression,     ambiguities,    gt_distortion
         self.pp_classes = [Label(1,     'cylinder', 'cylinder', ((6,6,1),(40,40,8)),    (*self.blue, 1.),   "radius_2d",    (),             ()),
                            #Label(2,      'block',      'block',        ((6,6,1),(40,40,8)),  (*self.aubergine,1.),  "radii_2d", (), ('radius_calib',))
             ]
 
 
         #########################
         #         I/O           #
         #########################
 
         self.data_sourcedir = '/mnt/HDD2TB/Documents/data/toy/cyl1ps_exact'
 
         if server_env:
             self.data_sourcedir = '/datasets/data_ramien/toy/cyl1ps_exact_npz'
 
 
         self.test_data_sourcedir = os.path.join(self.data_sourcedir, 'test')
         self.data_sourcedir = os.path.join(self.data_sourcedir, "train")
 
         self.info_df_name = 'info_df.pickle'
 
         # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'detection_fpn'].
-        self.model = 'retina_net'
+        self.model = 'retinau'
         self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net')
         self.model_path = os.path.join(self.source_dir, self.model_path)
 
 
         #########################
         #      Architecture     #
         #########################
 
         # one out of [2, 3]. dimension the model operates in.
         self.dim = 2
 
         # 'class', 'regression', 'regression_bin', 'regression_ken_gal'
         # currently only tested mode is a single-task at a time (i.e., only one task in below list)
         # but, in principle, tasks could be combined (e.g., object classes and regression per class)
         self.prediction_tasks = ['class',]
 
-        self.start_filts = 48 if self.dim == 2 else 18
+        self.start_filts = 36 if self.dim == 2 else 18
         self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2
         self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50'
         self.norm = None #'instance_norm' # one of None, 'instance_norm', 'batch_norm'
         self.relu = 'relu'
         # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform')
         self.weight_init = None
 
         self.regression_n_features = 1  # length of regressor target vector
 
 
         #########################
         #      Data Loader      #
         #########################
 
         self.num_epochs = 32
-        self.num_train_batches = 120 if self.dim == 2 else 80
+        self.num_train_batches = 100 if self.dim == 2 else 80
         self.batch_size = 12 if self.dim == 2 else 8
 
         self.n_cv_splits = 4
         # select modalities from preprocessed data
         self.channels = [0]
         self.n_channels = len(self.channels)
 
         # which channel (mod) to show as bg in plotting, will be extra added to batch if not in self.channels
         self.plot_bg_chan = 0
         self.crop_margin = [20, 20, 1]  # has to be smaller than respective patch_size//2
         self.patch_size_2D = self.pre_crop_size[:2]
         self.patch_size_3D = self.pre_crop_size[:2]+[8]
 
         # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation.
         self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D
 
         # ratio of free sampled batch elements before class balancing is triggered
         # (>0 to include "empty"/background patches.)
         self.batch_random_ratio = 0.2
         self.balance_target = "class_targets" if 'class' in self.prediction_tasks else "rg_bin_targets"
 
         self.observables_patient = []
         self.observables_rois = []
 
-        self.seed = 3 #for generating folds
+        self.seed = 3 # for generating folds
 
         #############################
         # Colors, Classes, Legends  #
         #############################
         self.plot_frequency = 1
 
         binary_bin_labels = [binLabel(1,  'r<=25',      (*self.green, 1.),      (1,25)),
                              binLabel(2,  'r>25',       (*self.red, 1.),        (25,))]
         quintuple_bin_labels = [binLabel(1,  'r2-10',   (*self.green, 1.),      (2,10)),
                                 binLabel(2,  'r10-20',  (*self.yellow, 1.),     (10,20)),
                                 binLabel(3,  'r20-30',  (*self.orange, 1.),     (20,30)),
                                 binLabel(4,  'r30-40',  (*self.bright_red, 1.), (30,40)),
                                 binLabel(5,  'r>40',    (*self.red, 1.), (40,))]
 
         # choose here if to do 2-way or 5-way regression-bin classification
         task_spec_bin_labels = quintuple_bin_labels
 
         self.class_labels = [
             # regression: regression-task label, either value or "(x,y,z)_radius" or "radii".
             # ambiguities: name of above defined ambig to apply to image data (not gt); need to be iterables!
             # gt_distortion: name of ambig to apply to gt only; needs to be iterable!
             #      #id  #name   #shape  #radius     #color              #regression #ambiguities    #gt_distortion
             Label(  0,  'bg',   None,   (0, 0, 0),  (*self.white, 0.),  (0, 0, 0),  (),             ())]
         if "class" in self.prediction_tasks:
             self.class_labels += self.pp_classes
         else:
             self.class_labels += [Label(1, 'object', 'object', ('various',), (*self.orange, 1.), ('radius_2d',), ("various",), ('various',))]
 
 
         if any(['regression' in task for task in self.prediction_tasks]):
             self.bin_labels = [binLabel(0,  'bg',       (*self.white, 1.),      (0,))]
             self.bin_labels += task_spec_bin_labels
             self.bin_id2label = {label.id: label for label in self.bin_labels}
             bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels]
             self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)}
             self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)]
             self.bin_dict = {label.id: label.name for label in self.bin_labels if label.id != 0}
 
         if self.class_specific_seg:
           self.seg_labels = self.class_labels
 
         self.box_type2label = {label.name: label for label in self.box_labels}
         self.class_id2label = {label.id: label for label in self.class_labels}
         self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0}
 
         self.seg_id2label = {label.id: label for label in self.seg_labels}
         self.cmap = {label.id: label.color for label in self.seg_labels}
 
         self.plot_prediction_histograms = True
         self.plot_stat_curves = False
         self.has_colorchannels = False
         self.plot_class_ids = True
 
         self.num_classes = len(self.class_dict)
         self.num_seg_classes = len(self.seg_labels)
 
         #########################
         #   Data Augmentation   #
         #########################
         self.do_aug = True
         self.da_kwargs = {
             'mirror': True,
             'mirror_axes': tuple(np.arange(0, self.dim, 1)),
             'do_elastic_deform': False,
             'alpha': (500., 1500.),
             'sigma': (40., 45.),
             'do_rotation': False,
             'angle_x': (0., 2 * np.pi),
             'angle_y': (0., 0),
             'angle_z': (0., 0),
             'do_scale': False,
             'scale': (0.8, 1.1),
             'random_crop': False,
             'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3),
             'border_mode_data': 'constant',
             'border_cval_data': 0,
             'order_data': 1
         }
 
         if self.dim == 3:
             self.da_kwargs['do_elastic_deform'] = False
             self.da_kwargs['angle_x'] = (0, 0.0)
             self.da_kwargs['angle_y'] = (0, 0.0)  # must be 0!!
             self.da_kwargs['angle_z'] = (0., 2 * np.pi)
 
         #########################
         #  Schedule / Selection #
         #########################
 
         # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training)
         # the former is morge accurate, while the latter is faster (depending on volume size)
         self.val_mode = 'val_sampling' # one of 'val_sampling' , 'val_patient'
         if self.val_mode == 'val_patient':
             self.max_val_patients = 220  # if 'all' iterates over entire val_set once.
         if self.val_mode == 'val_sampling':
             self.num_val_batches = 25 if self.dim==2 else 15
 
         self.save_n_models = 2
         self.min_save_thresh = 1 if self.dim == 2 else 1  # =wait time in epochs
         if "class" in self.prediction_tasks:
             self.model_selection_criteria = {name + "_ap": 1. for name in self.class_dict.values()}
         elif any("regression" in task for task in self.prediction_tasks):
             self.model_selection_criteria = {name + "_ap": 0.2 for name in self.class_dict.values()}
             self.model_selection_criteria.update({name + "_avp": 0.8 for name in self.class_dict.values()})
 
         self.lr_decay_factor = 0.5
         self.scheduling_patience = int(self.num_epochs / 5)
         self.weight_decay = 1e-5
         self.clip_norm = None  # number or None
 
         #########################
         #   Testing / Plotting  #
         #########################
 
         self.test_aug_axes = (0,1,(0,1)) # None or list: choices are 0,1,(0,1)
         self.held_out_test_set = True
         self.max_test_patients = "all"  # number or "all" for all
 
         self.test_against_exact_gt = not 'exact' in self.data_sourcedir
         self.val_against_exact_gt = False # True is an unrealistic --> irrelevant scenario.
         self.report_score_level = ['rois']  # 'patient' or 'rois' (incl)
         self.patient_class_of_interest = 1
         self.patient_bin_of_interest = 2
 
         self.eval_bins_separately = False#"additionally" if not 'class' in self.prediction_tasks else False
         self.metrics = ['ap', 'auc', 'dice']
         if any(['regression' in task for task in self.prediction_tasks]):
             self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp',
                              'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp']
         if 'aleatoric' in self.model:
             self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted']
         self.evaluate_fold_means = True
 
         self.ap_match_ious = [0.5]  # threshold(s) for considering a prediction as true positive
         self.min_det_thresh = 0.3
 
         self.model_max_iou_resolution = 0.2
 
         # aggregation method for test and val_patient predictions.
         # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf,
         # nms = standard non-maximum suppression, or None = no clustering
         self.clustering = 'wbc'
         # iou thresh (exclusive!) for regarding two preds as concerning the same ROI
         self.clustering_iou = self.model_max_iou_resolution  # has to be larger than desired possible overlap iou of model predictions
 
         self.merge_2D_to_3D_preds = False
         self.merge_3D_iou = self.model_max_iou_resolution
         self.n_test_plots = 1  # per fold and rank
 
         self.test_n_epochs = self.save_n_models  # should be called n_test_ens, since is number of models to ensemble over during testing
         # is multiplied by (1 + nr of test augs)
 
         #########################
         #   Assertions          #
         #########################
         if not 'class' in self.prediction_tasks:
             assert self.num_classes == 1
 
         #########################
         #   Add model specifics #
         #########################
 
         {'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs,
          'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs,
          'detection_unet': self.add_det_unet_configs, 'detection_fpn': self.add_det_fpn_configs
          }[self.model]()
 
     def rg_val_to_bin_id(self, rg_val):
         #only meant for isotropic radii!!
         # only 2D radii (x and y dims) or 1D (x or y) are expected
         return np.round(np.digitize(rg_val, self.bin_edges).mean())
 
 
     def add_det_fpn_configs(self):
 
       self.learning_rate = [1 * 1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True
       self.scheduling_criterion = 'torch_loss'
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       self.n_roi_candidates = 4 if self.dim == 2 else 6
       # max number of roi candidates to identify per image (slice in 2D, volume in 3D)
 
       # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
       self.seg_loss_mode = 'wce'
       self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1]
 
       self.fp_dice_weight = 1 if self.dim == 2 else 1
       # if <1, false positive predictions in foreground are penalized less.
 
       self.detection_min_confidence = 0.05
       # how to determine score of roi: 'max' or 'median'
       self.score_det = 'max'
 
     def add_det_unet_configs(self):
 
       self.learning_rate = [1 * 1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True
       self.scheduling_criterion = "torch_loss"
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       # max number of roi candidates to identify per image (slice in 2D, volume in 3D)
       self.n_roi_candidates = 4 if self.dim == 2 else 6
 
       # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
       self.seg_loss_mode = 'wce'
       self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1]
       # if <1, false positive predictions in foreground are penalized less.
       self.fp_dice_weight = 1 if self.dim == 2 else 1
 
       self.detection_min_confidence = 0.05
       # how to determine score of roi: 'max' or 'median'
       self.score_det = 'max'
 
       self.init_filts = 32
       self.kernel_size = 3  # ks for horizontal, normal convs
       self.kernel_size_m = 2  # ks for max pool
       self.pad = "same"  # "same" or integer, padding of horizontal convs
 
     def add_mrcnn_configs(self):
       self.frcnn_mode = False
 
       self.learning_rate = [1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True  # with scheduler set in exec
       self.scheduling_criterion = max(self.model_selection_criteria, key=self.model_selection_criteria.get)
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       # number of classes for network heads: n_foreground_classes + 1 (background)
       self.head_classes = self.num_classes + 1 if 'class' in self.prediction_tasks else 2
 
       # feed +/- n neighbouring slices into channel dimension. set to None for no context.
       self.n_3D_context = None
       if self.n_3D_context is not None and self.dim == 2:
         self.n_channels *= (self.n_3D_context * 2 + 1)
 
       self.detect_while_training = True
       # disable the re-sampling of mask proposals to original size for speed-up.
       # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching),
       # mask outputs are optional.
       self.return_masks_in_train = True
       self.return_masks_in_val = True
       self.return_masks_in_test = True
 
       # feature map strides per pyramid level are inferred from architecture. anchor scales are set accordingly.
       self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]}
       # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale
       # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.)
       self.rpn_anchor_scales = {'xy': [[4], [8], [16], [32]], 'z': [[1], [2], [4], [8]]}
       # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3.
       self.pyramid_levels = [0, 1, 2, 3]
       # number of feature maps in rpn. typically lowered in 3D to save gpu-memory.
-      self.n_rpn_features = 512 if self.dim == 2 else 64
+      self.n_rpn_features = 128 if self.dim == 2 else 64
 
       # anchor ratios and strides per position in feature maps.
       self.rpn_anchor_ratios = [0.5, 1., 2.]
       self.rpn_anchor_stride = 1
       # Threshold for first stage (RPN) non-maximum suppression (NMS):  LOWER == HARDER SELECTION
       self.rpn_nms_threshold = max(0.8, self.model_max_iou_resolution)
 
       # loss sampling settings.
       self.rpn_train_anchors_per_image = 4
       self.train_rois_per_image = 6 # per batch_instance
       self.roi_positive_ratio = 0.5
       self.anchor_matching_iou = 0.8
 
       # k negative example candidates are drawn from a pool of size k*shem_poolsize (stochastic hard-example mining),
       # where k<=#positive examples.
       self.shem_poolsize = 2
 
       self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3)
       self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5)
       self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10)
 
       self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
       self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
       self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]])
       self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1],
                              self.patch_size_3D[2], self.patch_size_3D[2]])  # y1,x1,y2,x2,z1,z2
 
       if self.dim == 2:
         self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4]
         self.bbox_std_dev = self.bbox_std_dev[:4]
         self.window = self.window[:4]
         self.scale = self.scale[:4]
 
       self.plot_y_max = 1.5
       self.n_plot_rpn_props = 5 if self.dim == 2 else 30  # per batch_instance (slice in 2D / patient in 3D)
 
       # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element.
       self.pre_nms_limit = 2000 if self.dim == 2 else 4000
 
       # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True,
       # since proposals of the entire batch are forwarded through second stage as one "batch".
       self.roi_chunk_size = 1300 if self.dim == 2 else 500
       self.post_nms_rois_training = 200 * (self.head_classes-1) if self.dim == 2 else 400
       self.post_nms_rois_inference = 200 * (self.head_classes-1)
 
       # Final selection of detections (refine_detections)
       self.model_max_instances_per_batch_element = 9 if self.dim == 2 else 18 # per batch element and class.
       self.detection_nms_threshold = self.model_max_iou_resolution  # needs to be > 0, otherwise all predictions are one cluster.
       self.model_min_confidence = 0.2  # iou for nms in box refining (directly after heads), should be >0 since ths>=x in mrcnn.py
 
       if self.dim == 2:
         self.backbone_shapes = np.array(
           [[int(np.ceil(self.patch_size[0] / stride)),
             int(np.ceil(self.patch_size[1] / stride))]
            for stride in self.backbone_strides['xy']])
       else:
         self.backbone_shapes = np.array(
           [[int(np.ceil(self.patch_size[0] / stride)),
             int(np.ceil(self.patch_size[1] / stride)),
             int(np.ceil(self.patch_size[2] / stride_z))]
            for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z']
                                        )])
 
       if self.model == 'retina_net' or self.model == 'retina_unet':
         # whether to use focal loss or SHEM for loss-sample selection
         self.focal_loss = False
         # implement extra anchor-scales according to https://arxiv.org/abs/1708.02002
         self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                         self.rpn_anchor_scales['xy']]
         self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                        self.rpn_anchor_scales['z']]
         self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3
 
         # pre-selection of detections for NMS-speedup. per entire batch.
         self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size
 
         # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002
         self.anchor_matching_iou = 0.7
 
         if self.model == 'retina_unet':
           self.operate_stride1 = True
diff --git a/datasets/toy/data_loader.py b/datasets/toy/data_loader.py
index dc9a03f..0590fae 100644
--- a/datasets/toy/data_loader.py
+++ b/datasets/toy/data_loader.py
@@ -1,597 +1,598 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import sys
 sys.path.append('../') # works on cluster indep from where sbatch job is started
 import plotting as plg
 
 import numpy as np
 import os
 from collections import OrderedDict
 import pandas as pd
 import pickle
 import time
 
 # batch generator tools from https://github.com/MIC-DKFZ/batchgenerators
 from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror
 from batchgenerators.transforms.abstract_transforms import Compose
 from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
 from batchgenerators.transforms.spatial_transforms import SpatialTransform
 from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform
 
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 import utils.dataloader_utils as dutils
 from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates
 
 
 def load_obj(file_path):
     with open(file_path, 'rb') as handle:
         return pickle.load(handle)
 
 class Dataset(dutils.Dataset):
     r""" Load a dict holding memmapped arrays and clinical parameters for each patient,
     evtly subset of those.
         If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to
         cf.data_dir.
     :param cf: config file
     :param folds: number of folds out of @params n_cv folds to include
     :param n_cv: number of total folds
     :return: dict with imgs, segs, pids, class_labels, observables
     """
 
     def __init__(self, cf, logger, subset_ids=None, data_sourcedir=None, mode='train'):
         super(Dataset,self).__init__(cf, data_sourcedir=data_sourcedir)
 
         load_exact_gts = (mode=='test' or cf.val_mode=="val_patient") and self.cf.test_against_exact_gt
 
         p_df = pd.read_pickle(os.path.join(self.data_dir, cf.info_df_name))
 
         if subset_ids is not None:
             p_df = p_df[p_df.pid.isin(subset_ids)]
             logger.info('subset: selected {} instances from df'.format(len(p_df)))
 
         pids = p_df.pid.tolist()
         #evtly copy data from data_sourcedir to data_dest
         if cf.server_env and not hasattr(cf, "data_dir"):
             file_subset = [os.path.join(self.data_dir, '{}.*'.format(pid)) for pid in pids]
             file_subset += [os.path.join(self.data_dir, '{}_seg.*'.format(pid)) for pid in pids]
             file_subset += [cf.info_df_name]
             if load_exact_gts:
                 file_subset += [os.path.join(self.data_dir, '{}_exact_seg.*'.format(pid)) for pid in pids]
             self.copy_data(cf, file_subset=file_subset)
 
         img_paths = [os.path.join(self.data_dir, '{}.npy'.format(pid)) for pid in pids]
         seg_paths = [os.path.join(self.data_dir, '{}_seg.npy'.format(pid)) for pid in pids]
         if load_exact_gts:
             exact_seg_paths = [os.path.join(self.data_dir, '{}_exact_seg.npy'.format(pid)) for pid in pids]
 
         class_targets = p_df['class_ids'].tolist()
         rg_targets = p_df['regression_vectors'].tolist()
         if load_exact_gts:
             exact_rg_targets = p_df['undistorted_rg_vectors'].tolist()
         fg_slices = p_df['fg_slices'].tolist()
 
         self.data = OrderedDict()
         for ix, pid in enumerate(pids):
             self.data[pid] = {'data': img_paths[ix], 'seg': seg_paths[ix], 'pid': pid,
                               'fg_slices': np.array(fg_slices[ix])}
             if load_exact_gts:
                 self.data[pid]['exact_seg'] = exact_seg_paths[ix]
             if 'class' in self.cf.prediction_tasks:
                 self.data[pid]['class_targets'] = np.array(class_targets[ix], dtype='uint8')
             else:
                 self.data[pid]['class_targets'] = np.ones_like(np.array(class_targets[ix]), dtype='uint8')
             if load_exact_gts:
                 self.data[pid]['exact_class_targets'] = self.data[pid]['class_targets']
             if any(['regression' in task for task in self.cf.prediction_tasks]):
                 self.data[pid]['regression_targets'] = np.array(rg_targets[ix], dtype='float16')
                 self.data[pid]["rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in rg_targets[ix]], dtype='uint8')
                 if load_exact_gts:
                     self.data[pid]['exact_regression_targets'] = np.array(exact_rg_targets[ix], dtype='float16')
                     self.data[pid]["exact_rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in exact_rg_targets[ix]],
                                                                 dtype='uint8')
 
 
         cf.roi_items = cf.observables_rois[:]
         cf.roi_items += ['class_targets']
         if any(['regression' in task for task in self.cf.prediction_tasks]):
             cf.roi_items += ['regression_targets']
             cf.roi_items += ['rg_bin_targets']
 
         self.set_ids = np.array(list(self.data.keys()))
         self.df = None
 
 class BatchGenerator(dutils.BatchGenerator):
     """
     creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D)
     from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size.
     Actual patch_size is obtained after data augmentation.
     :param data: data dictionary as provided by 'load_dataset'.
     :param batch_size: number of patients to sample for the batch
     :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target
     """
     def __init__(self, cf, data, sample_pids_w_replace=True):
         super(BatchGenerator, self).__init__(cf, data)
 
         self.chans = cf.channels if cf.channels is not None else np.index_exp[:]
         assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing"
 
 
         self.sample_pids_w_replace = sample_pids_w_replace
         self.eligible_pids = list(self._data.keys())
 
         self.crop_margin = np.array(self.cf.patch_size) / 8.  # min distance of ROI center to edge of cropped_patch.
         self.p_fg = 0.5
         self.empty_samples_max_ratio = 0.6
         self.random_count = int(cf.batch_random_ratio * cf.batch_size)
 
         self.balance_target_distribution(plot=sample_pids_w_replace)
         self.stats = {"roi_counts": np.zeros((len(self.unique_ts),), dtype='uint32'), "empty_samples_count": 0}
 
 
     def generate_train_batch(self):
         # everything done in here is per batch
         # print statements in here get confusing due to multithreading
         if self.sample_pids_w_replace:
             # fully random patients
             batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False))
             # target-balanced patients
             batch_patient_ids += list(np.random.choice(
                 self.dataset_pids, size=self.batch_size - self.random_count, replace=False, p=self.p_probs))
         else:
             batch_patient_ids = np.random.choice(self.eligible_pids, size=self.batch_size,
                                                  replace=False)
         if self.sample_pids_w_replace == False:
             self.eligible_pids = [pid for pid in self.eligible_pids if pid not in batch_patient_ids]
             if len(self.eligible_pids) < self.batch_size:
                 self.eligible_pids = self.dataset_pids
 
         batch_data, batch_segs, batch_patient_targets = [], [], []
         batch_roi_items = {name: [] for name in self.cf.roi_items}
         # record roi count of classes in batch
         # empty count for full bg samples (empty slices in 2D/patients in 3D) in slot num_classes (last)
         batch_roi_counts, empty_samples_count = np.zeros((len(self.unique_ts),), dtype='uint32'), 0
 
         for b in range(self.batch_size):
             patient = self._data[batch_patient_ids[b]]
 
             data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis]
             seg =  np.load(patient['seg'], mmap_mode='r').astype('uint8')
 
             (c, y, x, z) = data.shape
             if self.cf.dim == 2:
                 elig_slices, choose_fg = [], False
                 if len(patient['fg_slices']) > 0:
                     if empty_samples_count / self.batch_size >= self.empty_samples_max_ratio or np.random.rand(
                             1) <= self.p_fg:
                         # fg is to be picked
                         for tix in np.argsort(batch_roi_counts):
                             # pick slices of patient that have roi of sought-for target
                             # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix
                             elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero(
                                 patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] ==
                                 self.unique_ts[tix]) > 0]
                             if len(elig_slices) > 0:
                                 choose_fg = True
                                 break
                     else:
                         # pick bg
                         elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices'])
                 if len(elig_slices) > 0:
                     sl_pick_ix = np.random.choice(elig_slices, size=None)
                 else:
                     sl_pick_ix = np.random.choice(z, size=None)
                 data = data[..., sl_pick_ix]
                 seg = seg[..., sl_pick_ix]
 
             spatial_shp = data[0].shape
             assert spatial_shp == seg.shape, "spatial shape incongruence betw. data and seg"
             if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]):
                 new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))]
                 data = dutils.pad_nd_image(data, (len(data), *new_shape))
                 seg = dutils.pad_nd_image(seg, new_shape)
 
             # eventual cropping to pre_crop_size: sample pixel from random ROI and shift center,
             # if possible, to that pixel, so that img still contains ROI after pre-cropping
             dim_cropflags = [spatial_shp[i] > self.cf.pre_crop_size[i] for i in range(len(spatial_shp))]
             if np.any(dim_cropflags):
                 # sample pixel from random ROI and shift center, if possible, to that pixel
                 if self.cf.dim==3:
                     choose_fg = (empty_samples_count/self.batch_size>=self.empty_samples_max_ratio) or np.random.rand(1) <= self.p_fg
                 if choose_fg and np.any(seg):
                     available_roi_ids = np.unique(seg)[1:]
                     for tix in np.argsort(batch_roi_counts):
                         elig_roi_ids = available_roi_ids[patient[self.balance_target][available_roi_ids-1] == self.unique_ts[tix]]
                         if len(elig_roi_ids)>0:
                             seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None))
                             break
                     roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)]
                     assert seg[tuple(roi_anchor_pixel)] > 0
 
                     # sample the patch center coords. constrained by edges of image - pre_crop_size /2 and
                     # distance to the selected ROI < patch_size /2
                     def get_cropped_centercoords(dim):
                         low = np.max((self.cf.pre_crop_size[dim] // 2,
                                       roi_anchor_pixel[dim] - (
                                                   self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim])))
                         high = np.min((spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2,
                                        roi_anchor_pixel[dim] + (
                                                    self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim])))
                         if low >= high:  # happens if lesion on the edge of the image.
                             low = self.cf.pre_crop_size[dim] // 2
                             high = spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2
 
                         assert low < high, 'low greater equal high, data dimension {} too small, shp {}, patient {}, low {}, high {}'.format(
                             dim,
                             spatial_shp, patient['pid'], low, high)
                         return np.random.randint(low=low, high=high)
                 else:
                     # sample crop center regardless of ROIs, not guaranteed to be empty
                     def get_cropped_centercoords(dim):
                         return np.random.randint(low=self.cf.pre_crop_size[dim] // 2,
                                                  high=spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2)
 
                 sample_seg_center = {}
                 for dim in np.where(dim_cropflags)[0]:
                     sample_seg_center[dim] = get_cropped_centercoords(dim)
                     min_ = int(sample_seg_center[dim] - self.cf.pre_crop_size[dim] // 2)
                     max_ = int(sample_seg_center[dim] + self.cf.pre_crop_size[dim] // 2)
                     data = np.take(data, indices=range(min_, max_), axis=dim + 1)  # +1 for channeldim
                     seg = np.take(seg, indices=range(min_, max_), axis=dim)
 
             batch_data.append(data)
             batch_segs.append(seg[np.newaxis])
 
             for o in batch_roi_items: #after loop, holds every entry of every batchpatient per observable
                     batch_roi_items[o].append(patient[o])
 
             if self.cf.dim == 3:
                 for tix in range(len(self.unique_ts)):
                     batch_roi_counts[tix] += np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix])
             elif self.cf.dim == 2:
                 for tix in range(len(self.unique_ts)):
                     batch_roi_counts[tix] += np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix])
             if not np.any(seg):
                 empty_samples_count += 1
 
         batch = {'data': np.array(batch_data), 'seg': np.array(batch_segs).astype('uint8'),
                  'pid': batch_patient_ids,
                  'roi_counts': batch_roi_counts, 'empty_samples_count': empty_samples_count}
         for key,val in batch_roi_items.items(): #extend batch dic by entries of observables dic
             batch[key] = np.array(val)
 
         return batch
 
 class PatientBatchIterator(dutils.PatientBatchIterator):
     """
     creates a test generator that iterates over entire given dataset returning 1 patient per batch.
     Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actually evaluation (done in 3D),
     if willing to accept speed-loss during training.
     Specific properties of toy data set: toy data may be created with added ground-truth noise. thus, there are
     exact ground truths (GTs) and noisy ground truths available. the normal or noisy GTs are used in training by
     the BatchGenerator. The PatientIterator, however, may use the exact GTs if set in configs.
 
     :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or
     batch_size = n_2D_patches in 2D .
     """
 
     def __init__(self, cf, data, mode='test'):
         super(PatientBatchIterator, self).__init__(cf, data)
 
         self.patch_size = cf.patch_size_2D + [1] if cf.dim == 2 else cf.patch_size_3D
         self.chans = cf.channels if cf.channels is not None else np.index_exp[:]
         assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing"
 
         if (mode=="validation" and hasattr(self.cf, 'val_against_exact_gt') and self.cf.val_against_exact_gt) or \
                 (mode == 'test' and self.cf.test_against_exact_gt):
             self.gt_prefix = 'exact_'
             print("PatientIterator: Loading exact Ground Truths.")
         else:
             self.gt_prefix = ''
 
         self.patient_ix = 0  # running index over all patients in set
 
     def generate_train_batch(self, pid=None):
 
         if pid is None:
             pid = self.dataset_pids[self.patient_ix]
         patient = self._data[pid]
 
         # already swapped dimensions in pp from (c,)z,y,x to c,y,x,z or h,w,d to ease 2D/3D-case handling
         data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis]
         seg =  np.load(patient[self.gt_prefix+'seg']).astype('uint8')[np.newaxis]
 
         data_shp_raw = data.shape
         plot_bg = data[self.cf.plot_bg_chan] if self.cf.plot_bg_chan not in self.chans else None
         data = data[self.chans]
         discarded_chans = len(
             [c for c in np.setdiff1d(np.arange(data_shp_raw[0]), self.chans) if c < self.cf.plot_bg_chan])
         spatial_shp = data[0].shape  # spatial dims need to be in order x,y,z
         assert spatial_shp == seg[0].shape, "spatial shape incongruence betw. data and seg"
 
         if np.any([spatial_shp[i] < ps for i, ps in enumerate(self.patch_size)]):
             new_shape = [np.max([spatial_shp[i], self.patch_size[i]]) for i in range(len(self.patch_size))]
             data = dutils.pad_nd_image(data, new_shape)  # use 'return_slicer' to crop image back to original shape.
             seg = dutils.pad_nd_image(seg, new_shape)
             if plot_bg is not None:
                 plot_bg = dutils.pad_nd_image(plot_bg, new_shape)
 
         if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds:
             # adds the batch dim here bc won't go through MTaugmenter
             out_data = data[np.newaxis]
             out_seg = seg[np.newaxis]
             if plot_bg is not None:
                out_plot_bg = plot_bg[np.newaxis]
             # data and seg shape: (1,c,x,y,z), where c=1 for seg
 
             batch_3D = {'data': out_data, 'seg': out_seg}
             for o in self.cf.roi_items:
                 batch_3D[o] = np.array([patient[self.gt_prefix+o]])
             converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg)
             batch_3D = converter(**batch_3D)
             batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape})
             for o in self.cf.roi_items:
                 batch_3D["patient_" + o] = batch_3D[o]
 
         if self.cf.dim == 2:
             out_data = np.transpose(data, axes=(3, 0, 1, 2)).astype('float32')  # (c,y,x,z) to (b=z,c,x,y), use z=b as batchdim
             out_seg = np.transpose(seg, axes=(3, 0, 1, 2)).astype('uint8')  # (c,y,x,z) to (b=z,c,x,y)
 
             batch_2D = {'data': out_data, 'seg': out_seg}
             for o in self.cf.roi_items:
                 batch_2D[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(out_data), axis=0)
             converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg)
             batch_2D = converter(**batch_2D)
 
             if plot_bg is not None:
                 out_plot_bg = np.transpose(plot_bg, axes=(2, 0, 1)).astype('float32')
 
             if self.cf.merge_2D_to_3D_preds:
                 batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'],
                                  'original_img_shape': out_data.shape})
                 for o in self.cf.roi_items:
                     batch_2D["patient_" + o] = batch_3D[o]
             else:
                 batch_2D.update({'patient_bb_target': batch_2D['bb_target'],
                                  'original_img_shape': out_data.shape})
                 for o in self.cf.roi_items:
                     batch_2D["patient_" + o] = batch_2D[o]
 
         out_batch = batch_3D if self.cf.dim == 3 else batch_2D
         out_batch.update({'pid': np.array([patient['pid']] * len(out_data))})
 
         if self.cf.plot_bg_chan in self.chans and discarded_chans > 0:  # len(self.chans[:self.cf.plot_bg_chan])<data_shp_raw[0]:
             assert plot_bg is None
             plot_bg = int(self.cf.plot_bg_chan - discarded_chans)
             out_plot_bg = plot_bg
         if plot_bg is not None:
             out_batch['plot_bg'] = out_plot_bg
 
         # eventual tiling into patches
         spatial_shp = out_batch["data"].shape[2:]
         if np.any([spatial_shp[ix] > self.patch_size[ix] for ix in range(len(spatial_shp))]):
             patient_batch = out_batch
             print("patientiterator produced patched batch!")
             patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size)
             new_img_batch, new_seg_batch = [], []
 
             for c in patch_crop_coords_list:
                 new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:c[5]])
                 seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]]
                 new_seg_batch.append(seg_patch)
             shps = []
             for arr in new_img_batch:
                 shps.append(arr.shape)
 
             data = np.array(new_img_batch)  # (patches, c, x, y, z)
             seg = np.array(new_seg_batch)
             if self.cf.dim == 2:
                 # all patches have z dimension 1 (slices). discard dimension
                 data = data[..., 0]
                 seg = seg[..., 0]
             patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'),
                            'pid': np.array([patient['pid']] * data.shape[0])}
             for o in self.cf.roi_items:
                 patch_batch[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(patch_crop_coords_list), axis=0)
             #patient-wise (orig) batch info for putting the patches back together after prediction
             for o in self.cf.roi_items:
                 patch_batch["patient_"+o] = patient_batch["patient_"+o]
                 if self.cf.dim == 2:
                     # this could also be named "unpatched_2d_roi_items"
                     patch_batch["patient_" + o + "_2d"] = patient_batch[o]
             patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
             patch_batch['patient_bb_target'] = patient_batch['patient_bb_target']
             if self.cf.dim == 2:
                 patch_batch['patient_bb_target_2d'] = patient_batch['bb_target']
             patch_batch['patient_data'] = patient_batch['data']
             patch_batch['patient_seg'] = patient_batch['seg']
             patch_batch['original_img_shape'] = patient_batch['original_img_shape']
             if plot_bg is not None:
                 patch_batch['patient_plot_bg'] = patient_batch['plot_bg']
 
             converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, get_rois_from_seg=False,
                                                            class_specific_seg=self.cf.class_specific_seg)
 
             patch_batch = converter(**patch_batch)
             out_batch = patch_batch
 
         self.patient_ix += 1
         if self.patient_ix == len(self.dataset_pids):
             self.patient_ix = 0
 
         return out_batch
 
 
 def create_data_gen_pipeline(cf, patient_data, do_aug=True, sample_pids_w_replace=True):
     """
     create mutli-threaded train/val/test batch generation and augmentation pipeline.
     :param patient_data: dictionary containing one dictionary per patient in the train/test subset.
     :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing)
     :return: multithreaded_generator
     """
 
     # create instance of batch generator as first element in pipeline.
     data_gen = BatchGenerator(cf, patient_data, sample_pids_w_replace=sample_pids_w_replace)
 
     my_transforms = []
     if do_aug:
         if cf.da_kwargs["mirror"]:
             mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes'])
             my_transforms.append(mirror_transform)
 
         spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
                                              patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
                                              do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
                                              alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
                                              do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
                                              angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
                                              do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
                                              random_crop=cf.da_kwargs['random_crop'])
 
         my_transforms.append(spatial_transform)
     else:
         my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))
 
     my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg))
     all_transforms = Compose(my_transforms)
     # multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms)
     multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
     return multithreaded_generator
 
 def get_train_generators(cf, logger, data_statistics=False):
     """
     wrapper function for creating the training batch generator pipeline. returns the train/val generators.
     selects patients according to cv folds (generated by first run/fold of experiment):
     splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set)
     If cf.hold_out_test_set is True, adds the test split to the training data.
     """
     dataset = Dataset(cf, logger)
     dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits)
     dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle'))
     set_splits = dataset.fg.splits
 
     test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1)
     train_ids = np.concatenate(set_splits, axis=0)
 
     if cf.held_out_test_set:
         train_ids = np.concatenate((train_ids, test_ids), axis=0)
         test_ids = []
 
     train_data = {k: v for (k, v) in dataset.data.items() if str(k) in train_ids}
     val_data = {k: v for (k, v) in dataset.data.items() if str(k) in val_ids}
 
     logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids),
                                                                                     len(test_ids)))
     if data_statistics:
         dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids}, plot_dir=
         os.path.join(cf.plot_dir,"dataset"))
 
     batch_gen = {}
     batch_gen['train'] = create_data_gen_pipeline(cf, train_data, do_aug=cf.do_aug, sample_pids_w_replace=True)
     batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, do_aug=False, sample_pids_w_replace=False)
 
     if cf.val_mode == 'val_patient':
         batch_gen['val_patient'] = PatientBatchIterator(cf, val_data, mode='validation')
         batch_gen['n_val'] = len(val_ids) if cf.max_val_patients=="all" else min(len(val_ids), cf.max_val_patients)
     elif cf.val_mode == 'val_sampling':
         batch_gen['n_val'] = cf.num_val_batches if cf.num_val_batches != "all" else len(val_data)
 
     return batch_gen
 
 def get_test_generator(cf, logger):
     """
     if get_test_generators is possibly called multiple times in server env, every time of
     Dataset initiation rsync will check for copying the data; this should be okay
     since rsync will not copy if files already exist in destination.
     """
 
     if cf.held_out_test_set:
         sourcedir = cf.test_data_sourcedir
         test_ids = None
     else:
         sourcedir = None
         with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle:
             set_splits = pickle.load(handle)
         test_ids = set_splits[cf.fold]
 
     test_set = Dataset(cf, logger, subset_ids=test_ids, data_sourcedir=sourcedir, mode='test')
     logger.info("data set loaded with: {} test patients".format(len(test_set.set_ids)))
     batch_gen = {}
     batch_gen['test'] = PatientBatchIterator(cf, test_set.data)
     batch_gen['n_test'] = len(test_set.set_ids) if cf.max_test_patients=="all" else \
         min(cf.max_test_patients, len(test_set.set_ids))
 
     return batch_gen
 
 
 if __name__=="__main__":
 
     import utils.exp_utils as utils
     from configs import Configs
 
     cf = Configs()
 
     total_stime = time.time()
     times = {}
 
     # cf.server_env = True
     # cf.data_dir = "experiments/dev_data"
 
     cf.exp_dir = "experiments/dev/"
     cf.plot_dir = cf.exp_dir + "plots"
     os.makedirs(cf.exp_dir, exist_ok=True)
     cf.fold = 0
     logger = utils.get_logger(cf.exp_dir)
     gens = get_train_generators(cf, logger)
     train_loader = gens['train']
-    for i in range(0):
+    for i in range(1):
         stime = time.time()
         print("producing training batch nr ", i)
         ex_batch = next(train_loader)
         times["train_batch"] = time.time() - stime
         #experiments/dev/dev_exbatch_{}.png".format(i)
+        print(ex_batch.keys())
         plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=False)
 
 
     val_loader = gens['val_sampling']
     stime = time.time()
     for i in range(1):
         ex_batch = next(val_loader)
         times["val_batch"] = time.time() - stime
         stime = time.time()
         #"experiments/dev/dev_exvalbatch_{}.png"
         plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=True)
         times["val_plot"] = time.time() - stime
     import IPython; IPython.embed()
     #
     test_loader = get_test_generator(cf, logger)["test"]
     stime = time.time()
     ex_batch = test_loader.generate_train_batch(pid=None)
     times["test_batch"] = time.time() - stime
     stime = time.time()
     plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png", vmin=0)
     times["test_patchbatch_plot"] = time.time() - stime
 
 
 
     print("Times recorded throughout:")
     for (k, v) in times.items():
         print(k, "{:.2f}".format(v))
 
     mins, secs = divmod((time.time() - total_stime), 60)
     h, mins = divmod(mins, 60)
     t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
     print("{} total runtime: {}".format(os.path.split(__file__)[1], t))
\ No newline at end of file
diff --git a/unittests.py b/unittests.py
index 3b5ea77..ed02de1 100644
--- a/unittests.py
+++ b/unittests.py
@@ -1,492 +1,569 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import unittest
 
 import os
 import pickle
 import time
 from multiprocessing import  Pool
 import subprocess
+from pathlib import Path
 
 import numpy as np
 import pandas as pd
 import torch
 import torchvision as tv
 
 import tqdm
 
+import plotting as plg
 import utils.exp_utils as utils
 import utils.model_utils as mutils
 
 """ Note on unittests: run this file either in the way intended for unittests by starting the script with
     python -m unittest unittests.py or start it as a normal python file as python unittests.py.
     You can selective run single tests by calling python -m unittest unittests.TestClassOfYourChoice, where 
     TestClassOfYourChoice is the name of the test defined below, e.g., CompareFoldSplits.
 """
 
 
 
 def inspect_info_df(pp_dir):
     """ use your debugger to look into the info df of a pp dir.
     :param pp_dir: preprocessed-data directory
     """
 
     info_df = pd.read_pickle(os.path.join(pp_dir, "info_df.pickle"))
 
     return
 
 
 def generate_boxes(count, dim=2, h=100, w=100, d=20, normalize=False, on_grid=False, seed=0):
     """ generate boxes of format [y1, x1, y2, x2, (z1, z2)].
     :param count: nr of boxes
     :param dim: dimension of boxes (2 or 3)
     :return: boxes in format (n_boxes, 4 or 6), scores
     """
     np.random.seed(seed)
     if on_grid:
         lower_y = np.random.randint(0, h // 2, (count,))
         lower_x = np.random.randint(0, w // 2, (count,))
         upper_y = np.random.randint(h // 2, h, (count,))
         upper_x = np.random.randint(w // 2, w, (count,))
         if dim == 3:
             lower_z = np.random.randint(0, d // 2, (count,))
             upper_z = np.random.randint(d // 2, d, (count,))
     else:
         lower_y = np.random.rand(count) * h / 2.
         lower_x = np.random.rand(count) * w / 2.
         upper_y = (np.random.rand(count) + 1.) * h / 2.
         upper_x = (np.random.rand(count) + 1.) * w / 2.
         if dim == 3:
             lower_z = np.random.rand(count) * d / 2.
             upper_z = (np.random.rand(count) + 1.) * d / 2.
 
     if dim == 3:
         boxes = np.array(list(zip(lower_y, lower_x, upper_y, upper_x, lower_z, upper_z)))
         # add an extreme box that tests the boundaries
         boxes = np.concatenate((boxes, np.array([[0., 0., h, w, 0, d]])))
     else:
         boxes = np.array(list(zip(lower_y, lower_x, upper_y, upper_x)))
         boxes = np.concatenate((boxes, np.array([[0., 0., h, w]])))
 
     scores = np.random.rand(count + 1)
     if normalize:
         divisor = np.array([h, w, h, w, d, d]) if dim == 3 else np.array([h, w, h, w])
         boxes = boxes / divisor
     return boxes, scores
 
 #------- perform integrity checks on data set(s) -----------
 class VerifyLIDCSAIntegrity(unittest.TestCase):
     """ Perform integrity checks on preprocessed single-annotator GTs of LIDC data set.
     """
     @staticmethod
     def check_patient_sa_gt(pid, pp_dir, check_meta_files, check_info_df):
 
         faulty_cases = pd.DataFrame(columns=['pid', 'rater', 'cl_targets', 'roi_ids'])
 
         all_segs = np.load(os.path.join(pp_dir, pid + "_rois.npz"), mmap_mode='r')
         all_segs = all_segs[list(all_segs.keys())[0]]
         all_roi_ids = np.unique(all_segs[all_segs > 0])
         assert len(all_roi_ids) == np.max(all_segs), "roi ids not consecutive"
         if check_meta_files:
             meta_file = os.path.join(pp_dir, pid + "_meta_info.pickle")
             with open(meta_file, "rb") as handle:
                 info = pickle.load(handle)
             assert info["pid"] == pid, "wrong pid in meta_file"
             all_cl_targets = info["class_target"]
         if check_info_df:
             info_df = pd.read_pickle(os.path.join(pp_dir, "info_df.pickle"))
             pid_info = info_df[info_df.pid == pid]
             assert len(pid_info) == 1, "found {} entries for pid {} in info df, expected exactly 1".format(len(pid_info),
                                                                                                            pid)
             if check_meta_files:
                 assert pid_info[
                            "class_target"] == all_cl_targets, "meta_info and info_df class targets mismatch:\n{}\n{}".format(
                     pid_info["class_target"], all_cl_targets)
             all_cl_targets = pid_info["class_target"].iloc[0]
         assert len(all_roi_ids) == len(all_cl_targets)
         for rater in range(4):
             seg = all_segs[rater]
             roi_ids = np.unique(seg[seg > 0])
             cl_targs = np.array([roi[rater] for roi in all_cl_targets])
             assert np.count_nonzero(cl_targs) == len(roi_ids), "rater {} has targs {} but roi ids {}".format(rater, cl_targs, roi_ids)
             assert len(cl_targs) >= len(roi_ids), "not all marked rois have a label"
             for zeroix_roi_id, rating in enumerate(cl_targs):
                 if not ((rating > 0) == (np.any(seg == zeroix_roi_id + 1))):
                     print("\n\nFAULTY CASE:", end=" ", )
                     print("pid {}, rater {}, cl_targs {}, ids {}\n".format(pid, rater, cl_targs, roi_ids))
                     faulty_cases = faulty_cases.append(
                         {'pid': pid, 'rater': rater, 'cl_targets': cl_targs, 'roi_ids': roi_ids}, ignore_index=True)
         print("finished checking pid {}, {} faulty cases".format(pid, len(faulty_cases)))
         return faulty_cases
 
     def check_sa_gts(cf, pp_dir, pid_subset=None, check_meta_files=False, check_info_df=True, processes=os.cpu_count()):
         report_name = "verify_seg_label_pairings.csv"
         pids = {file_name.split("_")[0] for file_name in os.listdir(pp_dir) if file_name not in [report_name, "info_df.pickle"]}
         if pid_subset is not None:
             pids = [pid for pid in pids if pid in pid_subset]
 
 
         faulty_cases = pd.DataFrame(columns=['pid', 'rater', 'cl_targets', 'roi_ids'])
 
         p = Pool(processes=processes)
         mp_args = zip(pids, [pp_dir]*len(pids), [check_meta_files]*len(pids), [check_info_df]*len(pids))
         patient_cases = p.starmap(self.check_patient_sa_gt, mp_args)
         p.close(); p.join()
         faulty_cases = faulty_cases.append(patient_cases, sort=False)
 
 
         print("\n\nfaulty case count {}".format(len(faulty_cases)))
         print(faulty_cases)
         findings_file = os.path.join(pp_dir, "verify_seg_label_pairings.csv")
         faulty_cases.to_csv(findings_file)
 
         assert len(faulty_cases)==0, "there was a faulty case in data set {}.\ncheck {}".format(pp_dir, findings_file)
 
     def test(self):
         pp_root = "/mnt/HDD2TB/Documents/data/"
         pp_dir = "lidc/pp_20190805"
         gt_dir = os.path.join(pp_root, pp_dir, "patient_gts_sa")
         self.check_sa_gts(gt_dir, check_meta_files=True, check_info_df=False, pid_subset=None)  # ["0811a", "0812a"])
 
 #------ compare segmentation gts of preprocessed data sets ------
 class CompareSegGTs(unittest.TestCase):
     """ load and compare pre-processed gts by dice scores of segmentations.
 
     """
     @staticmethod
     def group_seg_paths(ref_path, comp_paths):
         # not working recursively
         ref_files = [fn for fn in os.listdir(ref_path) if
                      os.path.isfile(os.path.join(ref_path, fn)) and 'seg' in fn and fn.endswith('.npy')]
 
         comp_files = [[os.path.join(c_path, fn) for c_path in comp_paths] for fn in ref_files]
 
         ref_files = [os.path.join(ref_path, fn) for fn in ref_files]
 
         return zip(ref_files, comp_files)
 
     @staticmethod
     def load_calc_dice(paths):
         dices = []
         ref_seg = np.load(paths[0])[np.newaxis, np.newaxis]
         n_classes = len(np.unique(ref_seg))
         ref_seg = mutils.get_one_hot_encoding(ref_seg, n_classes)
 
         for c_file in paths[1]:
             c_seg = np.load(c_file)[np.newaxis, np.newaxis]
             assert n_classes == len(np.unique(c_seg)), "unequal nr of objects/classes betw segs {} {}".format(paths[0],
                                                                                                               c_file)
             c_seg = mutils.get_one_hot_encoding(c_seg, n_classes)
 
             dice = mutils.dice_per_batch_inst_and_class(c_seg, ref_seg, n_classes, convert_to_ohe=False)
             dices.append(dice)
         print("processed ref_path {}".format(paths[0]))
         return np.mean(dices), np.std(dices)
 
     def iterate_files(self, grouped_paths, processes=os.cpu_count()):
         p = Pool(processes)
 
         means_stds = np.array(p.map(self.load_calc_dice, grouped_paths))
 
         p.close(); p.join()
         min_dice = np.min(means_stds[:, 0])
         print("min mean dice {:.2f}, max std {:.4f}".format(min_dice, np.max(means_stds[:, 1])))
         assert min_dice > 1-1e5, "compared seg gts have insufficient minimum mean dice overlap of {}".format(min_dice)
 
     def test(self):
         ref_path = '/mnt/HDD2TB/Documents/data/prostate/data_t2_250519_ps384_gs6071'
         comp_paths = ['/mnt/HDD2TB/Documents/data/prostate/data_t2_190419_ps384_gs6071', ]
         paths = self.group_seg_paths(ref_path, comp_paths)
         self.iterate_files(paths)
 
 #------- check if cross-validation fold splits of different experiments are identical ----------
 class CompareFoldSplits(unittest.TestCase):
     """ Find evtl. differences in cross-val file splits across different experiments.
     """
     @staticmethod
     def group_id_paths(ref_exp_dir, comp_exp_dirs):
 
         f_name = 'fold_ids.pickle'
 
         ref_paths = os.path.join(ref_exp_dir, f_name)
         assert os.path.isfile(ref_paths), "ref file {} does not exist.".format(ref_paths)
 
 
         ref_paths = [ref_paths for comp_ed in comp_exp_dirs]
         comp_paths = [os.path.join(comp_ed, f_name) for comp_ed in comp_exp_dirs]
 
         return zip(ref_paths, comp_paths)
 
     @staticmethod
     def comp_fold_ids(mp_input):
         fold_ids1, fold_ids2 = mp_input
         with open(fold_ids1, 'rb') as f:
             fold_ids1 = pickle.load(f)
         try:
             with open(fold_ids2, 'rb') as f:
                 fold_ids2 = pickle.load(f)
         except FileNotFoundError:
             print("comp file {} does not exist.".format(fold_ids2))
             return
 
         n_splits = len(fold_ids1)
         assert n_splits == len(fold_ids2), "mismatch n splits: ref has {}, comp {}".format(n_splits, len(fold_ids2))
         split_diffs = [np.setdiff1d(fold_ids1[s], fold_ids2[s]) for s in range(n_splits)]
         all_equal = np.any(split_diffs)
         return (split_diffs, all_equal)
 
     def iterate_exp_dirs(self, ref_exp, comp_exps, processes=os.cpu_count()):
 
         grouped_paths = list(self.group_id_paths(ref_exp, comp_exps))
         print("performing {} comparisons of cross-val file splits".format(len(grouped_paths)))
         p = Pool(processes)
         split_diffs = p.map(self.comp_fold_ids, grouped_paths)
         p.close(); p.join()
 
         df = pd.DataFrame(index=range(0,len(grouped_paths)), columns=["ref", "comp", "all_equal"])#, "diffs"])
         for ix, (ref, comp) in enumerate(grouped_paths):
             df.iloc[ix] = [ref, comp, split_diffs[ix][1]]#, split_diffs[ix][0]]
 
         print("Any splits not equal?", df.all_equal.any())
         assert not df.all_equal.any(), "a split set is different from reference split set, {}".format(df[~df.all_equal])
 
     def test(self):
         exp_parent_dir = '/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/'
         ref_exp = '/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_detfpn2d_cl_bs10'
         comp_exps = [os.path.join(exp_parent_dir, p) for p in os.listdir(exp_parent_dir)]
         comp_exps = [p for p in comp_exps if os.path.isdir(p) and p != ref_exp]
         self.iterate_exp_dirs(ref_exp, comp_exps)
 
 
 #------- check if cross-validation fold splits of a single experiment are actually incongruent (as required) ----------
 class VerifyFoldSplits(unittest.TestCase):
     """ Check, for a single fold_ids file, i.e., for a single experiment, if the assigned folds (assignment of data
         identifiers) is actually incongruent. No overlaps between folds are required for a correct cross validation.
     """
     @staticmethod
     def verify_fold_ids(splits):
         for i, split1 in enumerate(splits):
             for j, split2 in enumerate(splits):
                 if j > i:
                     inter = np.intersect1d(split1, split2)
                     if len(inter) > 0:
                         raise Exception("Split {} and {} intersect by pids {}".format(i, j, inter))
     def test(self):
         exp_dir = "/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/dev"
         check_file = os.path.join(exp_dir, 'fold_ids.pickle')
         with open(check_file, 'rb') as handle:
             splits = pickle.load(handle)
         self.verify_fold_ids(splits)
 
 # -------- check own nms CUDA implement against own numpy implement ------
 class CheckNMSImplementation(unittest.TestCase):
 
     @staticmethod
     def assert_res_equality(keep_ics1, keep_ics2, boxes, scores, tolerance=0, names=("res1", "res2")):
         """
         :param keep_ics1: keep indices (results), torch.Tensor of shape (n_ics,)
         :param keep_ics2:
         :return:
         """
         keep_ics1, keep_ics2 = keep_ics1.cpu().numpy(), keep_ics2.cpu().numpy()
         discrepancies = np.setdiff1d(keep_ics1, keep_ics2)
         try:
             checks = np.array([
                 len(discrepancies) <= tolerance
             ])
         except:
             checks = np.zeros((1,)).astype("bool")
         msgs = np.array([
             """{}: {} \n{}: {} \nboxes: {}\n {}\n""".format(names[0], keep_ics1, names[1], keep_ics2, boxes,
                                                             scores)
         ])
 
         assert np.all(checks), "NMS: results mismatch: " + "\n".join(msgs[~checks])
 
     def single_case(self, count=20, dim=3, threshold=0.2, seed=0):
         boxes, scores = generate_boxes(count, dim, seed=seed, h=320, w=280, d=30)
 
         keep_numpy = torch.tensor(mutils.nms_numpy(boxes, scores, threshold))
 
         # for some reason torchvision nms requires box coords as floats.
         boxes = torch.from_numpy(boxes).type(torch.float32)
         scores = torch.from_numpy(scores).type(torch.float32)
         if dim == 2:
             """need to wait until next pytorch release where they fixed nms on cpu (currently they have >= where it
             needs to be >.
             """
-            # keep_ops = tv.ops.nms(boxes, scores, threshold)
+            keep_ops = tv.ops.nms(boxes, scores, threshold)
             # self.assert_res_equality(keep_numpy, keep_ops, boxes, scores, tolerance=0, names=["np", "ops"])
             pass
 
         boxes = boxes.cuda()
         scores = scores.cuda()
         keep = self.nms_ext.nms(boxes, scores, threshold)
         self.assert_res_equality(keep_numpy, keep, boxes, scores, tolerance=0, names=["np", "cuda"])
 
     def test(self, n_cases=200, box_count=30, threshold=0.5):
         # dynamically import module so that it doesn't affect other tests if import fails
         self.nms_ext = utils.import_module("nms_ext", 'custom_extensions/nms/nms.py')
         # change seed to something fix if you want exactly reproducible test
         seed0 = np.random.randint(50)
         print("NMS test progress (done/total box configurations) 2D:", end="\n")
         for i in tqdm.tqdm(range(n_cases)):
             self.single_case(count=box_count, dim=2, threshold=threshold, seed=seed0+i)
         print("NMS test progress (done/total box configurations) 3D:", end="\n")
         for i in tqdm.tqdm(range(n_cases)):
             self.single_case(count=box_count, dim=3, threshold=threshold, seed=seed0+i)
 
         return
 
 class CheckRoIAlignImplementation(unittest.TestCase):
 
     def prepare(self, dim=2):
 
         b, c, h, w = 1, 3, 50, 50
         # feature map, (b, c, h, w(, z))
         if dim == 2:
             fmap = torch.rand(b, c, h, w).cuda()
             # rois = torch.tensor([[
             #     [0.1, 0.1, 0.3, 0.3],
             #     [0.2, 0.2, 0.4, 0.7],
             #     [0.5, 0.7, 0.7, 0.9],
             # ]]).cuda()
             pool_size = (7, 7)
             rois = generate_boxes(5, dim=dim, h=h, w=w, on_grid=True, seed=np.random.randint(50))[0]
         elif dim == 3:
             d = 20
             fmap = torch.rand(b, c, h, w, d).cuda()
             # rois = torch.tensor([[
             #     [0.1, 0.1, 0.3, 0.3, 0.1, 0.1],
             #     [0.2, 0.2, 0.4, 0.7, 0.2, 0.4],
             #     [0.5, 0.0, 0.7, 1.0, 0.4, 0.5],
             #     [0.0, 0.0, 0.9, 1.0, 0.0, 1.0],
             # ]]).cuda()
             pool_size = (7, 7, 3)
             rois = generate_boxes(5, dim=dim, h=h, w=w, d=d, on_grid=True, seed=np.random.randint(50),
                                   normalize=False)[0]
         else:
             raise ValueError("dim needs to be 2 or 3")
 
         rois = [torch.from_numpy(rois).type(dtype=torch.float32).cuda(), ]
         fmap.requires_grad_(True)
         return fmap, rois, pool_size
 
     def check_2d(self):
-
-        fmap, rois, pool_size = self.prepare(dim=2)
-        align_ops = tv.ops.roi_align(fmap, rois, pool_size)
-        loss_ops = align_ops.sum()
-        loss_ops.backward()
-
-        ra_object = self.ra_ext.RoIAlign(output_size=pool_size, spatial_scale=1., sampling_ratio=-1)
-        align_ext = ra_object(fmap, rois)
-        loss_ext = align_ext.sum()
-        loss_ext.backward()
-        assert (loss_ops == loss_ext), "sum of roialign ops and extension 2D diverges"
-        assert (align_ops == align_ext).all(), "ROIAlign failed 2D test"
+        """ check vs torchvision ops not possible as on purpose different approach.
+        :return:
+        """
+        raise NotImplementedError
+        # fmap, rois, pool_size = self.prepare(dim=2)
+        # ra_object = self.ra_ext.RoIAlign(output_size=pool_size, spatial_scale=1., sampling_ratio=-1)
+        # align_ext = ra_object(fmap, rois)
+        # loss_ext = align_ext.sum()
+        # loss_ext.backward()
+        #
+        # rois_swapped = [rois[0][:, [1,3,0,2]]]
+        # align_ops = tv.ops.roi_align(fmap, rois_swapped, pool_size)
+        # loss_ops = align_ops.sum()
+        # loss_ops.backward()
+        #
+        # assert (loss_ops == loss_ext), "sum of roialign ops and extension 2D diverges"
+        # assert (align_ops == align_ext).all(), "ROIAlign failed 2D test"
 
     def check_3d(self):
         fmap, rois, pool_size = self.prepare(dim=3)
         ra_object = self.ra_ext.RoIAlign(output_size=pool_size, spatial_scale=1., sampling_ratio=-1)
         align_ext = ra_object(fmap, rois)
         loss_ext = align_ext.sum()
         loss_ext.backward()
 
         align_np = mutils.roi_align_3d_numpy(fmap.cpu().detach().numpy(), [roi.cpu().numpy() for roi in rois],
                                              pool_size)
         align_np = np.squeeze(align_np)  # remove singleton batch dim
 
         align_ext = align_ext.cpu().detach().numpy()
         assert np.allclose(align_np, align_ext, rtol=1e-5,
                            atol=1e-8), "RoIAlign differences in numpy and CUDA implement"
 
-    def manual_check(self):
+    def specific_example_check(self):
+        # dummy input
         self.ra_ext = utils.import_module("ra_ext", 'custom_extensions/roi_align/roi_align.py')
-        exp = 5
+        exp = 6
         pool_size = (2,2)
-        fmap = torch.arange(25).view(exp,exp).unsqueeze(0).unsqueeze(0).cuda().float()
-        boxes = (torch.tensor([[-1., -1., 5., 5.]]).cuda()/exp)
-        ind = torch.tensor([0.]).cuda()
+        fmap = torch.arange(exp**2).view(exp,exp).unsqueeze(0).unsqueeze(0).cuda().type(dtype=torch.float32)
+
+        boxes = torch.tensor([[1., 1., 5., 5.]]).cuda()/exp
+        ind = torch.tensor([0.]*len(boxes)).cuda().type(torch.float32)
         y_exp, x_exp = fmap.shape[2:]  # exp = expansion
         boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp], dtype=torch.float32).cuda())
         boxes = torch.cat((ind.unsqueeze(1), boxes), dim=1)
-        aligned = tv.ops.roi_align(fmap, boxes, output_size=pool_size)
-        # ra_object = self.ra_ext.RoIAlign(output_size=pool_size, spatial_scale=1.,)
-        # aligned_own = ra_object(fmap, boxes)
-        boxes_3d = torch.cat((boxes, torch.tensor([[-1.,1.]]).cuda()), dim=1)
+        aligned_tv = tv.ops.roi_align(fmap, boxes, output_size=pool_size, sampling_ratio=-1)
+        aligned = self.ra_ext.roi_align_2d(fmap, boxes, output_size=pool_size, sampling_ratio=-1)
+
+        boxes_3d = torch.cat((boxes, torch.tensor([[-1.,1.]]*len(boxes)).cuda()), dim=1)
         fmap_3d = fmap.unsqueeze(dim=-1)
         pool_size = (*pool_size,1)
         ra_object = self.ra_ext.RoIAlign(output_size=pool_size, spatial_scale=1.,)
-        aligned_own_3d = ra_object(fmap_3d, boxes_3d)
+        aligned_3d = ra_object(fmap_3d, boxes_3d)
+
+        expected_res = torch.tensor([[[[10.5000, 12.5000],
+                                       [22.5000, 24.5000]]]]).cuda()
+        expected_res_3d = torch.tensor([[[[[10.5000],[12.5000]],
+                                          [[22.5000],[24.5000]]]]]).cuda()
+        assert torch.all(aligned==expected_res), "2D RoIAlign check vs. specific example failed. res: {}\n expected: {}\n".format(aligned, expected_res)
+        assert torch.all(aligned_3d==expected_res_3d), "3D RoIAlign check vs. specific example failed. res: {}\n expected: {}\n".format(aligned_3d, expected_res_3d)
+
+    def manual_check(self):
+        """ print examples from a toy batch to file.
+        :return:
+        """
+        self.ra_ext = utils.import_module("ra_ext", 'custom_extensions/roi_align/roi_align.py')
+        # actual mrcnn mask input
+        from datasets.toy import configs
+        cf = configs.Configs()
+        cf.exp_dir = "datasets/toy/experiments/dev/"
+        cf.plot_dir = cf.exp_dir + "plots"
+        os.makedirs(cf.exp_dir, exist_ok=True)
+        cf.fold = 0
+        cf.n_workers = 1
+        logger = utils.get_logger(cf.exp_dir)
+        data_loader = utils.import_module('data_loader', os.path.join("datasets", "toy", 'data_loader.py'))
+        batch_gen = data_loader.get_train_generators(cf, logger=logger)
+        batch = next(batch_gen['train'])
+        roi_mask = np.zeros((1, 320, 200))
+        bb_target = (np.array([50, 40, 90, 120])).astype("int")
+        roi_mask[:, bb_target[0]+1:bb_target[2]+1, bb_target[1]+1:bb_target[3]+1] = 1.
+        #batch = {"roi_masks": np.array([np.array([roi_mask, roi_mask]), np.array([roi_mask])]), "bb_target": [[bb_target, bb_target + 25], [bb_target-20]]}
+        #batch_boxes_cor = [torch.tensor(batch_el_boxes).cuda().float() for batch_el_boxes in batch_cor["bb_target"]]
+        batch_boxes = [torch.tensor(batch_el_boxes).cuda().float() for batch_el_boxes in batch["bb_target"]]
+        #import IPython; IPython.embed()
+        for b in range(len(batch_boxes)):
+            roi_masks = batch["roi_masks"][b]
+            #roi_masks_cor = batch_cor["roi_masks"][b]
+            if roi_masks.sum()>0:
+                boxes = batch_boxes[b]
+                roi_masks = torch.tensor(roi_masks).cuda().type(dtype=torch.float32)
+                box_ids = torch.arange(roi_masks.shape[0]).cuda().unsqueeze(1).type(dtype=torch.float32)
+                masks = tv.ops.roi_align(roi_masks, [boxes], cf.mask_shape)
+                masks = masks.squeeze(1)
+                masks = torch.round(masks)
+                masks_own = self.ra_ext.roi_align_2d(roi_masks, torch.cat((box_ids, boxes), dim=1), cf.mask_shape)
+                boxes = boxes.type(torch.int)
+                #print("check roi mask", roi_masks[0, 0, boxes[0][0]:boxes[0][2], boxes[0][1]:boxes[0][3]].sum(), (boxes[0][2]-boxes[0][0]) * (boxes[0][3]-boxes[0][1]))
+                #print("batch masks", batch["roi_masks"])
+                masks_own = masks_own.squeeze(1)
+                masks_own = torch.round(masks_own)
+                #import IPython; IPython.embed()
+                for mix, mask in enumerate(masks):
+                    fig = plg.plt.figure()
+                    ax = fig.add_subplot()
+                    ax.imshow(roi_masks[mix][0].cpu().numpy(), cmap="gray", vmin=0.)
+                    ax.axis("off")
+                    y1, x1, y2, x2 = boxes[mix]
+                    bbox = plg.mpatches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=0.9, edgecolor="c", facecolor='none')
+                    ax.add_patch(bbox)
+                    x1, y1, x2, y2 = boxes[mix]
+                    bbox = plg.mpatches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=0.9, edgecolor="r",
+                                                  facecolor='none')
+                    ax.add_patch(bbox)
+                    debug_dir = Path("/home/gregor/Documents/regrcnn/datasets/toy/experiments/debugroial")
+                    os.makedirs(debug_dir, exist_ok=True)
+                    plg.plt.savefig(debug_dir/"mask_b{}_{}.png".format(b, mix))
+                    plg.plt.imsave(debug_dir/"mask_b{}_{}_pooled_tv.png".format(b, mix), mask.cpu().numpy(), cmap="gray", vmin=0.)
+                    plg.plt.imsave(debug_dir/"mask_b{}_{}_pooled_own.png".format(b, mix), masks_own[mix].cpu().numpy(), cmap="gray", vmin=0.)
         return
 
     def test(self):
         # dynamically import module so that it doesn't affect other tests if import fails
         self.ra_ext = utils.import_module("ra_ext", 'custom_extensions/roi_align/roi_align.py')
 
+        self.specific_example_check()
+
         # 2d test
-        self.check_2d()
+        #self.check_2d()
 
         # 3d test
         self.check_3d()
 
         return
 
 
 class CheckRuntimeErrors(unittest.TestCase):
     """ Check if minimal examples of the exec.py module finish without runtime errors.
         This check requires a working path to data in the toy-dataset configs.
     """
 
     def test(self):
         cf = utils.import_module("toy_cf", 'datasets/toy/configs.py').Configs()
         exp_dir = "./unittesting/"
         #checks = {"retina_net": False, "mrcnn": False}
         #print("Testing for runtime errors with models {}".format(list(checks.keys())))
         #for model in tqdm.tqdm(list(checks.keys())):
             # cf.model = model
             # cf.model_path = 'models/{}.py'.format(cf.model if not 'retina' in cf.model else 'retina_net')
             # cf.model_path = os.path.join(cf.source_dir, cf.model_path)
             # {'mrcnn': cf.add_mrcnn_configs,
             #  'retina_net': cf.add_mrcnn_configs, 'retina_unet': cf.add_mrcnn_configs,
             #  'detection_unet': cf.add_det_unet_configs, 'detection_fpn': cf.add_det_fpn_configs
             #  }[model]()
         # todo change structure of configs-handling with exec.py so that its dynamically parseable instead of needing to
         # todo be changed in the file all the time.
         checks = {cf.model:False}
         completed_process = subprocess.run("python exec.py --dev --dataset_name toy -m train_test --exp_dir {}".format(exp_dir),
                                            shell=True, capture_output=True, text=True)
         if completed_process.returncode!=0:
             print("Runtime test of model {} failed due to\n{}".format(cf.model, completed_process.stderr))
         else:
             checks[cf.model] = True
         subprocess.call("rm -rf {}".format(exp_dir), shell=True)
         assert all(checks.values()), "A runtime test crashed."
 
 
 if __name__=="__main__":
     stime = time.time()
 
     t = CheckRoIAlignImplementation()
     t.manual_check()
-    unittest.main()
+    #unittest.main()
 
     mins, secs = divmod((time.time() - stime), 60)
     h, mins = divmod(mins, 60)
     t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
     print("{} total runtime: {}".format(os.path.split(__file__)[1], t))
\ No newline at end of file
diff --git a/utils/model_utils.py b/utils/model_utils.py
index 58585f2..da1f34a 100644
--- a/utils/model_utils.py
+++ b/utils/model_utils.py
@@ -1,1537 +1,1527 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 """
 Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn
 published under MIT license.
 """
 import warnings
 warnings.filterwarnings('ignore', '.*From scipy 0.13.0, the output shape of zoom()*')
 
 import numpy as np
 import scipy.misc
 import scipy.ndimage
 import scipy.interpolate
 from scipy.ndimage.measurements import label as lb
 import torch
 
 import tqdm
 
 from custom_extensions.nms import nms
 from custom_extensions.roi_align import roi_align
 
-import torchvision as tv
-
 ############################################################
 #  Segmentation Processing
 ############################################################
 
 def sum_tensor(input, axes, keepdim=False):
     axes = np.unique(axes)
     if keepdim:
         for ax in axes:
             input = input.sum(ax, keepdim=True)
     else:
         for ax in sorted(axes, reverse=True):
             input = input.sum(int(ax))
     return input
 
 def get_one_hot_encoding(y, n_classes):
     """
     transform a numpy label array to a one-hot array of the same shape.
     :param y: array of shape (b, 1, y, x, (z)).
     :param n_classes: int, number of classes to unfold in one-hot encoding.
     :return y_ohe: array of shape (b, n_classes, y, x, (z))
     """
 
     dim = len(y.shape) - 2
     if dim == 2:
         y_ohe = np.zeros((y.shape[0], n_classes, y.shape[2], y.shape[3])).astype('int32')
     elif dim == 3:
         y_ohe = np.zeros((y.shape[0], n_classes, y.shape[2], y.shape[3], y.shape[4])).astype('int32')
     else:
         raise Exception("invalid dimensions {} encountered".format(y.shape))
     for cl in np.arange(n_classes):
         y_ohe[:, cl][y[:, 0] == cl] = 1
     return y_ohe
 
 def dice_per_batch_inst_and_class(pred, y, n_classes, convert_to_ohe=True, smooth=1e-8):
     '''
     computes dice scores per batch instance and class.
     :param pred: prediction array of shape (b, 1, y, x, (z)) (e.g. softmax prediction with argmax over dim 1)
     :param y: ground truth array of shape (b, 1, y, x, (z)) (contains int [0, ..., n_classes]
     :param n_classes: int
     :return: dice scores of shape (b, c)
     '''
     if convert_to_ohe:
         pred = get_one_hot_encoding(pred, n_classes)
         y = get_one_hot_encoding(y, n_classes)
     axes = tuple(range(2, len(pred.shape)))
     intersect = np.sum(pred*y, axis=axes)
     denominator = np.sum(pred, axis=axes)+np.sum(y, axis=axes)
     dice = (2.0*intersect + smooth) / (denominator + smooth)
     return dice
 
 def dice_per_batch_and_class(pred, targ, n_classes, convert_to_ohe=True, smooth=1e-8):
     '''
     computes dice scores per batch and class.
     :param pred: prediction array of shape (b, 1, y, x, (z)) (e.g. softmax prediction with argmax over dim 1)
     :param targ: ground truth array of shape (b, 1, y, x, (z)) (contains int [0, ..., n_classes])
     :param n_classes: int
     :param smooth: Laplacian smooth, https://en.wikipedia.org/wiki/Additive_smoothing
     :return: dice scores of shape (b, c)
     '''
     if convert_to_ohe:
         pred = get_one_hot_encoding(pred, n_classes)
         targ = get_one_hot_encoding(targ, n_classes)
     axes = (0, *list(range(2, len(pred.shape)))) #(0,2,3(,4))
 
     intersect = np.sum(pred * targ, axis=axes)
 
     denominator = np.sum(pred, axis=axes) + np.sum(targ, axis=axes)
     dice = (2.0 * intersect + smooth) / (denominator + smooth)
 
     assert dice.shape==(n_classes,), "dice shp {}".format(dice.shape)
     return dice
 
 
 def batch_dice(pred, y, false_positive_weight=1.0, smooth=1e-6):
     '''
     compute soft dice over batch. this is a differentiable score and can be used as a loss function.
     only dice scores of foreground classes are returned, since training typically
     does not benefit from explicit background optimization. Pixels of the entire batch are considered a pseudo-volume to compute dice scores of.
     This way, single patches with missing foreground classes can not produce faulty gradients.
     :param pred: (b, c, y, x, (z)), softmax probabilities (network output).
     :param y: (b, c, y, x, (z)), one hote encoded segmentation mask.
     :param false_positive_weight: float [0,1]. For weighting of imbalanced classes,
     reduces the penalty for false-positive pixels. Can be beneficial sometimes in data with heavy fg/bg imbalances.
     :return: soft dice score (float).This function discards the background score and returns the mena of foreground scores.
     '''
 
     if len(pred.size()) == 4:
         axes = (0, 2, 3)
         intersect = sum_tensor(pred * y, axes, keepdim=False)
         denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False)
         return torch.mean(( (2*intersect + smooth) / (denom + smooth))[1:]) #only fg dice here.
 
     elif len(pred.size()) == 5:
         axes = (0, 2, 3, 4)
         intersect = sum_tensor(pred * y, axes, keepdim=False)
         denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False)
         return torch.mean(( (2*intersect + smooth) / (denom + smooth))[1:]) #only fg dice here.
     else:
         raise ValueError('wrong input dimension in dice loss')
 
 
 ############################################################
 #  Bounding Boxes
 ############################################################
 
 def compute_iou_2D(box, boxes, box_area, boxes_area):
     """Calculates IoU of the given box with the array of the given boxes.
     box: 1D vector [y1, x1, y2, x2] THIS IS THE GT BOX
     boxes: [boxes_count, (y1, x1, y2, x2)]
     box_area: float. the area of 'box'
     boxes_area: array of length boxes_count.
 
     Note: the areas are passed in rather than calculated here for
           efficency. Calculate once in the caller to avoid duplicate work.
     """
     # Calculate intersection areas
     y1 = np.maximum(box[0], boxes[:, 0])
     y2 = np.minimum(box[2], boxes[:, 2])
     x1 = np.maximum(box[1], boxes[:, 1])
     x2 = np.minimum(box[3], boxes[:, 3])
     intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0)
     union = box_area + boxes_area[:] - intersection[:]
     iou = intersection / union
 
     return iou
 
 
 def compute_iou_3D(box, boxes, box_volume, boxes_volume):
     """Calculates IoU of the given box with the array of the given boxes.
     box: 1D vector [y1, x1, y2, x2, z1, z2] (typically gt box)
     boxes: [boxes_count, (y1, x1, y2, x2, z1, z2)]
     box_area: float. the area of 'box'
     boxes_area: array of length boxes_count.
 
     Note: the areas are passed in rather than calculated here for
           efficency. Calculate once in the caller to avoid duplicate work.
     """
     # Calculate intersection areas
     y1 = np.maximum(box[0], boxes[:, 0])
     y2 = np.minimum(box[2], boxes[:, 2])
     x1 = np.maximum(box[1], boxes[:, 1])
     x2 = np.minimum(box[3], boxes[:, 3])
     z1 = np.maximum(box[4], boxes[:, 4])
     z2 = np.minimum(box[5], boxes[:, 5])
     intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0) * np.maximum(z2 - z1, 0)
     union = box_volume + boxes_volume[:] - intersection[:]
     iou = intersection / union
 
     return iou
 
 
 
 def compute_overlaps(boxes1, boxes2):
     """Computes IoU overlaps between two sets of boxes.
     boxes1, boxes2: [N, (y1, x1, y2, x2)]. / 3D: (z1, z2))
     For better performance, pass the largest set first and the smaller second.
     :return: (#boxes1, #boxes2), ious of each box of 1 machted with each of 2
     """
     # Areas of anchors and GT boxes
     if boxes1.shape[1] == 4:
         area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
         area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
         # Compute overlaps to generate matrix [boxes1 count, boxes2 count]
         # Each cell contains the IoU value.
         overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0]))
         for i in range(overlaps.shape[1]):
             box2 = boxes2[i] #this is the gt box
             overlaps[:, i] = compute_iou_2D(box2, boxes1, area2[i], area1)
         return overlaps
 
     else:
         # Areas of anchors and GT boxes
         volume1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) * (boxes1[:, 5] - boxes1[:, 4])
         volume2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) * (boxes2[:, 5] - boxes2[:, 4])
         # Compute overlaps to generate matrix [boxes1 count, boxes2 count]
         # Each cell contains the IoU value.
         overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0]))
         for i in range(boxes2.shape[0]):
             box2 = boxes2[i]  # this is the gt box
             overlaps[:, i] = compute_iou_3D(box2, boxes1, volume2[i], volume1)
         return overlaps
 
 
 
 def box_refinement(box, gt_box):
     """Compute refinement needed to transform box to gt_box.
     box and gt_box are [N, (y1, x1, y2, x2)] / 3D: (z1, z2))
     """
     height = box[:, 2] - box[:, 0]
     width = box[:, 3] - box[:, 1]
     center_y = box[:, 0] + 0.5 * height
     center_x = box[:, 1] + 0.5 * width
 
     gt_height = gt_box[:, 2] - gt_box[:, 0]
     gt_width = gt_box[:, 3] - gt_box[:, 1]
     gt_center_y = gt_box[:, 0] + 0.5 * gt_height
     gt_center_x = gt_box[:, 1] + 0.5 * gt_width
 
     dy = (gt_center_y - center_y) / height
     dx = (gt_center_x - center_x) / width
     dh = torch.log(gt_height / height)
     dw = torch.log(gt_width / width)
     result = torch.stack([dy, dx, dh, dw], dim=1)
 
     if box.shape[1] > 4:
         depth = box[:, 5] - box[:, 4]
         center_z = box[:, 4] + 0.5 * depth
         gt_depth = gt_box[:, 5] - gt_box[:, 4]
         gt_center_z = gt_box[:, 4] + 0.5 * gt_depth
         dz = (gt_center_z - center_z) / depth
         dd = torch.log(gt_depth / depth)
         result = torch.stack([dy, dx, dz, dh, dw, dd], dim=1)
 
     return result
 
 
 
 def unmold_mask_2D(mask, bbox, image_shape):
     """Converts a mask generated by the neural network into a format similar
     to it's original shape.
     mask: [height, width] of type float. A small, typically 28x28 mask.
     bbox: [y1, x1, y2, x2]. The box to fit the mask in.
 
     Returns a binary mask with the same size as the original image.
     """
     y1, x1, y2, x2 = bbox
     out_zoom = [y2 - y1, x2 - x1]
     zoom_factor = [i / j for i, j in zip(out_zoom, mask.shape)]
 
     mask = scipy.ndimage.zoom(mask, zoom_factor, order=1).astype(np.float32)
 
     # Put the mask in the right location.
     full_mask = np.zeros(image_shape[:2]) #only y,x
     full_mask[y1:y2, x1:x2] = mask
     return full_mask
 
 
 def unmold_mask_2D_torch(mask, bbox, image_shape):
     """Converts a mask generated by the neural network into a format similar
     to it's original shape.
     mask: [height, width] of type float. A small, typically 28x28 mask.
     bbox: [y1, x1, y2, x2]. The box to fit the mask in.
 
     Returns a binary mask with the same size as the original image.
     """
     y1, x1, y2, x2 = bbox
     out_zoom = [(y2 - y1).float(), (x2 - x1).float()]
     zoom_factor = [i / j for i, j in zip(out_zoom, mask.shape)]
 
     mask = mask.unsqueeze(0).unsqueeze(0)
     mask = torch.nn.functional.interpolate(mask, scale_factor=zoom_factor)
     mask = mask[0][0]
     #mask = scipy.ndimage.zoom(mask.cpu().numpy(), zoom_factor, order=1).astype(np.float32)
     #mask = torch.from_numpy(mask).cuda()
     # Put the mask in the right location.
     full_mask = torch.zeros(image_shape[:2])  # only y,x
     full_mask[y1:y2, x1:x2] = mask
     return full_mask
 
 
 
 def unmold_mask_3D(mask, bbox, image_shape):
     """Converts a mask generated by the neural network into a format similar
     to it's original shape.
     mask: [height, width] of type float. A small, typically 28x28 mask.
     bbox: [y1, x1, y2, x2, z1, z2]. The box to fit the mask in.
 
     Returns a binary mask with the same size as the original image.
     """
     y1, x1, y2, x2, z1, z2 = bbox
     out_zoom = [y2 - y1, x2 - x1, z2 - z1]
     zoom_factor = [i/j for i,j in zip(out_zoom, mask.shape)]
     mask = scipy.ndimage.zoom(mask, zoom_factor, order=1).astype(np.float32)
 
     # Put the mask in the right location.
     full_mask = np.zeros(image_shape[:3])
     full_mask[y1:y2, x1:x2, z1:z2] = mask
     return full_mask
 
 def nms_numpy(box_coords, scores, thresh):
     """ non-maximum suppression on 2D or 3D boxes in numpy.
     :param box_coords: [y1,x1,y2,x2 (,z1,z2)] with y1<=y2, x1<=x2, z1<=z2.
     :param scores: ranking scores (higher score == higher rank) of boxes.
     :param thresh: IoU threshold for clustering.
     :return:
     """
     y1 = box_coords[:, 0]
     x1 = box_coords[:, 1]
     y2 = box_coords[:, 2]
     x2 = box_coords[:, 3]
     assert np.all(y1 <= y2) and np.all(x1 <= x2), """"the definition of the coordinates is crucially important here: 
             coordinates of which maxima are taken need to be the lower coordinates"""
     areas = (x2 - x1) * (y2 - y1)
 
     is_3d = box_coords.shape[1] == 6
     if is_3d: # 3-dim case
         z1 = box_coords[:, 4]
         z2 = box_coords[:, 5]
         assert np.all(z1<=z2), """"the definition of the coordinates is crucially important here: 
            coordinates of which maxima are taken need to be the lower coordinates"""
         areas *= (z2 - z1)
 
     order = scores.argsort()[::-1]
 
     keep = []
     while order.size > 0:  # order is the sorted index.  maps order to index: order[1] = 24 means (rank1, ix 24)
         i = order[0] # highest scoring element
         yy1 = np.maximum(y1[i], y1[order])  # highest scoring element still in >order<, is compared to itself, that is okay.
         xx1 = np.maximum(x1[i], x1[order])
         yy2 = np.minimum(y2[i], y2[order])
         xx2 = np.minimum(x2[i], x2[order])
 
         h = np.maximum(0.0, yy2 - yy1)
         w = np.maximum(0.0, xx2 - xx1)
         inter = h * w
 
         if is_3d:
             zz1 = np.maximum(z1[i], z1[order])
             zz2 = np.minimum(z2[i], z2[order])
             d = np.maximum(0.0, zz2 - zz1)
             inter *= d
 
         iou = inter / (areas[i] + areas[order] - inter)
 
         non_matches = np.nonzero(iou <= thresh)[0]  # get all elements that were not matched and discard all others.
         order = order[non_matches]
         keep.append(i)
 
     return keep
 
 
 
 ############################################################
 #  M-RCNN
 ############################################################
 
 def refine_proposals(rpn_pred_probs, rpn_pred_deltas, proposal_count, batch_anchors, cf):
     """
     Receives anchor scores and selects a subset to pass as proposals
     to the second stage. Filtering is done based on anchor scores and
     non-max suppression to remove overlaps. It also applies bounding
     box refinment details to anchors.
     :param rpn_pred_probs: (b, n_anchors, 2)
     :param rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d))))
     :return: batch_normalized_props: Proposals in normalized coordinates (b, proposal_count, (y1, x1, y2, x2, (z1), (z2), score))
     :return: batch_out_proposals: Box coords + RPN foreground scores
     for monitoring/plotting (b, proposal_count, (y1, x1, y2, x2, (z1), (z2), score))
     """
     std_dev = torch.from_numpy(cf.rpn_bbox_std_dev[None]).float().cuda()
     norm = torch.from_numpy(cf.scale).float().cuda()
     anchors = batch_anchors.clone()
 
 
 
     batch_scores = rpn_pred_probs[:, :, 1]
     # norm deltas
     batch_deltas = rpn_pred_deltas * std_dev
     batch_normalized_props = []
     batch_out_proposals = []
 
     # loop over batch dimension.
     for ix in range(batch_scores.shape[0]):
 
         scores = batch_scores[ix]
         deltas = batch_deltas[ix]
 
         # improve performance by trimming to top anchors by score
         # and doing the rest on the smaller subset.
         pre_nms_limit = min(cf.pre_nms_limit, anchors.size()[0])
         scores, order = scores.sort(descending=True)
         order = order[:pre_nms_limit]
         scores = scores[:pre_nms_limit]
         deltas = deltas[order, :]
 
         # apply deltas to anchors to get refined anchors and filter with non-maximum suppression.
         if batch_deltas.shape[-1] == 4:
             boxes = apply_box_deltas_2D(anchors[order, :], deltas)
             boxes = clip_boxes_2D(boxes, cf.window)
         else:
             boxes = apply_box_deltas_3D(anchors[order, :], deltas)
             boxes = clip_boxes_3D(boxes, cf.window)
         # boxes are y1,x1,y2,x2, torchvision-nms requires x1,y1,x2,y2, but consistent swap x<->y is irrelevant.
         keep = nms.nms(boxes, scores, cf.rpn_nms_threshold)
 
 
         keep = keep[:proposal_count]
         boxes = boxes[keep, :]
         rpn_scores = scores[keep][:, None]
 
         # pad missing boxes with 0.
         if boxes.shape[0] < proposal_count:
             n_pad_boxes = proposal_count - boxes.shape[0]
             zeros = torch.zeros([n_pad_boxes, boxes.shape[1]]).cuda()
             boxes = torch.cat([boxes, zeros], dim=0)
             zeros = torch.zeros([n_pad_boxes, rpn_scores.shape[1]]).cuda()
             rpn_scores = torch.cat([rpn_scores, zeros], dim=0)
 
         # concat box and score info for monitoring/plotting.
         batch_out_proposals.append(torch.cat((boxes, rpn_scores), 1).cpu().data.numpy())
         # normalize dimensions to range of 0 to 1.
         normalized_boxes = boxes / norm
-        assert torch.all(normalized_boxes <= 1), "normalized box coords >1 found"
+        where = normalized_boxes <=1
+        assert torch.all(where), "normalized box coords >1 found:\n {}\n".format(normalized_boxes[where])
+        #assert torch.all(normalized_boxes <= 1), "normalized box coords >1 found"
 
         # add again batch dimension
         batch_normalized_props.append(torch.cat((normalized_boxes, rpn_scores), 1).unsqueeze(0))
 
     batch_normalized_props = torch.cat(batch_normalized_props)
     batch_out_proposals = np.array(batch_out_proposals)
 
     return batch_normalized_props, batch_out_proposals
 
 def pyramid_roi_align(feature_maps, rois, pool_size, pyramid_levels, dim):
     """
     Implements ROI Pooling on multiple levels of the feature pyramid.
     :param feature_maps: list of feature maps, each of shape (b, c, y, x , (z))
     :param rois: proposals (normalized coords.) as returned by RPN. contain info about original batch element allocation.
     (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ixs)
     :param pool_size: list of poolsizes in dims: [x, y, (z)]
     :param pyramid_levels: list. [0, 1, 2, ...]
     :return: pooled: pooled feature map rois (n_proposals, c, poolsize_y, poolsize_x, (poolsize_z))
 
     Output:
     Pooled regions in the shape: [num_boxes, height, width, channels].
     The width and height are those specific in the pool_shape in the layer
     constructor.
     """
     boxes = rois[:, :dim*2]
     batch_ixs = rois[:, dim*2]
 
     # Assign each ROI to a level in the pyramid based on the ROI area.
     if dim == 2:
         y1, x1, y2, x2 = boxes.chunk(4, dim=1)
     else:
         y1, x1, y2, x2, z1, z2 = boxes.chunk(6, dim=1)
 
     h = y2 - y1
     w = x2 - x1
 
     # Equation 1 in https://arxiv.org/abs/1612.03144. Account for
     # the fact that our coordinates are normalized here.
     # divide sqrt(h*w) by 1 instead image_area.
     roi_level = (4 + torch.log2(torch.sqrt(h*w))).round().int().clamp(pyramid_levels[0], pyramid_levels[-1])
     # if Pyramid contains additional level P6, adapt the roi_level assignment accordingly.
     if len(pyramid_levels) == 5:
         roi_level[h*w > 0.65] = 5
 
     # Loop through levels and apply ROI pooling to each.
     pooled = []
     box_to_level = []
     fmap_shapes = [f.shape for f in feature_maps]
     for level_ix, level in enumerate(pyramid_levels):
         ix = roi_level == level
         if not ix.any():
             continue
         ix = torch.nonzero(ix)[:, 0]
         level_boxes = boxes[ix, :]
         # re-assign rois to feature map of original batch element.
         ind = batch_ixs[ix].int()
 
         # Keep track of which box is mapped to which level
         box_to_level.append(ix)
 
         # Stop gradient propogation to ROI proposals
         level_boxes = level_boxes.detach()
         if len(pool_size) == 2:
             # remap to feature map coordinate system
             y_exp, x_exp = fmap_shapes[level_ix][2:]  # exp = expansion
             level_boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp], dtype=torch.float32).cuda())
-            # pooled_features_own = roi_align.roi_align_2d(feature_maps[level_ix],
-            #                                          torch.cat((ind.unsqueeze(1).float(), level_boxes), dim=1),
-            #                                          pool_size)
-            import IPython; IPython.embed()
-            pooled_features = tv.ops.roi_align(feature_maps[level_ix],
+            pooled_features = roi_align.roi_align_2d(feature_maps[level_ix],
                                                      torch.cat((ind.unsqueeze(1).float(), level_boxes), dim=1),
                                                      pool_size)
         else:
             y_exp, x_exp, z_exp = fmap_shapes[level_ix][2:]
             level_boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp, z_exp, z_exp], dtype=torch.float32).cuda())
             pooled_features = roi_align.roi_align_3d(feature_maps[level_ix],
                                                      torch.cat((ind.unsqueeze(1).float(), level_boxes), dim=1),
                                                      pool_size)
         pooled.append(pooled_features)
 
 
     # Pack pooled features into one tensor
     pooled = torch.cat(pooled, dim=0)
 
     # Pack box_to_level mapping into one array and add another
     # column representing the order of pooled boxes
     box_to_level = torch.cat(box_to_level, dim=0)
 
     # Rearrange pooled features to match the order of the original boxes
     _, box_to_level = torch.sort(box_to_level)
     pooled = pooled[box_to_level, :, :]
 
     return pooled
 
 
 def roi_align_3d_numpy(input: np.ndarray, rois, output_size: tuple,
                        spatial_scale: float = 1., sampling_ratio: int = -1) -> np.ndarray:
     """ This fct mainly serves as a verification method for 3D CUDA implementation of RoIAlign, it's highly
         inefficient due to the nested loops.
     :param input:  (ndarray[N, C, H, W, D]): input feature map
     :param rois: list (N,K(n), 6), K(n) = nr of rois in batch-element n, single roi of format (y1,x1,y2,x2,z1,z2)
     :param output_size:
     :param spatial_scale:
     :param sampling_ratio:
     :return: (List[N, K(n), C, output_size[0], output_size[1], output_size[2]])
     """
 
     out_height, out_width, out_depth = output_size
 
     coord_grid = tuple([np.linspace(0, input.shape[dim] - 1, num=input.shape[dim]) for dim in range(2, 5)])
     pooled_rois = [[]] * len(rois)
     assert len(rois) == input.shape[0], "batch dim mismatch, rois: {}, input: {}".format(len(rois), input.shape[0])
     print("Numpy 3D RoIAlign progress:", end="\n")
     for b in range(input.shape[0]):
         for roi in tqdm.tqdm(rois[b]):
             y1, x1, y2, x2, z1, z2 = np.array(roi) * spatial_scale
             roi_height = max(float(y2 - y1), 1.)
             roi_width = max(float(x2 - x1), 1.)
             roi_depth = max(float(z2 - z1), 1.)
 
             if sampling_ratio <= 0:
                 sampling_ratio_h = int(np.ceil(roi_height / out_height))
                 sampling_ratio_w = int(np.ceil(roi_width / out_width))
                 sampling_ratio_d = int(np.ceil(roi_depth / out_depth))
             else:
                 sampling_ratio_h = sampling_ratio_w = sampling_ratio_d = sampling_ratio  # == n points per bin
 
             bin_height = roi_height / out_height
             bin_width = roi_width / out_width
             bin_depth = roi_depth / out_depth
 
             n_points = sampling_ratio_h * sampling_ratio_w * sampling_ratio_d
             pooled_roi = np.empty((input.shape[1], out_height, out_width, out_depth), dtype="float32")
             for chan in range(input.shape[1]):
                 lin_interpolator = scipy.interpolate.RegularGridInterpolator(coord_grid, input[b, chan],
                                                                              method="linear")
                 for bin_iy in range(out_height):
                     for bin_ix in range(out_width):
                         for bin_iz in range(out_depth):
 
                             bin_val = 0.
                             for i in range(sampling_ratio_h):
                                 for j in range(sampling_ratio_w):
                                     for k in range(sampling_ratio_d):
                                         loc_ijk = [
-                                            y1 + bin_iy * bin_height + (i + 0.5) * (bin_height / sampling_ratio_h),
-                                            x1 + bin_ix * bin_width + (j + 0.5) * (bin_width / sampling_ratio_w),
-                                            z1 + bin_iz * bin_depth + (k + 0.5) * (bin_depth / sampling_ratio_d)]
+                                            y1 + bin_iy * bin_height + (i + 0.5)* ((bin_height -1) / sampling_ratio_h),
+                                            x1 + bin_ix * bin_width + (j + 0.5) * ((bin_width -1) / sampling_ratio_w),
+                                            z1 + bin_iz * bin_depth + (k + 0.5) * ((bin_depth -1) / sampling_ratio_d)]
                                         # print("loc_ijk", loc_ijk)
                                         if not (np.any([c < -1.0 for c in loc_ijk]) or loc_ijk[0] > input.shape[2] or
                                                 loc_ijk[1] > input.shape[3] or loc_ijk[2] > input.shape[4]):
                                             for catch_case in range(3):
                                                 # catch on-border cases
                                                 if int(loc_ijk[catch_case]) == input.shape[catch_case + 2] - 1:
                                                     loc_ijk[catch_case] = input.shape[catch_case + 2] - 1
                                             bin_val += lin_interpolator(loc_ijk)
                             pooled_roi[chan, bin_iy, bin_ix, bin_iz] = bin_val / n_points
 
             pooled_rois[b].append(pooled_roi)
 
     return np.array(pooled_rois)
 
 def refine_detections(cf, batch_ixs, rois, deltas, scores, regressions):
     """
     Refine classified proposals (apply deltas to rpn rois), filter overlaps (nms) and return final detections.
 
     :param rois: (n_proposals, 2 * dim) normalized boxes as proposed by RPN. n_proposals = batch_size * POST_NMS_ROIS
     :param deltas: (n_proposals, n_classes, 2 * dim) box refinement deltas as predicted by mrcnn bbox regressor.
     :param batch_ixs: (n_proposals) batch element assignment info for re-allocation.
     :param scores: (n_proposals, n_classes) probabilities for all classes per roi as predicted by mrcnn classifier.
     :param regressions: (n_proposals, n_classes, regression_features (+1 for uncertainty if predicted) regression vector
     :return: result: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score, *regression vector features))
     """
     # class IDs per ROI. Since scores of all classes are of interest (not just max class), all are kept at this point.
     class_ids = []
     fg_classes = cf.head_classes - 1
     # repeat vectors to fill in predictions for all foreground classes.
     for ii in range(1, fg_classes + 1):
         class_ids += [ii] * rois.shape[0]
     class_ids = torch.from_numpy(np.array(class_ids)).cuda()
 
     batch_ixs = batch_ixs.repeat(fg_classes)
     rois = rois.repeat(fg_classes, 1)
     deltas = deltas.repeat(fg_classes, 1, 1)
     scores = scores.repeat(fg_classes, 1)
     regressions = regressions.repeat(fg_classes, 1, 1)
 
     # get class-specific scores and  bounding box deltas
     idx = torch.arange(class_ids.size()[0]).long().cuda()
     # using idx instead of slice [:,] squashes first dimension.
     #len(class_ids)>scores.shape[1] --> probs is broadcasted by expansion from fg_classes-->len(class_ids)
     batch_ixs = batch_ixs[idx]
     deltas_specific = deltas[idx, class_ids]
     class_scores = scores[idx, class_ids]
     regressions = regressions[idx, class_ids]
 
     # apply bounding box deltas. re-scale to image coordinates.
     std_dev = torch.from_numpy(np.reshape(cf.rpn_bbox_std_dev, [1, cf.dim * 2])).float().cuda()
     scale = torch.from_numpy(cf.scale).float().cuda()
     refined_rois = apply_box_deltas_2D(rois, deltas_specific * std_dev) * scale if cf.dim == 2 else \
         apply_box_deltas_3D(rois, deltas_specific * std_dev) * scale
 
     # round and cast to int since we're dealing with pixels now
     refined_rois = clip_to_window(cf.window, refined_rois)
     refined_rois = torch.round(refined_rois)
 
     # filter out low confidence boxes
     keep = idx
     keep_bool = (class_scores >= cf.model_min_confidence)
     if not 0 in torch.nonzero(keep_bool).size():
 
         score_keep = torch.nonzero(keep_bool)[:, 0]
         pre_nms_class_ids = class_ids[score_keep]
         pre_nms_rois = refined_rois[score_keep]
         pre_nms_scores = class_scores[score_keep]
         pre_nms_batch_ixs = batch_ixs[score_keep]
 
         for j, b in enumerate(unique1d(pre_nms_batch_ixs)):
 
             bixs = torch.nonzero(pre_nms_batch_ixs == b)[:, 0]
             bix_class_ids = pre_nms_class_ids[bixs]
             bix_rois = pre_nms_rois[bixs]
             bix_scores = pre_nms_scores[bixs]
 
             for i, class_id in enumerate(unique1d(bix_class_ids)):
 
                 ixs = torch.nonzero(bix_class_ids == class_id)[:, 0]
                 # nms expects boxes sorted by score.
                 ix_rois = bix_rois[ixs]
                 ix_scores = bix_scores[ixs]
                 ix_scores, order = ix_scores.sort(descending=True)
                 ix_rois = ix_rois[order, :]
 
                 class_keep = nms.nms(ix_rois, ix_scores, cf.detection_nms_threshold)
 
                 # map indices back.
                 class_keep = keep[score_keep[bixs[ixs[order[class_keep]]]]]
                 # merge indices over classes for current batch element
                 b_keep = class_keep if i == 0 else unique1d(torch.cat((b_keep, class_keep)))
 
             # only keep top-k boxes of current batch-element
             top_ids = class_scores[b_keep].sort(descending=True)[1][:cf.model_max_instances_per_batch_element]
             b_keep = b_keep[top_ids]
 
             # merge indices over batch elements.
             batch_keep = b_keep  if j == 0 else unique1d(torch.cat((batch_keep, b_keep)))
 
         keep = batch_keep
 
     else:
         keep = torch.tensor([0]).long().cuda()
 
     # arrange output
     output = [refined_rois[keep], batch_ixs[keep].unsqueeze(1)]
     output += [class_ids[keep].unsqueeze(1).float(), class_scores[keep].unsqueeze(1)]
     output += [regressions[keep]]
 
     result = torch.cat(output, dim=1)
     # shape: (n_keeps, catted feats), catted feats: [0:dim*2] are box_coords, [dim*2] are batch_ics,
     # [dim*2+1] are class_ids, [dim*2+2] are scores, [dim*2+3:] are regression vector features (incl uncertainty)
     return result
 
 
 def loss_example_mining(cf, batch_proposals, batch_gt_boxes, batch_gt_masks, batch_roi_scores,
                            batch_gt_class_ids, batch_gt_regressions):
     """
     Subsamples proposals for mrcnn losses and generates targets. Sampling is done per batch element, seems to have positive
     effects on training, as opposed to sampling over entire batch. Negatives are sampled via stochastic hard-example mining
     (SHEM), where a number of negative proposals is drawn from larger pool of highest scoring proposals for stochasticity.
     Scoring is obtained here as the max over all foreground probabilities as returned by mrcnn_classifier (worked better than
     loss-based class-balancing methods like "online hard-example mining" or "focal loss".)
 
     Classification-regression duality: regressions can be given along with classes (at least fg/bg, only class scores
     are used for ranking).
 
     :param batch_proposals: (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ixs).
     boxes as proposed by RPN. n_proposals here is determined by batch_size * POST_NMS_ROIS.
     :param mrcnn_class_logits: (n_proposals, n_classes)
     :param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates.
     :param batch_gt_masks: list over batch elements. Each element is binary mask of shape (n_gt_rois, c, y, x, (z))
     :param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels.
         if no classes predicted (only fg/bg from RPN): expected as pseudo classes [0, 1] for bg, fg.
     :param batch_gt_regressions: list over b elements. Each element is a regression target vector. if None--> pseudo
     :return: sample_indices: (n_sampled_rois) indices of sampled proposals to be used for loss functions.
     :return: target_class_ids: (n_sampled_rois)containing target class labels of sampled proposals.
     :return: target_deltas: (n_sampled_rois, 2 * dim) containing target deltas of sampled proposals for box refinement.
     :return: target_masks: (n_sampled_rois, y, x, (z)) containing target masks of sampled proposals.
     """
     # normalization of target coordinates
     #global sample_regressions
     if cf.dim == 2:
         h, w = cf.patch_size
         scale = torch.from_numpy(np.array([h, w, h, w])).float().cuda()
     else:
         h, w, z = cf.patch_size
         scale = torch.from_numpy(np.array([h, w, h, w, z, z])).float().cuda()
 
     positive_count = 0
     negative_count = 0
     sample_positive_indices = []
     sample_negative_indices = []
     sample_deltas = []
     sample_masks = []
     sample_class_ids = []
     if batch_gt_regressions is not None:
         sample_regressions = []
     else:
         target_regressions = torch.FloatTensor().cuda()
 
     std_dev = torch.from_numpy(cf.bbox_std_dev).float().cuda()
 
     # loop over batch and get positive and negative sample rois.
     for b in range(len(batch_gt_boxes)):
 
         gt_masks = torch.from_numpy(batch_gt_masks[b]).float().cuda()
         gt_class_ids = torch.from_numpy(batch_gt_class_ids[b]).int().cuda()
         if batch_gt_regressions is not None:
             gt_regressions = torch.from_numpy(batch_gt_regressions[b]).float().cuda()
 
         #if np.any(batch_gt_class_ids[b] > 0):  # skip roi selection for no gt images.
         if np.any([len(coords)>0 for coords in batch_gt_boxes[b]]):
             gt_boxes = torch.from_numpy(batch_gt_boxes[b]).float().cuda() / scale
         else:
             gt_boxes = torch.FloatTensor().cuda()
 
         # get proposals and indices of current batch element.
         proposals = batch_proposals[batch_proposals[:, -1] == b][:, :-1]
         batch_element_indices = torch.nonzero(batch_proposals[:, -1] == b).squeeze(1)
 
         # Compute overlaps matrix [proposals, gt_boxes]
         if not 0 in gt_boxes.size():
             if gt_boxes.shape[1] == 4:
                 assert cf.dim == 2, "gt_boxes shape {} doesnt match cf.dim{}".format(gt_boxes.shape, cf.dim)
                 overlaps = bbox_overlaps_2D(proposals, gt_boxes)
             else:
                 assert cf.dim == 3, "gt_boxes shape {} doesnt match cf.dim{}".format(gt_boxes.shape, cf.dim)
                 overlaps = bbox_overlaps_3D(proposals, gt_boxes)
 
             # Determine positive and negative ROIs
             roi_iou_max = torch.max(overlaps, dim=1)[0]
             # 1. Positive ROIs are those with >= 0.5 IoU with a GT box
             positive_roi_bool = roi_iou_max >= (0.5 if cf.dim == 2 else 0.3)
             # 2. Negative ROIs are those with < 0.1 with every GT box.
             negative_roi_bool = roi_iou_max < (0.1 if cf.dim == 2 else 0.01)
         else:
             positive_roi_bool = torch.FloatTensor().cuda()
             negative_roi_bool = torch.from_numpy(np.array([1]*proposals.shape[0])).cuda()
 
         # Sample Positive ROIs
         if not 0 in torch.nonzero(positive_roi_bool).size():
             positive_indices = torch.nonzero(positive_roi_bool).squeeze(1)
             positive_samples = int(cf.train_rois_per_image * cf.roi_positive_ratio)
             rand_idx = torch.randperm(positive_indices.size()[0])
             rand_idx = rand_idx[:positive_samples].cuda()
             positive_indices = positive_indices[rand_idx]
             positive_samples = positive_indices.size()[0]
             positive_rois = proposals[positive_indices, :]
             # Assign positive ROIs to GT boxes.
             positive_overlaps = overlaps[positive_indices, :]
             roi_gt_box_assignment = torch.max(positive_overlaps, dim=1)[1]
             roi_gt_boxes = gt_boxes[roi_gt_box_assignment, :]
             roi_gt_class_ids = gt_class_ids[roi_gt_box_assignment]
             if batch_gt_regressions is not None:
                 roi_gt_regressions = gt_regressions[roi_gt_box_assignment]
 
             # Compute bbox refinement targets for positive ROIs
             deltas = box_refinement(positive_rois, roi_gt_boxes)
             deltas /= std_dev
 
             roi_masks = gt_masks[roi_gt_box_assignment]
-            #print("roi_masks[b] in ex mining pre align", roi_masks.unique(return_counts=True))
             assert roi_masks.shape[1] == 1, "gt masks have more than one channel --> is this desired?"
             # Compute mask targets
             boxes = positive_rois
             box_ids = torch.arange(roi_masks.shape[0]).cuda().unsqueeze(1).float()
 
             if len(cf.mask_shape) == 2:
                 y_exp, x_exp = roi_masks.shape[2:]  # exp = expansion
                 boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp], dtype=torch.float32).cuda())
-                # masks = roi_align.roi_align_2d(roi_masks,
-                #                                torch.cat((box_ids, boxes), dim=1),
-                #                                cf.mask_shape)
-                #import IPython; IPython.embed()
-                masks = tv.ops.roi_align(roi_masks,
-                                         torch.cat((box_ids, boxes), dim=1),
-                                         cf.mask_shape)
+                masks = roi_align.roi_align_2d(roi_masks,
+                                               torch.cat((box_ids, boxes), dim=1),
+                                               cf.mask_shape)
             else:
                 y_exp, x_exp, z_exp = roi_masks.shape[2:]  # exp = expansion
                 boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp, z_exp, z_exp], dtype=torch.float32).cuda())
                 masks = roi_align.roi_align_3d(roi_masks,
                                                torch.cat((box_ids, boxes), dim=1),
                                                cf.mask_shape)
-            #print("roi_masks[b] in ex mining POST align", masks.unique(return_counts=True))
 
             masks = masks.squeeze(1)
             # Threshold mask pixels at 0.5 to have GT masks be 0 or 1 to use with
             # binary cross entropy loss.
             masks = torch.round(masks)
 
             sample_positive_indices.append(batch_element_indices[positive_indices])
             sample_deltas.append(deltas)
             sample_masks.append(masks)
             sample_class_ids.append(roi_gt_class_ids)
             if batch_gt_regressions is not None:
                 sample_regressions.append(roi_gt_regressions)
             positive_count += positive_samples
         else:
             positive_samples = 0
 
         # Sample negative ROIs. Add enough to maintain positive:negative ratio, but at least 1. Sample via SHEM.
         if not 0 in torch.nonzero(negative_roi_bool).size():
             negative_indices = torch.nonzero(negative_roi_bool).squeeze(1)
             r = 1.0 / cf.roi_positive_ratio
             b_neg_count = np.max((int(r * positive_samples - positive_samples), 1))
             roi_scores_neg = batch_roi_scores[batch_element_indices[negative_indices]]
             raw_sampled_indices = shem(roi_scores_neg, b_neg_count, cf.shem_poolsize)
             sample_negative_indices.append(batch_element_indices[negative_indices[raw_sampled_indices]])
             negative_count  += raw_sampled_indices.size()[0]
 
     if len(sample_positive_indices) > 0:
         target_deltas = torch.cat(sample_deltas)
         target_masks = torch.cat(sample_masks)
         target_class_ids = torch.cat(sample_class_ids)
         if batch_gt_regressions is not None:
             target_regressions = torch.cat(sample_regressions)
 
     # Pad target information with zeros for negative ROIs.
     if positive_count > 0 and negative_count > 0:
         sample_indices = torch.cat((torch.cat(sample_positive_indices), torch.cat(sample_negative_indices)), dim=0)
         zeros = torch.zeros(negative_count, cf.dim * 2).cuda()
         target_deltas = torch.cat([target_deltas, zeros], dim=0)
         zeros = torch.zeros(negative_count, *cf.mask_shape).cuda()
         target_masks = torch.cat([target_masks, zeros], dim=0)
         zeros = torch.zeros(negative_count).int().cuda()
         target_class_ids = torch.cat([target_class_ids, zeros], dim=0)
         if batch_gt_regressions is not None:
             # regression targets need to have 0 as background/negative with below practice
             if 'regression_bin' in cf.prediction_tasks:
                 zeros = torch.zeros(negative_count, dtype=torch.float).cuda()
             else:
                 zeros = torch.zeros(negative_count, cf.regression_n_features, dtype=torch.float).cuda()
             target_regressions = torch.cat([target_regressions, zeros], dim=0)
 
     elif positive_count > 0:
         sample_indices = torch.cat(sample_positive_indices)
     elif negative_count > 0:
         sample_indices = torch.cat(sample_negative_indices)
         target_deltas = torch.zeros(negative_count, cf.dim * 2).cuda()
         target_masks = torch.zeros(negative_count, *cf.mask_shape).cuda()
         target_class_ids = torch.zeros(negative_count).int().cuda()
         if batch_gt_regressions is not None:
             if 'regression_bin' in cf.prediction_tasks:
                 target_regressions = torch.zeros(negative_count, dtype=torch.float).cuda()
             else:
                 target_regressions = torch.zeros(negative_count, cf.regression_n_features, dtype=torch.float).cuda()
     else:
         sample_indices = torch.LongTensor().cuda()
         target_class_ids = torch.IntTensor().cuda()
         target_deltas = torch.FloatTensor().cuda()
         target_masks = torch.FloatTensor().cuda()
         target_regressions = torch.FloatTensor().cuda()
 
     return sample_indices, target_deltas, target_masks, target_class_ids, target_regressions
 
 ############################################################
 #  Anchors
 ############################################################
 
 def generate_anchors(scales, ratios, shape, feature_stride, anchor_stride):
     """
     scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128]
     ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2]
     shape: [height, width] spatial shape of the feature map over which
             to generate anchors.
     feature_stride: Stride of the feature map relative to the image in pixels.
     anchor_stride: Stride of anchors on the feature map. For example, if the
         value is 2 then generate anchors for every other feature map pixel.
     """
     # Get all combinations of scales and ratios
     scales, ratios = np.meshgrid(np.array(scales), np.array(ratios))
     scales = scales.flatten()
     ratios = ratios.flatten()
 
     # Enumerate heights and widths from scales and ratios
     heights = scales / np.sqrt(ratios)
     widths = scales * np.sqrt(ratios)
 
     # Enumerate shifts in feature space
     shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride
     shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride
     shifts_x, shifts_y = np.meshgrid(shifts_x, shifts_y)
 
     # Enumerate combinations of shifts, widths, and heights
     box_widths, box_centers_x = np.meshgrid(widths, shifts_x)
     box_heights, box_centers_y = np.meshgrid(heights, shifts_y)
 
     # Reshape to get a list of (y, x) and a list of (h, w)
     box_centers = np.stack([box_centers_y, box_centers_x], axis=2).reshape([-1, 2])
     box_sizes = np.stack([box_heights, box_widths], axis=2).reshape([-1, 2])
 
     # Convert to corner coordinates (y1, x1, y2, x2)
     boxes = np.concatenate([box_centers - 0.5 * box_sizes, box_centers + 0.5 * box_sizes], axis=1)
     return boxes
 
 
 
 def generate_anchors_3D(scales_xy, scales_z, ratios, shape, feature_stride_xy, feature_stride_z, anchor_stride):
     """
     scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128]
     ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2]
     shape: [height, width] spatial shape of the feature map over which
             to generate anchors.
     feature_stride: Stride of the feature map relative to the image in pixels.
     anchor_stride: Stride of anchors on the feature map. For example, if the
         value is 2 then generate anchors for every other feature map pixel.
     """
     # Get all combinations of scales and ratios
 
     scales_xy, ratios_meshed = np.meshgrid(np.array(scales_xy), np.array(ratios))
     scales_xy = scales_xy.flatten()
     ratios_meshed = ratios_meshed.flatten()
 
     # Enumerate heights and widths from scales and ratios
     heights = scales_xy / np.sqrt(ratios_meshed)
     widths = scales_xy * np.sqrt(ratios_meshed)
     depths = np.tile(np.array(scales_z), len(ratios_meshed)//np.array(scales_z)[..., None].shape[0])
 
     # Enumerate shifts in feature space
     shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride_xy #translate from fm positions to input coords.
     shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride_xy
     shifts_z = np.arange(0, shape[2], anchor_stride) * (feature_stride_z)
     shifts_x, shifts_y, shifts_z = np.meshgrid(shifts_x, shifts_y, shifts_z)
 
     # Enumerate combinations of shifts, widths, and heights
     box_widths, box_centers_x = np.meshgrid(widths, shifts_x)
     box_heights, box_centers_y = np.meshgrid(heights, shifts_y)
     box_depths, box_centers_z = np.meshgrid(depths, shifts_z)
 
     # Reshape to get a list of (y, x, z) and a list of (h, w, d)
     box_centers = np.stack(
         [box_centers_y, box_centers_x, box_centers_z], axis=2).reshape([-1, 3])
     box_sizes = np.stack([box_heights, box_widths, box_depths], axis=2).reshape([-1, 3])
 
     # Convert to corner coordinates (y1, x1, y2, x2, z1, z2)
     boxes = np.concatenate([box_centers - 0.5 * box_sizes,
                             box_centers + 0.5 * box_sizes], axis=1)
 
     boxes = np.transpose(np.array([boxes[:, 0], boxes[:, 1], boxes[:, 3], boxes[:, 4], boxes[:, 2], boxes[:, 5]]), axes=(1, 0))
     return boxes
 
 
 def generate_pyramid_anchors(logger, cf):
     """Generate anchors at different levels of a feature pyramid. Each scale
     is associated with a level of the pyramid, but each ratio is used in
     all levels of the pyramid.
 
     from configs:
     :param scales: cf.RPN_ANCHOR_SCALES , for conformity with retina nets: scale entries need to be list, e.g. [[4], [8], [16], [32]]
     :param ratios: cf.RPN_ANCHOR_RATIOS , e.g. [0.5, 1, 2]
     :param feature_shapes: cf.BACKBONE_SHAPES , e.g.  [array of shapes per feature map] [80, 40, 20, 10, 5]
     :param feature_strides: cf.BACKBONE_STRIDES , e.g. [2, 4, 8, 16, 32, 64]
     :param anchors_stride: cf.RPN_ANCHOR_STRIDE , e.g. 1
     :return anchors: (N, (y1, x1, y2, x2, (z1), (z2)). All generated anchors in one array. Sorted
     with the same order of the given scales. So, anchors of scale[0] come first, then anchors of scale[1], and so on.
     """
     scales = cf.rpn_anchor_scales
     ratios = cf.rpn_anchor_ratios
     feature_shapes = cf.backbone_shapes
     anchor_stride = cf.rpn_anchor_stride
     pyramid_levels = cf.pyramid_levels
     feature_strides = cf.backbone_strides
 
     logger.info("anchor scales {} and feature map shapes {}".format(scales, feature_shapes))
     expected_anchors = [np.prod(feature_shapes[level]) * len(ratios) * len(scales['xy'][level]) for level in pyramid_levels]
 
     anchors = []
     for lix, level in enumerate(pyramid_levels):
         if len(feature_shapes[level]) == 2:
             anchors.append(generate_anchors(scales['xy'][level], ratios, feature_shapes[level],
                                             feature_strides['xy'][level], anchor_stride))
         elif len(feature_shapes[level]) == 3:
             anchors.append(generate_anchors_3D(scales['xy'][level], scales['z'][level], ratios, feature_shapes[level],
                                             feature_strides['xy'][level], feature_strides['z'][level], anchor_stride))
         else:
             raise Exception("invalid feature_shapes[{}] size {}".format(level, feature_shapes[level]))
         logger.info("level {}: expected anchors {}, built anchors {}.".format(level, expected_anchors[lix], anchors[-1].shape))
 
     out_anchors = np.concatenate(anchors, axis=0)
     logger.info("Total: expected anchors {}, built anchors {}.".format(np.sum(expected_anchors), out_anchors.shape))
 
     return out_anchors
 
 
 
 def apply_box_deltas_2D(boxes, deltas):
     """Applies the given deltas to the given boxes.
     boxes: [N, 4] where each row is y1, x1, y2, x2
     deltas: [N, 4] where each row is [dy, dx, log(dh), log(dw)]
     """
     # Convert to y, x, h, w
     height = boxes[:, 2] - boxes[:, 0]
     width = boxes[:, 3] - boxes[:, 1]
     center_y = boxes[:, 0] + 0.5 * height
     center_x = boxes[:, 1] + 0.5 * width
     # Apply deltas
     center_y += deltas[:, 0] * height
     center_x += deltas[:, 1] * width
     height *= torch.exp(deltas[:, 2])
     width *= torch.exp(deltas[:, 3])
     # Convert back to y1, x1, y2, x2
     y1 = center_y - 0.5 * height
     x1 = center_x - 0.5 * width
     y2 = y1 + height
     x2 = x1 + width
     result = torch.stack([y1, x1, y2, x2], dim=1)
     return result
 
 
 
 def apply_box_deltas_3D(boxes, deltas):
     """Applies the given deltas to the given boxes.
     boxes: [N, 6] where each row is y1, x1, y2, x2, z1, z2
     deltas: [N, 6] where each row is [dy, dx, dz, log(dh), log(dw), log(dd)]
     """
     # Convert to y, x, h, w
     height = boxes[:, 2] - boxes[:, 0]
     width = boxes[:, 3] - boxes[:, 1]
     depth = boxes[:, 5] - boxes[:, 4]
     center_y = boxes[:, 0] + 0.5 * height
     center_x = boxes[:, 1] + 0.5 * width
     center_z = boxes[:, 4] + 0.5 * depth
     # Apply deltas
     center_y += deltas[:, 0] * height
     center_x += deltas[:, 1] * width
     center_z += deltas[:, 2] * depth
     height *= torch.exp(deltas[:, 3])
     width *= torch.exp(deltas[:, 4])
     depth *= torch.exp(deltas[:, 5])
     # Convert back to y1, x1, y2, x2
     y1 = center_y - 0.5 * height
     x1 = center_x - 0.5 * width
     z1 = center_z - 0.5 * depth
     y2 = y1 + height
     x2 = x1 + width
     z2 = z1 + depth
     result = torch.stack([y1, x1, y2, x2, z1, z2], dim=1)
     return result
 
 
 
 def clip_boxes_2D(boxes, window):
     """
     boxes: [N, 4] each col is y1, x1, y2, x2
     window: [4] in the form y1, x1, y2, x2
     """
     boxes = torch.stack( \
         [boxes[:, 0].clamp(float(window[0]), float(window[2])),
          boxes[:, 1].clamp(float(window[1]), float(window[3])),
          boxes[:, 2].clamp(float(window[0]), float(window[2])),
          boxes[:, 3].clamp(float(window[1]), float(window[3]))], 1)
     return boxes
 
 def clip_boxes_3D(boxes, window):
     """
     boxes: [N, 6] each col is y1, x1, y2, x2, z1, z2
     window: [6] in the form y1, x1, y2, x2, z1, z2
     """
     boxes = torch.stack( \
         [boxes[:, 0].clamp(float(window[0]), float(window[2])),
          boxes[:, 1].clamp(float(window[1]), float(window[3])),
          boxes[:, 2].clamp(float(window[0]), float(window[2])),
          boxes[:, 3].clamp(float(window[1]), float(window[3])),
          boxes[:, 4].clamp(float(window[4]), float(window[5])),
          boxes[:, 5].clamp(float(window[4]), float(window[5]))], 1)
     return boxes
 
 from matplotlib import pyplot as plt
 
 
 def clip_boxes_numpy(boxes, window):
     """
     boxes: [N, 4] each col is y1, x1, y2, x2 / [N, 6] in 3D.
     window: iamge shape (y, x, (z))
     """
     if boxes.shape[1] == 4:
         boxes = np.concatenate(
             (np.clip(boxes[:, 0], 0, window[0])[:, None],
             np.clip(boxes[:, 1], 0, window[0])[:, None],
             np.clip(boxes[:, 2], 0, window[1])[:, None],
             np.clip(boxes[:, 3], 0, window[1])[:, None]), 1
         )
 
     else:
         boxes = np.concatenate(
             (np.clip(boxes[:, 0], 0, window[0])[:, None],
              np.clip(boxes[:, 1], 0, window[0])[:, None],
              np.clip(boxes[:, 2], 0, window[1])[:, None],
              np.clip(boxes[:, 3], 0, window[1])[:, None],
              np.clip(boxes[:, 4], 0, window[2])[:, None],
              np.clip(boxes[:, 5], 0, window[2])[:, None]), 1
         )
 
     return boxes
 
 
 
 def bbox_overlaps_2D(boxes1, boxes2):
     """Computes IoU overlaps between two sets of boxes.
     boxes1, boxes2: [N, (y1, x1, y2, x2)].
     """
     # 1. Tile boxes2 and repeate boxes1. This allows us to compare
     # every boxes1 against every boxes2 without loops.
     # TF doesn't have an equivalent to np.repeate() so simulate it
     # using tf.tile() and tf.reshape.
 
     boxes1_repeat = boxes2.size()[0]
     boxes2_repeat = boxes1.size()[0]
 
     boxes1 = boxes1.repeat(1,boxes1_repeat).view(-1,4)
     boxes2 = boxes2.repeat(boxes2_repeat,1)
 
     # 2. Compute intersections
     b1_y1, b1_x1, b1_y2, b1_x2 = boxes1.chunk(4, dim=1)
     b2_y1, b2_x1, b2_y2, b2_x2 = boxes2.chunk(4, dim=1)
     y1 = torch.max(b1_y1, b2_y1)[:, 0]
     x1 = torch.max(b1_x1, b2_x1)[:, 0]
     y2 = torch.min(b1_y2, b2_y2)[:, 0]
     x2 = torch.min(b1_x2, b2_x2)[:, 0]
     #--> expects x1<x2 & y1<y2
     zeros = torch.zeros(y1.size()[0], requires_grad=False)
     if y1.is_cuda:
         zeros = zeros.cuda()
     intersection = torch.max(x2 - x1, zeros) * torch.max(y2 - y1, zeros)
 
     # 3. Compute unions
     b1_area = (b1_y2 - b1_y1) * (b1_x2 - b1_x1)
     b2_area = (b2_y2 - b2_y1) * (b2_x2 - b2_x1)
     union = b1_area[:,0] + b2_area[:,0] - intersection
 
     # 4. Compute IoU and reshape to [boxes1, boxes2]
     iou = intersection / union
     assert torch.all(iou<=1), "iou score>1 produced in bbox_overlaps_2D"
     overlaps = iou.view(boxes2_repeat, boxes1_repeat) #--> per gt box: ious of all proposal boxes with that gt box
 
     return overlaps
 
 def bbox_overlaps_3D(boxes1, boxes2):
     """Computes IoU overlaps between two sets of boxes.
     boxes1, boxes2: [N, (y1, x1, y2, x2, z1, z2)].
     """
     # 1. Tile boxes2 and repeate boxes1. This allows us to compare
     # every boxes1 against every boxes2 without loops.
     # TF doesn't have an equivalent to np.repeate() so simulate it
     # using tf.tile() and tf.reshape.
     boxes1_repeat = boxes2.size()[0]
     boxes2_repeat = boxes1.size()[0]
     boxes1 = boxes1.repeat(1,boxes1_repeat).view(-1,6)
     boxes2 = boxes2.repeat(boxes2_repeat,1)
 
     # 2. Compute intersections
     b1_y1, b1_x1, b1_y2, b1_x2, b1_z1, b1_z2 = boxes1.chunk(6, dim=1)
     b2_y1, b2_x1, b2_y2, b2_x2, b2_z1, b2_z2 = boxes2.chunk(6, dim=1)
     y1 = torch.max(b1_y1, b2_y1)[:, 0]
     x1 = torch.max(b1_x1, b2_x1)[:, 0]
     y2 = torch.min(b1_y2, b2_y2)[:, 0]
     x2 = torch.min(b1_x2, b2_x2)[:, 0]
     z1 = torch.max(b1_z1, b2_z1)[:, 0]
     z2 = torch.min(b1_z2, b2_z2)[:, 0]
     zeros = torch.zeros(y1.size()[0], requires_grad=False)
     if y1.is_cuda:
         zeros = zeros.cuda()
     intersection = torch.max(x2 - x1, zeros) * torch.max(y2 - y1, zeros) * torch.max(z2 - z1, zeros)
 
     # 3. Compute unions
     b1_volume = (b1_y2 - b1_y1) * (b1_x2 - b1_x1)  * (b1_z2 - b1_z1)
     b2_volume = (b2_y2 - b2_y1) * (b2_x2 - b2_x1)  * (b2_z2 - b2_z1)
     union = b1_volume[:,0] + b2_volume[:,0] - intersection
 
     # 4. Compute IoU and reshape to [boxes1, boxes2]
     iou = intersection / union
     overlaps = iou.view(boxes2_repeat, boxes1_repeat)
     return overlaps
 
 def gt_anchor_matching(cf, anchors, gt_boxes, gt_class_ids=None):
     """Given the anchors and GT boxes, compute overlaps and identify positive
     anchors and deltas to refine them to match their corresponding GT boxes.
 
     anchors: [num_anchors, (y1, x1, y2, x2, (z1), (z2))]
     gt_boxes: [num_gt_boxes, (y1, x1, y2, x2, (z1), (z2))]
     gt_class_ids (optional): [num_gt_boxes] Integer class IDs for one stage detectors. in RPN case of Mask R-CNN,
     set all positive matches to 1 (foreground)
 
     Returns:
     anchor_class_matches: [N] (int32) matches between anchors and GT boxes.
                1 = positive anchor, -1 = negative anchor, 0 = neutral
     anchor_delta_targets: [N, (dy, dx, (dz), log(dh), log(dw), (log(dd)))] Anchor bbox deltas.
     """
 
     anchor_class_matches = np.zeros([anchors.shape[0]], dtype=np.int32)
     anchor_delta_targets = np.zeros((cf.rpn_train_anchors_per_image, 2*cf.dim))
     anchor_matching_iou = cf.anchor_matching_iou
 
     if gt_boxes is None:
         anchor_class_matches = np.full(anchor_class_matches.shape, fill_value=-1)
         return anchor_class_matches, anchor_delta_targets
 
     # for mrcnn: anchor matching is done for RPN loss, so positive labels are all 1 (foreground)
     if gt_class_ids is None:
         gt_class_ids = np.array([1] * len(gt_boxes))
 
     # Compute overlaps [num_anchors, num_gt_boxes]
     overlaps = compute_overlaps(anchors, gt_boxes)
 
     # Match anchors to GT Boxes
     # If an anchor overlaps a GT box with IoU >= anchor_matching_iou then it's positive.
     # If an anchor overlaps a GT box with IoU < 0.1 then it's negative.
     # Neutral anchors are those that don't match the conditions above,
     # and they don't influence the loss function.
     # However, don't keep any GT box unmatched (rare, but happens). Instead,
     # match it to the closest anchor (even if its max IoU is < 0.1).
 
     # 1. Set negative anchors first. They get overwritten below if a GT box is
     # matched to them. Skip boxes in crowd areas.
     anchor_iou_argmax = np.argmax(overlaps, axis=1)
     anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax]
     if anchors.shape[1] == 4:
         anchor_class_matches[(anchor_iou_max < 0.1)] = -1
     elif anchors.shape[1] == 6:
         anchor_class_matches[(anchor_iou_max < 0.01)] = -1
     else:
         raise ValueError('anchor shape wrong {}'.format(anchors.shape))
 
     # 2. Set an anchor for each GT box (regardless of IoU value).
     gt_iou_argmax = np.argmax(overlaps, axis=0)
     for ix, ii in enumerate(gt_iou_argmax):
         anchor_class_matches[ii] = gt_class_ids[ix]
 
     # 3. Set anchors with high overlap as positive.
     above_thresh_ixs = np.argwhere(anchor_iou_max >= anchor_matching_iou)
     anchor_class_matches[above_thresh_ixs] = gt_class_ids[anchor_iou_argmax[above_thresh_ixs]]
 
     # Subsample to balance positive anchors.
     ids = np.where(anchor_class_matches > 0)[0]
     extra = len(ids) - (cf.rpn_train_anchors_per_image // 2)
     if extra > 0:
         # Reset the extra ones to neutral
         ids = np.random.choice(ids, extra, replace=False)
         anchor_class_matches[ids] = 0
 
     # Leave all negative proposals negative for now and sample from them later in online hard example mining.
     # For positive anchors, compute shift and scale needed to transform them to match the corresponding GT boxes.
     ids = np.where(anchor_class_matches > 0)[0]
     ix = 0  # index into anchor_delta_targets
     for i, a in zip(ids, anchors[ids]):
         # closest gt box (it might have IoU < anchor_matching_iou)
         gt = gt_boxes[anchor_iou_argmax[i]]
 
         # convert coordinates to center plus width/height.
         gt_h = gt[2] - gt[0]
         gt_w = gt[3] - gt[1]
         gt_center_y = gt[0] + 0.5 * gt_h
         gt_center_x = gt[1] + 0.5 * gt_w
         # Anchor
         a_h = a[2] - a[0]
         a_w = a[3] - a[1]
         a_center_y = a[0] + 0.5 * a_h
         a_center_x = a[1] + 0.5 * a_w
 
         if cf.dim == 2:
             anchor_delta_targets[ix] = [
                 (gt_center_y - a_center_y) / a_h,
                 (gt_center_x - a_center_x) / a_w,
                 np.log(gt_h / a_h),
                 np.log(gt_w / a_w),
             ]
 
         else:
             gt_d = gt[5] - gt[4]
             gt_center_z = gt[4] + 0.5 * gt_d
             a_d = a[5] - a[4]
             a_center_z = a[4] + 0.5 * a_d
 
             anchor_delta_targets[ix] = [
                 (gt_center_y - a_center_y) / a_h,
                 (gt_center_x - a_center_x) / a_w,
                 (gt_center_z - a_center_z) / a_d,
                 np.log(gt_h / a_h),
                 np.log(gt_w / a_w),
                 np.log(gt_d / a_d)
             ]
 
         # normalize.
         anchor_delta_targets[ix] /= cf.rpn_bbox_std_dev
         ix += 1
 
     return anchor_class_matches, anchor_delta_targets
 
 
 
 def clip_to_window(window, boxes):
     """
         window: (y1, x1, y2, x2) / 3D: (z1, z2). The window in the image we want to clip to.
         boxes: [N, (y1, x1, y2, x2)]  / 3D: (z1, z2)
     """
     boxes[:, 0] = boxes[:, 0].clamp(float(window[0]), float(window[2]))
     boxes[:, 1] = boxes[:, 1].clamp(float(window[1]), float(window[3]))
     boxes[:, 2] = boxes[:, 2].clamp(float(window[0]), float(window[2]))
     boxes[:, 3] = boxes[:, 3].clamp(float(window[1]), float(window[3]))
 
     if boxes.shape[1] > 5:
         boxes[:, 4] = boxes[:, 4].clamp(float(window[4]), float(window[5]))
         boxes[:, 5] = boxes[:, 5].clamp(float(window[4]), float(window[5]))
 
     return boxes
 
 ############################################################
 #  Connected Componenent Analysis
 ############################################################
 
 def get_coords(binary_mask, n_components, dim):
     """
     loops over batch to perform connected component analysis on binary input mask. computes box coordinates around
     n_components - biggest components (rois).
     :param binary_mask: (b, y, x, (z)). binary mask for one specific foreground class.
     :param n_components: int. number of components to extract per batch element and class.
     :return: coords (b, n, (y1, x1, y2, x2 (,z1, z2))
     :return: batch_components (b, n, (y1, x1, y2, x2, (z1), (z2))
     """
     assert len(binary_mask.shape)==dim+1
     binary_mask = binary_mask.astype('uint8')
     batch_coords = []
     batch_components = []
     for ix,b in enumerate(binary_mask):
         clusters, n_cands = lb(b)  # performs connected component analysis.
         uniques, counts = np.unique(clusters, return_counts=True)
         keep_uniques = uniques[1:][np.argsort(counts[1:])[::-1]][:n_components] #only keep n_components largest components
         p_components = np.array([(clusters == ii) * 1 for ii in keep_uniques])  # separate clusters and concat
         p_coords = []
         if p_components.shape[0] > 0:
             for roi in p_components:
                 mask_ixs = np.argwhere(roi != 0)
 
                 # get coordinates around component.
                 roi_coords = [np.min(mask_ixs[:, 0]) - 1, np.min(mask_ixs[:, 1]) - 1, np.max(mask_ixs[:, 0]) + 1,
                                np.max(mask_ixs[:, 1]) + 1]
                 if dim == 3:
                     roi_coords += [np.min(mask_ixs[:, 2]), np.max(mask_ixs[:, 2])+1]
                 p_coords.append(roi_coords)
 
             p_coords = np.array(p_coords)
 
             #clip coords.
             p_coords[p_coords < 0] = 0
             p_coords[:, :4][p_coords[:, :4] > binary_mask.shape[-2]] = binary_mask.shape[-2]
             if dim == 3:
                 p_coords[:, 4:][p_coords[:, 4:] > binary_mask.shape[-1]] = binary_mask.shape[-1]
 
         batch_coords.append(p_coords)
         batch_components.append(p_components)
     return batch_coords, batch_components
 
 
 # noinspection PyCallingNonCallable
 def get_coords_gpu(binary_mask, n_components, dim):
     """
     loops over batch to perform connected component analysis on binary input mask. computes box coordiantes around
     n_components - biggest components (rois).
     :param binary_mask: (b, y, x, (z)). binary mask for one specific foreground class.
     :param n_components: int. number of components to extract per batch element and class.
     :return: coords (b, n, (y1, x1, y2, x2 (,z1, z2))
     :return: batch_components (b, n, (y1, x1, y2, x2, (z1), (z2))
     """
     raise Exception("throws floating point exception")
     assert len(binary_mask.shape)==dim+1
     binary_mask = binary_mask.type(torch.uint8)
     batch_coords = []
     batch_components = []
     for ix,b in enumerate(binary_mask):
         clusters, n_cands = lb(b.cpu().data.numpy())  # peforms connected component analysis.
         clusters = torch.from_numpy(clusters).cuda()
         uniques = torch.unique(clusters)
         counts = torch.stack([(clusters==unique).sum() for unique in uniques])
         keep_uniques = uniques[1:][torch.sort(counts[1:])[1].flip(0)][:n_components] #only keep n_components largest components
         p_components = torch.cat([(clusters == ii).unsqueeze(0) for ii in keep_uniques]).cuda()  # separate clusters and concat
         p_coords = []
         if p_components.shape[0] > 0:
             for roi in p_components:
                 mask_ixs = torch.nonzero(roi)
 
                 # get coordinates around component.
                 roi_coords = [torch.min(mask_ixs[:, 0]) - 1, torch.min(mask_ixs[:, 1]) - 1,
                               torch.max(mask_ixs[:, 0]) + 1,
                               torch.max(mask_ixs[:, 1]) + 1]
                 if dim == 3:
                     roi_coords += [torch.min(mask_ixs[:, 2]), torch.max(mask_ixs[:, 2])+1]
                 p_coords.append(roi_coords)
 
             p_coords = torch.tensor(p_coords)
 
             #clip coords.
             p_coords[p_coords < 0] = 0
             p_coords[:, :4][p_coords[:, :4] > binary_mask.shape[-2]] = binary_mask.shape[-2]
             if dim == 3:
                 p_coords[:, 4:][p_coords[:, 4:] > binary_mask.shape[-1]] = binary_mask.shape[-1]
 
         batch_coords.append(p_coords)
         batch_components.append(p_components)
     return batch_coords, batch_components
 
 
 ############################################################
 #  Pytorch Utility Functions
 ############################################################
 
 def unique1d(tensor):
     """discard all elements of tensor that occur more than once; make tensor unique.
     :param tensor:
     :return:
     """
     if tensor.size()[0] == 0 or tensor.size()[0] == 1:
         return tensor
     tensor = tensor.sort()[0]
     unique_bool = tensor[1:] != tensor[:-1]
     first_element = torch.tensor([True], dtype=torch.bool, requires_grad=False)
     if tensor.is_cuda:
         first_element = first_element.cuda()
     unique_bool = torch.cat((first_element, unique_bool), dim=0)
     return tensor[unique_bool.data]
 
 
 def intersect1d(tensor1, tensor2):
     aux = torch.cat((tensor1, tensor2), dim=0)
     aux = aux.sort(descending=True)[0]
     return aux[:-1][(aux[1:] == aux[:-1]).data]
 
 
 
 def shem(roi_probs_neg, negative_count, poolsize):
     """
     stochastic hard example mining: from a list of indices (referring to non-matched predictions),
     determine a pool of highest scoring (worst false positives) of size negative_count*poolsize.
     Then, sample n (= negative_count) predictions of this pool as negative examples for loss.
     :param roi_probs_neg: tensor of shape (n_predictions, n_classes).
     :param negative_count: int.
     :param poolsize: int.
     :return: (negative_count).  indices refer to the positions in roi_probs_neg. If pool smaller than expected due to
     limited negative proposals availabel, this function will return sampled indices of number < negative_count without
     throwing an error.
     """
     # sort according to higehst foreground score.
     probs, order = roi_probs_neg[:, 1:].max(1)[0].sort(descending=True)
     select = torch.tensor((poolsize * int(negative_count), order.size()[0])).min().int()
 
     pool_indices = order[:select]
     rand_idx = torch.randperm(pool_indices.size()[0])
     return pool_indices[rand_idx[:negative_count].cuda()]
 
 
 ############################################################
 #  Weight Init
 ############################################################
 
 
 def initialize_weights(net):
     """Initialize model weights. Current Default in Pytorch (version 0.4.1) is initialization from a uniform distriubtion.
     Will expectably be changed to kaiming_uniform in future versions.
     """
     init_type = net.cf.weight_init
 
     for m in [module for module in net.modules() if type(module) in [torch.nn.Conv2d, torch.nn.Conv3d,
                                                                      torch.nn.ConvTranspose2d,
                                                                      torch.nn.ConvTranspose3d,
                                                                      torch.nn.Linear]]:
         if init_type == 'xavier_uniform':
             torch.nn.init.xavier_uniform_(m.weight.data)
             if m.bias is not None:
                 m.bias.data.zero_()
 
         elif init_type == 'xavier_normal':
             torch.nn.init.xavier_normal_(m.weight.data)
             if m.bias is not None:
                 m.bias.data.zero_()
 
         elif init_type == "kaiming_uniform":
             torch.nn.init.kaiming_uniform_(m.weight.data, mode='fan_out', nonlinearity=net.cf.relu, a=0)
             if m.bias is not None:
                 fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                 bound = 1 / np.sqrt(fan_out)
                 torch.nn.init.uniform_(m.bias, -bound, bound)
 
         elif init_type == "kaiming_normal":
             torch.nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity=net.cf.relu, a=0)
             if m.bias is not None:
                 fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                 bound = 1 / np.sqrt(fan_out)
                 torch.nn.init.normal_(m.bias, -bound, bound)
     net.logger.info("applied {} weight init.".format(init_type))
\ No newline at end of file