diff --git a/models/mrcnn.py b/models/mrcnn.py
index a9535fe..e0b7982 100644
--- a/models/mrcnn.py
+++ b/models/mrcnn.py
@@ -1,751 +1,752 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 """
 Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn
 published under MIT license.
 """
 import os
 from multiprocessing import  Pool
 import time
 
 import numpy as np
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.utils
 
 import utils.model_utils as mutils
 import utils.exp_utils as utils
 
 
 
 class RPN(nn.Module):
     """
     Region Proposal Network.
     """
 
     def __init__(self, cf, conv):
 
         super(RPN, self).__init__()
         self.dim = conv.dim
 
         self.conv_shared = conv(cf.end_filts, cf.n_rpn_features, ks=3, stride=cf.rpn_anchor_stride, pad=1, relu=cf.relu)
         self.conv_class = conv(cf.n_rpn_features, 2 * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None)
         self.conv_bbox = conv(cf.n_rpn_features, 2 * self.dim * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None)
 
 
     def forward(self, x):
         """
         :param x: input feature maps (b, in_channels, y, x, (z))
         :return: rpn_class_logits (b, 2, n_anchors)
         :return: rpn_probs_logits (b, 2, n_anchors)
         :return: rpn_bbox (b, 2 * dim, n_anchors)
         """
 
         # Shared convolutional base of the RPN.
         x = self.conv_shared(x)
 
         # Anchor Score. (batch, anchors per location * 2, y, x, (z)).
         rpn_class_logits = self.conv_class(x)
         # Reshape to (batch, 2, anchors)
         axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1)
         rpn_class_logits = rpn_class_logits.permute(*axes)
         rpn_class_logits = rpn_class_logits.contiguous()
         rpn_class_logits = rpn_class_logits.view(x.size()[0], -1, 2)
 
         # Softmax on last dimension (fg vs. bg).
         rpn_probs = F.softmax(rpn_class_logits, dim=2)
 
         # Bounding box refinement. (batch, anchors_per_location * (y, x, (z), log(h), log(w), (log(d)), y, x, (z))
         rpn_bbox = self.conv_bbox(x)
 
         # Reshape to (batch, 2*dim, anchors)
         rpn_bbox = rpn_bbox.permute(*axes)
         rpn_bbox = rpn_bbox.contiguous()
         rpn_bbox = rpn_bbox.view(x.size()[0], -1, self.dim * 2)
 
         return [rpn_class_logits, rpn_probs, rpn_bbox]
 
 
 
 class Classifier(nn.Module):
     """
     Head network for classification and bounding box refinement. Performs RoiAlign, processes resulting features through a
     shared convolutional base and finally branches off the classifier- and regression head.
     """
     def __init__(self, cf, conv):
         super(Classifier, self).__init__()
 
         self.cf = cf
         self.dim = conv.dim
         self.in_channels = cf.end_filts
         self.pool_size = cf.pool_size
         self.pyramid_levels = cf.pyramid_levels
         # instance_norm does not work with spatial dims (1, 1, (1))
         norm = cf.norm if cf.norm != 'instance_norm' else None
 
         self.conv1 = conv(cf.end_filts, cf.end_filts * 4, ks=self.pool_size, stride=1, norm=norm, relu=cf.relu)
         self.conv2 = conv(cf.end_filts * 4, cf.end_filts * 4, ks=1, stride=1, norm=norm, relu=cf.relu)
         self.linear_bbox = nn.Linear(cf.end_filts * 4, cf.head_classes * 2 * self.dim)
 
 
         if 'regression' in self.cf.prediction_tasks:
             self.linear_regressor = nn.Linear(cf.end_filts * 4, cf.head_classes * cf.regression_n_features)
             self.rg_n_feats = cf.regression_n_features
         #classify into bins of regression values
         elif 'regression_bin' in self.cf.prediction_tasks:
             self.linear_regressor = nn.Linear(cf.end_filts * 4, cf.head_classes * len(cf.bin_labels))
             self.rg_n_feats = len(cf.bin_labels)
         else:
             self.linear_regressor = lambda x: torch.zeros((x.shape[0], cf.head_classes * 1), dtype=torch.float32).fill_(float('NaN')).cuda()
             self.rg_n_feats = 1 #cf.regression_n_features
         if 'class' in self.cf.prediction_tasks:
             self.linear_class = nn.Linear(cf.end_filts * 4, cf.head_classes)
         else:
             assert cf.head_classes == 2, "#head classes {} needs to be 2 (bg/fg) when not predicting classes".format(cf.head_classes)
             self.linear_class = lambda x: torch.zeros((x.shape[0], cf.head_classes), dtype=torch.float64).cuda()
 
 
     def forward(self, x, rois):
         """
         :param x: input feature maps (b, in_channels, y, x, (z))
         :param rois: normalized box coordinates as proposed by the RPN to be forwarded through
         the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements
         have been merged to one vector, while the origin info has been stored for re-allocation.
         :return: mrcnn_class_logits (n_proposals, n_head_classes)
         :return: mrcnn_bbox (n_proposals, n_head_classes, 2 * dim) predicted corrections to be applied to proposals for refinement.
         """
         x = mutils.pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim)
         x = self.conv1(x)
         x = self.conv2(x)
         x = x.view(-1, self.in_channels * 4)
 
         mrcnn_bbox = self.linear_bbox(x)
         mrcnn_bbox = mrcnn_bbox.view(mrcnn_bbox.size()[0], -1, self.dim * 2)
         mrcnn_class_logits = self.linear_class(x)
         mrcnn_regress = self.linear_regressor(x)
         mrcnn_regress = mrcnn_regress.view(mrcnn_regress.size()[0], -1, self.rg_n_feats)
 
         return [mrcnn_bbox, mrcnn_class_logits, mrcnn_regress]
 
 
 class Mask(nn.Module):
     """
     Head network for proposal-based mask segmentation. Performs RoiAlign, some convolutions and applies sigmoid on the
     output logits to allow for overlapping classes.
     """
     def __init__(self, cf, conv):
         super(Mask, self).__init__()
         self.pool_size = cf.mask_pool_size
         self.pyramid_levels = cf.pyramid_levels
         self.dim = conv.dim
         self.conv1 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
         self.conv2 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
         self.conv3 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
         self.conv4 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu)
         if conv.dim == 2:
             self.deconv = nn.ConvTranspose2d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2)
         else:
             self.deconv = nn.ConvTranspose3d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2)
 
         self.relu = nn.ReLU(inplace=True) if cf.relu == 'relu' else nn.LeakyReLU(inplace=True)
         self.conv5 = conv(cf.end_filts, cf.head_classes, ks=1, stride=1, relu=None)
         self.sigmoid = nn.Sigmoid()
 
     def forward(self, x, rois):
         """
         :param x: input feature maps (b, in_channels, y, x, (z))
         :param rois: normalized box coordinates as proposed by the RPN to be forwarded through
         the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements
         have been merged to one vector, while the origin info has been stored for re-allocation.
         :return: x: masks (n_sampled_proposals (n_detections in inference), n_classes, y, x, (z))
         """
         x = mutils.pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim)
         x = self.conv1(x)
         x = self.conv2(x)
         x = self.conv3(x)
         x = self.conv4(x)
         x = self.relu(self.deconv(x))
         x = self.conv5(x)
         x = self.sigmoid(x)
         return x
 
 
 ############################################################
 #  Loss Functions
 ############################################################
 
 def compute_rpn_class_loss(rpn_class_logits, rpn_match, shem_poolsize):
     """
     :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors.
     :param rpn_class_logits: (n_anchors, 2). logits from RPN classifier.
     :param SHEM_poolsize: int. factor of top-k candidates to draw from per negative sample (stochastic-hard-example-mining).
     :return: loss: torch tensor
     :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training.
     """
 
     # Filter out netural anchors
     pos_indices = torch.nonzero(rpn_match == 1)
     neg_indices = torch.nonzero(rpn_match == -1)
 
     # loss for positive samples
     if not 0 in pos_indices.size():
         pos_indices = pos_indices.squeeze(1)
         roi_logits_pos = rpn_class_logits[pos_indices]
         pos_loss = F.cross_entropy(roi_logits_pos, torch.LongTensor([1] * pos_indices.shape[0]).cuda())
     else:
         pos_loss = torch.FloatTensor([0]).cuda()
 
     # loss for negative samples: draw hard negative examples (SHEM)
     # that match the number of positive samples, but at least 1.
     if not 0 in neg_indices.size():
         neg_indices = neg_indices.squeeze(1)
         roi_logits_neg = rpn_class_logits[neg_indices]
         negative_count = np.max((1, pos_indices.cpu().data.numpy().size))
         roi_probs_neg = F.softmax(roi_logits_neg, dim=1)
         neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize)
         neg_loss = F.cross_entropy(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda())
         np_neg_ix = neg_ix.cpu().data.numpy()
         #print("pos, neg count", pos_indices.cpu().data.numpy().size, negative_count)
     else:
         neg_loss = torch.FloatTensor([0]).cuda()
         np_neg_ix = np.array([]).astype('int32')
 
     loss = (pos_loss + neg_loss) / 2
     return loss, np_neg_ix
 
 
 def compute_rpn_bbox_loss(rpn_pred_deltas, rpn_target_deltas, rpn_match):
     """
     :param rpn_target_deltas:   (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))).
     Uses 0 padding to fill in unsed bbox deltas.
     :param rpn_pred_deltas: predicted deltas from RPN. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd))))
     :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors.
     :return: loss: torch 1D tensor.
     """
     if not 0 in torch.nonzero(rpn_match == 1).size():
 
         indices = torch.nonzero(rpn_match == 1).squeeze(1)
         # Pick bbox deltas that contribute to the loss
         rpn_pred_deltas = rpn_pred_deltas[indices]
         # Trim target bounding box deltas to the same length as rpn_bbox.
         target_deltas = rpn_target_deltas[:rpn_pred_deltas.size()[0], :]
         # Smooth L1 loss
         loss = F.smooth_l1_loss(rpn_pred_deltas, target_deltas)
     else:
         loss = torch.FloatTensor([0]).cuda()
 
     return loss
 
 def compute_mrcnn_bbox_loss(mrcnn_pred_deltas, mrcnn_target_deltas, target_class_ids):
     """
     :param mrcnn_target_deltas: (n_sampled_rois, (dy, dx, (dz), log(dh), log(dw), (log(dh)))
     :param mrcnn_pred_deltas: (n_sampled_rois, n_classes, (dy, dx, (dz), log(dh), log(dw), (log(dh)))
     :param target_class_ids: (n_sampled_rois)
     :return: loss: torch 1D tensor.
     """
     if not 0 in torch.nonzero(target_class_ids > 0).size():
         positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0]
         positive_roi_class_ids = target_class_ids[positive_roi_ix].long()
         target_bbox = mrcnn_target_deltas[positive_roi_ix, :].detach()
         pred_bbox = mrcnn_pred_deltas[positive_roi_ix, positive_roi_class_ids, :]
         loss = F.smooth_l1_loss(pred_bbox, target_bbox)
     else:
         loss = torch.FloatTensor([0]).cuda()
 
     return loss
 
 def compute_mrcnn_mask_loss(pred_masks, target_masks, target_class_ids):
     """
     :param target_masks: (n_sampled_rois, y, x, (z)) A float32 tensor of values 0 or 1. Uses zero padding to fill array.
     :param pred_masks: (n_sampled_rois, n_classes, y, x, (z)) float32 tensor with values between [0, 1].
     :param target_class_ids: (n_sampled_rois)
     :return: loss: torch 1D tensor.
     """
     #print("targ masks", target_masks.unique(return_counts=True))
     if not 0 in torch.nonzero(target_class_ids > 0).size():
         # Only positive ROIs contribute to the loss. And only
         # the class-specific mask of each ROI.
         positive_ix = torch.nonzero(target_class_ids > 0)[:, 0]
         positive_class_ids = target_class_ids[positive_ix].long()
         y_true = target_masks[positive_ix, :, :].detach()
         y_pred = pred_masks[positive_ix, positive_class_ids, :, :]
         loss = F.binary_cross_entropy(y_pred, y_true)
     else:
         loss = torch.FloatTensor([0]).cuda()
 
     return loss
 
 def compute_mrcnn_class_loss(tasks, pred_class_logits, target_class_ids):
     """
     :param pred_class_logits: (n_sampled_rois, n_classes)
     :param target_class_ids: (n_sampled_rois) batch dimension was merged into roi dimension.
     :return: loss: torch 1D tensor.
     """
     if 'class' in tasks and not 0 in target_class_ids.size():
         loss = F.cross_entropy(pred_class_logits, target_class_ids.long())
     else:
         loss = torch.FloatTensor([0.]).cuda()
 
     return loss
 
 def compute_mrcnn_regression_loss(tasks, pred, target, target_class_ids):
     """regression loss is a distance metric between target vector and predicted regression vector.
     :param pred: (n_sampled_rois, n_classes, [n_rg_feats if real regression or 1 if rg_bin task)
     :param target: (n_sampled_rois, [n_rg_feats or n_rg_bins])
     :return: differentiable loss, torch 1D tensor on cuda
     """
 
     if not 0 in target.shape and not 0 in torch.nonzero(target_class_ids > 0).shape:
         positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0]
         positive_roi_class_ids = target_class_ids[positive_roi_ix].long()
         target = target[positive_roi_ix].detach()
         pred = pred[positive_roi_ix, positive_roi_class_ids]
         if "regression_bin" in tasks:
             loss = F.cross_entropy(pred, target.long())
         else:
             loss = F.smooth_l1_loss(pred, target)
             #loss = F.mse_loss(pred, target)
     else:
         loss = torch.FloatTensor([0.]).cuda()
 
     return loss
 
 ############################################################
 #  Detection Layer
 ############################################################
 
 def compute_roi_scores(tasks, batch_rpn_proposals, mrcnn_cl_logits):
     """ Depending on the predicition tasks: if no class prediction beyong fg/bg (--> means no additional class
         head was applied) use RPN objectness scores as roi scores, otherwise class head scores.
     :param cf:
     :param batch_rpn_proposals:
     :param mrcnn_cl_logits:
     :return:
     """
     if not 'class' in tasks:
         scores = batch_rpn_proposals[:, :, -1].view(-1, 1)
         scores = torch.cat((1 - scores, scores), dim=1)
     else:
         scores = F.softmax(mrcnn_cl_logits, dim=1)
 
     return scores
 
 ############################################################
 #  MaskRCNN Class
 ############################################################
 
 class net(nn.Module):
 
 
     def __init__(self, cf, logger):
 
         super(net, self).__init__()
         self.cf = cf
         self.logger = logger
         self.build()
 
         loss_order = ['rpn_class', 'rpn_bbox', 'mrcnn_bbox', 'mrcnn_mask', 'mrcnn_class', 'mrcnn_rg']
         if hasattr(cf, "mrcnn_loss_weights"):
             # bring into right order
             self.loss_weights = np.array([cf.mrcnn_loss_weights[k] for k in loss_order])
         else:
             self.loss_weights = np.array([1.]*len(loss_order))
 
         if self.cf.weight_init=="custom":
             logger.info("Tried to use custom weight init which is not defined. Using pytorch default.")
         elif self.cf.weight_init:
             mutils.initialize_weights(self)
         else:
             logger.info("using default pytorch weight init")
 
     def build(self):
         """Build Mask R-CNN architecture."""
 
         # Image size must be dividable by 2 multiple times.
         h, w = self.cf.patch_size[:2]
         if h / 2**5 != int(h / 2**5) or w / 2**5 != int(w / 2**5):
             raise Exception("Image size must be divisible by 2 at least 5 times "
                             "to avoid fractions when downscaling and upscaling."
                             "For example, use 256, 288, 320, 384, 448, 512, ... etc.,i.e.,"
                             "any number x*32 will do!")
 
         # instantiate abstract multi-dimensional conv generator and load backbone module.
         backbone = utils.import_module('bbone', self.cf.backbone_path)
         self.logger.info("loaded backbone from {}".format(self.cf.backbone_path))
         conv = backbone.ConvGenerator(self.cf.dim)
 
         # build Anchors, FPN, RPN, Classifier / Bbox-Regressor -head, Mask-head
         self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf)
         self.anchors = torch.from_numpy(self.np_anchors).float().cuda()
         self.fpn = backbone.FPN(self.cf, conv, relu_enc=self.cf.relu, operate_stride1=False).cuda()
         self.rpn = RPN(self.cf, conv)
         self.classifier = Classifier(self.cf, conv)
         self.mask = Mask(self.cf, conv)
 
     def forward(self, img, is_training=True):
         """
         :param img: input images (b, c, y, x, (z)).
         :return: rpn_pred_logits: (b, n_anchors, 2)
         :return: rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d))))
         :return: batch_proposal_boxes: (b, n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix)) only for monitoring/plotting.
         :return: detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score)
         :return: detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head.
         """
         # extract features.
         fpn_outs = self.fpn(img)
         rpn_feature_maps = [fpn_outs[i] for i in self.cf.pyramid_levels]
         self.mrcnn_feature_maps = rpn_feature_maps
 
         # loop through pyramid layers and apply RPN.
         layer_outputs = [ self.rpn(p_feats) for p_feats in rpn_feature_maps ]
 
         # concatenate layer outputs.
         # convert from list of lists of level outputs to list of lists of outputs across levels.
         # e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]]
         outputs = list(zip(*layer_outputs))
         outputs = [torch.cat(list(o), dim=1) for o in outputs]
         rpn_pred_logits, rpn_pred_probs, rpn_pred_deltas = outputs
         #
         # # generate proposals: apply predicted deltas to anchors and filter by foreground scores from RPN classifier.
         proposal_count = self.cf.post_nms_rois_training if is_training else self.cf.post_nms_rois_inference
         batch_normed_props, batch_unnormed_props = mutils.refine_proposals(rpn_pred_probs, rpn_pred_deltas,
                                                                             proposal_count, self.anchors, self.cf)
 
         # merge batch dimension of proposals while storing allocation info in coordinate dimension.
         batch_ixs = torch.arange(
             batch_normed_props.shape[0]).cuda().unsqueeze(1).repeat(1,batch_normed_props.shape[1]).view(-1).float()
         rpn_rois = batch_normed_props[:, :, :-1].view(-1, batch_normed_props[:, :, :-1].shape[2])
         self.rpn_rois_batch_info = torch.cat((rpn_rois, batch_ixs.unsqueeze(1)), dim=1)
 
         # this is the first of two forward passes in the second stage, where no activations are stored for backprop.
         # here, all proposals are forwarded (with virtual_batch_size = batch_size * post_nms_rois.)
         # for inference/monitoring as well as sampling of rois for the loss functions.
         # processed in chunks of roi_chunk_size to re-adjust to gpu-memory.
         chunked_rpn_rois = self.rpn_rois_batch_info.split(self.cf.roi_chunk_size)
         bboxes_list, class_logits_list, regressions_list = [], [], []
         with torch.no_grad():
             for chunk in chunked_rpn_rois:
                 chunk_bboxes, chunk_class_logits, chunk_regressions = self.classifier(self.mrcnn_feature_maps, chunk)
                 bboxes_list.append(chunk_bboxes)
                 class_logits_list.append(chunk_class_logits)
                 regressions_list.append(chunk_regressions)
         mrcnn_bbox = torch.cat(bboxes_list, 0)
         mrcnn_class_logits = torch.cat(class_logits_list, 0)
         mrcnn_regressions = torch.cat(regressions_list, 0)
         self.mrcnn_roi_scores = compute_roi_scores(self.cf.prediction_tasks, batch_normed_props, mrcnn_class_logits)
 
         # refine classified proposals, filter and return final detections.
         # returns (cf.max_inst_per_batch_element, n_coords+1+...)
         detections = mutils.refine_detections(self.cf, batch_ixs, rpn_rois, mrcnn_bbox, self.mrcnn_roi_scores,
                                        mrcnn_regressions)
 
         # forward remaining detections through mask-head to generate corresponding masks.
         scale = [img.shape[2]] * 4 + [img.shape[-1]] * 2
         scale = torch.from_numpy(np.array(scale[:self.cf.dim * 2] + [1])[None]).float().cuda()
 
         # first self.cf.dim * 2 entries on axis 1 are always the box coords, +1 is batch_ix
         detection_boxes = detections[:, :self.cf.dim * 2 + 1] / scale
         with torch.no_grad():
             detection_masks = self.mask(self.mrcnn_feature_maps, detection_boxes)
 
         return [rpn_pred_logits, rpn_pred_deltas, batch_unnormed_props, detections, detection_masks]
 
 
     def loss_samples_forward(self, batch_gt_boxes, batch_gt_masks, batch_gt_class_ids, batch_gt_regressions=None):
         """
         this is the second forward pass through the second stage (features from stage one are re-used).
         samples few rois in loss_example_mining and forwards only those for loss computation.
         :param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels.
         :param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates.
         :param batch_gt_masks: (b, n(b), c, y, x (,z)) list over batch elements. Each element holds n_gt_rois(b)
                 (i.e., dependent on the batch element) binary masks of shape (c, y, x, (z)).
         :return: sample_logits: (n_sampled_rois, n_classes) predicted class scores.
         :return: sample_deltas: (n_sampled_rois, n_classes, 2 * dim) predicted corrections to be applied to proposals for refinement.
         :return: sample_mask: (n_sampled_rois, n_classes, y, x, (z)) predicted masks per class and proposal.
         :return: sample_target_class_ids: (n_sampled_rois) target class labels of sampled proposals.
         :return: sample_target_deltas: (n_sampled_rois, 2 * dim) target deltas of sampled proposals for box refinement.
         :return: sample_target_masks: (n_sampled_rois, y, x, (z)) target masks of sampled proposals.
         :return: sample_proposals: (n_sampled_rois, 2 * dim) RPN output for sampled proposals. only for monitoring/plotting.
         """
         # sample rois for loss and get corresponding targets for all Mask R-CNN head network losses.
         sample_ics, sample_target_deltas, sample_target_mask, sample_target_class_ids, sample_target_regressions = \
             mutils.loss_example_mining(self.cf, self.rpn_rois_batch_info, batch_gt_boxes, batch_gt_masks,
                                        self.mrcnn_roi_scores, batch_gt_class_ids, batch_gt_regressions)
 
         # re-use feature maps and RPN output from first forward pass.
         sample_proposals = self.rpn_rois_batch_info[sample_ics]
         if not 0 in sample_proposals.size():
             sample_deltas, sample_logits, sample_regressions = self.classifier(self.mrcnn_feature_maps, sample_proposals)
             sample_mask = self.mask(self.mrcnn_feature_maps, sample_proposals)
         else:
             sample_logits = torch.FloatTensor().cuda()
             sample_deltas = torch.FloatTensor().cuda()
             sample_regressions = torch.FloatTensor().cuda()
             sample_mask = torch.FloatTensor().cuda()
 
         return [sample_deltas, sample_mask, sample_logits, sample_regressions, sample_proposals,
                 sample_target_deltas, sample_target_mask, sample_target_class_ids, sample_target_regressions]
 
     def get_results(self, img_shape, detections, detection_masks, box_results_list=None, return_masks=True):
         """
         Restores batch dimension of merged detections, unmolds detections, creates and fills results dict.
         :param img_shape:
         :param detections: shape (n_final_detections, len(info)), where
             info=( y1, x1, y2, x2, (z1,z2), batch_ix, pred_class_id, pred_score )
         :param detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head.
         :param box_results_list: None or list of output boxes for monitoring/plotting.
         each element is a list of boxes per batch element.
         :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off).
         :return: results_dict: dictionary with keys:
                  'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                           [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                  'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, 1] only fg. vs. bg for now.
                  class-specific return of masks will come with implementation of instance segmentation evaluation.
         """
 
         detections = detections.cpu().data.numpy()
         if self.cf.dim == 2:
             detection_masks = detection_masks.permute(0, 2, 3, 1).cpu().data.numpy()
         else:
             detection_masks = detection_masks.permute(0, 2, 3, 4, 1).cpu().data.numpy()
         # det masks shape now (n_dets, y,x(,z), n_classes)
         # restore batch dimension of merged detections using the batch_ix info.
         batch_ixs = detections[:, self.cf.dim*2]
         detections = [detections[batch_ixs == ix] for ix in range(img_shape[0])]
         mrcnn_mask = [detection_masks[batch_ixs == ix] for ix in range(img_shape[0])]
         # mrcnn_mask: shape (b_size, variable, variable, n_classes), variable bc depends on single instance mask size
 
         if box_results_list == None: # for test_forward, where no previous list exists.
             box_results_list =  [[] for _ in range(img_shape[0])]
         # seg_logits == seg_probs in mrcnn since mask head finishes with sigmoid (--> image space = [0,1])
         seg_probs = []
         # loop over batch and unmold detections.
         for ix in range(img_shape[0]):
 
             # final masks are one-hot encoded (b, n_classes, y, x, (z))
             final_masks = np.zeros((self.cf.num_classes + 1, *img_shape[2:]))
             #+1 for bg, 0.5 bc mask head classifies only bg/fg with logits between 0,1--> bg is <0.5
             if self.cf.num_classes + 1 != self.cf.num_seg_classes:
                 self.logger.warning("n of roi-classifier head classes {} doesnt match cf.num_seg_classes {}".format(
                     self.cf.num_classes + 1, self.cf.num_seg_classes))
 
             if not 0 in detections[ix].shape:
                 boxes = detections[ix][:, :self.cf.dim*2].astype(np.int32)
                 class_ids = detections[ix][:, self.cf.dim*2 + 1].astype(np.int32)
                 scores = detections[ix][:, self.cf.dim*2 + 2]
                 masks = mrcnn_mask[ix][np.arange(boxes.shape[0]), ..., class_ids]
                 regressions = detections[ix][:,self.cf.dim*2+3:]
 
                 # Filter out detections with zero area. Often only happens in early
                 # stages of training when the network weights are still a bit random.
                 if self.cf.dim == 2:
                     exclude_ix = np.where((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0]
                 else:
                     exclude_ix = np.where(
                         (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4]) <= 0)[0]
 
                 if exclude_ix.shape[0] > 0:
                     boxes = np.delete(boxes, exclude_ix, axis=0)
                     masks = np.delete(masks, exclude_ix, axis=0)
                     class_ids = np.delete(class_ids, exclude_ix, axis=0)
                     scores = np.delete(scores, exclude_ix, axis=0)
                     regressions = np.delete(regressions, exclude_ix, axis=0)
 
                 # Resize masks to original image size and set boundary threshold.
                 if return_masks:
                     for i in range(masks.shape[0]): #masks per this batch instance/element/image
                         # Convert neural network mask to full size mask
                         if self.cf.dim == 2:
                             full_mask = mutils.unmold_mask_2D(masks[i], boxes[i], img_shape[2:])
                         else:
                             full_mask = mutils.unmold_mask_3D(masks[i], boxes[i], img_shape[2:])
                         # take the maximum seg_logits per class of instances in that class, i.e., a pixel in a class
                         # has the max seg_logit value over all instances of that class in one sample
                         final_masks[class_ids[i]] = np.max((final_masks[class_ids[i]], full_mask), axis=0)
                     final_masks[0] = np.full(final_masks[0].shape, 0.49999999) #effectively min_det_thres at 0.5 per pixel
 
                 # add final predictions to results.
                 if not 0 in boxes.shape:
                     for ix2, coords in enumerate(boxes):
                         box = {'box_coords': coords, 'box_type': 'det', 'box_score': scores[ix2],
                                'box_pred_class_id': class_ids[ix2]}
                         #if (hasattr(self.cf, "convert_cl_to_rg") and self.cf.convert_cl_to_rg):
                         if "regression_bin" in self.cf.prediction_tasks:
                             # in this case, regression preds are actually the rg_bin_ids --> map to rg value the bin represents
                             box['rg_bin'] = regressions[ix2].argmax()
                             box['regression'] = self.cf.bin_id2rg_val[box['rg_bin']]
                         else:
                             box['regression'] = regressions[ix2]
                             if hasattr(self.cf, "rg_val_to_bin_id") and \
                                     any(['regression' in task for task in self.cf.prediction_tasks]):
                                 box.update({'rg_bin': self.cf.rg_val_to_bin_id(regressions[ix2])})
 
                         box_results_list[ix].append(box)
 
             # if no detections were made--> keep full bg mask (zeros).
             seg_probs.append(final_masks)
 
         # create and fill results dictionary.
         results_dict = {}
         results_dict['boxes'] = box_results_list
         results_dict['seg_preds'] = np.array(seg_probs)
 
         return results_dict
 
