diff --git a/custom_extensions/nms/setup.py b/custom_extensions/nms/setup.py
index 90a5d13..70ed1e6 100644
--- a/custom_extensions/nms/setup.py
+++ b/custom_extensions/nms/setup.py
@@ -1,22 +1,25 @@
 """
 Created at 07.11.19 19:12
 @author: gregor
 
 """
 
 import os, sys, site
 from pathlib import Path
 
 # recognise newly installed packages in path
 site.main()
 
 from setuptools import setup
 from torch.utils import cpp_extension
 
 dir_ = Path(os.path.dirname(sys.argv[0]))
 
+sources = [str(dir_/'src/nms_interface.cpp'), str(dir_/'src/nms.cu')]
+
 setup(name='nms_extension',
-      ext_modules=[cpp_extension.CUDAExtension('nms_extension', [str(dir_/'src/nms_interface.cpp'), str(dir_/'src/nms.cu')])],
+      ext_modules=[cpp_extension.CUDAExtension(
+            'nms_extension', sources
+      )],
       cmdclass={'build_ext': cpp_extension.BuildExtension}
-      )
-
+      )
\ No newline at end of file
diff --git a/custom_extensions/nms/setup.py b/custom_extensions/roi_align/2D/setup.py
similarity index 57%
copy from custom_extensions/nms/setup.py
copy to custom_extensions/roi_align/2D/setup.py
index 90a5d13..921913f 100644
--- a/custom_extensions/nms/setup.py
+++ b/custom_extensions/roi_align/2D/setup.py
@@ -1,22 +1,22 @@
 """
 Created at 07.11.19 19:12
 @author: gregor
 
 """
 
 import os, sys, site
 from pathlib import Path
 
 # recognise newly installed packages in path
 site.main()
 
 from setuptools import setup
 from torch.utils import cpp_extension
 
 dir_ = Path(os.path.dirname(sys.argv[0]))
 
