diff --git a/models/retina_net.py b/models/retina_net.py index 81dff0a..c1c8b0a 100644 --- a/models/retina_net.py +++ b/models/retina_net.py @@ -1,508 +1,508 @@ #!/usr/bin/env python # Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ Retina Net. According to https://arxiv.org/abs/1708.02002 Retina U-Net. According to https://arxiv.org/abs/1811.08661 """ import utils.model_utils as mutils import utils.exp_utils as utils import sys sys.path.append('../') from cuda_functions.nms_2D.pth_nms import nms_gpu as nms_2D from cuda_functions.nms_3D.pth_nms import nms_gpu as nms_3D import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.utils ############################################################ # Network Heads ############################################################ class Classifier(nn.Module): def __init__(self, cf, conv): """ Builds the classifier sub-network. """ super(Classifier, self).__init__() self.dim = conv.dim self.n_classes = cf.head_classes n_input_channels = cf.end_filts n_features = cf.n_rpn_features n_output_channels = cf.n_anchors_per_pos * cf.head_classes anchor_stride = cf.rpn_anchor_stride self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu) self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu) self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu) self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu) self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, pad=1, relu=None) def forward(self, x): """ :param x: input feature map (b, in_c, y, x, (z)) :return: class_logits (b, n_anchors, n_classes) """ x = self.conv_1(x) x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) class_logits = self.conv_final(x) axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) class_logits = class_logits.permute(*axes) class_logits = class_logits.contiguous() class_logits = class_logits.view(x.size()[0], -1, self.n_classes) return [class_logits] class BBRegressor(nn.Module): def __init__(self, cf, conv): """ Builds the bb-regression sub-network. """ super(BBRegressor, self).__init__() self.dim = conv.dim n_input_channels = cf.end_filts n_features = cf.n_rpn_features n_output_channels = cf.n_anchors_per_pos * self.dim * 2 anchor_stride = cf.rpn_anchor_stride self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu) self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu) self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu) self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu) self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, pad=1, relu=None) def forward(self, x): """ :param x: input feature map (b, in_c, y, x, (z)) :return: bb_logits (b, n_anchors, dim * 2) """ x = self.conv_1(x) x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) bb_logits = self.conv_final(x) axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) bb_logits = bb_logits.permute(*axes) bb_logits = bb_logits.contiguous() bb_logits = bb_logits.view(x.size()[0], -1, self.dim * 2) return [bb_logits] ############################################################ # Loss Functions ############################################################ def compute_class_loss(anchor_matches, class_pred_logits, shem_poolsize=20): """ - :param anchor_matches: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. + :param anchor_matches: (n_anchors). [-1, 0, class_id] for negative, neutral, and positive matched anchors. :param class_pred_logits: (n_anchors, n_classes). logits from classifier sub-network. :param shem_poolsize: int. factor of top-k candidates to draw from per negative sample (online-hard-example-mining). :return: loss: torch tensor. :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training. """ # Positive and Negative anchors contribute to the loss, # but neutral anchors (match value = 0) don't. pos_indices = torch.nonzero(anchor_matches > 0) neg_indices = torch.nonzero(anchor_matches == -1) # get positive samples and calucalte loss. if 0 not in pos_indices.size(): pos_indices = pos_indices.squeeze(1) roi_logits_pos = class_pred_logits[pos_indices] targets_pos = anchor_matches[pos_indices] pos_loss = F.cross_entropy(roi_logits_pos, targets_pos.long()) else: pos_loss = torch.FloatTensor([0]).cuda() # get negative samples, such that the amount matches the number of positive samples, but at least 1. # get high scoring negatives by applying online-hard-example-mining. if 0 not in neg_indices.size(): neg_indices = neg_indices.squeeze(1) roi_logits_neg = class_pred_logits[neg_indices] negative_count = np.max((1, pos_indices.size()[0])) roi_probs_neg = F.softmax(roi_logits_neg, dim=1) neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize) neg_loss = F.cross_entropy(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda()) # return the indices of negative samples, which contributed to the loss (for monitoring plots). np_neg_ix = neg_ix.cpu().data.numpy() else: neg_loss = torch.FloatTensor([0]).cuda() np_neg_ix = np.array([]).astype('int32') loss = (pos_loss + neg_loss) / 2 return loss, np_neg_ix def compute_bbox_loss(target_deltas, pred_deltas, anchor_matches): """ :param target_deltas: (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))). Uses 0 padding to fill in unsed bbox deltas. :param pred_deltas: predicted deltas from bbox regression head. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))) - :param anchor_matches: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. + :param anchor_matches: (n_anchors). [-1, 0, class_id] for negative, neutral, and positive matched anchors. :return: loss: torch 1D tensor. """ - if 0 not in torch.nonzero(anchor_matches == 1).size(): + if 0 not in torch.nonzero(anchor_matches > 0).size(): - indices = torch.nonzero(anchor_matches == 1).squeeze(1) + indices = torch.nonzero(anchor_matches > 0).squeeze(1) # Pick bbox deltas that contribute to the loss pred_deltas = pred_deltas[indices] # Trim target bounding box deltas to the same length as pred_deltas. target_deltas = target_deltas[:pred_deltas.size()[0], :] # Smooth L1 loss loss = F.smooth_l1_loss(pred_deltas, target_deltas) else: loss = torch.FloatTensor([0]).cuda() return loss ############################################################ # Output Handler ############################################################ def refine_detections(anchors, probs, deltas, batch_ixs, cf): """ Refine classified proposals, filter overlaps and return final detections. n_proposals here is typically a very large number: batch_size * n_anchors. This function is hence optimized on trimming down n_proposals. :param anchors: (n_anchors, 2 * dim) :param probs: (n_proposals, n_classes) softmax probabilities for all rois as predicted by classifier head. :param deltas: (n_proposals, n_classes, 2 * dim) box refinement deltas as predicted by bbox regressor head. :param batch_ixs: (n_proposals) batch element assignemnt info for re-allocation. :return: result: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score)) """ anchors = anchors.repeat(len(np.unique(batch_ixs)), 1) # flatten foreground probabilities, sort and trim down to highest confidences by pre_nms limit. fg_probs = probs[:, 1:].contiguous() flat_probs, flat_probs_order = fg_probs.view(-1).sort(descending=True) keep_ix = flat_probs_order[:cf.pre_nms_limit] # reshape indices to 2D index array with shape like fg_probs. keep_arr = torch.cat(((keep_ix / fg_probs.shape[1]).unsqueeze(1), (keep_ix % fg_probs.shape[1]).unsqueeze(1)), 1) pre_nms_scores = flat_probs[:cf.pre_nms_limit] pre_nms_class_ids = keep_arr[:, 1] + 1 # add background again. pre_nms_batch_ixs = batch_ixs[keep_arr[:, 0]] pre_nms_anchors = anchors[keep_arr[:, 0]] pre_nms_deltas = deltas[keep_arr[:, 0]] keep = torch.arange(pre_nms_scores.size()[0]).long().cuda() # apply bounding box deltas. re-scale to image coordinates. std_dev = torch.from_numpy(np.reshape(cf.rpn_bbox_std_dev, [1, cf.dim * 2])).float().cuda() scale = torch.from_numpy(cf.scale).float().cuda() refined_rois = mutils.apply_box_deltas_2D(pre_nms_anchors / scale, pre_nms_deltas * std_dev) * scale \ if cf.dim == 2 else mutils.apply_box_deltas_3D(pre_nms_anchors / scale, pre_nms_deltas * std_dev) * scale # round and cast to int since we're deadling with pixels now refined_rois = mutils.clip_to_window(cf.window, refined_rois) pre_nms_rois = torch.round(refined_rois) for j, b in enumerate(mutils.unique1d(pre_nms_batch_ixs)): bixs = torch.nonzero(pre_nms_batch_ixs == b)[:, 0] bix_class_ids = pre_nms_class_ids[bixs] bix_rois = pre_nms_rois[bixs] bix_scores = pre_nms_scores[bixs] for i, class_id in enumerate(mutils.unique1d(bix_class_ids)): ixs = torch.nonzero(bix_class_ids == class_id)[:, 0] # nms expects boxes sorted by score. ix_rois = bix_rois[ixs] ix_scores = bix_scores[ixs] ix_scores, order = ix_scores.sort(descending=True) ix_rois = ix_rois[order, :] ix_scores = ix_scores if cf.dim == 2: class_keep = nms_2D(torch.cat((ix_rois, ix_scores.unsqueeze(1)), dim=1), cf.detection_nms_threshold) else: class_keep = nms_3D(torch.cat((ix_rois, ix_scores.unsqueeze(1)), dim=1), cf.detection_nms_threshold) # map indices back. class_keep = keep[bixs[ixs[order[class_keep]]]] # merge indices over classes for current batch element b_keep = class_keep if i == 0 else mutils.unique1d(torch.cat((b_keep, class_keep))) # only keep top-k boxes of current batch-element. top_ids = pre_nms_scores[b_keep].sort(descending=True)[1][:cf.model_max_instances_per_batch_element] b_keep = b_keep[top_ids] # merge indices over batch elements. batch_keep = b_keep if j == 0 else mutils.unique1d(torch.cat((batch_keep, b_keep))) keep = batch_keep # arrange output. result = torch.cat((pre_nms_rois[keep], pre_nms_batch_ixs[keep].unsqueeze(1).float(), pre_nms_class_ids[keep].unsqueeze(1).float(), pre_nms_scores[keep].unsqueeze(1)), dim=1) return result def get_results(cf, img_shape, detections, seg_logits, box_results_list=None): """ Restores batch dimension of merged detections, unmolds detections, creates and fills results dict. :param img_shape: :param detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score) :param box_results_list: None or list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, ..., n_classes] for retina_unet and dummy array for retina_net. """ detections = detections.cpu().data.numpy() batch_ixs = detections[:, cf.dim*2] detections = [detections[batch_ixs == ix] for ix in range(img_shape[0])] # for test_forward, where no previous list exists. if box_results_list is None: box_results_list = [[] for _ in range(img_shape[0])] for ix in range(img_shape[0]): if 0 not in detections[ix].shape: boxes = detections[ix][:, :2 * cf.dim].astype(np.int32) class_ids = detections[ix][:, 2 * cf.dim + 1].astype(np.int32) scores = detections[ix][:, 2 * cf.dim + 2] # Filter out detections with zero area. Often only happens in early # stages of training when the network weights are still a bit random. if cf.dim == 2: exclude_ix = np.where((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0] else: exclude_ix = np.where( (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4]) <= 0)[0] if exclude_ix.shape[0] > 0: boxes = np.delete(boxes, exclude_ix, axis=0) class_ids = np.delete(class_ids, exclude_ix, axis=0) scores = np.delete(scores, exclude_ix, axis=0) if 0 not in boxes.shape: for ix2, score in enumerate(scores): if score >= cf.model_min_confidence: box_results_list[ix].append({'box_coords': boxes[ix2], 'box_score': score, 'box_type': 'det', 'box_pred_class_id': class_ids[ix2]}) results_dict = {'boxes': box_results_list} if seg_logits is None: # output dummy segmentation for retina_net. results_dict['seg_preds'] = np.zeros(img_shape)[:, 0][:, np.newaxis] else: # output label maps for retina_unet. results_dict['seg_preds'] = F.softmax(seg_logits, 1).argmax(1).cpu().data.numpy()[:, np.newaxis].astype('uint8') return results_dict ############################################################ # Retina (U-)Net Class ############################################################ class net(nn.Module): def __init__(self, cf, logger): super(net, self).__init__() self.cf = cf self.logger = logger self.build() if self.cf.weight_init is not None: logger.info("using pytorch weight init of type {}".format(self.cf.weight_init)) mutils.initialize_weights(self) else: logger.info("using default pytorch weight init") def build(self): """ Build Retina Net architecture. """ # Image size must be dividable by 2 multiple times. h, w = self.cf.patch_size[:2] if h / 2 ** 5 != int(h / 2 ** 5) or w / 2 ** 5 != int(w / 2 ** 5): raise Exception("Image size must be dividable by 2 at least 5 times " "to avoid fractions when downscaling and upscaling." "For example, use 256, 320, 384, 448, 512, ... etc. ") # instanciate abstract multi dimensional conv class and backbone model. conv = mutils.NDConvGenerator(self.cf.dim) backbone = utils.import_module('bbone', self.cf.backbone_path) # build Anchors, FPN, Classifier / Bbox-Regressor -head self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf) self.anchors = torch.from_numpy(self.np_anchors).float().cuda() self.Fpn = backbone.FPN(self.cf, conv, operate_stride1=self.cf.operate_stride1) self.Classifier = Classifier(self.cf, conv) self.BBRegressor = BBRegressor(self.cf, conv) def train_forward(self, batch, **kwargs): """ train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data for processing, computes losses, and stores outputs in a dictionary. :param batch: dictionary containing 'data', 'seg', etc. :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixelwise segmentation output (b, c, y, x, (z)) with values [0, .., n_classes]. 'monitor_values': dict of values to be monitored. """ img = batch['data'] gt_class_ids = batch['roi_labels'] gt_boxes = batch['bb_target'] var_seg_ohe = torch.FloatTensor(mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).cuda() var_seg = torch.LongTensor(batch['seg']).cuda() img = torch.from_numpy(img).float().cuda() batch_class_loss = torch.FloatTensor([0]).cuda() batch_bbox_loss = torch.FloatTensor([0]).cuda() # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. box_results_list = [[] for _ in range(img.shape[0])] detections, class_logits, pred_deltas, seg_logits = self.forward(img) # loop over batch for b in range(img.shape[0]): # add gt boxes to results dict for monitoring. if len(gt_boxes[b]) > 0: for ix in range(len(gt_boxes[b])): box_results_list[b].append({'box_coords': batch['bb_target'][b][ix], 'box_label': batch['roi_labels'][b][ix], 'box_type': 'gt'}) # match gt boxes with anchors to generate targets. anchor_class_match, anchor_target_deltas = mutils.gt_anchor_matching( self.cf, self.np_anchors, gt_boxes[b], gt_class_ids[b]) # add positive anchors used for loss to results_dict for monitoring. pos_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere(anchor_class_match > 0)][:, 0], img.shape[2:]) for p in pos_anchors: box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'}) else: anchor_class_match = np.array([-1]*self.np_anchors.shape[0]) anchor_target_deltas = np.array([0]) anchor_class_match = torch.from_numpy(anchor_class_match).cuda() anchor_target_deltas = torch.from_numpy(anchor_target_deltas).float().cuda() # compute losses. class_loss, neg_anchor_ix = compute_class_loss(anchor_class_match, class_logits[b]) bbox_loss = compute_bbox_loss(anchor_target_deltas, pred_deltas[b], anchor_class_match) # add negative anchors used for loss to results_dict for monitoring. neg_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere(anchor_class_match == -1)][0, neg_anchor_ix], img.shape[2:]) for n in neg_anchors: box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'}) batch_class_loss += class_loss / img.shape[0] batch_bbox_loss += bbox_loss / img.shape[0] results_dict = get_results(self.cf, img.shape, detections, seg_logits, box_results_list) loss = batch_class_loss + batch_bbox_loss results_dict['torch_loss'] = loss results_dict['monitor_values'] = {'loss': loss.item(), 'class_loss': batch_class_loss.item()} results_dict['logger_string'] = "loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}"\ .format(loss.item(), batch_class_loss.item(), batch_bbox_loss.item()) return results_dict def test_forward(self, batch, **kwargs): """ test method. wrapper around forward pass of network without usage of any ground truth information. prepares input data for processing and stores outputs in a dictionary. :param batch: dictionary containing 'data' :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, ..., n_classes] for retina_unet and dummy array for retina_net. """ img = batch['data'] img = torch.from_numpy(img).float().cuda() detections, _, _, seg_logits = self.forward(img) results_dict = get_results(self.cf, img.shape, detections, seg_logits) return results_dict def forward(self, img): """ forward pass of the model. :param img: input img (b, c, y, x, (z)). :return: rpn_pred_logits: (b, n_anchors, 2) :return: rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d)))) :return: batch_proposal_boxes: (b, n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix)) only for monitoring/plotting. :return: detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score) :return: detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head. """ # Feature extraction fpn_outs = self.Fpn(img) seg_logits = None selected_fmaps = [fpn_outs[i] for i in self.cf.pyramid_levels] # Loop through pyramid layers class_layer_outputs, bb_reg_layer_outputs = [], [] # list of lists for p in selected_fmaps: class_layer_outputs.append(self.Classifier(p)) bb_reg_layer_outputs.append(self.BBRegressor(p)) # Concatenate layer outputs # Convert from list of lists of level outputs to list of lists # of outputs across levels. # e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]] class_logits = list(zip(*class_layer_outputs)) class_logits = [torch.cat(list(o), dim=1) for o in class_logits][0] bb_outputs = list(zip(*bb_reg_layer_outputs)) bb_outputs = [torch.cat(list(o), dim=1) for o in bb_outputs][0] # merge batch_dimension and store info in batch_ixs for re-allocation. batch_ixs = torch.arange(class_logits.shape[0]).unsqueeze(1).repeat(1, class_logits.shape[1]).view(-1).cuda() flat_class_softmax = F.softmax(class_logits.view(-1, class_logits.shape[-1]), 1) flat_bb_outputs = bb_outputs.view(-1, bb_outputs.shape[-1]) detections = refine_detections(self.anchors, flat_class_softmax, flat_bb_outputs, batch_ixs, self.cf) return detections, class_logits, bb_outputs, seg_logits diff --git a/models/retina_unet.py b/models/retina_unet.py index 5ce2471..ae43d19 100644 --- a/models/retina_unet.py +++ b/models/retina_unet.py @@ -1,513 +1,513 @@ #!/usr/bin/env python # Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ Retina Net. According to https://arxiv.org/abs/1708.02002 Retina U-Net. According to https://arxiv.org/abs/1811.08661 """ import utils.model_utils as mutils import utils.exp_utils as utils import sys sys.path.append('../') from cuda_functions.nms_2D.pth_nms import nms_gpu as nms_2D from cuda_functions.nms_3D.pth_nms import nms_gpu as nms_3D import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.utils ############################################################ # Network Heads ############################################################ class Classifier(nn.Module): def __init__(self, cf, conv): """ Builds the classifier sub-network. """ super(Classifier, self).__init__() self.dim = conv.dim self.n_classes = cf.head_classes n_input_channels = cf.end_filts n_features = cf.n_rpn_features n_output_channels = cf.n_anchors_per_pos * cf.head_classes anchor_stride = cf.rpn_anchor_stride self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu) self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu) self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu) self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu) self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, pad=1, relu=None) def forward(self, x): """ :param x: input feature map (b, in_c, y, x, (z)) :return: class_logits (b, n_anchors, n_classes) """ x = self.conv_1(x) x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) class_logits = self.conv_final(x) axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) class_logits = class_logits.permute(*axes) class_logits = class_logits.contiguous() class_logits = class_logits.view(x.size()[0], -1, self.n_classes) return [class_logits] class BBRegressor(nn.Module): def __init__(self, cf, conv): """ Builds the bb-regression sub-network. """ super(BBRegressor, self).__init__() self.dim = conv.dim n_input_channels = cf.end_filts n_features = cf.n_rpn_features n_output_channels = cf.n_anchors_per_pos * self.dim * 2 anchor_stride = cf.rpn_anchor_stride self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu) self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu) self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu) self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu) self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, pad=1, relu=None) def forward(self, x): """ :param x: input feature map (b, in_c, y, x, (z)) :return: bb_logits (b, n_anchors, dim * 2) """ x = self.conv_1(x) x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) bb_logits = self.conv_final(x) axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) bb_logits = bb_logits.permute(*axes) bb_logits = bb_logits.contiguous() bb_logits = bb_logits.view(x.size()[0], -1, self.dim * 2) return [bb_logits] ############################################################ # Loss Functions ############################################################ def compute_class_loss(anchor_matches, class_pred_logits, shem_poolsize=20): """ - :param anchor_matches: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. + :param anchor_matches: (n_anchors). [-1, 0, class_id] for negative, neutral, and positive matched anchors. :param class_pred_logits: (n_anchors, n_classes). logits from classifier sub-network. :param shem_poolsize: int. factor of top-k candidates to draw from per negative sample (online-hard-example-mining). :return: loss: torch tensor. :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training. """ # Positive and Negative anchors contribute to the loss, # but neutral anchors (match value = 0) don't. pos_indices = torch.nonzero(anchor_matches > 0) neg_indices = torch.nonzero(anchor_matches == -1) # get positive samples and calucalte loss. if 0 not in pos_indices.size(): pos_indices = pos_indices.squeeze(1) roi_logits_pos = class_pred_logits[pos_indices] targets_pos = anchor_matches[pos_indices] pos_loss = F.cross_entropy(roi_logits_pos, targets_pos.long()) else: pos_loss = torch.FloatTensor([0]).cuda() # get negative samples, such that the amount matches the number of positive samples, but at least 1. # get high scoring negatives by applying online-hard-example-mining. if 0 not in neg_indices.size(): neg_indices = neg_indices.squeeze(1) roi_logits_neg = class_pred_logits[neg_indices] negative_count = np.max((1, pos_indices.size()[0])) roi_probs_neg = F.softmax(roi_logits_neg, dim=1) neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize) neg_loss = F.cross_entropy(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda()) # return the indices of negative samples, which contributed to the loss (for monitoring plots). np_neg_ix = neg_ix.cpu().data.numpy() else: neg_loss = torch.FloatTensor([0]).cuda() np_neg_ix = np.array([]).astype('int32') loss = (pos_loss + neg_loss) / 2 return loss, np_neg_ix def compute_bbox_loss(target_deltas, pred_deltas, anchor_matches): """ :param target_deltas: (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))). Uses 0 padding to fill in unsed bbox deltas. :param pred_deltas: predicted deltas from bbox regression head. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))) - :param anchor_matches: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. + :param anchor_matches: (n_anchors). [-1, 0, class_id] for negative, neutral, and positive matched anchors. :return: loss: torch 1D tensor. """ - if 0 not in torch.nonzero(anchor_matches == 1).size(): + if 0 not in torch.nonzero(anchor_matches > 0).size(): - indices = torch.nonzero(anchor_matches == 1).squeeze(1) + indices = torch.nonzero(anchor_matches > 0).squeeze(1) # Pick bbox deltas that contribute to the loss pred_deltas = pred_deltas[indices] # Trim target bounding box deltas to the same length as pred_deltas. target_deltas = target_deltas[:pred_deltas.size()[0], :] # Smooth L1 loss loss = F.smooth_l1_loss(pred_deltas, target_deltas) else: loss = torch.FloatTensor([0]).cuda() return loss ############################################################ # Output Handler ############################################################ def refine_detections(anchors, probs, deltas, batch_ixs, cf): """ Refine classified proposals, filter overlaps and return final detections. n_proposals here is typically a very large number: batch_size * n_anchors. This function is hence optimized on trimming down n_proposals. :param anchors: (n_anchors, 2 * dim) :param probs: (n_proposals, n_classes) softmax probabilities for all rois as predicted by classifier head. :param deltas: (n_proposals, n_classes, 2 * dim) box refinement deltas as predicted by bbox regressor head. :param batch_ixs: (n_proposals) batch element assignemnt info for re-allocation. :return: result: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score)) """ anchors = anchors.repeat(len(np.unique(batch_ixs)), 1) # flatten foreground probabilities, sort and trim down to highest confidences by pre_nms limit. fg_probs = probs[:, 1:].contiguous() flat_probs, flat_probs_order = fg_probs.view(-1).sort(descending=True) keep_ix = flat_probs_order[:cf.pre_nms_limit] # reshape indices to 2D index array with shape like fg_probs. keep_arr = torch.cat(((keep_ix / fg_probs.shape[1]).unsqueeze(1), (keep_ix % fg_probs.shape[1]).unsqueeze(1)), 1) pre_nms_scores = flat_probs[:cf.pre_nms_limit] pre_nms_class_ids = keep_arr[:, 1] + 1 # add background again. pre_nms_batch_ixs = batch_ixs[keep_arr[:, 0]] pre_nms_anchors = anchors[keep_arr[:, 0]] pre_nms_deltas = deltas[keep_arr[:, 0]] keep = torch.arange(pre_nms_scores.size()[0]).long().cuda() # apply bounding box deltas. re-scale to image coordinates. std_dev = torch.from_numpy(np.reshape(cf.rpn_bbox_std_dev, [1, cf.dim * 2])).float().cuda() scale = torch.from_numpy(cf.scale).float().cuda() refined_rois = mutils.apply_box_deltas_2D(pre_nms_anchors / scale, pre_nms_deltas * std_dev) * scale \ if cf.dim == 2 else mutils.apply_box_deltas_3D(pre_nms_anchors / scale, pre_nms_deltas * std_dev) * scale # round and cast to int since we're deadling with pixels now refined_rois = mutils.clip_to_window(cf.window, refined_rois) pre_nms_rois = torch.round(refined_rois) for j, b in enumerate(mutils.unique1d(pre_nms_batch_ixs)): bixs = torch.nonzero(pre_nms_batch_ixs == b)[:, 0] bix_class_ids = pre_nms_class_ids[bixs] bix_rois = pre_nms_rois[bixs] bix_scores = pre_nms_scores[bixs] for i, class_id in enumerate(mutils.unique1d(bix_class_ids)): ixs = torch.nonzero(bix_class_ids == class_id)[:, 0] # nms expects boxes sorted by score. ix_rois = bix_rois[ixs] ix_scores = bix_scores[ixs] ix_scores, order = ix_scores.sort(descending=True) ix_rois = ix_rois[order, :] ix_scores = ix_scores if cf.dim == 2: class_keep = nms_2D(torch.cat((ix_rois, ix_scores.unsqueeze(1)), dim=1), cf.detection_nms_threshold) else: class_keep = nms_3D(torch.cat((ix_rois, ix_scores.unsqueeze(1)), dim=1), cf.detection_nms_threshold) # map indices back. class_keep = keep[bixs[ixs[order[class_keep]]]] # merge indices over classes for current batch element b_keep = class_keep if i == 0 else mutils.unique1d(torch.cat((b_keep, class_keep))) # only keep top-k boxes of current batch-element. top_ids = pre_nms_scores[b_keep].sort(descending=True)[1][:cf.model_max_instances_per_batch_element] b_keep = b_keep[top_ids] # merge indices over batch elements. batch_keep = b_keep if j == 0 else mutils.unique1d(torch.cat((batch_keep, b_keep))) keep = batch_keep # arrange output. result = torch.cat((pre_nms_rois[keep], pre_nms_batch_ixs[keep].unsqueeze(1).float(), pre_nms_class_ids[keep].unsqueeze(1).float(), pre_nms_scores[keep].unsqueeze(1)), dim=1) return result def get_results(cf, img_shape, detections, seg_logits, box_results_list=None): """ Restores batch dimension of merged detections, unmolds detections, creates and fills results dict. :param img_shape: :param detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score) :param box_results_list: None or list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, ..., n_classes] for retina_unet and dummy array for retina_net. """ detections = detections.cpu().data.numpy() batch_ixs = detections[:, cf.dim*2] detections = [detections[batch_ixs == ix] for ix in range(img_shape[0])] # for test_forward, where no previous list exists. if box_results_list is None: box_results_list = [[] for _ in range(img_shape[0])] for ix in range(img_shape[0]): if 0 not in detections[ix].shape: boxes = detections[ix][:, :2 * cf.dim].astype(np.int32) class_ids = detections[ix][:, 2 * cf.dim + 1].astype(np.int32) scores = detections[ix][:, 2 * cf.dim + 2] # Filter out detections with zero area. Often only happens in early # stages of training when the network weights are still a bit random. if cf.dim == 2: exclude_ix = np.where((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0] else: exclude_ix = np.where( (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4]) <= 0)[0] if exclude_ix.shape[0] > 0: boxes = np.delete(boxes, exclude_ix, axis=0) class_ids = np.delete(class_ids, exclude_ix, axis=0) scores = np.delete(scores, exclude_ix, axis=0) if 0 not in boxes.shape: for ix2, score in enumerate(scores): if score >= cf.model_min_confidence: box_results_list[ix].append({'box_coords': boxes[ix2], 'box_score': score, 'box_type': 'det', 'box_pred_class_id': class_ids[ix2]}) results_dict = {'boxes': box_results_list} if seg_logits is None: # output dummy segmentation for retina_net. results_dict['seg_preds'] = np.zeros(img_shape)[:, 0][:, np.newaxis] else: # output label maps for retina_unet. results_dict['seg_preds'] = F.softmax(seg_logits, 1).argmax(1).cpu().data.numpy()[:, np.newaxis].astype('uint8') return results_dict ############################################################ # Retina (U-)Net Class ############################################################ class net(nn.Module): def __init__(self, cf, logger): super(net, self).__init__() self.cf = cf self.logger = logger self.build() if self.cf.weight_init is not None: logger.info("using pytorch weight init of type {}".format(self.cf.weight_init)) mutils.initialize_weights(self) else: logger.info("using default pytorch weight init") def build(self): """ Build Retina Net architecture. """ # Image size must be dividable by 2 multiple times. h, w = self.cf.patch_size[:2] if h / 2 ** 5 != int(h / 2 ** 5) or w / 2 ** 5 != int(w / 2 ** 5): raise Exception("Image size must be dividable by 2 at least 5 times " "to avoid fractions when downscaling and upscaling." "For example, use 256, 320, 384, 448, 512, ... etc. ") # instanciate abstract multi dimensional conv class and backbone model. conv = mutils.NDConvGenerator(self.cf.dim) backbone = utils.import_module('bbone', self.cf.backbone_path) # build Anchors, FPN, Classifier / Bbox-Regressor -head self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf) self.anchors = torch.from_numpy(self.np_anchors).float().cuda() self.Fpn = backbone.FPN(self.cf, conv, operate_stride1=self.cf.operate_stride1) self.Classifier = Classifier(self.cf, conv) self.BBRegressor = BBRegressor(self.cf, conv) self.final_conv = conv(self.cf.end_filts, self.cf.num_seg_classes, ks=1, pad=0, norm=False, relu=None) def train_forward(self, batch, **kwargs): """ train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data for processing, computes losses, and stores outputs in a dictionary. :param batch: dictionary containing 'data', 'seg', etc. :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixelwise segmentation output (b, c, y, x, (z)) with values [0, .., n_classes]. 'monitor_values': dict of values to be monitored. """ img = batch['data'] gt_class_ids = batch['roi_labels'] gt_boxes = batch['bb_target'] var_seg_ohe = torch.FloatTensor(mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).cuda() var_seg = torch.LongTensor(batch['seg']).cuda() img = torch.from_numpy(img).float().cuda() batch_class_loss = torch.FloatTensor([0]).cuda() batch_bbox_loss = torch.FloatTensor([0]).cuda() # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. box_results_list = [[] for _ in range(img.shape[0])] detections, class_logits, pred_deltas, seg_logits = self.forward(img) # loop over batch for b in range(img.shape[0]): # add gt boxes to results dict for monitoring. if len(gt_boxes[b]) > 0: for ix in range(len(gt_boxes[b])): box_results_list[b].append({'box_coords': batch['bb_target'][b][ix], 'box_label': batch['roi_labels'][b][ix], 'box_type': 'gt'}) # match gt boxes with anchors to generate targets. anchor_class_match, anchor_target_deltas = mutils.gt_anchor_matching( self.cf, self.np_anchors, gt_boxes[b], gt_class_ids[b]) # add positive anchors used for loss to results_dict for monitoring. pos_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere(anchor_class_match > 0)][:, 0], img.shape[2:]) for p in pos_anchors: box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'}) else: anchor_class_match = np.array([-1]*self.np_anchors.shape[0]) anchor_target_deltas = np.array([0]) anchor_class_match = torch.from_numpy(anchor_class_match).cuda() anchor_target_deltas = torch.from_numpy(anchor_target_deltas).float().cuda() # compute losses. class_loss, neg_anchor_ix = compute_class_loss(anchor_class_match, class_logits[b]) bbox_loss = compute_bbox_loss(anchor_target_deltas, pred_deltas[b], anchor_class_match) # add negative anchors used for loss to results_dict for monitoring. neg_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere(anchor_class_match == -1)][0, neg_anchor_ix], img.shape[2:]) for n in neg_anchors: box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'}) batch_class_loss += class_loss / img.shape[0] batch_bbox_loss += bbox_loss / img.shape[0] results_dict = get_results(self.cf, img.shape, detections, seg_logits, box_results_list) seg_loss_dice = 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1),var_seg_ohe) seg_loss_ce = F.cross_entropy(seg_logits, var_seg[:, 0]) loss = batch_class_loss + batch_bbox_loss + (seg_loss_dice + seg_loss_ce) / 2 results_dict['torch_loss'] = loss results_dict['monitor_values'] = {'loss': loss.item(), 'class_loss': batch_class_loss.item()} results_dict['logger_string'] = \ "loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}, seg dice: {3:.3f}, seg ce: {4:.3f}, mean pix. pr.: {5:.5f}"\ .format(loss.item(), batch_class_loss.item(), batch_bbox_loss.item(), seg_loss_dice.item(), seg_loss_ce.item(), np.mean(results_dict['seg_preds'])) return results_dict def test_forward(self, batch, **kwargs): """ test method. wrapper around forward pass of network without usage of any ground truth information. prepares input data for processing and stores outputs in a dictionary. :param batch: dictionary containing 'data' :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, ..., n_classes] for retina_unet and dummy array for retina_net. """ img = batch['data'] img = torch.from_numpy(img).float().cuda() detections, _, _, seg_logits = self.forward(img) results_dict = get_results(self.cf, img.shape, detections, seg_logits) return results_dict def forward(self, img): """ forward pass of the model. :param img: input img (b, c, y, x, (z)). :return: rpn_pred_logits: (b, n_anchors, 2) :return: rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d)))) :return: batch_proposal_boxes: (b, n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix)) only for monitoring/plotting. :return: detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score) :return: detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head. """ # Feature extraction fpn_outs = self.Fpn(img) seg_logits = self.final_conv(fpn_outs[0]) selected_fmaps = [fpn_outs[i + 1] for i in self.cf.pyramid_levels] # Loop through pyramid layers class_layer_outputs, bb_reg_layer_outputs = [], [] # list of lists for p in selected_fmaps: class_layer_outputs.append(self.Classifier(p)) bb_reg_layer_outputs.append(self.BBRegressor(p)) # Concatenate layer outputs # Convert from list of lists of level outputs to list of lists # of outputs across levels. # e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]] class_logits = list(zip(*class_layer_outputs)) class_logits = [torch.cat(list(o), dim=1) for o in class_logits][0] bb_outputs = list(zip(*bb_reg_layer_outputs)) bb_outputs = [torch.cat(list(o), dim=1) for o in bb_outputs][0] # merge batch_dimension and store info in batch_ixs for re-allocation. batch_ixs = torch.arange(class_logits.shape[0]).unsqueeze(1).repeat(1, class_logits.shape[1]).view(-1).cuda() flat_class_softmax = F.softmax(class_logits.view(-1, class_logits.shape[-1]), 1) flat_bb_outputs = bb_outputs.view(-1, bb_outputs.shape[-1]) detections = refine_detections(self.anchors, flat_class_softmax, flat_bb_outputs, batch_ixs, self.cf) return detections, class_logits, bb_outputs, seg_logits diff --git a/utils/model_utils.py b/utils/model_utils.py index a150a07..cb0d944 100644 --- a/utils/model_utils.py +++ b/utils/model_utils.py @@ -1,889 +1,891 @@ #!/usr/bin/env python # Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn published under MIT license. """ import numpy as np import scipy.misc import scipy.ndimage import torch from torch.autograd import Variable import torch.nn as nn ############################################################ # Bounding Boxes ############################################################ def compute_iou_2D(box, boxes, box_area, boxes_area): """Calculates IoU of the given box with the array of the given boxes. box: 1D vector [y1, x1, y2, x2] THIS IS THE GT BOX boxes: [boxes_count, (y1, x1, y2, x2)] box_area: float. the area of 'box' boxes_area: array of length boxes_count. Note: the areas are passed in rather than calculated here for efficency. Calculate once in the caller to avoid duplicate work. """ # Calculate intersection areas y1 = np.maximum(box[0], boxes[:, 0]) y2 = np.minimum(box[2], boxes[:, 2]) x1 = np.maximum(box[1], boxes[:, 1]) x2 = np.minimum(box[3], boxes[:, 3]) intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0) union = box_area + boxes_area[:] - intersection[:] iou = intersection / union return iou def compute_iou_3D(box, boxes, box_volume, boxes_volume): """Calculates IoU of the given box with the array of the given boxes. box: 1D vector [y1, x1, y2, x2, z1, z2] (typically gt box) boxes: [boxes_count, (y1, x1, y2, x2, z1, z2)] box_area: float. the area of 'box' boxes_area: array of length boxes_count. Note: the areas are passed in rather than calculated here for efficency. Calculate once in the caller to avoid duplicate work. """ # Calculate intersection areas y1 = np.maximum(box[0], boxes[:, 0]) y2 = np.minimum(box[2], boxes[:, 2]) x1 = np.maximum(box[1], boxes[:, 1]) x2 = np.minimum(box[3], boxes[:, 3]) z1 = np.maximum(box[4], boxes[:, 4]) z2 = np.minimum(box[5], boxes[:, 5]) intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0) * np.maximum(z2 - z1, 0) union = box_volume + boxes_volume[:] - intersection[:] iou = intersection / union return iou def compute_overlaps(boxes1, boxes2): """Computes IoU overlaps between two sets of boxes. boxes1, boxes2: [N, (y1, x1, y2, x2)]. / 3D: (z1, z2)) For better performance, pass the largest set first and the smaller second. """ # Areas of anchors and GT boxes if boxes1.shape[1] == 4: area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) # Compute overlaps to generate matrix [boxes1 count, boxes2 count] # Each cell contains the IoU value. overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0])) for i in range(overlaps.shape[1]): box2 = boxes2[i] #this is the gt box overlaps[:, i] = compute_iou_2D(box2, boxes1, area2[i], area1) return overlaps else: # Areas of anchors and GT boxes volume1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) * (boxes1[:, 5] - boxes1[:, 4]) volume2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) * (boxes2[:, 5] - boxes2[:, 4]) # Compute overlaps to generate matrix [boxes1 count, boxes2 count] # Each cell contains the IoU value. overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0])) for i in range(overlaps.shape[1]): box2 = boxes2[i] # this is the gt box overlaps[:, i] = compute_iou_3D(box2, boxes1, volume2[i], volume1) return overlaps def box_refinement(box, gt_box): """Compute refinement needed to transform box to gt_box. box and gt_box are [N, (y1, x1, y2, x2)] / 3D: (z1, z2)) """ height = box[:, 2] - box[:, 0] width = box[:, 3] - box[:, 1] center_y = box[:, 0] + 0.5 * height center_x = box[:, 1] + 0.5 * width gt_height = gt_box[:, 2] - gt_box[:, 0] gt_width = gt_box[:, 3] - gt_box[:, 1] gt_center_y = gt_box[:, 0] + 0.5 * gt_height gt_center_x = gt_box[:, 1] + 0.5 * gt_width dy = (gt_center_y - center_y) / height dx = (gt_center_x - center_x) / width dh = torch.log(gt_height / height) dw = torch.log(gt_width / width) result = torch.stack([dy, dx, dh, dw], dim=1) if box.shape[1] > 4: depth = box[:, 5] - box[:, 4] center_z = box[:, 4] + 0.5 * depth gt_depth = gt_box[:, 5] - gt_box[:, 4] gt_center_z = gt_box[:, 4] + 0.5 * gt_depth dz = (gt_center_z - center_z) / depth dd = torch.log(gt_depth / depth) result = torch.stack([dy, dx, dz, dh, dw, dd], dim=1) return result def unmold_mask_2D(mask, bbox, image_shape): """Converts a mask generated by the neural network into a format similar to it's original shape. mask: [height, width] of type float. A small, typically 28x28 mask. bbox: [y1, x1, y2, x2]. The box to fit the mask in. Returns a binary mask with the same size as the original image. """ y1, x1, y2, x2 = bbox out_zoom = [y2 - y1, x2 - x1] zoom_factor = [i / j for i, j in zip(out_zoom, mask.shape)] mask = scipy.ndimage.zoom(mask, zoom_factor, order=1).astype(np.float32) # Put the mask in the right location. full_mask = np.zeros(image_shape[:2]) full_mask[y1:y2, x1:x2] = mask return full_mask def unmold_mask_3D(mask, bbox, image_shape): """Converts a mask generated by the neural network into a format similar to it's original shape. mask: [height, width] of type float. A small, typically 28x28 mask. bbox: [y1, x1, y2, x2, z1, z2]. The box to fit the mask in. Returns a binary mask with the same size as the original image. """ y1, x1, y2, x2, z1, z2 = bbox out_zoom = [y2 - y1, x2 - x1, z2 - z1] zoom_factor = [i/j for i,j in zip(out_zoom, mask.shape)] mask = scipy.ndimage.zoom(mask, zoom_factor, order=1).astype(np.float32) # Put the mask in the right location. full_mask = np.zeros(image_shape[:3]) full_mask[y1:y2, x1:x2, z1:z2] = mask return full_mask ############################################################ # Anchors ############################################################ def generate_anchors(scales, ratios, shape, feature_stride, anchor_stride): """ scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128] ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2] shape: [height, width] spatial shape of the feature map over which to generate anchors. feature_stride: Stride of the feature map relative to the image in pixels. anchor_stride: Stride of anchors on the feature map. For example, if the value is 2 then generate anchors for every other feature map pixel. """ # Get all combinations of scales and ratios scales, ratios = np.meshgrid(np.array(scales), np.array(ratios)) scales = scales.flatten() ratios = ratios.flatten() # Enumerate heights and widths from scales and ratios heights = scales / np.sqrt(ratios) widths = scales * np.sqrt(ratios) # Enumerate shifts in feature space shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride shifts_x, shifts_y = np.meshgrid(shifts_x, shifts_y) # Enumerate combinations of shifts, widths, and heights box_widths, box_centers_x = np.meshgrid(widths, shifts_x) box_heights, box_centers_y = np.meshgrid(heights, shifts_y) # Reshape to get a list of (y, x) and a list of (h, w) box_centers = np.stack( [box_centers_y, box_centers_x], axis=2).reshape([-1, 2]) box_sizes = np.stack([box_heights, box_widths], axis=2).reshape([-1, 2]) # Convert to corner coordinates (y1, x1, y2, x2) boxes = np.concatenate([box_centers - 0.5 * box_sizes, box_centers + 0.5 * box_sizes], axis=1) return boxes def generate_anchors_3D(scales_xy, scales_z, ratios, shape, feature_stride_xy, feature_stride_z, anchor_stride): """ scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128] ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2] shape: [height, width] spatial shape of the feature map over which to generate anchors. feature_stride: Stride of the feature map relative to the image in pixels. anchor_stride: Stride of anchors on the feature map. For example, if the value is 2 then generate anchors for every other feature map pixel. """ # Get all combinations of scales and ratios scales_xy, ratios_meshed = np.meshgrid(np.array(scales_xy), np.array(ratios)) scales_xy = scales_xy.flatten() ratios_meshed = ratios_meshed.flatten() # Enumerate heights and widths from scales and ratios heights = scales_xy / np.sqrt(ratios_meshed) widths = scales_xy * np.sqrt(ratios_meshed) depths = np.tile(np.array(scales_z), len(ratios_meshed)//np.array(scales_z)[..., None].shape[0]) # Enumerate shifts in feature space shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride_xy #translate from fm positions to input coords. shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride_xy shifts_z = np.arange(0, shape[2], anchor_stride) * (feature_stride_z) shifts_x, shifts_y, shifts_z = np.meshgrid(shifts_x, shifts_y, shifts_z) # Enumerate combinations of shifts, widths, and heights box_widths, box_centers_x = np.meshgrid(widths, shifts_x) box_heights, box_centers_y = np.meshgrid(heights, shifts_y) box_depths, box_centers_z = np.meshgrid(depths, shifts_z) # Reshape to get a list of (y, x, z) and a list of (h, w, d) box_centers = np.stack( [box_centers_y, box_centers_x, box_centers_z], axis=2).reshape([-1, 3]) box_sizes = np.stack([box_heights, box_widths, box_depths], axis=2).reshape([-1, 3]) # Convert to corner coordinates (y1, x1, y2, x2, z1, z2) boxes = np.concatenate([box_centers - 0.5 * box_sizes, box_centers + 0.5 * box_sizes], axis=1) boxes = np.transpose(np.array([boxes[:, 0], boxes[:, 1], boxes[:, 3], boxes[:, 4], boxes[:, 2], boxes[:, 5]]), axes=(1, 0)) return boxes def generate_pyramid_anchors(logger, cf): """Generate anchors at different levels of a feature pyramid. Each scale is associated with a level of the pyramid, but each ratio is used in all levels of the pyramid. from configs: :param scales: cf.RPN_ANCHOR_SCALES , e.g. [4, 8, 16, 32] :param ratios: cf.RPN_ANCHOR_RATIOS , e.g. [0.5, 1, 2] :param feature_shapes: cf.BACKBONE_SHAPES , e.g. [array of shapes per feature map] [80, 40, 20, 10, 5] :param feature_strides: cf.BACKBONE_STRIDES , e.g. [2, 4, 8, 16, 32, 64] :param anchors_stride: cf.RPN_ANCHOR_STRIDE , e.g. 1 :return anchors: (N, (y1, x1, y2, x2, (z1), (z2)). All generated anchors in one array. Sorted with the same order of the given scales. So, anchors of scale[0] come first, then anchors of scale[1], and so on. """ scales = cf.rpn_anchor_scales ratios = cf.rpn_anchor_ratios feature_shapes = cf.backbone_shapes anchor_stride = cf.rpn_anchor_stride pyramid_levels = cf.pyramid_levels feature_strides = cf.backbone_strides anchors = [] logger.info("feature map shapes: {}".format(feature_shapes)) logger.info("anchor scales: {}".format(scales)) expected_anchors = [np.prod(feature_shapes[ii]) * len(ratios) * len(scales['xy'][ii]) for ii in pyramid_levels] for lix, level in enumerate(pyramid_levels): if len(feature_shapes[level]) == 2: anchors.append(generate_anchors(scales['xy'][level], ratios, feature_shapes[level], feature_strides['xy'][level], anchor_stride)) else: anchors.append(generate_anchors_3D(scales['xy'][level], scales['z'][level], ratios, feature_shapes[level], feature_strides['xy'][level], feature_strides['z'][level], anchor_stride)) logger.info("level {}: built anchors {} / expected anchors {} ||| total build {} / total expected {}".format( level, anchors[-1].shape, expected_anchors[lix], np.concatenate(anchors).shape, np.sum(expected_anchors))) out_anchors = np.concatenate(anchors, axis=0) return out_anchors def apply_box_deltas_2D(boxes, deltas): """Applies the given deltas to the given boxes. boxes: [N, 4] where each row is y1, x1, y2, x2 deltas: [N, 4] where each row is [dy, dx, log(dh), log(dw)] """ # Convert to y, x, h, w height = boxes[:, 2] - boxes[:, 0] width = boxes[:, 3] - boxes[:, 1] center_y = boxes[:, 0] + 0.5 * height center_x = boxes[:, 1] + 0.5 * width # Apply deltas center_y += deltas[:, 0] * height center_x += deltas[:, 1] * width height *= torch.exp(deltas[:, 2]) width *= torch.exp(deltas[:, 3]) # Convert back to y1, x1, y2, x2 y1 = center_y - 0.5 * height x1 = center_x - 0.5 * width y2 = y1 + height x2 = x1 + width result = torch.stack([y1, x1, y2, x2], dim=1) return result def apply_box_deltas_3D(boxes, deltas): """Applies the given deltas to the given boxes. boxes: [N, 6] where each row is y1, x1, y2, x2, z1, z2 deltas: [N, 6] where each row is [dy, dx, dz, log(dh), log(dw), log(dd)] """ # Convert to y, x, h, w height = boxes[:, 2] - boxes[:, 0] width = boxes[:, 3] - boxes[:, 1] depth = boxes[:, 5] - boxes[:, 4] center_y = boxes[:, 0] + 0.5 * height center_x = boxes[:, 1] + 0.5 * width center_z = boxes[:, 4] + 0.5 * depth # Apply deltas center_y += deltas[:, 0] * height center_x += deltas[:, 1] * width center_z += deltas[:, 2] * depth height *= torch.exp(deltas[:, 3]) width *= torch.exp(deltas[:, 4]) depth *= torch.exp(deltas[:, 5]) # Convert back to y1, x1, y2, x2 y1 = center_y - 0.5 * height x1 = center_x - 0.5 * width z1 = center_z - 0.5 * depth y2 = y1 + height x2 = x1 + width z2 = z1 + depth result = torch.stack([y1, x1, y2, x2, z1, z2], dim=1) return result def clip_boxes_2D(boxes, window): """ boxes: [N, 4] each col is y1, x1, y2, x2 window: [4] in the form y1, x1, y2, x2 """ boxes = torch.stack( \ [boxes[:, 0].clamp(float(window[0]), float(window[2])), boxes[:, 1].clamp(float(window[1]), float(window[3])), boxes[:, 2].clamp(float(window[0]), float(window[2])), boxes[:, 3].clamp(float(window[1]), float(window[3]))], 1) return boxes def clip_boxes_3D(boxes, window): """ boxes: [N, 6] each col is y1, x1, y2, x2, z1, z2 window: [6] in the form y1, x1, y2, x2, z1, z2 """ boxes = torch.stack( \ [boxes[:, 0].clamp(float(window[0]), float(window[2])), boxes[:, 1].clamp(float(window[1]), float(window[3])), boxes[:, 2].clamp(float(window[0]), float(window[2])), boxes[:, 3].clamp(float(window[1]), float(window[3])), boxes[:, 4].clamp(float(window[4]), float(window[5])), boxes[:, 5].clamp(float(window[4]), float(window[5]))], 1) return boxes def clip_boxes_numpy(boxes, window): """ boxes: [N, 4] each col is y1, x1, y2, x2 / [N, 6] in 3D. window: iamge shape (y, x, (z)) """ if boxes.shape[1] == 4: boxes = np.concatenate( (np.clip(boxes[:, 0], 0, window[0])[:, None], np.clip(boxes[:, 1], 0, window[0])[:, None], np.clip(boxes[:, 2], 0, window[1])[:, None], np.clip(boxes[:, 3], 0, window[1])[:, None]), 1 ) else: boxes = np.concatenate( (np.clip(boxes[:, 0], 0, window[0])[:, None], np.clip(boxes[:, 1], 0, window[0])[:, None], np.clip(boxes[:, 2], 0, window[1])[:, None], np.clip(boxes[:, 3], 0, window[1])[:, None], np.clip(boxes[:, 4], 0, window[2])[:, None], np.clip(boxes[:, 5], 0, window[2])[:, None]), 1 ) return boxes def bbox_overlaps_2D(boxes1, boxes2): """Computes IoU overlaps between two sets of boxes. boxes1, boxes2: [N, (y1, x1, y2, x2)]. """ # 1. Tile boxes2 and repeate boxes1. This allows us to compare # every boxes1 against every boxes2 without loops. # TF doesn't have an equivalent to np.repeate() so simulate it # using tf.tile() and tf.reshape. boxes1_repeat = boxes2.size()[0] boxes2_repeat = boxes1.size()[0] boxes1 = boxes1.repeat(1,boxes1_repeat).view(-1,4) boxes2 = boxes2.repeat(boxes2_repeat,1) # 2. Compute intersections b1_y1, b1_x1, b1_y2, b1_x2 = boxes1.chunk(4, dim=1) b2_y1, b2_x1, b2_y2, b2_x2 = boxes2.chunk(4, dim=1) y1 = torch.max(b1_y1, b2_y1)[:, 0] x1 = torch.max(b1_x1, b2_x1)[:, 0] y2 = torch.min(b1_y2, b2_y2)[:, 0] x2 = torch.min(b1_x2, b2_x2)[:, 0] zeros = Variable(torch.zeros(y1.size()[0]), requires_grad=False) if y1.is_cuda: zeros = zeros.cuda() intersection = torch.max(x2 - x1, zeros) * torch.max(y2 - y1, zeros) # 3. Compute unions b1_area = (b1_y2 - b1_y1) * (b1_x2 - b1_x1) b2_area = (b2_y2 - b2_y1) * (b2_x2 - b2_x1) union = b1_area[:,0] + b2_area[:,0] - intersection # 4. Compute IoU and reshape to [boxes1, boxes2] iou = intersection / union overlaps = iou.view(boxes2_repeat, boxes1_repeat) return overlaps def bbox_overlaps_3D(boxes1, boxes2): """Computes IoU overlaps between two sets of boxes. boxes1, boxes2: [N, (y1, x1, y2, x2, z1, z2)]. """ # 1. Tile boxes2 and repeate boxes1. This allows us to compare # every boxes1 against every boxes2 without loops. # TF doesn't have an equivalent to np.repeate() so simulate it # using tf.tile() and tf.reshape. boxes1_repeat = boxes2.size()[0] boxes2_repeat = boxes1.size()[0] boxes1 = boxes1.repeat(1,boxes1_repeat).view(-1,6) boxes2 = boxes2.repeat(boxes2_repeat,1) # 2. Compute intersections b1_y1, b1_x1, b1_y2, b1_x2, b1_z1, b1_z2 = boxes1.chunk(6, dim=1) b2_y1, b2_x1, b2_y2, b2_x2, b2_z1, b2_z2 = boxes2.chunk(6, dim=1) y1 = torch.max(b1_y1, b2_y1)[:, 0] x1 = torch.max(b1_x1, b2_x1)[:, 0] y2 = torch.min(b1_y2, b2_y2)[:, 0] x2 = torch.min(b1_x2, b2_x2)[:, 0] z1 = torch.max(b1_z1, b2_z1)[:, 0] z2 = torch.min(b1_z2, b2_z2)[:, 0] zeros = Variable(torch.zeros(y1.size()[0]), requires_grad=False) if y1.is_cuda: zeros = zeros.cuda() intersection = torch.max(x2 - x1, zeros) * torch.max(y2 - y1, zeros) * torch.max(z2 - z1, zeros) # 3. Compute unions b1_volume = (b1_y2 - b1_y1) * (b1_x2 - b1_x1) * (b1_z2 - b1_z1) b2_volume = (b2_y2 - b2_y1) * (b2_x2 - b2_x1) * (b2_z2 - b2_z1) union = b1_volume[:,0] + b2_volume[:,0] - intersection # 4. Compute IoU and reshape to [boxes1, boxes2] iou = intersection / union overlaps = iou.view(boxes2_repeat, boxes1_repeat) return overlaps def gt_anchor_matching(cf, anchors, gt_boxes, gt_class_ids=None): """Given the anchors and GT boxes, compute overlaps and identify positive anchors and deltas to refine them to match their corresponding GT boxes. anchors: [num_anchors, (y1, x1, y2, x2, (z1), (z2))] gt_boxes: [num_gt_boxes, (y1, x1, y2, x2, (z1), (z2))] gt_class_ids (optional): [num_gt_boxes] Integer class IDs for one stage detectors. in RPN case of Mask R-CNN, set all positive matches to 1 (foreground) Returns: anchor_class_matches: [N] (int32) matches between anchors and GT boxes. - 1 = positive anchor, -1 = negative anchor, 0 = neutral + 1 = positive anchor, -1 = negative anchor, 0 = neutral. + In case of one stage detectors like RetinaNet/RetinaUNet this flag takes + class_ids as positive anchor values, i.e. values >= 1! anchor_delta_targets: [N, (dy, dx, (dz), log(dh), log(dw), (log(dd)))] Anchor bbox deltas. """ anchor_class_matches = np.zeros([anchors.shape[0]], dtype=np.int32) anchor_delta_targets = np.zeros((cf.rpn_train_anchors_per_image, 2*cf.dim)) anchor_matching_iou = cf.anchor_matching_iou if gt_boxes is None: anchor_class_matches = np.full(anchor_class_matches.shape, fill_value=-1) return anchor_class_matches, anchor_delta_targets # for mrcnn: anchor matching is done for RPN loss, so positive labels are all 1 (foreground) if gt_class_ids is None: gt_class_ids = np.array([1] * len(gt_boxes)) # Compute overlaps [num_anchors, num_gt_boxes] overlaps = compute_overlaps(anchors, gt_boxes) # Match anchors to GT Boxes # If an anchor overlaps a GT box with IoU >= anchor_matching_iou then it's positive. # If an anchor overlaps a GT box with IoU < 0.1 then it's negative. # Neutral anchors are those that don't match the conditions above, # and they don't influence the loss function. # However, don't keep any GT box unmatched (rare, but happens). Instead, # match it to the closest anchor (even if its max IoU is < 0.1). # 1. Set negative anchors first. They get overwritten below if a GT box is # matched to them. Skip boxes in crowd areas. anchor_iou_argmax = np.argmax(overlaps, axis=1) anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax] if anchors.shape[1] == 4: anchor_class_matches[(anchor_iou_max < 0.1)] = -1 elif anchors.shape[1] == 6: anchor_class_matches[(anchor_iou_max < 0.01)] = -1 else: raise ValueError('anchor shape wrong {}'.format(anchors.shape)) # 2. Set an anchor for each GT box (regardless of IoU value). gt_iou_argmax = np.argmax(overlaps, axis=0) for ix, ii in enumerate(gt_iou_argmax): anchor_class_matches[ii] = gt_class_ids[ix] # 3. Set anchors with high overlap as positive. above_trhesh_ixs = np.argwhere(anchor_iou_max >= anchor_matching_iou) anchor_class_matches[above_trhesh_ixs] = gt_class_ids[anchor_iou_argmax[above_trhesh_ixs]] # Subsample to balance positive anchors. ids = np.where(anchor_class_matches > 0)[0] extra = len(ids) - (cf.rpn_train_anchors_per_image // 2) if extra > 0: # Reset the extra ones to neutral ids = np.random.choice(ids, extra, replace=False) anchor_class_matches[ids] = 0 # Leave all negative proposals negative now and sample from them in online hard example mining. # For positive anchors, compute shift and scale needed to transform them to match the corresponding GT boxes. ids = np.where(anchor_class_matches > 0)[0] ix = 0 # index into anchor_delta_targets for i, a in zip(ids, anchors[ids]): # closest gt box (it might have IoU < anchor_matching_iou) gt = gt_boxes[anchor_iou_argmax[i]] # convert coordinates to center plus width/height. gt_h = gt[2] - gt[0] gt_w = gt[3] - gt[1] gt_center_y = gt[0] + 0.5 * gt_h gt_center_x = gt[1] + 0.5 * gt_w # Anchor a_h = a[2] - a[0] a_w = a[3] - a[1] a_center_y = a[0] + 0.5 * a_h a_center_x = a[1] + 0.5 * a_w if cf.dim == 2: anchor_delta_targets[ix] = [ (gt_center_y - a_center_y) / a_h, (gt_center_x - a_center_x) / a_w, np.log(gt_h / a_h), np.log(gt_w / a_w), ] else: gt_d = gt[5] - gt[4] gt_center_z = gt[4] + 0.5 * gt_d a_d = a[5] - a[4] a_center_z = a[4] + 0.5 * a_d anchor_delta_targets[ix] = [ (gt_center_y - a_center_y) / a_h, (gt_center_x - a_center_x) / a_w, (gt_center_z - a_center_z) / a_d, np.log(gt_h / a_h), np.log(gt_w / a_w), np.log(gt_d / a_d) ] # normalize. anchor_delta_targets[ix] /= cf.rpn_bbox_std_dev ix += 1 return anchor_class_matches, anchor_delta_targets def clip_to_window(window, boxes): """ window: (y1, x1, y2, x2) / 3D: (z1, z2). The window in the image we want to clip to. boxes: [N, (y1, x1, y2, x2)] / 3D: (z1, z2) """ boxes[:, 0] = boxes[:, 0].clamp(float(window[0]), float(window[2])) boxes[:, 1] = boxes[:, 1].clamp(float(window[1]), float(window[3])) boxes[:, 2] = boxes[:, 2].clamp(float(window[0]), float(window[2])) boxes[:, 3] = boxes[:, 3].clamp(float(window[1]), float(window[3])) if boxes.shape[1] > 5: boxes[:, 4] = boxes[:, 4].clamp(float(window[4]), float(window[5])) boxes[:, 5] = boxes[:, 5].clamp(float(window[4]), float(window[5])) return boxes ############################################################ # Pytorch Utility Functions ############################################################ def unique1d(tensor): if tensor.size()[0] == 0 or tensor.size()[0] == 1: return tensor tensor = tensor.sort()[0] unique_bool = tensor[1:] != tensor [:-1] first_element = Variable(torch.ByteTensor([True]), requires_grad=False) if tensor.is_cuda: first_element = first_element.cuda() unique_bool = torch.cat((first_element, unique_bool),dim=0) return tensor[unique_bool.data] def log2(x): """Implementatin of Log2. Pytorch doesn't have a native implemenation.""" ln2 = Variable(torch.log(torch.FloatTensor([2.0])), requires_grad=False) if x.is_cuda: ln2 = ln2.cuda() return torch.log(x) / ln2 def intersect1d(tensor1, tensor2): aux = torch.cat((tensor1, tensor2), dim=0) aux = aux.sort(descending=True)[0] return aux[:-1][(aux[1:] == aux[:-1]).data] def shem(roi_probs_neg, negative_count, ohem_poolsize): """ stochastic hard example mining: from a list of indices (referring to non-matched predictions), determine a pool of highest scoring (worst false positives) of size negative_count*ohem_poolsize. Then, sample n (= negative_count) predictions of this pool as negative examples for loss. :param roi_probs_neg: tensor of shape (n_predictions, n_classes). :param negative_count: int. :param ohem_poolsize: int. :return: (negative_count). indices refer to the positions in roi_probs_neg. If pool smaller than expected due to limited negative proposals availabel, this function will return sampled indices of number < negative_count without throwing an error. """ # sort according to higehst foreground score. probs, order = roi_probs_neg[:, 1:].max(1)[0].sort(descending=True) select = torch.tensor((ohem_poolsize * int(negative_count), order.size()[0])).min().int() pool_indices = order[:select] rand_idx = torch.randperm(pool_indices.size()[0]) return pool_indices[rand_idx[:negative_count].cuda()] def initialize_weights(net): """ Initialize model weights. Current Default in Pytorch (version 0.4.1) is initialization from a uniform distriubtion. Will expectably be changed to kaiming_uniform in future versions. """ init_type = net.cf.weight_init for m in [module for module in net.modules() if type(module) in [nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d, nn.Linear]]: if init_type == 'xavier_uniform': nn.init.xavier_uniform_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif init_type == 'xavier_normal': nn.init.xavier_normal_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif init_type == "kaiming_uniform": nn.init.kaiming_uniform_(m.weight.data, mode='fan_out', nonlinearity=net.cf.relu, a=0) if m.bias is not None: fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data) bound = 1 / np.sqrt(fan_out) nn.init.uniform_(m.bias, -bound, bound) elif init_type == "kaiming_normal": nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity=net.cf.relu, a=0) if m.bias is not None: fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data) bound = 1 / np.sqrt(fan_out) nn.init.normal_(m.bias, -bound, bound) class NDConvGenerator(object): """ generic wrapper around conv-layers to avoid 2D vs. 3D distinguishing in code. """ def __init__(self, dim): self.dim = dim def __call__(self, c_in, c_out, ks, pad=0, stride=1, norm=None, relu='relu'): """ :param c_in: number of in_channels. :param c_out: number of out_channels. :param ks: kernel size. :param pad: pad size. :param stride: kernel stride. :param norm: string specifying type of feature map normalization. If None, no normalization is applied. :param relu: string specifying type of nonlinearity. If None, no nonlinearity is applied. :return: convolved feature_map. """ if self.dim == 2: conv = nn.Conv2d(c_in, c_out, kernel_size=ks, padding=pad, stride=stride) if norm is not None: if norm == 'instance_norm': norm_layer = nn.InstanceNorm2d(c_out) elif norm == 'batch_norm': norm_layer = nn.BatchNorm2d(c_out) else: raise ValueError('norm type as specified in configs is not implemented...') conv = nn.Sequential(conv, norm_layer) else: conv = nn.Conv3d(c_in, c_out, kernel_size=ks, padding=pad, stride=stride) if norm is not None: if norm == 'instance_norm': norm_layer = nn.InstanceNorm3d(c_out) elif norm == 'batch_norm': norm_layer = nn.BatchNorm3d(c_out) else: raise ValueError('norm type as specified in configs is not implemented... {}'.format(norm)) conv = nn.Sequential(conv, norm_layer) if relu is not None: if relu == 'relu': relu_layer = nn.ReLU(inplace=True) elif relu == 'leaky_relu': relu_layer = nn.LeakyReLU(inplace=True) else: raise ValueError('relu type as specified in configs is not implemented...') conv = nn.Sequential(conv, relu_layer) return conv def get_one_hot_encoding(y, n_classes): """ transform a numpy label array to a one-hot array of the same shape. :param y: array of shape (b, 1, y, x, (z)). :param n_classes: int, number of classes to unfold in one-hot encoding. :return y_ohe: array of shape (b, n_classes, y, x, (z)) """ dim = len(y.shape) - 2 if dim == 2: y_ohe = np.zeros((y.shape[0], n_classes, y.shape[2], y.shape[3])).astype('int32') if dim ==3: y_ohe = np.zeros((y.shape[0], n_classes, y.shape[2], y.shape[3], y.shape[4])).astype('int32') for cl in range(n_classes): y_ohe[:, cl][y[:, 0] == cl] = 1 return y_ohe def get_dice_per_batch_and_class(pred, y, n_classes): ''' computes dice scores per batch instance and class. :param pred: prediction array of shape (b, 1, y, x, (z)) (e.g. softmax prediction with argmax over dim 1) :param y: ground truth array of shape (b, 1, y, x, (z)) (contains int [0, ..., n_classes] :param n_classes: int :return: dice scores of shape (b, c) ''' pred = get_one_hot_encoding(pred, n_classes) y = get_one_hot_encoding(y, n_classes) axes = tuple(range(2, len(pred.shape))) intersect = np.sum(pred*y, axis=axes) denominator = np.sum(pred, axis=axes)+np.sum(y, axis=axes) + 1e-8 dice = 2.0*intersect / denominator return dice def sum_tensor(input, axes, keepdim=False): axes = np.unique(axes) if keepdim: for ax in axes: input = input.sum(ax, keepdim=True) else: for ax in sorted(axes, reverse=True): input = input.sum(int(ax)) return input def batch_dice(pred, y, false_positive_weight=1.0, eps=1e-6): ''' compute soft dice over batch. this is a diffrentiable score and can be used as a loss function. only dice scores of foreground classes are returned, since training typically does not benefit from explicit background optimization. Pixels of the entire batch are considered a pseudo-volume to compute dice scores of. This way, single patches with missing foreground classes can not produce faulty gradients. :param pred: (b, c, y, x, (z)), softmax probabilities (network output). :param y: (b, c, y, x, (z)), one hote encoded segmentation mask. :param false_positive_weight: float [0,1]. For weighting of imbalanced classes, reduces the penalty for false-positive pixels. Can be beneficial sometimes in data with heavy fg/bg imbalances. :return: soft dice score (float).This function discards the background score and returns the mena of foreground scores. ''' if len(pred.size()) == 4: axes = (0, 2, 3) intersect = sum_tensor(pred * y, axes, keepdim=False) denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False) return torch.mean((2 * intersect / (denom + eps))[1:]) #only fg dice here. if len(pred.size()) == 5: axes = (0, 2, 3, 4) intersect = sum_tensor(pred * y, axes, keepdim=False) denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False) return torch.mean((2 * intersect / (denom + eps))[1:]) #only fg dice here. else: raise ValueError('wrong input dimension in dice loss') def batch_dice_mask(pred, y, mask, false_positive_weight=1.0, eps=1e-6): ''' compute soft dice over batch. this is a diffrentiable score and can be used as a loss function. only dice scores of foreground classes are returned, since training typically does not benefit from explicit background optimization. Pixels of the entire batch are considered a pseudo-volume to compute dice scores of. This way, single patches with missing foreground classes can not produce faulty gradients. :param pred: (b, c, y, x, (z)), softmax probabilities (network output). :param y: (b, c, y, x, (z)), one hote encoded segmentation mask. :param false_positive_weight: float [0,1]. For weighting of imbalanced classes, reduces the penalty for false-positive pixels. Can be beneficial sometimes in data with heavy fg/bg imbalances. :return: soft dice score (float).This function discards the background score and returns the mena of foreground scores. ''' mask = mask.unsqueeze(1).repeat(1, 2, 1, 1) if len(pred.size()) == 4: axes = (0, 2, 3) intersect = sum_tensor(pred * y * mask, axes, keepdim=False) denom = sum_tensor(false_positive_weight*pred * mask + y * mask, axes, keepdim=False) return torch.mean((2 * intersect / (denom + eps))[1:]) #only fg dice here. if len(pred.size()) == 5: axes = (0, 2, 3, 4) intersect = sum_tensor(pred * y, axes, keepdim=False) denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False) return torch.mean((2 * intersect / (denom + eps))[1:]) #only fg dice here. else: raise ValueError('wrong input dimension in dice loss') \ No newline at end of file