+
     def train_forward(self, batch, is_validation=False):
         """
         train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
         for processing, computes losses, and stores outputs in a dictionary.
         :param batch: dictionary containing 'data', 'seg', etc.
             batch['roi_masks']: (b, n(b), c, h(n), w(n) (z(n))) list like roi_labels but with arrays (masks) inplace of
         integers. c==channels of the raw segmentation.
         :return: results_dict: dictionary with keys:
                 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                         [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes].
                 'torch_loss': 1D torch tensor for backprop.
                 'class_loss': classification loss for monitoring.
         """
         img = batch['data']
         gt_boxes = batch['bb_target']
         #axes = (0, 2, 3, 1) if self.cf.dim == 2 else (0, 2, 3, 4, 1)
         #gt_masks = [np.transpose(batch['roi_masks'][ii], axes=axes) for ii in range(len(batch['roi_masks']))]
         gt_masks = batch['roi_masks']
         gt_class_ids = batch['class_targets']
         if 'regression' in self.cf.prediction_tasks:
             gt_regressions = batch["regression_targets"]
         elif 'regression_bin' in self.cf.prediction_tasks:
             gt_regressions = batch["rg_bin_targets"]
         else:
             gt_regressions = None
 
         img = torch.from_numpy(img).cuda().float()
         batch_rpn_class_loss = torch.FloatTensor([0]).cuda()
         batch_rpn_bbox_loss = torch.FloatTensor([0]).cuda()
 
         # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element.
         box_results_list = [[] for _ in range(img.shape[0])]
 
         #forward passes. 1. general forward pass, where no activations are saved in second stage (for performance
         # monitoring and loss sampling). 2. second stage forward pass of sampled rois with stored activations for backprop.
         rpn_class_logits, rpn_pred_deltas, proposal_boxes, detections, detection_masks = self.forward(img)
 
         mrcnn_pred_deltas, mrcnn_pred_mask, mrcnn_class_logits, mrcnn_regressions, sample_proposals, \
         mrcnn_target_deltas, target_mask, target_class_ids, target_regressions = \
             self.loss_samples_forward(gt_boxes, gt_masks, gt_class_ids, gt_regressions)
         # loop over batch
         for b in range(img.shape[0]):
             if len(gt_boxes[b]) > 0:
                 # add gt boxes to output list
                 for tix in range(len(gt_boxes[b])):
                     gt_box = {'box_type': 'gt', 'box_coords': batch['bb_target'][b][tix]}
                     for name in self.cf.roi_items:
                         gt_box.update({name: batch[name][b][tix]})
                     box_results_list[b].append(gt_box)
 
                 # match gt boxes with anchors to generate targets for RPN losses.
                 rpn_match, rpn_target_deltas = mutils.gt_anchor_matching(self.cf, self.np_anchors, gt_boxes[b])
 
                 # add positive anchors used for loss to output list for monitoring.
                 pos_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == 1)][:, 0], img.shape[2:])
                 for p in pos_anchors:
                     box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'})
 
             else:
                 rpn_match = np.array([-1]*self.np_anchors.shape[0])
                 rpn_target_deltas = np.array([0])
 
             rpn_match_gpu = torch.from_numpy(rpn_match).cuda()
             rpn_target_deltas = torch.from_numpy(rpn_target_deltas).float().cuda()
 
             # compute RPN losses.
             rpn_class_loss, neg_anchor_ix = compute_rpn_class_loss(rpn_class_logits[b], rpn_match_gpu, self.cf.shem_poolsize)
             rpn_bbox_loss = compute_rpn_bbox_loss(rpn_pred_deltas[b], rpn_target_deltas, rpn_match_gpu)
             batch_rpn_class_loss += rpn_class_loss /img.shape[0]
             batch_rpn_bbox_loss += rpn_bbox_loss /img.shape[0]
 
             # add negative anchors used for loss to output list for monitoring.
             # neg_anchor_ix = neg_ix come from shem and mark positions in roi_probs_neg = rpn_class_logits[neg_indices]
             # with neg_indices = rpn_match == -1
             neg_anchors = mutils.clip_boxes_numpy(self.np_anchors[rpn_match == -1][neg_anchor_ix], img.shape[2:])
             for n in neg_anchors:
                 box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'})
 
             # add highest scoring proposals to output list for monitoring.
             rpn_proposals = proposal_boxes[b][proposal_boxes[b, :, -1].argsort()][::-1]
             for r in rpn_proposals[:self.cf.n_plot_rpn_props, :-1]:
                 box_results_list[b].append({'box_coords': r, 'box_type': 'prop'})
 
         # add positive and negative roi samples used for mrcnn losses to output list for monitoring.
         if not 0 in sample_proposals.shape:
             rois = mutils.clip_to_window(self.cf.window, sample_proposals).cpu().data.numpy()
             for ix, r in enumerate(rois):
                 box_results_list[int(r[-1])].append({'box_coords': r[:-1] * self.cf.scale,
                                             'box_type': 'pos_class' if target_class_ids[ix] > 0 else 'neg_class'})
 
         # compute mrcnn losses.
         mrcnn_class_loss = compute_mrcnn_class_loss(self.cf.prediction_tasks, mrcnn_class_logits, target_class_ids)
         mrcnn_bbox_loss = compute_mrcnn_bbox_loss(mrcnn_pred_deltas, mrcnn_target_deltas, target_class_ids)
         mrcnn_regressions_loss = compute_mrcnn_regression_loss(self.cf.prediction_tasks, mrcnn_regressions, target_regressions, target_class_ids)
         # mrcnn can be run without pixelwise annotations available (Faster R-CNN mode).
         # In this case, the mask_loss is taken out of training.
         if self.cf.frcnn_mode:
             mrcnn_mask_loss = torch.FloatTensor([0]).cuda()
         else:
             mrcnn_mask_loss = compute_mrcnn_mask_loss(mrcnn_pred_mask, target_mask, target_class_ids)
 
         loss = batch_rpn_class_loss + batch_rpn_bbox_loss +\
                mrcnn_bbox_loss + mrcnn_mask_loss +  mrcnn_class_loss + mrcnn_regressions_loss
 
         # run unmolding of predictions for monitoring and merge all results to one dictionary.
         return_masks = self.cf.return_masks_in_val if is_validation else self.cf.return_masks_in_train
         results_dict = self.get_results(img.shape, detections, detection_masks, box_results_list,
                                         return_masks=return_masks)
 