-setup(name='nms_extension',
-      ext_modules=[cpp_extension.CUDAExtension('nms_extension', [str(dir_/'src/nms_interface.cpp'), str(dir_/'src/nms.cu')])],
+setup(name='RoIAlign extension 2D',
+      ext_modules=[cpp_extension.CUDAExtension('roi_al_extension', [str(dir_/'src/RoIAlign_interface.cpp'),
+                                                                    str(dir_/'src/RoIAlign_cuda.cu')])],
       cmdclass={'build_ext': cpp_extension.BuildExtension}
       )
-
diff --git a/custom_extensions/roi_align/2D/src/RoIAlign_cuda.cu b/custom_extensions/roi_align/2D/src/RoIAlign_cuda.cu
new file mode 100644
index 0000000..39426bf
--- /dev/null
+++ b/custom_extensions/roi_align/2D/src/RoIAlign_cuda.cu
@@ -0,0 +1,422 @@
+/*
+ROIAlign implementation in CUDA from pytorch framework
+(https://github.com/pytorch/vision/tree/master/torchvision/csrc/cuda on Nov 14 2019)
+
+*/
+
+#include <ATen/ATen.h>
+#include <ATen/TensorUtils.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+#include <typeinfo>
+#include "cuda_helpers.h"
+
+template <typename T>
+__device__ T bilinear_interpolate(
+    const T* input,
+    const int height,
+    const int width,
+    T y,
+    T x,
+    const int index /* index for debug only*/) {
+  // deal with cases that inverse elements are out of feature map boundary
+  if (y < -1.0 || y > height || x < -1.0 || x > width) {
+    // empty
+    return 0;
+  }
+
+  if (y <= 0)
+    y = 0;
+  if (x <= 0)
+    x = 0;
+
+  int y_low = (int)y;
+  int x_low = (int)x;
+  int y_high;
+  int x_high;
+
+  if (y_low >= height - 1) {
+    y_high = y_low = height - 1;
+    y = (T)y_low;
+  } else {
+    y_high = y_low + 1;
+  }
+
+  if (x_low >= width - 1) {
+    x_high = x_low = width - 1;
+    x = (T)x_low;
+  } else {
+    x_high = x_low + 1;
+  }
+
+  T ly = y - y_low;
+  T lx = x - x_low;
+  T hy = 1. - ly, hx = 1. - lx;
+
+  // do bilinear interpolation
+  T v1 = input[y_low * width + x_low];
+  T v2 = input[y_low * width + x_high];
+  T v3 = input[y_high * width + x_low];
+  T v4 = input[y_high * width + x_high];
+  T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+  T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+
+  return val;
+}
+
+template <typename T>
+__global__ void RoIAlignForward(
+    const int nthreads,
+    const T* input,
+    const T spatial_scale,
+    const int channels,
+    const int height,
+    const int width,
+    const int pooled_height,
+    const int pooled_width,
+    const int sampling_ratio,
+    const T* rois,
+    T* output) {
+  CUDA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    const int pw = index % pooled_width;
+    const int ph = (index / pooled_width) % pooled_height;
+    const int c = (index / pooled_width / pooled_height) % channels;
+    const int n = index / pooled_width / pooled_height / channels;
+
+    const T* offset_rois = rois + n * 5;
+    int roi_batch_ind = offset_rois[0];
+
+    // Do not using rounding; this implementation detail is critical
+    T roi_start_h = offset_rois[1] * spatial_scale;
+    T roi_start_w = offset_rois[2] * spatial_scale;
+    T roi_end_h = offset_rois[3] * spatial_scale;
+    T roi_end_w = offset_rois[4] * spatial_scale;
+
+    // Force malformed ROIs to be 1x1
+    T roi_width = max(roi_end_w - roi_start_w, (T)1.);
+    T roi_height = max(roi_end_h - roi_start_h, (T)1.);
+
+    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
+    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+
+    const T* offset_input =
+        input + (roi_batch_ind * channels + c) * height * width;
+
+    // We use roi_bin_grid to sample the grid and mimic integral
+    int roi_bin_grid_h = (sampling_ratio > 0)
+        ? sampling_ratio
+        : ceil(roi_height / pooled_height); // e.g., = 2
+    int roi_bin_grid_w =
+        (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+    // We do average (integral) pooling inside a bin
+    const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
+    T output_val = 0.;
+    for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
+    {
+      const T y = roi_start_h + ph * bin_size_h +
+          static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
+      for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+        const T x = roi_start_w + pw * bin_size_w +
+            static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
+        T val = bilinear_interpolate(offset_input, height, width, y, x, index);
+        output_val += val;
+      }
+    }
+    output_val /= count;
+
+    output[index] = output_val;
+  }
+}
+
+template <typename T>
+__device__ void bilinear_interpolate_gradient(
+    const int height,
+    const int width,
+    T y,
+    T x,
+    T& w1,
+    T& w2,
+    T& w3,
+    T& w4,
+    int& x_low,
+    int& x_high,
+    int& y_low,
+    int& y_high,
+    const int index /* index for debug only*/) {
+  // deal with cases that inverse elements are out of feature map boundary
+  if (y < -1.0 || y > height || x < -1.0 || x > width) {
+    // empty
+    w1 = w2 = w3 = w4 = 0.;
+    x_low = x_high = y_low = y_high = -1;
+    return;
+  }
+
+  if (y <= 0)
+    y = 0;
+  if (x <= 0)
+    x = 0;
+
+  y_low = (int)y;
+  x_low = (int)x;
+
+  if (y_low >= height - 1) {
+    y_high = y_low = height - 1;
+    y = (T)y_low;
+  } else {
+    y_high = y_low + 1;
+  }
+
+  if (x_low >= width - 1) {
+    x_high = x_low = width - 1;
+    x = (T)x_low;
+  } else {
+    x_high = x_low + 1;
+  }
+
+  T ly = y - y_low;
+  T lx = x - x_low;
+  T hy = 1. - ly, hx = 1. - lx;
+
+  // reference in forward
+  // T v1 = input[y_low * width + x_low];
+  // T v2 = input[y_low * width + x_high];
+  // T v3 = input[y_high * width + x_low];
+  // T v4 = input[y_high * width + x_high];
+  // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+
+  w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+  return;
+}
+
+template <typename T>
+__global__ void RoIAlignBackward(
+    const int nthreads,
+    const T* grad_output,
+    const T spatial_scale,
+    const int channels,
+    const int height,
+    const int width,
+    const int pooled_height,
+    const int pooled_width,
+    const int sampling_ratio,
+    T* grad_input,
+    const T* rois,
+    const int n_stride,
+    const int c_stride,
+    const int h_stride,
+    const int w_stride)
+{
+  CUDA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw) is an element in the pooled output
+    int pw = index % pooled_width;
+    int ph = (index / pooled_width) % pooled_height;
+    int c = (index / pooled_width / pooled_height) % channels;
+    int n = index / pooled_width / pooled_height / channels;
+
+    const T* offset_rois = rois + n * 5;
+    int roi_batch_ind = offset_rois[0];
+
+    // Do not using rounding; this implementation detail is critical
+    T roi_start_h = offset_rois[1] * spatial_scale;
+    T roi_start_w = offset_rois[2] * spatial_scale;
+    T roi_end_h = offset_rois[3] * spatial_scale;
+    T roi_end_w = offset_rois[4] * spatial_scale;
+
+    // Force malformed ROIs to be 1x1
+    T roi_width = max(roi_end_w - roi_start_w, (T)1.);
+    T roi_height = max(roi_end_h - roi_start_h, (T)1.);
+    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
+    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+
+    T* offset_grad_input =
+        grad_input + ((roi_batch_ind * channels + c) * height * width);
+
+    // We need to index the gradient using the tensor strides to access the
+    // correct values.
+    int output_offset = n * n_stride + c * c_stride;
+    const T* offset_grad_output = grad_output + output_offset;
+    const T grad_output_this_bin =
+        offset_grad_output[ph * h_stride + pw * w_stride];
+
+    // We use roi_bin_grid to sample the grid and mimic integral
+    int roi_bin_grid_h = (sampling_ratio > 0)
+        ? sampling_ratio
+        : ceil(roi_height / pooled_height); // e.g., = 2
+    int roi_bin_grid_w =
+        (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+    // We do average (integral) pooling inside a bin
+    const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
+
+    for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
+    {
+      const T y = roi_start_h + ph * bin_size_h +
+          static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
+      for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+        const T x = roi_start_w + pw * bin_size_w  +
+            static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
+
+        T w1, w2, w3, w4;
+        int x_low, x_high, y_low, y_high;
+
+        bilinear_interpolate_gradient(
+            height,
+            width,
+            y,
+            x,
+            w1,
+            w2,
+            w3,
+            w4,
+            x_low,
+            x_high,
+            y_low,
+            y_high,
+            index);
+
+        T g1 = grad_output_this_bin * w1 / count;
+        T g2 = grad_output_this_bin * w2 / count;
+        T g3 = grad_output_this_bin * w3 / count;
+        T g4 = grad_output_this_bin * w4 / count;
+
+        if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
+          atomicAdd(
+              offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
+          atomicAdd(
+              offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
+          atomicAdd(
+              offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
+          atomicAdd(
+              offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
+        } // if
+      } // ix
+    } // iy
+  } // CUDA_1D_KERNEL_LOOP
+} // RoIAlignBackward
+
+at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, const at::Tensor& rois, const float spatial_scale,
+                                const int pooled_height, const int pooled_width, const int sampling_ratio) {
+  /*
+   input: feature-map tensor, shape (batch, n_channels, y, x(, z))
+   */
+  AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor");
+  AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
+
+  at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
+
+  at::CheckedFrom c = "ROIAlign_forward_cuda";
+  at::checkAllSameGPU(c, {input_t, rois_t});
+  at::checkAllSameType(c, {input_t, rois_t});
+
+  at::cuda::CUDAGuard device_guard(input.device());
+
+  int num_rois = rois.size(0);
+  int channels = input.size(1);
+  int height = input.size(2);
+  int width = input.size(3);
+
+  at::Tensor output = at::zeros(
+      {num_rois, channels, pooled_height, pooled_width}, input.options());
+
+  auto output_size = num_rois * pooled_height * pooled_width * channels;
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+  dim3 grid(std::min(
+      at::cuda::ATenCeilDiv(
+          static_cast<int64_t>(output_size), static_cast<int64_t>(512)),
+      static_cast<int64_t>(4096)));
+  dim3 block(512);
+
+  if (output.numel() == 0) {
+    AT_CUDA_CHECK(cudaGetLastError());
+    return output;
+  }
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign_forward", [&] {
+    RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>(
+        output_size,
+        input.contiguous().data_ptr<scalar_t>(),
+        spatial_scale,
+        channels,
+        height,
+        width,
+        pooled_height,
+        pooled_width,
+        sampling_ratio,
+        rois.contiguous().data_ptr<scalar_t>(),
+        output.data_ptr<scalar_t>());
+  });
+  AT_CUDA_CHECK(cudaGetLastError());
+  return output;
+}
+
+at::Tensor ROIAlign_backward_cuda(
+    const at::Tensor& grad,
+    const at::Tensor& rois,
+    const float spatial_scale,
+    const int pooled_height,
+    const int pooled_width,
+    const int batch_size,
+    const int channels,
+    const int height,
+    const int width,
+    const int sampling_ratio) {
+  AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor");
+  AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
+
+  at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
+
+  at::CheckedFrom c = "ROIAlign_backward_cuda";
+  at::checkAllSameGPU(c, {grad_t, rois_t});
+  at::checkAllSameType(c, {grad_t, rois_t});
+
+  at::cuda::CUDAGuard device_guard(grad.device());
+
+  at::Tensor grad_input =
+      at::zeros({batch_size, channels, height, width}, grad.options());
+
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+  dim3 grid(std::min(
+      at::cuda::ATenCeilDiv(
+          static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)),
+      static_cast<int64_t>(4096)));
+  dim3 block(512);
+
+  // handle possibly empty gradients
+  if (grad.numel() == 0) {
+    AT_CUDA_CHECK(cudaGetLastError());
+    return grad_input;
+  }
+
+  int n_stride = grad.stride(0);
+  int c_stride = grad.stride(1);
+  int h_stride = grad.stride(2);
+  int w_stride = grad.stride(3);
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_backward", [&] {
+    RoIAlignBackward<scalar_t><<<grid, block, 0, stream>>>(
+        grad.numel(),
+        grad.data_ptr<scalar_t>(),
+        spatial_scale,
+        channels,
+        height,
+        width,
+        pooled_height,
+        pooled_width,
+        sampling_ratio,
+        grad_input.data_ptr<scalar_t>(),
+        rois.contiguous().data_ptr<scalar_t>(),
+        n_stride,
+        c_stride,
+        h_stride,
+        w_stride);
+  });
+  AT_CUDA_CHECK(cudaGetLastError());
+  return grad_input;
+}
\ No newline at end of file
diff --git a/custom_extensions/roi_align/2D/src/RoIAlign_interface.cpp b/custom_extensions/roi_align/2D/src/RoIAlign_interface.cpp
new file mode 100644
index 0000000..41d5fdf
--- /dev/null
+++ b/custom_extensions/roi_align/2D/src/RoIAlign_interface.cpp
@@ -0,0 +1,104 @@
+/* adopted from pytorch framework
+    https://github.com/pytorch/vision/blob/master/torchvision/csrc/ROIAlign.h on Nov 15 2019.
+
+    does not include CPU support but could be added with this interface.
+*/
+
+#include <torch/extension.h>
+
+// Declarations that are initialized in cuda file
+at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, const at::Tensor& rois, const float spatial_scale,
+                                const int pooled_height, const int pooled_width, const int sampling_ratio);
+at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, const at::Tensor& rois, const float spatial_scale,
+    const int pooled_height, const int pooled_width, const int batch_size, const int channels,
+    const int height, const int width, const int sampling_ratio);
+
+// Interface for Python
+at::Tensor ROIAlign_forward(
+    const at::Tensor& input, // Input feature map.
+    const at::Tensor& rois, // List of ROIs to pool over.
+    const double spatial_scale, // The scale of the image features. ROIs will be  scaled to this.
+    const int64_t pooled_height, // The height of the pooled feature map.
+    const int64_t pooled_width, // The width of the pooled feature
+    const int64_t sampling_ratio) // The number of points to sample in each bin along each axis.
+{
+  if (input.type().is_cuda()) {
+    return ROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
+  }
+  AT_ERROR("Not compiled with CPU support");
+  //return ROIAlign_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
+}
+
+at::Tensor ROIAlign_backward(const at::Tensor& grad, const at::Tensor& rois, const float spatial_scale,
+    const int pooled_height, const int pooled_width, const int batch_size, const int channels,
+    const int height, const int width, const int sampling_ratio) {
+  if (grad.type().is_cuda()) {
+    return ROIAlign_backward_cuda( grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels,
+        height, width, sampling_ratio);
+  }
+  AT_ERROR("Not compiled with CPU support");
+  //return ROIAlign_backward_cpu(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels,
+  //   height, width, sampling_ratio);
+}
+
+using namespace at;
+using torch::Tensor;
+using torch::autograd::AutogradContext;
+using torch::autograd::Variable;
+using torch::autograd::variable_list;
+
+class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
+ public:
+  static variable_list forward(
+      AutogradContext* ctx,
+      Variable input,
+      Variable rois,
+      const double spatial_scale,
+      const int64_t pooled_height,
+      const int64_t pooled_width,
+      const int64_t sampling_ratio) {
+    ctx->saved_data["spatial_scale"] = spatial_scale;
+    ctx->saved_data["pooled_height"] = pooled_height;
+    ctx->saved_data["pooled_width"] = pooled_width;
+    ctx->saved_data["sampling_ratio"] = sampling_ratio;
+    ctx->saved_data["input_shape"] = input.sizes();
+    ctx->save_for_backward({rois});
+    auto result = ROIAlign_forward(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
+    return {result};
+  }
+
+  static variable_list backward(
+      AutogradContext* ctx,
+      variable_list grad_output) {
+    // Use data saved in forward
+    auto saved = ctx->get_saved_variables();
+    auto rois = saved[0];
+    auto input_shape = ctx->saved_data["input_shape"].toIntList();
+    auto grad_in = ROIAlign_backward(
+        grad_output[0],
+        rois,
+        ctx->saved_data["spatial_scale"].toDouble(),
+        ctx->saved_data["pooled_height"].toInt(),
+        ctx->saved_data["pooled_width"].toInt(),
+        input_shape[0], //b
+        input_shape[1], //c
+        input_shape[2], //h
+        input_shape[3], //w
+        ctx->saved_data["sampling_ratio"].toInt());
+    return {
+        grad_in, Variable(), Variable(), Variable(), Variable(), Variable()};
+  }
+};
+
+Tensor roi_align(const Tensor& input, const Tensor& rois, const double spatial_scale, const int64_t pooled_height,
+    const int64_t pooled_width, const int64_t sampling_ratio) {
+
+  return ROIAlignFunction::apply(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio)[0];
+
+}
+
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("roi_align", &roi_align, "ROIAlign 2D in c++ and/or cuda");
+}
\ No newline at end of file
diff --git a/custom_extensions/roi_align/2D/src/cuda_helpers.h b/custom_extensions/roi_align/2D/src/cuda_helpers.h
new file mode 100644
index 0000000..af32f60
--- /dev/null
+++ b/custom_extensions/roi_align/2D/src/cuda_helpers.h
@@ -0,0 +1,5 @@
+#pragma once
+
+#define CUDA_1D_KERNEL_LOOP(i, n)                                \
+  for (int i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \
+       i += (blockDim.x * gridDim.x))
diff --git a/custom_extensions/nms/setup.py b/custom_extensions/roi_align/3D/setup.py
similarity index 54%
copy from custom_extensions/nms/setup.py
copy to custom_extensions/roi_align/3D/setup.py
index 90a5d13..f2d164b 100644
--- a/custom_extensions/nms/setup.py
+++ b/custom_extensions/roi_align/3D/setup.py
@@ -1,22 +1,22 @@
 """
 Created at 07.11.19 19:12
 @author: gregor
 
 """
 
 import os, sys, site
 from pathlib import Path
 
 # recognise newly installed packages in path
 site.main()
 
 from setuptools import setup
 from torch.utils import cpp_extension
 
 dir_ = Path(os.path.dirname(sys.argv[0]))
 
