diff --git a/datasets/toy/configs.py b/datasets/toy/configs.py
index 1d1b393..df1997b 100644
--- a/datasets/toy/configs.py
+++ b/datasets/toy/configs.py
@@ -1,490 +1,490 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 
 import sys
 import os
 sys.path.append(os.path.dirname(os.path.realpath(__file__)))
 import numpy as np
 from default_configs import DefaultConfigs
 from collections import namedtuple
 
 boxLabel = namedtuple('boxLabel', ["name", "color"])
 Label = namedtuple("Label", ['id', 'name', 'shape', 'radius', 'color', 'regression', 'ambiguities', 'gt_distortion'])
 binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'bin_vals'])
 
 class Configs(DefaultConfigs):
 
     def __init__(self, server_env=None):
         super(Configs, self).__init__(server_env)
 
         #########################
         #         Prepro        #
         #########################
 
         self.pp_rootdir = os.path.join('/home/gregor/datasets/toy', "cyl1ps_dev")
         self.pp_npz_dir = self.pp_rootdir+"_npz"
 
         self.pre_crop_size = [320,320,8] #y,x,z; determines pp data shape (2D easily implementable, but only 3D for now)
         self.min_2d_radius = 6 #in pixels
         self.n_train_samples, self.n_test_samples = 1200, 1000
 
         # not actually real one-hot encoding (ohe) but contains more info: roi-overlap only within classes.
         self.pp_create_ohe_seg = False
         self.pp_empty_samples_ratio = 0.1
 
         self.pp_place_radii_mid_bin = True
         self.pp_only_distort_2d = True
         # outer-most intensity of blurred radii, relative to inner-object intensity. <1 for decreasing, > 1 for increasing.
         # e.g.: setting 0.1 means blurred edge has min intensity 10% as large as inner-object intensity.
         self.pp_blur_min_intensity = 0.2
 
         self.max_instances_per_sample = 1 #how many max instances over all classes per sample (img if 2d, vol if 3d)
         self.max_instances_per_class = self.max_instances_per_sample  # how many max instances per image per class
         self.noise_scale = 0.  # std-dev of gaussian noise
 
         self.ambigs_sampling = "gaussian" #"gaussian" or "uniform"
         """ radius_calib: gt distort for calibrating uncertainty. Range of gt distortion is inferable from
             image by distinguishing it from the rest of the object.
             blurring width around edge will be shifted so that symmetric rel to orig radius.
             blurring scale: if self.ambigs_sampling is uniform, distribution's non-zero range (b-a) will be sqrt(12)*scale
             since uniform dist has variance (b-a)²/12. b,a will be placed symmetrically around unperturbed radius.
             if sampling is gaussian, then scale parameter sets one std dev, i.e., blurring width will be orig_radius * std_dev * 2.
         """
         self.ambiguities = {
              #set which classes to apply which ambs to below in class labels
              #choose out of: 'outer_radius', 'inner_radius', 'radii_relations'.
              #kind              #probability   #scale (gaussian std, relative to unperturbed value)
             #"outer_radius":     (1.,            0.5),
             #"outer_radius_xy":  (1.,            0.5),
             #"inner_radius":     (0.5,            0.1),
             #"radii_relations":  (0.5,            0.1),
             "radius_calib":     (1.,            1./6)
         }
 
         # shape choices: 'cylinder', 'block'
         #                        id,    name,       shape,      radius,                 color,              regression,     ambiguities,    gt_distortion
         self.pp_classes = [Label(1,     'cylinder', 'cylinder', ((6,6,1),(40,40,8)),    (*self.blue, 1.),   "radius_2d",    (),             ()),
                            #Label(2,      'block',      'block',        ((6,6,1),(40,40,8)),  (*self.aubergine,1.),  "radii_2d", (), ('radius_calib',))
             ]
 
 
         #########################
         #         I/O           #
         #########################
 
         self.data_sourcedir = '/home/gregor/datasets/toy/cyl1ps_dev'
 
         if server_env:
             self.data_sourcedir = '/datasets/data_ramien/toy/cyl1ps_dev_npz'
 
 
         self.test_data_sourcedir = os.path.join(self.data_sourcedir, 'test')
         self.data_sourcedir = os.path.join(self.data_sourcedir, "train")
 
         self.info_df_name = 'info_df.pickle'
 
         # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn', 'detection_fpn'].
