diff --git a/custom_extensions/nms/nms.py b/custom_extensions/nms/nms.py new file mode 100644 index 0000000..cdfb078 --- /dev/null +++ b/custom_extensions/nms/nms.py @@ -0,0 +1,75 @@ +""" +adopted from pytorch framework, torchvision.ops.boxes + +""" + +import torch +import nms_extension + +def nms(boxes, scores, iou_threshold): + """ + Performs non-maximum suppression (NMS) on the boxes according + to their intersection-over-union (IoU). + + NMS iteratively removes lower scoring boxes which have an + IoU greater than iou_threshold with another (higher scoring) + box. + + Parameters + ---------- + boxes : Tensor[N, 4] for 2D or Tensor[N,6] for 3D. + boxes to perform NMS on. They + are expected to be in (y1, x1, y2, x2(, z1, z2)) format + scores : Tensor[N] + scores for each one of the boxes + iou_threshold : float + discards all overlapping + boxes with IoU < iou_threshold + + Returns + ------- + keep : Tensor + int64 tensor with the indices + of the elements that have been kept + by NMS, sorted in decreasing order of scores + """ + return nms_extension.nms(boxes, scores, iou_threshold) + + +def batched_nms(boxes, scores, idxs, iou_threshold): + """ + Performs non-maximum suppression in a batched fashion. + + Each index value correspond to a category, and NMS + will not be applied between elements of different categories. + + Parameters + ---------- + boxes : Tensor[N, 4] for 2D or Tensor[N,6] for 3D. + boxes to perform NMS on. They + are expected to be in (y1, x1, y2, x2(, z1, z2)) format + scores : Tensor[N] + scores for each one of the boxes + idxs : Tensor[N] + indices of the categories for each one of the boxes. + iou_threshold : float + discards all overlapping boxes + with IoU < iou_threshold + + Returns + ------- + keep : Tensor + int64 tensor with the indices of + the elements that have been kept by NMS, sorted + in decreasing order of scores + """ + if boxes.numel() == 0: + return torch.empty((0,), dtype=torch.int64, device=boxes.device) + # strategy: in order to perform NMS independently per class. + # we add an offset to all the boxes. The offset is dependent + # only on the class idx, and is large enough so that boxes + # from different classes do not overlap + max_coordinate = boxes.max() + offsets = idxs.to(boxes) * (max_coordinate + 1) + boxes_for_nms = boxes + offsets[:, None] + return nms(boxes_for_nms, scores, iou_threshold) diff --git a/custom_extensions/nms/setup.py b/custom_extensions/nms/setup.py new file mode 100644 index 0000000..911e616 --- /dev/null +++ b/custom_extensions/nms/setup.py @@ -0,0 +1,18 @@ +""" +Created at 07.11.19 19:12 +@author: gregor + +""" + +import os, sys +from pathlib import Path + +from setuptools import setup +from torch.utils import cpp_extension + +dir_ = Path(os.path.dirname(sys.argv[0])) + +setup(name='nms_extension', + ext_modules=[cpp_extension.CUDAExtension('nms_extension', [str(dir_/'src/nms_interface.cpp'), str(dir_/'src/nms.cu')])], + cmdclass={'build_ext': cpp_extension.BuildExtension} + ) \ No newline at end of file diff --git a/custom_extensions/nms/src/nms.cu b/custom_extensions/nms/src/nms.cu new file mode 100644 index 0000000..913d835 --- /dev/null +++ b/custom_extensions/nms/src/nms.cu @@ -0,0 +1,220 @@ +/* +NMS implementation in CUDA from pytorch framework +(https://github.com/pytorch/vision/tree/master/torchvision/csrc/cuda on Nov 13 2019) + +Adapted for additional 3D capability by G. Ramien, DKFZ Heidelberg +*/ + +#include +#include +#include +#include +#include + +#include "cuda_helpers.h" + +#include +#include + +int const threadsPerBlock = sizeof(unsigned long long) * 8; + +template +__device__ inline float devIoU(T const* const a, T const* const b) { + // a, b hold box coords as (y1, x1, y2, x2) with y1 < y2 etc. + T bottom = max(a[0], b[0]), top = min(a[2], b[2]); + T left = max(a[1], b[1]), right = min(a[3], b[3]); + T width = max(right - left, (T)0), height = max(top - bottom, (T)0); + T interS = width * height; + + T Sa = (a[2] - a[0]) * (a[3] - a[1]); + T Sb = (b[2] - b[0]) * (b[3] - b[1]); + + return interS / (Sa + Sb - interS); +} + +template +__device__ inline float devIoU_3d(T const* const a, T const* const b) { + // a, b hold box coords as (y1, x1, y2, x2, z1, z2) with y1 < y2 etc. + // get coordinates of intersection, calc intersection + T bottom = max(a[0], b[0]), top = min(a[2], b[2]); + T left = max(a[1], b[1]), right = min(a[3], b[3]); + T front = max(a[4], b[4]), back = min(a[5], b[5]); + T width = max(right - left, (T)0), height = max(top - bottom, (T)0); + T depth = max(back - front + 1, (T)0); + T interS = width * height * depth; + // calc separate boxes volumes + T Sa = (a[2] - a[0]) * (a[3] - a[1]) * (a[5] - a[4] +1); + T Sb = (b[2] - b[0]) * (b[3] - b[1]) * (b[5] - b[4] +1); + + return interS / (Sa + Sb - interS); +} + + +template +__global__ void nms_kernel(const int n_boxes, const float iou_threshold, const T* dev_boxes, + unsigned long long* dev_mask) { + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + const int row_size = + min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); + const int col_size = + min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); + + __shared__ T block_boxes[threadsPerBlock * 4]; + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 4 + 0] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 0]; + block_boxes[threadIdx.x * 4 + 1] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 1]; + block_boxes[threadIdx.x * 4 + 2] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 2]; + block_boxes[threadIdx.x * 4 + 3] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 3]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; + const T* cur_box = dev_boxes + cur_box_idx * 4; + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (devIoU(cur_box, block_boxes + i * 4) > iou_threshold) { + t |= 1ULL << i; + } + } + const int col_blocks = at::cuda::ATenCeilDiv(n_boxes, threadsPerBlock); + dev_mask[cur_box_idx * col_blocks + col_start] = t; + } +} + + +template +__global__ void nms_kernel_3d(const int n_boxes, const float iou_threshold, const T* dev_boxes, + unsigned long long* dev_mask) { + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + const int row_size = + min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); + const int col_size = + min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); + + __shared__ T block_boxes[threadsPerBlock * 6]; + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 6 + 0] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 0]; + block_boxes[threadIdx.x * 6 + 1] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 1]; + block_boxes[threadIdx.x * 6 + 2] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 2]; + block_boxes[threadIdx.x * 6 + 3] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 3]; + block_boxes[threadIdx.x * 6 + 4] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 4]; + block_boxes[threadIdx.x * 6 + 5] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 5]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; + const T* cur_box = dev_boxes + cur_box_idx * 6; + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (devIoU_3d(cur_box, block_boxes + i * 6) > iou_threshold) { + t |= 1ULL << i; + } + } + const int col_blocks = at::cuda::ATenCeilDiv(n_boxes, threadsPerBlock); + dev_mask[cur_box_idx * col_blocks + col_start] = t; + } +} + + +at::Tensor nms_cuda(const at::Tensor& dets, const at::Tensor& scores, float iou_threshold) { + /* dets expected as (n_dets, dim) where dim=4 in 2D, dim=6 in 3D */ + AT_ASSERTM(dets.type().is_cuda(), "dets must be a CUDA tensor"); + AT_ASSERTM(scores.type().is_cuda(), "scores must be a CUDA tensor"); + at::cuda::CUDAGuard device_guard(dets.device()); + + bool is_3d = dets.size(1) == 6; + auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); + auto dets_sorted = dets.index_select(0, order_t); + + int dets_num = dets.size(0); + + const int col_blocks = at::cuda::ATenCeilDiv(dets_num, threadsPerBlock); + + at::Tensor mask = + at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); + + dim3 blocks(col_blocks, col_blocks); + dim3 threads(threadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + + if (is_3d) { + //std::cout << "performing NMS on 3D boxes in CUDA" << std::endl; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + dets_sorted.type(), "nms_kernel_cuda", [&] { + nms_kernel_3d<<>>( + dets_num, + iou_threshold, + dets_sorted.data_ptr(), + (unsigned long long*)mask.data_ptr()); + }); + } + else { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + dets_sorted.type(), "nms_kernel_cuda", [&] { + nms_kernel<<>>( + dets_num, + iou_threshold, + dets_sorted.data_ptr(), + (unsigned long long*)mask.data_ptr()); + }); + + } + + at::Tensor mask_cpu = mask.to(at::kCPU); + unsigned long long* mask_host = (unsigned long long*)mask_cpu.data_ptr(); + + std::vector remv(col_blocks); + memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); + + at::Tensor keep = + at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU)); + int64_t* keep_out = keep.data_ptr(); + + int num_to_keep = 0; + for (int i = 0; i < dets_num; i++) { + int nblock = i / threadsPerBlock; + int inblock = i % threadsPerBlock; + + if (!(remv[nblock] & (1ULL << inblock))) { + keep_out[num_to_keep++] = i; + unsigned long long* p = mask_host + i * col_blocks; + for (int j = nblock; j < col_blocks; j++) { + remv[j] |= p[j]; + } + } + } + + AT_CUDA_CHECK(cudaGetLastError()); + return order_t.index( + {keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep) + .to(order_t.device(), keep.scalar_type())}); +} \ No newline at end of file diff --git a/custom_extensions/nms/src/nms_interface.cpp b/custom_extensions/nms/src/nms_interface.cpp new file mode 100644 index 0000000..e04f65d --- /dev/null +++ b/custom_extensions/nms/src/nms_interface.cpp @@ -0,0 +1,32 @@ +/* adopted from + https://github.com/pytorch/vision/blob/master/torchvision/csrc/nms.h on Nov 15 2019 + no cpu support, but could be added with this interface. +*/ +#include + + +//#include "cpu/vision_cpu.h" + +at::Tensor nms_cuda(const at::Tensor& dets, const at::Tensor& scores, float iou_threshold); + +at::Tensor nms( + const at::Tensor& dets, + const at::Tensor& scores, + const double iou_threshold) { + if (dets.device().is_cuda()) { + + if (dets.numel() == 0) { + //at::cuda::CUDAGuard device_guard(dets.device()); + return at::empty({0}, dets.options().dtype(at::kLong)); + } + return nms_cuda(dets, scores, iou_threshold); + + } + AT_ERROR("Not compiled with CPU support"); + //at::Tensor result = nms_cpu(dets, scores, iou_threshold); + //return result; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("nms", &nms, "NMS C++ and/or CUDA"); +} \ No newline at end of file diff --git a/custom_extensions/roi_align/roi_align.py b/custom_extensions/roi_align/roi_align.py new file mode 100644 index 0000000..817e871 --- /dev/null +++ b/custom_extensions/roi_align/roi_align.py @@ -0,0 +1,128 @@ +""" +ROIAlign implementation from pytorch framework +(https://github.com/pytorch/vision/blob/master/torchvision/ops/roi_align.py on Nov 14 2019) + +adapted for 3D support without additional python function interface (only cpp function interface). +""" + +import torch +from torch import nn + +from torchvision.ops._utils import convert_boxes_to_roi_format + +import roi_al_extension +import roi_al_extension_3d + +def roi_align_2d(input: torch.Tensor, boxes, output_size, + spatial_scale: float = 1.0, sampling_ratio: int =-1) -> torch.Tensor: + """ + Performs Region of Interest (RoI) Align operator described in Mask R-CNN + + Arguments: + input: (Tensor[N, C, H, W]), input tensor + boxes: (Tensor[K, 5] or List[Tensor[L, 4]]), the box coordinates in (x1, y1, x2, y2) + format where the regions will be taken from. If a single Tensor is passed, + then the first column should contain the batch index. If a list of Tensors + is passed, then each Tensor will correspond to the boxes for an element i + in a batch + output_size: (int or Tuple[int, int]) the size of the output after the cropping + is performed, as (height, width) + spatial_scale: (float) a scaling factor that maps the input coordinates to + the box coordinates. Default: 1.0 + sampling_ratio: (int) number of sampling points in the interpolation grid + used to compute the output value of each pooled output bin. If > 0, + then exactly sampling_ratio x sampling_ratio grid points are used. If + <= 0, then an adaptive number of grid points are used (computed as + ceil(roi_width / pooled_w), and likewise for height). Default: -1 + + Returns: + output (Tensor[K, C, output_size[0], output_size[1]]) + """ + rois = boxes + if not isinstance(rois, torch.Tensor): + rois = convert_boxes_to_roi_format(rois) + return roi_al_extension.roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio) + + +def roi_align_3d(input: torch.Tensor, boxes, output_size, + spatial_scale: float = 1.0, sampling_ratio: int = -1) -> torch.Tensor: + """ + Performs Region of Interest (RoI) Align operator described in Mask R-CNN for 3-dim input. + + Arguments: + input (Tensor[N, C, H, W, D]): input tensor + boxes (Tensor[K, 7] or List[Tensor[L, 6]]): the box coordinates in (x1, y1, x2, y2, z1 ,z2) + format where the regions will be taken from. If a single Tensor is passed, + then the first column should contain the batch index. If a list of Tensors + is passed, then each Tensor will correspond to the boxes for an element i + in a batch + output_size (int or Tuple[int, int, int]): the size of the output after the cropping + is performed, as (height, width, depth) + spatial_scale (float): a scaling factor that maps the input coordinates to + the box coordinates. Default: 1.0 + sampling_ratio (int): number of sampling points in the interpolation grid + used to compute the output value of each pooled output bin. If > 0, + then exactly sampling_ratio x sampling_ratio grid points are used. If + <= 0, then an adaptive number of grid points are used (computed as + ceil(roi_width / pooled_w), and likewise for height). Default: -1 + + Returns: + output (Tensor[K, C, output_size[0], output_size[1], output_size[2]]) + """ + rois = boxes + if not isinstance(rois, torch.Tensor): + rois = convert_boxes_to_roi_format(rois) + return roi_al_extension_3d.roi_align(input, rois, spatial_scale, output_size[0], output_size[1], output_size[2], + sampling_ratio) + + +class RoIAlign(nn.Module): + """ + Performs Region of Interest (RoI) Align operator described in Mask R-CNN for 2- or 3-dim input. + + Arguments: + input (Tensor[N, C, H, W(, D)]): input tensor + boxes (Tensor[K, 5] or List[Tensor[L, 4]]) or (Tensor[K, 7] or List[Tensor[L, 6]]): + the box coordinates in (x1, y1, x2, y2(, z1 ,z2)) + format where the regions will be taken from. If a single Tensor is passed, + then the first column should contain the batch index. If a list of Tensors + is passed, then each Tensor will correspond to the boxes for an element i + in a batch + output_size (int or Tuple[int, int(, int)]): the size of the output after the cropping + is performed, as (height, width(, depth)) + spatial_scale (float): a scaling factor that maps the input coordinates to + the box coordinates. Default: 1.0 + sampling_ratio (int): number of sampling points in the interpolation grid + used to compute the output value of each pooled output bin. If > 0, + then exactly sampling_ratio x sampling_ratio grid points are used. If + <= 0, then an adaptive number of grid points are used (computed as + ceil(roi_width / pooled_w), and likewise for height (and depth)). Default: -1 + + Returns: + output (Tensor[K, C, output_size[0], output_size[1](, output_size[2])]) + """ + def __init__(self, output_size, spatial_scale, sampling_ratio): + super(RoIAlign, self).__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + self.sampling_ratio = sampling_ratio + self.dim = len(self.output_size) + + if self.dim == 2: + self.roi_align = roi_align_2d + elif self.dim == 3: + self.roi_align = roi_align_3d + else: + raise Exception("Tried to init RoIAlign module with incorrect output size: {}".format(self.output_size)) + + def forward(self, input: torch.Tensor, rois) -> torch.Tensor: + return self.roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio) + + def __repr__(self): + tmpstr = self.__class__.__name__ + '(' + tmpstr += 'output_size=' + str(self.output_size) + tmpstr += ', spatial_scale=' + str(self.spatial_scale) + tmpstr += ', sampling_ratio=' + str(self.sampling_ratio) + tmpstr += ', dimension=' + str(self.dim) + tmpstr += ')' + return tmpstr diff --git a/custom_extensions/roi_align/setup.py b/custom_extensions/roi_align/setup.py new file mode 100644 index 0000000..a1edef8 --- /dev/null +++ b/custom_extensions/roi_align/setup.py @@ -0,0 +1,25 @@ +""" +Created at 07.11.19 19:12 +@author: gregor + +""" + +import os, sys +from pathlib import Path + +from setuptools import setup +from torch.utils import cpp_extension + +dir_ = Path(os.path.dirname(sys.argv[0])) + +setup(name='RoIAlign extension 2D', + ext_modules=[cpp_extension.CUDAExtension('roi_al_extension', [str(dir_/'src/RoIAlign_interface.cpp'), + str(dir_/'src/RoIAlign_cuda.cu')])], + cmdclass={'build_ext': cpp_extension.BuildExtension} + ) + +setup(name='RoIAlign extension 3D', + ext_modules=[cpp_extension.CUDAExtension('roi_al_extension_3d', [str(dir_/'src/RoIAlign_interface_3d.cpp'), + str(dir_/'src/RoIAlign_cuda_3d.cu')])], + cmdclass={'build_ext': cpp_extension.BuildExtension} + ) \ No newline at end of file diff --git a/custom_extensions/roi_align/src/RoIAlign_cuda.cu b/custom_extensions/roi_align/src/RoIAlign_cuda.cu new file mode 100644 index 0000000..b31ba44 --- /dev/null +++ b/custom_extensions/roi_align/src/RoIAlign_cuda.cu @@ -0,0 +1,429 @@ +/* +ROIAlign implementation in CUDA from pytorch framework +(https://github.com/pytorch/vision/tree/master/torchvision/csrc/cuda on Nov 14 2019) + +*/ + +#include +#include +#include +#include +#include + +#include "cuda_helpers.h" + +template +__device__ T bilinear_interpolate( + const T* input, + const int height, + const int width, + T y, + T x, + const int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + return 0; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // do bilinear interpolation + T v1 = input[y_low * width + x_low]; + T v2 = input[y_low * width + x_high]; + T v3 = input[y_high * width + x_low]; + T v4 = input[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} + +template +__global__ void RoIAlignForward( + const int nthreads, + const T* input, + const T spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const T* rois, + T* output) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale; + T roi_start_h = offset_rois[2] * spatial_scale; + T roi_end_w = offset_rois[3] * spatial_scale; + T roi_end_h = offset_rois[4] * spatial_scale; + + // Force malformed ROIs to be 1x1 + T roi_width = max(roi_end_w - roi_start_w, (T)1.); + T roi_height = max(roi_end_h - roi_start_h, (T)1.); + //printf("roi height %f, width %f\n", (float) roi_height, (float) roi_width); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T val = bilinear_interpolate(offset_input, height, width, y, x, index); + output_val += val; + } + } + output_val /= count; + + output[index] = output_val; + } +} + +template +__device__ void bilinear_interpolate_gradient( + const int height, + const int width, + T y, + T x, + T& w1, + T& w2, + T& w3, + T& w4, + int& x_low, + int& x_high, + int& y_low, + int& y_high, + const int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + return; +} + +template +__global__ void RoIAlignBackward( + const int nthreads, + const T* grad_output, + const T spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + T* grad_input, + const T* rois, + const int n_stride, + const int c_stride, + const int h_stride, + const int w_stride) +{ + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale; + T roi_start_h = offset_rois[2] * spatial_scale; + T roi_end_w = offset_rois[3] * spatial_scale; + T roi_end_h = offset_rois[4] * spatial_scale; + + // Force malformed ROIs to be 1x1 + T roi_width = max(roi_end_w - roi_start_w, (T)1.); + T roi_height = max(roi_end_h - roi_start_h, (T)1.); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + T* offset_grad_input = + grad_input + ((roi_batch_ind * channels + c) * height * width); + + // We need to index the gradient using the tensor strides to access the + // correct values. + int output_offset = n * n_stride + c * c_stride; + const T* offset_grad_output = grad_output + output_offset; + const T grad_output_this_bin = + offset_grad_output[ph * h_stride + pw * w_stride]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, + width, + y, + x, + w1, + w2, + w3, + w4, + x_low, + x_high, + y_low, + y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomicAdd( + offset_grad_input + y_low * width + x_low, static_cast(g1)); + atomicAdd( + offset_grad_input + y_low * width + x_high, static_cast(g2)); + atomicAdd( + offset_grad_input + y_high * width + x_low, static_cast(g3)); + atomicAdd( + offset_grad_input + y_high * width + x_high, static_cast(g4)); + } // if + } // ix + } // iy + } // CUDA_1D_KERNEL_LOOP +} // RoIAlignBackward + +at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, const at::Tensor& rois, const float spatial_scale, + const int pooled_height, const int pooled_width, const int sampling_ratio) { + /* + input: feature-map tensor, shape (batch, n_channels, y, x(, z)) + */ + AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ROIAlign_forward_cuda"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + at::cuda::CUDAGuard device_guard(input.device()); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + at::Tensor output = at::zeros( + {num_rois, channels, pooled_height, pooled_width}, input.options()); + + auto output_size = num_rois * pooled_height * pooled_width * channels; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min( + at::cuda::ATenCeilDiv( + static_cast(output_size), static_cast(512)), + static_cast(4096))); + dim3 block(512); + + if (output.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return output; + } + + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign_forward", [&] { + RoIAlignForward<<>>( + output_size, + input.contiguous().data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois.contiguous().data_ptr(), + output.data_ptr()); + }); + AT_CUDA_CHECK(cudaGetLastError()); + return output; +} + +at::Tensor ROIAlign_backward_cuda( + const at::Tensor& grad, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio) { + AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ROIAlign_backward_cuda"; + at::checkAllSameGPU(c, {grad_t, rois_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + at::cuda::CUDAGuard device_guard(grad.device()); + + at::Tensor grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min( + at::cuda::ATenCeilDiv( + static_cast(grad.numel()), static_cast(512)), + static_cast(4096))); + dim3 block(512); + + // handle possibly empty gradients + if (grad.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return grad_input; + } + + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_backward", [&] { + RoIAlignBackward<<>>( + grad.numel(), + grad.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + grad_input.data_ptr(), + rois.contiguous().data_ptr(), + n_stride, + c_stride, + h_stride, + w_stride); + }); + AT_CUDA_CHECK(cudaGetLastError()); + return grad_input; +} \ No newline at end of file diff --git a/custom_extensions/roi_align/src/RoIAlign_cuda_3d.cu b/custom_extensions/roi_align/src/RoIAlign_cuda_3d.cu new file mode 100644 index 0000000..a8022bb --- /dev/null +++ b/custom_extensions/roi_align/src/RoIAlign_cuda_3d.cu @@ -0,0 +1,488 @@ +/* +ROIAlign implementation in CUDA from pytorch framework +(https://github.com/pytorch/vision/tree/master/torchvision/csrc/cuda on Nov 14 2019) + +Adapted for additional 3D capability by G. Ramien, DKFZ Heidelberg +*/ + +#include +#include +#include +#include +#include +#include +#include "cuda_helpers.h" + +/*-------------- gpu kernels -----------------*/ + +template +__device__ T linear_interpolate(const T xl, const T val_low, const T val_high){ + + T val = (val_high - val_low) * xl + val_low; + return val; +} + +template +__device__ T trilinear_interpolate(const T* input, const int height, const int width, const int depth, + T y, T x, T z, const int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width || z < -1.0 || z > depth) { + // empty + return 0; + } + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + if (z <= 0) + z = 0; + + int y0 = (int)y; + int x0 = (int)x; + int z0 = (int)z; + int y1; + int x1; + int z1; + + if (y0 >= height - 1) { + /*if nearest gridpoint to y on the lower end is on border or border-1, set low, high, mid(=actual point) to border-1*/ + y1 = y0 = height - 1; + y = (T)y0; + } else { + /* y1 is one pixel from y0, y is the actual point somewhere in between */ + y1 = y0 + 1; + } + if (x0 >= width - 1) { + x1 = x0 = width - 1; + x = (T)x0; + } else { + x1 = x0 + 1; + } + if (z0 >= depth - 1) { + z1 = z0 = depth - 1; + z = (T)z0; + } else { + z1 = z0 + 1; + } + + + // do linear interpolation of x values + // distance of actual point to lower boundary point, already normalized since x_high - x0 = 1 + T dis = x - x0; + /* accessing element b,c,y,x,z in 1D-rolled-out array of a tensor with dimensions (B, C, Y, X, Z): + tensor[b,c,y,x,z] = arr[ (((b*C+c)*Y+y)*X + x)*Z + z ] = arr[ alpha + (y*X + x)*Z + z ] + with alpha = batch&channel locator = (b*C+c)*YXZ. + hence, as current input pointer is already offset by alpha: y,x,z at input[( y*X + x)*Z + z], where + X = width, Z = depth. + */ + T x00 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z0], input[(y0*width+ x1)*depth+z0]); + T x10 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z0], input[(y1*width+ x1)*depth+z0]); + T x01 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z1], input[(y0*width+ x1)*depth+z1]); + T x11 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z1], input[(y1*width+ x1)*depth+z1]); + + // linear interpol of y values = bilinear interpol of f(x,y) + dis = y - y0; + T xy0 = linear_interpolate(dis, x00, x10); + T xy1 = linear_interpolate(dis, x01, x11); + + // linear interpol of z value = trilinear interpol of f(x,y,z) + dis = z - z0; + T xyz = linear_interpolate(dis, xy0, xy1); + + return xyz; +} + +template +__device__ void trilinear_interpolate_gradient(const int height, const int width, const int depth, T y, T x, T z, + T& g000, T& g001, T& g010, T& g100, T& g011, T& g101, T& g110, T& g111, + int& x0, int& x1, int& y0, int& y1, int& z0, int&z1, const int index /* index for debug only*/) +{ + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width || z < -1.0 || z > depth) { + // empty + g000 = g001 = g010 = g100 = g011 = g101 = g110 = g111 = 0.; + x0 = x1 = y0 = y1 = z0 = z1 = -1; + return; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + if (z <= 0) + z = 0; + + y0 = (int)y; + x0 = (int)x; + z0 = (int)z; + + if (y0 >= height - 1) { + y1 = y0 = height - 1; + y = (T)y0; + } else { + y1 = y0 + 1; + } + + if (x0 >= width - 1) { + x1 = x0 = width - 1; + x = (T)x0; + } else { + x1 = x0 + 1; + } + + if (z0 >= depth - 1) { + z1 = z0 = depth - 1; + z = (T)z0; + } else { + z1 = z0 + 1; + } + + // forward calculations are added as hints + T dis_x = x - x0; + //T x00 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z0], input[(y0*width+ x1)*depth+z0]); // v000, v100 + //T x10 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z0], input[(y1*width+ x1)*depth+z0]); // v010, v110 + //T x01 = linear_interpolate(dis, input[(y0*width+ x0)*depth+z1], input[(y0*width+ x1)*depth+z1]); // v001, v101 + //T x11 = linear_interpolate(dis, input[(y1*width+ x0)*depth+z1], input[(y1*width+ x1)*depth+z1]); // v011, v111 + + // linear interpol of y values = bilinear interpol of f(x,y) + T dis_y = y - y0; + //T xy0 = linear_interpolate(dis, x00, x10); + //T xy1 = linear_interpolate(dis, x01, x11); + + // linear interpol of z value = trilinear interpol of f(x,y,z) + T dis_z = z - z0; + //T xyz = linear_interpolate(dis, xy0, xy1); + + /* need: grad_i := d(xyz)/d(v_i) with v_i = input_value_i for all i = 0,..,7 (eight input values --> eight-entry gradient) + d(lin_interp(dis,x,y))/dx = (-dis +1) and d(lin_interp(dis,x,y))/dy = dis --> derivatives are indep of x,y. + notation: gxyz = gradient for d(trilin_interp)/d(input_value_at_xyz) + below grads were calculated by hand + save time by reusing (1-dis_x) = 1-x+x0 = x1-x =: dis_x1 */ + T dis_x1 = (1-dis_x), dis_y1 = (1-dis_y), dis_z1 = (1-dis_z); + + g000 = dis_z1 * dis_y1 * dis_x1; + g001 = dis_z * dis_y1 * dis_x1; + g010 = dis_z1 * dis_y * dis_x1; + g100 = dis_z1 * dis_y1 * dis_x; + g011 = dis_z * dis_y * dis_x1; + g101 = dis_z * dis_y1 * dis_x; + g110 = dis_z1 * dis_y * dis_x; + g111 = dis_z * dis_y * dis_x; + + return; +} + +template +__global__ void RoIAlignForward(const int nthreads, const T* input, const T spatial_scale, const int channels, + const int height, const int width, const int depth, const int pooled_height, const int pooled_width, + const int pooled_depth, const int sampling_ratio, const T* rois, T* output) +{ + + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw, pd) is an element in the pooled output + int pd = index % pooled_depth; + int pw = (index / pooled_depth) % pooled_width; + int ph = (index / pooled_depth / pooled_width) % pooled_height; + int c = (index / pooled_depth / pooled_width / pooled_height) % channels; + int n = index / pooled_depth / pooled_width / pooled_height / channels; + + + // rois are (y1,x1,y2,x2,z1,z2) --> tensor of shape (n_rois, 6) + const T* offset_rois = rois + n * 7; + int roi_batch_ind = offset_rois[0]; + // Do not use rounding; this implementation detail is critical + T roi_start_h = offset_rois[1] * spatial_scale; + T roi_start_w = offset_rois[2] * spatial_scale; + T roi_end_h = offset_rois[3] * spatial_scale; + T roi_end_w = offset_rois[4] * spatial_scale; + T roi_start_d = offset_rois[5] * spatial_scale; + T roi_end_d = offset_rois[6] * spatial_scale; + + // Force malformed ROIs to be 1x1 + T roi_height = max(roi_end_h - roi_start_h, (T)1.); + T roi_width = max(roi_end_w - roi_start_w, (T)1.); + T roi_depth = max(roi_end_d - roi_start_d, (T)1.); + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + T bin_size_d = static_cast(roi_depth) / static_cast(pooled_depth); + + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width * depth; + + // We use roi_bin_grid to sample the grid and mimic integral + // roi_bin_grid == nr of sampling points per bin >= 1 + int roi_bin_grid_h = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + int roi_bin_grid_d = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_depth / pooled_depth); + + // We do average (integral) pooling inside a bin + const T n_voxels = roi_bin_grid_h * roi_bin_grid_w * roi_bin_grid_d; // e.g. = 4 + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5, always in the middle of two grid pointsk + + for (int ix = 0; ix < roi_bin_grid_w; ix++) + { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + + for (int iz = 0; iz < roi_bin_grid_d; iz++) + { + const T z = roi_start_d + pd * bin_size_d + + static_cast(iz + .5f) * bin_size_d / static_cast(roi_bin_grid_d); + // TODO verify trilinear interpolation + T val = trilinear_interpolate(offset_input, height, width, depth, y, x, z, index); + output_val += val; + } // z iterator and calc+add value + } // x iterator + } // y iterator + output_val /= n_voxels; + + output[index] = output_val; + } +} + +template +__global__ void RoIAlignBackward(const int nthreads, const T* grad_output, const T spatial_scale, const int channels, + const int height, const int width, const int depth, const int pooled_height, const int pooled_width, + const int pooled_depth, const int sampling_ratio, T* grad_input, const T* rois, + const int n_stride, const int c_stride, const int h_stride, const int w_stride, const int d_stride) +{ + + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw, pd) is an element in the pooled output + int pd = index % pooled_depth; + int pw = (index / pooled_depth) % pooled_width; + int ph = (index / pooled_depth / pooled_width) % pooled_height; + int c = (index / pooled_depth / pooled_width / pooled_height) % channels; + int n = index / pooled_depth / pooled_width / pooled_height / channels; + + + const T* offset_rois = rois + n * 7; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale; + T roi_start_h = offset_rois[2] * spatial_scale; + T roi_end_w = offset_rois[3] * spatial_scale; + T roi_end_h = offset_rois[4] * spatial_scale; + T roi_start_d = offset_rois[5] * spatial_scale; + T roi_end_d = offset_rois[6] * spatial_scale; + + + // Force malformed ROIs to be 1x1 + T roi_width = max(roi_end_w - roi_start_w, (T)1.); + T roi_height = max(roi_end_h - roi_start_h, (T)1.); + T roi_depth = max(roi_end_d - roi_start_d, (T)1.); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + T bin_size_d = static_cast(roi_depth) / static_cast(pooled_depth); + + // offset: index b,c,y,x,z of tensor of shape (B,C,Y,X,Z) is + // b*C*Y*X*Z + c * Y*X*Z + y * X*Z + x *Z + z = (b*C+c)Y*X*Z + ... + T* offset_grad_input = + grad_input + ((roi_batch_ind * channels + c) * height * width * depth); + + // We need to index the gradient using the tensor strides to access the correct values. + int output_offset = n * n_stride + c * c_stride; + const T* offset_grad_output = grad_output + output_offset; + const T grad_output_this_bin = offset_grad_output[ph * h_stride + pw * w_stride + pd * d_stride]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + int roi_bin_grid_d = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_depth / pooled_depth); + + // We do average (integral) pooling inside a bin + const T n_voxels = roi_bin_grid_h * roi_bin_grid_w * roi_bin_grid_d; // e.g. = 6 + + for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + + for (int ix = 0; ix < roi_bin_grid_w; ix++) + { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + + for (int iz = 0; iz < roi_bin_grid_d; iz++) + { + const T z = roi_start_d + pd * bin_size_d + + static_cast(iz + .5f) * bin_size_d / static_cast(roi_bin_grid_d); + + T g000, g001, g010, g100, g011, g101, g110, g111; // will hold the current partial derivatives + int x0, x1, y0, y1, z0, z1; + /* notation: gxyz = gradient at xyz, where x,y,z need to lie on feature-map grid (i.e., =x0,x1 etc.) */ + trilinear_interpolate_gradient(height, width, depth, y, x, z, + g000, g001, g010, g100, g011, g101, g110, g111, + x0, x1, y0, y1, z0, z1, index); + /* chain rule: derivatives (i.e., the gradient) of trilin_interpolate(v1,v2,v3,v4,...) (div by n_voxels + as we actually need gradient of whole roi_align) are multiplied with gradient so far*/ + g000 *= grad_output_this_bin / n_voxels; + g001 *= grad_output_this_bin / n_voxels; + g010 *= grad_output_this_bin / n_voxels; + g100 *= grad_output_this_bin / n_voxels; + g011 *= grad_output_this_bin / n_voxels; + g101 *= grad_output_this_bin / n_voxels; + g110 *= grad_output_this_bin / n_voxels; + g111 *= grad_output_this_bin / n_voxels; + + if (x0 >= 0 && x1 >= 0 && y0 >= 0 && y1 >= 0 && z0 >= 0 && z1 >= 0) + { // atomicAdd(address, content) reads content under address, adds content to it, while: no other thread + // can interfere with the memory at address during this operation (thread lock, therefore "atomic"). + atomicAdd(offset_grad_input + (y0 * width + x0) * depth + z0, static_cast(g000)); + atomicAdd(offset_grad_input + (y0 * width + x0) * depth + z1, static_cast(g001)); + atomicAdd(offset_grad_input + (y1 * width + x0) * depth + z0, static_cast(g010)); + atomicAdd(offset_grad_input + (y0 * width + x1) * depth + z0, static_cast(g100)); + atomicAdd(offset_grad_input + (y1 * width + x0) * depth + z1, static_cast(g011)); + atomicAdd(offset_grad_input + (y0 * width + x1) * depth + z1, static_cast(g101)); + atomicAdd(offset_grad_input + (y1 * width + x1) * depth + z0, static_cast(g110)); + atomicAdd(offset_grad_input + (y1 * width + x1) * depth + z1, static_cast(g111)); + } // if + } // iz + } // ix + } // iy + } // CUDA_1D_KERNEL_LOOP +} // RoIAlignBackward + + +/*----------- wrapper functions ----------------*/ + +at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, const at::Tensor& rois, const float spatial_scale, + const int pooled_height, const int pooled_width, const int pooled_depth, const int sampling_ratio) { + /* + input: feature-map tensor, shape (batch, n_channels, y, x(, z)) + */ + AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ROIAlign_forward_cuda"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + at::cuda::CUDAGuard device_guard(input.device()); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + auto depth = input.size(4); + //std::cout << "input.options" << input.options() << std::endl; + at::Tensor output = at::zeros( + {num_rois, channels, pooled_height, pooled_width, pooled_depth}, input.options()); + + auto output_size = num_rois * channels * pooled_height * pooled_width * pooled_depth; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min( + at::cuda::ATenCeilDiv(static_cast(output_size), static_cast(512)), static_cast(4096))); + dim3 block(512); + + if (output.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return output; + } + + //std::printf("launching kernel\n"); + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign forward in 3d", [&] { + RoIAlignForward<<>>( + output_size, + input.contiguous().data_ptr(), + spatial_scale, + channels, + height, + width, + depth, + pooled_height, + pooled_width, + pooled_depth, + sampling_ratio, + rois.contiguous().data_ptr(), + output.data_ptr()); + }); + AT_CUDA_CHECK(cudaGetLastError()); + return output; +} + +at::Tensor ROIAlign_backward_cuda( + const at::Tensor& grad, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int pooled_depth, + const int batch_size, + const int channels, + const int height, + const int width, + const int depth, + const int sampling_ratio) +{ + AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ROIAlign_backward_cuda"; + at::checkAllSameGPU(c, {grad_t, rois_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + at::cuda::CUDAGuard device_guard(grad.device()); + + at::Tensor grad_input = + at::zeros({batch_size, channels, height, width, depth}, grad.options()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min( + at::cuda::ATenCeilDiv( + static_cast(grad.numel()), static_cast(512)), + static_cast(4096))); + dim3 block(512); + + // handle possibly empty gradients + if (grad.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return grad_input; + } + + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + int d_stride = grad.stride(4); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign backward 3D", [&] { + RoIAlignBackward<<>>( + grad.numel(), + grad.data_ptr(), + spatial_scale, + channels, + height, + width, + depth, + pooled_height, + pooled_width, + pooled_depth, + sampling_ratio, + grad_input.data_ptr(), + rois.contiguous().data_ptr(), + n_stride, + c_stride, + h_stride, + w_stride, + d_stride); + }); + AT_CUDA_CHECK(cudaGetLastError()); + return grad_input; +} \ No newline at end of file diff --git a/custom_extensions/roi_align/src/RoIAlign_interface.cpp b/custom_extensions/roi_align/src/RoIAlign_interface.cpp new file mode 100644 index 0000000..41d5fdf --- /dev/null +++ b/custom_extensions/roi_align/src/RoIAlign_interface.cpp @@ -0,0 +1,104 @@ +/* adopted from pytorch framework + https://github.com/pytorch/vision/blob/master/torchvision/csrc/ROIAlign.h on Nov 15 2019. + + does not include CPU support but could be added with this interface. +*/ + +#include + +// Declarations that are initialized in cuda file +at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, const at::Tensor& rois, const float spatial_scale, + const int pooled_height, const int pooled_width, const int sampling_ratio); +at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, const at::Tensor& rois, const float spatial_scale, + const int pooled_height, const int pooled_width, const int batch_size, const int channels, + const int height, const int width, const int sampling_ratio); + +// Interface for Python +at::Tensor ROIAlign_forward( + const at::Tensor& input, // Input feature map. + const at::Tensor& rois, // List of ROIs to pool over. + const double spatial_scale, // The scale of the image features. ROIs will be scaled to this. + const int64_t pooled_height, // The height of the pooled feature map. + const int64_t pooled_width, // The width of the pooled feature + const int64_t sampling_ratio) // The number of points to sample in each bin along each axis. +{ + if (input.type().is_cuda()) { + return ROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); + } + AT_ERROR("Not compiled with CPU support"); + //return ROIAlign_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); +} + +at::Tensor ROIAlign_backward(const at::Tensor& grad, const at::Tensor& rois, const float spatial_scale, + const int pooled_height, const int pooled_width, const int batch_size, const int channels, + const int height, const int width, const int sampling_ratio) { + if (grad.type().is_cuda()) { + return ROIAlign_backward_cuda( grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, + height, width, sampling_ratio); + } + AT_ERROR("Not compiled with CPU support"); + //return ROIAlign_backward_cpu(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, + // height, width, sampling_ratio); +} + +using namespace at; +using torch::Tensor; +using torch::autograd::AutogradContext; +using torch::autograd::Variable; +using torch::autograd::variable_list; + +class ROIAlignFunction : public torch::autograd::Function { + public: + static variable_list forward( + AutogradContext* ctx, + Variable input, + Variable rois, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio) { + ctx->saved_data["spatial_scale"] = spatial_scale; + ctx->saved_data["pooled_height"] = pooled_height; + ctx->saved_data["pooled_width"] = pooled_width; + ctx->saved_data["sampling_ratio"] = sampling_ratio; + ctx->saved_data["input_shape"] = input.sizes(); + ctx->save_for_backward({rois}); + auto result = ROIAlign_forward(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); + return {result}; + } + + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output) { + // Use data saved in forward + auto saved = ctx->get_saved_variables(); + auto rois = saved[0]; + auto input_shape = ctx->saved_data["input_shape"].toIntList(); + auto grad_in = ROIAlign_backward( + grad_output[0], + rois, + ctx->saved_data["spatial_scale"].toDouble(), + ctx->saved_data["pooled_height"].toInt(), + ctx->saved_data["pooled_width"].toInt(), + input_shape[0], //b + input_shape[1], //c + input_shape[2], //h + input_shape[3], //w + ctx->saved_data["sampling_ratio"].toInt()); + return { + grad_in, Variable(), Variable(), Variable(), Variable(), Variable()}; + } +}; + +Tensor roi_align(const Tensor& input, const Tensor& rois, const double spatial_scale, const int64_t pooled_height, + const int64_t pooled_width, const int64_t sampling_ratio) { + + return ROIAlignFunction::apply(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio)[0]; + +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("roi_align", &roi_align, "ROIAlign 2D in c++ and/or cuda"); +} \ No newline at end of file diff --git a/custom_extensions/roi_align/src/RoIAlign_interface_3d.cpp b/custom_extensions/roi_align/src/RoIAlign_interface_3d.cpp new file mode 100644 index 0000000..bb680fa --- /dev/null +++ b/custom_extensions/roi_align/src/RoIAlign_interface_3d.cpp @@ -0,0 +1,112 @@ +/* adopted from pytorch framework + https://github.com/pytorch/vision/blob/master/torchvision/csrc/ROIAlign.h on Nov 15 2019. + + does not include CPU support but could be added with this interface. +*/ + +#include + +/*---------------- 3D implementation ---------------------------*/ + +// Declarations that are initialized in cuda file +at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, const at::Tensor& rois, const float spatial_scale, + const int pooled_height, const int pooled_width, const int pooled_depth, + const int sampling_ratio); +at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, const at::Tensor& rois, const float spatial_scale, + const int pooled_height, const int pooled_width, const int pooled_depth, const int batch_size, const int channels, + const int height, const int width, const int depth, const int sampling_ratio); + +// Interface for Python +at::Tensor ROIAlign_forward( + const at::Tensor& input, // Input feature map. + const at::Tensor& rois, // List of ROIs to pool over. + const double spatial_scale, // The scale of the image features. ROIs will be scaled to this. + const int64_t pooled_height, // The height of the pooled feature map. + const int64_t pooled_width, // The width of the pooled feature + const int64_t pooled_depth, + const int64_t sampling_ratio) // The number of points to sample in each bin along each axis. +{ + if (input.type().is_cuda()) { + return ROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, pooled_depth, sampling_ratio); + } + AT_ERROR("Not compiled with CPU support"); + //return ROIAlign_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); +} + +at::Tensor ROIAlign_backward(const at::Tensor& grad, const at::Tensor& rois, const float spatial_scale, + const int pooled_height, const int pooled_width, const int pooled_depth, const int batch_size, const int channels, + const int height, const int width, const int depth, const int sampling_ratio) { + if (grad.type().is_cuda()) { + return ROIAlign_backward_cuda( grad, rois, spatial_scale, pooled_height, pooled_width, pooled_depth, batch_size, + channels, height, width, depth, sampling_ratio); + } + AT_ERROR("Not compiled with CPU support"); + //return ROIAlign_backward_cpu(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, + // height, width, sampling_ratio); +} + +using namespace at; +using torch::Tensor; +using torch::autograd::AutogradContext; +using torch::autograd::Variable; +using torch::autograd::variable_list; + +class ROIAlignFunction : public torch::autograd::Function { + public: + static variable_list forward( + AutogradContext* ctx, + Variable input, + Variable rois, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t pooled_depth, + const int64_t sampling_ratio) { + ctx->saved_data["spatial_scale"] = spatial_scale; + ctx->saved_data["pooled_height"] = pooled_height; + ctx->saved_data["pooled_width"] = pooled_width; + ctx->saved_data["pooled_depth"] = pooled_depth; + ctx->saved_data["sampling_ratio"] = sampling_ratio; + ctx->saved_data["input_shape"] = input.sizes(); + ctx->save_for_backward({rois}); + auto result = ROIAlign_forward(input, rois, spatial_scale, pooled_height, pooled_width, pooled_depth, + sampling_ratio); + return {result}; + } + + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output) { + // Use data saved in forward + auto saved = ctx->get_saved_variables(); + auto rois = saved[0]; + auto input_shape = ctx->saved_data["input_shape"].toIntList(); + auto grad_in = ROIAlign_backward( + grad_output[0], + rois, + ctx->saved_data["spatial_scale"].toDouble(), + ctx->saved_data["pooled_height"].toInt(), + ctx->saved_data["pooled_width"].toInt(), + ctx->saved_data["pooled_depth"].toInt(), + input_shape[0], + input_shape[1], + input_shape[2], + input_shape[3], + input_shape[4], + ctx->saved_data["sampling_ratio"].toInt()); + return { + grad_in, Variable(), Variable(), Variable(), Variable(), Variable(), Variable()}; + } +}; + +Tensor roi_align(const Tensor& input, const Tensor& rois, const double spatial_scale, const int64_t pooled_height, + const int64_t pooled_width, const int64_t pooled_depth, const int64_t sampling_ratio) { + + return ROIAlignFunction::apply(input, rois, spatial_scale, pooled_height, pooled_width, pooled_depth, + sampling_ratio)[0]; + +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("roi_align", &roi_align, "ROIAlign 3D in c++ and/or cuda"); +} \ No newline at end of file diff --git a/custom_extensions/sandbox/setup.py b/custom_extensions/sandbox/setup.py new file mode 100644 index 0000000..71ac616 --- /dev/null +++ b/custom_extensions/sandbox/setup.py @@ -0,0 +1,13 @@ +""" +Created at 07.11.19 19:12 +@author: gregor + +""" + + +from setuptools import setup +from torch.utils import cpp_extension + +setup(name='sandbox_cuda', + ext_modules=[cpp_extension.CUDAExtension('sandbox', ['src/sandbox.cpp', 'src/sandbox_cuda.cu'])], + cmdclass={'build_ext': cpp_extension.BuildExtension}) \ No newline at end of file diff --git a/custom_extensions/sandbox/src/sandbox.cpp b/custom_extensions/sandbox/src/sandbox.cpp new file mode 100644 index 0000000..e1fbe12 --- /dev/null +++ b/custom_extensions/sandbox/src/sandbox.cpp @@ -0,0 +1,82 @@ +// ------------------------------------------------------------------ +// Faster R-CNN +// Copyright (c) 2015 Microsoft +// Licensed under The MIT License [see fast-rcnn/LICENSE for details] +// Written by Shaoqing Ren, rewritten in C++ by Gregor Ramien +// ------------------------------------------------------------------ +//#include +//#include + +#include +#include + +#include "sandbox.h" + +//#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) // divide m by n, add +1 if there is a remainder +#define getNBlocks(m,n) ( (m+n-1) / (n) ) // m = nr of total (required) threads, n = nr of threads per block. +int const threadsPerBlock = sizeof(unsigned long long) * 8; + +//---- declarations that will be defined in cuda kernel +void add_cuda(int n=1<<3); +void nms_cuda(at::Tensor *boxes, at::Tensor *scores, float thresh); +//----------------------------------------------------------------- + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +void sandbox() { + + //std::cout<< "number: "<< number << std::endl; + + torch::Tensor tensor = torch::tensor({0,1,2,3,4,5}, at::kLong).view({2,3}); + std::cout<< "tensor: " << tensor << std::endl; + std::cout<< "tensor shape: " << at::size(tensor,0) << ", " << at::size(tensor,1) << std::endl; + return; +} + +void add(){ + //tutorial function: add two arrays (x,y) of length n + add_cuda(); +} + +//void nms(at::Tensor boxes, at::Tensor scores, float thresh) { +void nms() { + + // x1, y1, x2, y2 + at::Tensor boxes = torch::tensor({ + {20, 10, 60, 40}, + {10, 20, 40, 60}, + {20, 20, 80, 50} + }, at::TensorOptions().dtype(at::kInt).device(at::kCUDA)); + std::cout << "boxes: \n" << boxes << std::endl; + at::Tensor scores = torch::tensor({ + 0.5, + 0.6, + 0.7 + }, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA)); + std::cout<< "scores: \n" << scores << std::endl; + + CHECK_INPUT(boxes); CHECK_INPUT(scores); + + int boxes_num = at::size(boxes,0); + int boxes_dim = at::size(boxes,1); + + std::cout << "boxes shape: " << boxes_num << ", " << boxes_dim << std::endl; + + float * boxes_dev; unsigned long long * mask_dev; float thresh = 0.05; + + dim3 blocks((boxes_num, threadsPerBlock); + int threads(threadsPerBlock); + + +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("sandbox", &sandbox, "Sandy Box Function"); + m.def("nms", &nms, "NMS in cpp"); + m.def("add", &add, "tutorial add function"); + +} \ No newline at end of file diff --git a/custom_extensions/sandbox/src/sandbox_cuda.cu b/custom_extensions/sandbox/src/sandbox_cuda.cu new file mode 100644 index 0000000..8bb6764 --- /dev/null +++ b/custom_extensions/sandbox/src/sandbox_cuda.cu @@ -0,0 +1,130 @@ +// ------------------------------------------------------------------ +// Faster R-CNN +// Copyright (c) 2015 Microsoft +// Licensed under The MIT License [see fast-rcnn/LICENSE for details] +// Written by Shaoqing Ren +// ------------------------------------------------------------------ + +#include +#include + +#include +#include +#include +#include + +#include "sandbox.h" +//tutorial cuda function add + +__global__ void add_kernel(float *x, float *y, int n){ + printf("block %d: threadIdx.x %d, threadIdx.y %d, threadIdx.z %d.\n", blockIdx.x, threadIdx.x, threadIdx.y, threadIdx.z); + int index = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i=index; i>>(x, y, n); + + cudaDeviceSynchronize(); + + float maxError = 0.0f; + for (int i = 0; i < n; i++) + maxError = fmax(maxError, fabs(y[i]-3.0f)); + std::cout << "Max error: " << maxError << std::endl; + + cudaFree(x); + cudaFree(y); + +} + +/* +__device__ inline float devIoU(float const * const a, float const * const b) { + float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]); + float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]); + float width = fmaxf(right - left + 1, 0.f), height = fmaxf(bottom - top + 1, 0.f); + float interS = width * height; + float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); + float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); + return interS / (Sa + Sb - interS); +} + +__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, + const float *dev_boxes, unsigned long long *dev_mask) { + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + + const int row_size = + fminf(n_boxes - row_start * threadsPerBlock, threadsPerBlock); + const int col_size = + fminf(n_boxes - col_start * threadsPerBlock, threadsPerBlock); + + __shared__ float block_boxes[threadsPerBlock * 5]; + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 5 + 0] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; + block_boxes[threadIdx.x * 5 + 1] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; + block_boxes[threadIdx.x * 5 + 2] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; + block_boxes[threadIdx.x * 5 + 3] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; + block_boxes[threadIdx.x * 5 + 4] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; + const float *cur_box = dev_boxes + cur_box_idx * 5; + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { + t |= 1ULL << i; + } + } + const int col_blocks = DIVUP(n_boxes, threadsPerBlock); + dev_mask[cur_box_idx * col_blocks + col_start] = t; + } +} +// +*/ + +__global__ void nms_kernel(){ + +} + +void nms_cuda(){ + + + +} + + +int main(void){ + + nms_cuda(); + + return 0; +} \ No newline at end of file