-setup(name='nms_extension',
-      ext_modules=[cpp_extension.CUDAExtension('nms_extension', [str(dir_/'src/nms_interface.cpp'), str(dir_/'src/nms.cu')])],
+setup(name='RoIAlign extension 3D',
+      ext_modules=[cpp_extension.CUDAExtension('roi_al_extension_3d', [str(dir_/'src/RoIAlign_interface_3d.cpp'),
+                                                                       str(dir_/'src/RoIAlign_cuda_3d.cu')])],
       cmdclass={'build_ext': cpp_extension.BuildExtension}
-      )
-
+      )
\ No newline at end of file
diff --git a/custom_extensions/roi_align/3D/src/RoIAlign_cuda_3d.cu b/custom_extensions/roi_align/3D/src/RoIAlign_cuda_3d.cu
new file mode 100644
index 0000000..182274f
--- /dev/null
+++ b/custom_extensions/roi_align/3D/src/RoIAlign_cuda_3d.cu
@@ -0,0 +1,487 @@
+/*
+ROIAlign implementation in CUDA from pytorch framework
+(https://github.com/pytorch/vision/tree/master/torchvision/csrc/cuda on Nov 14 2019)
+
+Adapted for additional 3D capability by G. Ramien, DKFZ Heidelberg
+*/
+
+#include <ATen/ATen.h>
+#include <ATen/TensorUtils.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
+#include <cstdio>
+#include "cuda_helpers.h"
+
+/*-------------- gpu kernels -----------------*/
+
+template <typename T>
+__device__ T linear_interpolate(const T xl, const T val_low, const T val_high){
+
+  T val = (val_high - val_low) * xl + val_low;
+  return val;
+}
+
+template <typename T>
+__device__ T trilinear_interpolate(const T* input, const int height, const int width, const int depth,
+                T y, T x, T z, const int index /* index for debug only*/) {
+  // deal with cases that inverse elements are out of feature map boundary
+  if (y < -1.0 || y > height || x < -1.0 || x > width || z < -1.0 || z > depth) {
+    // empty
+    return 0;
+  }
+  if (y <= 0)
+    y = 0;
+  if (x <= 0)
+    x = 0;
+  if (z <= 0)
+    z = 0;
+
+  int y0 = (int)y;
+  int x0 = (int)x;
+  int z0 = (int)z;
+  int y1;
+  int x1;
+  int z1;
+
+  if (y0 >= height - 1) {
+  /*if nearest gridpoint to y on the lower end is on border or border-1, set low, high, mid(=actual point) to border-1*/
+    y1 = y0 = height - 1;
+    y = (T)y0;
+  } else {
+    /* y1 is one pixel from y0, y is the actual point somewhere in between  */
+    y1 = y0 + 1;
+  }
+  if (x0 >= width - 1) {
+    x1 = x0 = width - 1;
+    x = (T)x0;
+  } else {
+    x1 = x0 + 1;
+  }
+  if (z0 >= depth - 1) {
+    z1 = z0 = depth - 1;
+    z = (T)z0;
+  } else {
+    z1 = z0 + 1;
+  }
+
+
+  // do linear interpolation of x values
+  // distance of actual point to lower boundary point, already normalized since x_high - x0 = 1
+  T dis = x - x0;
+  /*  accessing element b,c,y,x,z in 1D-rolled-out array of a tensor with dimensions (B, C, Y, X, Z):
+      tensor[b,c,y,x,z] = arr[ (((b*C+c)*Y+y)*X + x)*Z + z ] = arr[ alpha + (y*X + x)*Z + z ]
+      with alpha = batch&channel locator = (b*C+c)*YXZ.
+      hence, as current input pointer is already offset by alpha: y,x,z is at input[( y*X + x)*Z + z], where
+      X = width, Z = depth.
+  */
+  T x00 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z0], input[(y0*width+ x1)*depth+z0]);
+  T x10 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z0], input[(y1*width+ x1)*depth+z0]);
+  T x01 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z1], input[(y0*width+ x1)*depth+z1]);
+  T x11 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z1], input[(y1*width+ x1)*depth+z1]);
+
+  // linear interpol of y values = bilinear interpol of f(x,y)
+  dis = y - y0;
+  T xy0 = linear_interpolate(dis, x00, x10);
+  T xy1 = linear_interpolate(dis, x01, x11);
+
+  // linear interpol of z value = trilinear interpol of f(x,y,z)
+  dis = z - z0;
+  T xyz = linear_interpolate(dis, xy0, xy1);
+
+  return xyz;
+}
+
+template <typename T>
+__device__ void trilinear_interpolate_gradient(const int height, const int width, const int depth, T y, T x, T z,
+    T& g000, T& g001, T& g010, T& g100, T& g011, T& g101, T& g110, T& g111,
+    int& x0, int& x1, int& y0, int& y1, int& z0, int&z1, const int index /* index for debug only*/)
+{
+  // deal with cases that inverse elements are out of feature map boundary
+  if (y < -1.0 || y > height || x < -1.0 || x > width || z < -1.0 || z > depth) {
+    // empty
+    g000 = g001 = g010 = g100 = g011 = g101 = g110 = g111 = 0.;
+    x0 = x1 = y0 = y1 = z0 = z1 = -1;
+    return;
+  }
+
+  if (y <= 0)
+    y = 0;
+  if (x <= 0)
+    x = 0;
+  if (z <= 0)
+    z = 0;
+
+  y0 = (int)y;
+  x0 = (int)x;
+  z0 = (int)z;
+
+  if (y0 >= height - 1) {
+    y1 = y0 = height - 1;
+    y = (T)y0;
+  } else {
+    y1 = y0 + 1;
+  }
+
+  if (x0 >= width - 1) {
+    x1 = x0 = width - 1;
+    x = (T)x0;
+  } else {
+    x1 = x0 + 1;
+  }
+
+  if (z0 >= depth - 1) {
+    z1 = z0 = depth - 1;
+    z = (T)z0;
+  } else {
+    z1 = z0 + 1;
+  }
+
+  // forward calculations are added as hints
+  T dis_x = x - x0;
+  //T x00 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z0], input[(y0*width+ x1)*depth+z0]); // v000, v100
+  //T x10 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z0], input[(y1*width+ x1)*depth+z0]); // v010, v110
+  //T x01 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z1], input[(y0*width+ x1)*depth+z1]); // v001, v101
+  //T x11 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z1], input[(y1*width+ x1)*depth+z1]); // v011, v111
+
+  // linear interpol of y values = bilinear interpol of f(x,y)
+  T dis_y = y - y0;
+  //T xy0 = linear_interpolate(dis, x00, x10);
+  //T xy1 = linear_interpolate(dis, x01, x11);
+
+  // linear interpol of z value = trilinear interpol of f(x,y,z)
+  T dis_z = z - z0;
+  //T xyz = linear_interpolate(dis, xy0, xy1);
+
+  /* need: grad_i := d(xyz)/d(v_i) with v_i = input_value_i  for all i = 0,..,7 (eight input values --> eight-entry gradient)
+     d(lin_interp(dis,x,y))/dx = (-dis +1) and d(lin_interp(dis,x,y))/dy = dis --> derivatives are indep of x,y.
+     notation: gxyz = gradient for d(trilin_interp)/d(input_value_at_xyz)
+     below grads were calculated by hand
+     save time by reusing (1-dis_x) = 1-x+x0 = x1-x =: dis_x1 */
+  T dis_x1 = (1-dis_x), dis_y1 = (1-dis_y), dis_z1 = (1-dis_z);
+
+  g000 = dis_z1 * dis_y1  * dis_x1;
+  g001 = dis_z  * dis_y1  * dis_x1;
+  g010 = dis_z1 * dis_y   * dis_x1;
+  g100 = dis_z1 * dis_y1  * dis_x;
+  g011 = dis_z  * dis_y   * dis_x1;
+  g101 = dis_z  * dis_y1  * dis_x;
+  g110 = dis_z1 * dis_y   * dis_x;
+  g111 = dis_z  * dis_y   * dis_x;
+
+  return;
+}
+
+template <typename T>
+__global__ void RoIAlignForward(const int nthreads, const T* input, const T spatial_scale, const int channels,
+    const int height, const int width, const int depth, const int pooled_height, const int pooled_width,
+    const int pooled_depth, const int sampling_ratio, const T* rois, T* output)
+{
+
+  CUDA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw, pd) is an element in the pooled output
+    int pd =  index % pooled_depth;
+    int pw = (index / pooled_depth) % pooled_width;
+    int ph = (index / pooled_depth / pooled_width) % pooled_height;
+    int c  = (index / pooled_depth / pooled_width / pooled_height) % channels;
+    int n  =  index / pooled_depth / pooled_width / pooled_height / channels;
+
+
+    // rois are (y1,x1,y2,x2,z1,z2) --> tensor of shape (n_rois, 6)
+    const T* offset_rois = rois + n * 7;
+    int roi_batch_ind = offset_rois[0];
+    // Do not use rounding; this implementation detail is critical
+    T roi_start_h = offset_rois[1] * spatial_scale;
+    T roi_start_w = offset_rois[2] * spatial_scale;
+    T roi_end_h = offset_rois[3] * spatial_scale;
+    T roi_end_w = offset_rois[4] * spatial_scale;
+    T roi_start_d = offset_rois[5] * spatial_scale;
+    T roi_end_d = offset_rois[6] * spatial_scale;
+
+    // Force malformed ROIs to be 1x1
+    T roi_height = max(roi_end_h - roi_start_h, (T)1.);
+    T roi_width = max(roi_end_w - roi_start_w, (T)1.);
+    T roi_depth = max(roi_end_d - roi_start_d, (T)1.);
+
+    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
+    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+    T bin_size_d = static_cast<T>(roi_depth) / static_cast<T>(pooled_depth);
+
+    const T* offset_input =
+        input + (roi_batch_ind * channels + c) * height * width * depth;
+
+    // We use roi_bin_grid to sample the grid and mimic integral
+    // roi_bin_grid == nr of sampling points per bin >= 1
+    int roi_bin_grid_h =
+        (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
+    int roi_bin_grid_w =
+        (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+    int roi_bin_grid_d =
+        (sampling_ratio > 0) ? sampling_ratio : ceil(roi_depth / pooled_depth);
+
+    // We do average (integral) pooling inside a bin
+    const T n_voxels = roi_bin_grid_h * roi_bin_grid_w * roi_bin_grid_d; // e.g. = 4
+
+    T output_val = 0.;
+    for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
+    {
+      const T y = roi_start_h + ph * bin_size_h +
+          static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5, always in the middle of two grid pointsk
+
+      for (int ix = 0; ix < roi_bin_grid_w; ix++)
+      {
+        const T x = roi_start_w + pw * bin_size_w +
+            static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
+
+        for (int iz = 0; iz < roi_bin_grid_d; iz++)
+        {
+          const T z = roi_start_d + pd * bin_size_d +
+              static_cast<T>(iz + .5f) * bin_size_d / static_cast<T>(roi_bin_grid_d);
+          // TODO verify trilinear interpolation
+          T val = trilinear_interpolate(offset_input, height, width, depth, y, x, z, index);
+          output_val += val;
+        } // z iterator and calc+add value
+      } // x iterator
+    } // y iterator
+    output_val /= n_voxels;
+
+    output[index] = output_val;
+  }
+}
+
+template <typename T>
+__global__ void RoIAlignBackward(const int nthreads, const T* grad_output, const T spatial_scale, const int channels,
+    const int height, const int width, const int depth, const int pooled_height, const int pooled_width,
+    const int pooled_depth, const int sampling_ratio, T* grad_input, const T* rois,
+    const int n_stride, const int c_stride, const int h_stride, const int w_stride, const int d_stride)
+{
+
+  CUDA_1D_KERNEL_LOOP(index, nthreads) {
+    // (n, c, ph, pw, pd) is an element in the pooled output
+    int pd =  index % pooled_depth;
+    int pw = (index / pooled_depth) % pooled_width;
+    int ph = (index / pooled_depth / pooled_width) % pooled_height;
+    int c  = (index / pooled_depth / pooled_width / pooled_height) % channels;
+    int n  =  index / pooled_depth / pooled_width / pooled_height / channels;
+
+
+    const T* offset_rois = rois + n * 7;
+    int roi_batch_ind = offset_rois[0];
+
+    // Do not using rounding; this implementation detail is critical
+    T roi_start_h = offset_rois[1] * spatial_scale;
+    T roi_start_w = offset_rois[2] * spatial_scale;
+    T roi_end_h = offset_rois[3] * spatial_scale;
+    T roi_end_w = offset_rois[4] * spatial_scale;
+    T roi_start_d = offset_rois[5] * spatial_scale;
+    T roi_end_d = offset_rois[6] * spatial_scale;
+
+
+    // Force malformed ROIs to be 1x1
+    T roi_width = max(roi_end_w - roi_start_w, (T)1.);
+    T roi_height = max(roi_end_h - roi_start_h, (T)1.);
+    T roi_depth = max(roi_end_d - roi_start_d, (T)1.);
+    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
+    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
+    T bin_size_d = static_cast<T>(roi_depth) / static_cast<T>(pooled_depth);
+
+    // offset: index b,c,y,x,z of tensor of shape (B,C,Y,X,Z) is
+    // b*C*Y*X*Z + c * Y*X*Z + y * X*Z + x *Z + z = (b*C+c)Y*X*Z + ...
+    T* offset_grad_input =
+        grad_input + ((roi_batch_ind * channels + c) * height * width * depth);
+
+    // We need to index the gradient using the tensor strides to access the correct values.
+    int output_offset = n * n_stride + c * c_stride;
+    const T* offset_grad_output = grad_output + output_offset;
+    const T grad_output_this_bin = offset_grad_output[ph * h_stride + pw * w_stride + pd * d_stride];
+
+    // We use roi_bin_grid to sample the grid and mimic integral
+    int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
+    int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+    int roi_bin_grid_d = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_depth / pooled_depth);
+
+    // We do average (integral) pooling inside a bin
+    const T n_voxels = roi_bin_grid_h * roi_bin_grid_w * roi_bin_grid_d; // e.g. = 6
+
+    for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
+    {
+      const T y = roi_start_h + ph * bin_size_h +
+          static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
+
+      for (int ix = 0; ix < roi_bin_grid_w; ix++)
+      {
+        const T x = roi_start_w + pw * bin_size_w +
+          static_cast<T>(ix + .5f) * bin_size_w / static_cast<T>(roi_bin_grid_w);
+
+        for (int iz = 0; iz < roi_bin_grid_d; iz++)
+        {
+          const T z = roi_start_d + pd * bin_size_d +
+              static_cast<T>(iz + .5f) * bin_size_d / static_cast<T>(roi_bin_grid_d);
+
+          T g000, g001, g010, g100, g011, g101, g110, g111; // will hold the current partial derivatives
+          int x0, x1, y0, y1, z0, z1;
+          /* notation: gxyz = gradient at xyz, where x,y,z need to lie on feature-map grid (i.e., =x0,x1 etc.) */
+          trilinear_interpolate_gradient(height, width, depth, y, x, z,
+                                         g000, g001, g010, g100, g011, g101, g110, g111,
+                                         x0, x1, y0, y1, z0, z1, index);
+          /* chain rule: derivatives (i.e., the gradient) of trilin_interpolate(v1,v2,v3,v4,...) (div by n_voxels
+             as we actually need gradient of whole roi_align) are multiplied with gradient so far*/
+          g000 *= grad_output_this_bin / n_voxels;
+          g001 *= grad_output_this_bin / n_voxels;
+          g010 *= grad_output_this_bin / n_voxels;
+          g100 *= grad_output_this_bin / n_voxels;
+          g011 *= grad_output_this_bin / n_voxels;
+          g101 *= grad_output_this_bin / n_voxels;
+          g110 *= grad_output_this_bin / n_voxels;
+          g111 *= grad_output_this_bin / n_voxels;
+
+          if (x0 >= 0 && x1 >= 0 && y0 >= 0 && y1 >= 0 && z0 >= 0 && z1 >= 0)
+          { // atomicAdd(address, content) reads content under address, adds content to it, while: no other thread
+            // can interfere with the memory at address during this operation (thread lock, therefore "atomic").
+            atomicAdd(offset_grad_input + (y0 * width + x0) * depth + z0, static_cast<T>(g000));
+            atomicAdd(offset_grad_input + (y0 * width + x0) * depth + z1, static_cast<T>(g001));
+            atomicAdd(offset_grad_input + (y1 * width + x0) * depth + z0, static_cast<T>(g010));
+            atomicAdd(offset_grad_input + (y0 * width + x1) * depth + z0, static_cast<T>(g100));
+            atomicAdd(offset_grad_input + (y1 * width + x0) * depth + z1, static_cast<T>(g011));
+            atomicAdd(offset_grad_input + (y0 * width + x1) * depth + z1, static_cast<T>(g101));
+            atomicAdd(offset_grad_input + (y1 * width + x1) * depth + z0, static_cast<T>(g110));
+            atomicAdd(offset_grad_input + (y1 * width + x1) * depth + z1, static_cast<T>(g111));
+          } // if
+        } // iz
+      } // ix
+    } // iy
+  } // CUDA_1D_KERNEL_LOOP
+} // RoIAlignBackward
+
+
+/*----------- wrapper functions ----------------*/
+
+at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, const at::Tensor& rois, const float spatial_scale,
+                                const int pooled_height, const int pooled_width, const int pooled_depth, const int sampling_ratio) {
+  /*
+   input: feature-map tensor, shape (batch, n_channels, y, x(, z))
+   */
+  AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor");
+  AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
+
+  at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
+
+  at::CheckedFrom c = "ROIAlign_forward_cuda";
+  at::checkAllSameGPU(c, {input_t, rois_t});
+  at::checkAllSameType(c, {input_t, rois_t});
+
+  at::cuda::CUDAGuard device_guard(input.device());
+
+  auto num_rois = rois.size(0);
+  auto channels = input.size(1);
+  auto height = input.size(2);
+  auto width = input.size(3);
+  auto depth = input.size(4);
+
+  at::Tensor output = at::zeros(
+      {num_rois, channels, pooled_height, pooled_width, pooled_depth}, input.options());
+
+  auto output_size = num_rois * channels * pooled_height * pooled_width * pooled_depth;
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+  dim3 grid(std::min(
+      at::cuda::ATenCeilDiv(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)));
+  dim3 block(512);
+
+  if (output.numel() == 0) {
+    AT_CUDA_CHECK(cudaGetLastError());
+    return output;
+  }
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign forward in 3d", [&] {
+    RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>(
+        output_size,
+        input.contiguous().data_ptr<scalar_t>(),
+        spatial_scale,
+        channels,
+        height,
+        width,
+        depth,
+        pooled_height,
+        pooled_width,
+        pooled_depth,
+        sampling_ratio,
+        rois.contiguous().data_ptr<scalar_t>(),
+        output.data_ptr<scalar_t>());
+  });
+  AT_CUDA_CHECK(cudaGetLastError());
+  return output;
+}
+
+at::Tensor ROIAlign_backward_cuda(
+    const at::Tensor& grad,
+    const at::Tensor& rois,
+    const float spatial_scale,
+    const int pooled_height,
+    const int pooled_width,
+    const int pooled_depth,
+    const int batch_size,
+    const int channels,
+    const int height,
+    const int width,
+    const int depth,
+    const int sampling_ratio)
+{
+  AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor");
+  AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
+
+  at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
+
+  at::CheckedFrom c = "ROIAlign_backward_cuda";
+  at::checkAllSameGPU(c, {grad_t, rois_t});
+  at::checkAllSameType(c, {grad_t, rois_t});
+
+  at::cuda::CUDAGuard device_guard(grad.device());
+
+  at::Tensor grad_input =
+      at::zeros({batch_size, channels, height, width, depth}, grad.options());
+
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+  dim3 grid(std::min(
+      at::cuda::ATenCeilDiv(
+          static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)),
+      static_cast<int64_t>(4096)));
+  dim3 block(512);
+
+  // handle possibly empty gradients
+  if (grad.numel() == 0) {
+    AT_CUDA_CHECK(cudaGetLastError());
+    return grad_input;
+  }
+
+  int n_stride = grad.stride(0);
+  int c_stride = grad.stride(1);
+  int h_stride = grad.stride(2);
+  int w_stride = grad.stride(3);
+  int d_stride = grad.stride(4);
+
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign backward 3D", [&] {
+    RoIAlignBackward<scalar_t><<<grid, block, 0, stream>>>(
+        grad.numel(),
+        grad.data_ptr<scalar_t>(),
+        spatial_scale,
+        channels,
+        height,
+        width,
+        depth,
+        pooled_height,
+        pooled_width,
+        pooled_depth,
+        sampling_ratio,
+        grad_input.data_ptr<scalar_t>(),
+        rois.contiguous().data_ptr<scalar_t>(),
+        n_stride,
+        c_stride,
+        h_stride,
+        w_stride,
+        d_stride);
+  });
+  AT_CUDA_CHECK(cudaGetLastError());
+  return grad_input;
+}
\ No newline at end of file
diff --git a/custom_extensions/roi_align/3D/src/RoIAlign_interface_3d.cpp b/custom_extensions/roi_align/3D/src/RoIAlign_interface_3d.cpp
new file mode 100644
index 0000000..bb680fa
--- /dev/null
+++ b/custom_extensions/roi_align/3D/src/RoIAlign_interface_3d.cpp
@@ -0,0 +1,112 @@
+/* adopted from pytorch framework
+    https://github.com/pytorch/vision/blob/master/torchvision/csrc/ROIAlign.h on Nov 15 2019.
+
+    does not include CPU support but could be added with this interface.
+*/
+
+#include <torch/extension.h>
+
+/*---------------- 3D implementation ---------------------------*/
+
+// Declarations that are initialized in cuda file
+at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, const at::Tensor& rois, const float spatial_scale,
+                                const int pooled_height, const int pooled_width, const int pooled_depth,
+                                const int sampling_ratio);
+at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, const at::Tensor& rois, const float spatial_scale,
+    const int pooled_height, const int pooled_width, const int pooled_depth, const int batch_size, const int channels,
+    const int height, const int width, const int depth, const int sampling_ratio);
+
+// Interface for Python
+at::Tensor ROIAlign_forward(
+    const at::Tensor& input, // Input feature map.
+    const at::Tensor& rois, // List of ROIs to pool over.
+    const double spatial_scale, // The scale of the image features. ROIs will be  scaled to this.
+    const int64_t pooled_height, // The height of the pooled feature map.
+    const int64_t pooled_width, // The width of the pooled feature
+    const int64_t pooled_depth,
+    const int64_t sampling_ratio) // The number of points to sample in each bin along each axis.
+{
+  if (input.type().is_cuda()) {
+    return ROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, pooled_depth, sampling_ratio);
+  }
+  AT_ERROR("Not compiled with CPU support");
+  //return ROIAlign_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
+}
+
+at::Tensor ROIAlign_backward(const at::Tensor& grad, const at::Tensor& rois, const float spatial_scale,
+    const int pooled_height, const int pooled_width, const int pooled_depth, const int batch_size, const int channels,
+    const int height, const int width, const int depth, const int sampling_ratio) {
+  if (grad.type().is_cuda()) {
+    return ROIAlign_backward_cuda( grad, rois, spatial_scale, pooled_height, pooled_width, pooled_depth, batch_size,
+     channels, height, width, depth, sampling_ratio);
+  }
+  AT_ERROR("Not compiled with CPU support");
+  //return ROIAlign_backward_cpu(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels,
+  //   height, width, sampling_ratio);
+}
+
+using namespace at;
+using torch::Tensor;
+using torch::autograd::AutogradContext;
+using torch::autograd::Variable;
+using torch::autograd::variable_list;
+
+class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
+ public:
+  static variable_list forward(
+      AutogradContext* ctx,
+      Variable input,
+      Variable rois,
+      const double spatial_scale,
+      const int64_t pooled_height,
+      const int64_t pooled_width,
+      const int64_t pooled_depth,
+      const int64_t sampling_ratio) {
+    ctx->saved_data["spatial_scale"] = spatial_scale;
+    ctx->saved_data["pooled_height"] = pooled_height;
+    ctx->saved_data["pooled_width"] = pooled_width;
+    ctx->saved_data["pooled_depth"] = pooled_depth;
+    ctx->saved_data["sampling_ratio"] = sampling_ratio;
+    ctx->saved_data["input_shape"] = input.sizes();
+    ctx->save_for_backward({rois});
+    auto result = ROIAlign_forward(input, rois, spatial_scale, pooled_height, pooled_width, pooled_depth,
+     sampling_ratio);
+    return {result};
+  }
+
+  static variable_list backward(
+      AutogradContext* ctx,
+      variable_list grad_output) {
+    // Use data saved in forward
+    auto saved = ctx->get_saved_variables();
+    auto rois = saved[0];
+    auto input_shape = ctx->saved_data["input_shape"].toIntList();
+    auto grad_in = ROIAlign_backward(
+        grad_output[0],
+        rois,
+        ctx->saved_data["spatial_scale"].toDouble(),
+        ctx->saved_data["pooled_height"].toInt(),
+        ctx->saved_data["pooled_width"].toInt(),
+        ctx->saved_data["pooled_depth"].toInt(),
+        input_shape[0],
+        input_shape[1],
+        input_shape[2],
+        input_shape[3],
+        input_shape[4],
+        ctx->saved_data["sampling_ratio"].toInt());
+    return {
+        grad_in, Variable(), Variable(), Variable(), Variable(), Variable(), Variable()};
+  }
+};
+
+Tensor roi_align(const Tensor& input, const Tensor& rois, const double spatial_scale, const int64_t pooled_height,
+    const int64_t pooled_width, const int64_t pooled_depth, const int64_t sampling_ratio) {
+
+  return ROIAlignFunction::apply(input, rois, spatial_scale, pooled_height, pooled_width, pooled_depth,
+                                    sampling_ratio)[0];
+
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("roi_align", &roi_align, "ROIAlign 3D in c++ and/or cuda");
+}
\ No newline at end of file
diff --git a/custom_extensions/roi_align/3D/src/cuda_helpers.h b/custom_extensions/roi_align/3D/src/cuda_helpers.h
new file mode 100644
index 0000000..af32f60
--- /dev/null
+++ b/custom_extensions/roi_align/3D/src/cuda_helpers.h
@@ -0,0 +1,5 @@
+#pragma once
+
+#define CUDA_1D_KERNEL_LOOP(i, n)                                \
+  for (int i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \
+       i += (blockDim.x * gridDim.x))
diff --git a/datasets/toy/configs.py b/datasets/toy/configs.py
index 228241f..94288ad 100644
--- a/datasets/toy/configs.py
+++ b/datasets/toy/configs.py
@@ -1,488 +1,490 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import sys
 import os
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 import numpy as np
 from default_configs import DefaultConfigs
 from collections import namedtuple
 
 boxLabel = namedtuple('boxLabel', ["name", "color"])
 Label = namedtuple("Label", ['id', 'name', 'shape', 'radius', 'color', 'regression', 'ambiguities', 'gt_distortion'])
 binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'bin_vals'])
 
 class Configs(DefaultConfigs):
 
     def __init__(self, server_env=None):
         super(Configs, self).__init__(server_env)
 
         #########################
         #         Prepro        #
         #########################
