diff --git a/datasets/lidc/configs.py b/datasets/lidc/configs.py index ae8d96e..738ecee 100644 --- a/datasets/lidc/configs.py +++ b/datasets/lidc/configs.py @@ -1,445 +1,444 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import sys import os from collections import namedtuple sys.path.append(os.path.dirname(os.path.realpath(__file__))) import numpy as np sys.path.append(os.path.dirname(os.path.realpath(__file__))+"/../..") from default_configs import DefaultConfigs # legends, nested classes are not handled well in multiprocessing! hence, Label class def in outer scope Label = namedtuple("Label", ['id', 'name', 'color', 'm_scores']) # m_scores = malignancy scores binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'm_scores', 'bin_vals']) class Configs(DefaultConfigs): def __init__(self, server_env=None): super(Configs, self).__init__(server_env) ######################### # Preprocessing # ######################### self.root_dir = '/home/gregor/networkdrives/E130-Personal/Goetz/Datenkollektive/Lungendaten/Nodules_LIDC_IDRI' self.raw_data_dir = '{}/new_nrrd'.format(self.root_dir) self.pp_dir = '/media/gregor/HDD2TB/data/lidc/pp_20200309_dev' # 'merged' for one gt per image, 'single_annotator' for four gts per image. self.gts_to_produce = ["single_annotator", "merged"] self.target_spacing = (0.7, 0.7, 1.25) ######################### # I/O # ######################### # path to preprocessed data. self.pp_name = 'pp_20190805' self.input_df_name = 'info_df.pickle' self.data_sourcedir = '/media/gregor/HDD2TB/data/lidc/{}/'.format(self.pp_name) #self.data_sourcedir = '/home/gregor/networkdrives/E132-Cluster-Projects/lidc/data/{}/'.format(self.pp_name) # settings for deployment on cluster. if server_env: # path to preprocessed data. self.data_sourcedir = '/datasets/datasets_ramien/lidc/data/{}_npz/'.format(self.pp_name) # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_fpn']. - self.model = 'retina_net' + self.model = 'retina_unet' self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net') self.model_path = os.path.join(self.source_dir, self.model_path) ######################### # Architecture # ######################### # dimension the model operates in. one out of [2, 3]. self.dim = 2 # 'class': standard object classification per roi, pairwise combinable with each of below tasks. # if 'class' is omitted from tasks, object classes will be fg/bg (1/0) from RPN. # 'regression': regress some vector per each roi # 'regression_ken_gal': use kendall-gal uncertainty sigma # 'regression_bin': classify each roi into a bin related to a regression scale self.prediction_tasks = ['class'] self.start_filts = 48 if self.dim == 2 else 18 self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2 self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50' - self.norm = None # one of None, 'instance_norm', 'batch_norm' + self.norm = "instance_norm" # one of None, 'instance_norm', 'batch_norm' # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform') self.weight_init = None self.regression_n_features = 1 ######################### # Data Loader # ######################### # distorted gt experiments: train on single-annotator gts in a random fashion to investigate network's # handling of noisy gts. # choose 'merged' for single, merged gt per image, or 'single_annotator' for four gts per image. # validation is always performed on same gt kind as training, testing always on merged gt. self.training_gts = "merged" # select modalities from preprocessed data self.channels = [0] self.n_channels = len(self.channels) # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation. self.pre_crop_size_2D = [320, 320] self.patch_size_2D = [320, 320] self.pre_crop_size_3D = [160, 160, 96] self.patch_size_3D = [160, 160, 96] self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D self.pre_crop_size = self.pre_crop_size_2D if self.dim == 2 else self.pre_crop_size_3D # ratio of free sampled batch elements before class balancing is triggered - # (>0 to include "empty"/background patches.) - self.batch_random_ratio = 0.3 + self.batch_random_ratio = 0.2 self.balance_target = "class_targets" if 'class' in self.prediction_tasks else 'rg_bin_targets' # set 2D network to match 3D gt boxes. self.merge_2D_to_3D_preds = self.dim==2 self.observables_rois = [] #self.rg_map = {1:1, 2:2, 3:3, 4:4, 5:5} ######################### # Colors and Legends # ######################### self.plot_frequency = 5 binary_cl_labels = [Label(1, 'benign', (*self.dark_green, 1.), (1, 2)), Label(2, 'malignant', (*self.red, 1.), (3, 4, 5))] quintuple_cl_labels = [Label(1, 'MS1', (*self.dark_green, 1.), (1,)), Label(2, 'MS2', (*self.dark_yellow, 1.), (2,)), Label(3, 'MS3', (*self.orange, 1.), (3,)), Label(4, 'MS4', (*self.bright_red, 1.), (4,)), Label(5, 'MS5', (*self.red, 1.), (5,))] # choose here if to do 2-way or 5-way regression-bin classification task_spec_cl_labels = quintuple_cl_labels self.class_labels = [ # #id #name #color #malignancy score Label( 0, 'bg', (*self.gray, 0.), (0,))] if "class" in self.prediction_tasks: self.class_labels += task_spec_cl_labels else: self.class_labels += [Label(1, 'lesion', (*self.orange, 1.), (1,2,3,4,5))] if any(['regression' in task for task in self.prediction_tasks]): self.bin_labels = [binLabel(0, 'MS0', (*self.gray, 1.), (0,), (0,))] self.bin_labels += [binLabel(cll.id, cll.name, cll.color, cll.m_scores, tuple([ms for ms in cll.m_scores])) for cll in task_spec_cl_labels] self.bin_id2label = {label.id: label for label in self.bin_labels} self.ms2bin_label = {ms: label for label in self.bin_labels for ms in label.m_scores} bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels] self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)} self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)] if self.class_specific_seg: self.seg_labels = self.class_labels else: self.seg_labels = [ # id #name #color Label(0, 'bg', (*self.gray, 0.)), Label(1, 'fg', (*self.orange, 1.)) ] self.class_id2label = {label.id: label for label in self.class_labels} self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0} # class_dict is used in evaluator / ap, auc, etc. statistics, and class 0 (bg) only needs to be # evaluated in debugging self.class_cmap = {label.id: label.color for label in self.class_labels} self.seg_id2label = {label.id: label for label in self.seg_labels} self.cmap = {label.id: label.color for label in self.seg_labels} self.plot_prediction_histograms = True self.plot_stat_curves = False self.has_colorchannels = False self.plot_class_ids = True self.num_classes = len(self.class_dict) # for instance classification (excl background) self.num_seg_classes = len(self.seg_labels) # incl background ######################### # Data Augmentation # ######################### self.da_kwargs={ 'mirror': True, 'mirror_axes': tuple(np.arange(0, self.dim, 1)), 'do_elastic_deform': True, 'alpha':(0., 1500.), 'sigma':(30., 50.), 'do_rotation':True, 'angle_x': (0., 2 * np.pi), 'angle_y': (0., 0), 'angle_z': (0., 0), 'do_scale': True, 'scale':(0.8, 1.1), 'random_crop':False, 'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3), 'border_mode_data': 'constant', 'border_cval_data': 0, 'order_data': 1} if self.dim == 3: self.da_kwargs['do_elastic_deform'] = False self.da_kwargs['angle_x'] = (0, 0.0) self.da_kwargs['angle_y'] = (0, 0.0) #must be 0!! self.da_kwargs['angle_z'] = (0., 2 * np.pi) ################################# # Schedule / Selection / Optim # ################################# self.num_epochs = 130 if self.dim == 2 else 150 self.num_train_batches = 200 if self.dim == 2 else 200 self.batch_size = 20 if self.dim == 2 else 8 # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training) # the former is morge accurate, while the latter is faster (depending on volume size) self.val_mode = 'val_sampling' # only 'val_sampling', 'val_patient' not implemented if self.val_mode == 'val_patient': raise NotImplementedError if self.val_mode == 'val_sampling': self.num_val_batches = 70 self.save_n_models = 4 # set a minimum epoch number for saving in case of instabilities in the first phase of training. self.min_save_thresh = 0 if self.dim == 2 else 0 # criteria to average over for saving epochs, 'criterion':weight. if "class" in self.prediction_tasks: # 'criterion': weight if len(self.class_labels)==3: self.model_selection_criteria = {"benign_ap": 0.5, "malignant_ap": 0.5} elif len(self.class_labels)==6: self.model_selection_criteria = {str(label.name)+"_ap": 1./5 for label in self.class_labels if label.id!=0} elif any("regression" in task for task in self.prediction_tasks): self.model_selection_criteria = {"lesion_ap": 0.2, "lesion_avp": 0.8} - self.weight_decay = 0 + self.weight_decay = 3e-5 self.clip_norm = 200 if 'regression_ken_gal' in self.prediction_tasks else None # number or None # int in [0, dataset_size]. select n patients from dataset for prototyping. If None, all data is used. self.select_prototype_subset = None #self.batch_size ######################### # Testing # ######################### # set the top-n-epochs to be saved for temporal averaging in testing. self.test_n_epochs = self.save_n_models self.test_aug_axes = (0,1,(0,1)) # None or list: choices are 0,1,(0,1) (0==spatial y, 1== spatial x). self.held_out_test_set = False self.max_test_patients = "all" # "all" or number self.report_score_level = ['rois', 'patient'] # choose list from 'patient', 'rois' self.patient_class_of_interest = 2 if 'class' in self.prediction_tasks else 1 self.metrics = ['ap', 'auc'] if any(['regression' in task for task in self.prediction_tasks]): self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp', 'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp'] if 'aleatoric' in self.model: self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted'] self.evaluate_fold_means = True self.ap_match_ious = [0.1] # list of ious to be evaluated for ap-scoring. self.min_det_thresh = 0.1 # minimum confidence value to select predictions for evaluation. # aggregation method for test and val_patient predictions. # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf, # nms = standard non-maximum suppression, or None = no clustering self.clustering = 'wbc' # iou thresh (exclusive!) for regarding two preds as concerning the same ROI self.clustering_iou = 0.1 # has to be larger than desired possible overlap iou of model predictions self.plot_prediction_histograms = True self.plot_stat_curves = False self.n_test_plots = 1 ######################### # Assertions # ######################### if not 'class' in self.prediction_tasks: assert self.num_classes == 1 ######################### # Add model specifics # ######################### {'detection_fpn': self.add_det_fpn_configs, 'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs, 'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs, }[self.model]() def rg_val_to_bin_id(self, rg_val): return float(np.digitize(np.mean(rg_val), self.bin_edges)) def add_det_fpn_configs(self): self.learning_rate = [1e-4] * self.num_epochs self.dynamic_lr_scheduling = False # RoI score assigned to aggregation from pixel prediction (connected component). One of ['max', 'median']. self.score_det = 'max' # max number of roi candidates to identify per batch element and class. self.n_roi_candidates = 10 if self.dim == 2 else 30 # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') self.seg_loss_mode = 'wce' # if <1, false positive predictions in foreground are penalized less. self.fp_dice_weight = 1 if self.dim == 2 else 1 if len(self.class_labels)==3: self.wce_weights = [1., 1., 1.] if self.seg_loss_mode=="dice_wce" else [0.1, 1., 1.] elif len(self.class_labels)==6: self.wce_weights = [1., 1., 1., 1., 1., 1.] if self.seg_loss_mode == "dice_wce" else [0.1, 1., 1., 1., 1., 1.] else: raise Exception("mismatch loss weights & nr of classes") self.detection_min_confidence = self.min_det_thresh self.head_classes = self.num_seg_classes def add_mrcnn_configs(self): # learning rate is a list with one entry per epoch. - self.learning_rate = [1e-4] * self.num_epochs + self.learning_rate = [3e-4] * self.num_epochs self.dynamic_lr_scheduling = False # disable the re-sampling of mask proposals to original size for speed-up. # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching), # mask-outputs are optional. self.return_masks_in_train = False self.return_masks_in_val = True self.return_masks_in_test = False # set number of proposal boxes to plot after each epoch. self.n_plot_rpn_props = 5 if self.dim == 2 else 30 # number of classes for network heads: n_foreground_classes + 1 (background) self.head_classes = self.num_classes + 1 self.frcnn_mode = False # feature map strides per pyramid level are inferred from architecture. self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]} # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.) self.rpn_anchor_scales = {'xy': [[8], [16], [32], [64]], 'z': [[2], [4], [8], [16]]} # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3. self.pyramid_levels = [0, 1, 2, 3] # number of feature maps in rpn. typically lowered in 3D to save gpu-memory. self.n_rpn_features = 512 if self.dim == 2 else 128 # anchor ratios and strides per position in feature maps. self.rpn_anchor_ratios = [0.5, 1, 2] self.rpn_anchor_stride = 1 # Threshold for first stage (RPN) non-maximum suppression (NMS): LOWER == HARDER SELECTION self.rpn_nms_threshold = 0.7 if self.dim == 2 else 0.7 # loss sampling settings. - self.rpn_train_anchors_per_image = 6 #per batch element + self.rpn_train_anchors_per_image = 64 #per batch element self.train_rois_per_image = 6 #per batch element self.roi_positive_ratio = 0.5 self.anchor_matching_iou = 0.7 # factor of top-k candidates to draw from per negative sample (stochastic-hard-example-mining). # poolsize to draw top-k candidates from will be shem_poolsize * n_negative_samples. self.shem_poolsize = 10 self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3) self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5) self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10) self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]]) self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1], self.patch_size_3D[2], self.patch_size_3D[2]]) if self.dim == 2: self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4] self.bbox_std_dev = self.bbox_std_dev[:4] self.window = self.window[:4] self.scale = self.scale[:4] # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element. self.pre_nms_limit = 3000 if self.dim == 2 else 6000 # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True, # since proposals of the entire batch are forwarded through second stage in as one "batch". self.roi_chunk_size = 2500 if self.dim == 2 else 600 self.post_nms_rois_training = 500 if self.dim == 2 else 75 self.post_nms_rois_inference = 500 # Final selection of detections (refine_detections) self.model_max_instances_per_batch_element = 10 if self.dim == 2 else 30 # per batch element and class. self.detection_nms_threshold = 1e-5 # needs to be > 0, otherwise all predictions are one cluster. self.model_min_confidence = 0.1 if self.dim == 2: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride))] for stride in self.backbone_strides['xy']]) else: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride)), int(np.ceil(self.patch_size[2] / stride_z))] for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z'] )]) if self.model == 'retina_net' or self.model == 'retina_unet': self.focal_loss = True # implement extra anchor-scales according to retina-net publication. self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['xy']] self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['z']] self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3 self.n_rpn_features = 256 if self.dim == 2 else 128 # pre-selection of detections for NMS-speedup. per entire batch. self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002 self.anchor_matching_iou = 0.5 if self.model == 'retina_unet': self.operate_stride1 = True diff --git a/datasets/toy_mdt/configs.py b/datasets/toy_mdt/configs.py index 3c8748d..c78225e 100644 --- a/datasets/toy_mdt/configs.py +++ b/datasets/toy_mdt/configs.py @@ -1,355 +1,351 @@ #!/usr/bin/env python # Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import sys import os sys.path.append(os.path.dirname(os.path.realpath(__file__))) import numpy as np from collections import namedtuple from default_configs import DefaultConfigs Label = namedtuple("Label", ['id', 'name', 'color']) class Configs(DefaultConfigs): def __init__(self, server_env=None): ######################### # Preprocessing # ######################### self.root_dir = '/home/gregor/datasets/toy_mdt' ######################### # I/O # ######################### # one out of [2, 3]. dimension the model operates in. self.dim = 2 DefaultConfigs.__init__(self, server_env, self.dim) # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn']. - self.model = 'detection_fpn' + self.model = 'mrcnn' self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net') self.model_path = os.path.join(self.source_dir, self.model_path) # int [0 < dataset_size]. select n patients from dataset for prototyping. self.select_prototype_subset = None self.held_out_test_set = True self.n_train_data = 2500 # choose one of the 3 toy experiments described in https://arxiv.org/pdf/1811.08661.pdf # one of ['donuts_shape', 'donuts_pattern', 'circles_scale']. toy_mode = 'donuts_shape_noise' # path to preprocessed data. self.info_df_name = 'info_df.pickle' self.pp_name = os.path.join(toy_mode, 'train') self.data_sourcedir = os.path.join(self.root_dir, self.pp_name) self.pp_test_name = os.path.join(toy_mode, 'test') self.test_data_sourcedir = os.path.join(self.root_dir, self.pp_test_name) # settings for deployment in cloud. if server_env: # path to preprocessed data. pp_root_dir = '/datasets/datasets_ramien/toy_exp/data' self.pp_name = os.path.join(toy_mode, 'train') self.data_sourcedir = os.path.join(pp_root_dir, self.pp_name) self.pp_test_name = os.path.join(toy_mode, 'test') self.test_data_sourcedir = os.path.join(pp_root_dir, self.pp_test_name) self.select_prototype_subset = None ######################### # Data Loader # ######################### # select modalities from preprocessed data self.channels = [0] self.n_channels = len(self.channels) self.plot_bg_chan = 0 # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation. self.pre_crop_size_2D = [320, 320] self.patch_size_2D = [320, 320] self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D self.pre_crop_size = self.pre_crop_size_2D if self.dim == 2 else self.pre_crop_size_3D # ratio of free sampled batch elements before class balancing is triggered # (>0 to include "empty"/background patches.) self.batch_random_ratio = 0.2 # set 2D network to operate in 3D images. self.merge_2D_to_3D_preds = False - # feed +/- n neighbouring slices into channel dimension. set to None for no context. - self.n_3D_context = None - if self.n_3D_context is not None and self.dim == 2: - self.n_channels *= (self.n_3D_context * 2 + 1) - - ######################### # Architecture # ######################### self.start_filts = 48 if self.dim == 2 else 18 self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2 self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50' self.norm = None # one of None, 'instance_norm', 'batch_norm' # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform') - self.weight_init = None + self.weight_init = "xavier_uniform" # compatibility self.regression_n_features = 1 self.num_classes = 2 # excluding bg self.num_seg_classes = 3 # incl bg ######################### # Schedule / Selection # ######################### - self.num_epochs = 32 + self.num_epochs = 24 self.num_train_batches = 100 if self.dim == 2 else 200 self.batch_size = 20 if self.dim == 2 else 8 self.do_validation = True # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training) # the former is morge accurate, while the latter is faster (depending on volume size) self.val_mode = 'val_patient' # one of 'val_sampling' , 'val_patient' if self.val_mode == 'val_patient': self.max_val_patients = "all" # if 'None' iterates over entire val_set once. if self.val_mode == 'val_sampling': self.num_val_batches = 50 + self.optimizer = "ADAMW" + # set dynamic_lr_scheduling to True to apply LR scheduling with below settings. self.dynamic_lr_scheduling = True - self.lr_decay_factor = 0.5 + self.lr_decay_factor = 0.25 self.scheduling_patience = np.ceil(2400 / (self.num_train_batches * self.batch_size)) self.scheduling_criterion = 'donuts_ap' self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' - self.weight_decay = 0 + self.weight_decay = 3e-5 self.clip_norm = None ######################### # Testing / Plotting # ######################### # set the top-n-epochs to be saved for temporal averaging in testing. self.save_n_models = 5 self.test_n_epochs = 5 self.test_aug_axes = (0, 1, (0, 1)) self.n_test_plots = 2 self.clustering = "wbc" self.clustering_iou = 1e-5 # set a minimum epoch number for saving in case of instabilities in the first phase of training. self.min_save_thresh = 0 if self.dim == 2 else 0 self.report_score_level = ['patient', 'rois'] # choose list from 'patient', 'rois' self.class_labels = [Label(0, 'bg', (*self.white, 0.)), Label(1, 'circles', (*self.orange, .9)), Label(2, 'donuts', (*self.blue, .9)),] if self.class_specific_seg: self.seg_labels = self.class_labels self.box_type2label = {label.name: label for label in self.box_labels} self.class_id2label = {label.id: label for label in self.class_labels} self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0} self.seg_id2label = {label.id: label for label in self.seg_labels} self.cmap = {label.id: label.color for label in self.seg_labels} - + self.metrics = ["ap", "auc"] self.patient_class_of_interest = 2 # patient metrics are only plotted for one class. self.ap_match_ious = [0.1] # list of ious to be evaluated for ap-scoring. self.model_selection_criteria = {name + "_ap": 1. for name in self.class_dict.values()}# criteria to average over for saving epochs. self.min_det_thresh = 0.1 # minimum confidence value to select predictions for evaluation. self.plot_prediction_histograms = True self.plot_stat_curves = False self.plot_class_ids = True ######################### # Data Augmentation # ######################### self.do_aug = False self.da_kwargs={ 'do_elastic_deform': True, 'alpha':(0., 1500.), 'sigma':(30., 50.), 'do_rotation':True, 'angle_x': (0., 2 * np.pi), 'angle_y': (0., 0), 'angle_z': (0., 0), 'do_scale': True, 'scale':(0.8, 1.1), 'random_crop':False, 'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3), 'border_mode_data': 'constant', 'border_cval_data': 0, 'order_data': 1 } if self.dim == 3: self.da_kwargs['do_elastic_deform'] = False self.da_kwargs['angle_x'] = (0, 0.0) self.da_kwargs['angle_y'] = (0, 0.0) #must be 0!! self.da_kwargs['angle_z'] = (0., 2 * np.pi) ######################### # Add model specifics # ######################### {'detection_fpn': self.add_det_fpn_configs, 'mrcnn': self.add_mrcnn_configs, 'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs, }[self.model]() def add_det_fpn_configs(self): self.learning_rate = [1 * 1e-3] * self.num_epochs self.dynamic_lr_scheduling = True self.scheduling_criterion = 'torch_loss' self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' self.n_roi_candidates = 4 if self.dim == 2 else 6 # max number of roi candidates to identify per image (slice in 2D, volume in 3D) # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') self.seg_loss_mode = 'wce' self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1., 1.] self.fp_dice_weight = 1 if self.dim == 2 else 1 # if <1, false positive predictions in foreground are penalized less. self.detection_min_confidence = 0.05 # how to determine score of roi: 'max' or 'median' self.score_det = 'max' def add_mrcnn_configs(self): # learning rate is a list with one entry per epoch. - self.learning_rate = [1e-3] * self.num_epochs + self.learning_rate = [3e-4] * self.num_epochs # disable mask head loss. (e.g. if no pixelwise annotations available) self.frcnn_mode = False # disable the re-sampling of mask proposals to original size for speed-up. # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching), # mask-outputs are optional. self.return_masks_in_val = True self.return_masks_in_test = False # set number of proposal boxes to plot after each epoch. self.n_plot_rpn_props = 0 if self.dim == 2 else 0 # number of classes for head networks: n_foreground_classes + 1 (background) self.head_classes = self.num_classes + 1 # feature map strides per pyramid level are inferred from architecture. self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]} # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.) self.rpn_anchor_scales = {'xy': [[8], [16], [32], [64]], 'z': [[2], [4], [8], [16]]} # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3. self.pyramid_levels = [0, 1, 2, 3] # number of feature maps in rpn. typically lowered in 3D to save gpu-memory. self.n_rpn_features = 512 if self.dim == 2 else 128 # anchor ratios and strides per position in feature maps. self.rpn_anchor_ratios = [0.5, 1., 2.] self.rpn_anchor_stride = 1 # Threshold for first stage (RPN) non-maximum suppression (NMS): LOWER == HARDER SELECTION self.rpn_nms_threshold = 0.7 if self.dim == 2 else 0.7 # loss sampling settings. self.rpn_train_anchors_per_image = 64 #per batch element self.train_rois_per_image = 2 #per batch element self.roi_positive_ratio = 0.5 self.anchor_matching_iou = 0.7 # factor of top-k candidates to draw from per negative sample (stochastic-hard-example-mining). # poolsize to draw top-k candidates from will be shem_poolsize * n_negative_samples. self.shem_poolsize = 10 self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3) self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5) self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10) self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1]]) self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1]]) if self.dim == 2: self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4] self.bbox_std_dev = self.bbox_std_dev[:4] self.window = self.window[:4] self.scale = self.scale[:4] # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element. self.pre_nms_limit = 3000 if self.dim == 2 else 6000 # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True, # since proposals of the entire batch are forwarded through second stage in as one "batch". self.roi_chunk_size = 800 if self.dim == 2 else 600 self.post_nms_rois_training = 500 if self.dim == 2 else 75 self.post_nms_rois_inference = 500 # Final selection of detections (refine_detections) self.model_max_instances_per_batch_element = 10 if self.dim == 2 else 30 # per batch element and class. self.detection_nms_threshold = 1e-5 # needs to be > 0, otherwise all predictions are one cluster. self.model_min_confidence = 0.1 if self.dim == 2: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride))] for stride in self.backbone_strides['xy']]) else: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride)), int(np.ceil(self.patch_size[2] / stride_z))] for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z'] )]) if self.model == 'retina_net' or self.model == 'retina_unet': # implement extra anchor-scales according to retina-net publication. self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['xy']] self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['z']] self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3 self.n_rpn_features = 256 if self.dim == 2 else 64 # pre-selection of detections for NMS-speedup. per entire batch. self.pre_nms_limit = 10000 if self.dim == 2 else 50000 # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002 self.anchor_matching_iou = 0.5 if self.model == 'retina_unet': self.operate_stride1 = True diff --git a/default_configs.py b/default_configs.py index 58cdc8d..c415e98 100644 --- a/default_configs.py +++ b/default_configs.py @@ -1,204 +1,204 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Default Configurations script. Avoids changing configs of all experiments if general settings are to be changed.""" import os from collections import namedtuple boxLabel = namedtuple('boxLabel', ["name", "color"]) class DefaultConfigs: def __init__(self, server_env=None, dim=2): self.server_env = server_env self.cuda_benchmark = True self.sysmetrics_interval = 2 # set > 0 to record system metrics to tboard with this time span in seconds. ######################### # I/O # ######################### self.dim = dim # int [0 < dataset_size]. select n patients from dataset for prototyping. self.select_prototype_subset = None # some default paths. self.source_dir = os.path.dirname(os.path.realpath(__file__)) # current dir. self.backbone_path = os.path.join(self.source_dir, 'models/backbone.py') self.input_df_name = 'info_df.pickle' if server_env: self.select_prototype_subset = None ######################### # Colors/legends # ######################### # in part from solarized theme. self.black = (0.1, 0.05, 0.) self.gray = (0.514, 0.580, 0.588) self.beige = (1., 1., 0.85) self.white = (0.992, 0.965, 0.890) self.green = (0.659, 0.792, 0.251) # [168, 202, 64] self.dark_green = (0.522, 0.600, 0.000) # [133.11, 153. , 0. ] self.cyan = (0.165, 0.631, 0.596) # [ 42.075, 160.905, 151.98 ] self.bright_blue = (0.85, 0.95, 1.) self.blue = (0.149, 0.545, 0.824) # [ 37.995, 138.975, 210.12 ] self.dkfz_blue = (0, 75. / 255, 142. / 255) self.dark_blue = (0.027, 0.212, 0.259) # [ 6.885, 54.06 , 66.045] self.purple = (0.424, 0.443, 0.769) # [108.12 , 112.965, 196.095] self.aubergine = (0.62, 0.21, 0.44) # [ 157, 53 , 111] self.magenta = (0.827, 0.212, 0.510) # [210.885, 54.06 , 130.05 ] self.coral = (1., 0.251, 0.4) # [255,64,102] self.bright_red = (1., 0.15, 0.1) # [255, 38.25, 25.5] self.brighter_red = (0.863, 0.196, 0.184) # [220.065, 49.98 , 46.92 ] self.red = (0.87, 0.05, 0.01) # [ 223, 13, 2] self.dark_red = (0.6, 0.04, 0.005) self.orange = (0.91, 0.33, 0.125) # [ 232.05 , 84.15 , 31.875] self.dark_orange = (0.796, 0.294, 0.086) #[202.98, 74.97, 21.93] self.yellow = (0.95, 0.9, 0.02) # [ 242.25, 229.5 , 5.1 ] self.dark_yellow = (0.710, 0.537, 0.000) # [181.05 , 136.935, 0. ] self.color_palette = [self.blue, self.dark_blue, self.aubergine, self.green, self.yellow, self.orange, self.red, self.cyan, self.black] self.box_labels = [ # name color boxLabel("det", self.blue), boxLabel("prop", self.gray), boxLabel("pos_anchor", self.cyan), boxLabel("neg_anchor", self.cyan), boxLabel("neg_class", self.green), boxLabel("pos_class", self.aubergine), boxLabel("gt", self.red) ] # neg and pos in a medical sense, i.e., pos=positive diagnostic finding self.box_type2label = {label.name: label for label in self.box_labels} self.box_color_palette = {label.name: label.color for label in self.box_labels} # whether the input data is mono-channel or RGB/rgb self.has_colorchannels = False ######################### # Data Loader # ######################### #random seed for fold_generator and batch_generator. self.seed = 0 #number of threads for multithreaded tasks like batch generation, wcs, merge2dto3d self.n_workers = 16 if server_env else os.cpu_count() self.create_bounding_box_targets = True self.class_specific_seg = True # False if self.model=="mrcnn" else True self.max_val_patients = "all" ######################### # Architecture # ######################### self.prediction_tasks = ["class"] # 'class', 'regression_class', 'regression_kendall', 'regression_feindt' self.weight_decay = 0.0 # nonlinearity to be applied after convs with nonlinearity. one of 'relu' or 'leaky_relu' self.relu = 'relu' # if True initializes weights as specified in model script. else use default Pytorch init. self.weight_init = None # if True adds high-res decoder levels to feature pyramid: P1 + P0. (e.g. set to true in retina_unet configs) self.operate_stride1 = False ######################### # Optimization # ######################### - self.optimizer = "ADAM" # "ADAM" or "SGD" or implemented additionals + self.optimizer = "ADAMW" # "ADAMW" or "SGD" or implemented additionals ######################### # Schedule # ######################### # number of folds in cross validation. self.n_cv_splits = 5 ######################### # Testing / Plotting # ######################### # perform mirroring at test time. (only XY. Z not done to not blow up predictions times). self.test_aug = True # if True, test data lies in a separate folder and is not part of the cross validation. self.held_out_test_set = False # if hold-out test set: eval each fold's parameters separately on the test set self.eval_test_fold_wise = True # if held_out_test_set provided, ensemble predictions over models of all trained cv-folds. self.ensemble_folds = False # what metrics to evaluate self.metrics = ['ap'] # whether to evaluate fold means when evaluating over more than one fold self.evaluate_fold_means = False # how often (in nr of epochs) to plot example batches during train/val self.plot_frequency = 1 # color specifications for all box_types in prediction_plot. self.box_color_palette = {'det': 'b', 'gt': 'r', 'neg_class': 'purple', 'prop': 'w', 'pos_class': 'g', 'pos_anchor': 'c', 'neg_anchor': 'c'} # scan over confidence score in evaluation to optimize it on the validation set. self.scan_det_thresh = False # plots roc-curves / prc-curves in evaluation. self.plot_stat_curves = False # if True: evaluate average precision per patient id and average over per-pid results, # instead of computing one ap over whole data set. self.per_patient_ap = False # threshold for clustering 2D box predictions to 3D Cubes. Overlap is computed in XY. self.merge_3D_iou = 0.1 # number or "all" for all self.max_test_patients = "all" ######################### # MRCNN # ######################### # if True, mask loss is not applied. used for data sets, where no pixel-wise annotations are provided. self.frcnn_mode = False self.return_masks_in_train = False # if True, unmolds masks in Mask R-CNN to full-res for plotting/monitoring. self.return_masks_in_val = False self.return_masks_in_test = False # needed if doing instance segmentation. evaluation not yet implemented. # add P6 to Feature Pyramid Network. self.sixth_pooling = False ######################### # RetinaNet # ######################### self.focal_loss = False self.focal_loss_gamma = 2. diff --git a/evaluator.py b/evaluator.py index 682d686..6d53912 100644 --- a/evaluator.py +++ b/evaluator.py @@ -1,971 +1,980 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import os from multiprocessing import Pool import pickle import time import numpy as np import pandas as pd from sklearn.metrics import roc_auc_score, average_precision_score from sklearn.metrics import roc_curve, precision_recall_curve from sklearn.metrics import mean_squared_error, mean_absolute_error, accuracy_score import torch import utils.model_utils as mutils import plotting as plg import warnings def get_roi_ap_from_df(inputs): ''' :param df: data frame. :param det_thresh: min_threshold for filtering out low confidence predictions. :param per_patient_ap: boolean flag. evaluate average precision per patient id and average over per-pid results, instead of computing one ap over whole data set. :return: average_precision (float) ''' df, det_thresh, per_patient_ap = inputs if per_patient_ap: pids_list = df.pid.unique() aps = [] for match_iou in df.match_iou.unique(): iou_df = df[df.match_iou == match_iou] for pid in pids_list: pid_df = iou_df[iou_df.pid == pid] all_p = len(pid_df[pid_df.class_label == 1]) pid_df = pid_df[(pid_df.det_type == 'det_fp') | (pid_df.det_type == 'det_tp')].sort_values('pred_score', ascending=False) pid_df = pid_df[pid_df.pred_score > det_thresh] if (len(pid_df) ==0 and all_p == 0): pass elif (len(pid_df) > 0 and all_p == 0): aps.append(0) else: aps.append(compute_roi_ap(pid_df, all_p)) return np.mean(aps) else: aps = [] for match_iou in df.match_iou.unique(): iou_df = df[df.match_iou == match_iou] # it's important to not apply the threshold before counting all_p in order to not lose the fn! all_p = len(iou_df[(iou_df.det_type == 'det_tp') | (iou_df.det_type == 'det_fn')]) # sorting out all entries that are not fp or tp or have confidence(=pred_score) <= detection_threshold iou_df = iou_df[(iou_df.det_type == 'det_fp') | (iou_df.det_type == 'det_tp')].sort_values('pred_score', ascending=False) iou_df = iou_df[iou_df.pred_score > det_thresh] if all_p>0: aps.append(compute_roi_ap(iou_df, all_p)) return np.mean(aps) def compute_roi_ap(df, all_p): """ adapted from: https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py :param df: dataframe containing class labels of predictions sorted in descending manner by their prediction score. :param all_p: number of all ground truth objects. (for denominator of recall.) :return: """ tp = df.class_label.values fp = (tp == 0) * 1 #recall thresholds, where precision will be measured R = np.linspace(0., 1., np.round((1. - 0.) / .01).astype(int) + 1, endpoint=True) tp_sum = np.cumsum(tp) fp_sum = np.cumsum(fp) n_dets = len(tp) rc = tp_sum / all_p pr = tp_sum / (fp_sum + tp_sum) # initialize precision array over recall steps (q=queries). q = [0. for _ in range(len(R))] # numpy is slow without cython optimization for accessing elements # python array gets significant speed improvement pr = pr.tolist() for i in range(n_dets - 1, 0, -1): if pr[i] > pr[i - 1]: pr[i - 1] = pr[i] #--> pr[i]<=pr[i-1] for all i since we want to consider the maximum #precision value for a queried interval # discretize empiric recall steps with given bins. assert np.all(rc[:-1]<=rc[1:]), "recall not sorted ascendingly" inds = np.searchsorted(rc, R, side='left') try: for rc_ix, pr_ix in enumerate(inds): q[rc_ix] = pr[pr_ix] except IndexError: #now q is filled with pr values up to first non-available index pass return np.mean(q) def roi_avp(inputs): ''' :param df: data frame. :param det_thresh: min_threshold for filtering out low confidence predictions. :param per_patient_ap: boolean flag. evaluate average precision per patient id and average over per-pid results, instead of computing one ap over whole data set. :return: average_precision (float) ''' df, det_thresh, per_patient_ap = inputs if per_patient_ap: pids_list = df.pid.unique() aps = [] for match_iou in df.match_iou.unique(): iou_df = df[df.match_iou == match_iou] for pid in pids_list: pid_df = iou_df[iou_df.pid == pid] all_p = len(pid_df[pid_df.class_label == 1]) mask = ((pid_df.rg_bins == pid_df.rg_bin_target) & (pid_df.det_type == 'det_tp')) | (pid_df.det_type == 'det_fp') pid_df = pid_df[mask].sort_values('pred_score', ascending=False) pid_df = pid_df[pid_df.pred_score > det_thresh] if (len(pid_df) ==0 and all_p == 0): pass elif (len(pid_df) > 0 and all_p == 0): aps.append(0) else: aps.append(compute_roi_ap(pid_df, all_p)) return np.mean(aps) else: aps = [] for match_iou in df.match_iou.unique(): iou_df = df[df.match_iou == match_iou] #it's important to not apply the threshold before counting all_positives! all_p = len(iou_df[(iou_df.det_type == 'det_tp') | (iou_df.det_type == 'det_fn')]) # filtering out tps which don't match rg_bin target at this point is same as reclassifying them as fn. # also sorting out all entries that are not fp or have confidence(=pred_score) <= detection_threshold mask = ((iou_df.rg_bins == iou_df.rg_bin_target) & (iou_df.det_type == 'det_tp')) | (iou_df.det_type == 'det_fp') iou_df = iou_df[mask].sort_values('pred_score', ascending=False) iou_df = iou_df[iou_df.pred_score > det_thresh] if all_p>0: aps.append(compute_roi_ap(iou_df, all_p)) return np.mean(aps) def compute_prc(df): """compute precision-recall curve with maximum precision per recall interval. :param df: :param all_p: # of all positive samples in data. :return: array: [precisions, recall query values] """ assert (df.class_label==1).any(), "cannot compute prc when no positives in data." all_p = len(df[(df.det_type == 'det_tp') | (df.det_type == 'det_fn')]) df = df[(df.det_type=="det_tp") | (df.det_type=="det_fp")] df = df.sort_values("pred_score", ascending=False) # recall thresholds, where precision will be measured scores = df.pred_score.values labels = df.class_label.values n_dets = len(scores) pr = np.zeros((n_dets,)) rc = pr.copy() for rank in range(n_dets): tp = np.count_nonzero(labels[:rank+1]==1) fp = np.count_nonzero(labels[:rank+1]==0) pr[rank] = tp/(tp+fp) rc[rank] = tp/all_p #after obj detection convention/ coco-dataset template: take maximum pr within intervals: # --> pr[i]<=pr[i-1] for all i since we want to consider the maximum # precision value for a queried interval for i in range(n_dets - 1, 0, -1): if pr[i] > pr[i - 1]: pr[i - 1] = pr[i] R = np.linspace(0., 1., np.round((1. - 0.) / .01).astype(int) + 1, endpoint=True)#precision queried at R points inds = np.searchsorted(rc, R, side='left') queries = np.zeros((len(R),)) try: for q_ix, rank in enumerate(inds): queries[q_ix] = pr[rank] except IndexError: pass return np.array((queries, R)) def RMSE(y_true, y_pred, weights=None): if len(y_true)>0: return np.sqrt(mean_squared_error(y_true, y_pred, sample_weight=weights)) else: return np.nan def MAE_w_std(y_true, y_pred, weights=None): if len(y_true)>0: y_true, y_pred = np.array(y_true), np.array(y_pred) deltas = np.abs(y_true-y_pred) mae = np.average(deltas, weights=weights, axis=0).item() skmae = mean_absolute_error(y_true, y_pred, sample_weight=weights) assert np.allclose(mae, skmae, atol=1e-6), "mae {}, sklearn mae {}".format(mae, skmae) std = np.std(weights*deltas) return mae, std else: return np.nan, np.nan def MAE(y_true, y_pred, weights=None): if len(y_true)>0: return mean_absolute_error(y_true, y_pred, sample_weight=weights) else: return np.nan def accuracy(y_true, y_pred, weights=None): if len(y_true)>0: return accuracy_score(y_true, y_pred, sample_weight=weights) else: return np.nan # noinspection PyCallingNonCallable class Evaluator(): """ Evaluates given results dicts. Can return results as updated monitor_metrics. Can save test data frames to file. """ def __init__(self, cf, logger, mode='test'): """ :param mode: either 'train', 'val_sampling', 'val_patient' or 'test'. handles prediction lists of different forms. """ self.cf = cf self.logger = logger self.mode = mode self.regress_flag = any(['regression' in task for task in self.cf.prediction_tasks]) self.plot_dir = self.cf.plot_dir if not self.mode == "test" else self.cf.test_dir if self.cf.plot_prediction_histograms: self.hist_dir = os.path.join(self.plot_dir, 'histograms') os.makedirs(self.hist_dir, exist_ok=True) if self.cf.plot_stat_curves: self.curves_dir = os.path.join(self.plot_dir, 'stat_curves') os.makedirs(self.curves_dir, exist_ok=True) def eval_losses(self, batch_res_dicts): if hasattr(self.cf, "losses_to_monitor"): loss_names = self.cf.losses_to_monitor else: loss_names = {name for b_res_dict in batch_res_dicts for name in b_res_dict if 'loss' in name} self.epoch_losses = {l_name: torch.tensor([b_res_dict[l_name] for b_res_dict in batch_res_dicts if l_name in b_res_dict.keys()]).mean().item() for l_name in loss_names} def eval_segmentations(self, batch_res_dicts, pid_list): batch_dices = [b_res_dict['batch_dices'] for b_res_dict in batch_res_dicts if 'batch_dices' in b_res_dict.keys()] # shape (n_batches, n_seg_classes) if len(batch_dices) > 0: batch_dices = np.array(batch_dices) # dims n_batches x 1 in sampling / n_test_epochs x n_classes assert batch_dices.shape[1] == self.cf.num_seg_classes, "bdices shp {}, n seg cl {}, pid lst len {}".format( batch_dices.shape, self.cf.num_seg_classes, len(pid_list)) self.seg_df = pd.DataFrame() for seg_id in range(batch_dices.shape[1]): self.seg_df[self.cf.seg_id2label[seg_id].name + "_dice"] = batch_dices[:, seg_id] # one row== one batch, one column== one class # self.seg_df[self.cf.seg_id2label[seg_id].name+"_dice"] = np.concatenate(batch_dices[:,:,seg_id]) self.seg_df['fold'] = self.cf.fold if self.mode == "val_patient" or self.mode == "test": # need to make it more conform between sampling and patient-mode self.seg_df["pid"] = [pid for pix, pid in enumerate(pid_list)] # for b_inst in batch_inst_boxes[pix]] else: self.seg_df["pid"] = np.nan def eval_boxes(self, batch_res_dicts, pid_list, obj_cl_dict, obj_cl_identifiers={"gt":'class_targets', "pred":'box_pred_class_id'}): """ :param batch_res_dicts: :param pid_list: [pid_0, pid_1, ...] :return: """ if self.mode == 'train' or self.mode == 'val_sampling': # one pid per batch element # batch_size > 1, with varying patients across batch: # [[[results_0, ...], [pid_0, ...]], [[results_n, ...], [pid_n, ...]], ...] # -> [results_0, results_1, ..] batch_inst_boxes = [b_res_dict['boxes'] for b_res_dict in batch_res_dicts] # len: nr of batches in epoch batch_inst_boxes = [[b_inst_boxes] for whole_batch_boxes in batch_inst_boxes for b_inst_boxes in whole_batch_boxes] # len: batch instances of whole epoch assert np.all(len(b_boxes_list) == self.cf.batch_size for b_boxes_list in batch_inst_boxes) elif self.mode == "val_patient" or self.mode == "test": # patient processing, one element per batch = one patient. # [[results_0, pid_0], [results_1, pid_1], ...] -> [results_0, results_1, ..] # in patientbatchiterator there is only one pid per batch batch_inst_boxes = [b_res_dict['boxes'] for b_res_dict in batch_res_dicts] # in patient mode not actually per batch instance, but per whole batch! if hasattr(self.cf, "eval_test_separately") and self.cf.eval_test_separately: """ you could write your own routines to add GTs to raw predictions for evaluation. implemented standard is: cf.eval_test_separately = False or not set --> GTs are saved at same time and in same file as raw prediction results. """ raise NotImplementedError assert len(batch_inst_boxes) == len(pid_list) df_list_preds = [] df_list_labels = [] df_list_class_preds = [] df_list_pids = [] df_list_type = [] df_list_match_iou = [] df_list_n_missing = [] df_list_regressions = [] df_list_rg_targets = [] df_list_rg_bins = [] df_list_rg_bin_targets = [] df_list_rg_uncs = [] for match_iou in self.cf.ap_match_ious: self.logger.info('evaluating with ap_match_iou: {}'.format(match_iou)) for cl in list(obj_cl_dict.keys()): for pix, pid in enumerate(pid_list): len_df_list_before_patient = len(df_list_pids) # input of each batch element is a list of boxes, where each box is a dictionary. for b_inst_ix, b_boxes_list in enumerate(batch_inst_boxes[pix]): b_tar_boxes = [] b_cand_boxes, b_cand_scores, b_cand_n_missing = [], [], [] if self.regress_flag: b_tar_regs, b_tar_rg_bins = [], [] b_cand_regs, b_cand_rg_bins, b_cand_rg_uncs = [], [], [] for box in b_boxes_list: # each box is either gt or detection or proposal/anchor # we need all gts in the same order & all dets in same order if box['box_type'] == 'gt' and box[obj_cl_identifiers["gt"]] == cl: b_tar_boxes.append(box["box_coords"]) if self.regress_flag: b_tar_regs.append(np.array(box['regression_targets'], dtype='float32')) b_tar_rg_bins.append(box['rg_bin_targets']) if box['box_type'] == 'det' and box[obj_cl_identifiers["pred"]] == cl: b_cand_boxes.append(box["box_coords"]) b_cand_scores.append(box["box_score"]) b_cand_n_missing.append(box["cluster_n_missing"] if 'cluster_n_missing' in box.keys() else np.nan) if self.regress_flag: b_cand_regs.append(box["regression"]) b_cand_rg_bins.append(box["rg_bin"]) b_cand_rg_uncs.append(box["rg_uncertainty"] if 'rg_uncertainty' in box.keys() else np.nan) b_tar_boxes = np.array(b_tar_boxes) b_cand_boxes, b_cand_scores, b_cand_n_missing = np.array(b_cand_boxes), np.array(b_cand_scores), np.array(b_cand_n_missing) if self.regress_flag: b_tar_regs, b_tar_rg_bins = np.array(b_tar_regs), np.array(b_tar_rg_bins) b_cand_regs, b_cand_rg_bins, b_cand_rg_uncs = np.array(b_cand_regs), np.array(b_cand_rg_bins), np.array(b_cand_rg_uncs) # check if predictions and ground truth boxes exist and match them according to match_iou. if not 0 in b_cand_boxes.shape and not 0 in b_tar_boxes.shape: assert np.all(np.round(b_cand_scores,6) <= 1.), "there is a box score>1: {}".format(b_cand_scores[~(b_cand_scores<=1.)]) #coords_check = np.array([len(coords)==self.cf.dim*2 for coords in b_cand_boxes]) #assert np.all(coords_check), "cand box with wrong bcoords dim: {}, mode: {}".format(b_cand_boxes[~coords_check], self.mode) expected_dim = len(b_cand_boxes[0]) assert np.all([len(coords) == expected_dim for coords in b_tar_boxes]), \ "gt/cand box coords mismatch, expected dim: {}.".format(expected_dim) # overlaps: shape len(cand_boxes) x len(tar_boxes) overlaps = mutils.compute_overlaps(b_cand_boxes, b_tar_boxes) # match_cand_ixs: shape (nr_of_matches,) # theses indices are the indices of b_cand_boxes match_cand_ixs = np.argwhere(np.max(overlaps, axis=1) > match_iou)[:, 0] non_match_cand_ixs = np.argwhere(np.max(overlaps, 1) <= match_iou)[:, 0] # the corresponding gt assigned to the pred boxes by highest iou overlap, # i.e., match_gt_ixs holds index into b_tar_boxes for each entry in match_cand_ixs, # i.e., gt_ixs and cand_ixs are paired via their position in their list # (cand_ixs[j] corresponds to gt_ixs[j]) match_gt_ixs = np.argmax(overlaps[match_cand_ixs, :], axis=1) if \ not 0 in match_cand_ixs.shape else np.array([]) assert len(match_gt_ixs)==len(match_cand_ixs) #match_gt_ixs: shape (nr_of_matches,) or 0 non_match_gt_ixs = np.array( [ii for ii in np.arange(b_tar_boxes.shape[0]) if ii not in match_gt_ixs]) unique, counts = np.unique(match_gt_ixs, return_counts=True) # check for double assignments, i.e. two predictions having been assigned to the same gt. # according to the COCO-metrics, only one prediction counts as true positive, the rest counts as # false positive. This case is supposed to be avoided by the model itself by, # e.g. using a low enough NMS threshold. if np.any(counts > 1): double_match_gt_ixs = unique[np.argwhere(counts > 1)[:, 0]] keep_max = [] double_match_list = [] for dg in double_match_gt_ixs: double_match_cand_ixs = match_cand_ixs[np.argwhere(match_gt_ixs == dg)] keep_max.append(double_match_cand_ixs[np.argmax(b_cand_scores[double_match_cand_ixs])]) double_match_list += [ii for ii in double_match_cand_ixs] fp_ixs = np.array([ii for ii in match_cand_ixs if (ii in double_match_list and ii not in keep_max)]) # count as fp: boxes that match gt above match_iou threshold but have not highest class confidence score match_gt_ixs = np.array([gt_ix for ii, gt_ix in enumerate(match_gt_ixs) if match_cand_ixs[ii] not in fp_ixs]) match_cand_ixs = np.array([cand_ix for cand_ix in match_cand_ixs if cand_ix not in fp_ixs]) assert len(match_gt_ixs) == len(match_cand_ixs) df_list_preds += [ii for ii in b_cand_scores[fp_ixs]] df_list_labels += [0] * fp_ixs.shape[0] # means label==gt==0==bg for all these fp_ixs df_list_class_preds += [cl] * fp_ixs.shape[0] df_list_n_missing += [n for n in b_cand_n_missing[fp_ixs]] if self.regress_flag: df_list_regressions += [r for r in b_cand_regs[fp_ixs]] df_list_rg_bins += [r for r in b_cand_rg_bins[fp_ixs]] df_list_rg_uncs += [r for r in b_cand_rg_uncs[fp_ixs]] df_list_rg_targets += [[0.]*self.cf.regression_n_features] * fp_ixs.shape[0] df_list_rg_bin_targets += [0.] * fp_ixs.shape[0] df_list_pids += [pid] * fp_ixs.shape[0] df_list_type += ['det_fp'] * fp_ixs.shape[0] # matched/tp: if not 0 in match_cand_ixs.shape: df_list_preds += list(b_cand_scores[match_cand_ixs]) df_list_labels += [1] * match_cand_ixs.shape[0] df_list_class_preds += [cl] * match_cand_ixs.shape[0] df_list_n_missing += list(b_cand_n_missing[match_cand_ixs]) if self.regress_flag: df_list_regressions += list(b_cand_regs[match_cand_ixs]) df_list_rg_bins += list(b_cand_rg_bins[match_cand_ixs]) df_list_rg_uncs += list(b_cand_rg_uncs[match_cand_ixs]) assert len(match_cand_ixs)==len(match_gt_ixs) df_list_rg_targets += list(b_tar_regs[match_gt_ixs]) df_list_rg_bin_targets += list(b_tar_rg_bins[match_gt_ixs]) df_list_pids += [pid] * match_cand_ixs.shape[0] df_list_type += ['det_tp'] * match_cand_ixs.shape[0] # rest fp: if not 0 in non_match_cand_ixs.shape: df_list_preds += list(b_cand_scores[non_match_cand_ixs]) df_list_labels += [0] * non_match_cand_ixs.shape[0] df_list_class_preds += [cl] * non_match_cand_ixs.shape[0] df_list_n_missing += list(b_cand_n_missing[non_match_cand_ixs]) if self.regress_flag: df_list_regressions += list(b_cand_regs[non_match_cand_ixs]) df_list_rg_bins += list(b_cand_rg_bins[non_match_cand_ixs]) df_list_rg_uncs += list(b_cand_rg_uncs[non_match_cand_ixs]) df_list_rg_targets += [[0.]*self.cf.regression_n_features] * non_match_cand_ixs.shape[0] df_list_rg_bin_targets += [0.] * non_match_cand_ixs.shape[0] df_list_pids += [pid] * non_match_cand_ixs.shape[0] df_list_type += ['det_fp'] * non_match_cand_ixs.shape[0] # fn: if not 0 in non_match_gt_ixs.shape: df_list_preds += [0] * non_match_gt_ixs.shape[0] df_list_labels += [1] * non_match_gt_ixs.shape[0] df_list_class_preds += [cl] * non_match_gt_ixs.shape[0] df_list_n_missing += [np.nan] * non_match_gt_ixs.shape[0] if self.regress_flag: df_list_regressions += [[0.]*self.cf.regression_n_features] * non_match_gt_ixs.shape[0] df_list_rg_bins += [0.] * non_match_gt_ixs.shape[0] df_list_rg_uncs += [np.nan] * non_match_gt_ixs.shape[0] df_list_rg_targets += list(b_tar_regs[non_match_gt_ixs]) df_list_rg_bin_targets += list(b_tar_rg_bins[non_match_gt_ixs]) df_list_pids += [pid] * non_match_gt_ixs.shape[0] df_list_type += ['det_fn'] * non_match_gt_ixs.shape[0] # only fp: if not 0 in b_cand_boxes.shape and 0 in b_tar_boxes.shape: # means there is no gt in all samples! any preds have to be fp. df_list_preds += list(b_cand_scores) df_list_labels += [0] * b_cand_boxes.shape[0] df_list_class_preds += [cl] * b_cand_boxes.shape[0] df_list_n_missing += list(b_cand_n_missing) if self.regress_flag: df_list_regressions += list(b_cand_regs) df_list_rg_bins += list(b_cand_rg_bins) df_list_rg_uncs += list(b_cand_rg_uncs) df_list_rg_targets += [[0.]*self.cf.regression_n_features] * b_cand_boxes.shape[0] df_list_rg_bin_targets += [0.] * b_cand_boxes.shape[0] df_list_pids += [pid] * b_cand_boxes.shape[0] df_list_type += ['det_fp'] * b_cand_boxes.shape[0] # only fn: if 0 in b_cand_boxes.shape and not 0 in b_tar_boxes.shape: df_list_preds += [0] * b_tar_boxes.shape[0] df_list_labels += [1] * b_tar_boxes.shape[0] df_list_class_preds += [cl] * b_tar_boxes.shape[0] df_list_n_missing += [np.nan] * b_tar_boxes.shape[0] if self.regress_flag: df_list_regressions += [[0.]*self.cf.regression_n_features] * b_tar_boxes.shape[0] df_list_rg_bins += [0.] * b_tar_boxes.shape[0] df_list_rg_uncs += [np.nan] * b_tar_boxes.shape[0] df_list_rg_targets += list(b_tar_regs) df_list_rg_bin_targets += list(b_tar_rg_bins) df_list_pids += [pid] * b_tar_boxes.shape[0] df_list_type += ['det_fn'] * b_tar_boxes.shape[0] # empty patient with 0 detections needs empty patient score, in order to not disappear from stats. # filtered out for roi-level evaluation later. During training (and val_sampling), # tn are assigned per sample independently of associated patients. # i.e., patient_tn is also meant as sample_tn if a list of samples is evaluated instead of whole patient if len(df_list_pids) == len_df_list_before_patient: df_list_preds += [0] df_list_labels += [0] df_list_class_preds += [cl] df_list_n_missing += [np.nan] if self.regress_flag: df_list_regressions += [[0.]*self.cf.regression_n_features] df_list_rg_bins += [0.] df_list_rg_uncs += [np.nan] df_list_rg_targets += [[0.]*self.cf.regression_n_features] df_list_rg_bin_targets += [0.] df_list_pids += [pid] df_list_type += ['patient_tn'] # true negative: no ground truth boxes, no detections. df_list_match_iou += [match_iou] * (len(df_list_preds) - len(df_list_match_iou)) self.test_df = pd.DataFrame() self.test_df['pred_score'] = df_list_preds self.test_df['class_label'] = df_list_labels # class labels are gt, 0,1, only indicate neg/pos (or bg/fg) remapped from all classes self.test_df['pred_class'] = df_list_class_preds # can be diff than 0,1 self.test_df['pid'] = df_list_pids self.test_df['det_type'] = df_list_type self.test_df['fold'] = self.cf.fold self.test_df['match_iou'] = df_list_match_iou self.test_df['cluster_n_missing'] = df_list_n_missing if self.regress_flag: self.test_df['regressions'] = df_list_regressions self.test_df['rg_targets'] = df_list_rg_targets self.test_df['rg_uncertainties'] = df_list_rg_uncs self.test_df['rg_bins'] = df_list_rg_bins # super weird error: pandas does not properly add an attribute if column is named "rg_bin_targets" ... ?!? self.test_df['rg_bin_target'] = df_list_rg_bin_targets assert hasattr(self.test_df, "rg_bin_target") #fn_df = self.test_df[self.test_df["det_type"] == "det_fn"] pass def evaluate_predictions(self, results_list, monitor_metrics=None): """ Performs the matching of predicted boxes and ground truth boxes. Loops over list of matching IoUs and foreground classes. Resulting info of each prediction is stored as one line in an internal dataframe, with the keys: det_type: 'tp' (true positive), 'fp' (false positive), 'fn' (false negative), 'tn' (true negative) pred_class: foreground class which the object predicts. pid: corresponding patient-id. pred_score: confidence score [0, 1] fold: corresponding fold of CV. match_iou: utilized IoU for matching. :param results_list: list of model predictions. Either from train/val_sampling (patch processing) for monitoring with form: [[[results_0, ...], [pid_0, ...]], [[results_n, ...], [pid_n, ...]], ...] Or from val_patient/testing (patient processing), with form: [[results_0, pid_0], [results_1, pid_1], ...]) :param monitor_metrics (optional): dict of dicts with all metrics of previous epochs. :return monitor_metrics: if provided (during training), return monitor_metrics now including results of current epoch. """ # gets results_list = [[batch_instances_box_lists], [batch_instances_pids]]*n_batches # we want to evaluate one batch_instance (= 2D or 3D image) at a time. self.logger.info('evaluating in mode {}'.format(self.mode)) batch_res_dicts = [batch[0] for batch in results_list] # len: nr of batches in epoch if self.mode == 'train' or self.mode=='val_sampling': # one pid per batch element # [[[results_0, ...], [pid_0, ...]], [[results_n, ...], [pid_n, ...]], ...] # -> [pid_0, pid_1, ...] # additional list wrapping to make conform with below per-patient batches, where one pid is linked to more than one batch instance pid_list = [batch_instance_pid for batch in results_list for batch_instance_pid in batch[1]] elif self.mode == "val_patient" or self.mode=="test": # [[results_0, pid_0], [results_1, pid_1], ...] -> [pid_0, pid_1, ...] # in patientbatchiterator there is only one pid per batch pid_list = [np.unique(batch[1]) for batch in results_list] assert np.all([len(pid)==1 for pid in pid_list]), "pid list in patient-eval mode, should only contain a single scalar per patient: {}".format(pid_list) pid_list = [pid[0] for pid in pid_list] else: raise Exception("undefined run mode encountered") self.eval_losses(batch_res_dicts) self.eval_segmentations(batch_res_dicts, pid_list) self.eval_boxes(batch_res_dicts, pid_list, self.cf.class_dict) if monitor_metrics is not None: # return all_stats, updated monitor_metrics return self.return_metrics(self.test_df, self.cf.class_dict, monitor_metrics) def return_metrics(self, df, obj_cl_dict, monitor_metrics=None, boxes_only=False): """ Calculates metric scores for internal data frame. Called directly from evaluate_predictions during training for monitoring, or from score_test_df during inference (for single folds or aggregated test set). Loops over foreground classes and score_levels ('roi' and/or 'patient'), gets scores and stores them. Optionally creates plots of prediction histograms and ROC/PR curves. :param df: Data frame that holds evaluated predictions. :param obj_cl_dict: Dict linking object-class ids to object-class names. E.g., {1: "bikes", 2 : "cars"}. Set in configs as cf.class_dict. :param monitor_metrics: dict of dicts with all metrics of previous epochs. This function adds metrics for current epoch and returns the same object. :param boxes_only: whether to produce metrics only for the boxes, not the segmentations. :return: all_stats: list. Contains dicts with resulting scores for each combination of foreground class and score_level. :return: monitor_metrics """ # -------------- monitoring independent of class, score level ------------ if monitor_metrics is not None: for l_name in self.epoch_losses: monitor_metrics[l_name] = [self.epoch_losses[l_name]] # -------------- metrics calc dependent on class, score level ------------ all_stats = [] # all_stats: one entry per score_level per class + for cl in list(obj_cl_dict.keys()): # bg eval is neglected cl_name = obj_cl_dict[cl] cl_df = df[df.pred_class == cl] + if hasattr(self, "seg_df") and not boxes_only: dice_col = self.cf.seg_id2label[cl].name+"_dice" seg_cl_df = self.seg_df.loc[:,['pid', dice_col, 'fold']] for score_level in self.cf.report_score_level: stats_dict = {} stats_dict['name'] = 'fold_{} {} {}'.format(self.cf.fold, score_level, cl_name) # -------------- RoI-based ----------------- if score_level == 'rois': stats_dict['auc'] = np.nan stats_dict['roc'] = np.nan if monitor_metrics is not None: tn = len(cl_df[cl_df.det_type == "patient_tn"]) tp = len(cl_df[(cl_df.det_type == "det_tp")&(cl_df.pred_score>self.cf.min_det_thresh)]) fp = len(cl_df[(cl_df.det_type == "det_fp")&(cl_df.pred_score>self.cf.min_det_thresh)]) fn = len(cl_df[cl_df.det_type == "det_fn"]) sens = np.divide(tp, (fn + tp)) monitor_metrics.update({"Bin_Stats/" + cl_name + "_fp": [fp], "Bin_Stats/" + cl_name + "_tp": [tp], "Bin_Stats/" + cl_name + "_fn": [fn], "Bin_Stats/" + cl_name + "_tn": [tn], "Bin_Stats/" + cl_name + "_sensitivity": [sens]}) # list wrapping only needed bc other metrics are recorded over all epochs; spec_df = cl_df[cl_df.det_type != 'patient_tn'] if self.regress_flag: # filter false negatives out for regression-only eval since regressor didn't predict truncd_df = spec_df[(((spec_df.det_type == "det_fp") | ( spec_df.det_type == "det_tp")) & spec_df.pred_score > self.cf.min_det_thresh)] truncd_df_tp = truncd_df[truncd_df.det_type == "det_tp"] weights, weights_tp = truncd_df.pred_score.tolist(), truncd_df_tp.pred_score.tolist() y_true, y_pred = truncd_df.rg_targets.tolist(), truncd_df.regressions.tolist() stats_dict["rg_RMSE"] = RMSE(y_true, y_pred) stats_dict["rg_MAE"] = MAE(y_true, y_pred) stats_dict["rg_RMSE_weighted"] = RMSE(y_true, y_pred, weights) stats_dict["rg_MAE_weighted"] = MAE(y_true, y_pred, weights) y_true, y_pred = truncd_df_tp.rg_targets.tolist(), truncd_df_tp.regressions.tolist() stats_dict["rg_MAE_weighted_tp"] = MAE(y_true, y_pred, weights_tp) stats_dict["rg_MAE_w_std_weighted_tp"] = MAE_w_std(y_true, y_pred, weights_tp) y_true, y_pred = truncd_df.rg_bin_target.tolist(), truncd_df.rg_bins.tolist() stats_dict["rg_bin_accuracy"] = accuracy(y_true, y_pred) stats_dict["rg_bin_accuracy_weighted"] = accuracy(y_true, y_pred, weights) y_true, y_pred = truncd_df_tp.rg_bin_target.tolist(), truncd_df_tp.rg_bins.tolist() stats_dict["rg_bin_accuracy_weighted_tp"] = accuracy(y_true, y_pred, weights_tp) if np.any(~truncd_df.rg_uncertainties.isna()): # det_fn are expected to be NaN so they drop out in means stats_dict.update({"rg_uncertainty": truncd_df.rg_uncertainties.mean(), "rg_uncertainty_tp": truncd_df_tp.rg_uncertainties.mean(), "rg_uncertainty_tp_weighted": (truncd_df_tp.rg_uncertainties * truncd_df_tp.pred_score).sum() / truncd_df_tp.pred_score.sum() }) if (spec_df.class_label==1).any(): stats_dict['ap'] = get_roi_ap_from_df((spec_df, self.cf.min_det_thresh, self.cf.per_patient_ap)) stats_dict['prc'] = precision_recall_curve(spec_df.class_label.tolist(), spec_df.pred_score.tolist()) if self.regress_flag: stats_dict['avp'] = roi_avp((spec_df, self.cf.min_det_thresh, self.cf.per_patient_ap)) else: stats_dict['ap'] = np.nan stats_dict['prc'] = np.nan stats_dict['avp'] = np.nan # np.nan is formattable by __format__ as a float, None-type is not if hasattr(self, "seg_df") and not boxes_only: stats_dict["dice"] = seg_cl_df.loc[:,dice_col].mean() # mean per all rois in this epoch stats_dict["dice_std"] = seg_cl_df.loc[:,dice_col].std() # for the aggregated test set case, additionally get the scores of averaging over fold results. if self.cf.evaluate_fold_means and len(df.fold.unique()) > 1: aps = [] for fold in df.fold.unique(): fold_df = spec_df[spec_df.fold == fold] if (fold_df.class_label==1).any(): aps.append(get_roi_ap_from_df((fold_df, self.cf.min_det_thresh, self.cf.per_patient_ap))) stats_dict['ap_folds_mean'] = np.mean(aps) if len(aps)>0 else np.nan stats_dict['ap_folds_std'] = np.std(aps) if len(aps)>0 else np.nan stats_dict['auc_folds_mean'] = np.nan stats_dict['auc_folds_std'] = np.nan if self.regress_flag: avps, accuracies, MAEs = [], [], [] for fold in df.fold.unique(): fold_df = spec_df[spec_df.fold == fold] if (fold_df.class_label == 1).any(): avps.append(roi_avp((fold_df, self.cf.min_det_thresh, self.cf.per_patient_ap))) truncd_df_tp = fold_df[((fold_df.det_type == "det_tp") & fold_df.pred_score > self.cf.min_det_thresh)] weights_tp = truncd_df_tp.pred_score.tolist() y_true, y_pred = truncd_df_tp.rg_bin_target.tolist(), truncd_df_tp.rg_bins.tolist() accuracies.append(accuracy(y_true, y_pred, weights_tp)) y_true, y_pred = truncd_df_tp.rg_targets.tolist(), truncd_df_tp.regressions.tolist() MAEs.append(MAE_w_std(y_true, y_pred, weights_tp)) stats_dict['avp_folds_mean'] = np.mean(avps) if len(avps) > 0 else np.nan stats_dict['avp_folds_std'] = np.std(avps) if len(avps) > 0 else np.nan stats_dict['rg_bin_accuracy_weighted_tp_folds_mean'] = np.mean(accuracies) if len(accuracies) > 0 else np.nan stats_dict['rg_bin_accuracy_weighted_tp_folds_std'] = np.std(accuracies) if len(accuracies) > 0 else np.nan stats_dict['rg_MAE_w_std_weighted_tp_folds_mean'] = np.mean(MAEs, axis=0) if len(MAEs) > 0 else np.nan stats_dict['rg_MAE_w_std_weighted_tp_folds_std'] = np.std(MAEs, axis=0) if len(MAEs) > 0 else np.nan if hasattr(self, "seg_df") and not boxes_only and self.cf.evaluate_fold_means and len(seg_cl_df.fold.unique()) > 1: fold_means = seg_cl_df.groupby(['fold'], as_index=True).agg({dice_col:"mean"}) stats_dict["dice_folds_mean"] = float(fold_means.mean()) stats_dict["dice_folds_std"] = float(fold_means.std()) # -------------- patient-based ----------------- # on patient level, aggregate predictions per patient (pid): The patient predicted score is the highest # confidence prediction for this class. The patient class label is 1 if roi of this class exists in patient, else 0. if score_level == 'patient': #this is the critical part in patient scoring: only the max gt and max pred score are taken per patient! #--> does mix up values from separate detections spec_df = cl_df.groupby(['pid'], as_index=False) agg_args = {'class_label': 'max', 'pred_score': 'max', 'fold': 'first'} if self.regress_flag: # pandas throws error if aggregated value is np.array, not if is list. agg_args.update({'regressions': lambda series: list(series.iloc[np.argmax(series.apply(np.linalg.norm).values)]), 'rg_targets': lambda series: list(series.iloc[np.argmax(series.apply(np.linalg.norm).values)]), 'rg_bins': 'max', 'rg_bin_target': 'max', 'rg_uncertainties': 'max' }) if hasattr(cl_df, "cluster_n_missing"): agg_args.update({'cluster_n_missing': 'mean'}) spec_df = spec_df.agg(agg_args) if len(spec_df.class_label.unique()) > 1: stats_dict['auc'] = roc_auc_score(spec_df.class_label.tolist(), spec_df.pred_score.tolist()) stats_dict['roc'] = roc_curve(spec_df.class_label.tolist(), spec_df.pred_score.tolist()) else: stats_dict['auc'] = np.nan stats_dict['roc'] = np.nan if (spec_df.class_label == 1).any(): patient_cl_labels = spec_df.class_label.tolist() stats_dict['ap'] = average_precision_score(patient_cl_labels, spec_df.pred_score.tolist()) stats_dict['prc'] = precision_recall_curve(patient_cl_labels, spec_df.pred_score.tolist()) if self.regress_flag: avp_scores = spec_df[spec_df.rg_bins == spec_df.rg_bin_target].pred_score.tolist() avp_scores += [0.] * (len(patient_cl_labels) - len(avp_scores)) stats_dict['avp'] = average_precision_score(patient_cl_labels, avp_scores) else: stats_dict['ap'] = np.nan stats_dict['prc'] = np.nan stats_dict['avp'] = np.nan if self.regress_flag: y_true, y_pred = spec_df.rg_targets.tolist(), spec_df.regressions.tolist() stats_dict["rg_RMSE"] = RMSE(y_true, y_pred) stats_dict["rg_MAE"] = MAE(y_true, y_pred) stats_dict["rg_bin_accuracy"] = accuracy(spec_df.rg_bin_target.tolist(), spec_df.rg_bins.tolist()) stats_dict["rg_uncertainty"] = spec_df.rg_uncertainties.mean() if hasattr(self, "seg_df") and not boxes_only: seg_cl_df = seg_cl_df.groupby(['pid'], as_index=False).agg( {dice_col: "mean", "fold": "first"}) # mean of all rois per patient in this epoch stats_dict["dice"] = seg_cl_df.loc[:,dice_col].mean() #mean of all patients stats_dict["dice_std"] = seg_cl_df.loc[:, dice_col].std() # for the aggregated test set case, additionally get the scores for averaging over fold results. if self.cf.evaluate_fold_means and len(df.fold.unique()) > 1 and self.mode in ["test", "analysis"]: aucs = [] aps = [] for fold in df.fold.unique(): fold_df = spec_df[spec_df.fold == fold] if (fold_df.class_label==1).any(): aps.append( average_precision_score(fold_df.class_label.tolist(), fold_df.pred_score.tolist())) if len(fold_df.class_label.unique())>1: aucs.append(roc_auc_score(fold_df.class_label.tolist(), fold_df.pred_score.tolist())) stats_dict['auc_folds_mean'] = np.mean(aucs) stats_dict['auc_folds_std'] = np.std(aucs) stats_dict['ap_folds_mean'] = np.mean(aps) stats_dict['ap_folds_std'] = np.std(aps) if hasattr(self, "seg_df") and not boxes_only and self.cf.evaluate_fold_means and len(seg_cl_df.fold.unique()) > 1: fold_means = seg_cl_df.groupby(['fold'], as_index=True).agg({dice_col:"mean"}) stats_dict["dice_folds_mean"] = float(fold_means.mean()) stats_dict["dice_folds_std"] = float(fold_means.std()) all_stats.append(stats_dict) # -------------- monitoring, visualisation ----------------- # fill new results into monitor_metrics dict. for simplicity, only one class (of interest) is monitored on patient level. patient_interests = [self.cf.class_dict[self.cf.patient_class_of_interest],] if hasattr(self.cf, "bin_dict"): patient_interests += [self.cf.bin_dict[self.cf.patient_bin_of_interest]] if monitor_metrics is not None and (score_level != 'patient' or cl_name in patient_interests): name = 'patient_'+cl_name if score_level == 'patient' else cl_name for metric in self.cf.metrics: if metric in stats_dict.keys(): monitor_metrics[name + '_'+metric].append(stats_dict[metric]) else: print("WARNING: skipped monitor metric {}_{} since not avail".format(name, metric)) # histograms if self.cf.plot_prediction_histograms: out_filename = os.path.join(self.hist_dir, 'pred_hist_{}_{}_{}_{}'.format( self.cf.fold, self.mode, score_level, cl_name)) plg.plot_prediction_hist(self.cf, spec_df, out_filename) # analysis of the hyper-parameter cf.min_det_thresh, for optimization on validation set. if self.cf.scan_det_thresh and "val" in self.mode: conf_threshs = list(np.arange(0.8, 1, 0.02)) pool = Pool(processes=self.cf.n_workers) mp_inputs = [[spec_df, ii, self.cf.per_patient_ap] for ii in conf_threshs] aps = pool.map(get_roi_ap_from_df, mp_inputs, chunksize=1) pool.close() pool.join() self.logger.info('results from scanning over det_threshs: {}'.format([[i, j] for i, j in zip(conf_threshs, aps)])) + class_means = pd.DataFrame(columns=self.cf.report_score_level) + for slevel in self.cf.report_score_level: + level_stats = pd.DataFrame([stats for stats in all_stats if slevel in stats["name"]])[self.cf.metrics] + class_means.loc[:, slevel] = level_stats.mean() + all_stats.extend([{"name": 'fold_{} {} {}'.format(self.cf.fold, slevel, "class_means"), **level_means} for + slevel, level_means in class_means.to_dict().items()]) + if self.cf.plot_stat_curves: out_filename = os.path.join(self.curves_dir, '{}_{}_stat_curves'.format(self.cf.fold, self.mode)) plg.plot_stat_curves(self.cf, all_stats, out_filename) if self.cf.plot_prediction_histograms and hasattr(df, "cluster_n_missing") and df.cluster_n_missing.notna().any(): out_filename = os.path.join(self.hist_dir, 'n_missing_hist_{}_{}.png'.format(self.cf.fold, self.mode)) plg.plot_wbc_n_missing(self.cf, df, outfile=out_filename) return all_stats, monitor_metrics def score_test_df(self, max_fold=None, internal_df=True): """ Writes out resulting scores to text files: First checks for class-internal-df (typically current) fold, gets resulting scores, writes them to a text file and pickles data frame. Also checks if data-frame pickles of all folds of cross-validation exist in exp_dir. If true, loads all dataframes, aggregates test sets over folds, and calculates and writes out overall metrics. """ # this should maybe be extended to auc, ap stds. metrics_to_score = self.cf.metrics # + [ m+ext for m in self.cf.metrics if "dice" in m for ext in ["_std"]] if internal_df: self.test_df.to_pickle(os.path.join(self.cf.test_dir, '{}_test_df.pkl'.format(self.cf.fold))) if hasattr(self, "seg_df"): self.seg_df.to_pickle(os.path.join(self.cf.test_dir, '{}_test_seg_df.pkl'.format(self.cf.fold))) stats, _ = self.return_metrics(self.test_df, self.cf.class_dict) with open(os.path.join(self.cf.test_dir, 'results.txt'), 'a') as handle: handle.write('\n****************************\n') handle.write('\nresults for fold {}, {} \n'.format(self.cf.fold, time.strftime("%d/%m/%y %H:%M:%S"))) handle.write('\n****************************\n') handle.write('\nfold df shape {}\n \n'.format(self.test_df.shape)) for s in stats: for metric in metrics_to_score: if metric in s.keys(): #needed as long as no dice on patient level poss if "accuracy" in metric: handle.write('{} {:0.4f} '.format(metric, s[metric])) else: handle.write('{} {:0.3f} '.format(metric, s[metric])) else: print("WARNING: skipped metric {} since not avail".format(metric)) handle.write('{} \n'.format(s['name'])) fold_df_paths = sorted([ii for ii in os.listdir(self.cf.test_dir) if 'test_df.pkl' in ii]) fold_seg_df_paths = sorted([ii for ii in os.listdir(self.cf.test_dir) if 'test_seg_df.pkl' in ii]) for paths in [fold_df_paths, fold_seg_df_paths]: assert len(paths)<= self.cf.n_cv_splits, "found {} > nr of cv splits results dfs in {}".format(len(paths), self.cf.test_dir) if max_fold is None: max_fold = self.cf.n_cv_splits-1 if self.cf.fold == max_fold: print("max fold/overall stats triggered") if self.cf.evaluate_fold_means: metrics_to_score += [m + ext for m in self.cf.metrics for ext in ("_folds_mean", "_folds_std")] with open(os.path.join(self.cf.test_dir, 'results.txt'), 'a') as handle: self.cf.fold = 'overall' dfs_list = [pd.read_pickle(os.path.join(self.cf.test_dir, ii)) for ii in fold_df_paths] seg_dfs_list = [pd.read_pickle(os.path.join(self.cf.test_dir, ii)) for ii in fold_seg_df_paths] self.test_df = pd.concat(dfs_list, sort=True) if len(seg_dfs_list)>0: self.seg_df = pd.concat(seg_dfs_list, sort=True) stats, _ = self.return_metrics(self.test_df, self.cf.class_dict) handle.write('\n****************************\n') handle.write('\nOVERALL RESULTS \n') handle.write('\n****************************\n') handle.write('\ndf shape \n \n'.format(self.test_df.shape)) for s in stats: for metric in metrics_to_score: if metric in s.keys(): handle.write('{} {:0.3f} '.format(metric, s[metric])) handle.write('{} \n'.format(s['name'])) results_table_path = os.path.join(self.cf.test_dir,"../../", 'results_table.csv') with open(results_table_path, 'a') as handle: #---column headers--- handle.write('\n{},'.format("Experiment Name")) handle.write('{},'.format("Time Stamp")) handle.write('{},'.format("Samples Seen")) handle.write('{},'.format("Spatial Dim")) handle.write('{},'.format("Patch Size")) handle.write('{},'.format("CV Folds")) handle.write('{},'.format("{}-clustering IoU".format(self.cf.clustering))) handle.write('{},'.format("Merge-2D-to-3D IoU")) if hasattr(self.cf, "test_against_exact_gt"): handle.write('{},'.format('Exact GT')) for s in stats: assert "overall" in s['name'].split(" ")[0] - if self.cf.class_dict[self.cf.patient_class_of_interest] in s['name']: + if self.cf.class_dict[self.cf.patient_class_of_interest] in s['name'] or "mean" in s["name"]: for metric in metrics_to_score: if metric in s.keys() and not np.isnan(s[metric]): if metric=='ap': handle.write('{}_{} : {}_{},'.format(*s['name'].split(" ")[1:], metric, int(np.mean(self.cf.ap_match_ious)*100))) elif not "folds_std" in metric: handle.write('{}_{} : {},'.format(*s['name'].split(" ")[1:], metric)) else: print("WARNING: skipped metric {} since not avail".format(metric)) handle.write('\n') #--- columns content--- handle.write('{},'.format(self.cf.exp_dir.split(os.sep)[-1])) handle.write('{},'.format(time.strftime("%d%b%y %H:%M:%S"))) handle.write('{},'.format(self.cf.num_epochs*self.cf.num_train_batches*self.cf.batch_size)) handle.write('{}D,'.format(self.cf.dim)) handle.write('{},'.format("x".join([str(self.cf.patch_size[i]) for i in range(self.cf.dim)]))) handle.write('{},'.format(str(self.test_df.fold.unique().tolist()).replace(",", ""))) handle.write('{},'.format(self.cf.clustering_iou if self.cf.clustering else str("N/A"))) handle.write('{},'.format(self.cf.merge_3D_iou if self.cf.merge_2D_to_3D_preds else str("N/A"))) if hasattr(self.cf, "test_against_exact_gt"): handle.write('{},'.format(self.cf.test_against_exact_gt)) for s in stats: - if self.cf.class_dict[self.cf.patient_class_of_interest] in s['name']: + if self.cf.class_dict[self.cf.patient_class_of_interest] in s['name'] or "mean" in s["name"]: for metric in metrics_to_score: if metric in s.keys() and not np.isnan(s[metric]): # needed as long as no dice on patient level possible if "folds_mean" in metric: handle.write('{:0.3f}\u00B1{:0.3f}, '.format(s[metric], s["_".join((*metric.split("_")[:-1], "std"))])) elif not "folds_std" in metric: handle.write('{:0.3f}, '.format(s[metric])) handle.write('\n') with open(os.path.join(self.cf.test_dir, 'results_extr_scores.txt'), 'w') as handle: handle.write('\n****************************\n') handle.write('\nextremal scores for fold {} \n'.format(self.cf.fold)) handle.write('\n****************************\n') # want: pid & fold (&other) of highest scoring tp & fp in test_df for cl in self.cf.class_dict.keys(): print("\nClass {}".format(self.cf.class_dict[cl]), file=handle) cl_df = self.test_df[self.test_df.pred_class == cl] #.dropna(axis=1) for det_type in ['det_tp', 'det_fp']: filtered_df = cl_df[cl_df.det_type==det_type] print("\nHighest scoring {} of class {}".format(det_type, self.cf.class_dict[cl]), file=handle) if len(filtered_df)>0: print(filtered_df.loc[filtered_df.pred_score.idxmax()], file=handle) else: print("No detections of type {} for class {} in this df".format(det_type, self.cf.class_dict[cl]), file=handle) handle.write('\n****************************\n') diff --git a/exec.py b/exec.py index 7c7df4f..c343d47 100644 --- a/exec.py +++ b/exec.py @@ -1,341 +1,343 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ execution script. this where all routines come together and the only script you need to call. refer to parse args below to see options for execution. """ import plotting as plg import os import warnings import argparse import time import torch import utils.exp_utils as utils from evaluator import Evaluator from predictor import Predictor for msg in ["Attempting to set identical bottom==top results", "This figure includes Axes that are not compatible with tight_layout", "Data has no positive values, and therefore cannot be log-scaled.", ".*invalid value encountered in true_divide.*"]: warnings.filterwarnings("ignore", msg) def train(cf, logger): """ performs the training routine for a given fold. saves plots and selected parameters to the experiment dir specified in the configs. logs to file and tensorboard. """ logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format( cf.dim, cf.fold, cf.exp_dir, cf.model)) logger.time("train_val") # -------------- inits and settings ----------------- net = model.net(cf, logger).cuda() - if cf.optimizer == "ADAM": - optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay) + if cf.optimizer == "ADAMW": + optimizer = torch.optim.AdamW(utils.parse_params_for_optim(net, weight_decay=cf.weight_decay), + lr=cf.learning_rate[0]) elif cf.optimizer == "SGD": - optimizer = torch.optim.SGD(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay, momentum=0.3) + optimizer = torch.optim.SGD(utils.parse_params_for_optim(net, weight_decay=cf.weight_decay), + lr=cf.learning_rate[0], momentum=0.3) if cf.dynamic_lr_scheduling: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=cf.scheduling_mode, factor=cf.lr_decay_factor, - patience=cf.scheduling_patience) + patience=cf.scheduling_patience) model_selector = utils.ModelSelector(cf, logger) starting_epoch = 1 if cf.resume: checkpoint_path = os.path.join(cf.fold_dir, "last_state.pth") starting_epoch, net, optimizer, model_selector = \ utils.load_checkpoint(checkpoint_path, net, optimizer, model_selector) logger.info('resumed from checkpoint {} to epoch {}'.format(checkpoint_path, starting_epoch)) # prepare monitoring monitor_metrics = utils.prepare_monitoring(cf) logger.info('loading dataset and initializing batch generators...') batch_gen = data_loader.get_train_generators(cf, logger) # -------------- training ----------------- for epoch in range(starting_epoch, cf.num_epochs + 1): logger.info('starting training epoch {}/{}'.format(epoch, cf.num_epochs)) logger.time("train_epoch") net.train() train_results_list = [] train_evaluator = Evaluator(cf, logger, mode='train') for i in range(cf.num_train_batches): logger.time("train_batch_loadfw") batch = next(batch_gen['train']) batch_gen['train'].generator.stats['roi_counts'] += batch['roi_counts'] batch_gen['train'].generator.stats['empty_counts'] += batch['empty_counts'] logger.time("train_batch_loadfw") logger.time("train_batch_netfw") results_dict = net.train_forward(batch) logger.time("train_batch_netfw") logger.time("train_batch_bw") optimizer.zero_grad() results_dict['torch_loss'].backward() if cf.clip_norm: torch.nn.utils.clip_grad_norm_(net.parameters(), cf.clip_norm, norm_type=2) # gradient clipping optimizer.step() train_results_list.append(({k:v for k,v in results_dict.items() if k != "seg_preds"}, batch["pid"])) # slim res dict if not cf.server_env: print("\rFinished training batch " + "{}/{} in {:.1f}s ({:.2f}/{:.2f} forw load/net, {:.2f} backw).".format(i+1, cf.num_train_batches, logger.get_time("train_batch_loadfw")+ logger.get_time("train_batch_netfw") +logger.time("train_batch_bw"), logger.get_time("train_batch_loadfw",reset=True), logger.get_time("train_batch_netfw", reset=True), logger.get_time("train_batch_bw", reset=True)), end="", flush=True) print() #--------------- train eval ---------------- if (epoch-1)%cf.plot_frequency==0: # view an example batch utils.split_off_process(plg.view_batch, cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True, get_time="train-example plot", out_file=os.path.join(cf.plot_dir, 'batch_example_train_{}.png'.format(cf.fold))) logger.time("evals") _, monitor_metrics['train'] = train_evaluator.evaluate_predictions(train_results_list, monitor_metrics['train']) logger.time("evals") logger.time("train_epoch", toggle=False) del train_results_list #----------- validation ------------ logger.info('starting validation in mode {}.'.format(cf.val_mode)) logger.time("val_epoch") with torch.no_grad(): net.eval() val_results_list = [] val_evaluator = Evaluator(cf, logger, mode=cf.val_mode) val_predictor = Predictor(cf, net, logger, mode='val') for i in range(batch_gen['n_val']): logger.time("val_batch") batch = next(batch_gen[cf.val_mode]) if cf.val_mode == 'val_patient': results_dict = val_predictor.predict_patient(batch) elif cf.val_mode == 'val_sampling': results_dict = net.train_forward(batch, is_validation=True) val_results_list.append([results_dict, batch["pid"]]) if not cf.server_env: print("\rFinished validation {} {}/{} in {:.1f}s.".format('patient' if cf.val_mode=='val_patient' else 'batch', i + 1, batch_gen['n_val'], logger.time("val_batch")), end="", flush=True) print() #------------ val eval ------------- if (epoch - 1) % cf.plot_frequency == 0: utils.split_off_process(plg.view_batch, cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True, get_time="val-example plot", out_file=os.path.join(cf.plot_dir, 'batch_example_val_{}.png'.format(cf.fold))) logger.time("evals") _, monitor_metrics['val'] = val_evaluator.evaluate_predictions(val_results_list, monitor_metrics['val']) model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch) del val_results_list #----------- monitoring ------------- monitor_metrics.update({"lr": {str(g) : group['lr'] for (g, group) in enumerate(optimizer.param_groups)}}) logger.metrics2tboard(monitor_metrics, global_step=epoch) logger.time("evals") logger.info('finished epoch {}/{}, took {:.2f}s. train total: {:.2f}s, average: {:.2f}s. val total: {:.2f}s, average: {:.2f}s.'.format( epoch, cf.num_epochs, logger.get_time("train_epoch")+logger.time("val_epoch"), logger.get_time("train_epoch"), logger.get_time("train_epoch", reset=True)/cf.num_train_batches, logger.get_time("val_epoch"), logger.get_time("val_epoch", reset=True)/batch_gen["n_val"])) logger.info("time for evals: {:.2f}s".format(logger.get_time("evals", reset=True))) #-------------- scheduling ----------------- if cf.dynamic_lr_scheduling: scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1]) else: for param_group in optimizer.param_groups: param_group['lr'] = cf.learning_rate[epoch-1] logger.time("train_val") logger.info("Training and validating over {} epochs took {}".format(cf.num_epochs, logger.get_time("train_val", format="hms", reset=True))) batch_gen['train'].generator.print_stats(logger, plot=True) def test(cf, logger, max_fold=None): """performs testing for a given fold (or held out set). saves stats in evaluator. """ logger.time("test_fold") logger.info('starting testing model of fold {} in exp {}'.format(cf.fold, cf.exp_dir)) net = model.net(cf, logger).cuda() batch_gen = data_loader.get_test_generator(cf, logger) test_predictor = Predictor(cf, net, logger, mode='test') test_results_list = test_predictor.predict_test_set(batch_gen, return_results = not hasattr( cf, "eval_test_separately") or not cf.eval_test_separately) if test_results_list is not None: test_evaluator = Evaluator(cf, logger, mode='test') test_evaluator.evaluate_predictions(test_results_list) test_evaluator.score_test_df(max_fold=max_fold) logger.info('Testing of fold {} took {}.\n'.format(cf.fold, logger.get_time("test_fold", reset=True, format="hms"))) if __name__ == '__main__': stime = time.time() parser = argparse.ArgumentParser() parser.add_argument('--dataset_name', type=str, default='toy', help="path to the dataset-specific code in source_dir/datasets") parser.add_argument('--exp_dir', type=str, default='/home/gregor/Documents/regrcnn/datasets/toy/experiments/dev', help='path to experiment dir. will be created if non existent.') parser.add_argument('-m', '--mode', type=str, default='train_test', help='one out of: create_exp, analysis, train, train_test, or test') parser.add_argument('-f', '--folds', nargs='+', type=int, default=None, help='None runs over all folds in CV. otherwise specify list of folds.') parser.add_argument('--server_env', default=False, action='store_true', help='change IO settings to deploy models on a cluster.') parser.add_argument('--data_dest', type=str, default=None, help="path to final data folder if different from config") parser.add_argument('--use_stored_settings', default=False, action='store_true', help='load configs from existing exp_dir instead of source dir. always done for testing, ' 'but can be set to true to do the same for training. useful in job scheduler environment, ' 'where source code might change before the job actually runs.') parser.add_argument('--resume', action="store_true", default=False, help='if given, resume from checkpoint(s) of the specified folds.') parser.add_argument('-d', '--dev', default=False, action='store_true', help="development mode: shorten everything") args = parser.parse_args() args.dataset_name = os.path.join("datasets", args.dataset_name) if not "datasets" in args.dataset_name else args.dataset_name folds = args.folds resume = None if args.resume in ['None', 'none'] else args.resume if args.mode == 'create_exp': cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=False) logger = utils.get_logger(cf.exp_dir, cf.server_env, -1) logger.info('created experiment directory at {}'.format(args.exp_dir)) elif args.mode == 'train' or args.mode == 'train_test': cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, args.use_stored_settings) if args.dev: folds = [0,1] cf.batch_size, cf.num_epochs, cf.min_save_thresh, cf.save_n_models = 3 if cf.dim==2 else 1, 2, 0, 2 cf.num_train_batches, cf.num_val_batches, cf.max_val_patients = 5, 1, 1 cf.test_n_epochs = cf.save_n_models cf.max_test_patients = 1 torch.backends.cudnn.benchmark = cf.dim==3 else: torch.backends.cudnn.benchmark = cf.cuda_benchmark if args.data_dest is not None: cf.data_dest = args.data_dest logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval) data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py')) model = utils.import_module('model', cf.model_path) logger.info("loaded model from {}".format(cf.model_path)) if folds is None: folds = range(cf.n_cv_splits) for fold in folds: """k-fold cross-validation: the dataset is split into k equally-sized folds, one used for validation, one for testing, the rest for training. This loop iterates k-times over the dataset, cyclically moving the splits. k==folds, fold in [0,folds) says which split is used for testing. """ cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold)); cf.fold = fold logger.set_logfile(fold=fold) cf.resume = resume if not os.path.exists(cf.fold_dir): os.mkdir(cf.fold_dir) train(cf, logger) cf.resume = None if args.mode == 'train_test': test(cf, logger) elif args.mode == 'test': cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False) if args.data_dest is not None: cf.data_dest = args.data_dest logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval) data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py')) model = utils.import_module('model', cf.model_path) logger.info("loaded model from {}".format(cf.model_path)) fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")]) if folds is None: folds = range(cf.n_cv_splits) if args.dev: folds = folds[:2] cf.batch_size, cf.max_test_patients, cf.test_n_epochs = 1 if cf.dim==2 else 1, 2, 2 else: torch.backends.cudnn.benchmark = cf.cuda_benchmark for fold in folds: cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold)); cf.fold = fold logger.set_logfile(fold=fold) if cf.fold_dir in fold_dirs: test(cf, logger, max_fold=max([int(f[-1]) for f in fold_dirs])) else: logger.info("Skipping fold {} since no model parameters found.".format(fold)) # load raw predictions saved by predictor during testing, run aggregation algorithms and evaluation. elif args.mode == 'analysis': """ analyse already saved predictions. """ cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False) logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval) if cf.held_out_test_set and not cf.eval_test_fold_wise: predictor = Predictor(cf, net=None, logger=logger, mode='analysis') results_list = predictor.load_saved_predictions() logger.info('starting evaluation...') cf.fold = 0 evaluator = Evaluator(cf, logger, mode='test') evaluator.evaluate_predictions(results_list) evaluator.score_test_df(max_fold=0) else: fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")]) if args.dev: fold_dirs = fold_dirs[:1] if folds is None: folds = range(cf.n_cv_splits) for fold in folds: cf.fold = fold; cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold)) logger.set_logfile(fold=fold) if cf.fold_dir in fold_dirs: predictor = Predictor(cf, net=None, logger=logger, mode='analysis') results_list = predictor.load_saved_predictions() # results_list[x][1] is pid, results_list[x][0] is list of len samples-per-patient, each entry hlds # list of boxes per that sample, i.e., len(results_list[x][y][0]) would be nr of boxes in sample y of patient x logger.info('starting evaluation...') evaluator = Evaluator(cf, logger, mode='test') evaluator.evaluate_predictions(results_list) max_fold = max([int(f[-1]) for f in fold_dirs]) evaluator.score_test_df(max_fold=max_fold) else: logger.info("Skipping fold {} since no model parameters found.".format(fold)) else: raise ValueError('mode "{}" specified in args is not implemented.'.format(args.mode)) mins, secs = divmod((time.time() - stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) logger.info("{} total runtime: {}".format(os.path.split(__file__)[1], t)) del logger torch.cuda.empty_cache() diff --git a/models/retina_net.py b/models/retina_net.py index aa28d41..ee1b266 100644 --- a/models/retina_net.py +++ b/models/retina_net.py @@ -1,779 +1,778 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Retina Net. According to https://arxiv.org/abs/1708.02002""" import utils.model_utils as mutils import utils.exp_utils as utils import sys import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.utils sys.path.append('..') from custom_extensions.nms import nms class Classifier(nn.Module): def __init__(self, cf, conv): """ Builds the classifier sub-network. """ super(Classifier, self).__init__() self.dim = conv.dim self.n_classes = cf.head_classes n_input_channels = cf.end_filts n_features = cf.n_rpn_features n_output_channels = cf.n_anchors_per_pos * cf.head_classes anchor_stride = cf.rpn_anchor_stride self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, pad=1, relu=None) def forward(self, x): """ :param x: input feature map (b, in_c, y, x, (z)) :return: class_logits (b, n_anchors, n_classes) """ x = self.conv_1(x) x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) class_logits = self.conv_final(x) axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) class_logits = class_logits.permute(*axes) class_logits = class_logits.contiguous() class_logits = class_logits.view(x.shape[0], -1, self.n_classes) return [class_logits] class BBRegressor(nn.Module): def __init__(self, cf, conv): """ Builds the bb-regression sub-network. """ super(BBRegressor, self).__init__() self.dim = conv.dim n_input_channels = cf.end_filts n_features = cf.n_rpn_features n_output_channels = cf.n_anchors_per_pos * self.dim * 2 anchor_stride = cf.rpn_anchor_stride self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, pad=1, relu=None) def forward(self, x): """ :param x: input feature map (b, in_c, y, x, (z)) :return: bb_logits (b, n_anchors, dim * 2) """ x = self.conv_1(x) x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) bb_logits = self.conv_final(x) axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) bb_logits = bb_logits.permute(*axes) bb_logits = bb_logits.contiguous() bb_logits = bb_logits.view(x.shape[0], -1, self.dim * 2) return [bb_logits] class RoIRegressor(nn.Module): def __init__(self, cf, conv, rg_feats): """ Builds the RoI-item-regression sub-network. Regression items can be, e.g., malignancy scores of tumors. """ super(RoIRegressor, self).__init__() self.dim = conv.dim n_input_channels = cf.end_filts n_features = cf.n_rpn_features self.rg_feats = rg_feats n_output_channels = cf.n_anchors_per_pos * self.rg_feats anchor_stride = cf.rpn_anchor_stride self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, pad=1, relu=None) def forward(self, x): """ :param x: input feature map (b, in_c, y, x, (z)) :return: bb_logits (b, n_anchors, dim * 2) """ x = self.conv_1(x) x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) x = self.conv_final(x) axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) x = x.permute(*axes) x = x.contiguous() x = x.view(x.shape[0], -1, self.rg_feats) return [x] ############################################################ # Loss Functions ############################################################ # def compute_class_loss(anchor_matches, class_pred_logits, shem_poolsize=20): """ :param anchor_matches: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. :param class_pred_logits: (n_anchors, n_classes). logits from classifier sub-network. :param shem_poolsize: int. factor of top-k candidates to draw from per negative sample (online-hard-example-mining). :return: loss: torch tensor :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training. """ # Positive and Negative anchors contribute to the loss, # but neutral anchors (match value = 0) don't. pos_indices = torch.nonzero(anchor_matches > 0) neg_indices = torch.nonzero(anchor_matches == -1) # get positive samples and calucalte loss. if not 0 in pos_indices.size(): pos_indices = pos_indices.squeeze(1) roi_logits_pos = class_pred_logits[pos_indices] targets_pos = anchor_matches[pos_indices].detach() pos_loss = F.cross_entropy(roi_logits_pos, targets_pos.long()) else: pos_loss = torch.FloatTensor([0]).cuda() # get negative samples, such that the amount matches the number of positive samples, but at least 1. # get high scoring negatives by applying online-hard-example-mining. if not 0 in neg_indices.size(): neg_indices = neg_indices.squeeze(1) roi_logits_neg = class_pred_logits[neg_indices] negative_count = np.max((1, pos_indices.cpu().data.numpy().size)) roi_probs_neg = F.softmax(roi_logits_neg, dim=1) neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize) neg_loss = F.cross_entropy(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda()) # return the indices of negative samples, who contributed to the loss for monitoring plots. np_neg_ix = neg_ix.cpu().data.numpy() else: neg_loss = torch.FloatTensor([0]).cuda() np_neg_ix = np.array([]).astype('int32') loss = (pos_loss + neg_loss) / 2 return loss, np_neg_ix def compute_bbox_loss(target_deltas, pred_deltas, anchor_matches): """ :param target_deltas: (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))). Uses 0 padding to fill in unused bbox deltas. :param pred_deltas: predicted deltas from bbox regression head. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))) :param anchor_matches: tensor (n_anchors). value in [-1, 0, class_ids] for negative, neutral, and positive matched anchors. i.e., positively matched anchors are marked by class_id >0 :return: loss: torch 1D tensor. """ if not 0 in torch.nonzero(anchor_matches>0).shape: indices = torch.nonzero(anchor_matches>0).squeeze(1) # Pick bbox deltas that contribute to the loss pred_deltas = pred_deltas[indices] # Trim target bounding box deltas to the same length as pred_deltas. target_deltas = target_deltas[:pred_deltas.shape[0], :].detach() # Smooth L1 loss loss = F.smooth_l1_loss(pred_deltas, target_deltas) else: loss = torch.FloatTensor([0]).cuda() return loss def compute_rg_loss(tasks, target, pred, anchor_matches): """ :param target_deltas: (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))). Uses 0 padding to fill in unsed bbox deltas. :param pred_deltas: predicted deltas from bbox regression head. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))) :param anchor_matches: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. :return: loss: torch 1D tensor. """ if not 0 in target.shape and not 0 in torch.nonzero(anchor_matches>0).shape: indices = torch.nonzero(anchor_matches>0).squeeze(1) # Pick rgs that contribute to the loss pred = pred[indices] # Trim target target = target[:pred.shape[0]].detach() if 'regression_bin' in tasks: loss = F.cross_entropy(pred, target.long()) else: loss = F.smooth_l1_loss(pred, target) else: loss = torch.FloatTensor([0]).cuda() return loss def compute_focal_class_loss(anchor_matches, class_pred_logits, gamma=2.): """ Focal Loss FL = -(1-q)^g log(q) with q = pred class probability. :param anchor_matches: (n_anchors). [-1, 0, class] for negative, neutral, and positive matched anchors. :param class_pred_logits: (n_anchors, n_classes). logits from classifier sub-network. :param gamma: g in above formula, good results with g=2 in original paper. :return: loss: torch tensor :return: focal loss """ # Positive and Negative anchors contribute to the loss, # but neutral anchors (match value = 0) don't. pos_indices = torch.nonzero(anchor_matches > 0).squeeze(-1) # dim=-1 instead of 1 or 0 to cover empty matches. neg_indices = torch.nonzero(anchor_matches == -1).squeeze(-1) target_classes = torch.cat( (anchor_matches[pos_indices].long(), torch.LongTensor([0] * neg_indices.shape[0]).cuda()) ) non_neutral_indices = torch.cat( (pos_indices, neg_indices) ) q = F.softmax(class_pred_logits[non_neutral_indices], dim=1) # q shape: (n_non_neutral_anchors, n_classes) # one-hot encoded target classes: keep only the pred probs of the correct class. it will receive incentive to be maximized. # log(q_i) where i = target class --> FL shape (n_anchors,) # need to transform to indices into flattened tensor to use torch.take target_locs_flat = q.shape[1] * torch.arange(q.shape[0]).cuda() + target_classes q = torch.take(q, target_locs_flat) FL = torch.log(q) # element-wise log FL *= -(1-q)**gamma # take mean over all considered anchors FL = FL.sum() / FL.shape[0] return FL def refine_detections(anchors, probs, deltas, regressions, batch_ixs, cf): """Refine classified proposals, filter overlaps and return final detections. n_proposals here is typically a very large number: batch_size * n_anchors. This function is hence optimized on trimming down n_proposals. :param anchors: (n_anchors, 2 * dim) :param probs: (n_proposals, n_classes) softmax probabilities for all rois as predicted by classifier head. :param deltas: (n_proposals, n_classes, 2 * dim) box refinement deltas as predicted by bbox regressor head. :param regressions: (n_proposals, n_classes, n_rg_feats) :param batch_ixs: (n_proposals) batch element assignemnt info for re-allocation. :return: result: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score, pred_regr)) """ anchors = anchors.repeat(batch_ixs.unique().shape[0], 1) #flatten foreground probabilities, sort and trim down to highest confidences by pre_nms limit. fg_probs = probs[:, 1:].contiguous() flat_probs, flat_probs_order = fg_probs.view(-1).sort(descending=True) keep_ix = flat_probs_order[:cf.pre_nms_limit] # reshape indices to 2D index array with shape like fg_probs. keep_arr = torch.cat(((keep_ix / fg_probs.shape[1]).unsqueeze(1), (keep_ix % fg_probs.shape[1]).unsqueeze(1)), 1) pre_nms_scores = flat_probs[:cf.pre_nms_limit] pre_nms_class_ids = keep_arr[:, 1] + 1 # add background again. pre_nms_batch_ixs = batch_ixs[keep_arr[:, 0]] pre_nms_anchors = anchors[keep_arr[:, 0]] pre_nms_deltas = deltas[keep_arr[:, 0]] pre_nms_regressions = regressions[keep_arr[:, 0]] keep = torch.arange(pre_nms_scores.size()[0]).long().cuda() # apply bounding box deltas. re-scale to image coordinates. std_dev = torch.from_numpy(np.reshape(cf.rpn_bbox_std_dev, [1, cf.dim * 2])).float().cuda() scale = torch.from_numpy(cf.scale).float().cuda() refined_rois = mutils.apply_box_deltas_2D(pre_nms_anchors / scale, pre_nms_deltas * std_dev) * scale \ if cf.dim == 2 else mutils.apply_box_deltas_3D(pre_nms_anchors / scale, pre_nms_deltas * std_dev) * scale # round and cast to int since we're deadling with pixels now refined_rois = mutils.clip_to_window(cf.window, refined_rois) pre_nms_rois = torch.round(refined_rois) for j, b in enumerate(mutils.unique1d(pre_nms_batch_ixs)): bixs = torch.nonzero(pre_nms_batch_ixs == b)[:, 0] bix_class_ids = pre_nms_class_ids[bixs] bix_rois = pre_nms_rois[bixs] bix_scores = pre_nms_scores[bixs] for i, class_id in enumerate(mutils.unique1d(bix_class_ids)): ixs = torch.nonzero(bix_class_ids == class_id)[:, 0] # nms expects boxes sorted by score. ix_rois = bix_rois[ixs] ix_scores = bix_scores[ixs] ix_scores, order = ix_scores.sort(descending=True) ix_rois = ix_rois[order, :] ix_scores = ix_scores class_keep = nms.nms(ix_rois, ix_scores, cf.detection_nms_threshold) # map indices back. class_keep = keep[bixs[ixs[order[class_keep]]]] # merge indices over classes for current batch element b_keep = class_keep if i == 0 else mutils.unique1d(torch.cat((b_keep, class_keep))) # only keep top-k boxes of current batch-element. top_ids = pre_nms_scores[b_keep].sort(descending=True)[1][:cf.model_max_instances_per_batch_element] b_keep = b_keep[top_ids] # merge indices over batch elements. batch_keep = b_keep if j == 0 else mutils.unique1d(torch.cat((batch_keep, b_keep))) keep = batch_keep # arrange output. result = torch.cat((pre_nms_rois[keep], pre_nms_batch_ixs[keep].unsqueeze(1).float(), pre_nms_class_ids[keep].unsqueeze(1).float(), pre_nms_scores[keep].unsqueeze(1), pre_nms_regressions[keep]), dim=1) return result def gt_anchor_matching(cf, anchors, gt_boxes, gt_class_ids=None, gt_regressions=None): """Given the anchors and GT boxes, compute overlaps and identify positive anchors and deltas to refine them to match their corresponding GT boxes. anchors: [num_anchors, (y1, x1, y2, x2, (z1), (z2))] gt_boxes: [num_gt_boxes, (y1, x1, y2, x2, (z1), (z2))] gt_class_ids (optional): [num_gt_boxes] Integer class IDs for one stage detectors. in RPN case of Mask R-CNN, set all positive matches to 1 (foreground) gt_regressions: [num_gt_rgs, n_rg_feats], if None empty rg_targets are returned Returns: anchor_class_matches: [N] (int32) matches between anchors and GT boxes. class_id = positive anchor, -1 = negative anchor, 0 = neutral. i.e., positively matched anchors are marked by class_id (which is >0). anchor_delta_targets: [N, (dy, dx, (dz), log(dh), log(dw), (log(dd)))] Anchor bbox deltas. anchor_rg_targets: [n_anchors, n_rg_feats] """ anchor_class_matches = np.zeros([anchors.shape[0]], dtype=np.int32) anchor_delta_targets = np.zeros((cf.rpn_train_anchors_per_image, 2*cf.dim)) if gt_regressions is not None: if 'regression_bin' in cf.prediction_tasks: anchor_rg_targets = np.zeros((cf.rpn_train_anchors_per_image,)) else: anchor_rg_targets = np.zeros((cf.rpn_train_anchors_per_image, cf.regression_n_features)) else: anchor_rg_targets = np.array([]) anchor_matching_iou = cf.anchor_matching_iou if gt_boxes is None: anchor_class_matches = np.full(anchor_class_matches.shape, fill_value=-1) return anchor_class_matches, anchor_delta_targets, anchor_rg_targets # for mrcnn: anchor matching is done for RPN loss, so positive labels are all 1 (foreground) if gt_class_ids is None: gt_class_ids = np.array([1] * len(gt_boxes)) # Compute overlaps [num_anchors, num_gt_boxes] overlaps = mutils.compute_overlaps(anchors, gt_boxes) # Match anchors to GT Boxes # If an anchor overlaps a GT box with IoU >= anchor_matching_iou then it's positive. # If an anchor overlaps a GT box with IoU < 0.1 then it's negative. # Neutral anchors are those that don't match the conditions above, # and they don't influence the loss function. # However, don't keep any GT box unmatched (rare, but happens). Instead, # match it to the closest anchor (even if its max IoU is < 0.1). # 1. Set negative anchors first. They get overwritten below if a GT box is # matched to them. Skip boxes in crowd areas. anchor_iou_argmax = np.argmax(overlaps, axis=1) anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax] if anchors.shape[1] == 4: anchor_class_matches[(anchor_iou_max < 0.1)] = -1 elif anchors.shape[1] == 6: anchor_class_matches[(anchor_iou_max < 0.01)] = -1 else: raise ValueError('anchor shape wrong {}'.format(anchors.shape)) # 2. Set an anchor for each GT box (regardless of IoU value). gt_iou_argmax = np.argmax(overlaps, axis=0) for ix, ii in enumerate(gt_iou_argmax): anchor_class_matches[ii] = gt_class_ids[ix] # 3. Set anchors with high overlap as positive. above_thresh_ixs = np.argwhere(anchor_iou_max >= anchor_matching_iou) anchor_class_matches[above_thresh_ixs] = gt_class_ids[anchor_iou_argmax[above_thresh_ixs]] # Subsample to balance positive anchors. ids = np.where(anchor_class_matches > 0)[0] extra = len(ids) - (cf.rpn_train_anchors_per_image // 2) if extra > 0: # Reset the extra ones to neutral ids = np.random.choice(ids, extra, replace=False) anchor_class_matches[ids] = 0 # Leave all negative proposals negative for now and sample from them later in online hard example mining. # For positive anchors, compute shift and scale needed to transform them to match the corresponding GT boxes. ids = np.where(anchor_class_matches > 0)[0] ix = 0 # index into anchor_delta_targets for i, a in zip(ids, anchors[ids]): # closest gt box (it might have IoU < anchor_matching_iou) gt = gt_boxes[anchor_iou_argmax[i]] # convert coordinates to center plus width/height. gt_h = gt[2] - gt[0] gt_w = gt[3] - gt[1] gt_center_y = gt[0] + 0.5 * gt_h gt_center_x = gt[1] + 0.5 * gt_w # Anchor a_h = a[2] - a[0] a_w = a[3] - a[1] a_center_y = a[0] + 0.5 * a_h a_center_x = a[1] + 0.5 * a_w if cf.dim == 2: anchor_delta_targets[ix] = [ (gt_center_y - a_center_y) / a_h, (gt_center_x - a_center_x) / a_w, np.log(gt_h / a_h), np.log(gt_w / a_w)] else: gt_d = gt[5] - gt[4] gt_center_z = gt[4] + 0.5 * gt_d a_d = a[5] - a[4] a_center_z = a[4] + 0.5 * a_d anchor_delta_targets[ix] = [ (gt_center_y - a_center_y) / a_h, (gt_center_x - a_center_x) / a_w, (gt_center_z - a_center_z) / a_d, np.log(gt_h / a_h), np.log(gt_w / a_w), np.log(gt_d / a_d)] # normalize. anchor_delta_targets[ix] /= cf.rpn_bbox_std_dev if gt_regressions is not None: anchor_rg_targets[ix] = gt_regressions[anchor_iou_argmax[i]] ix += 1 return anchor_class_matches, anchor_delta_targets, anchor_rg_targets ############################################################ # RetinaNet Class ############################################################ class net(nn.Module): """Encapsulates the RetinaNet model functionality. """ def __init__(self, cf, logger): """ cf: A Sub-class of the cf class model_dir: Directory to save training logs and trained weights """ super(net, self).__init__() self.cf = cf self.logger = logger self.build() if self.cf.weight_init is not None: - logger.info("using pytorch weight init of type {}".format(self.cf.weight_init)) mutils.initialize_weights(self) else: logger.info("using default pytorch weight init") self.debug_acm = [] def build(self): """Build Retina Net architecture.""" # Image size must be dividable by 2 multiple times. h, w = self.cf.patch_size[:2] if h / 2 ** 5 != int(h / 2 ** 5) or w / 2 ** 5 != int(w / 2 ** 5): raise Exception("Image size must be divisible by 2 at least 5 times " "to avoid fractions when downscaling and upscaling." "For example, use 256, 320, 384, 448, 512, ... etc. ") backbone = utils.import_module('bbone', self.cf.backbone_path) self.logger.info("loaded backbone from {}".format(self.cf.backbone_path)) conv = backbone.ConvGenerator(self.cf.dim) # build Anchors, FPN, Classifier / Bbox-Regressor -head self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf) self.anchors = torch.from_numpy(self.np_anchors).float().cuda() self.fpn = backbone.FPN(self.cf, conv, operate_stride1=self.cf.operate_stride1).cuda() self.classifier = Classifier(self.cf, conv).cuda() self.bb_regressor = BBRegressor(self.cf, conv).cuda() if 'regression' in self.cf.prediction_tasks: self.roi_regressor = RoIRegressor(self.cf, conv, self.cf.regression_n_features).cuda() elif 'regression_bin' in self.cf.prediction_tasks: # classify into bins of regression values self.roi_regressor = RoIRegressor(self.cf, conv, len(self.cf.bin_labels)).cuda() else: self.roi_regressor = lambda x: [torch.tensor([]).cuda()] if self.cf.model == 'retina_unet': self.final_conv = conv(self.cf.end_filts, self.cf.num_seg_classes, ks=1, pad=0, norm=None, relu=None) def forward(self, img): """ :param img: input img (b, c, y, x, (z)). """ # Feature extraction fpn_outs = self.fpn(img) if self.cf.model == 'retina_unet': seg_logits = self.final_conv(fpn_outs[0]) selected_fmaps = [fpn_outs[i + 1] for i in self.cf.pyramid_levels] else: seg_logits = None selected_fmaps = [fpn_outs[i] for i in self.cf.pyramid_levels] # Loop through pyramid layers class_layer_outputs, bb_reg_layer_outputs, roi_reg_layer_outputs = [], [], [] # list of lists for p in selected_fmaps: class_layer_outputs.append(self.classifier(p)) bb_reg_layer_outputs.append(self.bb_regressor(p)) roi_reg_layer_outputs.append(self.roi_regressor(p)) # Concatenate layer outputs # Convert from list of lists of level outputs to list of lists # of outputs across levels. # e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]] class_logits = list(zip(*class_layer_outputs)) class_logits = [torch.cat(list(o), dim=1) for o in class_logits][0] bb_outputs = list(zip(*bb_reg_layer_outputs)) bb_outputs = [torch.cat(list(o), dim=1) for o in bb_outputs][0] if not 0 == roi_reg_layer_outputs[0][0].shape[0]: rg_outputs = list(zip(*roi_reg_layer_outputs)) rg_outputs = [torch.cat(list(o), dim=1) for o in rg_outputs][0] else: if self.cf.dim == 2: n_feats = np.array([p.shape[-2] * p.shape[-1] * self.cf.n_anchors_per_pos for p in selected_fmaps]).sum() else: n_feats = np.array([p.shape[-3]*p.shape[-2]*p.shape[-1]*self.cf.n_anchors_per_pos for p in selected_fmaps]).sum() rg_outputs = torch.zeros((selected_fmaps[0].shape[0], n_feats, self.cf.regression_n_features), dtype=torch.float32).fill_(float('NaN')).cuda() # merge batch_dimension and store info in batch_ixs for re-allocation. batch_ixs = torch.arange(class_logits.shape[0]).unsqueeze(1).repeat(1, class_logits.shape[1]).view(-1).cuda() flat_class_softmax = F.softmax(class_logits.view(-1, class_logits.shape[-1]), 1) flat_bb_outputs = bb_outputs.view(-1, bb_outputs.shape[-1]) flat_rg_outputs = rg_outputs.view(-1, rg_outputs.shape[-1]) detections = refine_detections(self.anchors, flat_class_softmax, flat_bb_outputs, flat_rg_outputs, batch_ixs, self.cf) return detections, class_logits, bb_outputs, rg_outputs, seg_logits def get_results(self, img_shape, detections, seg_logits, box_results_list=None): """ Restores batch dimension of merged detections, unmolds detections, creates and fills results dict. :param img_shape: :param detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score, pred_regression) :param box_results_list: None or list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, 1] only fg. vs. bg for now. class-specific return of masks will come with implementation of instance segmentation evaluation. """ detections = detections.cpu().data.numpy() batch_ixs = detections[:, self.cf.dim*2] detections = [detections[batch_ixs == ix] for ix in range(img_shape[0])] if box_results_list == None: # for test_forward, where no previous list exists. box_results_list = [[] for _ in range(img_shape[0])] for ix in range(img_shape[0]): if not 0 in detections[ix].shape: boxes = detections[ix][:, :2 * self.cf.dim].astype(np.int32) class_ids = detections[ix][:, 2 * self.cf.dim + 1].astype(np.int32) scores = detections[ix][:, 2 * self.cf.dim + 2] regressions = detections[ix][:, 2 * self.cf.dim + 3:] # Filter out detections with zero area. Often only happens in early # stages of training when the network weights are still a bit random. if self.cf.dim == 2: exclude_ix = np.where((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0] else: exclude_ix = np.where( (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4]) <= 0)[0] if exclude_ix.shape[0] > 0: boxes = np.delete(boxes, exclude_ix, axis=0) class_ids = np.delete(class_ids, exclude_ix, axis=0) scores = np.delete(scores, exclude_ix, axis=0) regressions = np.delete(regressions, exclude_ix, axis=0) if not 0 in boxes.shape: for ix2, score in enumerate(scores): if score >= self.cf.model_min_confidence: box = {'box_type': 'det', 'box_coords': boxes[ix2], 'box_score': score, 'box_pred_class_id': class_ids[ix2]} if "regression_bin" in self.cf.prediction_tasks: # in this case, regression preds are actually the rg_bin_ids --> map to rg value the bin stands for box['rg_bin'] = regressions[ix2].argmax() box['regression'] = self.cf.bin_id2rg_val[box['rg_bin']] else: box['regression'] = regressions[ix2] if hasattr(self.cf, "rg_val_to_bin_id") and \ any(['regression' in task for task in self.cf.prediction_tasks]): box['rg_bin'] = self.cf.rg_val_to_bin_id(regressions[ix2]) box_results_list[ix].append(box) results_dict = {} results_dict['boxes'] = box_results_list if seg_logits is None: # output dummy segmentation for retina_net. out_logits_shape = list(img_shape) out_logits_shape[1] = self.cf.num_seg_classes results_dict['seg_preds'] = np.zeros(out_logits_shape, dtype=np.float16) #todo: try with seg_preds=None? as to not carry heavy dummy preds. else: # output label maps for retina_unet. results_dict['seg_preds'] = F.softmax(seg_logits, 1).cpu().data.numpy() return results_dict def train_forward(self, batch, is_validation=False): """ train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data for processing, computes losses, and stores outputs in a dictionary. :param batch: dictionary containing 'data', 'seg', etc. :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixelwise segmentation output (b, c, y, x, (z)) with values [0, .., n_classes]. 'torch_loss': 1D torch tensor for backprop. 'class_loss': classification loss for monitoring. """ img = batch['data'] gt_class_ids = batch['class_targets'] gt_boxes = batch['bb_target'] if 'regression' in self.cf.prediction_tasks: gt_regressions = batch["regression_targets"] elif 'regression_bin' in self.cf.prediction_tasks: gt_regressions = batch["rg_bin_targets"] else: gt_regressions = None if self.cf.model == 'retina_unet': var_seg_ohe = torch.FloatTensor(mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).cuda() var_seg = torch.LongTensor(batch['seg']).cuda() img = torch.from_numpy(img).float().cuda() torch_loss = torch.FloatTensor([0]).cuda() # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. box_results_list = [[] for _ in range(img.shape[0])] detections, class_logits, pred_deltas, pred_rgs, seg_logits = self.forward(img) # loop over batch for b in range(img.shape[0]): # add gt boxes to results dict for monitoring. if len(gt_boxes[b]) > 0: for tix in range(len(gt_boxes[b])): gt_box = {'box_type': 'gt', 'box_coords': batch['bb_target'][b][tix]} for name in self.cf.roi_items: gt_box.update({name: batch[name][b][tix]}) box_results_list[b].append(gt_box) # match gt boxes with anchors to generate targets. anchor_class_match, anchor_target_deltas, anchor_target_rgs = gt_anchor_matching( self.cf, self.np_anchors, gt_boxes[b], gt_class_ids[b], gt_regressions[b] if gt_regressions is not None else None) # add positive anchors used for loss to results_dict for monitoring. pos_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere(anchor_class_match > 0)][:, 0], img.shape[2:]) for p in pos_anchors: box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'}) else: anchor_class_match = np.array([-1]*self.np_anchors.shape[0]) anchor_target_deltas = np.array([]) anchor_target_rgs = np.array([]) anchor_class_match = torch.from_numpy(anchor_class_match).cuda() anchor_target_deltas = torch.from_numpy(anchor_target_deltas).float().cuda() anchor_target_rgs = torch.from_numpy(anchor_target_rgs).float().cuda() if self.cf.focal_loss: # compute class loss as focal loss as suggested in original publication, but multi-class. class_loss = compute_focal_class_loss(anchor_class_match, class_logits[b], gamma=self.cf.focal_loss_gamma) # sparing appendix of negative anchors for monitoring as not really relevant else: # compute class loss with SHEM. class_loss, neg_anchor_ix = compute_class_loss(anchor_class_match, class_logits[b]) # add negative anchors used for loss to results_dict for monitoring. neg_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere(anchor_class_match.cpu().numpy() == -1)][neg_anchor_ix, 0], img.shape[2:]) for n in neg_anchors: box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'}) rg_loss = compute_rg_loss(self.cf.prediction_tasks, anchor_target_rgs, pred_rgs[b], anchor_class_match) bbox_loss = compute_bbox_loss(anchor_target_deltas, pred_deltas[b], anchor_class_match) torch_loss += (class_loss + bbox_loss + rg_loss) / img.shape[0] results_dict = self.get_results(img.shape, detections, seg_logits, box_results_list) results_dict['seg_preds'] = results_dict['seg_preds'].argmax(axis=1).astype('uint8')[:, np.newaxis] if self.cf.model == 'retina_unet': seg_loss_dice = 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1),var_seg_ohe) seg_loss_ce = F.cross_entropy(seg_logits, var_seg[:, 0]) torch_loss += (seg_loss_dice + seg_loss_ce) / 2 #self.logger.info("loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}, seg dice: {3:.3f}, seg ce: {4:.3f}, " # "mean pixel preds: {5:.5f}".format(torch_loss.item(), batch_class_loss.item(), batch_bbox_loss.item(), # seg_loss_dice.item(), seg_loss_ce.item(), np.mean(results_dict['seg_preds']))) if 'dice' in self.cf.metrics: results_dict['batch_dices'] = mutils.dice_per_batch_and_class( results_dict['seg_preds'], batch["seg"], self.cf.num_seg_classes, convert_to_ohe=True) #else: #self.logger.info("loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}".format( # torch_loss.item(), class_loss.item(), bbox_loss.item())) results_dict['torch_loss'] = torch_loss results_dict['class_loss'] = class_loss.item() return results_dict def test_forward(self, batch, **kwargs): """ test method. wrapper around forward pass of network without usage of any ground truth information. prepares input data for processing and stores outputs in a dictionary. :param batch: dictionary containing 'data' :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': actually contain seg probabilities since evaluated to seg_preds (via argmax) in predictor. or dummy seg logits for real retina net (detection only) """ img = torch.from_numpy(batch['data']).float().cuda() detections, _, _, _, seg_logits = self.forward(img) results_dict = self.get_results(img.shape, detections, seg_logits) return results_dict \ No newline at end of file diff --git a/plotting.py b/plotting.py index 1bb78aa..72b29ee 100644 --- a/plotting.py +++ b/plotting.py @@ -1,2139 +1,2139 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import matplotlib # matplotlib.rcParams['font.family'] = ['serif'] # matplotlib.rcParams['font.serif'] = ['Times New Roman'] matplotlib.rcParams['mathtext.fontset'] = 'cm' matplotlib.rcParams['font.family'] = 'STIXGeneral' matplotlib.use('Agg') #complains with spyder editor, bc spyder imports mpl at startup from matplotlib.ticker import FormatStrFormatter import matplotlib.colors as mcolors import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import matplotlib.patches as mpatches from matplotlib.ticker import StrMethodFormatter, ScalarFormatter import SimpleITK as sitk from tensorboard.backend.event_processing.event_multiplexer import EventMultiplexer import sys import os import warnings import time from copy import deepcopy import numpy as np import pandas as pd import scipy.interpolate as interpol from utils.exp_utils import IO_safe warnings.filterwarnings("ignore", module="matplotlib.image") def make_colormap(seq): """ Return a LinearSegmentedColormap seq: a sequence of floats and RGB-tuples. The floats should be increasing and in the interval (0,1). """ seq = [(None,) * 3, 0.0] + list(seq) + [1.0, (None,) * 3] cdict = {'red': [], 'green': [], 'blue': []} for i, item in enumerate(seq): if isinstance(item, float): r1, g1, b1 = seq[i - 1] r2, g2, b2 = seq[i + 1] cdict['red'].append([item, r1, r2]) cdict['green'].append([item, g1, g2]) cdict['blue'].append([item, b1, b2]) return mcolors.LinearSegmentedColormap('CustomMap', cdict) bw_cmap = make_colormap([(1.,1.,1.), (0.,0.,0.)]) #------------------------------------------------------------------------ #------------- plotting functions, not all are used --------------------- def shape_small_first(shape): """sort a tuple so that the smallest entry is swapped to the beginning """ if len(shape) <= 2: # no changing dimensions if channel-dim is missing return shape smallest_dim = np.argmin(shape) if smallest_dim != 0: # assume that smallest dim is color channel new_shape = np.array(shape) # to support mask indexing new_shape = (new_shape[smallest_dim], *new_shape[(np.arange(len(shape), dtype=int) != smallest_dim)]) return new_shape else: return shape def RGB_to_rgb(RGB): rgb = np.array(RGB) / 255. return rgb def mod_to_rgb(arr, cmap=None): """convert a single-channel modality img to 3-color-channel img. :param arr: input img, expected in shape (b,c,)x,y with c=1 :return: img of shape (...,c') with c'=3 """ if len(arr.shape) == 3: arr = np.squeeze(arr) elif len(arr.shape) != 2: raise Exception("Invalid input arr shape: {}".format(arr.shape)) if cmap is None: cmap = "gray" norm = matplotlib.colors.Normalize() norm.autoscale(arr) arr = norm(arr) arr = np.stack((arr,) * 3, axis=-1) return arr def to_rgb(arr, cmap): """ Transform an integer-labeled segmentation map using an rgb color-map. :param arr: img_arr w/o a color-channel :param cmap: dictionary mapping from integer class labels to rgb values :return: img of shape (...,c) """ new_arr = np.zeros(shape=(arr.shape) + (3,)) for l in cmap.keys(): ixs = np.where(arr == l) new_arr[ixs] = np.array([cmap[l][i] for i in range(3)]) return new_arr def to_rgba(arr, cmap): """ Transform an integer-labeled segmentation map using an rgba color-map. :param arr: img_arr w/o a color-channel :param cmap: dictionary mapping from integer class labels to rgba values :return: new array holding rgba-image """ new_arr = np.zeros(shape=(arr.shape) + (4,)) for lab, val in cmap.items(): # in case no alpha, complement with 100% alpha if len(val) == 3: cmap[lab] = (*val, 1.) assert len(cmap[lab]) == 4, "cmap has color with {} entries".format(len(val)) for lab in cmap.keys(): ixs = np.where(arr == lab) rgb = np.array(cmap[lab][:3]) new_arr[ixs] = np.append(rgb, cmap[lab][3]) return new_arr def bin_seg_to_rgba(arr, color): """ Transform a continuously labelled binary segmentation map using an rgba color-map. values are expected to be 0-1, will give alpha-value :param arr: img_arr w/o a color-channel :param color: color to give img :return: new array holding rgba-image """ new_arr = np.zeros(shape=(arr.shape) + (4,)) for i in range(arr.shape[0]): for j in range(arr.shape[1]): new_arr[i][j] = (*color, arr[i][j]) return new_arr def suppress_axes_lines(ax): """ :param ax: pyplot axes object """ ax.axes.get_xaxis().set_ticks([]) ax.axes.get_yaxis().set_ticks([]) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.spines['left'].set_visible(False) return def label_bar(ax, rects, labels=None, colors=None, fontsize=10): """Attach a text label above each bar displaying its height :param ax: :param rects: rectangles as returned by plt.bar() :param labels: :param colors: """ for ix, rect in enumerate(rects): height = rect.get_height() if labels is not None and labels[ix] is not None: label = labels[ix] else: label = '{:g}'.format(height) if colors is not None and colors[ix] is not None and np.any(np.array(colors[ix])<1): color = colors[ix] else: color = 'black' ax.text(rect.get_x() + rect.get_width() / 2., 1.007 * height, label, color=color, ha='center', va='bottom', bbox=dict(facecolor=(1., 1., 1.), edgecolor='none', clip_on=True, pad=0, alpha=0.75), fontsize=fontsize) def draw_box_into_arr(arr, box_coords, box_color=None, lw=2): """ :param arr: imgs shape, (3,y,x) :param box_coords: (x1,y1,x2,y2), in ascending order :param box_color: arr of shape (3,) :param lw: linewidth in pixels """ if box_color is None: box_color = [1., 0.4, 0.] (x1, y1, x2, y2) = box_coords[:4] arr = np.swapaxes(arr, 0, -1) arr[..., y1:y2, x1:x1 + lw, :], arr[..., y1:y2 + lw, x2:x2 + lw, :] = box_color, box_color arr[..., y1:y1 + lw, x1:x2, :], arr[..., y2:y2 + lw, x1:x2, :] = box_color, box_color arr = np.swapaxes(arr, 0, -1) return arr def draw_boxes_into_batch(imgs, batch_boxes, type2color=None, cmap=None): """ :param imgs: either the actual batch imgs or a tuple with shape of batch imgs, need to have 3 color channels, need to be rgb; """ if isinstance(imgs, tuple): img_oshp = imgs imgs = None else: img_oshp = imgs[0].shape img_shp = shape_small_first(img_oshp) # c,x/y,y/x now imgs = np.reshape(imgs, (-1, *img_shp)) box_imgs = np.empty((len(batch_boxes), *(img_shp))) for sample, boxes in enumerate(batch_boxes): # imgs in batch have shape b,c,x,y, swap c to end sample_img = np.full(img_shp, 1.) if imgs is None else imgs[sample] for box in boxes: if len(box["box_coords"]) > 0: if type2color is not None and "box_type" in box.keys(): sample_img = draw_box_into_arr(sample_img, box["box_coords"].astype(np.int32), type2color[box["box_type"]]) else: sample_img = draw_box_into_arr(sample_img, box["box_coords"].astype(np.int32)) box_imgs[sample] = sample_img return box_imgs def plot_prediction_hist(cf, spec_df, outfile, title=None, fs=11, ax=None): labels = spec_df.class_label.values preds = spec_df.pred_score.values type_list = spec_df.det_type.tolist() if hasattr(spec_df, "det_type") else None if title is None: title = outfile.split('/')[-1] + ' count:{}'.format(len(labels)) close=False if ax is None: fig = plt.figure(tight_layout=True) ax = fig.add_subplot(1,1,1) close=True ax.set_yscale('log') ax.set_xlabel("Prediction Score", fontsize=fs) ax.set_ylabel("Occurences", fontsize=fs) ax.hist(preds[labels == 0], alpha=0.3, color=cf.red, range=(0, 1), bins=50, label="fp") ax.hist(preds[labels == 1], alpha=0.3, color=cf.blue, range=(0, 1), bins=50, label="fn at score 0 and tp") ax.axvline(x=cf.min_det_thresh, alpha=1, color=cf.orange, linewidth=1.5, label="min det thresh") if type_list is not None: fp_count = type_list.count('det_fp') fn_count = type_list.count('det_fn') tp_count = type_list.count('det_tp') pos_count = fn_count + tp_count title += '\ntp:{} fp:{} fn:{} pos:{}'.format(tp_count, fp_count, fn_count, pos_count) ax.set_title(title, fontsize=fs) ax.tick_params(axis='both', which='major', labelsize=fs) ax.tick_params(axis='both', which='minor', labelsize=fs) if close: ax.legend(loc="best", fontsize=fs) if cf.server_env: IO_safe(plt.savefig, fname=outfile, _raise=False) else: plt.savefig(outfile) pass plt.close() def plot_wbc_n_missing(cf, df, outfile, fs=11, ax=None): """ WBC (weighted box clustering) has parameter n_missing, which shows how many boxes are missing per cluster. This function plots the average relative amount of missing boxes sorted by cluster score. :param cf: config. :param df: dataframe. :param outfile: path to save image under. :param fs: fontsize. :param ax: axes object. """ bins = np.linspace(0., 1., 10) names = ["{:.1f}".format((bins[i]+(bins[i+1]-bins[i])/2.)*100) for i in range(len(bins)-1)] classes = df.pred_class.unique() colors = [cf.class_id2label[cl_id].color for cl_id in classes] binned_df = df.copy() binned_df.loc[:,"pred_score"] = pd.cut(binned_df["pred_score"], bins) close=False if ax is None: ax = plt.subplot() close=True width = 1 / (len(classes) + 1) group_positions = np.arange(len(names)) legend_handles = [] for ix, cl_id in enumerate(classes): cl_df = binned_df[binned_df.pred_class==cl_id].groupby("pred_score").agg({"cluster_n_missing": 'mean'}) ax.bar(group_positions + ix * width, cl_df.cluster_n_missing.values, width=width, color=colors[ix], alpha=0.4 + ix / 2 / len(classes), edgecolor=colors[ix]) legend_handles.append(mpatches.Patch(color=colors[ix], label=cf.class_dict[cl_id])) title = "Fold {} WBC Missing Preds\nAverage over scores and classes: {:.1f}%".format(cf.fold, df.cluster_n_missing.mean()) ax.set_title(title, fontsize=fs) ax.legend(handles=legend_handles, title="Class", loc="best", fontsize=fs, title_fontsize=fs) ax.set_xticks(group_positions + (len(classes) - 1) * width / 2) # ax.xaxis.set_major_formatter(StrMethodFormatter('{x:.1f}')) THIS WONT WORK... no clue! ax.set_xticklabels(names) ax.tick_params(axis='both', which='major', labelsize=fs) ax.tick_params(axis='both', which='minor', labelsize=fs) ax.set_axisbelow(True) ax.grid() ax.set_ylabel(r"Average Missing Preds per Cluster (%)", fontsize=fs) ax.set_xlabel("Prediction Score", fontsize=fs) if close: if cf.server_env: IO_safe(plt.savefig, fname=outfile, _raise=False) else: plt.savefig(outfile) plt.close() def plot_stat_curves(cf, stats, outfile, fill=False): """ Plot precision-recall and/or receiver-operating-characteristic curve(s). :param cf: config. :param stats: statistics as supplied by Evaluator. :param outfile: path to save plot under. :param fill: whether to colorize space between plot and x-axis. :return: """ for c in ['roc', 'prc']: plt.figure() empty_plot = True for ix, s in enumerate(stats): if s[c] is not np.nan: plt.plot(s[c][1], s[c][0], label=s['name'] + '_' + c, marker=None, color=cf.color_palette[ix%len(cf.color_palette)]) empty_plot = False if fill: plt.fill_between(s[c][1], s[c][0], alpha=0.33, color=cf.color_palette[ix%len(cf.color_palette)]) if not empty_plot: plt.title(outfile.split('/')[-1] + '_' + c) plt.legend(loc=3 if c == 'prc' else 4) plt.ylabel('precision' if c == 'prc' else '1-spec.') plt.ylim((0.,1)) plt.xlabel('recall') plt.savefig(outfile + '_' + c) plt.close() def plot_grouped_bar_chart(cf, bar_values, groups, splits, colors=None, alphas=None, errors=None, ylabel='', xlabel='', xticklabels=None, yticks=None, yticklabels=None, ylim=None, label_format="{:.3f}", title=None, ax=None, out_file=None, legend=False, fs=11): """ Plot a categorically grouped bar chart. :param cf: config. :param bar_values: values of the bars. :param groups: groups/categories that bars belong to. :param splits: splits within groups, i.e., names of bars. :param colors: colors. :param alphas: 1-opacity. :param errors: values for errorbars. :param ylabel: label of y-axis. :param xlabel: label of x-axis. :param title: plot title. :param ax: axes object to draw into. if None, new is created. :param out_file: path to save plot. :param legend: whether to show a legend. :param fs: fontsize. :return: legend handles. """ bar_values = np.array(bar_values) if alphas is None: alphas = [1.,] * len(splits) if colors is None: colors = [cf.color_palette[ix%len(cf.color_palette)] for ix in range(len(splits))] if errors is None: errors = np.zeros_like(bar_values) # patterns = ('/', '\\', '*', 'O', '.', '-', '+', 'x', 'o') # patterns = tuple([patterns[ix%len(patterns)] for ix in range(len(splits))]) close=False if ax is None: ax = plt.subplot() close=True width = 1 / (len(splits) +0.25) group_positions = np.arange(len(groups)) for ix, split in enumerate(splits): rects = ax.bar(group_positions + ix * width, bar_values[ix], width=width, color=(*colors[ix], 0.8), edgecolor=colors[ix], yerr=errors[ix], ecolor=(*np.array(colors[ix])*0.8, 1.), capsize=5) # for ix, bar in enumerate(rects): # bar.set_hatch(patterns[ix]) labels = [label_format.format(val) for val in bar_values[ix]] label_bar(ax, rects, labels, [colors[ix]]*len(labels), fontsize=fs) legend_handles = [mpatches.Patch(color=colors[ix], alpha=alphas[ix], label=split) for ix, split in enumerate(splits)] if legend: ax.legend(handles=legend_handles, fancybox=True, framealpha=1., loc="lower center") legend_handles = [(colors[ix], alphas[ix], split) for ix, split in enumerate(splits)] if title is not None: ax.set_title(title, fontsize=fs) ax.set_xticks(group_positions + (len(splits) - 1) * width / 2) if xticklabels is None: ax.set_xticklabels(groups, fontsize=fs) else: ax.set_xticklabels(xticklabels, fontsize=fs) ax.set_axisbelow(True) ax.set_xlabel(xlabel, fontsize=fs) ax.tick_params(labelsize=fs) ax.grid(axis='y') ax.set_ylabel(ylabel, fontsize=fs) if yticks is not None: ax.set_yticks(yticks) if yticklabels is not None: ax.set_yticklabels(yticklabels, fontsize=fs) if ylim is not None: ax.set_ylim(ylim) if out_file is not None: plt.savefig(out_file, dpi=600) if close: plt.close() return legend_handles def plot_binned_rater_dissent(cf, binned_stats, out_file=None, ax=None, legend=True, fs=11): """ LIDC-specific plot: rater disagreement as standard deviations within each bin. :param cf: config. :param binned_stats: list, ix==bin_id, item: [(roi_mean, roi_std, roi_max, roi_bin_id-roi_max_bin_id) for roi in bin] :return: """ dissent = [np.array([roi[1] for roi in bin]) for bin in binned_stats] avg_dissent_first_degree = [np.mean(bin) for bin in dissent] groups = list(cf.bin_id2label.keys()) splits = [r"$1^{st}$ std. dev.",] colors = [cf.bin_id2label[bin_id].color[:3] for bin_id in groups] #colors = [cf.blue for bin_id in groups] alphas = [0.9,] #patterns = ('/', '\\', '*', 'O', '.', '-', '+', 'x', 'o') #patterns = tuple([patterns[ix%len(patterns)] for ix in range(len(splits))]) close=False if ax is None: ax = plt.subplot() close=True width = 1/(len(splits)+1) group_positions = np.arange(len(groups)) #total_counts = [df.loc[split].sum() for split in splits] dissent = np.array(avg_dissent_first_degree) ix=0 rects = ax.bar(group_positions+ix*width, dissent, color=colors, alpha=alphas[ix], edgecolor=colors) #for ix, bar in enumerate(rects): #bar.set_hatch(patterns[ix]) labels = ["{:.2f}".format(diss) for diss in dissent] label_bar(ax, rects, labels, colors, fontsize=fs) bin_edge_color = cf.blue ax.axhline(y=0.5, color=bin_edge_color) ax.text(2.5, 0.38, "bin edge", color=cf.white, fontsize=fs, horizontalalignment="center", bbox=dict(boxstyle='round', facecolor=(*bin_edge_color, 0.85), edgecolor='none', clip_on=True, pad=0)) if legend: legend_handles = [mpatches.Patch(color=cf.blue ,alpha=alphas[ix], label=split) for ix, split in enumerate(splits)] ax.legend(handles=legend_handles, loc='lower center', fontsize=fs) title = "LIDC-IDRI: Average Std Deviation per Lesion" plt.title(title) ax.set_xticks(group_positions + (len(splits)-1)*width/2) ax.set_xticklabels(groups, fontsize=fs) ax.set_axisbelow(True) #ax.tick_params(axis='both', which='major', labelsize=fs) #ax.tick_params(axis='both', which='minor', labelsize=fs) ax.grid() ax.set_ylabel(r"Average Dissent (MS)", fontsize=fs) ax.set_xlabel("binned malignancy-score value (ms)", fontsize=fs) ax.tick_params(labelsize=fs) if out_file is not None: plt.savefig(out_file, dpi=600) if close: plt.close() return def plot_confusion_matrix(cf, cm, out_file=None, ax=None, fs=11, cmap=plt.cm.Blues, color_bar=True): """ Plot a confusion matrix. :param cf: config. :param cm: confusion matrix, e.g., as supplied by metrics.confusion_matrix from scikit-learn. :return: """ close=False if ax is None: ax = plt.subplot() close=True im = ax.imshow(cm, interpolation='nearest', cmap=cmap) if color_bar: ax.figure.colorbar(im, ax=ax) # Rotate the tick labels and set their alignment. #plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. fmt = '.0%' if np.mod(cm, 1).any() else 'd' thresh = cm.max() / 2. for i in range(cm.shape[0]): for j in range(cm.shape[1]): ax.text(j, i, format(cm[i, j], fmt), ha="center", va="center", color="white" if cm[i, j] > thresh else "black") ax.set_ylabel(r"Binned Mean MS", fontsize=fs) ax.set_xlabel("Single-Annotator MS", fontsize=fs) #ax.tick_params(labelsize=fs) if close and out_file is not None: plt.savefig(out_file, dpi=600) if close: plt.close() else: return ax def plot_data_stats(cf, df, labels=None, out_file=None, ax=None, fs=11): """ Plot data-set statistics. Shows target counts. Mainly used by Dataset Class in dataloader.py. :param cf: configs obj :param df: pandas dataframe :param out_file: path to save fig in """ names = df.columns if labels is not None: colors = [label.color for name in names for label in labels if label.name==name] else: colors = [cf.color_palette[ix%len(cf.color_palette)] for ix in range(len(names))] #patterns = ('/', '\\', '*', 'O', '.', '-', '+', 'x', 'o') #patterns = tuple([patterns[ix%len(patterns)] for ix in range(len(splits))]) if ax is None: fig, ax = plt.subplots(figsize=(14,6), dpi=300) return_ax = False else: return_ax = True plt.margins(x=0.01) plt.subplots_adjust(bottom=0.15) bar_positions = np.arange(len(names)) name_counts = df.sum() total_count = name_counts.sum() rects = ax.bar(bar_positions, name_counts, color=colors, alpha=0.9, edgecolor=colors) labels = ["{:.0f}%".format(count/ total_count*100) for count in name_counts] label_bar(ax, rects, labels, colors, fontsize=fs) title= "Data Set RoI-Target Balance\nTotal #RoIs: {}".format(int(total_count)) ax.set_title(title, fontsize=fs) ax.set_xticks(bar_positions) rotation = "vertical" if np.any([len(str(name)) > 3 for name in names]) else None if all([isinstance(name, (float, int)) for name in names]): ax.set_xticklabels(["{:.2f}".format(name) for name in names], rotation=rotation, fontsize=fs) else: ax.set_xticklabels(names, rotation=rotation, fontsize=fs) ax.set_axisbelow(True) ax.grid() ax.set_ylabel(r"#RoIs", fontsize=fs) ax.set_xlabel(str(df._metadata[0]), fontsize=fs) ax.tick_params(axis='both', which='major', labelsize=fs) ax.tick_params(axis='both', which='minor', labelsize=fs) if out_file is not None: plt.savefig(out_file) if return_ax: return ax else: plt.close() def plot_fold_stats(cf, df, labels=None, out_file=None, ax=None): """ Similar as plot_data_stats but per single cross-val fold. :param cf: configs obj :param df: pandas dataframe :param out_file: path to save fig in """ names = df.columns splits = df.index if labels is not None: colors = [label.color for name in names for label in labels if label.name==name] else: colors = [cf.color_palette[ix%len(cf.color_palette)] for ix in range(len(names))] #patterns = ('/', '\\', '*', 'O', '.', '-', '+', 'x', 'o') #patterns = tuple([patterns[ix%len(patterns)] for ix in range(len(splits))]) if ax is None: ax = plt.subplot() return_ax = False else: return_ax = True width = 1/(len(names)+1) group_positions = np.arange(len(splits)) legend_handles = [] total_counts = [df.loc[split].sum() for split in splits] for ix, name in enumerate(names): rects = ax.bar(group_positions+ix*width, df.loc[:,name], width=width, color=colors[ix], alpha=0.9, edgecolor=colors[ix]) #for ix, bar in enumerate(rects): #bar.set_hatch(patterns[ix]) labels = ["{:.0f}%".format(df.loc[split, name]/ total_counts[ii]*100) for ii, split in enumerate(splits)] label_bar(ax, rects, labels, [colors[ix]]*len(group_positions)) legend_handles.append(mpatches.Patch(color=colors[ix] ,alpha=0.9, label=name)) title= "Fold {} RoI-Target Balances\nTotal #RoIs: {}".format(cf.fold, int(df.values.sum())) plt.title(title) ax.legend(handles=legend_handles) ax.set_xticks(group_positions + (len(names)-1)*width/2) ax.set_xticklabels(splits, rotation="vertical" if len(splits)>2 else None, size=12) ax.set_axisbelow(True) ax.grid() ax.set_ylabel(r"#RoIs") ax.set_xlabel("Set split") if out_file is not None: plt.savefig(out_file) if return_ax: return ax plt.close() def plot_batchgen_distribution(cf, pids, p_probs, balance_target, out_file=None): """plot top n_pids probabilities for drawing a pid into a batch. :param cf: experiment config object :param pids: sorted iterable of patient ids :param p_probs: pid's drawing likelihood, order needs to match the one of pids. :param out_file: :return: """ n_pids = len(pids) zip_sorted = np.array(sorted(list(zip(p_probs, pids)), reverse=True)) names, probs = zip_sorted[:n_pids,1], zip_sorted[:n_pids,0].astype('float32') * 100 try: names = [str(int(n)) for n in names] except ValueError: names = [str(n) for n in names] lowest_p = min(p_probs)*100 fig, ax = plt.subplots(1,1,figsize=(17,5), dpi=200) rects = ax.bar(names, probs, color=cf.blue, alpha=0.9, edgecolor=cf.blue) ax = plt.gca() ax.text(0.8, 0.92, "Lowest prob.: {:.5f}%".format(lowest_p), transform=ax.transAxes, color=cf.white, bbox=dict(boxstyle='round', facecolor=cf.blue, edgecolor='none', alpha=0.9)) ax.yaxis.set_major_formatter(StrMethodFormatter('{x:g}')) ax.set_xticklabels(names, rotation="vertical", fontsize=7) plt.margins(x=0.01) plt.subplots_adjust(bottom=0.15) if balance_target=="class_targets": balance_target = "Class" elif balance_target=="lesion_gleasons": balance_target = "GS" ax.set_title(str(balance_target)+"-Balanced Train Generator: Sampling Likelihood per PID") ax.set_axisbelow(True) ax.grid(axis='y') ax.set_ylabel("Sampling Likelihood (%)") ax.set_xlabel("PID") plt.tight_layout() if out_file is not None: plt.savefig(out_file) plt.close() def plot_batchgen_stats(cf, stats, empties, target_name, unique_ts, out_file=None): """Plot bar chart showing RoI frequencies and empty-sample count of batch stats recorded by BatchGenerator. :param cf: config. :param stats: statistics as supplied by BatchGenerator class. :param out_file: path to save plot. """ total_samples = cf.num_epochs*cf.num_train_batches*cf.batch_size if target_name=="class_targets": target_name = "Class" label_dict = {cl_id: label for (cl_id, label) in cf.class_id2label.items()} elif target_name=="lesion_gleasons": target_name = "Lesion's Gleason Score" label_dict = cf.gs2label elif target_name=="rg_bin_targets": target_name = "Regression-Bin ID" label_dict = cf.bin_id2label else: raise NotImplementedError names = [label_dict[t_id].name for t_id in unique_ts] colors = [label_dict[t_id].color for t_id in unique_ts] title = "Training Target Frequencies" title += "\nempty samples: {}".format(empties) rects = plt.bar(names, stats['roi_counts'], color=colors, alpha=0.9, edgecolor=colors) ax = plt.gca() ax.yaxis.set_major_formatter(StrMethodFormatter('{x:g}')) ax.set_title(title) ax.set_axisbelow(True) ax.grid() ax.set_ylabel(r"#RoIs") ax.set_xlabel(target_name) total_count = np.sum(stats["roi_counts"]) labels = ["{:.0f}%".format(count/total_count*100) for count in stats["roi_counts"]] label_bar(ax, rects, labels, colors) if out_file is not None: plt.savefig(out_file) plt.close() def view_3D_array(arr, outfile, elev=30, azim=30): from mpl_toolkits.mplot3d import Axes3D fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.set_aspect("equal") ax.set_xlabel("x") ax.set_ylabel("y") ax.set_zlabel("z") ax.voxels(arr) ax.view_init(elev=elev, azim=azim) plt.savefig(outfile) def view_batch(cf, batch, res_dict=None, out_file=None, legend=True, show_info=True, has_colorchannels=False, isRGB=True, show_seg_ids="all", show_seg_pred=True, show_gt_boxes=True, show_gt_labels=False, roi_items="all", sample_picks=None, vol_slice_picks=None, box_score_thres=None, plot_mods=True, - dpi=200, vmin=None, return_fig=False, get_time=True): + dpi=300, vmin=None, return_fig=False, get_time=True): r""" View data and target entries of a batch. Batch expected as dic with entries 'data' and 'seg' holding np.arrays of size :math:`batch\_size \times modalities \times h \times w` for data and :math:`batch\_size \times classes \times h \times w` or :math:`batch\_size \times 1 \times h \times w` for segs. Classes, even if just dummy, are always needed for plotting since they determine colors. Pyplot expects dimensions in order y,x,chans (height, width, chans) for imshow. :param cf: config. :param batch: batch. :param res_dict: results dictionary. :param out_file: path to save plot. :param legend: whether to show a legend. :param show_info: whether to show text info about img sizes and type in plot. :param has_colorchannels: whether image has color channels. :param isRGB: if image is RGB. :param show_seg_ids: "all" or None or list with seg classes to show (seg_ids) :param show_seg_pred: whether to the predicted segmentation. :param show_gt_boxes: whether to show ground-truth boxes. :param show_gt_labels: whether to show labels of ground-truth boxes. :param roi_items: which roi items to show: strings "all" or "targets". --> all roi items in cf.roi_items or only those which are targets, or list holding keys/names of entries in cf.roi_items to plot additionally on roi boxes. empty iterator to show none. :param sample_picks: which indices of the batch to display. None for all. :param vol_slice_picks: when batch elements are 3D: which slices to display. None for all, or tuples ("random", int: amt) / (float€[0,1]: fg_prob, int: amt) for random pick / fg_slices pick w probability fg_prob of amt slices. fg pick requires gt seg. :param box_score_thres: plot only boxes with pred_score > box_score_thres. None or 0. for no threshold. :param plot_mods: whether to plot input modality/modalities. :param dpi: graphics resolution. :param vmin: min value for gray-scale cmap in imshow, set to a fix value for inter-batch normalization, or None for intra-batch. :param return_fig: whether to return created figure. """ stime = time.time() # pfix = prefix, ptfix = postfix patched_patient = 'patch_crop_coords' in list(batch.keys()) pfix = 'patient_' if patched_patient else '' ptfix = '_2d' if (patched_patient and cf.dim == 2 and pfix + 'class_targets_2d' in batch.keys()) else '' # -------------- get data, set flags ----------------- try: btype = type(batch[pfix + 'data']) data = batch[pfix + 'data'].astype("float32") seg = batch[pfix + 'seg'] except AttributeError: # in this case: assume it's single-annotator ground truths btype = type(batch[pfix + 'data']) data = batch[pfix + 'data'].astype("float32") seg = batch[pfix + 'seg'][0] print("Showing only gts of rater 0") data_init_shp, seg_init_shp = data.shape, seg.shape seg = np.copy(seg) if show_seg_ids else None plot_bg = batch['plot_bg'] if 'plot_bg' in batch.keys() and not isinstance(batch['plot_bg'], (int, float)) else None plot_bg_chan = batch['plot_bg'] if 'plot_bg' in batch.keys() and isinstance(batch['plot_bg'], (int, float)) else 0 gt_boxes = batch[pfix+'bb_target'+ptfix] if pfix+'bb_target'+ptfix in batch.keys() and show_gt_boxes else None class_targets = batch[pfix+'class_targets'+ptfix] if pfix+'class_targets'+ptfix in batch.keys() else None cf_roi_items = [pfix+it+ptfix for it in cf.roi_items] if roi_items == "all": roi_items = [it for it in cf_roi_items] elif roi_items == "targets": roi_items = [it for it in cf_roi_items if 'targets' in it] else: roi_items = [it for it in cf_roi_items if it in roi_items] if res_dict is not None: seg_preds = res_dict["seg_preds"] if (show_seg_pred is not None and 'seg_preds' in res_dict.keys() and show_seg_ids) else None if '2D_boxes' in res_dict.keys(): assert cf.dim==2 pr_boxes = res_dict["2D_boxes"] elif 'boxes' in res_dict.keys(): pr_boxes = res_dict["boxes"] else: pr_boxes = None else: seg_preds = None pr_boxes = None # -------------- get shapes, apply sample selection ----------------- (n_samples, mods, h, w), d = data.shape[:4], 0 z_ics = [slice(None)] if has_colorchannels: #has to be 2D data = np.transpose(data, axes=(0, 2, 3, 1)) # now b,y,x,c mods = 1 else: if len(data.shape) == 5: # 3dim case d = data.shape[4] if vol_slice_picks is None: z_ics = np.arange(0, d) elif hasattr(vol_slice_picks, "__iter__") and vol_slice_picks[0]=="random": z_ics = np.random.choice(np.arange(0, d), size=min(vol_slice_picks[1], d), replace=False) else: z_ics = vol_slice_picks sample_ics = range(n_samples) # 8000 approx value of pixels that are displayable in one figure dim (pyplot has a render limit), depends on dpi however if data.shape[0]*data.shape[2]*len(z_ics)>8000: n_picks = max(1, int(8000/(data.shape[2]*len(z_ics)))) if len(z_ics)>1 and vol_slice_picks is None: z_ics = np.random.choice(np.arange(0, data.shape[4]), size=min(data.shape[4], max(1,int(8000/(n_picks*data.shape[2])))), replace=False) if sample_picks is None: sample_picks = np.random.choice(data.shape[0], n_picks, replace=False) if sample_picks is not None: sample_ics = [s for s in sample_picks if s in sample_ics] n_samples = len(sample_ics) if not plot_mods: mods = 0 if show_seg_ids=="all": show_seg_ids = np.unique(seg) if seg_preds is not None and not type(show_seg_ids)==str: seg_preds = np.copy(seg_preds) seg_preds = np.where(np.isin(seg_preds, show_seg_ids), seg_preds, 0) if seg is not None: if not type(show_seg_ids)==str: #to save time seg = np.where(np.isin(seg, show_seg_ids), seg, 0) legend_items = {cf.seg_id2label[seg_id] for seg_id in np.unique(seg) if seg_id != 0} # add seg labels else: legend_items = set() # -------------- setup figure ----------------- if isRGB: data = RGB_to_rgb(data) if plot_bg is not None: plot_bg = RGB_to_rgb(plot_bg) n_cols = mods if seg is not None or gt_boxes is not None: n_cols += 1 if seg_preds is not None or pr_boxes is not None: n_cols += 1 n_rows = n_samples*len(z_ics) grid = gridspec.GridSpec(n_rows, n_cols, wspace=0.01, hspace=0.0) fig = plt.figure(figsize=((n_cols + 1)*2, n_rows*2), tight_layout=True) title_fs = 12 # fontsize sample_ics, z_ics = sorted(sample_ics), sorted(z_ics) row = 0 # current row for s_count, s_ix in enumerate(sample_ics): for z_ix in z_ics: col = 0 # current col # ----visualise input data ------------- if has_colorchannels: if plot_mods: ax = fig.add_subplot(grid[row, col]) ax.imshow(data[s_ix][...,z_ix]) ax.axis("off") if row == 0: plt.title("Input", fontsize=title_fs) if col == 0: specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix) == slice else z_ix ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number col += 1 bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix][...,z_ix] else: for mod in range(mods): ax = fig.add_subplot(grid[row, col]) ax.imshow(data[s_ix, mod][...,z_ix], cmap="gray", vmin=vmin) suppress_axes_lines(ax) if row == 0: plt.title("Mod. " + str(mod), fontsize=title_fs) if col == 0: specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix)==slice else z_ix ylabel = str(specs[s_ix])[-5:]+"/"+str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number col += 1 bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix, plot_bg_chan][...,z_ix] # ---evtly visualise groundtruths------------------- if seg is not None or gt_boxes is not None: # img as bg for gt ax = fig.add_subplot(grid[row, col]) ax.imshow(bg_img, cmap="gray", vmin=vmin) if row == 0: plt.title("Ground Truth", fontsize=title_fs) if col == 0: specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix) == slice else z_ix ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number suppress_axes_lines(ax) else: plt.axis('off') col += 1 if seg is not None and seg.shape[1] == 1: ax.imshow(to_rgba(seg[s_ix][0][...,z_ix], cf.cmap), alpha=0.8) elif seg is not None: ax.imshow(to_rgba(np.argmax(seg[s_ix][...,z_ix], axis=0), cf.cmap), alpha=0.8) # gt bounding boxes if gt_boxes is not None and len(gt_boxes[s_ix]) > 0: for j, box in enumerate(gt_boxes[s_ix]): if d > 0: [z1, z2] = box[4:] if not (z1<=z_ix and z_ix<=z2): box = [] if len(box) > 0: [y1, x1, y2, x2] = box[:4] width, height = x2 - x1, y2 - y1 if class_targets is not None: label = cf.class_id2label[class_targets[s_ix][j]] legend_items.add(label) if show_gt_labels: text_poss, p = [(x1, y1), (x1, (y1+y2)//2)], 0 text_fs = title_fs // 3 if roi_items is not None: for name in roi_items: if name in cf_roi_items and batch[name][s_ix][j] is not None: if 'class_targets' in name and cf.plot_class_ids: text_x = x2 #- 2 * text_fs * (len(str(class_targets[s_ix][j]))) # avoid overlap of scores text_y = y1 #+ 2 * text_fs text_str = '{}'.format(class_targets[s_ix][j]) elif 'regression_targets' in name: text_x, text_y = (x2, y2) text_str = "[" + " ".join( ["{:.1f}".format(x) for x in batch[name][s_ix][j]]) + "]" elif 'rg_bin_targets' in name: text_x, text_y = (x1, y2) text_str = '{}'.format(batch[name][s_ix][j]) else: text_pos = text_poss.pop(0) text_x = text_pos[0] #- 2 * text_fs * len(str(batch[name][s_ix][j])) text_y = text_pos[1] #+ 2 * text_fs text_str = '{}'.format(batch[name][s_ix][j]) ax.text(text_x, text_y, text_str, color=cf.white, fontsize=text_fs, bbox=dict(facecolor=label.color, alpha=0.7, edgecolor='none', clip_on=True, pad=0)) p+=1 bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=label.color, facecolor='none') ax.add_patch(bbox) # -----evtly visualise predictions ------------- if pr_boxes is not None or seg_preds is not None: ax = fig.add_subplot(grid[row, col]) ax.imshow(bg_img, cmap="gray") ax.axis("off") col += 1 if row == 0: plt.title("Prediction", fontsize=title_fs) # ---------- pred boxes ------------------------- if pr_boxes is not None and len(pr_boxes[s_ix]) > 0: box_score_thres = cf.min_det_thresh if box_score_thres is None else box_score_thres for j, box in enumerate(pr_boxes[s_ix]): plot_box = box["box_type"] in ["det", "prop"] # , "pos_anchor", "neg_anchor"] if box["box_type"] == "det" and (float(box["box_score"]) <= box_score_thres or box["box_pred_class_id"] == 0): plot_box = False if plot_box: if d > 0: [z1, z2] = box["box_coords"][4:] if not (z1<=z_ix and z_ix<=z2): box = [] if len(box) > 0: [y1, x1, y2, x2] = box["box_coords"][:4] width, height = x2 - x1, y2 - y1 if box["box_type"] == "det": label = cf.class_id2label[box["box_pred_class_id"]] legend_items.add(label) text_x, text_y = x2, y1 id_text = str(box["box_pred_class_id"]) + "|" if cf.plot_class_ids else "" text_str = '{}{:.0f}'.format(id_text, box["box_score"] * 100) text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) edgecolor = label.color if 'regression' in box.keys(): text_x, text_y = x2, y2 id_text = "["+" ".join(["{:.1f}".format(x) for x in box["regression"]])+"]" #str(box["regression"]) #+ "|" if cf.plot_class_ids else "" if 'rg_uncertainty' in box.keys() and not np.isnan(box['rg_uncertainty']): id_text += " | {:.1f}".format(box['rg_uncertainty']) text_str = '{}'.format(id_text) #, box["box_score"] * 100) text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) if 'rg_bin' in box.keys(): text_x, text_y = x1, y2 text_str = '{}'.format(box["rg_bin"]) text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) else: label = cf.box_type2label[box["box_type"]] legend_items.add(label) edgecolor = label.color bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=edgecolor, facecolor='none') ax.add_patch(bbox) # ------------ pred segs -------- if seg_preds is not None: # and seg_preds.shape[1] == 1: if cf.class_specific_seg: ax.imshow(to_rgba(seg_preds[s_ix][0][...,z_ix], cf.cmap), alpha=0.8) else: ax.imshow(bin_seg_to_rgba(seg_preds[s_ix][0][...,z_ix], cf.orange), alpha=0.8) row += 1 # -----actions for all batch entries---------- if legend and len(legend_items) > 0: patches = [] for label in legend_items: if cf.plot_class_ids and type(label) != type(cf.box_labels[0]): id_text = str(label.id) + ":" else: id_text = "" patches.append(mpatches.Patch(color=label.color, label="{}{:.10s}".format(id_text, label.name))) # assumes one image gives enough y-space for 5 legend items ncols = max(1, len(legend_items) // (5 * n_samples)) plt.figlegend(handles=patches, loc="upper center", bbox_to_anchor=(0.99, 0.86), borderaxespad=0., ncol=ncols, bbox_transform=fig.transFigure, fontsize=int(2/3*title_fs)) # fig.set_size_inches(mods+3+ncols-1,1.5+1.2*n_samples) if show_info: plt.figtext(0, 0, "Batch content is of type\n{}\nand has shapes\n".format(btype) + \ "{} for 'data' and {} for 'seg'".format(data_init_shp, seg_init_shp)) if out_file is not None: if cf.server_env: IO_safe(plt.savefig, fname=out_file, dpi=dpi, pad_inches=0.0, bbox_inches='tight', _raise=False) else: plt.savefig(out_file, dpi=dpi, pad_inches=0.0, bbox_inches='tight') if get_time: print("generated {} in {:.3f}s".format("plot" if not isinstance(get_time, str) else get_time, time.time()-stime)) if return_fig: return plt.gcf() plt.clf() plt.close() def view_batch_paper(cf, batch, res_dict=None, out_file=None, legend=True, show_info=True, has_colorchannels=False, isRGB=True, show_seg_ids="all", show_seg_pred=True, show_gt_boxes=True, show_gt_labels=False, roi_items="all", split_ens_ics=False, server_env=True, sample_picks=None, vol_slice_picks=None, patient_items=False, box_score_thres=None, plot_mods=True, dpi=400, vmin=None, return_fig=False): r"""view data and target entries of a batch. batch expected as dic with entries 'data' and 'seg' holding tensors or nparrays of size :math:`batch\_size \times modalities \times h \times w` for data and :math:`batch\_size \times classes \times h \times w` or :math:`batch\_size \times 1 \times h \times w` for segs. Classes, even if just dummy, are always needed for plotting since they determine colors. :param cf: :param batch: :param res_dict: :param out_file: :param legend: :param show_info: :param has_colorchannels: :param isRGB: :param show_seg_ids: :param show_seg_pred: :param show_gt_boxes: :param show_gt_labels: :param roi_items: strings "all" or "targets" --> all roi items in cf.roi_items or only those which are targets, or list holding keys/names of entries in cf.roi_items to plot additionally on roi boxes. empty iterator to show none. :param split_ens_ics: :param server_env: :param sample_picks: which indices of the batch to display. None for all. :param vol_slice_picks: when batch elements are 3D: which slices to display. None for all, or tuples ("random", int: amt) / (float€[0,1]: fg_prob, int: amt) for random pick / fg_slices pick w probability fg_prob of amt slices. fg pick requires gt seg. :param patient_items: set to true if patient-wise batch items should be displayed (need to be contained in batch and marked via 'patient_' prefix. :param box_score_thres: plot only boxes with pred_score > box_score_thres. None or 0. for no thres. :param plot_mods: :param dpi: graphics resolution :param vmin: min value for gs cmap in imshow, set to fix inter-batch, or None for intra-batch. pyplot expects dimensions in order y,x,chans (height, width, chans) for imshow. show_seg_ids: "all" or None or list with seg classes to show (seg_ids) """ # pfix = prefix, ptfix = postfix pfix = 'patient_' if patient_items else '' ptfix = '_2d' if (patient_items and cf.dim==2) else '' # -------------- get data, set flags ----------------- btype = type(batch[pfix + 'data']) data = batch[pfix + 'data'].astype("float32") seg = batch[pfix + 'seg'] # seg = np.array(seg).mean(axis=0, keepdims=True) # seg[seg>0] = 1. print("Showing multirater GT") data_init_shp, seg_init_shp = data.shape, seg.shape fg_slices = np.where(np.sum(np.sum(np.squeeze(seg), axis=0), axis=0)>0)[0] if len(fg_slices)==0: print("skipping empty patient") return if vol_slice_picks is None: vol_slice_picks = fg_slices print("data shp, seg shp", data_init_shp, seg_init_shp) plot_bg = batch['plot_bg'] if 'plot_bg' in batch.keys() and not isinstance(batch['plot_bg'], (int, float)) else None plot_bg_chan = batch['plot_bg'] if 'plot_bg' in batch.keys() and isinstance(batch['plot_bg'], (int, float)) else 0 gt_boxes = batch[pfix+'bb_target'+ptfix] if pfix+'bb_target'+ptfix in batch.keys() and show_gt_boxes else None class_targets = batch[pfix+'class_targets'+ptfix] if pfix+'class_targets'+ptfix in batch.keys() else None cf_roi_items = [pfix+it+ptfix for it in cf.roi_items] if roi_items == "all": roi_items = [it for it in cf_roi_items] elif roi_items == "targets": roi_items = [it for it in cf_roi_items if 'targets' in it] else: roi_items = [it for it in cf_roi_items if it in roi_items] if res_dict is not None: seg_preds = res_dict["seg_preds"] if (show_seg_pred is not None and 'seg_preds' in res_dict.keys() and show_seg_ids) else None if '2D_boxes' in res_dict.keys(): assert cf.dim==2 pr_boxes = res_dict["2D_boxes"] elif 'boxes' in res_dict.keys(): pr_boxes = res_dict["boxes"] else: pr_boxes = None else: seg_preds = None pr_boxes = None # -------------- get shapes, apply sample selection ----------------- (n_samples, mods, h, w), d = data.shape[:4], 0 z_ics = [slice(None)] if has_colorchannels: #has to be 2D data = np.transpose(data, axes=(0, 2, 3, 1)) # now b,y,x,c mods = 1 else: if len(data.shape) == 5: # 3dim case d = data.shape[4] if vol_slice_picks is None: z_ics = np.arange(0, d) # elif hasattr(vol_slice_picks, "__iter__") and vol_slice_picks[0]=="random": # z_ics = np.random.choice(np.arange(0, d), size=min(vol_slice_picks[1], d), replace=False) else: z_ics = vol_slice_picks sample_ics = range(n_samples) # 8000 approx value of pixels that are displayable in one figure dim (pyplot has a render limit), depends on dpi however if data.shape[0]*data.shape[2]*len(z_ics)>8000: n_picks = max(1, int(8000/(data.shape[2]*len(z_ics)))) if len(z_ics)>1: if vol_slice_picks is None: z_ics = np.random.choice(np.arange(0, data.shape[4]), size=min(data.shape[4], max(1,int(8000/(n_picks*data.shape[2])))), replace=False) else: z_ics = np.random.choice(vol_slice_picks, size=min(len(vol_slice_picks), max(1,int(8000/(n_picks*data.shape[2])))), replace=False) if sample_picks is None: sample_picks = np.random.choice(data.shape[0], n_picks, replace=False) if sample_picks is not None: sample_ics = [s for s in sample_picks if s in sample_ics] n_samples = len(sample_ics) if not plot_mods: mods = 0 if show_seg_ids=="all": show_seg_ids = np.unique(seg) legend_items = set() # -------------- setup figure ----------------- if isRGB: data = RGB_to_rgb(data) if plot_bg is not None: plot_bg = RGB_to_rgb(plot_bg) n_cols = mods if seg is not None or gt_boxes is not None: n_cols += 1 if seg_preds is not None or pr_boxes is not None: n_cols += 1 n_rows = n_samples*len(z_ics) grid = gridspec.GridSpec(n_rows, n_cols, wspace=0.01, hspace=0.0) fig = plt.figure(figsize=((n_cols + 1)*2, n_rows*2), tight_layout=True) title_fs = 12 # fontsize sample_ics, z_ics = sorted(sample_ics), sorted(z_ics) row = 0 # current row for s_count, s_ix in enumerate(sample_ics): for z_ix in z_ics: col = 0 # current col # ----visualise input data ------------- if has_colorchannels: if plot_mods: ax = fig.add_subplot(grid[row, col]) ax.imshow(data[s_ix][...,z_ix]) ax.axis("off") if row == 0: plt.title("Input", fontsize=title_fs) if col == 0: # key = "spec" if "spec" in batch.keys() else "pid" specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix) == slice else z_ix ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number col += 1 bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix][...,z_ix] else: for mod in range(mods): ax = fig.add_subplot(grid[row, col]) ax.imshow(data[s_ix, mod][...,z_ix], cmap="gray", vmin=vmin) suppress_axes_lines(ax) if row == 0: plt.title("Mod. " + str(mod), fontsize=title_fs) if col == 0: # key = "spec" if "spec" in batch.keys() else "pid" specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix)==slice else z_ix ylabel = str(specs[s_ix])[-5:]+"/"+str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number col += 1 bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix, plot_bg_chan][...,z_ix] # ---evtly visualise groundtruths------------------- if seg is not None or gt_boxes is not None: # img as bg for gt ax = fig.add_subplot(grid[row, col]) ax.imshow(bg_img, cmap="gray", vmin=vmin) if row == 0: plt.title("Ground Truth+ Pred", fontsize=title_fs) if col == 0: specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix) == slice else z_ix ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number suppress_axes_lines(ax) else: plt.axis('off') col += 1 if seg is not None and seg.shape[1] == 1: cmap = {1: cf.orange} ax.imshow(to_rgba(seg[s_ix][0][...,z_ix], cmap), alpha=0.8) # gt bounding boxes if gt_boxes is not None and len(gt_boxes[s_ix]) > 0: for j, box in enumerate(gt_boxes[s_ix]): if d > 0: [z1, z2] = box[4:] if not (z1<=z_ix and z_ix<=z2): box = [] if len(box) > 0: [y1, x1, y2, x2] = box[:4] # [x1,y1,x2,y2] = box[:4]#:return: coords (x1, y1, x2, y2) width, height = x2 - x1, y2 - y1 if class_targets is not None: label = cf.class_id2label[class_targets[s_ix][j]] legend_items.add(label) if show_gt_labels and cf.plot_class_ids: text_poss, p = [(x1, y1), (x1, (y1+y2)//2)], 0 text_fs = title_fs // 3 if roi_items is not None: for name in roi_items: if name in cf_roi_items and batch[name][s_ix][j] is not None: if 'class_targets' in name: text_x = x2 #- 2 * text_fs * (len(str(class_targets[s_ix][j]))) # avoid overlap of scores text_y = y1 #+ 2 * text_fs text_str = '{}'.format(class_targets[s_ix][j]) elif 'regression_targets' in name: text_x, text_y = (x2, y2) text_str = "[" + " ".join( ["{:.1f}".format(x) for x in batch[name][s_ix][j]]) + "]" elif 'rg_bin_targets' in name: text_x, text_y = (x1, y2) text_str = '{}'.format(batch[name][s_ix][j]) else: text_pos = text_poss.pop(0) text_x = text_pos[0] #- 2 * text_fs * len(str(batch[name][s_ix][j])) text_y = text_pos[1] #+ 2 * text_fs text_str = '{}'.format(batch[name][s_ix][j]) ax.text(text_x, text_y, text_str, color=cf.black if label.color==cf.yellow else cf.white, fontsize=text_fs, bbox=dict(facecolor=label.color, alpha=0.7, edgecolor='none', clip_on=True, pad=0)) p+=1 bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=label.color, facecolor='none') ax.add_patch(bbox) # # -----evtly visualise predictions ------------- # if pr_boxes is not None or seg_preds is not None: # ax = fig.add_subplot(grid[row, col]) # ax.imshow(bg_img, cmap="gray") # ax.axis("off") # col += 1 # if row == 0: # plt.title("Prediction", fontsize=title_fs) # ---------- pred boxes ------------------------- if pr_boxes is not None and len(pr_boxes[s_ix]) > 0: box_score_thres = cf.min_det_thresh if box_score_thres is None else box_score_thres for j, box in enumerate(pr_boxes[s_ix]): plot_box = box["box_type"] in ["det", "prop"] # , "pos_anchor", "neg_anchor"] if box["box_type"] == "det" and (float(box["box_score"]) <= box_score_thres or box["box_pred_class_id"] == 0): plot_box = False if plot_box: if d > 0: [z1, z2] = box["box_coords"][4:] if not (z1<=z_ix and z_ix<=z2): box = [] if len(box) > 0: [y1, x1, y2, x2] = box["box_coords"][:4] width, height = x2 - x1, y2 - y1 if box["box_type"] == "det": label = cf.bin_id2label[box["rg_bin"]] color = cf.aubergine legend_items.add(label) text_x, text_y = x2, y1 #id_text = str(box["box_pred_class_id"]) + "|" if cf.plot_class_ids else "" id_text = "fg: " text_str = '{}{:.0f}'.format(id_text, box["box_score"] * 100) text_settings = dict(facecolor=color, alpha=0.5, edgecolor='none', clip_on=True, pad=0.2) ax.text(text_x, text_y, text_str, color=cf.black if label.color==cf.yellow else cf.white, bbox=text_settings, fontsize=title_fs // 2) edgecolor = color #label.color if 'regression' in box.keys(): text_x, text_y = x2, y2 id_text = "ms: "+" ".join(["{:.1f}".format(x) for x in box["regression"]])+"" text_str = '{}'.format(id_text) #, box["box_score"] * 100) text_settings = dict(facecolor=color, alpha=0.5, edgecolor='none', clip_on=True, pad=0.2) ax.text(text_x, text_y, text_str, color=cf.black if label.color==cf.yellow else cf.white, bbox=text_settings, fontsize=title_fs // 2) if 'rg_bin' in box.keys(): text_x, text_y = x1, y2 text_str = '{}'.format(box["rg_bin"]) text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) # ax.text(text_x, text_y, text_str, color=cf.white, # bbox=text_settings, fontsize=title_fs // 4) if split_ens_ics and "ens_ix" in box.keys(): n_aug = box["ens_ix"].split("_")[1] edgecolor = [c for c in cf.color_palette if not c == cf.green][ int(n_aug) % (len(cf.color_palette) - 1)] text_x, text_y = x1, y2 text_str = "{}".format(box["ens_ix"][2:]) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 6) else: label = cf.box_type2label[box["box_type"]] legend_items.add(label) edgecolor = label.color bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=edgecolor, facecolor='none') ax.add_patch(bbox) row += 1 # -----actions for all batch entries---------- if legend and len(legend_items) > 0: patches = [] for label in legend_items: if cf.plot_class_ids and type(label) != type(cf.box_labels[0]): id_text = str(label.id) + ":" else: id_text = "" patches.append(mpatches.Patch(color=label.color, label="{}{:.10s}".format(id_text, label.name))) # assumes one image gives enough y-space for 5 legend items ncols = max(1, len(legend_items) // (5 * n_samples)) plt.figlegend(handles=patches, loc="upper center", bbox_to_anchor=(0.99, 0.86), borderaxespad=0., ncol=ncols, bbox_transform=fig.transFigure, fontsize=int(2/3*title_fs)) # fig.set_size_inches(mods+3+ncols-1,1.5+1.2*n_samples) if show_info: plt.figtext(0, 0, "Batch content is of type\n{}\nand has shapes\n".format(btype) + \ "{} for 'data' and {} for 'seg'".format(data_init_shp, seg_init_shp)) if out_file is not None: plt.savefig(out_file, dpi=dpi, pad_inches=0.0, bbox_inches='tight', tight_layout=True) if return_fig: return plt.gcf() if not (server_env or cf.server_env): plt.show() plt.clf() plt.close() def view_batch_thesis(cf, batch, res_dict=None, out_file=None, legend=True, has_colorchannels=False, isRGB=True, show_seg_ids="all", show_seg_pred=True, show_gt_boxes=True, show_gt_labels=False, show_cl_ids=True, roi_items="all", server_env=True, sample_picks=None, vol_slice_picks=None, fontsize=12, seg_cmap="class", patient_items=False, box_score_thres=None, plot_mods=True, dpi=400, vmin=None, return_fig=False, axes=None): r"""view data and target entries of a batch. batch expected as dic with entries 'data' and 'seg' holding tensors or nparrays of size :math:`batch\_size \times modalities \times h \times w` for data and :math:`batch\_size \times classes \times h \times w` or :math:`batch\_size \times 1 \times h \times w` for segs. Classes, even if just dummy, are always needed for plotting since they determine colors. :param cf: :param batch: :param res_dict: :param out_file: :param legend: :param show_info: :param has_colorchannels: :param isRGB: :param show_seg_ids: :param show_seg_pred: :param show_gt_boxes: :param show_gt_labels: :param roi_items: strings "all" or "targets" --> all roi items in cf.roi_items or only those which are targets, or list holding keys/names of entries in cf.roi_items to plot additionally on roi boxes. empty iterator to show none. :param split_ens_ics: :param server_env: :param sample_picks: which indices of the batch to display. None for all. :param vol_slice_picks: when batch elements are 3D: which slices to display. None for all, or tuples ("random", int: amt) / (float€[0,1]: fg_prob, int: amt) for random pick / fg_slices pick w probability fg_prob of amt slices. fg pick requires gt seg. :param patient_items: set to true if patient-wise batch items should be displayed (need to be contained in batch and marked via 'patient_' prefix. :param box_score_thres: plot only boxes with pred_score > box_score_thres. None or 0. for no thres. :param plot_mods: :param dpi: graphics resolution :param vmin: min value for gs cmap in imshow, set to fix inter-batch, or None for intra-batch. pyplot expects dimensions in order y,x,chans (height, width, chans) for imshow. show_seg_ids: "all" or None or list with seg classes to show (seg_ids) """ # pfix = prefix, ptfix = postfix pfix = 'patient_' if patient_items else '' ptfix = '_2d' if (patient_items and cf.dim==2) else '' # -------------- get data, set flags ----------------- btype = type(batch[pfix + 'data']) data = batch[pfix + 'data'].astype("float32") seg = batch[pfix + 'seg'] data_init_shp, seg_init_shp = data.shape, seg.shape fg_slices = np.where(np.sum(np.sum(np.squeeze(seg), axis=0), axis=0)>0)[0] if len(fg_slices)==0: print("skipping empty patient") return if vol_slice_picks is None: vol_slice_picks = fg_slices #print("data shp, seg shp", data_init_shp, seg_init_shp) plot_bg = batch['plot_bg'] if 'plot_bg' in batch.keys() and not isinstance(batch['plot_bg'], (int, float)) else None plot_bg_chan = batch['plot_bg'] if 'plot_bg' in batch.keys() and isinstance(batch['plot_bg'], (int, float)) else 0 gt_boxes = batch[pfix+'bb_target'+ptfix] if pfix+'bb_target'+ptfix in batch.keys() and show_gt_boxes else None class_targets = batch[pfix+'class_targets'+ptfix] if pfix+'class_targets'+ptfix in batch.keys() else None cl_targets_sa = batch[pfix+'class_targets_sa'+ptfix] if pfix+'class_targets_sa'+ptfix in batch.keys() else None cf_roi_items = [pfix+it+ptfix for it in cf.roi_items] if roi_items == "all": roi_items = [it for it in cf_roi_items] elif roi_items == "targets": roi_items = [it for it in cf_roi_items if 'targets' in it] else: roi_items = [it for it in cf_roi_items if it in roi_items] if res_dict is not None: seg_preds = res_dict["seg_preds"] if (show_seg_pred is not None and 'seg_preds' in res_dict.keys() and show_seg_ids) else None if '2D_boxes' in res_dict.keys(): assert cf.dim==2 pr_boxes = res_dict["2D_boxes"] elif 'boxes' in res_dict.keys(): pr_boxes = res_dict["boxes"] else: pr_boxes = None else: seg_preds = None pr_boxes = None # -------------- get shapes, apply sample selection ----------------- (n_samples, mods, h, w), d = data.shape[:4], 0 z_ics = [slice(None)] if has_colorchannels: #has to be 2D data = np.transpose(data, axes=(0, 2, 3, 1)) # now b,y,x,c mods = 1 else: if len(data.shape) == 5: # 3dim case d = data.shape[4] if vol_slice_picks is None: z_ics = np.arange(0, d) else: z_ics = vol_slice_picks sample_ics = range(n_samples) # 8000 approx value of pixels that are displayable in one figure dim (pyplot has a render limit), depends on dpi however if data.shape[0]*data.shape[2]*len(z_ics)>8000: n_picks = max(1, int(8000/(data.shape[2]*len(z_ics)))) if len(z_ics)>1 and vol_slice_picks is None: z_ics = np.random.choice(np.arange(0, data.shape[4]), size=min(data.shape[4], max(1,int(8000/(n_picks*data.shape[2])))), replace=False) if sample_picks is None: sample_picks = np.random.choice(data.shape[0], n_picks, replace=False) if sample_picks is not None: sample_ics = [s for s in sample_picks if s in sample_ics] n_samples = len(sample_ics) if not plot_mods: mods = 0 if show_seg_ids=="all": show_seg_ids = np.unique(seg) legend_items = set() # -------------- setup figure ----------------- if isRGB: data = RGB_to_rgb(data) if plot_bg is not None: plot_bg = RGB_to_rgb(plot_bg) n_cols = mods if seg is not None or gt_boxes is not None: n_cols += 1 if seg_preds is not None or pr_boxes is not None: n_cols += 1 n_rows = n_samples*len(z_ics) grid = gridspec.GridSpec(n_rows, n_cols, wspace=0.01, hspace=0.0) fig = plt.figure(figsize=((n_cols + 1)*2, n_rows*2), tight_layout=True) title_fs = fontsize # fontsize text_fs = title_fs * 2 / 3 sample_ics, z_ics = sorted(sample_ics), sorted(z_ics) row = 0 # current row for s_count, s_ix in enumerate(sample_ics): for z_ix in z_ics: col = 0 # current col # ----visualise input data ------------- if has_colorchannels: if plot_mods: ax = fig.add_subplot(grid[row, col]) ax.imshow(data[s_ix][...,z_ix]) ax.axis("off") if row == 0: plt.title("Input", fontsize=title_fs) if col == 0: # key = "spec" if "spec" in batch.keys() else "pid" specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix) == slice else z_ix ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number col += 1 bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix][...,z_ix] else: for mod in range(mods): ax = fig.add_subplot(grid[row, col]) ax.imshow(data[s_ix, mod][...,z_ix], cmap="gray", vmin=vmin) suppress_axes_lines(ax) if row == 0: plt.title("Mod. " + str(mod), fontsize=title_fs) if col == 0: # key = "spec" if "spec" in batch.keys() else "pid" specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix)==slice else z_ix ylabel = str(specs[s_ix])[-5:]+"/"+str(intra_patient_ix) ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number col += 1 bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix, plot_bg_chan][...,z_ix] # ---evtly visualise groundtruths------------------- if seg is not None or gt_boxes is not None: # img as bg for gt if axes is not None and 'gt' in axes.keys(): ax = axes['gt'] else: ax = fig.add_subplot(grid[row, col]) ax.imshow(bg_img, cmap="gray", vmin=vmin) if row == 0: ax.set_title("Ground Truth", fontsize=title_fs) if col == 0: # key = "spec" if "spec" in batch.keys() else "pid" specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix) == slice else z_ix ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) # str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=text_fs*1.3) # show id-number suppress_axes_lines(ax) else: ax.axis('off') col += 1 # gt bounding boxes if gt_boxes is not None and len(gt_boxes[s_ix]) > 0: for j, box in enumerate(gt_boxes[s_ix]): if d > 0: [z1, z2] = box[4:] if not (z1<=z_ix and z_ix<=z2): box = [] if len(box) > 0: [y1, x1, y2, x2] = box[:4] # [x1,y1,x2,y2] = box[:4]#:return: coords (x1, y1, x2, y2) width, height = x2 - x1, y2 - y1 if class_targets is not None: try: label = cf.bin_id2label[cf.rg_val_to_bin_id(batch['patient_regression_targets'][s_ix][j])] except AttributeError: label = cf.class_id2label[class_targets[s_ix][j]] legend_items.add(label) if show_gt_labels and cf.plot_class_ids: bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=label.color, facecolor='none') if height<=text_fs*6: y1 -= text_fs*1.5 y2 += text_fs*2 text_poss, p = [(x1, y1), (x1, (y1+y2)//2)], 0 if roi_items is not None: for name in roi_items: if name in cf_roi_items and batch[name][s_ix][j] is not None: if 'class_targets' in name: text_str = '{}'.format(class_targets[s_ix][j]) text_x, text_y = (x2 + 0 * len(text_str) // 4, y2) elif 'regression_targets' in name: text_str = 'agg. MS: {:.2f}'.format(batch[name][s_ix][j][0]) text_x, text_y = (x2 + 0 * len(text_str) // 4, y2) elif 'rg_bin_targets_sa' in name: text_str = 'sa. MS: {}'.format(batch[name][s_ix][j]) text_x, text_y = (x2-0*len(text_str)*text_fs//4, y1) # elif 'rg_bin_targets' in name: # text_str = 'agg. ms:{}'.format(batch[name][s_ix][j]) # text_x, text_y = (x2+0*len(text_str)//4, y1) ax.text(text_x, text_y, text_str, color=cf.black if (label.color[:3]==cf.yellow or label.color[:3]==cf.green) else cf.white, fontsize=text_fs, bbox=dict(facecolor=label.color, alpha=0.7, edgecolor='none', clip_on=True, pad=0)) p+=1 ax.add_patch(bbox) if seg is not None and seg.shape[1] == 1: #cmap = {1: cf.orange} # cmap = {label_id: label.color for label_id, label in cf.bin_id2label.items()} # this whole function is totally only hacked together for a quick very specific case if seg_cmap == "rg" or seg_cmap=="regression": cmap = {1: cf.bin_id2label[cf.rg_val_to_bin_id(batch['patient_regression_targets'][s_ix][0])].color} else: cmap = cf.class_cmap ax.imshow(to_rgba(seg[s_ix][0][...,z_ix], cmap), alpha=0.8) # # -----evtly visualise predictions ------------- if pr_boxes is not None or seg_preds is not None: if axes is not None and 'pred' in axes.keys(): ax = axes['pred'] else: ax = fig.add_subplot(grid[row, col]) ax.imshow(bg_img, cmap="gray") ax.axis("off") col += 1 if row == 0: ax.set_title("Prediction", fontsize=title_fs) # ---------- pred boxes ------------------------- if pr_boxes is not None and len(pr_boxes[s_ix]) > 0: alpha = 0.7 box_score_thres = cf.min_det_thresh if box_score_thres is None else box_score_thres for j, box in enumerate(pr_boxes[s_ix]): plot_box = box["box_type"] in ["det", "prop"] # , "pos_anchor", "neg_anchor"] if box["box_type"] == "det" and (float(box["box_score"]) <= box_score_thres or box["box_pred_class_id"] == 0): plot_box = False if plot_box: if d > 0: [z1, z2] = box["box_coords"][4:] if not (z1<=z_ix and z_ix<=z2): box = [] if len(box) > 0: [y1, x1, y2, x2] = box["box_coords"][:4] width, height = x2 - x1, y2 - y1 if box["box_type"] == "det": try: label = cf.bin_id2label[cf.rg_val_to_bin_id(box['regression'])] except AttributeError: label = cf.class_id2label[box['box_pred_class_id']] # assert box["rg_bin"] == cf.rg_val_to_bin_id(box['regression']), \ # "box bin: {}, rg-bin {}".format(box["rg_bin"], cf.rg_val_to_bin_id(box['regression'])) color = label.color#cf.aubergine edgecolor = color # label.color text_color = cf.black if (color[:3]==cf.yellow or color[:3]==cf.green) else cf.white legend_items.add(label) bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=edgecolor, facecolor='none') if height<=text_fs*6: y1 -= text_fs*1.5 y2 += text_fs*2 text_x, text_y = x2, y1 #id_text = str(box["box_pred_class_id"]) + "|" if cf.plot_class_ids else "" id_text = "FG: " text_str = r'{}{:.0f}%'.format(id_text, box["box_score"] * 100) text_settings = dict(facecolor=color, alpha=alpha, edgecolor='none', clip_on=True, pad=0.2) ax.text(text_x, text_y, text_str, color=text_color, bbox=text_settings, fontsize=text_fs ) if 'regression' in box.keys(): text_x, text_y = x2, y2 id_text = "MS: "+" ".join(["{:.2f}".format(x) for x in box["regression"]])+"" text_str = '{}'.format(id_text) text_settings = dict(facecolor=color, alpha=alpha, edgecolor='none', clip_on=True, pad=0.2) ax.text(text_x, text_y, text_str, color=text_color, bbox=text_settings, fontsize=text_fs) if 'rg_bin' in box.keys(): text_x, text_y = x1, y2 text_str = '{}'.format(box["rg_bin"]) text_settings = dict(facecolor=color, alpha=alpha, edgecolor='none', clip_on=True, pad=0) # ax.text(text_x, text_y, text_str, color=cf.white, # bbox=text_settings, fontsize=title_fs // 4) if 'box_pred_class_id' in box.keys() and show_cl_ids: text_x, text_y = x2, y2 id_text = box["box_pred_class_id"] text_str = '{}'.format(id_text) text_settings = dict(facecolor=color, alpha=alpha, edgecolor='none', clip_on=True, pad=0.2) ax.text(text_x, text_y, text_str, color=text_color, bbox=text_settings, fontsize=text_fs) else: label = cf.box_type2label[box["box_type"]] legend_items.add(label) edgecolor = label.color ax.add_patch(bbox) row += 1 # -----actions for all batch entries---------- if legend and len(legend_items) > 0: patches = [] for label in legend_items: if cf.plot_class_ids and type(label) != type(cf.box_labels[0]): id_text = str(label.id) + ":" else: id_text = "" patches.append(mpatches.Patch(color=label.color, label="{}{:.10s}".format(id_text, label.name))) # assumes one image gives enough y-space for 5 legend items ncols = max(1, len(legend_items) // (5 * n_samples)) plt.figlegend(handles=patches, loc="upper center", bbox_to_anchor=(0.99, 0.86), borderaxespad=0., ncol=ncols, bbox_transform=fig.transFigure, fontsize=int(2/3*title_fs)) # fig.set_size_inches(mods+3+ncols-1,1.5+1.2*n_samples) if out_file is not None: plt.savefig(out_file, dpi=dpi, pad_inches=0.0, bbox_inches='tight', tight_layout=True) if return_fig: return plt.gcf() if not (server_env or cf.server_env): plt.show() plt.clf() plt.close() def view_slices(cf, img, seg=None, ids=None, title="", out_dir=None, legend=True, cmap=None, label_remap=None, instance_labels=False): """View slices of a 3D image overlayed with corresponding segmentations. :params img, seg: expected as 3D-arrays """ if isinstance(img, sitk.SimpleITK.Image): img = sitk.GetArrayViewFromImage(img) elif isinstance(img, np.ndarray): #assume channels dim is smallest and in either first or last place if np.argmin(img.shape)==2: img = np.moveaxis(img, 2,0) else: raise Exception("view_slices got unexpected img type.") if seg is not None: if isinstance(seg, sitk.SimpleITK.Image): seg = sitk.GetArrayViewFromImage(seg) elif isinstance(img, np.ndarray): if np.argmin(seg.shape)==2: seg = np.moveaxis(seg, 2,0) else: raise Exception("view_slices got unexpected seg type.") if label_remap is not None: for (key, val) in label_remap.items(): seg[seg==key] = val if instance_labels: class Label(): def __init__(self, id, name, color): self.id = id self.name = name self.color = color legend_items = {Label(seg_id, "instance_{}".format(seg_id), cf.color_palette[seg_id%len(cf.color_palette)]) for seg_id in np.unique(seg)} if cmap is None: cmap = {label.id : label.color for label in legend_items} else: legend_items = {cf.seg_id2label[seg_id] for seg_id in np.unique(seg)} if cmap is None: cmap = {label.id : label.color for label in legend_items} slices = img.shape[0] if seg is not None: assert slices==seg.shape[0], "Img and seg have different amt of slices." grid = gridspec.GridSpec(int(np.ceil(slices/4)),4) fig = plt.figure(figsize=(10, slices/4*2.5)) rng = np.arange(slices, dtype='uint8') if not ids is None: rng = rng[ids] for s in rng: ax = fig.add_subplot(grid[int(s/4),int(s%4)]) ax.imshow(img[s], cmap="gray") if not seg is None: ax.imshow(to_rgba(seg[s], cmap), alpha=0.9) if legend and int(s/4)==0 and int(s%4)==3: patches = [mpatches.Patch(color=label.color, label="{}".format(label.name)) for label in legend_items] ncols = 1 plt.legend(handles=patches,bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., ncol=ncols) plt.title("slice {}, {}".format(s, img[s].shape)) plt.axis('off') plt.suptitle(title) if out_dir is not None: plt.savefig(out_dir, dpi=300, pad_inches=0.0, bbox_inches='tight') if not cf.server_env: plt.show() plt.close() def plot_txt(cf, txts, labels=None, title="", x_label="", y_labels=["",""], y_ranges=(None,None), twin_axes=(), smooth=None, out_dir=None): """Read and plot txt data, either from file (txts is paths) or directly (txts is arrays). :param twin_axes: plot two y-axis over same x-axis. twin_axes expected as tuple defining which txt files (determined via indices) share the second y-axis. """ if isinstance(txts, str) or not hasattr(txts, '__iter__'): txts = [txts] fig = plt.figure() ax1 = fig.add_subplot(1,1,1) if len(twin_axes)>0: ax2 = ax1.twinx() for i, txt in enumerate(txts): if isinstance(txt, str): arr = np.genfromtxt(txt, delimiter=',',skip_header=1, usecols=(1,2)) else: arr = txt if i in twin_axes: ax = ax2 else: ax = ax1 if smooth is not None: spline_graph = interpol.UnivariateSpline(arr[:,0], arr[:,1], k=5, s=float(smooth)) ax.plot(arr[:, 0], spline_graph(arr[:,0]), color=cf.color_palette[i % len(cf.color_palette)], marker='', markersize=2, linestyle='solid') ax.plot(arr[:,0], arr[:,1], color=cf.color_palette[i%len(cf.color_palette)], marker='', markersize=2, linestyle='solid', label=labels[i], alpha=0.5 if smooth else 1.) plt.title(title) ax1.set_xlabel(x_label) ax1.set_ylabel(y_labels[0]) if y_ranges[0] is not None: ax1.set_ylim(y_ranges[0]) if len(twin_axes)>0: ax2.set_ylabel(y_labels[1]) if y_ranges[1] is not None: ax2.set_ylim(y_ranges[1]) plt.grid() if labels is not None: ax1.legend(loc="upper center") if len(twin_axes)>0: ax2.legend(loc=4) if out_dir is not None: plt.savefig(out_dir, dpi=200) return fig def plot_tboard_logs(cf, log_dir, tag_filters=[""], inclusive_filters=True, out_dir=None, x_label="", y_labels=["",""], y_ranges=(None,None), twin_axes=(), smooth=None): """Plot (only) tboard scalar logs from given log_dir for multiple runs sorted by tag. """ print("log dir", log_dir) mpl = EventMultiplexer().AddRunsFromDirectory(log_dir) #EventAccumulator(log_dir) mpl.Reload() # Print tags of contained entities, use these names to retrieve entities as below #print(mpl.Runs()) scalars = {runName : data['scalars'] for (runName, data) in mpl.Runs().items() if len(data['scalars'])>0} print("scalars", scalars) tags = {} tag_filters = [tag_filter.lower() for tag_filter in tag_filters] for (runName, runtags) in scalars.items(): print("rn", runName.lower()) check = np.any if inclusive_filters else np.all if np.any([tag_filter in runName.lower() for tag_filter in tag_filters]): for runtag in runtags: #if tag_filter in runtag.lower(): if runtag not in tags: tags[runtag] = [runName] else: tags[runtag].append(runName) print("tags ", tags) for (tag, runNames) in tags.items(): print("runnames ", runNames) print("tag", tag) tag_scalars = [] labels = [] for run in runNames: #mpl.Scalars returns ScalarEvents array holding wall_time, step, value per time step (shape series_length x 3) #print(mpl.Scalars(runName, tag)[0]) run_scalars = [(s.step, s.value) for s in mpl.Scalars(run, tag)] print(np.array(run_scalars).shape) tag_scalars.append(np.array(run_scalars)) print("run", run) labels.append("/".join(run.split("/")[-2:])) #print("tag scalars ", tag_scalars) if out_dir is not None: out_path = os.path.join(out_dir,tag.replace("/","_")) else: out_path = None plot_txt(txts=tag_scalars, labels=labels, title=tag, out_dir=out_path, cf=cf, x_label=x_label, y_labels=y_labels, y_ranges=y_ranges, twin_axes=twin_axes, smooth=smooth) def plot_box_legend(cf, box_coords=None, class_id=None, out_dir=None): """plot a blank box explaining box annotations. :param cf: :return: """ if class_id is None: class_id = 1 img = np.ones(cf.patch_size[:2]) dim_max = max(cf.patch_size[:2]) width, height = cf.patch_size[0] // 2, cf.patch_size[1] // 2 if box_coords is None: # lower left corner x1, y1 = width // 2, height // 2 x2, y2 = x1 + width, y1 + height else: y1, x1, y2, x2 = box_coords fig = plt.figure(tight_layout=True, dpi=300) ax = fig.add_subplot(111) title_fs = 36 label = cf.class_id2label[class_id] # legend_items.add(label) ax.set_facecolor(cf.beige) ax.imshow(img, cmap='gray', vmin=0., vmax=1., alpha=0) # ax.axis('off') # suppress_axes_lines(ax) ax.set_xticks([]) ax.set_yticks([]) text_x, text_y = x2 * 0.85, y1 id_text = "class id" + " | " if cf.plot_class_ids else "" text_str = '{}{}'.format(id_text, "confidence") text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) edgecolor = label.color if any(['regression' in task for task in cf.prediction_tasks]): text_x, text_y = x2 * 0.85, y2 id_text = "regression" if any(['ken_gal' in task or 'feindt' in task for task in cf.prediction_tasks]): id_text += " | uncertainty" text_str = '{}'.format(id_text) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) if 'regression_bin' in cf.prediction_tasks or hasattr(cf, "rg_val_to_bin_id"): text_x, text_y = x1, y2 text_str = 'Rg. Bin' ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) if 'lesion_gleasons' in cf.observables_rois: text_x, text_y = x1, y1 text_str = 'Gleason Score' ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=1., edgecolor=edgecolor, facecolor='none') ax.add_patch(bbox) if out_dir is not None: plt.savefig(os.path.join(out_dir, "box_legend.png")) def plot_boxes(cf, box_coords, patch_size=None, scores=None, class_ids=None, out_file=None, ax=None): if patch_size is None: patch_size = cf.patch_size[:2] if class_ids is None: class_ids = np.ones((len(box_coords),), dtype='uint8') if scores is None: scores = np.ones((len(box_coords),), dtype='uint8') img = np.ones(patch_size) y1, x1, y2, x2 = box_coords[:,0], box_coords[:,1], box_coords[:,2], box_coords[:,3] width, height = x2-x1, y2-y1 close = False if ax is None: fig = plt.figure(tight_layout=True, dpi=300) ax = fig.add_subplot(111) close = True title_fs = 56 ax.set_facecolor((*cf.gray,0.15)) ax.imshow(img, cmap='gray', vmin=0., vmax=1., alpha=0) #ax.axis('off') #suppress_axes_lines(ax) ax.set_xticks([]) ax.set_yticks([]) for bix, cl_id in enumerate(class_ids): label = cf.class_id2label[cl_id] text_x, text_y = x2[bix] -20, y1[bix] +5 id_text = class_ids[bix] if cf.plot_class_ids else "" text_str = '{}{}{:.0f}'.format(id_text, " | ", scores[bix] * 100) text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) edgecolor = label.color bbox = mpatches.Rectangle((x1[bix], y1[bix]), width[bix], height[bix], linewidth=1., edgecolor=edgecolor, facecolor='none') ax.add_patch(bbox) if out_file is not None: plt.savefig(out_file) if close: plt.close() if __name__=="__main__": cluster_exp_root = "/mnt/E132-Cluster-Projects" #dataset="prostate/" dataset = "lidc/" exp_name = "ms13_mrcnnal3d_rg_bs8_480k" #exp_dir = os.path.join("datasets", dataset, "experiments", exp_name) # exp_dir = os.path.join(cluster_exp_root, dataset, "experiments", exp_name) # log_dir = os.path.join(exp_dir, "logs") # sys.path.append(exp_dir) # from configs import Configs # cf = configs() # # #print("logdir", log_dir) # #out_dir = os.path.join(cf.source_dir, log_dir.replace("/", "_")) # #print("outdir", out_dir) # log_dir = os.path.join(cf.source_dir, log_dir) # plot_tboard_logs(cf, log_dir, tag_filters=["train/lesion_avp", "val/lesion_ap", "val/lesion_avp", "val/patient_lesion_avp"], smooth=2.2, out_dir=log_dir, # y_ranges=([0,900], [0,0.8]), # twin_axes=[1], y_labels=["counts",""], x_label="epoch") #plot_box_legend(cf, out_dir=exp_dir) diff --git a/utils/exp_utils.py b/utils/exp_utils.py index d36a5dc..ec20a53 100644 --- a/utils/exp_utils.py +++ b/utils/exp_utils.py @@ -1,694 +1,726 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -# import plotting as plg - +from typing import Union, Iterable import sys import os import subprocess from multiprocessing import Process import threading import pickle import importlib.util import psutil import time import nvidia_smi import logging from torch.utils.tensorboard import SummaryWriter from collections import OrderedDict import numpy as np import pandas as pd import torch def import_module(name, path): """ correct way of importing a module dynamically in python 3. :param name: name given to module instance. :param path: path to module. :return: module: returned module instance. """ spec = importlib.util.spec_from_file_location(name, path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module def save_obj(obj, name): """Pickle a python object.""" with open(name + '.pkl', 'wb') as f: pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) def load_obj(file_path): with open(file_path, 'rb') as handle: return pickle.load(handle) def IO_safe(func, *args, _tries=5, _raise=True, **kwargs): """ Wrapper calling function func with arguments args and keyword arguments kwargs to catch input/output errors on cluster. :param func: function to execute (intended to be read/write operation to a problematic cluster drive, but can be any function). :param args: positional args of func. :param kwargs: kw args of func. :param _tries: how many attempts to make executing func. """ for _try in range(_tries): try: return func(*args, **kwargs) except OSError as e: # to catch cluster issues with network drives if _raise: raise e else: print("After attempting execution {} time{}, following error occurred:\n{}".format(_try + 1, "" if _try == 0 else "s", e)) continue def split_off_process(target, *args, daemon=False, **kwargs): """Start a process that won't block parent script. No join(), no return value. If daemon=False: before parent exits, it waits for this to finish. """ p = Process(target=target, args=tuple(args), kwargs=kwargs, daemon=daemon) p.start() return p def query_nvidia_gpu(device_id, d_keyword=None, no_units=False): """ :param device_id: :param d_keyword: -d, --display argument (keyword(s) for selective display), all are selected if None :return: dict of gpu-info items """ cmd = ['nvidia-smi', '-i', str(device_id), '-q'] if d_keyword is not None: cmd += ['-d', d_keyword] outp = subprocess.check_output(cmd).strip().decode('utf-8').split("\n") outp = [x for x in outp if len(x) > 0] headers = [ix for ix, item in enumerate(outp) if len(item.split(":")) == 1] + [len(outp)] out_dict = {} for lix, hix in enumerate(headers[:-1]): head = outp[hix].strip().replace(" ", "_").lower() out_dict[head] = {} for lix2 in range(hix, headers[lix + 1]): try: key, val = [x.strip().lower() for x in outp[lix2].split(":")] if no_units: val = val.split()[0] out_dict[head][key] = val except: pass return out_dict +class _AnsiColorizer(object): + """ + A colorizer is an object that loosely wraps around a stream, allowing + callers to write text to the stream in a particular color. + + Colorizer classes must implement C{supported()} and C{write(text, color)}. + """ + _colors = dict(black=30, red=31, green=32, yellow=33, + blue=34, magenta=35, cyan=36, white=37, default=39) + + def __init__(self, stream): + self.stream = stream + + @classmethod + def supported(cls, stream=sys.stdout): + """ + A class method that returns True if the current platform supports + coloring terminal output using this method. Returns False otherwise. + """ + if not stream.isatty(): + return False # auto color only on TTYs + try: + import curses + except ImportError: + return False + else: + try: + try: + return curses.tigetnum("colors") > 2 + except curses.error: + curses.setupterm() + return curses.tigetnum("colors") > 2 + except: + raise + # guess false in case of error + return False + + def write(self, text, color): + """ + Write the given text to the stream in the given color. + + @param text: Text to be written to the stream. + + @param color: A string label for a color. e.g. 'red', 'white'. + """ + color = self._colors[color] + self.stream.write('\x1b[%sm%s\x1b[0m' % (color, text)) + +class ColorHandler(logging.StreamHandler): + + def __init__(self, stream=sys.stdout): + super(ColorHandler, self).__init__(_AnsiColorizer(stream)) + + def emit(self, record): + msg_colors = { + logging.DEBUG: "green", + logging.INFO: "default", + logging.WARNING: "red", + logging.ERROR: "red" + } + color = msg_colors.get(record.levelno, "blue") + self.stream.write(record.msg + "\n", color) class CombinedPrinter(object): """combined print function. prints to logger and/or file if given, to normal print if non given. """ def __init__(self, logger=None, file=None): if logger is None and file is None: self.out = [print] elif logger is None: self.out = [file.write] elif file is None: self.out = [logger.info] else: self.out = [logger.info, file.write] def __call__(self, string): for fct in self.out: fct(string) - class Nvidia_GPU_Logger(object): def __init__(self): self.count = None def get_vals(self): # cmd = ['nvidia-settings', '-t', '-q', 'GPUUtilization'] # gpu_util = subprocess.check_output(cmd).strip().decode('utf-8').split(",") # gpu_util = dict([f.strip().split("=") for f in gpu_util]) # cmd[-1] = 'UsedDedicatedGPUMemory' # gpu_used_mem = subprocess.check_output(cmd).strip().decode('utf-8') nvidia_smi.nvmlInit() # card id 0 hardcoded here, there is also a call to get all available card ids, so we could iterate self.gpu_handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0) util_res = nvidia_smi.nvmlDeviceGetUtilizationRates(self.gpu_handle) #mem_res = nvidia_smi.nvmlDeviceGetMemoryInfo(self.gpu_handle) # current_vals = {"gpu_mem_alloc": mem_res.used / (1024**2), "gpu_graphics_util": int(gpu_util['graphics']), # "gpu_mem_util": gpu_util['memory'], "time": time.time()} current_vals = {"gpu_graphics_util": float(util_res.gpu), "time": time.time()} return current_vals def loop(self, interval): i = 0 while True: current_vals = self.get_vals() self.log["time"].append(time.time()) self.log["gpu_util"].append(current_vals["gpu_graphics_util"]) if self.count is not None: i += 1 if i == self.count: exit(0) time.sleep(self.interval) def start(self, interval=1.): self.interval = interval self.start_time = time.time() self.log = {"time": [], "gpu_util": []} if self.interval is not None: thread = threading.Thread(target=self.loop) thread.daemon = True thread.start() class CombinedLogger(object): """Combine console and tensorboard logger and record system metrics. """ def __init__(self, name, log_dir, server_env=True, fold="all", sysmetrics_interval=2): self.pylogger = logging.getLogger(name) self.tboard = SummaryWriter(log_dir=os.path.join(log_dir, "tboard")) self.times = {} self.log_dir = log_dir self.fold = str(fold) self.server_env = server_env self.pylogger.setLevel(logging.DEBUG) self.log_file = os.path.join(log_dir, "fold_"+self.fold, 'exec.log') os.makedirs(os.path.dirname(self.log_file), exist_ok=True) self.pylogger.addHandler(logging.FileHandler(self.log_file)) if not server_env: self.pylogger.addHandler(ColorHandler()) else: self.pylogger.addHandler(logging.StreamHandler()) self.pylogger.propagate = False # monitor system metrics (cpu, mem, ...) if not server_env and sysmetrics_interval > 0: self.sysmetrics = pd.DataFrame( columns=["global_step", "rel_time", r"CPU (%)", "mem_used (GB)", r"mem_used (%)", r"swap_used (GB)", r"gpu_utilization (%)"], dtype="float16") for device in range(torch.cuda.device_count()): self.sysmetrics[ "mem_allocd (GB) by torch on {:10s}".format(torch.cuda.get_device_name(device))] = np.nan self.sysmetrics[ "mem_cached (GB) by torch on {:10s}".format(torch.cuda.get_device_name(device))] = np.nan self.sysmetrics_start(sysmetrics_interval) pass else: print("NOT logging sysmetrics") def __getattr__(self, attr): """delegate all undefined method requests to objects of this class in order pylogger, tboard (first find first serve). E.g., combinedlogger.add_scalars(...) should trigger self.tboard.add_scalars(...) """ for obj in [self.pylogger, self.tboard]: if attr in dir(obj): return getattr(obj, attr) print("logger attr not found") #raise AttributeError("CombinedLogger has no attribute {}".format(attr)) def set_logfile(self, fold=None, log_file=None): if fold is not None: self.fold = str(fold) if log_file is None: self.log_file = os.path.join(self.log_dir, "fold_"+self.fold, 'exec.log') else: self.log_file = log_file os.makedirs(os.path.dirname(self.log_file), exist_ok=True) for hdlr in self.pylogger.handlers: hdlr.close() self.pylogger.handlers = [] self.pylogger.addHandler(logging.FileHandler(self.log_file)) if not self.server_env: self.pylogger.addHandler(ColorHandler()) else: self.pylogger.addHandler(logging.StreamHandler()) def time(self, name, toggle=None): """record time-spans as with a stopwatch. :param name: :param toggle: True^=On: start time recording, False^=Off: halt rec. if None determine from current status. :return: either start-time or last recorded interval """ if toggle is None: if name in self.times.keys(): toggle = not self.times[name]["toggle"] else: toggle = True if toggle: if not name in self.times.keys(): self.times[name] = {"total": 0, "last": 0} elif self.times[name]["toggle"] == toggle: self.info("restarting running stopwatch") self.times[name]["last"] = time.time() self.times[name]["toggle"] = toggle return time.time() else: if toggle == self.times[name]["toggle"]: self.info("WARNING: tried to stop stopped stop watch: {}.".format(name)) self.times[name]["last"] = time.time() - self.times[name]["last"] self.times[name]["total"] += self.times[name]["last"] self.times[name]["toggle"] = toggle return self.times[name]["last"] def get_time(self, name=None, kind="total", format=None, reset=False): """ :param name: :param kind: 'total' or 'last' :param format: None for float, "hms"/"ms" for (hours), mins, secs as string :param reset: reset time after retrieving :return: """ if name is None: times = self.times if reset: self.reset_time() return times else: if self.times[name]["toggle"]: self.time(name, toggle=False) time = self.times[name][kind] if format == "hms": m, s = divmod(time, 60) h, m = divmod(m, 60) time = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(m), int(s)) elif format == "ms": m, s = divmod(time, 60) time = "{:02d}m:{:02d}s".format(int(m), int(s)) if reset: self.reset_time(name) return time def reset_time(self, name=None): if name is None: self.times = {} else: del self.times[name] def sysmetrics_update(self, global_step=None): if global_step is None: global_step = time.strftime("%x_%X") mem = psutil.virtual_memory() mem_used = (mem.total - mem.available) gpu_vals = self.gpu_logger.get_vals() rel_time = time.time() - self.sysmetrics_start_time self.sysmetrics.loc[len(self.sysmetrics)] = [global_step, rel_time, psutil.cpu_percent(), mem_used / 1024 ** 3, mem_used / mem.total * 100, psutil.swap_memory().used / 1024 ** 3, int(gpu_vals['gpu_graphics_util']), *[torch.cuda.memory_allocated(d) / 1024 ** 3 for d in range(torch.cuda.device_count())], *[torch.cuda.memory_cached(d) / 1024 ** 3 for d in range(torch.cuda.device_count())] ] return self.sysmetrics.loc[len(self.sysmetrics) - 1].to_dict() def sysmetrics2tboard(self, metrics=None, global_step=None, suptitle=None): tag = "per_time" if metrics is None: metrics = self.sysmetrics_update(global_step=global_step) tag = "per_epoch" if suptitle is not None: suptitle = str(suptitle) elif self.fold != "": suptitle = "Fold_" + str(self.fold) if suptitle is not None: self.tboard.add_scalars(suptitle + "/System_Metrics/" + tag, {k: v for (k, v) in metrics.items() if (k != "global_step" and k != "rel_time")}, global_step) def sysmetrics_loop(self): try: os.nice(-19) self.info("Logging system metrics with superior process priority.") except: self.info("Logging system metrics without superior process priority.") while True: metrics = self.sysmetrics_update() self.sysmetrics2tboard(metrics, global_step=metrics["rel_time"]) # print("thread alive", self.thread.is_alive()) time.sleep(self.sysmetrics_interval) def sysmetrics_start(self, interval): if interval is not None and interval > 0: self.sysmetrics_interval = interval self.gpu_logger = Nvidia_GPU_Logger() self.sysmetrics_start_time = time.time() self.sys_metrics_process = split_off_process(target=self.sysmetrics_loop, daemon=True) # self.thread = threading.Thread(target=self.sysmetrics_loop) # self.thread.daemon = True # self.thread.start() def sysmetrics_save(self, out_file): self.sysmetrics.to_pickle(out_file) def metrics2tboard(self, metrics, global_step=None, suptitle=None): """ :param metrics: {'train': dataframe, 'val':df}, df as produced in evaluator.py.evaluate_predictions """ # print("metrics", metrics) if global_step is None: global_step = len(metrics['train'][list(metrics['train'].keys())[0]]) - 1 if suptitle is not None: suptitle = str(suptitle) else: suptitle = "Fold_" + str(self.fold) for key in ['train', 'val']: # series = {k:np.array(v[-1]) for (k,v) in metrics[key].items() if not np.isnan(v[-1]) and not 'Bin_Stats' in k} loss_series = {} unc_series = {} bin_stat_series = {} mon_met_series = {} for tag, val in metrics[key].items(): val = val[-1] # maybe remove list wrapping, recording in evaluator? if 'bin_stats' in tag.lower() and not np.isnan(val): bin_stat_series["{}".format(tag.split("/")[-1])] = val elif 'uncertainty' in tag.lower() and not np.isnan(val): unc_series["{}".format(tag)] = val elif 'loss' in tag.lower() and not np.isnan(val): loss_series["{}".format(tag)] = val elif not np.isnan(val): mon_met_series["{}".format(tag)] = val self.tboard.add_scalars(suptitle + "/Binary_Statistics/{}".format(key), bin_stat_series, global_step) self.tboard.add_scalars(suptitle + "/Uncertainties/{}".format(key), unc_series, global_step) self.tboard.add_scalars(suptitle + "/Losses/{}".format(key), loss_series, global_step) self.tboard.add_scalars(suptitle + "/Monitor_Metrics/{}".format(key), mon_met_series, global_step) self.tboard.add_scalars(suptitle + "/Learning_Rate", metrics["lr"], global_step) return def batchImgs2tboard(self, batch, results_dict, cmap, boxtype2color, img_bg=False, global_step=None): raise NotImplementedError("not up-to-date, problem with importing plotting-file, torchvision dependency.") if len(batch["seg"].shape) == 5: # 3D imgs slice_ix = np.random.randint(batch["seg"].shape[-1]) seg_gt = plg.to_rgb(batch['seg'][:, 0, :, :, slice_ix], cmap) seg_pred = plg.to_rgb(results_dict['seg_preds'][:, 0, :, :, slice_ix], cmap) mod_img = plg.mod_to_rgb(batch["data"][:, 0, :, :, slice_ix]) if img_bg else None elif len(batch["seg"].shape) == 4: seg_gt = plg.to_rgb(batch['seg'][:, 0, :, :], cmap) seg_pred = plg.to_rgb(results_dict['seg_preds'][:, 0, :, :], cmap) mod_img = plg.mod_to_rgb(batch["data"][:, 0]) if img_bg else None else: raise Exception("batch content has wrong format: {}".format(batch["seg"].shape)) # from here on only works in 2D seg_gt = np.transpose(seg_gt, axes=(0, 3, 1, 2)) # previous shp: b,x,y,c seg_pred = np.transpose(seg_pred, axes=(0, 3, 1, 2)) seg = np.concatenate((seg_gt, seg_pred), axis=0) # todo replace torchvision (tv) dependency seg = tv.utils.make_grid(torch.from_numpy(seg), nrow=2) self.tboard.add_image("Batch seg, 1st col: gt, 2nd: pred.", seg, global_step=global_step) if img_bg: bg_img = np.transpose(mod_img, axes=(0, 3, 1, 2)) else: bg_img = seg_gt box_imgs = plg.draw_boxes_into_batch(bg_img, results_dict["boxes"], boxtype2color) box_imgs = tv.utils.make_grid(torch.from_numpy(box_imgs), nrow=4) self.tboard.add_image("Batch bboxes", box_imgs, global_step=global_step) return def __del__(self): # otherwise might produce multiple prints e.g. in ipython console #self.sys_metrics_process.terminate() for hdlr in self.pylogger.handlers: hdlr.close() self.pylogger.handlers = [] del self.pylogger self.tboard.flush() # close holds up main script exit. maybe revise this issue with a later pytorch version. #self.tboard.close() - def get_logger(exp_dir, server_env=False, sysmetrics_interval=2): log_dir = os.path.join(exp_dir, "logs") logger = CombinedLogger('Reg R-CNN', log_dir, server_env=server_env, sysmetrics_interval=sysmetrics_interval) print("logging to {}".format(logger.log_file)) return logger - -def prep_exp(dataset_path, exp_path, server_env, use_stored_settings=True, is_training=True): +def prepare_monitoring(cf): """ - I/O handling, creating of experiment folder structure. Also creates a snapshot of configs/model scripts and copies them to the exp_dir. - This way the exp_dir contains all info needed to conduct an experiment, independent to changes in actual source code. Thus, training/inference of this experiment can be started at anytime. - Therefore, the model script is copied back to the source code dir as tmp_model (tmp_backbone). - Provides robust structure for cloud deployment. - :param dataset_path: path to source code for specific data set. (e.g. medicaldetectiontoolkit/lidc_exp) - :param exp_path: path to experiment directory. - :param server_env: boolean flag. pass to configs script for cloud deployment. - :param use_stored_settings: boolean flag. When starting training: If True, starts training from snapshot in existing - experiment directory, else creates experiment directory on the fly using configs/model scripts from source code. - :param is_training: boolean flag. distinguishes train vs. inference mode. - :return: configs object. + creates dictionaries, where train/val metrics are stored. """ + metrics = {} + # first entry for loss dict accounts for epoch starting at 1. + metrics['train'] = OrderedDict() # [(l_name, [np.nan]) for l_name in cf.losses_to_monitor] ) + metrics['val'] = OrderedDict() # [(l_name, [np.nan]) for l_name in cf.losses_to_monitor] ) + metric_classes = [] + if 'rois' in cf.report_score_level: + metric_classes.extend([v for k, v in cf.class_dict.items()]) + if hasattr(cf, "eval_bins_separately") and cf.eval_bins_separately: + metric_classes.extend([v for k, v in cf.bin_dict.items()]) + if 'patient' in cf.report_score_level: + metric_classes.extend(['patient_' + cf.class_dict[cf.patient_class_of_interest]]) + if hasattr(cf, "eval_bins_separately") and cf.eval_bins_separately: + metric_classes.extend(['patient_' + cf.bin_dict[cf.patient_bin_of_interest]]) + for cl in metric_classes: + for m in cf.metrics: + metrics['train'][cl + '_' + m] = [np.nan] + metrics['val'][cl + '_' + m] = [np.nan] - if is_training: - - if use_stored_settings: - cf_file = import_module('cf', os.path.join(exp_path, 'configs.py')) - cf = cf_file.Configs(server_env) - # in this mode, previously saved model and backbone need to be found in exp dir. - if not os.path.isfile(os.path.join(exp_path, 'model.py')) or \ - not os.path.isfile(os.path.join(exp_path, 'backbone.py')): - raise Exception( - "Selected use_stored_settings option but no model and/or backbone source files exist in exp dir.") - cf.model_path = os.path.join(exp_path, 'model.py') - cf.backbone_path = os.path.join(exp_path, 'backbone.py') - else: # this case overwrites settings files in exp dir, i.e., default_configs, configs, backbone, model - os.makedirs(exp_path, exist_ok=True) - # run training with source code info and copy snapshot of model to exp_dir for later testing (overwrite scripts if exp_dir already exists.) - subprocess.call('cp {} {}'.format('default_configs.py', os.path.join(exp_path, 'default_configs.py')), - shell=True) - subprocess.call( - 'cp {} {}'.format(os.path.join(dataset_path, 'configs.py'), os.path.join(exp_path, 'configs.py')), - shell=True) - cf_file = import_module('cf_file', os.path.join(dataset_path, 'configs.py')) - cf = cf_file.Configs(server_env) - subprocess.call('cp {} {}'.format(cf.model_path, os.path.join(exp_path, 'model.py')), shell=True) - subprocess.call('cp {} {}'.format(cf.backbone_path, os.path.join(exp_path, 'backbone.py')), shell=True) - if os.path.isfile(os.path.join(exp_path, "fold_ids.pickle")): - subprocess.call('rm {}'.format(os.path.join(exp_path, "fold_ids.pickle")), shell=True) - - else: # testing, use model and backbone stored in exp dir. - cf_file = import_module('cf', os.path.join(exp_path, 'configs.py')) - cf = cf_file.Configs(server_env) - cf.model_path = os.path.join(exp_path, 'model.py') - cf.backbone_path = os.path.join(exp_path, 'backbone.py') - - cf.exp_dir = exp_path - cf.test_dir = os.path.join(cf.exp_dir, 'test') - cf.plot_dir = os.path.join(cf.exp_dir, 'plots') - if not os.path.exists(cf.test_dir): - os.mkdir(cf.test_dir) - if not os.path.exists(cf.plot_dir): - os.mkdir(cf.plot_dir) - cf.experiment_name = exp_path.split("/")[-1] - cf.dataset_name = dataset_path - cf.server_env = server_env - cf.created_fold_id_pickle = False - - return cf + return metrics class ModelSelector: ''' saves a checkpoint after each epoch as 'last_state' (can be loaded to continue interrupted training). saves the top-k (k=cf.save_n_models) ranked epochs. In inference, predictions of multiple epochs can be ensembled to improve performance. ''' def __init__(self, cf, logger): self.cf = cf self.logger = logger self.model_index = pd.DataFrame(columns=["rank", "score", "criteria_values", "file_name"], index=pd.RangeIndex(self.cf.min_save_thresh, self.cf.num_epochs, name="epoch")) def run_model_selection(self, net, optimizer, monitor_metrics, epoch): """rank epoch via weighted mean from self.cf.model_selection_criteria: {criterion : weight} :param net: :param optimizer: :param monitor_metrics: :param epoch: :return: """ crita = self.cf.model_selection_criteria # shorter alias metrics = monitor_metrics['val'] epoch_score = np.sum([metrics[criterion][-1] * weight for criterion, weight in crita.items() if not np.isnan(metrics[criterion][-1])]) if not self.cf.resume: epoch_score_check = np.sum([metrics[criterion][epoch] * weight for criterion, weight in crita.items() if not np.isnan(metrics[criterion][epoch])]) assert np.all(epoch_score == epoch_score_check) self.model_index.loc[epoch, ["score", "criteria_values"]] = epoch_score, {cr: metrics[cr][-1] for cr in crita.keys()} nonna_ics = self.model_index["score"].dropna(axis=0).index order = np.argsort(self.model_index.loc[nonna_ics, "score"].to_numpy(), kind="stable")[::-1] self.model_index.loc[nonna_ics, "rank"] = np.argsort(order) + 1 # no zero-indexing for ranks (best rank is 1). rank = int(self.model_index.loc[epoch, "rank"]) if rank <= self.cf.save_n_models: name = '{}_best_params.pth'.format(epoch) if self.cf.server_env: IO_safe(torch.save, net.state_dict(), os.path.join(self.cf.fold_dir, name)) else: torch.save(net.state_dict(), os.path.join(self.cf.fold_dir, name)) self.model_index.loc[epoch, "file_name"] = name self.logger.info("saved current epoch {} at rank {}".format(epoch, rank)) clean_up = self.model_index.dropna(axis=0, subset=["file_name"]) clean_up = clean_up[clean_up["rank"] > self.cf.save_n_models] if clean_up.size > 0: file_name = clean_up["file_name"].to_numpy().item() subprocess.call("rm {}".format(os.path.join(self.cf.fold_dir, file_name)), shell=True) self.logger.info("removed outranked epoch {} at {}".format(clean_up.index.values.item(), os.path.join(self.cf.fold_dir, file_name))) self.model_index.loc[clean_up.index, "file_name"] = np.nan state = { 'epoch': epoch, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), 'model_index': self.model_index, } if self.cf.server_env: IO_safe(torch.save, state, os.path.join(self.cf.fold_dir, 'last_state.pth')) else: torch.save(state, os.path.join(self.cf.fold_dir, 'last_state.pth')) +def parse_params_for_optim(net: torch.nn.Module, weight_decay: float = 0., exclude_from_wd: Iterable = ("norm", "bias")): + """Format network parameters for the optimizer. + Convenience function to include options for group-specific settings like weight decay. + :param net: + :param weight_decay: + :param exclude_from_wd: List of strings of parameter-group names to exclude from weight decay. Options: "norm", "bias". + :return: + """ + # pytorch implements parameter groups as dicts {'params': ...} and + # weight decay as p.data.mul_(1 - group['lr'] * group['weight_decay']) + norm_types = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, + torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, + torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.SyncBatchNorm, torch.nn.LocalResponseNorm + ] + level_map = {"bias": "weight", + "norm": "module"} + type_map = {"norm": norm_types} + + exclude_from_wd = [str(name).lower() for name in exclude_from_wd] + exclude_weight_names = [k for k, v in level_map.items() if k in exclude_from_wd and v == "weight"] + exclude_module_types = tuple([type_ for k, v in level_map.items() if (k in exclude_from_wd and v == "module") + for type_ in type_map[k]]) + + print("excluding {} from weight decay.".format(exclude_from_wd)) + + with_dec, no_dec = [], [] + for name, module in net.named_modules(): + if isinstance(module, exclude_module_types): + no_dec.extend(module.parameters()) + else: + for param_name, param in module.named_parameters(): + if np.any([ename in param_name for ename in exclude_weight_names]): + no_dec.append(param) + else: + with_dec.append(param) + + groups = [{'params': gr, 'weight_decay': wd} for gr, wd in [(no_dec, 0.), (with_dec, weight_decay)] if len(gr) > 0] + + return groups + def load_checkpoint(checkpoint_path, net, optimizer, model_selector): checkpoint = torch.load(checkpoint_path) net.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) model_selector.model_index = checkpoint["model_index"] return checkpoint['epoch'] + 1, net, optimizer, model_selector - -def prepare_monitoring(cf): - """ - creates dictionaries, where train/val metrics are stored. - """ - metrics = {} - # first entry for loss dict accounts for epoch starting at 1. - metrics['train'] = OrderedDict() # [(l_name, [np.nan]) for l_name in cf.losses_to_monitor] ) - metrics['val'] = OrderedDict() # [(l_name, [np.nan]) for l_name in cf.losses_to_monitor] ) - metric_classes = [] - if 'rois' in cf.report_score_level: - metric_classes.extend([v for k, v in cf.class_dict.items()]) - if hasattr(cf, "eval_bins_separately") and cf.eval_bins_separately: - metric_classes.extend([v for k, v in cf.bin_dict.items()]) - if 'patient' in cf.report_score_level: - metric_classes.extend(['patient_' + cf.class_dict[cf.patient_class_of_interest]]) - if hasattr(cf, "eval_bins_separately") and cf.eval_bins_separately: - metric_classes.extend(['patient_' + cf.bin_dict[cf.patient_bin_of_interest]]) - for cl in metric_classes: - for m in cf.metrics: - metrics['train'][cl + '_' + m] = [np.nan] - metrics['val'][cl + '_' + m] = [np.nan] - - return metrics - - -class _AnsiColorizer(object): +def prep_exp(dataset_path, exp_path, server_env, use_stored_settings=True, is_training=True): """ - A colorizer is an object that loosely wraps around a stream, allowing - callers to write text to the stream in a particular color. - - Colorizer classes must implement C{supported()} and C{write(text, color)}. + I/O handling, creating of experiment folder structure. Also creates a snapshot of configs/model scripts and copies them to the exp_dir. + This way the exp_dir contains all info needed to conduct an experiment, independent to changes in actual source code. Thus, training/inference of this experiment can be started at anytime. + Therefore, the model script is copied back to the source code dir as tmp_model (tmp_backbone). + Provides robust structure for cloud deployment. + :param dataset_path: path to source code for specific data set. (e.g. medicaldetectiontoolkit/lidc_exp) + :param exp_path: path to experiment directory. + :param server_env: boolean flag. pass to configs script for cloud deployment. + :param use_stored_settings: boolean flag. When starting training: If True, starts training from snapshot in existing + experiment directory, else creates experiment directory on the fly using configs/model scripts from source code. + :param is_training: boolean flag. distinguishes train vs. inference mode. + :return: configs object. """ - _colors = dict(black=30, red=31, green=32, yellow=33, - blue=34, magenta=35, cyan=36, white=37, default=39) - - def __init__(self, stream): - self.stream = stream - - @classmethod - def supported(cls, stream=sys.stdout): - """ - A class method that returns True if the current platform supports - coloring terminal output using this method. Returns False otherwise. - """ - if not stream.isatty(): - return False # auto color only on TTYs - try: - import curses - except ImportError: - return False - else: - try: - try: - return curses.tigetnum("colors") > 2 - except curses.error: - curses.setupterm() - return curses.tigetnum("colors") > 2 - except: - raise - # guess false in case of error - return False - - def write(self, text, color): - """ - Write the given text to the stream in the given color. - @param text: Text to be written to the stream. - - @param color: A string label for a color. e.g. 'red', 'white'. - """ - color = self._colors[color] - self.stream.write('\x1b[%sm%s\x1b[0m' % (color, text)) + if is_training: + if use_stored_settings: + cf_file = import_module('cf', os.path.join(exp_path, 'configs.py')) + cf = cf_file.Configs(server_env) + # in this mode, previously saved model and backbone need to be found in exp dir. + if not os.path.isfile(os.path.join(exp_path, 'model.py')) or \ + not os.path.isfile(os.path.join(exp_path, 'backbone.py')): + raise Exception( + "Selected use_stored_settings option but no model and/or backbone source files exist in exp dir.") + cf.model_path = os.path.join(exp_path, 'model.py') + cf.backbone_path = os.path.join(exp_path, 'backbone.py') + else: # this case overwrites settings files in exp dir, i.e., default_configs, configs, backbone, model + os.makedirs(exp_path, exist_ok=True) + # run training with source code info and copy snapshot of model to exp_dir for later testing (overwrite scripts if exp_dir already exists.) + subprocess.call('cp {} {}'.format('default_configs.py', os.path.join(exp_path, 'default_configs.py')), + shell=True) + subprocess.call( + 'cp {} {}'.format(os.path.join(dataset_path, 'configs.py'), os.path.join(exp_path, 'configs.py')), + shell=True) + cf_file = import_module('cf_file', os.path.join(dataset_path, 'configs.py')) + cf = cf_file.Configs(server_env) + subprocess.call('cp {} {}'.format(cf.model_path, os.path.join(exp_path, 'model.py')), shell=True) + subprocess.call('cp {} {}'.format(cf.backbone_path, os.path.join(exp_path, 'backbone.py')), shell=True) + if os.path.isfile(os.path.join(exp_path, "fold_ids.pickle")): + subprocess.call('rm {}'.format(os.path.join(exp_path, "fold_ids.pickle")), shell=True) -class ColorHandler(logging.StreamHandler): + else: # testing, use model and backbone stored in exp dir. + cf_file = import_module('cf', os.path.join(exp_path, 'configs.py')) + cf = cf_file.Configs(server_env) + cf.model_path = os.path.join(exp_path, 'model.py') + cf.backbone_path = os.path.join(exp_path, 'backbone.py') - def __init__(self, stream=sys.stdout): - super(ColorHandler, self).__init__(_AnsiColorizer(stream)) + cf.exp_dir = exp_path + cf.test_dir = os.path.join(cf.exp_dir, 'test') + cf.plot_dir = os.path.join(cf.exp_dir, 'plots') + if not os.path.exists(cf.test_dir): + os.mkdir(cf.test_dir) + if not os.path.exists(cf.plot_dir): + os.mkdir(cf.plot_dir) + cf.experiment_name = exp_path.split("/")[-1] + cf.dataset_name = dataset_path + cf.server_env = server_env + cf.created_fold_id_pickle = False - def emit(self, record): - msg_colors = { - logging.DEBUG: "green", - logging.INFO: "default", - logging.WARNING: "red", - logging.ERROR: "red" - } - color = msg_colors.get(record.levelno, "blue") - self.stream.write(record.msg + "\n", color) + return cf