-        results_dict['seg_preds'] = results_dict['seg_preds'].argmax(axis=1).astype('uint8')[:,np.newaxis]
+        #results_dict['seg_preds'] = results_dict['seg_preds'].argmax(axis=1).astype('uint8')[:,np.newaxis]
         if 'dice' in self.cf.metrics:
             results_dict['batch_dices'] = mutils.dice_per_batch_and_class(
                 results_dict['seg_preds'], batch["seg"], self.cf.num_seg_classes, convert_to_ohe=True)
 
         results_dict['torch_loss'] = loss
         results_dict['class_loss'] = mrcnn_class_loss.item()
         results_dict['bbox_loss'] = mrcnn_bbox_loss.item()
+        results_dict['mask_loss'] = mrcnn_mask_loss.item()
         results_dict['rg_loss'] = mrcnn_regressions_loss.item()
         results_dict['rpn_class_loss'] = rpn_class_loss.item()
         results_dict['rpn_bbox_loss'] = rpn_bbox_loss.item()
-
         return results_dict
 
 
     def test_forward(self, batch, return_masks=True):
         """
         test method. wrapper around forward pass of network without usage of any ground truth information.
         prepares input data for processing and stores outputs in a dictionary.
         :param batch: dictionary containing 'data'
         :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off).
         :return: results_dict: dictionary with keys:
                'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                        [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]
         """
         img = batch['data']
         img = torch.from_numpy(img).float().cuda()
         _, _, _, detections, detection_masks = self.forward(img)
         results_dict = self.get_results(img.shape, detections, detection_masks, return_masks=return_masks)
 
         return results_dict
\ No newline at end of file