+
         self.pp_rootdir = os.path.join('/home/gregor/datasets/toy', "cyl1ps_dev")
         self.pp_npz_dir = self.pp_rootdir+"_npz"
 
         self.pre_crop_size = [320,320,8] #y,x,z; determines pp data shape (2D easily implementable, but only 3D for now)
         self.min_2d_radius = 6 #in pixels
         self.n_train_samples, self.n_test_samples = 1200, 1000
 
         # not actually real one-hot encoding (ohe) but contains more info: roi-overlap only within classes.
         self.pp_create_ohe_seg = False
         self.pp_empty_samples_ratio = 0.1
 
         self.pp_place_radii_mid_bin = True
         self.pp_only_distort_2d = True
         # outer-most intensity of blurred radii, relative to inner-object intensity. <1 for decreasing, > 1 for increasing.
         # e.g.: setting 0.1 means blurred edge has min intensity 10% as large as inner-object intensity.
         self.pp_blur_min_intensity = 0.2
 
         self.max_instances_per_sample = 1 #how many max instances over all classes per sample (img if 2d, vol if 3d)
         self.max_instances_per_class = self.max_instances_per_sample  # how many max instances per image per class
         self.noise_scale = 0.  # std-dev of gaussian noise
 
         self.ambigs_sampling = "gaussian" #"gaussian" or "uniform"
         """ radius_calib: gt distort for calibrating uncertainty. Range of gt distortion is inferable from
             image by distinguishing it from the rest of the object.
             blurring width around edge will be shifted so that symmetric rel to orig radius.
             blurring scale: if self.ambigs_sampling is uniform, distribution's non-zero range (b-a) will be sqrt(12)*scale
             since uniform dist has variance (b-a)²/12. b,a will be placed symmetrically around unperturbed radius.
             if sampling is gaussian, then scale parameter sets one std dev, i.e., blurring width will be orig_radius * std_dev * 2.
         """
         self.ambiguities = {
              #set which classes to apply which ambs to below in class labels
              #choose out of: 'outer_radius', 'inner_radius', 'radii_relations'.
              #kind              #probability   #scale (gaussian std, relative to unperturbed value)
             #"outer_radius":     (1.,            0.5),
             #"outer_radius_xy":  (1.,            0.5),
             #"inner_radius":     (0.5,            0.1),
             #"radii_relations":  (0.5,            0.1),
             "radius_calib":     (1.,            1./6)
         }
 
         # shape choices: 'cylinder', 'block'
         #                        id,    name,       shape,      radius,                 color,              regression,     ambiguities,    gt_distortion
         self.pp_classes = [Label(1,     'cylinder', 'cylinder', ((6,6,1),(40,40,8)),    (*self.blue, 1.),   "radius_2d",    (),             ()),
                            #Label(2,      'block',      'block',        ((6,6,1),(40,40,8)),  (*self.aubergine,1.),  "radii_2d", (), ('radius_calib',))
             ]
 
 
         #########################
         #         I/O           #
         #########################