-        self.model = 'retina_net'
+        self.model = 'mrcnn'
         self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net')
         self.model_path = os.path.join(self.source_dir, self.model_path)
 
 
         #########################
         #      Architecture     #
         #########################
 
         # one out of [2, 3]. dimension the model operates in.
         self.dim = 2
 
         # 'class', 'regression', 'regression_bin', 'regression_ken_gal'
         # currently only tested mode is a single-task at a time (i.e., only one task in below list)
         # but, in principle, tasks could be combined (e.g., object classes and regression per class)
         self.prediction_tasks = ['class', ]
 
         self.start_filts = 48 if self.dim == 2 else 18
         self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2
         self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50'
         self.norm = 'instance_norm' # one of None, 'instance_norm', 'batch_norm'
         self.relu = 'relu'
         # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform')
         self.weight_init = None
 
         self.regression_n_features = 1  # length of regressor target vector
 
 
         #########################
         #      Data Loader      #
         #########################
 
         self.num_epochs = 32
         self.num_train_batches = 120 if self.dim == 2 else 80
         self.batch_size = 8 if self.dim == 2 else 4
 
         self.n_cv_splits = 4
         # select modalities from preprocessed data
         self.channels = [0]
         self.n_channels = len(self.channels)
 
         # which channel (mod) to show as bg in plotting, will be extra added to batch if not in self.channels
         self.plot_bg_chan = 0
         self.crop_margin = [20, 20, 1]  # has to be smaller than respective patch_size//2
         self.patch_size_2D = self.pre_crop_size[:2]
         self.patch_size_3D = self.pre_crop_size[:2]+[8]
 
         # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation.
         self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D
 
         # ratio of free sampled batch elements before class balancing is triggered
         # (>0 to include "empty"/background patches.)
         self.batch_random_ratio = 0.2
         self.balance_target = "class_targets" if 'class' in self.prediction_tasks else "rg_bin_targets"
 
         self.observables_patient = []
         self.observables_rois = []
 
         self.seed = 3 #for generating folds
 
         #############################
         # Colors, Classes, Legends  #
         #############################
         self.plot_frequency = 1
 
         binary_bin_labels = [binLabel(1,  'r<=25',      (*self.green, 1.),      (1,25)),
                              binLabel(2,  'r>25',       (*self.red, 1.),        (25,))]
         quintuple_bin_labels = [binLabel(1,  'r2-10',   (*self.green, 1.),      (2,10)),
                                 binLabel(2,  'r10-20',  (*self.yellow, 1.),     (10,20)),
                                 binLabel(3,  'r20-30',  (*self.orange, 1.),     (20,30)),
                                 binLabel(4,  'r30-40',  (*self.bright_red, 1.), (30,40)),
                                 binLabel(5,  'r>40',    (*self.red, 1.), (40,))]
 
         # choose here if to do 2-way or 5-way regression-bin classification
         task_spec_bin_labels = quintuple_bin_labels
 
         self.class_labels = [
             # regression: regression-task label, either value or "(x,y,z)_radius" or "radii".
             # ambiguities: name of above defined ambig to apply to image data (not gt); need to be iterables!
             # gt_distortion: name of ambig to apply to gt only; needs to be iterable!
             #      #id  #name   #shape  #radius     #color              #regression #ambiguities    #gt_distortion
             Label(  0,  'bg',   None,   (0, 0, 0),  (*self.white, 0.),  (0, 0, 0),  (),             ())]
         if "class" in self.prediction_tasks:
             self.class_labels += self.pp_classes
         else:
             self.class_labels += [Label(1, 'object', 'object', ('various',), (*self.orange, 1.), ('radius_2d',), ("various",), ('various',))]
 
 
         if any(['regression' in task for task in self.prediction_tasks]):
             self.bin_labels = [binLabel(0,  'bg',       (*self.white, 1.),      (0,))]
             self.bin_labels += task_spec_bin_labels
             self.bin_id2label = {label.id: label for label in self.bin_labels}
             bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels]
             self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)}
             self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)]
             self.bin_dict = {label.id: label.name for label in self.bin_labels if label.id != 0}
 
         if self.class_specific_seg:
           self.seg_labels = self.class_labels
 
         self.box_type2label = {label.name: label for label in self.box_labels}
         self.class_id2label = {label.id: label for label in self.class_labels}
         self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0}
 
         self.seg_id2label = {label.id: label for label in self.seg_labels}
         self.cmap = {label.id: label.color for label in self.seg_labels}
 
         self.plot_prediction_histograms = True
         self.plot_stat_curves = False
         self.has_colorchannels = False
         self.plot_class_ids = True
 
         self.num_classes = len(self.class_dict)
         self.num_seg_classes = len(self.seg_labels)
 
         #########################
         #   Data Augmentation   #
         #########################
         self.do_aug = True
         self.da_kwargs = {
             'mirror': True,
             'mirror_axes': tuple(np.arange(0, self.dim, 1)),
             'do_elastic_deform': False,
             'alpha': (500., 1500.),
             'sigma': (40., 45.),
             'do_rotation': False,
             'angle_x': (0., 2 * np.pi),
             'angle_y': (0., 0),
             'angle_z': (0., 0),
             'do_scale': False,
             'scale': (0.8, 1.1),
             'random_crop': False,
             'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3),
             'border_mode_data': 'constant',
             'border_cval_data': 0,
             'order_data': 1
         }
 
         if self.dim == 3:
             self.da_kwargs['do_elastic_deform'] = False
             self.da_kwargs['angle_x'] = (0, 0.0)
             self.da_kwargs['angle_y'] = (0, 0.0)  # must be 0!!
             self.da_kwargs['angle_z'] = (0., 2 * np.pi)
 
         #########################
         #  Schedule / Selection #
         #########################
 
         # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training)
         # the former is morge accurate, while the latter is faster (depending on volume size)
         self.val_mode = 'val_sampling' # one of 'val_sampling' , 'val_patient'
         if self.val_mode == 'val_patient':
             self.max_val_patients = 220  # if 'all' iterates over entire val_set once.
         if self.val_mode == 'val_sampling':
             self.num_val_batches = 35 if self.dim==2 else 25
 
         self.save_n_models = 2
         self.min_save_thresh = 1 if self.dim == 2 else 1  # =wait time in epochs
         if "class" in self.prediction_tasks:
             self.model_selection_criteria = {name + "_ap": 1. for name in self.class_dict.values()}
         elif any("regression" in task for task in self.prediction_tasks):
             self.model_selection_criteria = {name + "_ap": 0.2 for name in self.class_dict.values()}
             self.model_selection_criteria.update({name + "_avp": 0.8 for name in self.class_dict.values()})
 
         self.lr_decay_factor = 0.5
         self.scheduling_patience = int(self.num_epochs / 5)
         self.weight_decay = 1e-5
         self.clip_norm = None  # number or None
 
         #########################
         #   Testing / Plotting  #
         #########################
 
         self.test_aug_axes = (0,1,(0,1)) # None or list: choices are 0,1,(0,1)
         self.held_out_test_set = True
         self.max_test_patients = "all"  # number or "all" for all
 
         self.test_against_exact_gt = True # only True implemented
         self.val_against_exact_gt = False # True is an unrealistic --> irrelevant scenario.
         self.report_score_level = ['rois']  # 'patient' or 'rois' (incl)
         self.patient_class_of_interest = 1
         self.patient_bin_of_interest = 2
 
         self.eval_bins_separately = False#"additionally" if not 'class' in self.prediction_tasks else False
         self.metrics = ['ap', 'auc', 'dice']
         if any(['regression' in task for task in self.prediction_tasks]):
             self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp',
                              'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp']
         if 'aleatoric' in self.model:
             self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted']
         self.evaluate_fold_means = True
 
         self.ap_match_ious = [0.5]  # threshold(s) for considering a prediction as true positive
         self.min_det_thresh = 0.3
 
         self.model_max_iou_resolution = 0.2
 
         # aggregation method for test and val_patient predictions.
         # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf,
         # nms = standard non-maximum suppression, or None = no clustering
         self.clustering = 'wbc'
         # iou thresh (exclusive!) for regarding two preds as concerning the same ROI
         self.clustering_iou = self.model_max_iou_resolution  # has to be larger than desired possible overlap iou of model predictions
 
         self.merge_2D_to_3D_preds = False
         self.merge_3D_iou = self.model_max_iou_resolution
         self.n_test_plots = 1  # per fold and rank
 
         self.test_n_epochs = self.save_n_models  # should be called n_test_ens, since is number of models to ensemble over during testing
         # is multiplied by (1 + nr of test augs)
 
         #########################
         #   Assertions          #
         #########################
         if not 'class' in self.prediction_tasks:
             assert self.num_classes == 1
 
         #########################
         #   Add model specifics #
         #########################
 
         {'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs,
          'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs,
          'detection_unet': self.add_det_unet_configs, 'detection_fpn': self.add_det_fpn_configs
          }[self.model]()
 
     def rg_val_to_bin_id(self, rg_val):
         #only meant for isotropic radii!!
         # only 2D radii (x and y dims) or 1D (x or y) are expected
         return np.round(np.digitize(rg_val, self.bin_edges).mean())
 
 
     def add_det_fpn_configs(self):
 
       self.learning_rate = [1 * 1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True
       self.scheduling_criterion = 'torch_loss'
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       self.n_roi_candidates = 4 if self.dim == 2 else 6
       # max number of roi candidates to identify per image (slice in 2D, volume in 3D)
 
       # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
       self.seg_loss_mode = 'wce'
       self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1]
 
       self.fp_dice_weight = 1 if self.dim == 2 else 1
       # if <1, false positive predictions in foreground are penalized less.
 
       self.detection_min_confidence = 0.05
       # how to determine score of roi: 'max' or 'median'
       self.score_det = 'max'
 
     def add_det_unet_configs(self):
 
       self.learning_rate = [1 * 1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True
       self.scheduling_criterion = "torch_loss"
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       # max number of roi candidates to identify per image (slice in 2D, volume in 3D)
       self.n_roi_candidates = 4 if self.dim == 2 else 6
 
       # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
       self.seg_loss_mode = 'wce'
       self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1]
       # if <1, false positive predictions in foreground are penalized less.
       self.fp_dice_weight = 1 if self.dim == 2 else 1
 
       self.detection_min_confidence = 0.05
       # how to determine score of roi: 'max' or 'median'
       self.score_det = 'max'
 
       self.init_filts = 32
       self.kernel_size = 3  # ks for horizontal, normal convs
       self.kernel_size_m = 2  # ks for max pool
       self.pad = "same"  # "same" or integer, padding of horizontal convs
 
     def add_mrcnn_configs(self):
 
       self.learning_rate = [1e-4] * self.num_epochs
       self.dynamic_lr_scheduling = True  # with scheduler set in exec
       self.scheduling_criterion = max(self.model_selection_criteria, key=self.model_selection_criteria.get)
       self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max'
 
       # number of classes for network heads: n_foreground_classes + 1 (background)
       self.head_classes = self.num_classes + 1 if 'class' in self.prediction_tasks else 2
 
       # feed +/- n neighbouring slices into channel dimension. set to None for no context.
       self.n_3D_context = None
       if self.n_3D_context is not None and self.dim == 2:
         self.n_channels *= (self.n_3D_context * 2 + 1)
 
       self.detect_while_training = True
       # disable the re-sampling of mask proposals to original size for speed-up.
       # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching),
       # mask outputs are optional.
       self.return_masks_in_train = True
       self.return_masks_in_val = True
       self.return_masks_in_test = True
 
       # feature map strides per pyramid level are inferred from architecture. anchor scales are set accordingly.
       self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]}
       # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale
       # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.)
       self.rpn_anchor_scales = {'xy': [[4], [8], [16], [32]], 'z': [[1], [2], [4], [8]]}
       # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3.
       self.pyramid_levels = [0, 1, 2, 3]
       # number of feature maps in rpn. typically lowered in 3D to save gpu-memory.
       self.n_rpn_features = 512 if self.dim == 2 else 64
 
       # anchor ratios and strides per position in feature maps.
       self.rpn_anchor_ratios = [0.5, 1., 2.]
       self.rpn_anchor_stride = 1
       # Threshold for first stage (RPN) non-maximum suppression (NMS):  LOWER == HARDER SELECTION
       self.rpn_nms_threshold = max(0.8, self.model_max_iou_resolution)
 
       # loss sampling settings.
       self.rpn_train_anchors_per_image = 4
       self.train_rois_per_image = 6 # per batch_instance
       self.roi_positive_ratio = 0.5
       self.anchor_matching_iou = 0.8
 
       # k negative example candidates are drawn from a pool of size k*shem_poolsize (stochastic hard-example mining),
       # where k<=#positive examples.
       self.shem_poolsize = 2
 
       self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3)
       self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5)
       self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10)
 
       self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
       self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
       self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]])
       self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1],
                              self.patch_size_3D[2], self.patch_size_3D[2]])  # y1,x1,y2,x2,z1,z2
 
       if self.dim == 2:
         self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4]
         self.bbox_std_dev = self.bbox_std_dev[:4]
         self.window = self.window[:4]
         self.scale = self.scale[:4]
 
       self.plot_y_max = 1.5
       self.n_plot_rpn_props = 5 if self.dim == 2 else 30  # per batch_instance (slice in 2D / patient in 3D)
 
       # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element.
       self.pre_nms_limit = 2000 if self.dim == 2 else 4000
 
       # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True,
       # since proposals of the entire batch are forwarded through second stage as one "batch".
       self.roi_chunk_size = 1300 if self.dim == 2 else 500
       self.post_nms_rois_training = 200 * (self.head_classes-1) if self.dim == 2 else 400
       self.post_nms_rois_inference = 200 * (self.head_classes-1)
 
       # Final selection of detections (refine_detections)
       self.model_max_instances_per_batch_element = 9 if self.dim == 2 else 18 # per batch element and class.
       self.detection_nms_threshold = self.model_max_iou_resolution  # needs to be > 0, otherwise all predictions are one cluster.
       self.model_min_confidence = 0.2  # iou for nms in box refining (directly after heads), should be >0 since ths>=x in mrcnn.py
 
       if self.dim == 2:
         self.backbone_shapes = np.array(
           [[int(np.ceil(self.patch_size[0] / stride)),
             int(np.ceil(self.patch_size[1] / stride))]
            for stride in self.backbone_strides['xy']])
       else:
         self.backbone_shapes = np.array(
           [[int(np.ceil(self.patch_size[0] / stride)),
             int(np.ceil(self.patch_size[1] / stride)),
             int(np.ceil(self.patch_size[2] / stride_z))]
            for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z']
                                        )])
 
       if self.model == 'retina_net' or self.model == 'retina_unet':
         # whether to use focal loss or SHEM for loss-sample selection
         self.focal_loss = False
         # implement extra anchor-scales according to https://arxiv.org/abs/1708.02002
         self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                         self.rpn_anchor_scales['xy']]
         self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
                                        self.rpn_anchor_scales['z']]
         self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3
 
         # pre-selection of detections for NMS-speedup. per entire batch.
         self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size
 
         # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002
         self.anchor_matching_iou = 0.7
 
         if self.model == 'retina_unet':
           self.operate_stride1 = True