+
         self.data_sourcedir = '/home/gregor/datasets/toy/cyl1ps_dev'
 
         if server_env:
             self.data_sourcedir = '/datasets/data_ramien/toy/cyl1ps_dev_npz'
 
 
         self.test_data_sourcedir = os.path.join(self.data_sourcedir, 'test')
         self.data_sourcedir = os.path.join(self.data_sourcedir, "train")
 
         self.info_df_name = 'info_df.pickle'
 
         # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn', 'detection_fpn'].
         self.model = 'mrcnn'
         self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net')
         self.model_path = os.path.join(self.source_dir, self.model_path)
 
 
         #########################
         #      Architecture     #
         #########################
 
         # one out of [2, 3]. dimension the model operates in.
         self.dim = 2
 
         # 'class', 'regression', 'regression_bin', 'regression_ken_gal'
         # currently only tested mode is a single-task at a time (i.e., only one task in below list)
         # but, in principle, tasks could be combined (e.g., object classes and regression per class)
         self.prediction_tasks = ['class', ]
 
         self.start_filts = 48 if self.dim == 2 else 18
         self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2
         self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50'
         self.norm = 'instance_norm' # one of None, 'instance_norm', 'batch_norm'
         self.relu = 'relu'
         # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform')
         self.weight_init = None
 
         self.regression_n_features = 1  # length of regressor target vector
 
 
         #########################
         #      Data Loader      #
         #########################
 
         self.num_epochs = 32
-        self.num_train_batches = 120 if self.dim == 2 else 80
+        self.num_train_batches = 120 if self.dim == 2 else 180
         self.batch_size = 8 if self.dim == 2 else 4
 
         self.n_cv_splits = 4
         # select modalities from preprocessed data
         self.channels = [0]
         self.n_channels = len(self.channels)
 
         # which channel (mod) to show as bg in plotting, will be extra added to batch if not in self.channels
         self.plot_bg_chan = 0
         self.crop_margin = [20, 20, 1]  # has to be smaller than respective patch_size//2
         self.patch_size_2D = self.pre_crop_size[:2]
         self.patch_size_3D = self.pre_crop_size[:2]+[8]
 
         # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation.
         self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D
 
         # ratio of free sampled batch elements before class balancing is triggered
         # (>0 to include "empty"/background patches.)
         self.batch_random_ratio = 0.2
         self.balance_target = "class_targets" if 'class' in self.prediction_tasks else "rg_bin_targets"
 
         self.observables_patient = []
         self.observables_rois = []
 
         self.seed = 3 #for generating folds
 
         #############################
         # Colors, Classes, Legends  #
         #############################
-        self.plot_frequency = 1
+        self.plot_frequency = 4
 
         binary_bin_labels = [binLabel(1,  'r<=25',      (*self.green, 1.),      (1,25)),
                              binLabel(2,  'r>25',       (*self.red, 1.),        (25,))]
         quintuple_bin_labels = [binLabel(1,  'r2-10',   (*self.green, 1.),      (2,10)),
                                 binLabel(2,  'r10-20',  (*self.yellow, 1.),     (10,20)),
                                 binLabel(3,  'r20-30',  (*self.orange, 1.),     (20,30)),
                                 binLabel(4,  'r30-40',  (*self.bright_red, 1.), (30,40)),
                                 binLabel(5,  'r>40',    (*self.red, 1.), (40,))]
 
         # choose here if to do 2-way or 5-way regression-bin classification
         task_spec_bin_labels = quintuple_bin_labels
 
         self.class_labels = [
             # regression: regression-task label, either value or "(x,y,z)_radius" or "radii".
             # ambiguities: name of above defined ambig to apply to image data (not gt); need to be iterables!
             # gt_distortion: name of ambig to apply to gt only; needs to be iterable!
             #      #id  #name   #shape  #radius     #color              #regression #ambiguities    #gt_distortion
             Label(  0,  'bg',   None,   (0, 0, 0),  (*self.white, 0.),  (0, 0, 0),  (),             ())]
         if "class" in self.prediction_tasks:
             self.class_labels += self.pp_classes
         else:
             self.class_labels += [Label(1, 'object', 'object', ('various',), (*self.orange, 1.), ('radius_2d',), ("various",), ('various',))]
 
 
         if any(['regression' in task for task in self.prediction_tasks]):
             self.bin_labels = [binLabel(0,  'bg',       (*self.white, 1.),      (0,))]
             self.bin_labels += task_spec_bin_labels
             self.bin_id2label = {label.id: label for label in self.bin_labels}
             bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels]
             self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)}
             self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)]
             self.bin_dict = {label.id: label.name for label in self.bin_labels if label.id != 0}
 
         if self.class_specific_seg:
           self.seg_labels = self.class_labels
 
         self.box_type2label = {label.name: label for label in self.box_labels}
         self.class_id2label = {label.id: label for label in self.class_labels}
         self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0}
 
         self.seg_id2label = {label.id: label for label in self.seg_labels}
         self.cmap = {label.id: label.color for label in self.seg_labels}
 
         self.plot_prediction_histograms = True
         self.plot_stat_curves = False
         self.has_colorchannels = False
         self.plot_class_ids = True
 
         self.num_classes = len(self.class_dict)
         self.num_seg_classes = len(self.seg_labels)
 
         #########################
         #   Data Augmentation   #
         #########################
         self.do_aug = True
         self.da_kwargs = {
             'mirror': True,
             'mirror_axes': tuple(np.arange(0, self.dim, 1)),
             'do_elastic_deform': False,
             'alpha': (500., 1500.),
             'sigma': (40., 45.),
             'do_rotation': False,
             'angle_x': (0., 2 * np.pi),
             'angle_y': (0., 0),
             'angle_z': (0., 0),
             'do_scale': False,
             'scale': (0.8, 1.1),
             'random_crop': False,
             'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3),
             'border_mode_data': 'constant',
             'border_cval_data': 0,
             'order_data': 1
         }
 
         if self.dim == 3:
             self.da_kwargs['do_elastic_deform'] = False
             self.da_kwargs['angle_x'] = (0, 0.0)
             self.da_kwargs['angle_y'] = (0, 0.0)  # must be 0!!
             self.da_kwargs['angle_z'] = (0., 2 * np.pi)
 
         #########################
         #  Schedule / Selection #
         #########################
 
         # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training)
         # the former is morge accurate, while the latter is faster (depending on volume size)
         self.val_mode = 'val_sampling' # one of 'val_sampling' , 'val_patient'
         if self.val_mode == 'val_patient':
             self.max_val_patients = 220  # if 'all' iterates over entire val_set once.
         if self.val_mode == 'val_sampling':
             self.num_val_batches = 35 if self.dim==2 else 25
 
         self.save_n_models = 2
         self.min_save_thresh = 1 if self.dim == 2 else 1  # =wait time in epochs
         if "class" in self.prediction_tasks:
             self.model_selection_criteria = {name + "_ap": 1. for name in self.class_dict.values()}
         elif any("regression" in task for task in self.prediction_tasks):
             self.model_selection_criteria = {name + "_ap": 0.2 for name in self.class_dict.values()}
             self.model_selection_criteria.update({name + "_avp": 0.8 for name in self.class_dict.values()})
 
         self.lr_decay_factor = 0.5
         self.scheduling_patience = int(self.num_epochs / 5)
         self.weight_decay = 1e-5
         self.clip_norm = None  # number or None
 
         #########################
         #   Testing / Plotting  #
         #########################
 
         self.test_aug_axes = (0,1,(0,1)) # None or list: choices are 0,1,(0,1)
         self.held_out_test_set = True
         self.max_test_patients = "all"  # number or "all" for all
 
         self.test_against_exact_gt = True # only True implemented
         self.val_against_exact_gt = False # True is an unrealistic --> irrelevant scenario.
         self.report_score_level = ['rois']  # 'patient' or 'rois' (incl)
         self.patient_class_of_interest = 1
         self.patient_bin_of_interest = 2
 
         self.eval_bins_separately = False#"additionally" if not 'class' in self.prediction_tasks else False
         self.metrics = ['ap', 'auc', 'dice']
         if any(['regression' in task for task in self.prediction_tasks]):
             self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp',
                              'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp']
         if 'aleatoric' in self.model:
             self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted']
         self.evaluate_fold_means = True
 
         self.ap_match_ious = [0.5]  # threshold(s) for considering a prediction as true positive
         self.min_det_thresh = 0.3
 
         self.model_max_iou_resolution = 0.2
 
         # aggregation method for test and val_patient predictions.
         # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf,
         # nms = standard non-maximum suppression, or None = no clustering
         self.clustering = 'wbc'
         # iou thresh (exclusive!) for regarding two preds as concerning the same ROI
         self.clustering_iou = self.model_max_iou_resolution  # has to be larger than desired possible overlap iou of model predictions
 
         self.merge_2D_to_3D_preds = False
         self.merge_3D_iou = self.model_max_iou_resolution
         self.n_test_plots = 1  # per fold and rank
 
         self.test_n_epochs = self.save_n_models  # should be called n_test_ens, since is number of models to ensemble over during testing
         # is multiplied by (1 + nr of test augs)
 
         #########################
         #   Assertions          #
         #########################
         if not 'class' in self.prediction_tasks:
             assert self.num_classes == 1
 
         #########################
         #   Add model specifics #
         #########################
 
         {'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs,
          'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs,
          'detection_unet': self.add_det_unet_configs, 'detection_fpn': self.add_det_fpn_configs
          }[self.model]()
 
     def rg_val_to_bin_id(self, rg_val):
         #only meant for isotropic radii!!
         # only 2D radii (x and y dims) or 1D (x or y) are expected
         return np.round(np.digitize(rg_val, self.bin_edges).mean())
 
 
     def add_det_fpn_configs(self):
 
       self.learning_rate = [1 * 1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True
       self.scheduling_criterion = 'torch_loss'
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       self.n_roi_candidates = 4 if self.dim == 2 else 6
       # max number of roi candidates to identify per image (slice in 2D, volume in 3D)
 
       # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
       self.seg_loss_mode = 'wce'
       self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1]
 
       self.fp_dice_weight = 1 if self.dim == 2 else 1
       # if <1, false positive predictions in foreground are penalized less.
 
       self.detection_min_confidence = 0.05
       # how to determine score of roi: 'max' or 'median'
       self.score_det = 'max'
 
     def add_det_unet_configs(self):
 
       self.learning_rate = [1 * 1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True
       self.scheduling_criterion = "torch_loss"
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       # max number of roi candidates to identify per image (slice in 2D, volume in 3D)
       self.n_roi_candidates = 4 if self.dim == 2 else 6
 
       # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
       self.seg_loss_mode = 'wce'
       self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1]
       # if <1, false positive predictions in foreground are penalized less.
       self.fp_dice_weight = 1 if self.dim == 2 else 1
 
       self.detection_min_confidence = 0.05
       # how to determine score of roi: 'max' or 'median'
       self.score_det = 'max'
 
       self.init_filts = 32
       self.kernel_size = 3  # ks for horizontal, normal convs
       self.kernel_size_m = 2  # ks for max pool
       self.pad = "same"  # "same" or integer, padding of horizontal convs
 
     def add_mrcnn_configs(self):
 
       self.learning_rate = [1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True  # with scheduler set in exec
       self.scheduling_criterion = max(self.model_selection_criteria, key=self.model_selection_criteria.get)
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       # number of classes for network heads: n_foreground_classes + 1 (background)
       self.head_classes = self.num_classes + 1 if 'class' in self.prediction_tasks else 2
 
       # feed +/- n neighbouring slices into channel dimension. set to None for no context.
       self.n_3D_context = None
       if self.n_3D_context is not None and self.dim == 2:
         self.n_channels *= (self.n_3D_context * 2 + 1)
 
       self.detect_while_training = True
       # disable the re-sampling of mask proposals to original size for speed-up.
       # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching),
       # mask outputs are optional.
       self.return_masks_in_train = True
       self.return_masks_in_val = True
       self.return_masks_in_test = True
 
       # feature map strides per pyramid level are inferred from architecture. anchor scales are set accordingly.
       self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]}
       # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale
       # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.)
       self.rpn_anchor_scales = {'xy': [[4], [8], [16], [32]], 'z': [[1], [2], [4], [8]]}
       # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3.
       self.pyramid_levels = [0, 1, 2, 3]
       # number of feature maps in rpn. typically lowered in 3D to save gpu-memory.
       self.n_rpn_features = 512 if self.dim == 2 else 64
 
       # anchor ratios and strides per position in feature maps.
       self.rpn_anchor_ratios = [0.5, 1., 2.]
       self.rpn_anchor_stride = 1
       # Threshold for first stage (RPN) non-maximum suppression (NMS):  LOWER == HARDER SELECTION
       self.rpn_nms_threshold = max(0.8, self.model_max_iou_resolution)
 
       # loss sampling settings.
       self.rpn_train_anchors_per_image = 4
       self.train_rois_per_image = 6 # per batch_instance
       self.roi_positive_ratio = 0.5
       self.anchor_matching_iou = 0.8
 
       # k negative example candidates are drawn from a pool of size k*shem_poolsize (stochastic hard-example mining),
       # where k<=#positive examples.
       self.shem_poolsize = 2
 
       self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3)
       self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5)
       self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10)
 
       self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
       self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
       self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]])
       self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1],
                              self.patch_size_3D[2], self.patch_size_3D[2]])  # y1,x1,y2,x2,z1,z2
 
       if self.dim == 2:
         self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4]
         self.bbox_std_dev = self.bbox_std_dev[:4]
         self.window = self.window[:4]
         self.scale = self.scale[:4]
 
       self.plot_y_max = 1.5
       self.n_plot_rpn_props = 5 if self.dim == 2 else 30  # per batch_instance (slice in 2D / patient in 3D)
 
       # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element.
       self.pre_nms_limit = 2000 if self.dim == 2 else 4000
 
       # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True,
       # since proposals of the entire batch are forwarded through second stage as one "batch".
       self.roi_chunk_size = 1300 if self.dim == 2 else 500
       self.post_nms_rois_training = 200 * (self.head_classes-1) if self.dim == 2 else 400
       self.post_nms_rois_inference = 200 * (self.head_classes-1)
 
       # Final selection of detections (refine_detections)
       self.model_max_instances_per_batch_element = 9 if self.dim == 2 else 18 # per batch element and class.
       self.detection_nms_threshold = self.model_max_iou_resolution  # needs to be > 0, otherwise all predictions are one cluster.
       self.model_min_confidence = 0.2  # iou for nms in box refining (directly after heads), should be >0 since ths>=x in mrcnn.py
 
       if self.dim == 2:
         self.backbone_shapes = np.array(
           [[int(np.ceil(self.patch_size[0] / stride)),
             int(np.ceil(self.patch_size[1] / stride))]
            for stride in self.backbone_strides['xy']])
       else:
         self.backbone_shapes = np.array(
           [[int(np.ceil(self.patch_size[0] / stride)),
             int(np.ceil(self.patch_size[1] / stride)),
             int(np.ceil(self.patch_size[2] / stride_z))]
            for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z']
                                        )])
 
       if self.model == 'retina_net' or self.model == 'retina_unet':
         # whether to use focal loss or SHEM for loss-sample selection
         self.focal_loss = False
         # implement extra anchor-scales according to https://arxiv.org/abs/1708.02002
         self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                         self.rpn_anchor_scales['xy']]
         self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                        self.rpn_anchor_scales['z']]
         self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3
 
         # pre-selection of detections for NMS-speedup. per entire batch.
         self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size
 
         # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002