diff --git a/utils/exp_utils.py b/utils/exp_utils.py
index 4dd5d88..73c1b43 100644
--- a/utils/exp_utils.py
+++ b/utils/exp_utils.py
@@ -1,691 +1,692 @@
 #!/usr/bin/env python
 # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
 # import plotting as plg
 
 import sys
 import os
 import subprocess
 from multiprocessing import Process
 import threading
 import pickle
 import importlib.util
 import psutil
 import time
 import nvidia_smi
 
 import logging
 from torch.utils.tensorboard import SummaryWriter
 
 from collections import OrderedDict
 import numpy as np
 import pandas as pd
 import torch
 
 
 def import_module(name, path):
     """
     correct way of importing a module dynamically in python 3.
     :param name: name given to module instance.
     :param path: path to module.
     :return: module: returned module instance.
     """
     spec = importlib.util.spec_from_file_location(name, path)
     module = importlib.util.module_from_spec(spec)
     spec.loader.exec_module(module)
     return module
 
 
 def save_obj(obj, name):
     """Pickle a python object."""
     with open(name + '.pkl', 'wb') as f:
         pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
 
 
 def load_obj(file_path):
     with open(file_path, 'rb') as handle:
         return pickle.load(handle)
 
 
 def IO_safe(func, *args, _tries=5, _raise=True, **kwargs):
     """ Wrapper calling function func with arguments args and keyword arguments kwargs to catch input/output errors
         on cluster.
     :param func: function to execute (intended to be read/write operation to a problematic cluster drive, but can be
         any function).
     :param args: positional args of func.
     :param kwargs: kw args of func.
     :param _tries: how many attempts to make executing func.
     """
     for _try in range(_tries):
         try:
             return func(*args, **kwargs)
         except OSError as e:  # to catch cluster issues with network drives
             if _raise:
                 raise e
             else:
                 print("After attempting execution {} time{}, following error occurred:\n{}".format(_try + 1,
                                                                                                    "" if _try == 0 else "s",
                                                                                                    e))
                 continue
 
 def split_off_process(target, *args, daemon=False, **kwargs):
     """Start a process that won't block parent script.
     No join(), no return value. If daemon=False: before parent exits, it waits for this to finish.
     """
     p = Process(target=target, args=tuple(args), kwargs=kwargs, daemon=daemon)
     p.start()
     return p
 
 
 def query_nvidia_gpu(device_id, d_keyword=None, no_units=False):
     """
     :param device_id:
     :param d_keyword: -d, --display argument (keyword(s) for selective display), all are selected if None
     :return: dict of gpu-info items
     """
     cmd = ['nvidia-smi', '-i', str(device_id), '-q']
     if d_keyword is not None:
         cmd += ['-d', d_keyword]
     outp = subprocess.check_output(cmd).strip().decode('utf-8').split("\n")
     outp = [x for x in outp if len(x) > 0]
     headers = [ix for ix, item in enumerate(outp) if len(item.split(":")) == 1] + [len(outp)]
 
     out_dict = {}
     for lix, hix in enumerate(headers[:-1]):
         head = outp[hix].strip().replace(" ", "_").lower()
         out_dict[head] = {}
         for lix2 in range(hix, headers[lix + 1]):
             try:
                 key, val = [x.strip().lower() for x in outp[lix2].split(":")]
                 if no_units:
                     val = val.split()[0]
                 out_dict[head][key] = val
             except:
                 pass
 
     return out_dict
 
 
 class CombinedPrinter(object):
     """combined print function.
     prints to logger and/or file if given, to normal print if non given.
 
     """
 
     def __init__(self, logger=None, file=None):
 
         if logger is None and file is None:
             self.out = [print]
         elif logger is None:
             self.out = [file.write]
         elif file is None:
             self.out = [logger.info]
         else:
             self.out = [logger.info, file.write]
 
     def __call__(self, string):
         for fct in self.out:
             fct(string)
 
 
 class Nvidia_GPU_Logger(object):
     def __init__(self):
         self.count = None
 
     def get_vals(self):
 
         # cmd = ['nvidia-settings', '-t', '-q', 'GPUUtilization']
         # gpu_util = subprocess.check_output(cmd).strip().decode('utf-8').split(",")
         # gpu_util = dict([f.strip().split("=") for f in gpu_util])
         # cmd[-1] = 'UsedDedicatedGPUMemory'
         # gpu_used_mem = subprocess.check_output(cmd).strip().decode('utf-8')
 
 
         nvidia_smi.nvmlInit()
         # card id 0 hardcoded here, there is also a call to get all available card ids, so we could iterate
         self.gpu_handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
         util_res = nvidia_smi.nvmlDeviceGetUtilizationRates(self.gpu_handle)
         #mem_res = nvidia_smi.nvmlDeviceGetMemoryInfo(self.gpu_handle)
         # current_vals = {"gpu_mem_alloc": mem_res.used / (1024**2), "gpu_graphics_util": int(gpu_util['graphics']),
         #                 "gpu_mem_util": gpu_util['memory'], "time": time.time()}
         current_vals = {"gpu_graphics_util": float(util_res.gpu),
                         "time": time.time()}
         return current_vals
 
     def loop(self, interval):
         i = 0
         while True:
             current_vals = self.get_vals()
             self.log["time"].append(time.time())
             self.log["gpu_util"].append(current_vals["gpu_graphics_util"])
             if self.count is not None:
                 i += 1
                 if i == self.count:
                     exit(0)
             time.sleep(self.interval)
 
     def start(self, interval=1.):
         self.interval = interval
         self.start_time = time.time()
         self.log = {"time": [], "gpu_util": []}
         if self.interval is not None:
             thread = threading.Thread(target=self.loop)
             thread.daemon = True
             thread.start()
 
 class CombinedLogger(object):
     """Combine console and tensorboard logger and record system metrics.
     """
 
     def __init__(self, name, log_dir, server_env=True, fold="all", sysmetrics_interval=2):
         self.pylogger = logging.getLogger(name)
         self.tboard = SummaryWriter(log_dir=os.path.join(log_dir, "tboard"))
         self.times = {}
         self.log_dir = log_dir
         self.fold = str(fold)
         self.server_env = server_env
 
         self.pylogger.setLevel(logging.DEBUG)
         self.log_file = os.path.join(log_dir, "fold_"+self.fold, 'exec.log')
         os.makedirs(os.path.dirname(self.log_file), exist_ok=True)
         self.pylogger.addHandler(logging.FileHandler(self.log_file))
         if not server_env:
             self.pylogger.addHandler(ColorHandler())
         else:
             self.pylogger.addHandler(logging.StreamHandler())
         self.pylogger.propagate = False
 
         # monitor system metrics (cpu, mem, ...)
         if not server_env and sysmetrics_interval > 0:
             self.sysmetrics = pd.DataFrame(
                 columns=["global_step", "rel_time", r"CPU (%)", "mem_used (GB)", r"mem_used (%)",
                          r"swap_used (GB)", r"gpu_utilization (%)"], dtype="float16")
             for device in range(torch.cuda.device_count()):
                 self.sysmetrics[
                     "mem_allocd (GB) by torch on {:10s}".format(torch.cuda.get_device_name(device))] = np.nan
                 self.sysmetrics[
                     "mem_cached (GB) by torch on {:10s}".format(torch.cuda.get_device_name(device))] = np.nan
             self.sysmetrics_start(sysmetrics_interval)
             pass
         else:
             print("NOT logging sysmetrics")
 
     def __getattr__(self, attr):
         """delegate all undefined method requests to objects of
         this class in order pylogger, tboard (first find first serve).
         E.g., combinedlogger.add_scalars(...) should trigger self.tboard.add_scalars(...)
         """
         for obj in [self.pylogger, self.tboard]:
             if attr in dir(obj):
                 return getattr(obj, attr)
         print("logger attr not found")
         #raise AttributeError("CombinedLogger has no attribute {}".format(attr))
 
     def set_logfile(self, fold=None, log_file=None):
         if fold is not None:
             self.fold = str(fold)
         if log_file is None:
             self.log_file = os.path.join(self.log_dir, "fold_"+self.fold, 'exec.log')
         else:
             self.log_file = log_file
         os.makedirs(os.path.dirname(self.log_file), exist_ok=True)
         for hdlr in self.pylogger.handlers:
             hdlr.close()
         self.pylogger.handlers = []
         self.pylogger.addHandler(logging.FileHandler(self.log_file))
         if not self.server_env:
             self.pylogger.addHandler(ColorHandler())
         else:
             self.pylogger.addHandler(logging.StreamHandler())
 
     def time(self, name, toggle=None):
         """record time-spans as with a stopwatch.
         :param name:
         :param toggle: True^=On: start time recording, False^=Off: halt rec. if None determine from current status.
         :return: either start-time or last recorded interval
         """
         if toggle is None:
             if name in self.times.keys():
                 toggle = not self.times[name]["toggle"]
             else:
                 toggle = True
 
         if toggle:
             if not name in self.times.keys():
                 self.times[name] = {"total": 0, "last": 0}
             elif self.times[name]["toggle"] == toggle:
                 self.info("restarting running stopwatch")
             self.times[name]["last"] = time.time()
             self.times[name]["toggle"] = toggle
             return time.time()
         else:
             if toggle == self.times[name]["toggle"]:
                 self.info("WARNING: tried to stop stopped stop watch: {}.".format(name))
             self.times[name]["last"] = time.time() - self.times[name]["last"]
             self.times[name]["total"] += self.times[name]["last"]
             self.times[name]["toggle"] = toggle
             return self.times[name]["last"]
 
     def get_time(self, name=None, kind="total", format=None, reset=False):
         """
         :param name:
         :param kind: 'total' or 'last'
         :param format: None for float, "hms"/"ms" for (hours), mins, secs as string
         :param reset: reset time after retrieving
         :return:
         """
         if name is None:
             times = self.times
             if reset:
                 self.reset_time()
             return times
 
         else:
             if self.times[name]["toggle"]:
                 self.time(name, toggle=False)
             time = self.times[name][kind]
             if format == "hms":
                 m, s = divmod(time, 60)
                 h, m = divmod(m, 60)
                 time = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(m), int(s))
             elif format == "ms":
                 m, s = divmod(time, 60)
                 time = "{:02d}m:{:02d}s".format(int(m), int(s))
             if reset:
                 self.reset_time(name)
             return time
 
     def reset_time(self, name=None):
         if name is None:
             self.times = {}
         else:
             del self.times[name]
 
     def sysmetrics_update(self, global_step=None):
         if global_step is None:
             global_step = time.strftime("%x_%X")
         mem = psutil.virtual_memory()
         mem_used = (mem.total - mem.available)
         gpu_vals = self.gpu_logger.get_vals()
         rel_time = time.time() - self.sysmetrics_start_time
         self.sysmetrics.loc[len(self.sysmetrics)] = [global_step, rel_time,
                                                      psutil.cpu_percent(), mem_used / 1024 ** 3,
                                                      mem_used / mem.total * 100,
                                                      psutil.swap_memory().used / 1024 ** 3,
                                                      int(gpu_vals['gpu_graphics_util']),
                                                      *[torch.cuda.memory_allocated(d) / 1024 ** 3 for d in
                                                        range(torch.cuda.device_count())],
                                                      *[torch.cuda.memory_cached(d) / 1024 ** 3 for d in
                                                        range(torch.cuda.device_count())]
                                                      ]
         return self.sysmetrics.loc[len(self.sysmetrics) - 1].to_dict()
 
     def sysmetrics2tboard(self, metrics=None, global_step=None, suptitle=None):
         tag = "per_time"
         if metrics is None:
             metrics = self.sysmetrics_update(global_step=global_step)
             tag = "per_epoch"
 
         if suptitle is not None:
             suptitle = str(suptitle)
         elif self.fold != "":
             suptitle = "Fold_" + str(self.fold)
         if suptitle is not None:
             self.tboard.add_scalars(suptitle + "/System_Metrics/" + tag,
                                     {k: v for (k, v) in metrics.items() if (k != "global_step"
                                                                             and k != "rel_time")}, global_step)
 
     def sysmetrics_loop(self):
         try:
             os.nice(-19)
             self.info("Logging system metrics with superior process priority.")
         except:
             self.info("Logging system metrics without superior process priority.")
         while True:
             metrics = self.sysmetrics_update()
             self.sysmetrics2tboard(metrics, global_step=metrics["rel_time"])
             # print("thread alive", self.thread.is_alive())
             time.sleep(self.sysmetrics_interval)
 
     def sysmetrics_start(self, interval):
         if interval is not None and interval > 0:
             self.sysmetrics_interval = interval
             self.gpu_logger = Nvidia_GPU_Logger()
             self.sysmetrics_start_time = time.time()
             self.sys_metrics_process = split_off_process(target=self.sysmetrics_loop, daemon=True)
             # self.thread = threading.Thread(target=self.sysmetrics_loop)
             # self.thread.daemon = True
             # self.thread.start()
 
     def sysmetrics_save(self, out_file):
         self.sysmetrics.to_pickle(out_file)
 
     def metrics2tboard(self, metrics, global_step=None, suptitle=None):
         """
         :param metrics: {'train': dataframe, 'val':df}, df as produced in
             evaluator.py.evaluate_predictions
         """
         # print("metrics", metrics)
         if global_step is None:
             global_step = len(metrics['train'][list(metrics['train'].keys())[0]]) - 1
         if suptitle is not None:
             suptitle = str(suptitle)
         else:
             suptitle = "Fold_" + str(self.fold)
 
         for key in ['train', 'val']:
             # series = {k:np.array(v[-1]) for (k,v) in metrics[key].items() if not np.isnan(v[-1]) and not 'Bin_Stats' in k}
             loss_series = {}
             unc_series = {}
             bin_stat_series = {}
             mon_met_series = {}
             for tag, val in metrics[key].items():
                 val = val[-1]  # maybe remove list wrapping, recording in evaluator?
                 if 'bin_stats' in tag.lower() and not np.isnan(val):
                     bin_stat_series["{}".format(tag.split("/")[-1])] = val
                 elif 'uncertainty' in tag.lower() and not np.isnan(val):
                     unc_series["{}".format(tag)] = val
                 elif 'loss' in tag.lower() and not np.isnan(val):
                     loss_series["{}".format(tag)] = val
                 elif not np.isnan(val):
                     mon_met_series["{}".format(tag)] = val
 
             self.tboard.add_scalars(suptitle + "/Binary_Statistics/{}".format(key), bin_stat_series, global_step)
             self.tboard.add_scalars(suptitle + "/Uncertainties/{}".format(key), unc_series, global_step)
             self.tboard.add_scalars(suptitle + "/Losses/{}".format(key), loss_series, global_step)
             self.tboard.add_scalars(suptitle + "/Monitor_Metrics/{}".format(key), mon_met_series, global_step)
         self.tboard.add_scalars(suptitle + "/Learning_Rate", metrics["lr"], global_step)
         return
 
     def batchImgs2tboard(self, batch, results_dict, cmap, boxtype2color, img_bg=False, global_step=None):
         raise NotImplementedError("not up-to-date, problem with importing plotting-file, torchvision dependency.")
         if len(batch["seg"].shape) == 5:  # 3D imgs
             slice_ix = np.random.randint(batch["seg"].shape[-1])
             seg_gt = plg.to_rgb(batch['seg'][:, 0, :, :, slice_ix], cmap)
             seg_pred = plg.to_rgb(results_dict['seg_preds'][:, 0, :, :, slice_ix], cmap)
 
             mod_img = plg.mod_to_rgb(batch["data"][:, 0, :, :, slice_ix]) if img_bg else None
 
         elif len(batch["seg"].shape) == 4:
             seg_gt = plg.to_rgb(batch['seg'][:, 0, :, :], cmap)
             seg_pred = plg.to_rgb(results_dict['seg_preds'][:, 0, :, :], cmap)
             mod_img = plg.mod_to_rgb(batch["data"][:, 0]) if img_bg else None
         else:
             raise Exception("batch content has wrong format: {}".format(batch["seg"].shape))
 
         # from here on only works in 2D
         seg_gt = np.transpose(seg_gt, axes=(0, 3, 1, 2))  # previous shp: b,x,y,c
         seg_pred = np.transpose(seg_pred, axes=(0, 3, 1, 2))
 
         seg = np.concatenate((seg_gt, seg_pred), axis=0)
         # todo replace torchvision (tv) dependency
         seg = tv.utils.make_grid(torch.from_numpy(seg), nrow=2)
         self.tboard.add_image("Batch seg, 1st col: gt, 2nd: pred.", seg, global_step=global_step)
 
         if img_bg:
             bg_img = np.transpose(mod_img, axes=(0, 3, 1, 2))
         else:
             bg_img = seg_gt
         box_imgs = plg.draw_boxes_into_batch(bg_img, results_dict["boxes"], boxtype2color)
         box_imgs = tv.utils.make_grid(torch.from_numpy(box_imgs), nrow=4)
         self.tboard.add_image("Batch bboxes", box_imgs, global_step=global_step)
 
         return
 
     def __del__(self):  # otherwise might produce multiple prints e.g. in ipython console
         self.sys_metrics_process.terminate()
         for hdlr in self.pylogger.handlers:
             hdlr.close()
         self.pylogger.handlers = []
         del self.pylogger
         self.tboard.close()
 
 
 def get_logger(exp_dir, server_env=False, sysmetrics_interval=2):
     log_dir = os.path.join(exp_dir, "logs")
     logger = CombinedLogger('Reg R-CNN', log_dir, server_env=server_env,
                             sysmetrics_interval=sysmetrics_interval)
     print("logging to {}".format(logger.log_file))
     return logger
 
 
 def prep_exp(dataset_path, exp_path, server_env, use_stored_settings=True, is_training=True):
     """
     I/O handling, creating of experiment folder structure. Also creates a snapshot of configs/model scripts and copies them to the exp_dir.
     This way the exp_dir contains all info needed to conduct an experiment, independent to changes in actual source code. Thus, training/inference of this experiment can be started at anytime.
     Therefore, the model script is copied back to the source code dir as tmp_model (tmp_backbone).
     Provides robust structure for cloud deployment.
     :param dataset_path: path to source code for specific data set. (e.g. medicaldetectiontoolkit/lidc_exp)
     :param exp_path: path to experiment directory.
     :param server_env: boolean flag. pass to configs script for cloud deployment.
     :param use_stored_settings: boolean flag. When starting training: If True, starts training from snapshot in existing
         experiment directory, else creates experiment directory on the fly using configs/model scripts from source code.
     :param is_training: boolean flag. distinguishes train vs. inference mode.
     :return: configs object.
     """
 
     if is_training:
 
         if use_stored_settings:
             cf_file = import_module('cf', os.path.join(exp_path, 'configs.py'))
             cf = cf_file.Configs(server_env)
             # in this mode, previously saved model and backbone need to be found in exp dir.
             if not os.path.isfile(os.path.join(exp_path, 'model.py')) or \
                     not os.path.isfile(os.path.join(exp_path, 'backbone.py')):
                 raise Exception(
                     "Selected use_stored_settings option but no model and/or backbone source files exist in exp dir.")
             cf.model_path = os.path.join(exp_path, 'model.py')
             cf.backbone_path = os.path.join(exp_path, 'backbone.py')
         else:  # this case overwrites settings files in exp dir, i.e., default_configs, configs, backbone, model
             os.makedirs(exp_path, exist_ok=True)
             # run training with source code info and copy snapshot of model to exp_dir for later testing (overwrite scripts if exp_dir already exists.)
             subprocess.call('cp {} {}'.format('default_configs.py', os.path.join(exp_path, 'default_configs.py')),
                             shell=True)
             subprocess.call(
                 'cp {} {}'.format(os.path.join(dataset_path, 'configs.py'), os.path.join(exp_path, 'configs.py')),
                 shell=True)
             cf_file = import_module('cf_file', os.path.join(dataset_path, 'configs.py'))
             cf = cf_file.Configs(server_env)
             subprocess.call('cp {} {}'.format(cf.model_path, os.path.join(exp_path, 'model.py')), shell=True)
             subprocess.call('cp {} {}'.format(cf.backbone_path, os.path.join(exp_path, 'backbone.py')), shell=True)
             if os.path.isfile(os.path.join(exp_path, "fold_ids.pickle")):
                 subprocess.call('rm {}'.format(os.path.join(exp_path, "fold_ids.pickle")), shell=True)
 
     else:  # testing, use model and backbone stored in exp dir.
         cf_file = import_module('cf', os.path.join(exp_path, 'configs.py'))
         cf = cf_file.Configs(server_env)
         cf.model_path = os.path.join(exp_path, 'model.py')
         cf.backbone_path = os.path.join(exp_path, 'backbone.py')
 
     cf.exp_dir = exp_path
     cf.test_dir = os.path.join(cf.exp_dir, 'test')
     cf.plot_dir = os.path.join(cf.exp_dir, 'plots')
     if not os.path.exists(cf.test_dir):
         os.mkdir(cf.test_dir)
     if not os.path.exists(cf.plot_dir):
         os.mkdir(cf.plot_dir)
     cf.experiment_name = exp_path.split("/")[-1]
     cf.dataset_name = dataset_path
     cf.server_env = server_env
     cf.created_fold_id_pickle = False
 
     return cf
 
 
 class ModelSelector:
     '''
     saves a checkpoint after each epoch as 'last_state' (can be loaded to continue interrupted training).
     saves the top-k (k=cf.save_n_models) ranked epochs. In inference, predictions of multiple epochs can be ensembled
     to improve performance.
     '''
 
     def __init__(self, cf, logger):
 
         self.cf = cf
         self.logger = logger
 
         self.model_index = pd.DataFrame(columns=["rank", "score", "criteria_values", "file_name"],
                                         index=pd.RangeIndex(self.cf.min_save_thresh, self.cf.num_epochs, name="epoch"))
 
     def run_model_selection(self, net, optimizer, monitor_metrics, epoch):
         """rank epoch via weighted mean from self.cf.model_selection_criteria: {criterion : weight}
         :param net:
         :param optimizer:
         :param monitor_metrics:
         :param epoch:
         :return:
         """
         crita = self.cf.model_selection_criteria  # shorter alias
         metrics =  monitor_metrics['val']
 
         epoch_score = np.sum([metrics[criterion][-1] * weight for criterion, weight in crita.items() if
                               not np.isnan(metrics[criterion][-1])])
         if not self.cf.resume_from_checkpoint:
             epoch_score_check = np.sum([metrics[criterion][epoch] * weight for criterion, weight in crita.items() if
                                   not np.isnan(metrics[criterion][epoch])])
             assert np.all(epoch_score == epoch_score_check)
 
         self.model_index.loc[epoch, ["score", "criteria_values"]] = epoch_score, {cr: metrics[cr][-1] for cr in crita.keys()}
 
         nonna_ics = self.model_index["score"].dropna(axis=0).index