-        self.anchor_matching_iou = 0.5
+        self.anchor_matching_iou = 0.7
 
         if self.model == 'retina_unet':
           self.operate_stride1 = True
diff --git a/exec.py b/exec.py
index 4d89fcd..5d46dd6 100644
--- a/exec.py
+++ b/exec.py
@@ -1,342 +1,341 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 """ execution script. this where all routines come together and the only script you need to call.
     refer to parse args below to see options for execution.
 """
 
 import plotting as plg
 
 import os
 import warnings
 import argparse
 import time
 
 import torch
 
 import utils.exp_utils as utils
 from evaluator import Evaluator
 from predictor import Predictor
 
 
 for msg in ["Attempting to set identical bottom==top results",
             "This figure includes Axes that are not compatible with tight_layout",
             "Data has no positive values, and therefore cannot be log-scaled.",
             ".*invalid value encountered in true_divide.*"]:
     warnings.filterwarnings("ignore", msg)
 
 
 def train(cf, logger):
     """
     performs the training routine for a given fold. saves plots and selected parameters to the experiment dir
     specified in the configs. logs to file and tensorboard.
     """
     logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format(
         cf.dim, cf.fold, cf.exp_dir, cf.model))
     logger.time("train_val")
 
     # -------------- inits and settings -----------------
     net = model.net(cf, logger).cuda()
     if cf.optimizer == "ADAM":
         optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay)
     elif cf.optimizer == "SGD":
         optimizer = torch.optim.SGD(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay, momentum=0.3)
     if cf.dynamic_lr_scheduling:
         scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=cf.scheduling_mode, factor=cf.lr_decay_factor,
                                                                     patience=cf.scheduling_patience)
     model_selector = utils.ModelSelector(cf, logger)
 
     starting_epoch = 1
     if cf.resume:
         checkpoint_path = os.path.join(cf.fold_dir, "last_state.pth")
         starting_epoch, net, optimizer, model_selector = \
             utils.load_checkpoint(checkpoint_path, net, optimizer, model_selector)
         logger.info('resumed from checkpoint {} to epoch {}'.format(checkpoint_path, starting_epoch))
 
     # prepare monitoring
     monitor_metrics = utils.prepare_monitoring(cf)
 
     logger.info('loading dataset and initializing batch generators...')
     batch_gen = data_loader.get_train_generators(cf, logger)
 
     # -------------- training -----------------
     for epoch in range(starting_epoch, cf.num_epochs + 1):
 
         logger.info('starting training epoch {}/{}'.format(epoch, cf.num_epochs))
         logger.time("train_epoch")
 
         net.train()
 
         train_results_list = []
         train_evaluator = Evaluator(cf, logger, mode='train')
 
         for i in range(cf.num_train_batches):
             logger.time("train_batch_loadfw")
             batch = next(batch_gen['train'])
             batch_gen['train'].generator.stats['roi_counts'] += batch['roi_counts']
             batch_gen['train'].generator.stats['empty_counts'] += batch['empty_counts']
 
             logger.time("train_batch_loadfw")
             logger.time("train_batch_netfw")
             results_dict = net.train_forward(batch)
             logger.time("train_batch_netfw")
             logger.time("train_batch_bw")
             optimizer.zero_grad()
             results_dict['torch_loss'].backward()
             if cf.clip_norm:
                 torch.nn.utils.clip_grad_norm_(net.parameters(), cf.clip_norm, norm_type=2) # gradient clipping
             optimizer.step()
             train_results_list.append(({k:v for k,v in results_dict.items() if k != "seg_preds"}, batch["pid"])) # slim res dict
             if not cf.server_env:
                 print("\rFinished training batch " +
                       "{}/{} in {:.1f}s ({:.2f}/{:.2f} forw load/net, {:.2f} backw).".format(i+1, cf.num_train_batches,
                                                                                              logger.get_time("train_batch_loadfw")+
                                                                                              logger.get_time("train_batch_netfw")
                                                                                              +logger.time("train_batch_bw"),
                                                                                              logger.get_time("train_batch_loadfw",reset=True),
                                                                                              logger.get_time("train_batch_netfw", reset=True),
                                                                                              logger.get_time("train_batch_bw", reset=True)), end="", flush=True)
         print()
 
         #--------------- train eval ----------------
         if (epoch-1)%cf.plot_frequency==0:
             # view an example batch
             utils.split_off_process(plg.view_batch, cf, batch, results_dict, has_colorchannels=cf.has_colorchannels,
                                     show_gt_labels=True, get_time="train-example plot",
                                     out_file=os.path.join(cf.plot_dir, 'batch_example_train_{}.png'.format(cf.fold)))
 
 
         logger.time("evals")
         _, monitor_metrics['train'] = train_evaluator.evaluate_predictions(train_results_list, monitor_metrics['train'])
         logger.time("evals")
         logger.time("train_epoch", toggle=False)
         del train_results_list
 
         #----------- validation ------------
         logger.info('starting validation in mode {}.'.format(cf.val_mode))
         logger.time("val_epoch")
         with torch.no_grad():
             net.eval()
             val_results_list = []
             val_evaluator = Evaluator(cf, logger, mode=cf.val_mode)
             val_predictor = Predictor(cf, net, logger, mode='val')
 
             for i in range(batch_gen['n_val']):
                 logger.time("val_batch")
                 batch = next(batch_gen[cf.val_mode])
                 if cf.val_mode == 'val_patient':
                     results_dict = val_predictor.predict_patient(batch)
                 elif cf.val_mode == 'val_sampling':
                     results_dict = net.train_forward(batch, is_validation=True)
                 val_results_list.append([results_dict, batch["pid"]])
                 if not cf.server_env:
                     print("\rFinished validation {} {}/{} in {:.1f}s.".format('patient' if cf.val_mode=='val_patient' else 'batch',
                                                                               i + 1, batch_gen['n_val'],
                                                                               logger.time("val_batch")), end="", flush=True)
             print()
 
             #------------ val eval -------------
             if (epoch - 1) % cf.plot_frequency == 0:
                 utils.split_off_process(plg.view_batch, cf, batch, results_dict, has_colorchannels=cf.has_colorchannels,
                                         show_gt_labels=True, get_time="val-example plot",
                                         out_file=os.path.join(cf.plot_dir, 'batch_example_val_{}.png'.format(cf.fold)))
 
             logger.time("evals")
             _, monitor_metrics['val'] = val_evaluator.evaluate_predictions(val_results_list, monitor_metrics['val'])
 
             model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch)
             del val_results_list
             #----------- monitoring -------------
             monitor_metrics.update({"lr": 
                 {str(g) : group['lr'] for (g, group) in enumerate(optimizer.param_groups)}})
             logger.metrics2tboard(monitor_metrics, global_step=epoch)
             logger.time("evals")
 
             logger.info('finished epoch {}/{}, took {:.2f}s. train total: {:.2f}s, average: {:.2f}s. val total: {:.2f}s, average: {:.2f}s.'.format(
                 epoch, cf.num_epochs, logger.get_time("train_epoch")+logger.time("val_epoch"), logger.get_time("train_epoch"),
                 logger.get_time("train_epoch", reset=True)/cf.num_train_batches, logger.get_time("val_epoch"),
                 logger.get_time("val_epoch", reset=True)/batch_gen["n_val"]))
             logger.info("time for evals: {:.2f}s".format(logger.get_time("evals", reset=True)))
 
         #-------------- scheduling -----------------
-        if not cf.dynamic_lr_scheduling:
+        if cf.dynamic_lr_scheduling:
+            scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1])
+        else:
             for param_group in optimizer.param_groups:
                 param_group['lr'] = cf.learning_rate[epoch-1]
-        else:
-            scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1])
 
     logger.time("train_val")
     logger.info("Training and validating over {} epochs took {}".format(cf.num_epochs, logger.get_time("train_val", format="hms", reset=True)))
     batch_gen['train'].generator.print_stats(logger, plot=True)
 
 def test(cf, logger, max_fold=None):
     """performs testing for a given fold (or held out set). saves stats in evaluator.
     """
     logger.time("test_fold")
     logger.info('starting testing model of fold {} in exp {}'.format(cf.fold, cf.exp_dir))
     net = model.net(cf, logger).cuda()
     batch_gen = data_loader.get_test_generator(cf, logger)
 
     test_predictor = Predictor(cf, net, logger, mode='test')
     test_results_list = test_predictor.predict_test_set(batch_gen, return_results = not hasattr(
         cf, "eval_test_separately") or not cf.eval_test_separately)
 
     if test_results_list is not None:
         test_evaluator = Evaluator(cf, logger, mode='test')
         test_evaluator.evaluate_predictions(test_results_list)
         test_evaluator.score_test_df(max_fold=max_fold)
 
     logger.info('Testing of fold {} took {}.\n'.format(cf.fold, logger.get_time("test_fold", reset=True, format="hms")))
 
 if __name__ == '__main__':
     stime = time.time()
 
     parser = argparse.ArgumentParser()
     parser.add_argument('--dataset_name', type=str, default='toy',
                         help="path to the dataset-specific code in source_dir/datasets")
     parser.add_argument('--exp_dir', type=str, default='/home/gregor/Documents/regrcnn/datasets/toy/experiments/dev',
                         help='path to experiment dir. will be created if non existent.')
     parser.add_argument('-m', '--mode', type=str,  default='train_test', help='one out of: create_exp, analysis, train, train_test, or test')
     parser.add_argument('-f', '--folds', nargs='+', type=int, default=None, help='None runs over all folds in CV. otherwise specify list of folds.')
     parser.add_argument('--server_env', default=False, action='store_true', help='change IO settings to deploy models on a cluster.')
     parser.add_argument('--data_dest', type=str, default=None, help="path to final data folder if different from config")
     parser.add_argument('--use_stored_settings', default=False, action='store_true',
                         help='load configs from existing exp_dir instead of source dir. always done for testing, '
                              'but can be set to true to do the same for training. useful in job scheduler environment, '
                              'where source code might change before the job actually runs.')
     parser.add_argument('--resume', action="store_true", default=False,
                         help='if given, resume from checkpoint(s) of the specified folds.')
     parser.add_argument('-d', '--dev', default=False, action='store_true', help="development mode: shorten everything")
 
     args = parser.parse_args()
     args.dataset_name = os.path.join("datasets", args.dataset_name) if not "datasets" in args.dataset_name else args.dataset_name
     folds = args.folds
     resume = None if args.resume in ['None', 'none'] else args.resume
 
     if args.mode == 'create_exp':
         cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=False)
         logger = utils.get_logger(cf.exp_dir, cf.server_env, -1)
         logger.info('created experiment directory at {}'.format(args.exp_dir))
 
     elif args.mode == 'train' or args.mode == 'train_test':
         cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, args.use_stored_settings)
         if args.dev:
             folds = [0,1]
             cf.batch_size, cf.num_epochs, cf.min_save_thresh, cf.save_n_models = 3 if cf.dim==2 else 1, 2, 0, 1
             cf.num_train_batches, cf.num_val_batches, cf.max_val_patients = 5, 1, 1
             cf.test_n_epochs =  cf.save_n_models
             cf.max_test_patients = 1
             torch.backends.cudnn.benchmark = cf.dim==3
         else:
             torch.backends.cudnn.benchmark = cf.cuda_benchmark
         if args.data_dest is not None:
             cf.data_dest = args.data_dest
             
         logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval)
         data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py'))
         model = utils.import_module('model', cf.model_path)
         logger.info("loaded model from {}".format(cf.model_path))
         if folds is None:
             folds = range(cf.n_cv_splits)
 
         for fold in folds:
             """k-fold cross-validation: the dataset is split into k equally-sized folds, one used for validation,
             one for testing, the rest for training. This loop iterates k-times over the dataset, cyclically moving the
             splits. k==folds, fold in [0,folds) says which split is used for testing.
             """
             cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold)); cf.fold = fold
             logger.set_logfile(fold=fold)
             cf.resume = resume
-
             if not os.path.exists(cf.fold_dir):
                 os.mkdir(cf.fold_dir)
             train(cf, logger)
             cf.resume = None
             if args.mode == 'train_test':
                 test(cf, logger)
 
     elif args.mode == 'test':
         cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False)
         if args.data_dest is not None:
             cf.data_dest = args.data_dest
         logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval)
         data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py'))
         model = utils.import_module('model', cf.model_path)
         logger.info("loaded model from {}".format(cf.model_path))
 
         fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if
                      os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")])
         if folds is None:
             folds = range(cf.n_cv_splits)
         if args.dev:
             folds = folds[:2]
             cf.batch_size, cf.max_test_patients, cf.test_n_epochs = 1 if cf.dim==2 else 1, 2, 2
         else:
             torch.backends.cudnn.benchmark = cf.cuda_benchmark
         for fold in folds:
             cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold)); cf.fold = fold
             logger.set_logfile(fold=fold)
             if cf.fold_dir in fold_dirs:
                 test(cf, logger, max_fold=max([int(f[-1]) for f in fold_dirs]))
             else:
                 logger.info("Skipping fold {} since no model parameters found.".format(fold))
     # load raw predictions saved by predictor during testing, run aggregation algorithms and evaluation.
     elif args.mode == 'analysis':
         """ analyse already saved predictions.
         """
         cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False)
         logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval)
 
         if cf.held_out_test_set and not cf.eval_test_fold_wise:
             predictor = Predictor(cf, net=None, logger=logger, mode='analysis')
             results_list = predictor.load_saved_predictions()
             logger.info('starting evaluation...')
             cf.fold = 0
             evaluator = Evaluator(cf, logger, mode='test')
             evaluator.evaluate_predictions(results_list)
             evaluator.score_test_df(max_fold=0)
         else:
             fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if
                          os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")])
             if args.dev:
                 fold_dirs = fold_dirs[:1]
             if folds is None:
                 folds = range(cf.n_cv_splits)
             for fold in folds:
                 cf.fold = fold; cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold))
                 logger.set_logfile(fold=fold)
                 if cf.fold_dir in fold_dirs:
                     predictor = Predictor(cf, net=None, logger=logger, mode='analysis')
                     results_list = predictor.load_saved_predictions()
                     # results_list[x][1] is pid, results_list[x][0] is list of len samples-per-patient, each entry hlds
                     # list of boxes per that sample, i.e., len(results_list[x][y][0]) would be nr of boxes in sample y of patient x
                     logger.info('starting evaluation...')
                     evaluator = Evaluator(cf, logger, mode='test')
                     evaluator.evaluate_predictions(results_list)
                     max_fold = max([int(f[-1]) for f in fold_dirs])
                     evaluator.score_test_df(max_fold=max_fold)
                 else:
                     logger.info("Skipping fold {} since no model parameters found.".format(fold))
     else:
         raise ValueError('mode "{}" specified in args is not implemented.'.format(args.mode))
         
     mins, secs = divmod((time.time() - stime), 60)
     h, mins = divmod(mins, 60)
     t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
     logger.info("{} total runtime: {}".format(os.path.split(__file__)[1], t))
     del logger
     torch.cuda.empty_cache()
 
diff --git a/requirements.txt b/requirements.txt
index 5a09f51..18097a5 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,64 +1,67 @@
 absl-py==0.8.1
 backcall==0.1.0
 batchgenerators==0.19.7
 cachetools==3.1.1
 certifi==2019.11.28
 chardet==3.0.4
 cycler==0.10.0
 Cython==0.29.14
 decorator==4.4.1
 future==0.18.2
 google-auth==1.7.2
 google-auth-oauthlib==0.4.1
 grpcio==1.25.0
 idna==2.8
 imageio==2.6.1
 ipython==7.9.0
 ipython-genutils==0.2.0
 jedi==0.15.1
 joblib==0.14.0
 kiwisolver==1.1.0
 linecache2==1.0.0
 Markdown==3.1.1
 matplotlib==3.1.2
 networkx==2.4
 nms-extension==0.0.0
 numpy==1.17.4
 nvidia-ml-py3==7.352.0
 oauthlib==3.1.0
-pandas==0.25.3
+pandas==1.0.3
 parso==0.5.1
 pexpect==4.7.0
 pickleshare==0.7.5
-Pillow==6.2.1
+Pillow==7.1.0
 prompt-toolkit==2.0.10
 protobuf==3.11.1
 psutil==5.7.0
 ptyprocess==0.6.0
 pyasn1==0.4.8
 pyasn1-modules==0.2.7
 Pygments==2.5.2
 pyparsing==2.4.5
 python-dateutil==2.8.1
 pytz==2019.3
 PyWavelets==1.1.1
 RegRCNN==0.0.2
 requests==2.22.0
 requests-oauthlib==1.3.0
+RoIAlign-extension-2D==0.0.0
+RoIAlign-extension-3D==0.0.0
 rsa==4.0
 scikit-image==0.16.2
 scikit-learn==0.21.3
 scipy==1.3.1
 SimpleITK==1.2.3
 six==1.13.0
-tensorboard==2.0.2
+tensorboard==2.2.0
+tensorboard-plugin-wit==1.6.0.post2
 threadpoolctl==2.0.0
-torch==1.3.1
-torchvision==0.4.2
+torch==1.4.0
+torchvision==0.5.0
 tqdm==4.39.0
 traceback2==1.4.0
 traitlets==4.3.3
 unittest2==1.1.0
 urllib3==1.25.7
 wcwidth==0.1.7
-Werkzeug==0.16.0
+Werkzeug==1.0.1
diff --git a/setup.py b/setup.py
index 2ab3c74..b59bbe0 100644
--- a/setup.py
+++ b/setup.py
@@ -1,60 +1,72 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 from setuptools import find_packages, setup
-import os
+import os, sys, subprocess
 
 def parse_requirements(filename, exclude=[]):
     lineiter = (line.strip() for line in open(filename))
     return [line for line in lineiter if line and not line.startswith("#") and not line.split("==")[0] in exclude]
 
+def pip_install(item):
+    subprocess.check_call([sys.executable, "-m", "pip", "install", item])
+
 def install_custom_ext(setup_path):
-    os.system("python "+setup_path+" install")
-    return
+    try:
+        pip_install(setup_path)
+    except Exception as e:
+        print("Could not install custom extension {} from source due to error:\n{}\n".format(path, e) +
+              "Trying to install from pre-compiled wheel.")
+        dist_path = setup_path+"/dist"
+        wheel_file = [fn for fn in os.listdir(dist_path) if fn.endswith(".whl")][0]
+        pip_install(os.path.join(dist_path, wheel_file))
 
 def clean():
     """Custom clean command to tidy up the project root."""
     os.system('rm -vrf ./build ./dist ./*.pyc ./*.tgz ./*.egg-info')
 
-req_file = "requirements.txt"
-custom_exts = ["nms-extension", "RoIAlign-extension-2D", "RoIAlign-extension-3D"]
-install_reqs = parse_requirements(req_file, exclude=custom_exts)
-
-setup(name='RegRCNN',
-      version='0.0.2',
-      url="https://github.com/MIC-DKFZ/RegRCNN",
-      author='G. Ramien, P. Jaeger, MIC at DKFZ Heidelberg',
-      author_email='g.ramien@dkfz.de',
-      licence="Apache 2.0",
-      description="Medical Object-Detection Toolkit incl. Regression Capability.",
-      classifiers=[
-          "Development Status :: 4 - Beta",
-          "Intended Audience :: Developers",
-          "Programming Language :: Python :: 3.7"
-      ],
-      packages=find_packages(exclude=['test', 'test.*']),
-      install_requires=install_reqs,
-      )
-
-custom_exts =  ["custom_extensions/nms", "custom_extensions/roi_align"]
-for path in custom_exts:
-    setup_path = os.path.join(path, "setup.py")
-    try:
-        install_custom_ext(setup_path)
-    except Exception as e:
-        print("FAILED to install custom extension {} due to Error:\n{}".format(path, e))
 
-clean()
\ No newline at end of file
+
+if __name__ == "__main__":
+
+    req_file = "requirements.txt"
+    custom_exts = ["nms-extension", "RoIAlign-extension-2D", "RoIAlign-extension-3D"]
+    install_reqs = parse_requirements(req_file, exclude=custom_exts)
+
+    setup(name='RegRCNN',
+          version='0.0.2',
+          url="https://github.com/MIC-DKFZ/RegRCNN",
+          author='G. Ramien, P. Jaeger, MIC at DKFZ Heidelberg',
+          author_email='g.ramien@dkfz.de',
+          licence="Apache 2.0",
+          description="Medical Object-Detection Toolkit incl. Regression Capability.",
+          classifiers=[
+              "Development Status :: 4 - Beta",
+              "Intended Audience :: Developers",
+              "Programming Language :: Python :: 3.7"
+          ],
+          packages=find_packages(exclude=['test', 'test.*']),
+          install_requires=install_reqs,
+          )
+
+    custom_exts =  ["custom_extensions/nms", "custom_extensions/roi_align/2D", "custom_extensions/roi_align/3D"]
+    for path in custom_exts:
+        try:
+            install_custom_ext(path)
+        except Exception as e:
+            print("FAILED to install custom extension {} due to Error:\n{}".format(path, e))
+
+    clean()
\ No newline at end of file