-        order = np.argsort(self.model_index.loc[nonna_ics, "score"].values)[::-1]
+        order = np.argsort(self.model_index.loc[nonna_ics, "score"].to_numpy(), kind="stable")[::-1]
         self.model_index.loc[nonna_ics, "rank"] = np.argsort(order) + 1 # no zero-indexing for ranks (best rank is 1).
 
         rank = int(self.model_index.loc[epoch, "rank"])
         if rank <= self.cf.save_n_models:
             name = '{}_best_params.pth'.format(epoch)
             if self.cf.server_env:
                 IO_safe(torch.save, net.state_dict(), os.path.join(self.cf.fold_dir, name))
             else:
                 torch.save(net.state_dict(), os.path.join(self.cf.fold_dir, name))
             self.model_index.loc[epoch, "file_name"] = name
             self.logger.info("saved current epoch {} at rank {}".format(epoch, rank))
 
             clean_up = self.model_index.dropna(axis=0, subset=["file_name"])
             clean_up = clean_up[clean_up["rank"] > self.cf.save_n_models]
             if clean_up.size > 0:
-                subprocess.call("rm {}".format(os.path.join(self.cf.fold_dir, clean_up["file_name"].to_numpy().item())), shell=True)
+                file_name = clean_up["file_name"].to_numpy().item()
+                subprocess.call("rm {}".format(os.path.join(self.cf.fold_dir, file_name)), shell=True)
                 self.logger.info("removed outranked epoch {} at {}".format(clean_up.index.values.item(),
-                                                                       os.path.join(self.cf.fold_dir, clean_up["file_name"].to_numpy().item())))
+                                                                       os.path.join(self.cf.fold_dir, file_name)))
                 self.model_index.loc[clean_up.index, "file_name"] = np.nan
 
         state = {
             'epoch': epoch,
             'state_dict': net.state_dict(),
             'optimizer': optimizer.state_dict(),
             'model_index': self.model_index,
         }
 
         if self.cf.server_env:
             IO_safe(torch.save, state, os.path.join(self.cf.fold_dir, 'last_state.pth'))
         else:
             torch.save(state, os.path.join(self.cf.fold_dir, 'last_state.pth'))
 
 def load_checkpoint(checkpoint_path, net, optimizer, model_selector):
     checkpoint = torch.load(checkpoint_path)
     net.load_state_dict(checkpoint['state_dict'])
     optimizer.load_state_dict(checkpoint['optimizer'])
     model_selector.model_index = checkpoint["model_index"]
     return checkpoint['epoch'] + 1, net, optimizer, model_selector
 
 
 def prepare_monitoring(cf):
     """
     creates dictionaries, where train/val metrics are stored.
     """
     metrics = {}
     # first entry for loss dict accounts for epoch starting at 1.
     metrics['train'] = OrderedDict()  # [(l_name, [np.nan]) for l_name in cf.losses_to_monitor] )
     metrics['val'] = OrderedDict()  # [(l_name, [np.nan]) for l_name in cf.losses_to_monitor] )
     metric_classes = []
     if 'rois' in cf.report_score_level:
         metric_classes.extend([v for k, v in cf.class_dict.items()])
         if hasattr(cf, "eval_bins_separately") and cf.eval_bins_separately:
             metric_classes.extend([v for k, v in cf.bin_dict.items()])
     if 'patient' in cf.report_score_level:
         metric_classes.extend(['patient_' + cf.class_dict[cf.patient_class_of_interest]])
         if hasattr(cf, "eval_bins_separately") and cf.eval_bins_separately:
             metric_classes.extend(['patient_' + cf.bin_dict[cf.patient_bin_of_interest]])
     for cl in metric_classes:
         for m in cf.metrics:
             metrics['train'][cl + '_' + m] = [np.nan]
             metrics['val'][cl + '_' + m] = [np.nan]
 
     return metrics
 
 
 class _AnsiColorizer(object):
     """
     A colorizer is an object that loosely wraps around a stream, allowing
     callers to write text to the stream in a particular color.
 
     Colorizer classes must implement C{supported()} and C{write(text, color)}.
     """
     _colors = dict(black=30, red=31, green=32, yellow=33,
                    blue=34, magenta=35, cyan=36, white=37, default=39)
 
     def __init__(self, stream):
         self.stream = stream
 
     @classmethod
     def supported(cls, stream=sys.stdout):
         """
         A class method that returns True if the current platform supports
         coloring terminal output using this method. Returns False otherwise.
         """
         if not stream.isatty():
             return False  # auto color only on TTYs
         try:
             import curses
         except ImportError:
             return False
         else:
             try:
                 try:
                     return curses.tigetnum("colors") > 2
                 except curses.error:
                     curses.setupterm()
                     return curses.tigetnum("colors") > 2
             except:
                 raise
                 # guess false in case of error
                 return False
 
     def write(self, text, color):
         """
         Write the given text to the stream in the given color.
 
         @param text: Text to be written to the stream.
 
         @param color: A string label for a color. e.g. 'red', 'white'.
         """
         color = self._colors[color]
         self.stream.write('\x1b[%sm%s\x1b[0m' % (color, text))
 
 
 class ColorHandler(logging.StreamHandler):
 
     def __init__(self, stream=sys.stdout):
         super(ColorHandler, self).__init__(_AnsiColorizer(stream))
 
     def emit(self, record):
         msg_colors = {
             logging.DEBUG: "green",
             logging.INFO: "default",
             logging.WARNING: "red",
             logging.ERROR: "red"
         }
         color = msg_colors.get(record.levelno, "blue")
         self.stream.write(record.msg + "\n", color)