diff --git a/datasets/lidc/configs.py b/datasets/lidc/configs.py index e037756..413ce8f 100644 --- a/datasets/lidc/configs.py +++ b/datasets/lidc/configs.py @@ -1,445 +1,445 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import sys import os from collections import namedtuple sys.path.append(os.path.dirname(os.path.realpath(__file__))) import numpy as np sys.path.append(os.path.dirname(os.path.realpath(__file__))+"/../..") from default_configs import DefaultConfigs # legends, nested classes are not handled well in multiprocessing! hence, Label class def in outer scope Label = namedtuple("Label", ['id', 'name', 'color', 'm_scores']) # m_scores = malignancy scores binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'm_scores', 'bin_vals']) class Configs(DefaultConfigs): def __init__(self, server_env=None): super(Configs, self).__init__(server_env) ######################### # Preprocessing # ######################### self.root_dir = '/home/gregor/networkdrives/E130-Personal/Goetz/Datenkollektive/Lungendaten/Nodules_LIDC_IDRI' self.raw_data_dir = '{}/new_nrrd'.format(self.root_dir) self.pp_dir = '/media/gregor/HDD2TB/data/lidc/pp_20200309_dev' # 'merged' for one gt per image, 'single_annotator' for four gts per image. self.gts_to_produce = ["single_annotator", "merged"] self.target_spacing = (0.7, 0.7, 1.25) ######################### # I/O # ######################### # path to preprocessed data. #self.pp_name = 'pp_20190318' self.pp_name = 'pp_20200309_dev' self.input_df_name = 'info_df.pickle' self.data_sourcedir = '/media/gregor/HDD2TB/data/lidc/{}/'.format(self.pp_name) # settings for deployment on cluster. if server_env: # path to preprocessed data. self.data_sourcedir = '/datasets/data_ramien/lidc/{}_npz/'.format(self.pp_name) # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_fpn']. self.model = 'mrcnn' self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net') self.model_path = os.path.join(self.source_dir, self.model_path) ######################### # Architecture # ######################### # dimension the model operates in. one out of [2, 3]. self.dim = 2 # 'class': standard object classification per roi, pairwise combinable with each of below tasks. # if 'class' is omitted from tasks, object classes will be fg/bg (1/0) from RPN. # 'regression': regress some vector per each roi # 'regression_ken_gal': use kendall-gal uncertainty sigma # 'regression_bin': classify each roi into a bin related to a regression scale - self.prediction_tasks = ['regression'] + self.prediction_tasks = ['class'] self.start_filts = 48 if self.dim == 2 else 18 self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2 self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50' self.norm = None # one of None, 'instance_norm', 'batch_norm' # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform') self.weight_init = None self.regression_n_features = 1 ######################### # Data Loader # ######################### # distorted gt experiments: train on single-annotator gts in a random fashion to investigate network's # handling of noisy gts. # choose 'merged' for single, merged gt per image, or 'single_annotator' for four gts per image. # validation is always performed on same gt kind as training, testing always on merged gt. - self.training_gts = "sa" + self.training_gts = "merged" # select modalities from preprocessed data self.channels = [0] self.n_channels = len(self.channels) # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation. self.pre_crop_size_2D = [320, 320] self.patch_size_2D = [320, 320] self.pre_crop_size_3D = [160, 160, 96] self.patch_size_3D = [160, 160, 96] self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D self.pre_crop_size = self.pre_crop_size_2D if self.dim == 2 else self.pre_crop_size_3D # ratio of free sampled batch elements before class balancing is triggered # (>0 to include "empty"/background patches.) self.batch_random_ratio = 0.3 self.balance_target = "class_targets" if 'class' in self.prediction_tasks else 'rg_bin_targets' # set 2D network to match 3D gt boxes. self.merge_2D_to_3D_preds = self.dim==2 self.observables_rois = [] #self.rg_map = {1:1, 2:2, 3:3, 4:4, 5:5} ######################### # Colors and Legends # ######################### self.plot_frequency = 5 binary_cl_labels = [Label(1, 'benign', (*self.dark_green, 1.), (1, 2)), Label(2, 'malignant', (*self.red, 1.), (3, 4, 5))] quintuple_cl_labels = [Label(1, 'MS1', (*self.dark_green, 1.), (1,)), Label(2, 'MS2', (*self.dark_yellow, 1.), (2,)), Label(3, 'MS3', (*self.orange, 1.), (3,)), Label(4, 'MS4', (*self.bright_red, 1.), (4,)), Label(5, 'MS5', (*self.red, 1.), (5,))] # choose here if to do 2-way or 5-way regression-bin classification task_spec_cl_labels = quintuple_cl_labels self.class_labels = [ # #id #name #color #malignancy score Label( 0, 'bg', (*self.gray, 0.), (0,))] if "class" in self.prediction_tasks: self.class_labels += task_spec_cl_labels else: self.class_labels += [Label(1, 'lesion', (*self.orange, 1.), (1,2,3,4,5))] if any(['regression' in task for task in self.prediction_tasks]): self.bin_labels = [binLabel(0, 'MS0', (*self.gray, 1.), (0,), (0,))] self.bin_labels += [binLabel(cll.id, cll.name, cll.color, cll.m_scores, tuple([ms for ms in cll.m_scores])) for cll in task_spec_cl_labels] self.bin_id2label = {label.id: label for label in self.bin_labels} self.ms2bin_label = {ms: label for label in self.bin_labels for ms in label.m_scores} bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels] self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)} self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)] if self.class_specific_seg: self.seg_labels = self.class_labels else: self.seg_labels = [ # id #name #color Label(0, 'bg', (*self.gray, 0.)), Label(1, 'fg', (*self.orange, 1.)) ] self.class_id2label = {label.id: label for label in self.class_labels} self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0} # class_dict is used in evaluator / ap, auc, etc. statistics, and class 0 (bg) only needs to be # evaluated in debugging self.class_cmap = {label.id: label.color for label in self.class_labels} self.seg_id2label = {label.id: label for label in self.seg_labels} self.cmap = {label.id: label.color for label in self.seg_labels} self.plot_prediction_histograms = True self.plot_stat_curves = False self.has_colorchannels = False self.plot_class_ids = True self.num_classes = len(self.class_dict) # for instance classification (excl background) self.num_seg_classes = len(self.seg_labels) # incl background ######################### # Data Augmentation # ######################### self.da_kwargs={ 'mirror': True, 'mirror_axes': tuple(np.arange(0, self.dim, 1)), 'do_elastic_deform': True, 'alpha':(0., 1500.), 'sigma':(30., 50.), 'do_rotation':True, 'angle_x': (0., 2 * np.pi), 'angle_y': (0., 0), 'angle_z': (0., 0), 'do_scale': True, 'scale':(0.8, 1.1), 'random_crop':False, 'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3), 'border_mode_data': 'constant', 'border_cval_data': 0, 'order_data': 1} if self.dim == 3: self.da_kwargs['do_elastic_deform'] = False self.da_kwargs['angle_x'] = (0, 0.0) self.da_kwargs['angle_y'] = (0, 0.0) #must be 0!! self.da_kwargs['angle_z'] = (0., 2 * np.pi) ################################# # Schedule / Selection / Optim # ################################# self.num_epochs = 130 if self.dim == 2 else 150 self.num_train_batches = 200 if self.dim == 2 else 200 self.batch_size = 20 if self.dim == 2 else 8 # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training) # the former is morge accurate, while the latter is faster (depending on volume size) self.val_mode = 'val_sampling' # only 'val_sampling', 'val_patient' not implemented if self.val_mode == 'val_patient': raise NotImplementedError if self.val_mode == 'val_sampling': self.num_val_batches = 70 self.save_n_models = 4 # set a minimum epoch number for saving in case of instabilities in the first phase of training. self.min_save_thresh = 0 if self.dim == 2 else 0 # criteria to average over for saving epochs, 'criterion':weight. if "class" in self.prediction_tasks: # 'criterion': weight if len(self.class_labels)==3: self.model_selection_criteria = {"benign_ap": 0.5, "malignant_ap": 0.5} elif len(self.class_labels)==6: self.model_selection_criteria = {str(label.name)+"_ap": 1./5 for label in self.class_labels if label.id!=0} elif any("regression" in task for task in self.prediction_tasks): self.model_selection_criteria = {"lesion_ap": 0.2, "lesion_avp": 0.8} self.weight_decay = 0 self.clip_norm = 200 if 'regression_ken_gal' in self.prediction_tasks else None # number or None # int in [0, dataset_size]. select n patients from dataset for prototyping. If None, all data is used. self.select_prototype_subset = None #self.batch_size ######################### # Testing # ######################### # set the top-n-epochs to be saved for temporal averaging in testing. self.test_n_epochs = self.save_n_models self.test_aug_axes = (0,1,(0,1)) # None or list: choices are 0,1,(0,1) (0==spatial y, 1== spatial x). self.held_out_test_set = False self.max_test_patients = "all" # "all" or number self.report_score_level = ['rois', 'patient'] # choose list from 'patient', 'rois' self.patient_class_of_interest = 2 if 'class' in self.prediction_tasks else 1 self.metrics = ['ap', 'auc'] if any(['regression' in task for task in self.prediction_tasks]): self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp', 'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp'] if 'aleatoric' in self.model: self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted'] self.evaluate_fold_means = True self.ap_match_ious = [0.1] # list of ious to be evaluated for ap-scoring. self.min_det_thresh = 0.1 # minimum confidence value to select predictions for evaluation. # aggregation method for test and val_patient predictions. # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf, # nms = standard non-maximum suppression, or None = no clustering self.clustering = 'wbc' # iou thresh (exclusive!) for regarding two preds as concerning the same ROI self.clustering_iou = 0.1 # has to be larger than desired possible overlap iou of model predictions self.plot_prediction_histograms = True self.plot_stat_curves = False self.n_test_plots = 1 ######################### # Assertions # ######################### if not 'class' in self.prediction_tasks: assert self.num_classes == 1 ######################### # Add model specifics # ######################### {'detection_fpn': self.add_det_fpn_configs, 'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs, 'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs, }[self.model]() def rg_val_to_bin_id(self, rg_val): return float(np.digitize(np.mean(rg_val), self.bin_edges)) def add_det_fpn_configs(self): self.learning_rate = [1e-4] * self.num_epochs self.dynamic_lr_scheduling = False # RoI score assigned to aggregation from pixel prediction (connected component). One of ['max', 'median']. self.score_det = 'max' # max number of roi candidates to identify per batch element and class. self.n_roi_candidates = 10 if self.dim == 2 else 30 # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') self.seg_loss_mode = 'wce' # if <1, false positive predictions in foreground are penalized less. self.fp_dice_weight = 1 if self.dim == 2 else 1 if len(self.class_labels)==3: self.wce_weights = [1., 1., 1.] if self.seg_loss_mode=="dice_wce" else [0.1, 1., 1.] elif len(self.class_labels)==6: self.wce_weights = [1., 1., 1., 1., 1., 1.] if self.seg_loss_mode == "dice_wce" else [0.1, 1., 1., 1., 1., 1.] else: raise Exception("mismatch loss weights & nr of classes") self.detection_min_confidence = self.min_det_thresh self.head_classes = self.num_seg_classes def add_mrcnn_configs(self): # learning rate is a list with one entry per epoch. self.learning_rate = [1e-4] * self.num_epochs self.dynamic_lr_scheduling = False # disable the re-sampling of mask proposals to original size for speed-up. # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching), # mask-outputs are optional. self.return_masks_in_train = False self.return_masks_in_val = True self.return_masks_in_test = False # set number of proposal boxes to plot after each epoch. self.n_plot_rpn_props = 5 if self.dim == 2 else 30 # number of classes for network heads: n_foreground_classes + 1 (background) self.head_classes = self.num_classes + 1 self.frcnn_mode = False # feature map strides per pyramid level are inferred from architecture. self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]} # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.) self.rpn_anchor_scales = {'xy': [[8], [16], [32], [64]], 'z': [[2], [4], [8], [16]]} # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3. self.pyramid_levels = [0, 1, 2, 3] # number of feature maps in rpn. typically lowered in 3D to save gpu-memory. self.n_rpn_features = 512 if self.dim == 2 else 128 # anchor ratios and strides per position in feature maps. self.rpn_anchor_ratios = [0.5, 1, 2] self.rpn_anchor_stride = 1 # Threshold for first stage (RPN) non-maximum suppression (NMS): LOWER == HARDER SELECTION self.rpn_nms_threshold = 0.7 if self.dim == 2 else 0.7 # loss sampling settings. self.rpn_train_anchors_per_image = 6 #per batch element self.train_rois_per_image = 6 #per batch element self.roi_positive_ratio = 0.5 self.anchor_matching_iou = 0.7 # factor of top-k candidates to draw from per negative sample (stochastic-hard-example-mining). # poolsize to draw top-k candidates from will be shem_poolsize * n_negative_samples. self.shem_poolsize = 10 self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3) self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5) self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10) self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]]) self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1], self.patch_size_3D[2], self.patch_size_3D[2]]) if self.dim == 2: self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4] self.bbox_std_dev = self.bbox_std_dev[:4] self.window = self.window[:4] self.scale = self.scale[:4] # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element. self.pre_nms_limit = 3000 if self.dim == 2 else 6000 # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True, # since proposals of the entire batch are forwarded through second stage in as one "batch". self.roi_chunk_size = 2500 if self.dim == 2 else 600 self.post_nms_rois_training = 500 if self.dim == 2 else 75 self.post_nms_rois_inference = 500 # Final selection of detections (refine_detections) self.model_max_instances_per_batch_element = 10 if self.dim == 2 else 30 # per batch element and class. self.detection_nms_threshold = 1e-5 # needs to be > 0, otherwise all predictions are one cluster. self.model_min_confidence = 0.1 if self.dim == 2: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride))] for stride in self.backbone_strides['xy']]) else: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride)), int(np.ceil(self.patch_size[2] / stride_z))] for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z'] )]) if self.model == 'retina_net' or self.model == 'retina_unet': self.focal_loss = True # implement extra anchor-scales according to retina-net publication. self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['xy']] self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['z']] self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3 self.n_rpn_features = 256 if self.dim == 2 else 128 # pre-selection of detections for NMS-speedup. per entire batch. self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002 self.anchor_matching_iou = 0.5 if self.model == 'retina_unet': self.operate_stride1 = True diff --git a/datasets/lidc/data_loader.py b/datasets/lidc/data_loader.py index b04b9b4..222364d 100644 --- a/datasets/lidc/data_loader.py +++ b/datasets/lidc/data_loader.py @@ -1,1025 +1,1014 @@ # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== ''' Data Loader for the LIDC data set. This dataloader expects preprocessed data in .npy or .npz files per patient and a pandas dataframe containing the meta info e.g. file paths, and some ground-truth info like labels, foreground slice ids. LIDC 4-fold annotations storage capacity problem: keep segmentation gts compressed (npz), unpack at each batch generation. ''' import plotting as plg import os import pickle import time +from multiprocessing import Pool import numpy as np import pandas as pd from collections import OrderedDict # batch generator tools from https://github.com/MIC-DKFZ/batchgenerators from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror from batchgenerators.transforms.abstract_transforms import Compose from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter from batchgenerators.transforms.spatial_transforms import SpatialTransform from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform import utils.dataloader_utils as dutils from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates def save_obj(obj, name): """Pickle a python object.""" with open(name + '.pkl', 'wb') as f: pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) def vector(item): """ensure item is vector-like (list or array or tuple) :param item: anything """ if not isinstance(item, (list, tuple, np.ndarray)): item = [item] return item class Dataset(dutils.Dataset): r"""Load a dict holding memmapped arrays and clinical parameters for each patient, evtly subset of those. If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to cf.data_dest. :param cf: config object. :param logger: logger. :param subset_ids: subset of patient/sample identifiers to load from whole set. :param data_sourcedir: directory in which to find data, defaults to cf.data_sourcedir if None. :return: dict with imgs, segs, pids, class_labels, observables """ def __init__(self, cf, logger=None, subset_ids=None, data_sourcedir=None, mode='train'): super(Dataset,self).__init__(cf, data_sourcedir) if mode == 'train' and not cf.training_gts == "merged": self.gt_dir = "patient_gts_sa" self.gt_kind = cf.training_gts else: self.gt_dir = "patient_gts_merged" self.gt_kind = "merged" if logger is not None: logger.info("loading {} ground truths for {}".format(self.gt_kind, 'training and validation' if mode=='train' else 'testing')) p_df = pd.read_pickle(os.path.join(self.data_sourcedir, self.gt_dir, cf.input_df_name)) #exclude_pids = ["0305a", "0447a"] # due to non-bg segmentation but bg mal label in nodules 5728, 8840 #p_df = p_df[~p_df.pid.isin(exclude_pids)] if subset_ids is not None: p_df = p_df[p_df.pid.isin(subset_ids)] if logger is not None: logger.info('subset: selected {} instances from df'.format(len(p_df))) if cf.select_prototype_subset is not None: prototype_pids = p_df.pid.tolist()[:cf.select_prototype_subset] p_df = p_df[p_df.pid.isin(prototype_pids)] if logger is not None: logger.warning('WARNING: using prototyping data subset of length {}!!!'.format(len(p_df))) pids = p_df.pid.tolist() # evtly copy data from data_sourcedir to data_dest if cf.server_env and not hasattr(cf, 'data_dir') and hasattr(cf, "data_dest"): # copy and unpack images file_subset = ["{}_img.npz".format(pid) for pid in pids if not os.path.isfile(os.path.join(cf.data_dest,'{}_img.npy'.format(pid)))] file_subset += [os.path.join(self.data_sourcedir, self.gt_dir, cf.input_df_name)] self.copy_data(cf, file_subset=file_subset, keep_packed=False, del_after_unpack=True) # copy and do not unpack segmentations file_subset = [os.path.join(self.gt_dir, "{}_rois.np*".format(pid)) for pid in pids] keep_packed = not cf.training_gts == "merged" self.copy_data(cf, file_subset=file_subset, keep_packed=keep_packed, del_after_unpack=(not keep_packed)) else: cf.data_dir = self.data_sourcedir ext = 'npy' if self.gt_kind == "merged" else 'npz' imgs = [os.path.join(self.data_dir, '{}_img.npy'.format(pid)) for pid in pids] segs = [os.path.join(self.data_dir, self.gt_dir, '{}_rois.{}'.format(pid, ext)) for pid in pids] orig_class_targets = p_df['class_target'].tolist() data = OrderedDict() if self.gt_kind == 'merged': for ix, pid in enumerate(pids): data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid} data[pid]['fg_slices'] = np.array(p_df['fg_slices'].tolist()[ix]) if 'class' in cf.prediction_tasks: if len(cf.class_labels)==3: # malignancy scores are binarized: (benign: 1-2 --> cl 1, malignant: 3-5 --> cl 2) data[pid]['class_targets'] = np.array([2 if ii >= 3 else 1 for ii in orig_class_targets[ix]], dtype='uint8') elif len(cf.class_labels)==6: # classify each malignancy score data[pid]['class_targets'] = np.array([1 if ii==0.5 else np.round(ii) for ii in orig_class_targets[ix]], dtype='uint8') else: raise Exception("mismatch class labels and data-loading implementations.") else: data[pid]['class_targets'] = np.ones_like(np.array(orig_class_targets[ix]), dtype='uint8') if any(['regression' in task for task in cf.prediction_tasks]): data[pid]["regression_targets"] = np.array([vector(v) for v in orig_class_targets[ix]], dtype='float16') data[pid]["rg_bin_targets"] = np.array( [cf.rg_val_to_bin_id(v) for v in data[pid]["regression_targets"]], dtype='uint8') else: for ix, pid in enumerate(pids): data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid} data[pid]['fg_slices'] = np.array(p_df['fg_slices'].values[ix]) if 'class' in cf.prediction_tasks: # malignancy scores are binarized: (benign: 1-2 --> cl 1, malignant: 3-5 --> cl 2) raise NotImplementedError # todo need to consider bg # data[pid]['class_targets'] = np.array( # [[2 if ii >= 3 else 1 for ii in four_fold_targs] for four_fold_targs in orig_class_targets[ix]]) else: data[pid]['class_targets'] = np.array( [[1 if ii > 0 else 0 for ii in four_fold_targs] for four_fold_targs in orig_class_targets[ix]], dtype='uint8') if any(['regression' in task for task in cf.prediction_tasks]): data[pid]["regression_targets"] = np.array( [[vector(v) for v in four_fold_targs] for four_fold_targs in orig_class_targets[ix]], dtype='float16') data[pid]["rg_bin_targets"] = np.array( [[cf.rg_val_to_bin_id(v) for v in four_fold_targs] for four_fold_targs in data[pid]["regression_targets"]], dtype='uint8') cf.roi_items = cf.observables_rois[:] cf.roi_items += ['class_targets'] if any(['regression' in task for task in cf.prediction_tasks]): cf.roi_items += ['regression_targets'] cf.roi_items += ['rg_bin_targets'] self.data = data self.set_ids = np.array(list(self.data.keys())) self.df = None # merged GTs class BatchGenerator_merged(dutils.BatchGenerator): """ creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D) from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size. Actual patch_size is obtained after data augmentation. :param data: data dictionary as provided by 'load_dataset'. :param batch_size: number of patients to sample for the batch :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target """ def __init__(self, cf, data, name="train"): super(BatchGenerator_merged, self).__init__(cf, data) self.crop_margin = np.array(self.cf.patch_size)/8. #min distance of ROI center to edge of cropped_patch. self.p_fg = 0.5 self.empty_samples_max_ratio = 0.6 self.random_count = int(cf.batch_random_ratio * cf.batch_size) self.class_targets = {k: v["class_targets"] for (k, v) in self._data.items()} self.balance_target_distribution(plot=name=="train") def generate_train_batch(self): # samples patients towards equilibrium of foreground classes on a roi-level after sampling a random ratio # fully random patients batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False)) # target-balanced patients batch_patient_ids += list(np.random.choice(self.dataset_pids, size=self.batch_size-self.random_count, replace=False, p=self.p_probs)) batch_data, batch_segs, batch_pids, batch_patient_labels = [], [], [], [] batch_roi_items = {name: [] for name in self.cf.roi_items} # record roi count of classes in batch batch_roi_counts = np.zeros((len(self.unique_ts),), dtype='uint32') batch_empty_counts = np.zeros((len(self.unique_ts),), dtype='uint32') # empty count for full bg samples (empty slices in 2D/patients in 3D) per class for sample in range(self.batch_size): patient = self._data[batch_patient_ids[sample]] data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis] seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0)) batch_pids.append(patient['pid']) (c, y, x, z) = data.shape if self.cf.dim == 2: elig_slices, choose_fg = [], False if len(patient['fg_slices']) > 0: if np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or \ np.random.rand(1)<=self.p_fg: # fg is to be picked for tix in np.argsort(batch_roi_counts): # pick slices of patient that have roi of sought-for target # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero( patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0])-1] == self.unique_ts[tix]) > 0] if len(elig_slices) > 0: choose_fg = True break else: # pick bg elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices']) if len(elig_slices)>0: sl_pick_ix = np.random.choice(elig_slices, size=None) else: sl_pick_ix = np.random.choice(z, size=None) data = data[..., sl_pick_ix] seg = seg[..., sl_pick_ix] # pad data if smaller than pre_crop_size. if np.any([data.shape[dim + 1] < ps for dim, ps in enumerate(self.cf.pre_crop_size)]): new_shape = [np.max([data.shape[dim + 1], ps]) for dim, ps in enumerate(self.cf.pre_crop_size)] data = dutils.pad_nd_image(data, new_shape, mode='constant') seg = dutils.pad_nd_image(seg, new_shape, mode='constant') # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg. crop_dims = [dim for dim, ps in enumerate(self.cf.pre_crop_size) if data.shape[dim + 1] > ps] if len(crop_dims) > 0: if self.cf.dim == 3: choose_fg = np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio)\ or np.random.rand(1) <= self.p_fg if choose_fg and np.any(seg): available_roi_ids = np.unique(seg)[1:] for tix in np.argsort(batch_roi_counts): elig_roi_ids = available_roi_ids[patient[self.balance_target][available_roi_ids-1] == self.unique_ts[tix]] if len(elig_roi_ids)>0: seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None)) break roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)] assert seg[tuple(roi_anchor_pixel)] > 0 # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by # distance to the desired ROI < patch_size /2. # (here final patch size to account for center_crop after data augmentation). sample_seg_center = {} for ii in crop_dims: low = np.max((self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] - (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) high = np.min((data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] + (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) # happens if lesion on the edge of the image. dont care about roi anymore, # just make sure pre-crop is inside image. if low >= high: low = data.shape[ii + 1] // 2 - (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) high = data.shape[ii + 1] // 2 + (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) sample_seg_center[ii] = np.random.randint(low=low, high=high) else: # not guaranteed to be empty. probability of emptiness depends on the data. sample_seg_center = {ii: np.random.randint(low=self.cf.pre_crop_size[ii]//2, high=data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2) for ii in crop_dims} for ii in crop_dims: min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2) max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2) data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1) seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii) batch_data.append(data) batch_segs.append(seg[np.newaxis]) for o in batch_roi_items: #after loop, holds every entry of every batchpatient per roi-item batch_roi_items[o].append(patient[o]) if self.cf.dim == 3: for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero==0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 elif self.cf.dim == 2: for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero == 0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 data = np.array(batch_data).astype(np.float16) seg = np.array(batch_segs).astype(np.uint8) batch = {'data': data, 'seg': seg, 'pid': batch_pids, 'roi_counts':batch_roi_counts, 'empty_counts': batch_empty_counts} for key,val in batch_roi_items.items(): #extend batch dic by roi-wise items (obs, class ids, regression vectors...) batch[key] = np.array(val) return batch class PatientBatchIterator_merged(dutils.PatientBatchIterator): """ creates a test generator that iterates over entire given dataset returning 1 patient per batch. Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actualy evaluation (done in 3D), if willing to accept speed-loss during training. :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or batch_size = n_2D_patches in 2D . """ def __init__(self, cf, data): # threads in augmenter super(PatientBatchIterator_merged, self).__init__(cf, data) self.patient_ix = 0 self.patch_size = cf.patch_size + [1] if cf.dim == 2 else cf.patch_size def generate_train_batch(self, pid=None): if pid is None: pid = self.dataset_pids[self.patient_ix] patient = self._data[pid] data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0)) seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0)) # pad data if smaller than patch_size seen during training. if np.any([data.shape[dim] < ps for dim, ps in enumerate(self.patch_size)]): new_shape = [np.max([data.shape[dim], self.patch_size[dim]]) for dim, ps in enumerate(self.patch_size)] data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape. seg = dutils.pad_nd_image(seg, new_shape) # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor. if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: out_data = data[np.newaxis, np.newaxis] out_seg = seg[np.newaxis, np.newaxis] batch_3D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_3D[o] = np.array([patient[o]]) converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg) batch_3D = converter(**batch_3D) batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_3D["patient_" + o] = batch_3D[o] if self.cf.dim == 2: out_data = np.transpose(data, axes=(2, 0, 1))[:, np.newaxis] # (z, c, x, y ) out_seg = np.transpose(seg, axes=(2, 0, 1))[:, np.newaxis] batch_2D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_2D[o] = np.repeat(np.array([patient[o]]), out_data.shape[0], axis=0) converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg) batch_2D = converter(**batch_2D) if self.cf.merge_2D_to_3D_preds: batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_3D[o] else: batch_2D.update({'patient_bb_target': batch_2D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_2D[o] out_batch = batch_3D if self.cf.dim == 3 else batch_2D out_batch.update({'pid': np.array([patient['pid']] * len(out_data))}) # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension. # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1. if np.any([data.shape[dim] > self.patch_size[dim] for dim in range(3)]): patient_batch = out_batch patch_crop_coords_list = dutils.get_patch_crop_coords(data, self.patch_size) new_img_batch, new_seg_batch = [], [] for cix, c in enumerate(patch_crop_coords_list): seg_patch = seg[c[0]:c[1], c[2]: c[3], c[4]:c[5]] new_seg_batch.append(seg_patch) tmp_c_5 = c[5] new_img_batch.append(data[c[0]:c[1], c[2]:c[3], c[4]:tmp_c_5]) data = np.array(new_img_batch)[:, np.newaxis] # (n_patches, c, x, y, z) seg = np.array(new_seg_batch)[:, np.newaxis] # (n_patches, 1, x, y, z) if self.cf.dim == 2: # all patches have z dimension 1 (slices). discard dimension data = data[..., 0] seg = seg[..., 0] patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'), 'pid': np.array([patient['pid']] * data.shape[0])} for o in self.cf.roi_items: patch_batch[o] = np.repeat(np.array([patient[o]]), len(patch_crop_coords_list), axis=0) # patient-wise (orig) batch info for putting the patches back together after prediction for o in self.cf.roi_items: patch_batch["patient_" + o] = patient_batch['patient_' + o] if self.cf.dim == 2: # this could also be named "unpatched_2d_roi_items" patch_batch["patient_" + o + "_2d"] = patient_batch[o] # adding patient-wise data and seg adds about 2 GB of additional RAM consumption to a batch 20x288x288 # and enables calculating test-dice/viewing patient-wise results in test # remove, but also remove dice from metrics, when like to save memory patch_batch['patient_data'] = patient_batch['data'] patch_batch['patient_seg'] = patient_batch['seg'] patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] if self.cf.dim == 2: patch_batch['patient_bb_target_2d'] = patient_batch['bb_target'] patch_batch['original_img_shape'] = patient_batch['original_img_shape'] converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, self.cf.class_specific_seg) patch_batch = converter(**patch_batch) out_batch = patch_batch self.patient_ix += 1 if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 return out_batch # single-annotator GTs class BatchGenerator_sa(dutils.BatchGenerator): """ creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D) from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size. Actual patch_size is obtained after data augmentation. :param data: data dictionary as provided by 'load_dataset'. :param batch_size: number of patients to sample for the batch :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target """ # noinspection PyMethodOverriding def balance_target_distribution(self, rater, plot=False): """ :param rater: for which rater slot to generate the distribution :param self.targets: dic holding {patient_specifier : patient-wise-unique ROI targets} :param plot: whether to plot the generated patient distributions :return: probability distribution over all pids. draw without replace from this. """ unique_ts = np.unique([v[rater] for pat in self.targets.values() for v in pat]) sample_stats = pd.DataFrame(columns=[str(ix) + suffix for ix in unique_ts for suffix in ["", "_bg"]], index=list(self.targets.keys())) for pid in sample_stats.index: for targ in unique_ts: fg_count = 0 if len(self.targets[pid]) == 0 else np.count_nonzero(self.targets[pid][:, rater] == targ) sample_stats.loc[pid, str(targ)] = int(fg_count > 0) sample_stats.loc[pid, str(targ) + "_bg"] = int(fg_count == 0) target_stats = sample_stats.agg( ("sum", lambda col: col.sum() / len(self._data)), axis=0, sort=False).rename({"": "relative"}) anchor = 1. - target_stats.loc["relative"].iloc[0] - self.fg_bg_weights = anchor / target_stats.loc["relative"] - cum_weights = anchor * len(self.fg_bg_weights) - self.fg_bg_weights /= cum_weights + fg_bg_weights = anchor / target_stats.loc["relative"] + cum_weights = anchor * len(fg_bg_weights) + fg_bg_weights /= cum_weights - p_probs = sample_stats.apply(self.sample_targets_to_weights, axis=1).sum(axis=1) + p_probs = sample_stats.apply(self.sample_targets_to_weights, args=(fg_bg_weights,), axis=1).sum(axis=1) p_probs = p_probs / p_probs.sum() if plot: - print("Rater: {}. Applying class-weights:\n {}".format(rater, self.fg_bg_weights)) + print("Rater: {}. Applying class-weights:\n {}".format(rater, fg_bg_weights)) if len(sample_stats.columns) == 2: # assert that probs are calc'd correctly: # (p_probs * sample_stats["1"]).sum() == (p_probs * sample_stats["1_bg"]).sum() # only works if one label per patient (multi-label expectations depend on multi-label occurences). - expectations = [] - for targ in sample_stats.columns: - expectations.append((p_probs * sample_stats[targ]).sum()) - assert np.allclose(expectations, expectations[0], atol=1e-4), "expectation values for fgs/bgs: {}".format( - expectations) - - - # # get unique foreground targets per patient, assign -1 to an "empty" patient (has no foreground) - # patient_ts = [[roi[rater] for roi in patient_rois_lst] for patient_rois_lst in self.targets.values()] - # # assign [-1] to empty patients - # patient_ts = [np.unique(lst) if len([t for t in lst if np.any(t>0)])>0 else [-1] for lst in patient_ts] - # #bg_mask = np.array([np.all(lst == [-1]) for lst in patient_ts]) - # # sort out bg labels (are 0) - # unique_ts, t_counts = np.unique([t for lst in patient_ts for t in lst if t>0], return_counts=True) - # t_probs = t_counts.sum() / t_counts - # t_probs /= t_probs.sum() - # t_probs = {t : t_probs[ix] for ix, t in enumerate(unique_ts)} - # t_probs[-1] = 0. - # t_probs[0] = 0. - # # fail if balance target is not a number (i.e., a vector) - # p_probs = np.array([ max([t_probs[t] for t in lst]) for lst in patient_ts ]) - # #normalize - # p_probs /= p_probs.sum() + for rater in range(self.rater_bsize): + expectations = [] + for targ in sample_stats.columns: + expectations.append((p_probs[rater] * sample_stats[targ]).sum()) + assert np.allclose(expectations, expectations[0], atol=1e-4), "expectation values for fgs/bgs: {}".format( + expectations) if plot: plg.plot_batchgen_distribution(self.cf, self.dataset_pids, p_probs, self.balance_target, - out_file=os.path.join(self.cf.plot_dir, + out_file=os.path.join(self.plot_dir, "train_gen_distr_"+str(self.cf.fold)+"_rater"+str(rater)+".png")) return p_probs, unique_ts, sample_stats def __init__(self, cf, data, name="train"): super(BatchGenerator_sa, self).__init__(cf, data) self.name = name self.crop_margin = np.array(self.cf.patch_size) / 8. # min distance of ROI center to edge of cropped_patch. self.p_fg = 0.5 self.empty_samples_max_ratio = 0.6 self.random_count = int(cf.batch_random_ratio * cf.batch_size) self.rater_bsize = 4 unique_ts_total = set() self.p_probs = [] self.sample_stats = [] - for r in range(self.rater_bsize): - # todo multiprocess. takes forever - p_probs, unique_ts, sample_stats = self.balance_target_distribution(r, plot=name=="train") + p = Pool(processes=cf.n_workers) + mp_res = p.starmap(self.balance_target_distribution, [(r, name=="train") for r in range(self.rater_bsize)]) + p.close() + p.join() + + for r, res in enumerate(mp_res): + p_probs, unique_ts, sample_stats = res self.p_probs.append(p_probs) self.sample_stats.append(sample_stats) unique_ts_total.update(unique_ts) + self.unique_ts = sorted(list(unique_ts_total)) def generate_train_batch(self): rater = np.random.randint(self.rater_bsize) # samples patients towards equilibrium of foreground classes on a roi-level (after randomly sampling the ratio batch_random_ratio). # random patients batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False)) # target-balanced patients batch_patient_ids += list(np.random.choice(self.dataset_pids, size=self.batch_size-self.random_count, replace=False, p=self.p_probs[rater])) batch_data, batch_segs, batch_pids, batch_patient_labels = [], [], [], [] batch_roi_items = {name: [] for name in self.cf.roi_items} # record roi count of classes in batch batch_roi_counts = np.zeros((len(self.unique_ts),), dtype='uint32') batch_empty_counts = np.zeros((len(self.unique_ts),), dtype='uint32') # empty count for full bg samples (empty slices in 2D/patients in 3D) for sample in range(self.batch_size): patient = self._data[batch_patient_ids[sample]] patient_balance_ts = np.array([roi[rater] for roi in patient[self.balance_target]]) data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis] seg = np.load(patient['seg'], mmap_mode='r') seg = np.transpose(seg[list(seg.keys())[0]][rater], axes=(1, 2, 0)) batch_pids.append(patient['pid']) (c, y, x, z) = data.shape if self.cf.dim == 2: elig_slices, choose_fg = [], False if len(patient['fg_slices']) > 0: if np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or \ np.random.rand(1) <= self.p_fg: # fg is to be picked for tix in np.argsort(batch_roi_counts): # pick slices of patient that have roi of sought-for target # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero( patient_balance_ts[np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] == self.unique_ts[tix]) > 0] if len(elig_slices) > 0: choose_fg = True break else: # pick bg elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices'][rater]) if len(elig_slices) > 0: sl_pick_ix = np.random.choice(elig_slices, size=None) else: sl_pick_ix = np.random.choice(z, size=None) data = data[..., sl_pick_ix] seg = seg[..., sl_pick_ix] # pad data if smaller than pre_crop_size. if np.any([data.shape[dim + 1] < ps for dim, ps in enumerate(self.cf.pre_crop_size)]): new_shape = [np.max([data.shape[dim + 1], ps]) for dim, ps in enumerate(self.cf.pre_crop_size)] data = dutils.pad_nd_image(data, new_shape, mode='constant') seg = dutils.pad_nd_image(seg, new_shape, mode='constant') # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg. crop_dims = [dim for dim, ps in enumerate(self.cf.pre_crop_size) if data.shape[dim + 1] > ps] if len(crop_dims) > 0: if self.cf.dim == 3: choose_fg = np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or \ np.random.rand(1) <= self.p_fg if choose_fg and np.any(seg): available_roi_ids = np.unique(seg[seg>0]) assert np.all(patient_balance_ts[available_roi_ids-1]>0), "trying to choose roi with rating 0" for tix in np.argsort(batch_roi_counts): elig_roi_ids = available_roi_ids[ patient_balance_ts[available_roi_ids-1] == self.unique_ts[tix] ] if len(elig_roi_ids)>0: seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None)) roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)] break assert seg[tuple(roi_anchor_pixel)] > 0, "roi_anchor_pixel not inside roi: {}, pb_ts {}, elig ids {}".format(tuple(roi_anchor_pixel), patient_balance_ts, elig_roi_ids) # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by # distance to the desired ROI < patch_size /2. # (here final patch size to account for center_crop after data augmentation). sample_seg_center = {} for ii in crop_dims: low = np.max((self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] - (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) high = np.min((data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] + (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) # happens if lesion on the edge of the image. dont care about roi anymore, # just make sure pre-crop is inside image. if low >= high: low = data.shape[ii + 1] // 2 - (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) high = data.shape[ii + 1] // 2 + (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) sample_seg_center[ii] = np.random.randint(low=low, high=high) else: # not guaranteed to be empty. probability of emptiness depends on the data. sample_seg_center = {ii: np.random.randint(low=self.cf.pre_crop_size[ii]//2, high=data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2) for ii in crop_dims} for ii in crop_dims: min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2) max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2) data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1) seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii) batch_data.append(data) batch_segs.append(seg[np.newaxis]) for o in batch_roi_items: #after loop, holds every entry of every batchpatient per roi-item batch_roi_items[o].append([roi[rater] for roi in patient[o]]) if self.cf.dim == 3: for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero==0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 elif self.cf.dim == 2: for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero == 0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 data = np.array(batch_data).astype('float16') seg = np.array(batch_segs).astype('uint8') batch = {'data': data, 'seg': seg, 'pid': batch_pids, 'rater_id': rater, 'roi_counts':batch_roi_counts, 'empty_counts': batch_empty_counts} for key,val in batch_roi_items.items(): #extend batch dic by roi-wise items (obs, class ids, regression vectors...) batch[key] = np.array(val) return batch class PatientBatchIterator_sa(dutils.PatientBatchIterator): """ creates a test generator that iterates over entire given dataset returning 1 patient per batch. Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actual evaluation (done in 3D), if willing to accept speed loss during training. :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or batch_size = n_2D_patches in 2D . This is the data & gt loader for the 4-fold single-annotator GTs: each data input has separate annotations of 4 annotators. the way the pipeline is currently setup, the single-annotator GTs are only used if training with validation mode val_patient; during testing the Iterator with the merged GTs is used. # todo mode val_patient not implemented yet (since very slow). would need to sample from all available rater GTs. """ def __init__(self, cf, data): #threads in augmenter super(PatientBatchIterator_sa, self).__init__(cf, data) self.cf = cf self.patient_ix = 0 self.dataset_pids = list(self._data.keys()) self.patch_size = cf.patch_size+[1] if cf.dim==2 else cf.patch_size self.rater_bsize = 4 def generate_train_batch(self, pid=None): if pid is None: pid = self.dataset_pids[self.patient_ix] patient = self._data[pid] data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0)) # all gts are 4-fold and npz! seg = np.load(patient['seg'], mmap_mode='r') seg = np.transpose(seg[list(seg.keys())[0]], axes=(0, 2, 3, 1)) # pad data if smaller than patch_size seen during training. if np.any([data.shape[dim] < ps for dim, ps in enumerate(self.patch_size)]): new_shape = [np.max([data.shape[dim], self.patch_size[dim]]) for dim, ps in enumerate(self.patch_size)] data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape. seg = dutils.pad_nd_image(seg, new_shape) # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor. if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: out_data = data[np.newaxis, np.newaxis] out_seg = seg[:, np.newaxis] batch_3D = {'data': out_data, 'seg': out_seg} for item in self.cf.roi_items: batch_3D[item] = [] for r in range(self.rater_bsize): for item in self.cf.roi_items: batch_3D[item].append(np.array([roi[r] for roi in patient[item]])) converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg) batch_3D = converter(**batch_3D) batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_3D["patient_" + o] = batch_3D[o] if self.cf.dim == 2: out_data = np.transpose(data, axes=(2, 0, 1))[:, np.newaxis] # (z, c, y, x ) out_seg = np.transpose(seg, axes=(0, 3, 1, 2))[:, :, np.newaxis] # (n_raters, z, 1, y,x) batch_2D = {'data': out_data} for item in ["seg", "bb_target"]+self.cf.roi_items: batch_2D[item] = [] converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg) for r in range(self.rater_bsize): tmp_batch = {"seg": out_seg[r]} for item in self.cf.roi_items: tmp_batch[item] = np.repeat(np.array([[roi[r] for roi in patient[item]]]), out_data.shape[0], axis=0) tmp_batch = converter(**tmp_batch) for item in ["seg", "bb_target"]+self.cf.roi_items: batch_2D[item].append(tmp_batch[item]) # for item in ["seg", "bb_target"]+self.cf.roi_items: # batch_2D[item] = np.array(batch_2D[item]) if self.cf.merge_2D_to_3D_preds: batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_3D[o] else: batch_2D.update({'patient_bb_target': batch_2D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_2D[o] out_batch = batch_3D if self.cf.dim == 3 else batch_2D out_batch.update({'pid': np.array([patient['pid']] * out_data.shape[0])}) # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension. # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1. if np.any([data.shape[dim] > self.patch_size[dim] for dim in range(3)]): patient_batch = out_batch patch_crop_coords_list = dutils.get_patch_crop_coords(data, self.patch_size) new_img_batch = [] new_seg_batch = [] for cix, c in enumerate(patch_crop_coords_list): seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]] new_seg_batch.append(seg_patch) tmp_c_5 = c[5] new_img_batch.append(data[c[0]:c[1], c[2]:c[3], c[4]:tmp_c_5]) data = np.array(new_img_batch)[:, np.newaxis] # (n_patches, c, x, y, z) seg = np.transpose(np.array(new_seg_batch), axes=(1,0,2,3,4))[:,:,np.newaxis] # (n_raters, n_patches, x, y, z) if self.cf.dim == 2: # all patches have z dimension 1 (slices). discard dimension data = data[..., 0] seg = seg[..., 0] patch_batch = {'data': data.astype('float32'), 'pid': np.array([patient['pid']] * data.shape[0])} # for o in self.cf.roi_items: # patch_batch[o] = np.repeat(np.array([patient[o]]), len(patch_crop_coords_list), axis=0) converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, self.cf.class_specific_seg) for item in ["seg", "bb_target"]+self.cf.roi_items: patch_batch[item] = [] # coord_list = [np.min(seg_ixs[:, 1]) - 1, np.min(seg_ixs[:, 2]) - 1, np.max(seg_ixs[:, 1]) + 1, # IndexError: index 2 is out of bounds for axis 1 with size 2 for r in range(self.rater_bsize): tmp_batch = {"seg": seg[r]} for item in self.cf.roi_items: tmp_batch[item] = np.repeat(np.array([[roi[r] for roi in patient[item]]]), len(patch_crop_coords_list), axis=0) tmp_batch = converter(**tmp_batch) for item in ["seg", "bb_target"]+self.cf.roi_items: patch_batch[item].append(tmp_batch[item]) # patient-wise (orig) batch info for putting the patches back together after prediction for o in self.cf.roi_items: patch_batch["patient_" + o] = patient_batch['patient_'+o] if self.cf.dim==2: # this could also be named "unpatched_2d_roi_items" patch_batch["patient_"+o+"_2d"] = patient_batch[o] # adding patient-wise data and seg adds about 2 GB of additional RAM consumption to a batch 20x288x288 # and enables calculating test-dice/viewing patient-wise results in test # remove, but also remove dice from metrics, if you like to save memory patch_batch['patient_data'] = patient_batch['data'] patch_batch['patient_seg'] = patient_batch['seg'] patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] if self.cf.dim==2: patch_batch['patient_bb_target_2d'] = patient_batch['bb_target'] patch_batch['original_img_shape'] = patient_batch['original_img_shape'] out_batch = patch_batch self.patient_ix += 1 if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 return out_batch def create_data_gen_pipeline(cf, patient_data, is_training=True): """ create multi-threaded train/val/test batch generation and augmentation pipeline. :param cf: configs object. :param patient_data: dictionary containing one dictionary per patient in the train/test subset. :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing) :return: multithreaded_generator """ BG_name = "train" if is_training else "val" data_gen = BatchGenerator_merged(cf, patient_data, name=BG_name) if cf.training_gts=='merged' else \ BatchGenerator_sa(cf, patient_data, name=BG_name) # add transformations to pipeline. my_transforms = [] if is_training: if cf.da_kwargs["mirror"]: mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes']) my_transforms.append(mirror_transform) spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop']) my_transforms.append(spatial_transform) else: my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) if cf.create_bounding_box_targets: my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg)) all_transforms = Compose(my_transforms) multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers)) return multithreaded_generator def get_train_generators(cf, logger, data_statistics=True): """ wrapper function for creating the training batch generator pipeline. returns the train/val generators. selects patients according to cv folds (generated by first run/fold of experiment): splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set) If cf.held_out_test_set is True, adds the test split to the training data. """ dataset = Dataset(cf, logger) dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits) dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle')) set_splits = dataset.fg.splits test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1) train_ids = np.concatenate(set_splits, axis=0) if cf.held_out_test_set: train_ids = np.concatenate((train_ids, test_ids), axis=0) test_ids = [] train_data = {k: v for (k, v) in dataset.data.items() if k in train_ids} val_data = {k: v for (k, v) in dataset.data.items() if k in val_ids} logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids), len(test_ids))) if data_statistics: dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids}, plot_dir=os.path.join(cf.plot_dir,"dataset")) batch_gen = {} batch_gen['train'] = create_data_gen_pipeline(cf, train_data, is_training=True) batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, is_training=False) if cf.val_mode == 'val_patient': assert cf.training_gts == 'merged', 'val_patient not yet implemented for sa gts' batch_gen['val_patient'] = PatientBatchIterator_merged(cf, val_data) if cf.training_gts=='merged' \ else PatientBatchIterator_sa(cf, val_data) batch_gen['n_val'] = len(val_data) if cf.max_val_patients=="all" else min(len(val_data), cf.max_val_patients) else: batch_gen['n_val'] = cf.num_val_batches return batch_gen def get_test_generator(cf, logger): """ wrapper function for creating the test batch generator pipeline. selects patients according to cv folds (generated by first run/fold of experiment) If cf.held_out_test_set is True, gets the data from an external folder instead. """ if cf.held_out_test_set: sourcedir = cf.test_data_sourcedir test_ids = None else: sourcedir = None with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle: set_splits = pickle.load(handle) test_ids = set_splits[cf.fold] test_data = Dataset(cf, logger, subset_ids=test_ids, data_sourcedir=sourcedir, mode="test").data logger.info("data set loaded with: {} test patients".format(len(test_ids))) batch_gen = {} batch_gen['test'] = PatientBatchIterator_merged(cf, test_data) batch_gen['n_test'] = len(test_ids) if cf.max_test_patients == "all" else min(cf.max_test_patients, len(test_ids)) return batch_gen if __name__ == "__main__": import sys sys.path.append('../') import plotting as plg import utils.exp_utils as utils from configs import Configs cf = Configs() cf.batch_size = 3 #dataset_path = os.path.dirname(os.path.realpath(__file__)) #exp_path = os.path.join(dataset_path, "experiments/dev") #cf = utils.prep_exp(dataset_path, exp_path, server_env=False, use_stored_settings=False, is_training=True) cf.created_fold_id_pickle = False total_stime = time.time() times = {} # cf.server_env = True # cf.data_dir = "experiments/dev_data" # dataset = Dataset(cf) # patient = dataset['Master_00018'] cf.exp_dir = "experiments/dev/" cf.plot_dir = cf.exp_dir + "plots" os.makedirs(cf.exp_dir, exist_ok=True) cf.fold = 0 logger = utils.get_logger(cf.exp_dir) gens = get_train_generators(cf, logger) train_loader = gens['train'] for i in range(1): stime = time.time() #ex_batch = next(train_loader) print("train batch", i) times["train_batch"] = time.time() - stime #plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch.png", show_gt_labels=True) # # # with open(os.path.join(cf.exp_dir, "fold_"+str(cf.fold), "BatchGenerator_stats.txt"), mode="w") as file: # # train_loader.generator.print_stats(logger, file) # val_loader = gens['val_sampling'] stime = time.time() ex_batch = next(val_loader) times["val_batch"] = time.time() - stime stime = time.time() #plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch.png", show_gt_labels=True, plot_mods=False, # show_info=False) times["val_plot"] = time.time() - stime # test_loader = get_test_generator(cf, logger)["test"] stime = time.time() ex_batch = test_loader.generate_train_batch() times["test_batch"] = time.time() - stime stime = time.time() - plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png")#, sample_picks=[0,1,2,3]) + plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png", get_time=False)#, sample_picks=[0,1,2,3]) times["test_patchbatch_plot"] = time.time() - stime # ex_batch['data'] = ex_batch['patient_data'] # ex_batch['seg'] = ex_batch['patient_seg'] # ex_batch['bb_target'] = ex_batch['patient_bb_target'] # for item in cf.roi_items: # ex_batch[] # stime = time.time() # #ex_batch = next(test_loader) # ex_batch = next(test_loader) # plg.view_batch(cf, ex_batch, show_gt_labels=False, show_gt_boxes=True, patient_items=True,# vol_slice_picks=[146,148, 218,220], # out_file="experiments/dev/dev_expatientbatch.png") # , sample_picks=[0,1,2,3]) # times["test_patient_batch_plot"] = time.time() - stime print("Times recorded throughout:") for (k, v) in times.items(): print(k, "{:.2f}".format(v)) mins, secs = divmod((time.time() - total_stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) diff --git a/exec.py b/exec.py index 6f96597..155100e 100644 --- a/exec.py +++ b/exec.py @@ -1,343 +1,341 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ execution script. this where all routines come together and the only script you need to call. refer to parse args below to see options for execution. """ import plotting as plg import os import warnings import argparse import time import torch import utils.exp_utils as utils from evaluator import Evaluator from predictor import Predictor for msg in ["Attempting to set identical bottom==top results", "This figure includes Axes that are not compatible with tight_layout", "Data has no positive values, and therefore cannot be log-scaled.", ".*invalid value encountered in true_divide.*"]: warnings.filterwarnings("ignore", msg) def train(cf, logger): """ performs the training routine for a given fold. saves plots and selected parameters to the experiment dir specified in the configs. logs to file and tensorboard. """ logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format( cf.dim, cf.fold, cf.exp_dir, cf.model)) logger.time("train_val") # -------------- inits and settings ----------------- net = model.net(cf, logger).cuda() if cf.optimizer == "ADAM": optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay) elif cf.optimizer == "SGD": optimizer = torch.optim.SGD(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay, momentum=0.3) if cf.dynamic_lr_scheduling: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=cf.scheduling_mode, factor=cf.lr_decay_factor, patience=cf.scheduling_patience) model_selector = utils.ModelSelector(cf, logger) starting_epoch = 1 if cf.resume_from_checkpoint: starting_epoch = utils.load_checkpoint(cf.resume_from_checkpoint, net, optimizer) logger.info('resumed from checkpoint {} at epoch {}'.format(cf.resume_from_checkpoint, starting_epoch)) # prepare monitoring monitor_metrics = utils.prepare_monitoring(cf) logger.info('loading dataset and initializing batch generators...') batch_gen = data_loader.get_train_generators(cf, logger) # -------------- training ----------------- for epoch in range(starting_epoch, cf.num_epochs + 1): logger.info('starting training epoch {}/{}'.format(epoch, cf.num_epochs)) logger.time("train_epoch") net.train() train_results_list = [] train_evaluator = Evaluator(cf, logger, mode='train') for i in range(cf.num_train_batches): logger.time("train_batch_loadfw") batch = next(batch_gen['train']) batch_gen['train'].generator.stats['roi_counts'] += batch['roi_counts'] batch_gen['train'].generator.stats['empty_counts'] += batch['empty_counts'] logger.time("train_batch_loadfw") logger.time("train_batch_netfw") results_dict = net.train_forward(batch) logger.time("train_batch_netfw") logger.time("train_batch_bw") optimizer.zero_grad() results_dict['torch_loss'].backward() if cf.clip_norm: torch.nn.utils.clip_grad_norm_(net.parameters(), cf.clip_norm, norm_type=2) # gradient clipping optimizer.step() train_results_list.append(({k:v for k,v in results_dict.items() if k != "seg_preds"}, batch["pid"])) # slim res dict if not cf.server_env: print("\rFinished training batch " + "{}/{} in {:.1f}s ({:.2f}/{:.2f} forw load/net, {:.2f} backw).".format(i+1, cf.num_train_batches, logger.get_time("train_batch_loadfw")+ logger.get_time("train_batch_netfw") +logger.time("train_batch_bw"), logger.get_time("train_batch_loadfw",reset=True), logger.get_time("train_batch_netfw", reset=True), logger.get_time("train_batch_bw", reset=True)), end="", flush=True) print() #--------------- train eval ---------------- if (epoch-1)%cf.plot_frequency==0: # view an example batch - logger.time("train_plot") - plg.view_batch(cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True, - out_file=os.path.join(cf.plot_dir, 'batch_example_train_{}.png'.format(cf.fold))) - logger.info("generated train-example plot in {:.2f}s".format(logger.time("train_plot"))) + utils.split_off_process(plg.view_batch, cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, + show_gt_labels=True, get_time="train-example plot", + out_file=os.path.join(cf.plot_dir, 'batch_example_train_{}.png'.format(cf.fold))) logger.time("evals") _, monitor_metrics['train'] = train_evaluator.evaluate_predictions(train_results_list, monitor_metrics['train']) logger.time("evals") logger.time("train_epoch", toggle=False) del train_results_list #----------- validation ------------ logger.info('starting validation in mode {}.'.format(cf.val_mode)) logger.time("val_epoch") with torch.no_grad(): net.eval() val_results_list = [] val_evaluator = Evaluator(cf, logger, mode=cf.val_mode) val_predictor = Predictor(cf, net, logger, mode='val') for i in range(batch_gen['n_val']): logger.time("val_batch") batch = next(batch_gen[cf.val_mode]) if cf.val_mode == 'val_patient': results_dict = val_predictor.predict_patient(batch) elif cf.val_mode == 'val_sampling': results_dict = net.train_forward(batch, is_validation=True) val_results_list.append([results_dict, batch["pid"]]) if not cf.server_env: print("\rFinished validation {} {}/{} in {:.1f}s.".format('patient' if cf.val_mode=='val_patient' else 'batch', i + 1, batch_gen['n_val'], logger.time("val_batch")), end="", flush=True) print() #------------ val eval ------------- if (epoch - 1) % cf.plot_frequency == 0: - logger.time("val_plot") - plg.view_batch(cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True, - out_file=os.path.join(cf.plot_dir, 'batch_example_val_{}.png'.format(cf.fold))) - logger.info("generated val plot in {:.2f}s".format(logger.time("val_plot"))) + utils.split_off_process(plg.view_batch, cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, + show_gt_labels=True, get_time="val-example plot", + out_file=os.path.join(cf.plot_dir, 'batch_example_val_{}.png'.format(cf.fold))) logger.time("evals") _, monitor_metrics['val'] = val_evaluator.evaluate_predictions(val_results_list, monitor_metrics['val']) model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch) del val_results_list #----------- monitoring ------------- monitor_metrics.update({"lr": {str(g) : group['lr'] for (g, group) in enumerate(optimizer.param_groups)}}) logger.metrics2tboard(monitor_metrics, global_step=epoch) logger.time("evals") logger.info('finished epoch {}/{}, took {:.2f}s. train total: {:.2f}s, average: {:.2f}s. val total: {:.2f}s, average: {:.2f}s.'.format( epoch, cf.num_epochs, logger.get_time("train_epoch")+logger.time("val_epoch"), logger.get_time("train_epoch"), logger.get_time("train_epoch", reset=True)/cf.num_train_batches, logger.get_time("val_epoch"), logger.get_time("val_epoch", reset=True)/batch_gen["n_val"])) logger.info("time for evals: {:.2f}s".format(logger.get_time("evals", reset=True))) #-------------- scheduling ----------------- if not cf.dynamic_lr_scheduling: for param_group in optimizer.param_groups: param_group['lr'] = cf.learning_rate[epoch-1] else: scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1]) logger.time("train_val") logger.info("Training and validating over {} epochs took {}".format(cf.num_epochs, logger.get_time("train_val", format="hms", reset=True))) batch_gen['train'].generator.print_stats(logger, plot=True) def test(cf, logger, max_fold=None): """performs testing for a given fold (or held out set). saves stats in evaluator. """ logger.time("test_fold") logger.info('starting testing model of fold {} in exp {}'.format(cf.fold, cf.exp_dir)) net = model.net(cf, logger).cuda() batch_gen = data_loader.get_test_generator(cf, logger) test_predictor = Predictor(cf, net, logger, mode='test') test_results_list = test_predictor.predict_test_set(batch_gen, return_results = not hasattr( cf, "eval_test_separately") or not cf.eval_test_separately) if test_results_list is not None: test_evaluator = Evaluator(cf, logger, mode='test') test_evaluator.evaluate_predictions(test_results_list) test_evaluator.score_test_df(max_fold=max_fold) logger.info('Testing of fold {} took {}.\n'.format(cf.fold, logger.get_time("test_fold", reset=True, format="hms"))) if __name__ == '__main__': stime = time.time() parser = argparse.ArgumentParser() parser.add_argument('--dataset_name', type=str, default='toy', help="path to the dataset-specific code in source_dir/datasets") parser.add_argument('--exp_dir', type=str, default='/home/gregor/Documents/regrcnn/datasets/toy/experiments/dev', help='path to experiment dir. will be created if non existent.') parser.add_argument('-m', '--mode', type=str, default='train_test', help='one out of: create_exp, analysis, train, train_test, or test') parser.add_argument('-f', '--folds', nargs='+', type=int, default=None, help='None runs over all folds in CV. otherwise specify list of folds.') parser.add_argument('--server_env', default=False, action='store_true', help='change IO settings to deploy models on a cluster.') parser.add_argument('--data_dest', type=str, default=None, help="path to final data folder if different from config") parser.add_argument('--use_stored_settings', default=False, action='store_true', help='load configs from existing exp_dir instead of source dir. always done for testing, ' 'but can be set to true to do the same for training. useful in job scheduler environment, ' 'where source code might change before the job actually runs.') parser.add_argument('--resume_from_checkpoint', type=str, default=None, help='path to checkpoint. if resuming from checkpoint, the desired fold still needs to be parsed via --folds.') parser.add_argument('-d', '--dev', default=False, action='store_true', help="development mode: shorten everything") args = parser.parse_args() args.dataset_name = os.path.join("datasets", args.dataset_name) if not "datasets" in args.dataset_name else args.dataset_name folds = args.folds resume_from_checkpoint = None if args.resume_from_checkpoint in ['None', 'none'] else args.resume_from_checkpoint if args.mode == 'create_exp': cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=False) logger = utils.get_logger(cf.exp_dir, cf.server_env, -1) logger.info('created experiment directory at {}'.format(args.exp_dir)) elif args.mode == 'train' or args.mode == 'train_test': cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, args.use_stored_settings) if args.dev: folds = [0,1] cf.batch_size, cf.num_epochs, cf.min_save_thresh, cf.save_n_models = 3 if cf.dim==2 else 1, 1, 0, 1 cf.num_train_batches, cf.num_val_batches, cf.max_val_patients = 5, 1, 1 cf.test_n_epochs = cf.save_n_models cf.max_test_patients = 1 torch.backends.cudnn.benchmark = cf.dim==3 else: torch.backends.cudnn.benchmark = cf.cuda_benchmark if args.data_dest is not None: cf.data_dest = args.data_dest logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval) data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py')) model = utils.import_module('model', cf.model_path) logger.info("loaded model from {}".format(cf.model_path)) if folds is None: folds = range(cf.n_cv_splits) for fold in folds: """k-fold cross-validation: the dataset is split into k equally-sized folds, one used for validation, one for testing, the rest for training. This loop iterates k-times over the dataset, cyclically moving the splits. k==folds, fold in [0,folds) says which split is used for testing. """ cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold)); cf.fold = fold logger.set_logfile(fold=fold) cf.resume_from_checkpoint = resume_from_checkpoint if not os.path.exists(cf.fold_dir): os.mkdir(cf.fold_dir) train(cf, logger) cf.resume_from_checkpoint = None if args.mode == 'train_test': test(cf, logger) elif args.mode == 'test': cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False) if args.data_dest is not None: cf.data_dest = args.data_dest logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval) data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py')) model = utils.import_module('model', cf.model_path) logger.info("loaded model from {}".format(cf.model_path)) fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")]) if folds is None: folds = range(cf.n_cv_splits) if args.dev: folds = folds[:2] cf.batch_size, cf.max_test_patients, cf.test_n_epochs = 1 if cf.dim==2 else 1, 2, 2 else: torch.backends.cudnn.benchmark = cf.cuda_benchmark for fold in folds: cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold)); cf.fold = fold logger.set_logfile(fold=fold) if cf.fold_dir in fold_dirs: test(cf, logger, max_fold=max([int(f[-1]) for f in fold_dirs])) else: logger.info("Skipping fold {} since no model parameters found.".format(fold)) # load raw predictions saved by predictor during testing, run aggregation algorithms and evaluation. elif args.mode == 'analysis': """ analyse already saved predictions. """ cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False) logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval) if cf.held_out_test_set and not cf.eval_test_fold_wise: predictor = Predictor(cf, net=None, logger=logger, mode='analysis') results_list = predictor.load_saved_predictions() logger.info('starting evaluation...') cf.fold = 0 evaluator = Evaluator(cf, logger, mode='test') evaluator.evaluate_predictions(results_list) evaluator.score_test_df(max_fold=0) else: fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")]) if args.dev: fold_dirs = fold_dirs[:1] if folds is None: folds = range(cf.n_cv_splits) for fold in folds: cf.fold = fold; cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold)) logger.set_logfile(fold=fold) if cf.fold_dir in fold_dirs: predictor = Predictor(cf, net=None, logger=logger, mode='analysis') results_list = predictor.load_saved_predictions() # results_list[x][1] is pid, results_list[x][0] is list of len samples-per-patient, each entry hlds # list of boxes per that sample, i.e., len(results_list[x][y][0]) would be nr of boxes in sample y of patient x logger.info('starting evaluation...') evaluator = Evaluator(cf, logger, mode='test') evaluator.evaluate_predictions(results_list) max_fold = max([int(f[-1]) for f in fold_dirs]) evaluator.score_test_df(max_fold=max_fold) else: logger.info("Skipping fold {} since no model parameters found.".format(fold)) else: raise ValueError('mode "{}" specified in args is not implemented.'.format(args.mode)) mins, secs = divmod((time.time() - stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) logger.info("{} total runtime: {}".format(os.path.split(__file__)[1], t)) del logger torch.cuda.empty_cache() diff --git a/plotting.py b/plotting.py index f80dda0..1bb78aa 100644 --- a/plotting.py +++ b/plotting.py @@ -1,2136 +1,2139 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import matplotlib # matplotlib.rcParams['font.family'] = ['serif'] # matplotlib.rcParams['font.serif'] = ['Times New Roman'] matplotlib.rcParams['mathtext.fontset'] = 'cm' matplotlib.rcParams['font.family'] = 'STIXGeneral' matplotlib.use('Agg') #complains with spyder editor, bc spyder imports mpl at startup from matplotlib.ticker import FormatStrFormatter import matplotlib.colors as mcolors import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import matplotlib.patches as mpatches from matplotlib.ticker import StrMethodFormatter, ScalarFormatter import SimpleITK as sitk from tensorboard.backend.event_processing.event_multiplexer import EventMultiplexer import sys import os import warnings +import time from copy import deepcopy import numpy as np import pandas as pd import scipy.interpolate as interpol from utils.exp_utils import IO_safe warnings.filterwarnings("ignore", module="matplotlib.image") def make_colormap(seq): """ Return a LinearSegmentedColormap seq: a sequence of floats and RGB-tuples. The floats should be increasing and in the interval (0,1). """ seq = [(None,) * 3, 0.0] + list(seq) + [1.0, (None,) * 3] cdict = {'red': [], 'green': [], 'blue': []} for i, item in enumerate(seq): if isinstance(item, float): r1, g1, b1 = seq[i - 1] r2, g2, b2 = seq[i + 1] cdict['red'].append([item, r1, r2]) cdict['green'].append([item, g1, g2]) cdict['blue'].append([item, b1, b2]) return mcolors.LinearSegmentedColormap('CustomMap', cdict) bw_cmap = make_colormap([(1.,1.,1.), (0.,0.,0.)]) #------------------------------------------------------------------------ #------------- plotting functions, not all are used --------------------- def shape_small_first(shape): """sort a tuple so that the smallest entry is swapped to the beginning """ if len(shape) <= 2: # no changing dimensions if channel-dim is missing return shape smallest_dim = np.argmin(shape) if smallest_dim != 0: # assume that smallest dim is color channel new_shape = np.array(shape) # to support mask indexing new_shape = (new_shape[smallest_dim], *new_shape[(np.arange(len(shape), dtype=int) != smallest_dim)]) return new_shape else: return shape def RGB_to_rgb(RGB): rgb = np.array(RGB) / 255. return rgb def mod_to_rgb(arr, cmap=None): """convert a single-channel modality img to 3-color-channel img. :param arr: input img, expected in shape (b,c,)x,y with c=1 :return: img of shape (...,c') with c'=3 """ if len(arr.shape) == 3: arr = np.squeeze(arr) elif len(arr.shape) != 2: raise Exception("Invalid input arr shape: {}".format(arr.shape)) if cmap is None: cmap = "gray" norm = matplotlib.colors.Normalize() norm.autoscale(arr) arr = norm(arr) arr = np.stack((arr,) * 3, axis=-1) return arr def to_rgb(arr, cmap): """ Transform an integer-labeled segmentation map using an rgb color-map. :param arr: img_arr w/o a color-channel :param cmap: dictionary mapping from integer class labels to rgb values :return: img of shape (...,c) """ new_arr = np.zeros(shape=(arr.shape) + (3,)) for l in cmap.keys(): ixs = np.where(arr == l) new_arr[ixs] = np.array([cmap[l][i] for i in range(3)]) return new_arr def to_rgba(arr, cmap): """ Transform an integer-labeled segmentation map using an rgba color-map. :param arr: img_arr w/o a color-channel :param cmap: dictionary mapping from integer class labels to rgba values :return: new array holding rgba-image """ new_arr = np.zeros(shape=(arr.shape) + (4,)) for lab, val in cmap.items(): # in case no alpha, complement with 100% alpha if len(val) == 3: cmap[lab] = (*val, 1.) assert len(cmap[lab]) == 4, "cmap has color with {} entries".format(len(val)) for lab in cmap.keys(): ixs = np.where(arr == lab) rgb = np.array(cmap[lab][:3]) new_arr[ixs] = np.append(rgb, cmap[lab][3]) return new_arr def bin_seg_to_rgba(arr, color): """ Transform a continuously labelled binary segmentation map using an rgba color-map. values are expected to be 0-1, will give alpha-value :param arr: img_arr w/o a color-channel :param color: color to give img :return: new array holding rgba-image """ new_arr = np.zeros(shape=(arr.shape) + (4,)) for i in range(arr.shape[0]): for j in range(arr.shape[1]): new_arr[i][j] = (*color, arr[i][j]) return new_arr def suppress_axes_lines(ax): """ :param ax: pyplot axes object """ ax.axes.get_xaxis().set_ticks([]) ax.axes.get_yaxis().set_ticks([]) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.spines['left'].set_visible(False) return def label_bar(ax, rects, labels=None, colors=None, fontsize=10): """Attach a text label above each bar displaying its height :param ax: :param rects: rectangles as returned by plt.bar() :param labels: :param colors: """ for ix, rect in enumerate(rects): height = rect.get_height() if labels is not None and labels[ix] is not None: label = labels[ix] else: label = '{:g}'.format(height) if colors is not None and colors[ix] is not None and np.any(np.array(colors[ix])<1): color = colors[ix] else: color = 'black' ax.text(rect.get_x() + rect.get_width() / 2., 1.007 * height, label, color=color, ha='center', va='bottom', bbox=dict(facecolor=(1., 1., 1.), edgecolor='none', clip_on=True, pad=0, alpha=0.75), fontsize=fontsize) def draw_box_into_arr(arr, box_coords, box_color=None, lw=2): """ :param arr: imgs shape, (3,y,x) :param box_coords: (x1,y1,x2,y2), in ascending order :param box_color: arr of shape (3,) :param lw: linewidth in pixels """ if box_color is None: box_color = [1., 0.4, 0.] (x1, y1, x2, y2) = box_coords[:4] arr = np.swapaxes(arr, 0, -1) arr[..., y1:y2, x1:x1 + lw, :], arr[..., y1:y2 + lw, x2:x2 + lw, :] = box_color, box_color arr[..., y1:y1 + lw, x1:x2, :], arr[..., y2:y2 + lw, x1:x2, :] = box_color, box_color arr = np.swapaxes(arr, 0, -1) return arr def draw_boxes_into_batch(imgs, batch_boxes, type2color=None, cmap=None): """ :param imgs: either the actual batch imgs or a tuple with shape of batch imgs, need to have 3 color channels, need to be rgb; """ if isinstance(imgs, tuple): img_oshp = imgs imgs = None else: img_oshp = imgs[0].shape img_shp = shape_small_first(img_oshp) # c,x/y,y/x now imgs = np.reshape(imgs, (-1, *img_shp)) box_imgs = np.empty((len(batch_boxes), *(img_shp))) for sample, boxes in enumerate(batch_boxes): # imgs in batch have shape b,c,x,y, swap c to end sample_img = np.full(img_shp, 1.) if imgs is None else imgs[sample] for box in boxes: if len(box["box_coords"]) > 0: if type2color is not None and "box_type" in box.keys(): sample_img = draw_box_into_arr(sample_img, box["box_coords"].astype(np.int32), type2color[box["box_type"]]) else: sample_img = draw_box_into_arr(sample_img, box["box_coords"].astype(np.int32)) box_imgs[sample] = sample_img return box_imgs def plot_prediction_hist(cf, spec_df, outfile, title=None, fs=11, ax=None): labels = spec_df.class_label.values preds = spec_df.pred_score.values type_list = spec_df.det_type.tolist() if hasattr(spec_df, "det_type") else None if title is None: title = outfile.split('/')[-1] + ' count:{}'.format(len(labels)) close=False if ax is None: fig = plt.figure(tight_layout=True) ax = fig.add_subplot(1,1,1) close=True ax.set_yscale('log') ax.set_xlabel("Prediction Score", fontsize=fs) ax.set_ylabel("Occurences", fontsize=fs) ax.hist(preds[labels == 0], alpha=0.3, color=cf.red, range=(0, 1), bins=50, label="fp") ax.hist(preds[labels == 1], alpha=0.3, color=cf.blue, range=(0, 1), bins=50, label="fn at score 0 and tp") ax.axvline(x=cf.min_det_thresh, alpha=1, color=cf.orange, linewidth=1.5, label="min det thresh") if type_list is not None: fp_count = type_list.count('det_fp') fn_count = type_list.count('det_fn') tp_count = type_list.count('det_tp') pos_count = fn_count + tp_count title += '\ntp:{} fp:{} fn:{} pos:{}'.format(tp_count, fp_count, fn_count, pos_count) ax.set_title(title, fontsize=fs) ax.tick_params(axis='both', which='major', labelsize=fs) ax.tick_params(axis='both', which='minor', labelsize=fs) if close: ax.legend(loc="best", fontsize=fs) if cf.server_env: IO_safe(plt.savefig, fname=outfile, _raise=False) else: plt.savefig(outfile) pass plt.close() def plot_wbc_n_missing(cf, df, outfile, fs=11, ax=None): """ WBC (weighted box clustering) has parameter n_missing, which shows how many boxes are missing per cluster. This function plots the average relative amount of missing boxes sorted by cluster score. :param cf: config. :param df: dataframe. :param outfile: path to save image under. :param fs: fontsize. :param ax: axes object. """ bins = np.linspace(0., 1., 10) names = ["{:.1f}".format((bins[i]+(bins[i+1]-bins[i])/2.)*100) for i in range(len(bins)-1)] classes = df.pred_class.unique() colors = [cf.class_id2label[cl_id].color for cl_id in classes] binned_df = df.copy() binned_df.loc[:,"pred_score"] = pd.cut(binned_df["pred_score"], bins) close=False if ax is None: ax = plt.subplot() close=True width = 1 / (len(classes) + 1) group_positions = np.arange(len(names)) legend_handles = [] for ix, cl_id in enumerate(classes): cl_df = binned_df[binned_df.pred_class==cl_id].groupby("pred_score").agg({"cluster_n_missing": 'mean'}) ax.bar(group_positions + ix * width, cl_df.cluster_n_missing.values, width=width, color=colors[ix], alpha=0.4 + ix / 2 / len(classes), edgecolor=colors[ix]) legend_handles.append(mpatches.Patch(color=colors[ix], label=cf.class_dict[cl_id])) title = "Fold {} WBC Missing Preds\nAverage over scores and classes: {:.1f}%".format(cf.fold, df.cluster_n_missing.mean()) ax.set_title(title, fontsize=fs) ax.legend(handles=legend_handles, title="Class", loc="best", fontsize=fs, title_fontsize=fs) ax.set_xticks(group_positions + (len(classes) - 1) * width / 2) # ax.xaxis.set_major_formatter(StrMethodFormatter('{x:.1f}')) THIS WONT WORK... no clue! ax.set_xticklabels(names) ax.tick_params(axis='both', which='major', labelsize=fs) ax.tick_params(axis='both', which='minor', labelsize=fs) ax.set_axisbelow(True) ax.grid() ax.set_ylabel(r"Average Missing Preds per Cluster (%)", fontsize=fs) ax.set_xlabel("Prediction Score", fontsize=fs) if close: if cf.server_env: IO_safe(plt.savefig, fname=outfile, _raise=False) else: plt.savefig(outfile) plt.close() def plot_stat_curves(cf, stats, outfile, fill=False): """ Plot precision-recall and/or receiver-operating-characteristic curve(s). :param cf: config. :param stats: statistics as supplied by Evaluator. :param outfile: path to save plot under. :param fill: whether to colorize space between plot and x-axis. :return: """ for c in ['roc', 'prc']: plt.figure() empty_plot = True for ix, s in enumerate(stats): if s[c] is not np.nan: plt.plot(s[c][1], s[c][0], label=s['name'] + '_' + c, marker=None, color=cf.color_palette[ix%len(cf.color_palette)]) empty_plot = False if fill: plt.fill_between(s[c][1], s[c][0], alpha=0.33, color=cf.color_palette[ix%len(cf.color_palette)]) if not empty_plot: plt.title(outfile.split('/')[-1] + '_' + c) plt.legend(loc=3 if c == 'prc' else 4) plt.ylabel('precision' if c == 'prc' else '1-spec.') plt.ylim((0.,1)) plt.xlabel('recall') plt.savefig(outfile + '_' + c) plt.close() def plot_grouped_bar_chart(cf, bar_values, groups, splits, colors=None, alphas=None, errors=None, ylabel='', xlabel='', xticklabels=None, yticks=None, yticklabels=None, ylim=None, label_format="{:.3f}", title=None, ax=None, out_file=None, legend=False, fs=11): """ Plot a categorically grouped bar chart. :param cf: config. :param bar_values: values of the bars. :param groups: groups/categories that bars belong to. :param splits: splits within groups, i.e., names of bars. :param colors: colors. :param alphas: 1-opacity. :param errors: values for errorbars. :param ylabel: label of y-axis. :param xlabel: label of x-axis. :param title: plot title. :param ax: axes object to draw into. if None, new is created. :param out_file: path to save plot. :param legend: whether to show a legend. :param fs: fontsize. :return: legend handles. """ bar_values = np.array(bar_values) if alphas is None: alphas = [1.,] * len(splits) if colors is None: colors = [cf.color_palette[ix%len(cf.color_palette)] for ix in range(len(splits))] if errors is None: errors = np.zeros_like(bar_values) # patterns = ('/', '\\', '*', 'O', '.', '-', '+', 'x', 'o') # patterns = tuple([patterns[ix%len(patterns)] for ix in range(len(splits))]) close=False if ax is None: ax = plt.subplot() close=True width = 1 / (len(splits) +0.25) group_positions = np.arange(len(groups)) for ix, split in enumerate(splits): rects = ax.bar(group_positions + ix * width, bar_values[ix], width=width, color=(*colors[ix], 0.8), edgecolor=colors[ix], yerr=errors[ix], ecolor=(*np.array(colors[ix])*0.8, 1.), capsize=5) # for ix, bar in enumerate(rects): # bar.set_hatch(patterns[ix]) labels = [label_format.format(val) for val in bar_values[ix]] label_bar(ax, rects, labels, [colors[ix]]*len(labels), fontsize=fs) legend_handles = [mpatches.Patch(color=colors[ix], alpha=alphas[ix], label=split) for ix, split in enumerate(splits)] if legend: ax.legend(handles=legend_handles, fancybox=True, framealpha=1., loc="lower center") legend_handles = [(colors[ix], alphas[ix], split) for ix, split in enumerate(splits)] if title is not None: ax.set_title(title, fontsize=fs) ax.set_xticks(group_positions + (len(splits) - 1) * width / 2) if xticklabels is None: ax.set_xticklabels(groups, fontsize=fs) else: ax.set_xticklabels(xticklabels, fontsize=fs) ax.set_axisbelow(True) ax.set_xlabel(xlabel, fontsize=fs) ax.tick_params(labelsize=fs) ax.grid(axis='y') ax.set_ylabel(ylabel, fontsize=fs) if yticks is not None: ax.set_yticks(yticks) if yticklabels is not None: ax.set_yticklabels(yticklabels, fontsize=fs) if ylim is not None: ax.set_ylim(ylim) if out_file is not None: plt.savefig(out_file, dpi=600) if close: plt.close() return legend_handles def plot_binned_rater_dissent(cf, binned_stats, out_file=None, ax=None, legend=True, fs=11): """ LIDC-specific plot: rater disagreement as standard deviations within each bin. :param cf: config. :param binned_stats: list, ix==bin_id, item: [(roi_mean, roi_std, roi_max, roi_bin_id-roi_max_bin_id) for roi in bin] :return: """ dissent = [np.array([roi[1] for roi in bin]) for bin in binned_stats] avg_dissent_first_degree = [np.mean(bin) for bin in dissent] groups = list(cf.bin_id2label.keys()) splits = [r"$1^{st}$ std. dev.",] colors = [cf.bin_id2label[bin_id].color[:3] for bin_id in groups] #colors = [cf.blue for bin_id in groups] alphas = [0.9,] #patterns = ('/', '\\', '*', 'O', '.', '-', '+', 'x', 'o') #patterns = tuple([patterns[ix%len(patterns)] for ix in range(len(splits))]) close=False if ax is None: ax = plt.subplot() close=True width = 1/(len(splits)+1) group_positions = np.arange(len(groups)) #total_counts = [df.loc[split].sum() for split in splits] dissent = np.array(avg_dissent_first_degree) ix=0 rects = ax.bar(group_positions+ix*width, dissent, color=colors, alpha=alphas[ix], edgecolor=colors) #for ix, bar in enumerate(rects): #bar.set_hatch(patterns[ix]) labels = ["{:.2f}".format(diss) for diss in dissent] label_bar(ax, rects, labels, colors, fontsize=fs) bin_edge_color = cf.blue ax.axhline(y=0.5, color=bin_edge_color) ax.text(2.5, 0.38, "bin edge", color=cf.white, fontsize=fs, horizontalalignment="center", bbox=dict(boxstyle='round', facecolor=(*bin_edge_color, 0.85), edgecolor='none', clip_on=True, pad=0)) if legend: legend_handles = [mpatches.Patch(color=cf.blue ,alpha=alphas[ix], label=split) for ix, split in enumerate(splits)] ax.legend(handles=legend_handles, loc='lower center', fontsize=fs) title = "LIDC-IDRI: Average Std Deviation per Lesion" plt.title(title) ax.set_xticks(group_positions + (len(splits)-1)*width/2) ax.set_xticklabels(groups, fontsize=fs) ax.set_axisbelow(True) #ax.tick_params(axis='both', which='major', labelsize=fs) #ax.tick_params(axis='both', which='minor', labelsize=fs) ax.grid() ax.set_ylabel(r"Average Dissent (MS)", fontsize=fs) ax.set_xlabel("binned malignancy-score value (ms)", fontsize=fs) ax.tick_params(labelsize=fs) if out_file is not None: plt.savefig(out_file, dpi=600) if close: plt.close() return def plot_confusion_matrix(cf, cm, out_file=None, ax=None, fs=11, cmap=plt.cm.Blues, color_bar=True): """ Plot a confusion matrix. :param cf: config. :param cm: confusion matrix, e.g., as supplied by metrics.confusion_matrix from scikit-learn. :return: """ close=False if ax is None: ax = plt.subplot() close=True im = ax.imshow(cm, interpolation='nearest', cmap=cmap) if color_bar: ax.figure.colorbar(im, ax=ax) # Rotate the tick labels and set their alignment. #plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. fmt = '.0%' if np.mod(cm, 1).any() else 'd' thresh = cm.max() / 2. for i in range(cm.shape[0]): for j in range(cm.shape[1]): ax.text(j, i, format(cm[i, j], fmt), ha="center", va="center", color="white" if cm[i, j] > thresh else "black") ax.set_ylabel(r"Binned Mean MS", fontsize=fs) ax.set_xlabel("Single-Annotator MS", fontsize=fs) #ax.tick_params(labelsize=fs) if close and out_file is not None: plt.savefig(out_file, dpi=600) if close: plt.close() else: return ax def plot_data_stats(cf, df, labels=None, out_file=None, ax=None, fs=11): """ Plot data-set statistics. Shows target counts. Mainly used by Dataset Class in dataloader.py. :param cf: configs obj :param df: pandas dataframe :param out_file: path to save fig in """ names = df.columns if labels is not None: colors = [label.color for name in names for label in labels if label.name==name] else: colors = [cf.color_palette[ix%len(cf.color_palette)] for ix in range(len(names))] #patterns = ('/', '\\', '*', 'O', '.', '-', '+', 'x', 'o') #patterns = tuple([patterns[ix%len(patterns)] for ix in range(len(splits))]) if ax is None: fig, ax = plt.subplots(figsize=(14,6), dpi=300) return_ax = False else: return_ax = True plt.margins(x=0.01) plt.subplots_adjust(bottom=0.15) bar_positions = np.arange(len(names)) name_counts = df.sum() total_count = name_counts.sum() rects = ax.bar(bar_positions, name_counts, color=colors, alpha=0.9, edgecolor=colors) labels = ["{:.0f}%".format(count/ total_count*100) for count in name_counts] label_bar(ax, rects, labels, colors, fontsize=fs) title= "Data Set RoI-Target Balance\nTotal #RoIs: {}".format(int(total_count)) ax.set_title(title, fontsize=fs) ax.set_xticks(bar_positions) rotation = "vertical" if np.any([len(str(name)) > 3 for name in names]) else None if all([isinstance(name, (float, int)) for name in names]): ax.set_xticklabels(["{:.2f}".format(name) for name in names], rotation=rotation, fontsize=fs) else: ax.set_xticklabels(names, rotation=rotation, fontsize=fs) ax.set_axisbelow(True) ax.grid() ax.set_ylabel(r"#RoIs", fontsize=fs) ax.set_xlabel(str(df._metadata[0]), fontsize=fs) ax.tick_params(axis='both', which='major', labelsize=fs) ax.tick_params(axis='both', which='minor', labelsize=fs) if out_file is not None: plt.savefig(out_file) if return_ax: return ax else: plt.close() def plot_fold_stats(cf, df, labels=None, out_file=None, ax=None): """ Similar as plot_data_stats but per single cross-val fold. :param cf: configs obj :param df: pandas dataframe :param out_file: path to save fig in """ names = df.columns splits = df.index if labels is not None: colors = [label.color for name in names for label in labels if label.name==name] else: colors = [cf.color_palette[ix%len(cf.color_palette)] for ix in range(len(names))] #patterns = ('/', '\\', '*', 'O', '.', '-', '+', 'x', 'o') #patterns = tuple([patterns[ix%len(patterns)] for ix in range(len(splits))]) if ax is None: ax = plt.subplot() return_ax = False else: return_ax = True width = 1/(len(names)+1) group_positions = np.arange(len(splits)) legend_handles = [] total_counts = [df.loc[split].sum() for split in splits] for ix, name in enumerate(names): rects = ax.bar(group_positions+ix*width, df.loc[:,name], width=width, color=colors[ix], alpha=0.9, edgecolor=colors[ix]) #for ix, bar in enumerate(rects): #bar.set_hatch(patterns[ix]) labels = ["{:.0f}%".format(df.loc[split, name]/ total_counts[ii]*100) for ii, split in enumerate(splits)] label_bar(ax, rects, labels, [colors[ix]]*len(group_positions)) legend_handles.append(mpatches.Patch(color=colors[ix] ,alpha=0.9, label=name)) title= "Fold {} RoI-Target Balances\nTotal #RoIs: {}".format(cf.fold, int(df.values.sum())) plt.title(title) ax.legend(handles=legend_handles) ax.set_xticks(group_positions + (len(names)-1)*width/2) ax.set_xticklabels(splits, rotation="vertical" if len(splits)>2 else None, size=12) ax.set_axisbelow(True) ax.grid() ax.set_ylabel(r"#RoIs") ax.set_xlabel("Set split") if out_file is not None: plt.savefig(out_file) if return_ax: return ax plt.close() def plot_batchgen_distribution(cf, pids, p_probs, balance_target, out_file=None): """plot top n_pids probabilities for drawing a pid into a batch. :param cf: experiment config object :param pids: sorted iterable of patient ids :param p_probs: pid's drawing likelihood, order needs to match the one of pids. :param out_file: :return: """ n_pids = len(pids) zip_sorted = np.array(sorted(list(zip(p_probs, pids)), reverse=True)) names, probs = zip_sorted[:n_pids,1], zip_sorted[:n_pids,0].astype('float32') * 100 try: names = [str(int(n)) for n in names] except ValueError: names = [str(n) for n in names] lowest_p = min(p_probs)*100 fig, ax = plt.subplots(1,1,figsize=(17,5), dpi=200) rects = ax.bar(names, probs, color=cf.blue, alpha=0.9, edgecolor=cf.blue) ax = plt.gca() ax.text(0.8, 0.92, "Lowest prob.: {:.5f}%".format(lowest_p), transform=ax.transAxes, color=cf.white, bbox=dict(boxstyle='round', facecolor=cf.blue, edgecolor='none', alpha=0.9)) ax.yaxis.set_major_formatter(StrMethodFormatter('{x:g}')) ax.set_xticklabels(names, rotation="vertical", fontsize=7) plt.margins(x=0.01) plt.subplots_adjust(bottom=0.15) if balance_target=="class_targets": balance_target = "Class" elif balance_target=="lesion_gleasons": balance_target = "GS" ax.set_title(str(balance_target)+"-Balanced Train Generator: Sampling Likelihood per PID") ax.set_axisbelow(True) ax.grid(axis='y') ax.set_ylabel("Sampling Likelihood (%)") ax.set_xlabel("PID") plt.tight_layout() if out_file is not None: plt.savefig(out_file) plt.close() def plot_batchgen_stats(cf, stats, empties, target_name, unique_ts, out_file=None): """Plot bar chart showing RoI frequencies and empty-sample count of batch stats recorded by BatchGenerator. :param cf: config. :param stats: statistics as supplied by BatchGenerator class. :param out_file: path to save plot. """ total_samples = cf.num_epochs*cf.num_train_batches*cf.batch_size if target_name=="class_targets": target_name = "Class" label_dict = {cl_id: label for (cl_id, label) in cf.class_id2label.items()} elif target_name=="lesion_gleasons": target_name = "Lesion's Gleason Score" label_dict = cf.gs2label elif target_name=="rg_bin_targets": target_name = "Regression-Bin ID" label_dict = cf.bin_id2label else: raise NotImplementedError names = [label_dict[t_id].name for t_id in unique_ts] colors = [label_dict[t_id].color for t_id in unique_ts] title = "Training Target Frequencies" title += "\nempty samples: {}".format(empties) rects = plt.bar(names, stats['roi_counts'], color=colors, alpha=0.9, edgecolor=colors) ax = plt.gca() ax.yaxis.set_major_formatter(StrMethodFormatter('{x:g}')) ax.set_title(title) ax.set_axisbelow(True) ax.grid() ax.set_ylabel(r"#RoIs") ax.set_xlabel(target_name) total_count = np.sum(stats["roi_counts"]) labels = ["{:.0f}%".format(count/total_count*100) for count in stats["roi_counts"]] label_bar(ax, rects, labels, colors) if out_file is not None: plt.savefig(out_file) plt.close() def view_3D_array(arr, outfile, elev=30, azim=30): from mpl_toolkits.mplot3d import Axes3D fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.set_aspect("equal") ax.set_xlabel("x") ax.set_ylabel("y") ax.set_zlabel("z") ax.voxels(arr) ax.view_init(elev=elev, azim=azim) plt.savefig(outfile) def view_batch(cf, batch, res_dict=None, out_file=None, legend=True, show_info=True, has_colorchannels=False, isRGB=True, show_seg_ids="all", show_seg_pred=True, show_gt_boxes=True, show_gt_labels=False, - roi_items="all", sample_picks=None, vol_slice_picks=None, - box_score_thres=None, plot_mods=True, dpi=200, vmin=None, return_fig=False): + roi_items="all", sample_picks=None, vol_slice_picks=None, box_score_thres=None, plot_mods=True, + dpi=200, vmin=None, return_fig=False, get_time=True): r""" View data and target entries of a batch. Batch expected as dic with entries 'data' and 'seg' holding np.arrays of size :math:`batch\_size \times modalities \times h \times w` for data and :math:`batch\_size \times classes \times h \times w` or :math:`batch\_size \times 1 \times h \times w` for segs. Classes, even if just dummy, are always needed for plotting since they determine colors. Pyplot expects dimensions in order y,x,chans (height, width, chans) for imshow. :param cf: config. :param batch: batch. :param res_dict: results dictionary. :param out_file: path to save plot. :param legend: whether to show a legend. :param show_info: whether to show text info about img sizes and type in plot. :param has_colorchannels: whether image has color channels. :param isRGB: if image is RGB. :param show_seg_ids: "all" or None or list with seg classes to show (seg_ids) :param show_seg_pred: whether to the predicted segmentation. :param show_gt_boxes: whether to show ground-truth boxes. :param show_gt_labels: whether to show labels of ground-truth boxes. :param roi_items: which roi items to show: strings "all" or "targets". --> all roi items in cf.roi_items or only those which are targets, or list holding keys/names of entries in cf.roi_items to plot additionally on roi boxes. empty iterator to show none. :param sample_picks: which indices of the batch to display. None for all. :param vol_slice_picks: when batch elements are 3D: which slices to display. None for all, or tuples ("random", int: amt) / (float€[0,1]: fg_prob, int: amt) for random pick / fg_slices pick w probability fg_prob of amt slices. fg pick requires gt seg. :param box_score_thres: plot only boxes with pred_score > box_score_thres. None or 0. for no threshold. :param plot_mods: whether to plot input modality/modalities. :param dpi: graphics resolution. :param vmin: min value for gray-scale cmap in imshow, set to a fix value for inter-batch normalization, or None for intra-batch. :param return_fig: whether to return created figure. """ - + stime = time.time() # pfix = prefix, ptfix = postfix patched_patient = 'patch_crop_coords' in list(batch.keys()) pfix = 'patient_' if patched_patient else '' ptfix = '_2d' if (patched_patient and cf.dim == 2 and pfix + 'class_targets_2d' in batch.keys()) else '' # -------------- get data, set flags ----------------- try: btype = type(batch[pfix + 'data']) data = batch[pfix + 'data'].astype("float32") seg = batch[pfix + 'seg'] except AttributeError: # in this case: assume it's single-annotator ground truths btype = type(batch[pfix + 'data']) data = batch[pfix + 'data'].astype("float32") seg = batch[pfix + 'seg'][0] print("Showing only gts of rater 0") data_init_shp, seg_init_shp = data.shape, seg.shape seg = np.copy(seg) if show_seg_ids else None plot_bg = batch['plot_bg'] if 'plot_bg' in batch.keys() and not isinstance(batch['plot_bg'], (int, float)) else None plot_bg_chan = batch['plot_bg'] if 'plot_bg' in batch.keys() and isinstance(batch['plot_bg'], (int, float)) else 0 gt_boxes = batch[pfix+'bb_target'+ptfix] if pfix+'bb_target'+ptfix in batch.keys() and show_gt_boxes else None class_targets = batch[pfix+'class_targets'+ptfix] if pfix+'class_targets'+ptfix in batch.keys() else None cf_roi_items = [pfix+it+ptfix for it in cf.roi_items] if roi_items == "all": roi_items = [it for it in cf_roi_items] elif roi_items == "targets": roi_items = [it for it in cf_roi_items if 'targets' in it] else: roi_items = [it for it in cf_roi_items if it in roi_items] if res_dict is not None: seg_preds = res_dict["seg_preds"] if (show_seg_pred is not None and 'seg_preds' in res_dict.keys() and show_seg_ids) else None if '2D_boxes' in res_dict.keys(): assert cf.dim==2 pr_boxes = res_dict["2D_boxes"] elif 'boxes' in res_dict.keys(): pr_boxes = res_dict["boxes"] else: pr_boxes = None else: seg_preds = None pr_boxes = None # -------------- get shapes, apply sample selection ----------------- (n_samples, mods, h, w), d = data.shape[:4], 0 z_ics = [slice(None)] if has_colorchannels: #has to be 2D data = np.transpose(data, axes=(0, 2, 3, 1)) # now b,y,x,c mods = 1 else: if len(data.shape) == 5: # 3dim case d = data.shape[4] if vol_slice_picks is None: z_ics = np.arange(0, d) elif hasattr(vol_slice_picks, "__iter__") and vol_slice_picks[0]=="random": z_ics = np.random.choice(np.arange(0, d), size=min(vol_slice_picks[1], d), replace=False) else: z_ics = vol_slice_picks sample_ics = range(n_samples) # 8000 approx value of pixels that are displayable in one figure dim (pyplot has a render limit), depends on dpi however if data.shape[0]*data.shape[2]*len(z_ics)>8000: n_picks = max(1, int(8000/(data.shape[2]*len(z_ics)))) if len(z_ics)>1 and vol_slice_picks is None: z_ics = np.random.choice(np.arange(0, data.shape[4]), size=min(data.shape[4], max(1,int(8000/(n_picks*data.shape[2])))), replace=False) if sample_picks is None: sample_picks = np.random.choice(data.shape[0], n_picks, replace=False) if sample_picks is not None: sample_ics = [s for s in sample_picks if s in sample_ics] n_samples = len(sample_ics) if not plot_mods: mods = 0 if show_seg_ids=="all": show_seg_ids = np.unique(seg) if seg_preds is not None and not type(show_seg_ids)==str: seg_preds = np.copy(seg_preds) seg_preds = np.where(np.isin(seg_preds, show_seg_ids), seg_preds, 0) if seg is not None: if not type(show_seg_ids)==str: #to save time seg = np.where(np.isin(seg, show_seg_ids), seg, 0) legend_items = {cf.seg_id2label[seg_id] for seg_id in np.unique(seg) if seg_id != 0} # add seg labels else: legend_items = set() # -------------- setup figure ----------------- if isRGB: data = RGB_to_rgb(data) if plot_bg is not None: plot_bg = RGB_to_rgb(plot_bg) n_cols = mods if seg is not None or gt_boxes is not None: n_cols += 1 if seg_preds is not None or pr_boxes is not None: n_cols += 1 n_rows = n_samples*len(z_ics) grid = gridspec.GridSpec(n_rows, n_cols, wspace=0.01, hspace=0.0) fig = plt.figure(figsize=((n_cols + 1)*2, n_rows*2), tight_layout=True) title_fs = 12 # fontsize sample_ics, z_ics = sorted(sample_ics), sorted(z_ics) row = 0 # current row for s_count, s_ix in enumerate(sample_ics): for z_ix in z_ics: col = 0 # current col # ----visualise input data ------------- if has_colorchannels: if plot_mods: ax = fig.add_subplot(grid[row, col]) ax.imshow(data[s_ix][...,z_ix]) ax.axis("off") if row == 0: plt.title("Input", fontsize=title_fs) if col == 0: specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix) == slice else z_ix ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number col += 1 bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix][...,z_ix] else: for mod in range(mods): ax = fig.add_subplot(grid[row, col]) ax.imshow(data[s_ix, mod][...,z_ix], cmap="gray", vmin=vmin) suppress_axes_lines(ax) if row == 0: plt.title("Mod. " + str(mod), fontsize=title_fs) if col == 0: specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix)==slice else z_ix ylabel = str(specs[s_ix])[-5:]+"/"+str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number col += 1 bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix, plot_bg_chan][...,z_ix] # ---evtly visualise groundtruths------------------- if seg is not None or gt_boxes is not None: # img as bg for gt ax = fig.add_subplot(grid[row, col]) ax.imshow(bg_img, cmap="gray", vmin=vmin) if row == 0: plt.title("Ground Truth", fontsize=title_fs) if col == 0: specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix) == slice else z_ix ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number suppress_axes_lines(ax) else: plt.axis('off') col += 1 if seg is not None and seg.shape[1] == 1: ax.imshow(to_rgba(seg[s_ix][0][...,z_ix], cf.cmap), alpha=0.8) elif seg is not None: ax.imshow(to_rgba(np.argmax(seg[s_ix][...,z_ix], axis=0), cf.cmap), alpha=0.8) # gt bounding boxes if gt_boxes is not None and len(gt_boxes[s_ix]) > 0: for j, box in enumerate(gt_boxes[s_ix]): if d > 0: [z1, z2] = box[4:] if not (z1<=z_ix and z_ix<=z2): box = [] if len(box) > 0: [y1, x1, y2, x2] = box[:4] width, height = x2 - x1, y2 - y1 if class_targets is not None: label = cf.class_id2label[class_targets[s_ix][j]] legend_items.add(label) if show_gt_labels: text_poss, p = [(x1, y1), (x1, (y1+y2)//2)], 0 text_fs = title_fs // 3 if roi_items is not None: for name in roi_items: if name in cf_roi_items and batch[name][s_ix][j] is not None: if 'class_targets' in name and cf.plot_class_ids: text_x = x2 #- 2 * text_fs * (len(str(class_targets[s_ix][j]))) # avoid overlap of scores text_y = y1 #+ 2 * text_fs text_str = '{}'.format(class_targets[s_ix][j]) elif 'regression_targets' in name: text_x, text_y = (x2, y2) text_str = "[" + " ".join( ["{:.1f}".format(x) for x in batch[name][s_ix][j]]) + "]" elif 'rg_bin_targets' in name: text_x, text_y = (x1, y2) text_str = '{}'.format(batch[name][s_ix][j]) else: text_pos = text_poss.pop(0) text_x = text_pos[0] #- 2 * text_fs * len(str(batch[name][s_ix][j])) text_y = text_pos[1] #+ 2 * text_fs text_str = '{}'.format(batch[name][s_ix][j]) ax.text(text_x, text_y, text_str, color=cf.white, fontsize=text_fs, bbox=dict(facecolor=label.color, alpha=0.7, edgecolor='none', clip_on=True, pad=0)) p+=1 bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=label.color, facecolor='none') ax.add_patch(bbox) # -----evtly visualise predictions ------------- if pr_boxes is not None or seg_preds is not None: ax = fig.add_subplot(grid[row, col]) ax.imshow(bg_img, cmap="gray") ax.axis("off") col += 1 if row == 0: plt.title("Prediction", fontsize=title_fs) # ---------- pred boxes ------------------------- if pr_boxes is not None and len(pr_boxes[s_ix]) > 0: box_score_thres = cf.min_det_thresh if box_score_thres is None else box_score_thres for j, box in enumerate(pr_boxes[s_ix]): plot_box = box["box_type"] in ["det", "prop"] # , "pos_anchor", "neg_anchor"] if box["box_type"] == "det" and (float(box["box_score"]) <= box_score_thres or box["box_pred_class_id"] == 0): plot_box = False if plot_box: if d > 0: [z1, z2] = box["box_coords"][4:] if not (z1<=z_ix and z_ix<=z2): box = [] if len(box) > 0: [y1, x1, y2, x2] = box["box_coords"][:4] width, height = x2 - x1, y2 - y1 if box["box_type"] == "det": label = cf.class_id2label[box["box_pred_class_id"]] legend_items.add(label) text_x, text_y = x2, y1 id_text = str(box["box_pred_class_id"]) + "|" if cf.plot_class_ids else "" text_str = '{}{:.0f}'.format(id_text, box["box_score"] * 100) text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) edgecolor = label.color if 'regression' in box.keys(): text_x, text_y = x2, y2 id_text = "["+" ".join(["{:.1f}".format(x) for x in box["regression"]])+"]" #str(box["regression"]) #+ "|" if cf.plot_class_ids else "" if 'rg_uncertainty' in box.keys() and not np.isnan(box['rg_uncertainty']): id_text += " | {:.1f}".format(box['rg_uncertainty']) text_str = '{}'.format(id_text) #, box["box_score"] * 100) text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) if 'rg_bin' in box.keys(): text_x, text_y = x1, y2 text_str = '{}'.format(box["rg_bin"]) text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) else: label = cf.box_type2label[box["box_type"]] legend_items.add(label) edgecolor = label.color bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=edgecolor, facecolor='none') ax.add_patch(bbox) # ------------ pred segs -------- if seg_preds is not None: # and seg_preds.shape[1] == 1: if cf.class_specific_seg: ax.imshow(to_rgba(seg_preds[s_ix][0][...,z_ix], cf.cmap), alpha=0.8) else: ax.imshow(bin_seg_to_rgba(seg_preds[s_ix][0][...,z_ix], cf.orange), alpha=0.8) row += 1 # -----actions for all batch entries---------- if legend and len(legend_items) > 0: patches = [] for label in legend_items: if cf.plot_class_ids and type(label) != type(cf.box_labels[0]): id_text = str(label.id) + ":" else: id_text = "" patches.append(mpatches.Patch(color=label.color, label="{}{:.10s}".format(id_text, label.name))) # assumes one image gives enough y-space for 5 legend items ncols = max(1, len(legend_items) // (5 * n_samples)) plt.figlegend(handles=patches, loc="upper center", bbox_to_anchor=(0.99, 0.86), borderaxespad=0., ncol=ncols, bbox_transform=fig.transFigure, fontsize=int(2/3*title_fs)) # fig.set_size_inches(mods+3+ncols-1,1.5+1.2*n_samples) if show_info: plt.figtext(0, 0, "Batch content is of type\n{}\nand has shapes\n".format(btype) + \ "{} for 'data' and {} for 'seg'".format(data_init_shp, seg_init_shp)) if out_file is not None: if cf.server_env: IO_safe(plt.savefig, fname=out_file, dpi=dpi, pad_inches=0.0, bbox_inches='tight', _raise=False) else: plt.savefig(out_file, dpi=dpi, pad_inches=0.0, bbox_inches='tight') + if get_time: + print("generated {} in {:.3f}s".format("plot" if not isinstance(get_time, str) else get_time, time.time()-stime)) if return_fig: return plt.gcf() plt.clf() plt.close() def view_batch_paper(cf, batch, res_dict=None, out_file=None, legend=True, show_info=True, has_colorchannels=False, isRGB=True, show_seg_ids="all", show_seg_pred=True, show_gt_boxes=True, show_gt_labels=False, roi_items="all", split_ens_ics=False, server_env=True, sample_picks=None, vol_slice_picks=None, patient_items=False, box_score_thres=None, plot_mods=True, dpi=400, vmin=None, return_fig=False): r"""view data and target entries of a batch. batch expected as dic with entries 'data' and 'seg' holding tensors or nparrays of size :math:`batch\_size \times modalities \times h \times w` for data and :math:`batch\_size \times classes \times h \times w` or :math:`batch\_size \times 1 \times h \times w` for segs. Classes, even if just dummy, are always needed for plotting since they determine colors. :param cf: :param batch: :param res_dict: :param out_file: :param legend: :param show_info: :param has_colorchannels: :param isRGB: :param show_seg_ids: :param show_seg_pred: :param show_gt_boxes: :param show_gt_labels: :param roi_items: strings "all" or "targets" --> all roi items in cf.roi_items or only those which are targets, or list holding keys/names of entries in cf.roi_items to plot additionally on roi boxes. empty iterator to show none. :param split_ens_ics: :param server_env: :param sample_picks: which indices of the batch to display. None for all. :param vol_slice_picks: when batch elements are 3D: which slices to display. None for all, or tuples ("random", int: amt) / (float€[0,1]: fg_prob, int: amt) for random pick / fg_slices pick w probability fg_prob of amt slices. fg pick requires gt seg. :param patient_items: set to true if patient-wise batch items should be displayed (need to be contained in batch and marked via 'patient_' prefix. :param box_score_thres: plot only boxes with pred_score > box_score_thres. None or 0. for no thres. :param plot_mods: :param dpi: graphics resolution :param vmin: min value for gs cmap in imshow, set to fix inter-batch, or None for intra-batch. pyplot expects dimensions in order y,x,chans (height, width, chans) for imshow. show_seg_ids: "all" or None or list with seg classes to show (seg_ids) """ # pfix = prefix, ptfix = postfix pfix = 'patient_' if patient_items else '' ptfix = '_2d' if (patient_items and cf.dim==2) else '' # -------------- get data, set flags ----------------- btype = type(batch[pfix + 'data']) data = batch[pfix + 'data'].astype("float32") seg = batch[pfix + 'seg'] # seg = np.array(seg).mean(axis=0, keepdims=True) # seg[seg>0] = 1. print("Showing multirater GT") data_init_shp, seg_init_shp = data.shape, seg.shape fg_slices = np.where(np.sum(np.sum(np.squeeze(seg), axis=0), axis=0)>0)[0] if len(fg_slices)==0: print("skipping empty patient") return if vol_slice_picks is None: vol_slice_picks = fg_slices print("data shp, seg shp", data_init_shp, seg_init_shp) plot_bg = batch['plot_bg'] if 'plot_bg' in batch.keys() and not isinstance(batch['plot_bg'], (int, float)) else None plot_bg_chan = batch['plot_bg'] if 'plot_bg' in batch.keys() and isinstance(batch['plot_bg'], (int, float)) else 0 gt_boxes = batch[pfix+'bb_target'+ptfix] if pfix+'bb_target'+ptfix in batch.keys() and show_gt_boxes else None class_targets = batch[pfix+'class_targets'+ptfix] if pfix+'class_targets'+ptfix in batch.keys() else None cf_roi_items = [pfix+it+ptfix for it in cf.roi_items] if roi_items == "all": roi_items = [it for it in cf_roi_items] elif roi_items == "targets": roi_items = [it for it in cf_roi_items if 'targets' in it] else: roi_items = [it for it in cf_roi_items if it in roi_items] if res_dict is not None: seg_preds = res_dict["seg_preds"] if (show_seg_pred is not None and 'seg_preds' in res_dict.keys() and show_seg_ids) else None if '2D_boxes' in res_dict.keys(): assert cf.dim==2 pr_boxes = res_dict["2D_boxes"] elif 'boxes' in res_dict.keys(): pr_boxes = res_dict["boxes"] else: pr_boxes = None else: seg_preds = None pr_boxes = None # -------------- get shapes, apply sample selection ----------------- (n_samples, mods, h, w), d = data.shape[:4], 0 z_ics = [slice(None)] if has_colorchannels: #has to be 2D data = np.transpose(data, axes=(0, 2, 3, 1)) # now b,y,x,c mods = 1 else: if len(data.shape) == 5: # 3dim case d = data.shape[4] if vol_slice_picks is None: z_ics = np.arange(0, d) # elif hasattr(vol_slice_picks, "__iter__") and vol_slice_picks[0]=="random": # z_ics = np.random.choice(np.arange(0, d), size=min(vol_slice_picks[1], d), replace=False) else: z_ics = vol_slice_picks sample_ics = range(n_samples) # 8000 approx value of pixels that are displayable in one figure dim (pyplot has a render limit), depends on dpi however if data.shape[0]*data.shape[2]*len(z_ics)>8000: n_picks = max(1, int(8000/(data.shape[2]*len(z_ics)))) if len(z_ics)>1: if vol_slice_picks is None: z_ics = np.random.choice(np.arange(0, data.shape[4]), size=min(data.shape[4], max(1,int(8000/(n_picks*data.shape[2])))), replace=False) else: z_ics = np.random.choice(vol_slice_picks, size=min(len(vol_slice_picks), max(1,int(8000/(n_picks*data.shape[2])))), replace=False) if sample_picks is None: sample_picks = np.random.choice(data.shape[0], n_picks, replace=False) if sample_picks is not None: sample_ics = [s for s in sample_picks if s in sample_ics] n_samples = len(sample_ics) if not plot_mods: mods = 0 if show_seg_ids=="all": show_seg_ids = np.unique(seg) legend_items = set() # -------------- setup figure ----------------- if isRGB: data = RGB_to_rgb(data) if plot_bg is not None: plot_bg = RGB_to_rgb(plot_bg) n_cols = mods if seg is not None or gt_boxes is not None: n_cols += 1 if seg_preds is not None or pr_boxes is not None: n_cols += 1 n_rows = n_samples*len(z_ics) grid = gridspec.GridSpec(n_rows, n_cols, wspace=0.01, hspace=0.0) fig = plt.figure(figsize=((n_cols + 1)*2, n_rows*2), tight_layout=True) title_fs = 12 # fontsize sample_ics, z_ics = sorted(sample_ics), sorted(z_ics) row = 0 # current row for s_count, s_ix in enumerate(sample_ics): for z_ix in z_ics: col = 0 # current col # ----visualise input data ------------- if has_colorchannels: if plot_mods: ax = fig.add_subplot(grid[row, col]) ax.imshow(data[s_ix][...,z_ix]) ax.axis("off") if row == 0: plt.title("Input", fontsize=title_fs) if col == 0: # key = "spec" if "spec" in batch.keys() else "pid" specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix) == slice else z_ix ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number col += 1 bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix][...,z_ix] else: for mod in range(mods): ax = fig.add_subplot(grid[row, col]) ax.imshow(data[s_ix, mod][...,z_ix], cmap="gray", vmin=vmin) suppress_axes_lines(ax) if row == 0: plt.title("Mod. " + str(mod), fontsize=title_fs) if col == 0: # key = "spec" if "spec" in batch.keys() else "pid" specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix)==slice else z_ix ylabel = str(specs[s_ix])[-5:]+"/"+str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number col += 1 bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix, plot_bg_chan][...,z_ix] # ---evtly visualise groundtruths------------------- if seg is not None or gt_boxes is not None: # img as bg for gt ax = fig.add_subplot(grid[row, col]) ax.imshow(bg_img, cmap="gray", vmin=vmin) if row == 0: plt.title("Ground Truth+ Pred", fontsize=title_fs) if col == 0: specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix) == slice else z_ix ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number suppress_axes_lines(ax) else: plt.axis('off') col += 1 if seg is not None and seg.shape[1] == 1: cmap = {1: cf.orange} ax.imshow(to_rgba(seg[s_ix][0][...,z_ix], cmap), alpha=0.8) # gt bounding boxes if gt_boxes is not None and len(gt_boxes[s_ix]) > 0: for j, box in enumerate(gt_boxes[s_ix]): if d > 0: [z1, z2] = box[4:] if not (z1<=z_ix and z_ix<=z2): box = [] if len(box) > 0: [y1, x1, y2, x2] = box[:4] # [x1,y1,x2,y2] = box[:4]#:return: coords (x1, y1, x2, y2) width, height = x2 - x1, y2 - y1 if class_targets is not None: label = cf.class_id2label[class_targets[s_ix][j]] legend_items.add(label) if show_gt_labels and cf.plot_class_ids: text_poss, p = [(x1, y1), (x1, (y1+y2)//2)], 0 text_fs = title_fs // 3 if roi_items is not None: for name in roi_items: if name in cf_roi_items and batch[name][s_ix][j] is not None: if 'class_targets' in name: text_x = x2 #- 2 * text_fs * (len(str(class_targets[s_ix][j]))) # avoid overlap of scores text_y = y1 #+ 2 * text_fs text_str = '{}'.format(class_targets[s_ix][j]) elif 'regression_targets' in name: text_x, text_y = (x2, y2) text_str = "[" + " ".join( ["{:.1f}".format(x) for x in batch[name][s_ix][j]]) + "]" elif 'rg_bin_targets' in name: text_x, text_y = (x1, y2) text_str = '{}'.format(batch[name][s_ix][j]) else: text_pos = text_poss.pop(0) text_x = text_pos[0] #- 2 * text_fs * len(str(batch[name][s_ix][j])) text_y = text_pos[1] #+ 2 * text_fs text_str = '{}'.format(batch[name][s_ix][j]) ax.text(text_x, text_y, text_str, color=cf.black if label.color==cf.yellow else cf.white, fontsize=text_fs, bbox=dict(facecolor=label.color, alpha=0.7, edgecolor='none', clip_on=True, pad=0)) p+=1 bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=label.color, facecolor='none') ax.add_patch(bbox) # # -----evtly visualise predictions ------------- # if pr_boxes is not None or seg_preds is not None: # ax = fig.add_subplot(grid[row, col]) # ax.imshow(bg_img, cmap="gray") # ax.axis("off") # col += 1 # if row == 0: # plt.title("Prediction", fontsize=title_fs) # ---------- pred boxes ------------------------- if pr_boxes is not None and len(pr_boxes[s_ix]) > 0: box_score_thres = cf.min_det_thresh if box_score_thres is None else box_score_thres for j, box in enumerate(pr_boxes[s_ix]): plot_box = box["box_type"] in ["det", "prop"] # , "pos_anchor", "neg_anchor"] if box["box_type"] == "det" and (float(box["box_score"]) <= box_score_thres or box["box_pred_class_id"] == 0): plot_box = False if plot_box: if d > 0: [z1, z2] = box["box_coords"][4:] if not (z1<=z_ix and z_ix<=z2): box = [] if len(box) > 0: [y1, x1, y2, x2] = box["box_coords"][:4] width, height = x2 - x1, y2 - y1 if box["box_type"] == "det": label = cf.bin_id2label[box["rg_bin"]] color = cf.aubergine legend_items.add(label) text_x, text_y = x2, y1 #id_text = str(box["box_pred_class_id"]) + "|" if cf.plot_class_ids else "" id_text = "fg: " text_str = '{}{:.0f}'.format(id_text, box["box_score"] * 100) text_settings = dict(facecolor=color, alpha=0.5, edgecolor='none', clip_on=True, pad=0.2) ax.text(text_x, text_y, text_str, color=cf.black if label.color==cf.yellow else cf.white, bbox=text_settings, fontsize=title_fs // 2) edgecolor = color #label.color if 'regression' in box.keys(): text_x, text_y = x2, y2 id_text = "ms: "+" ".join(["{:.1f}".format(x) for x in box["regression"]])+"" text_str = '{}'.format(id_text) #, box["box_score"] * 100) text_settings = dict(facecolor=color, alpha=0.5, edgecolor='none', clip_on=True, pad=0.2) ax.text(text_x, text_y, text_str, color=cf.black if label.color==cf.yellow else cf.white, bbox=text_settings, fontsize=title_fs // 2) if 'rg_bin' in box.keys(): text_x, text_y = x1, y2 text_str = '{}'.format(box["rg_bin"]) text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) # ax.text(text_x, text_y, text_str, color=cf.white, # bbox=text_settings, fontsize=title_fs // 4) if split_ens_ics and "ens_ix" in box.keys(): n_aug = box["ens_ix"].split("_")[1] edgecolor = [c for c in cf.color_palette if not c == cf.green][ int(n_aug) % (len(cf.color_palette) - 1)] text_x, text_y = x1, y2 text_str = "{}".format(box["ens_ix"][2:]) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 6) else: label = cf.box_type2label[box["box_type"]] legend_items.add(label) edgecolor = label.color bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=edgecolor, facecolor='none') ax.add_patch(bbox) row += 1 # -----actions for all batch entries---------- if legend and len(legend_items) > 0: patches = [] for label in legend_items: if cf.plot_class_ids and type(label) != type(cf.box_labels[0]): id_text = str(label.id) + ":" else: id_text = "" patches.append(mpatches.Patch(color=label.color, label="{}{:.10s}".format(id_text, label.name))) # assumes one image gives enough y-space for 5 legend items ncols = max(1, len(legend_items) // (5 * n_samples)) plt.figlegend(handles=patches, loc="upper center", bbox_to_anchor=(0.99, 0.86), borderaxespad=0., ncol=ncols, bbox_transform=fig.transFigure, fontsize=int(2/3*title_fs)) # fig.set_size_inches(mods+3+ncols-1,1.5+1.2*n_samples) if show_info: plt.figtext(0, 0, "Batch content is of type\n{}\nand has shapes\n".format(btype) + \ "{} for 'data' and {} for 'seg'".format(data_init_shp, seg_init_shp)) if out_file is not None: plt.savefig(out_file, dpi=dpi, pad_inches=0.0, bbox_inches='tight', tight_layout=True) if return_fig: return plt.gcf() if not (server_env or cf.server_env): plt.show() plt.clf() plt.close() def view_batch_thesis(cf, batch, res_dict=None, out_file=None, legend=True, has_colorchannels=False, isRGB=True, show_seg_ids="all", show_seg_pred=True, show_gt_boxes=True, show_gt_labels=False, show_cl_ids=True, roi_items="all", server_env=True, sample_picks=None, vol_slice_picks=None, fontsize=12, seg_cmap="class", patient_items=False, box_score_thres=None, plot_mods=True, dpi=400, vmin=None, return_fig=False, axes=None): r"""view data and target entries of a batch. batch expected as dic with entries 'data' and 'seg' holding tensors or nparrays of size :math:`batch\_size \times modalities \times h \times w` for data and :math:`batch\_size \times classes \times h \times w` or :math:`batch\_size \times 1 \times h \times w` for segs. Classes, even if just dummy, are always needed for plotting since they determine colors. :param cf: :param batch: :param res_dict: :param out_file: :param legend: :param show_info: :param has_colorchannels: :param isRGB: :param show_seg_ids: :param show_seg_pred: :param show_gt_boxes: :param show_gt_labels: :param roi_items: strings "all" or "targets" --> all roi items in cf.roi_items or only those which are targets, or list holding keys/names of entries in cf.roi_items to plot additionally on roi boxes. empty iterator to show none. :param split_ens_ics: :param server_env: :param sample_picks: which indices of the batch to display. None for all. :param vol_slice_picks: when batch elements are 3D: which slices to display. None for all, or tuples ("random", int: amt) / (float€[0,1]: fg_prob, int: amt) for random pick / fg_slices pick w probability fg_prob of amt slices. fg pick requires gt seg. :param patient_items: set to true if patient-wise batch items should be displayed (need to be contained in batch and marked via 'patient_' prefix. :param box_score_thres: plot only boxes with pred_score > box_score_thres. None or 0. for no thres. :param plot_mods: :param dpi: graphics resolution :param vmin: min value for gs cmap in imshow, set to fix inter-batch, or None for intra-batch. pyplot expects dimensions in order y,x,chans (height, width, chans) for imshow. show_seg_ids: "all" or None or list with seg classes to show (seg_ids) """ # pfix = prefix, ptfix = postfix pfix = 'patient_' if patient_items else '' ptfix = '_2d' if (patient_items and cf.dim==2) else '' # -------------- get data, set flags ----------------- btype = type(batch[pfix + 'data']) data = batch[pfix + 'data'].astype("float32") seg = batch[pfix + 'seg'] data_init_shp, seg_init_shp = data.shape, seg.shape fg_slices = np.where(np.sum(np.sum(np.squeeze(seg), axis=0), axis=0)>0)[0] if len(fg_slices)==0: print("skipping empty patient") return if vol_slice_picks is None: vol_slice_picks = fg_slices #print("data shp, seg shp", data_init_shp, seg_init_shp) plot_bg = batch['plot_bg'] if 'plot_bg' in batch.keys() and not isinstance(batch['plot_bg'], (int, float)) else None plot_bg_chan = batch['plot_bg'] if 'plot_bg' in batch.keys() and isinstance(batch['plot_bg'], (int, float)) else 0 gt_boxes = batch[pfix+'bb_target'+ptfix] if pfix+'bb_target'+ptfix in batch.keys() and show_gt_boxes else None class_targets = batch[pfix+'class_targets'+ptfix] if pfix+'class_targets'+ptfix in batch.keys() else None cl_targets_sa = batch[pfix+'class_targets_sa'+ptfix] if pfix+'class_targets_sa'+ptfix in batch.keys() else None cf_roi_items = [pfix+it+ptfix for it in cf.roi_items] if roi_items == "all": roi_items = [it for it in cf_roi_items] elif roi_items == "targets": roi_items = [it for it in cf_roi_items if 'targets' in it] else: roi_items = [it for it in cf_roi_items if it in roi_items] if res_dict is not None: seg_preds = res_dict["seg_preds"] if (show_seg_pred is not None and 'seg_preds' in res_dict.keys() and show_seg_ids) else None if '2D_boxes' in res_dict.keys(): assert cf.dim==2 pr_boxes = res_dict["2D_boxes"] elif 'boxes' in res_dict.keys(): pr_boxes = res_dict["boxes"] else: pr_boxes = None else: seg_preds = None pr_boxes = None # -------------- get shapes, apply sample selection ----------------- (n_samples, mods, h, w), d = data.shape[:4], 0 z_ics = [slice(None)] if has_colorchannels: #has to be 2D data = np.transpose(data, axes=(0, 2, 3, 1)) # now b,y,x,c mods = 1 else: if len(data.shape) == 5: # 3dim case d = data.shape[4] if vol_slice_picks is None: z_ics = np.arange(0, d) else: z_ics = vol_slice_picks sample_ics = range(n_samples) # 8000 approx value of pixels that are displayable in one figure dim (pyplot has a render limit), depends on dpi however if data.shape[0]*data.shape[2]*len(z_ics)>8000: n_picks = max(1, int(8000/(data.shape[2]*len(z_ics)))) if len(z_ics)>1 and vol_slice_picks is None: z_ics = np.random.choice(np.arange(0, data.shape[4]), size=min(data.shape[4], max(1,int(8000/(n_picks*data.shape[2])))), replace=False) if sample_picks is None: sample_picks = np.random.choice(data.shape[0], n_picks, replace=False) if sample_picks is not None: sample_ics = [s for s in sample_picks if s in sample_ics] n_samples = len(sample_ics) if not plot_mods: mods = 0 if show_seg_ids=="all": show_seg_ids = np.unique(seg) legend_items = set() # -------------- setup figure ----------------- if isRGB: data = RGB_to_rgb(data) if plot_bg is not None: plot_bg = RGB_to_rgb(plot_bg) n_cols = mods if seg is not None or gt_boxes is not None: n_cols += 1 if seg_preds is not None or pr_boxes is not None: n_cols += 1 n_rows = n_samples*len(z_ics) grid = gridspec.GridSpec(n_rows, n_cols, wspace=0.01, hspace=0.0) fig = plt.figure(figsize=((n_cols + 1)*2, n_rows*2), tight_layout=True) title_fs = fontsize # fontsize text_fs = title_fs * 2 / 3 sample_ics, z_ics = sorted(sample_ics), sorted(z_ics) row = 0 # current row for s_count, s_ix in enumerate(sample_ics): for z_ix in z_ics: col = 0 # current col # ----visualise input data ------------- if has_colorchannels: if plot_mods: ax = fig.add_subplot(grid[row, col]) ax.imshow(data[s_ix][...,z_ix]) ax.axis("off") if row == 0: plt.title("Input", fontsize=title_fs) if col == 0: # key = "spec" if "spec" in batch.keys() else "pid" specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix) == slice else z_ix ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number col += 1 bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix][...,z_ix] else: for mod in range(mods): ax = fig.add_subplot(grid[row, col]) ax.imshow(data[s_ix, mod][...,z_ix], cmap="gray", vmin=vmin) suppress_axes_lines(ax) if row == 0: plt.title("Mod. " + str(mod), fontsize=title_fs) if col == 0: # key = "spec" if "spec" in batch.keys() else "pid" specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix)==slice else z_ix ylabel = str(specs[s_ix])[-5:]+"/"+str(intra_patient_ix) ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number col += 1 bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix, plot_bg_chan][...,z_ix] # ---evtly visualise groundtruths------------------- if seg is not None or gt_boxes is not None: # img as bg for gt if axes is not None and 'gt' in axes.keys(): ax = axes['gt'] else: ax = fig.add_subplot(grid[row, col]) ax.imshow(bg_img, cmap="gray", vmin=vmin) if row == 0: ax.set_title("Ground Truth", fontsize=title_fs) if col == 0: # key = "spec" if "spec" in batch.keys() else "pid" specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix) == slice else z_ix ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) # str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=text_fs*1.3) # show id-number suppress_axes_lines(ax) else: ax.axis('off') col += 1 # gt bounding boxes if gt_boxes is not None and len(gt_boxes[s_ix]) > 0: for j, box in enumerate(gt_boxes[s_ix]): if d > 0: [z1, z2] = box[4:] if not (z1<=z_ix and z_ix<=z2): box = [] if len(box) > 0: [y1, x1, y2, x2] = box[:4] # [x1,y1,x2,y2] = box[:4]#:return: coords (x1, y1, x2, y2) width, height = x2 - x1, y2 - y1 if class_targets is not None: try: label = cf.bin_id2label[cf.rg_val_to_bin_id(batch['patient_regression_targets'][s_ix][j])] except AttributeError: label = cf.class_id2label[class_targets[s_ix][j]] legend_items.add(label) if show_gt_labels and cf.plot_class_ids: bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=label.color, facecolor='none') if height<=text_fs*6: y1 -= text_fs*1.5 y2 += text_fs*2 text_poss, p = [(x1, y1), (x1, (y1+y2)//2)], 0 if roi_items is not None: for name in roi_items: if name in cf_roi_items and batch[name][s_ix][j] is not None: if 'class_targets' in name: text_str = '{}'.format(class_targets[s_ix][j]) text_x, text_y = (x2 + 0 * len(text_str) // 4, y2) elif 'regression_targets' in name: text_str = 'agg. MS: {:.2f}'.format(batch[name][s_ix][j][0]) text_x, text_y = (x2 + 0 * len(text_str) // 4, y2) elif 'rg_bin_targets_sa' in name: text_str = 'sa. MS: {}'.format(batch[name][s_ix][j]) text_x, text_y = (x2-0*len(text_str)*text_fs//4, y1) # elif 'rg_bin_targets' in name: # text_str = 'agg. ms:{}'.format(batch[name][s_ix][j]) # text_x, text_y = (x2+0*len(text_str)//4, y1) ax.text(text_x, text_y, text_str, color=cf.black if (label.color[:3]==cf.yellow or label.color[:3]==cf.green) else cf.white, fontsize=text_fs, bbox=dict(facecolor=label.color, alpha=0.7, edgecolor='none', clip_on=True, pad=0)) p+=1 ax.add_patch(bbox) if seg is not None and seg.shape[1] == 1: #cmap = {1: cf.orange} # cmap = {label_id: label.color for label_id, label in cf.bin_id2label.items()} # this whole function is totally only hacked together for a quick very specific case if seg_cmap == "rg" or seg_cmap=="regression": cmap = {1: cf.bin_id2label[cf.rg_val_to_bin_id(batch['patient_regression_targets'][s_ix][0])].color} else: cmap = cf.class_cmap ax.imshow(to_rgba(seg[s_ix][0][...,z_ix], cmap), alpha=0.8) # # -----evtly visualise predictions ------------- if pr_boxes is not None or seg_preds is not None: if axes is not None and 'pred' in axes.keys(): ax = axes['pred'] else: ax = fig.add_subplot(grid[row, col]) ax.imshow(bg_img, cmap="gray") ax.axis("off") col += 1 if row == 0: ax.set_title("Prediction", fontsize=title_fs) # ---------- pred boxes ------------------------- if pr_boxes is not None and len(pr_boxes[s_ix]) > 0: alpha = 0.7 box_score_thres = cf.min_det_thresh if box_score_thres is None else box_score_thres for j, box in enumerate(pr_boxes[s_ix]): plot_box = box["box_type"] in ["det", "prop"] # , "pos_anchor", "neg_anchor"] if box["box_type"] == "det" and (float(box["box_score"]) <= box_score_thres or box["box_pred_class_id"] == 0): plot_box = False if plot_box: if d > 0: [z1, z2] = box["box_coords"][4:] if not (z1<=z_ix and z_ix<=z2): box = [] if len(box) > 0: [y1, x1, y2, x2] = box["box_coords"][:4] width, height = x2 - x1, y2 - y1 if box["box_type"] == "det": try: label = cf.bin_id2label[cf.rg_val_to_bin_id(box['regression'])] except AttributeError: label = cf.class_id2label[box['box_pred_class_id']] # assert box["rg_bin"] == cf.rg_val_to_bin_id(box['regression']), \ # "box bin: {}, rg-bin {}".format(box["rg_bin"], cf.rg_val_to_bin_id(box['regression'])) color = label.color#cf.aubergine edgecolor = color # label.color text_color = cf.black if (color[:3]==cf.yellow or color[:3]==cf.green) else cf.white legend_items.add(label) bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=edgecolor, facecolor='none') if height<=text_fs*6: y1 -= text_fs*1.5 y2 += text_fs*2 text_x, text_y = x2, y1 #id_text = str(box["box_pred_class_id"]) + "|" if cf.plot_class_ids else "" id_text = "FG: " text_str = r'{}{:.0f}%'.format(id_text, box["box_score"] * 100) text_settings = dict(facecolor=color, alpha=alpha, edgecolor='none', clip_on=True, pad=0.2) ax.text(text_x, text_y, text_str, color=text_color, bbox=text_settings, fontsize=text_fs ) if 'regression' in box.keys(): text_x, text_y = x2, y2 id_text = "MS: "+" ".join(["{:.2f}".format(x) for x in box["regression"]])+"" text_str = '{}'.format(id_text) text_settings = dict(facecolor=color, alpha=alpha, edgecolor='none', clip_on=True, pad=0.2) ax.text(text_x, text_y, text_str, color=text_color, bbox=text_settings, fontsize=text_fs) if 'rg_bin' in box.keys(): text_x, text_y = x1, y2 text_str = '{}'.format(box["rg_bin"]) text_settings = dict(facecolor=color, alpha=alpha, edgecolor='none', clip_on=True, pad=0) # ax.text(text_x, text_y, text_str, color=cf.white, # bbox=text_settings, fontsize=title_fs // 4) if 'box_pred_class_id' in box.keys() and show_cl_ids: text_x, text_y = x2, y2 id_text = box["box_pred_class_id"] text_str = '{}'.format(id_text) text_settings = dict(facecolor=color, alpha=alpha, edgecolor='none', clip_on=True, pad=0.2) ax.text(text_x, text_y, text_str, color=text_color, bbox=text_settings, fontsize=text_fs) else: label = cf.box_type2label[box["box_type"]] legend_items.add(label) edgecolor = label.color ax.add_patch(bbox) row += 1 # -----actions for all batch entries---------- if legend and len(legend_items) > 0: patches = [] for label in legend_items: if cf.plot_class_ids and type(label) != type(cf.box_labels[0]): id_text = str(label.id) + ":" else: id_text = "" patches.append(mpatches.Patch(color=label.color, label="{}{:.10s}".format(id_text, label.name))) # assumes one image gives enough y-space for 5 legend items ncols = max(1, len(legend_items) // (5 * n_samples)) plt.figlegend(handles=patches, loc="upper center", bbox_to_anchor=(0.99, 0.86), borderaxespad=0., ncol=ncols, bbox_transform=fig.transFigure, fontsize=int(2/3*title_fs)) # fig.set_size_inches(mods+3+ncols-1,1.5+1.2*n_samples) if out_file is not None: plt.savefig(out_file, dpi=dpi, pad_inches=0.0, bbox_inches='tight', tight_layout=True) if return_fig: return plt.gcf() if not (server_env or cf.server_env): plt.show() plt.clf() plt.close() def view_slices(cf, img, seg=None, ids=None, title="", out_dir=None, legend=True, cmap=None, label_remap=None, instance_labels=False): """View slices of a 3D image overlayed with corresponding segmentations. :params img, seg: expected as 3D-arrays """ if isinstance(img, sitk.SimpleITK.Image): img = sitk.GetArrayViewFromImage(img) elif isinstance(img, np.ndarray): #assume channels dim is smallest and in either first or last place if np.argmin(img.shape)==2: img = np.moveaxis(img, 2,0) else: raise Exception("view_slices got unexpected img type.") if seg is not None: if isinstance(seg, sitk.SimpleITK.Image): seg = sitk.GetArrayViewFromImage(seg) elif isinstance(img, np.ndarray): if np.argmin(seg.shape)==2: seg = np.moveaxis(seg, 2,0) else: raise Exception("view_slices got unexpected seg type.") if label_remap is not None: for (key, val) in label_remap.items(): seg[seg==key] = val if instance_labels: class Label(): def __init__(self, id, name, color): self.id = id self.name = name self.color = color legend_items = {Label(seg_id, "instance_{}".format(seg_id), cf.color_palette[seg_id%len(cf.color_palette)]) for seg_id in np.unique(seg)} if cmap is None: cmap = {label.id : label.color for label in legend_items} else: legend_items = {cf.seg_id2label[seg_id] for seg_id in np.unique(seg)} if cmap is None: cmap = {label.id : label.color for label in legend_items} slices = img.shape[0] if seg is not None: assert slices==seg.shape[0], "Img and seg have different amt of slices." grid = gridspec.GridSpec(int(np.ceil(slices/4)),4) fig = plt.figure(figsize=(10, slices/4*2.5)) rng = np.arange(slices, dtype='uint8') if not ids is None: rng = rng[ids] for s in rng: ax = fig.add_subplot(grid[int(s/4),int(s%4)]) ax.imshow(img[s], cmap="gray") if not seg is None: ax.imshow(to_rgba(seg[s], cmap), alpha=0.9) if legend and int(s/4)==0 and int(s%4)==3: patches = [mpatches.Patch(color=label.color, label="{}".format(label.name)) for label in legend_items] ncols = 1 plt.legend(handles=patches,bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., ncol=ncols) plt.title("slice {}, {}".format(s, img[s].shape)) plt.axis('off') plt.suptitle(title) if out_dir is not None: plt.savefig(out_dir, dpi=300, pad_inches=0.0, bbox_inches='tight') if not cf.server_env: plt.show() plt.close() def plot_txt(cf, txts, labels=None, title="", x_label="", y_labels=["",""], y_ranges=(None,None), twin_axes=(), smooth=None, out_dir=None): """Read and plot txt data, either from file (txts is paths) or directly (txts is arrays). :param twin_axes: plot two y-axis over same x-axis. twin_axes expected as tuple defining which txt files (determined via indices) share the second y-axis. """ if isinstance(txts, str) or not hasattr(txts, '__iter__'): txts = [txts] fig = plt.figure() ax1 = fig.add_subplot(1,1,1) if len(twin_axes)>0: ax2 = ax1.twinx() for i, txt in enumerate(txts): if isinstance(txt, str): arr = np.genfromtxt(txt, delimiter=',',skip_header=1, usecols=(1,2)) else: arr = txt if i in twin_axes: ax = ax2 else: ax = ax1 if smooth is not None: spline_graph = interpol.UnivariateSpline(arr[:,0], arr[:,1], k=5, s=float(smooth)) ax.plot(arr[:, 0], spline_graph(arr[:,0]), color=cf.color_palette[i % len(cf.color_palette)], marker='', markersize=2, linestyle='solid') ax.plot(arr[:,0], arr[:,1], color=cf.color_palette[i%len(cf.color_palette)], marker='', markersize=2, linestyle='solid', label=labels[i], alpha=0.5 if smooth else 1.) plt.title(title) ax1.set_xlabel(x_label) ax1.set_ylabel(y_labels[0]) if y_ranges[0] is not None: ax1.set_ylim(y_ranges[0]) if len(twin_axes)>0: ax2.set_ylabel(y_labels[1]) if y_ranges[1] is not None: ax2.set_ylim(y_ranges[1]) plt.grid() if labels is not None: ax1.legend(loc="upper center") if len(twin_axes)>0: ax2.legend(loc=4) if out_dir is not None: plt.savefig(out_dir, dpi=200) return fig def plot_tboard_logs(cf, log_dir, tag_filters=[""], inclusive_filters=True, out_dir=None, x_label="", y_labels=["",""], y_ranges=(None,None), twin_axes=(), smooth=None): """Plot (only) tboard scalar logs from given log_dir for multiple runs sorted by tag. """ print("log dir", log_dir) mpl = EventMultiplexer().AddRunsFromDirectory(log_dir) #EventAccumulator(log_dir) mpl.Reload() # Print tags of contained entities, use these names to retrieve entities as below #print(mpl.Runs()) scalars = {runName : data['scalars'] for (runName, data) in mpl.Runs().items() if len(data['scalars'])>0} print("scalars", scalars) tags = {} tag_filters = [tag_filter.lower() for tag_filter in tag_filters] for (runName, runtags) in scalars.items(): print("rn", runName.lower()) check = np.any if inclusive_filters else np.all if np.any([tag_filter in runName.lower() for tag_filter in tag_filters]): for runtag in runtags: #if tag_filter in runtag.lower(): if runtag not in tags: tags[runtag] = [runName] else: tags[runtag].append(runName) print("tags ", tags) for (tag, runNames) in tags.items(): print("runnames ", runNames) print("tag", tag) tag_scalars = [] labels = [] for run in runNames: #mpl.Scalars returns ScalarEvents array holding wall_time, step, value per time step (shape series_length x 3) #print(mpl.Scalars(runName, tag)[0]) run_scalars = [(s.step, s.value) for s in mpl.Scalars(run, tag)] print(np.array(run_scalars).shape) tag_scalars.append(np.array(run_scalars)) print("run", run) labels.append("/".join(run.split("/")[-2:])) #print("tag scalars ", tag_scalars) if out_dir is not None: out_path = os.path.join(out_dir,tag.replace("/","_")) else: out_path = None plot_txt(txts=tag_scalars, labels=labels, title=tag, out_dir=out_path, cf=cf, x_label=x_label, y_labels=y_labels, y_ranges=y_ranges, twin_axes=twin_axes, smooth=smooth) def plot_box_legend(cf, box_coords=None, class_id=None, out_dir=None): """plot a blank box explaining box annotations. :param cf: :return: """ if class_id is None: class_id = 1 img = np.ones(cf.patch_size[:2]) dim_max = max(cf.patch_size[:2]) width, height = cf.patch_size[0] // 2, cf.patch_size[1] // 2 if box_coords is None: # lower left corner x1, y1 = width // 2, height // 2 x2, y2 = x1 + width, y1 + height else: y1, x1, y2, x2 = box_coords fig = plt.figure(tight_layout=True, dpi=300) ax = fig.add_subplot(111) title_fs = 36 label = cf.class_id2label[class_id] # legend_items.add(label) ax.set_facecolor(cf.beige) ax.imshow(img, cmap='gray', vmin=0., vmax=1., alpha=0) # ax.axis('off') # suppress_axes_lines(ax) ax.set_xticks([]) ax.set_yticks([]) text_x, text_y = x2 * 0.85, y1 id_text = "class id" + " | " if cf.plot_class_ids else "" text_str = '{}{}'.format(id_text, "confidence") text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) edgecolor = label.color if any(['regression' in task for task in cf.prediction_tasks]): text_x, text_y = x2 * 0.85, y2 id_text = "regression" if any(['ken_gal' in task or 'feindt' in task for task in cf.prediction_tasks]): id_text += " | uncertainty" text_str = '{}'.format(id_text) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) if 'regression_bin' in cf.prediction_tasks or hasattr(cf, "rg_val_to_bin_id"): text_x, text_y = x1, y2 text_str = 'Rg. Bin' ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) if 'lesion_gleasons' in cf.observables_rois: text_x, text_y = x1, y1 text_str = 'Gleason Score' ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=1., edgecolor=edgecolor, facecolor='none') ax.add_patch(bbox) if out_dir is not None: plt.savefig(os.path.join(out_dir, "box_legend.png")) def plot_boxes(cf, box_coords, patch_size=None, scores=None, class_ids=None, out_file=None, ax=None): if patch_size is None: patch_size = cf.patch_size[:2] if class_ids is None: class_ids = np.ones((len(box_coords),), dtype='uint8') if scores is None: scores = np.ones((len(box_coords),), dtype='uint8') img = np.ones(patch_size) y1, x1, y2, x2 = box_coords[:,0], box_coords[:,1], box_coords[:,2], box_coords[:,3] width, height = x2-x1, y2-y1 close = False if ax is None: fig = plt.figure(tight_layout=True, dpi=300) ax = fig.add_subplot(111) close = True title_fs = 56 ax.set_facecolor((*cf.gray,0.15)) ax.imshow(img, cmap='gray', vmin=0., vmax=1., alpha=0) #ax.axis('off') #suppress_axes_lines(ax) ax.set_xticks([]) ax.set_yticks([]) for bix, cl_id in enumerate(class_ids): label = cf.class_id2label[cl_id] text_x, text_y = x2[bix] -20, y1[bix] +5 id_text = class_ids[bix] if cf.plot_class_ids else "" text_str = '{}{}{:.0f}'.format(id_text, " | ", scores[bix] * 100) text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) edgecolor = label.color bbox = mpatches.Rectangle((x1[bix], y1[bix]), width[bix], height[bix], linewidth=1., edgecolor=edgecolor, facecolor='none') ax.add_patch(bbox) if out_file is not None: plt.savefig(out_file) if close: plt.close() if __name__=="__main__": cluster_exp_root = "/mnt/E132-Cluster-Projects" #dataset="prostate/" dataset = "lidc/" exp_name = "ms13_mrcnnal3d_rg_bs8_480k" #exp_dir = os.path.join("datasets", dataset, "experiments", exp_name) # exp_dir = os.path.join(cluster_exp_root, dataset, "experiments", exp_name) # log_dir = os.path.join(exp_dir, "logs") # sys.path.append(exp_dir) # from configs import Configs # cf = configs() # # #print("logdir", log_dir) # #out_dir = os.path.join(cf.source_dir, log_dir.replace("/", "_")) # #print("outdir", out_dir) # log_dir = os.path.join(cf.source_dir, log_dir) # plot_tboard_logs(cf, log_dir, tag_filters=["train/lesion_avp", "val/lesion_ap", "val/lesion_avp", "val/patient_lesion_avp"], smooth=2.2, out_dir=log_dir, # y_ranges=([0,900], [0,0.8]), # twin_axes=[1], y_labels=["counts",""], x_label="epoch") #plot_box_legend(cf, out_dir=exp_dir) diff --git a/predictor.py b/predictor.py index 370d2ce..b69f821 100644 --- a/predictor.py +++ b/predictor.py @@ -1,1000 +1,1003 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import os from multiprocessing import Pool import pickle import time import numpy as np import torch from scipy.stats import norm from collections import OrderedDict import plotting as plg import utils.model_utils as mutils +import utils.exp_utils as utils def get_mirrored_patch_crops(patch_crops, org_img_shape): mirrored_patch_crops = [] mirrored_patch_crops.append([[org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2], ii[3]] if len(ii) == 4 else [org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2], ii[3], ii[4], ii[5]] for ii in patch_crops]) mirrored_patch_crops.append([[ii[0], ii[1], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2]] if len(ii) == 4 else [ii[0], ii[1], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2], ii[4], ii[5]] for ii in patch_crops]) mirrored_patch_crops.append([[org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2]] if len(ii) == 4 else [org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2], ii[4], ii[5]] for ii in patch_crops]) return mirrored_patch_crops def get_mirrored_patch_crops_ax_dep(patch_crops, org_img_shape, mirror_axes): mirrored_patch_crops = [] for ax_ix, axes in enumerate(mirror_axes): if isinstance(axes, (int, float)) and int(axes) == 0: mirrored_patch_crops.append([[org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2], ii[3]] if len(ii) == 4 else [org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2], ii[3], ii[4], ii[5]] for ii in patch_crops]) elif isinstance(axes, (int, float)) and int(axes) == 1: mirrored_patch_crops.append([[ii[0], ii[1], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2]] if len(ii) == 4 else [ii[0], ii[1], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2], ii[4], ii[5]] for ii in patch_crops]) elif hasattr(axes, "__iter__") and (tuple(axes) == (0, 1) or tuple(axes) == (1, 0)): mirrored_patch_crops.append([[org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2]] if len(ii) == 4 else [org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2], ii[4], ii[5]] for ii in patch_crops]) else: raise Exception("invalid mirror axes {} in get mirrored patch crops".format(axes)) return mirrored_patch_crops def apply_wbc_to_patient(inputs): """ wrapper around prediction box consolidation: weighted box clustering (wbc). processes a single patient. loops over batch elements in patient results (1 in 3D, slices in 2D) and foreground classes, aggregates and stores results in new list. :return. patient_results_list: list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, and a dummy batch dimension of 1 for 3D predictions. :return. pid: string. patient id. """ regress_flag, in_patient_results_list, pid, class_dict, clustering_iou, n_ens = inputs out_patient_results_list = [[] for _ in range(len(in_patient_results_list))] for bix, b in enumerate(in_patient_results_list): for cl in list(class_dict.keys()): boxes = [(ix, box) for ix, box in enumerate(b) if (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)] box_coords = np.array([b[1]['box_coords'] for b in boxes]) box_scores = np.array([b[1]['box_score'] for b in boxes]) box_center_factor = np.array([b[1]['box_patch_center_factor'] for b in boxes]) box_n_overlaps = np.array([b[1]['box_n_overlaps'] for b in boxes]) try: box_patch_id = np.array([b[1]['patch_id'] for b in boxes]) except KeyError: #backward compatibility for already saved pred results ... omg box_patch_id = np.array([b[1]['ens_ix'] for b in boxes]) box_regressions = np.array([b[1]['regression'] for b in boxes]) if regress_flag else None box_rg_bins = np.array([b[1]['rg_bin'] if 'rg_bin' in b[1].keys() else float('NaN') for b in boxes]) box_rg_uncs = np.array([b[1]['rg_uncertainty'] if 'rg_uncertainty' in b[1].keys() else float('NaN') for b in boxes]) if 0 not in box_scores.shape: keep_scores, keep_coords, keep_n_missing, keep_regressions, keep_rg_bins, keep_rg_uncs = \ weighted_box_clustering(box_coords, box_scores, box_center_factor, box_n_overlaps, box_rg_bins, box_rg_uncs, box_regressions, box_patch_id, clustering_iou, n_ens) for boxix in range(len(keep_scores)): clustered_box = {'box_type': 'det', 'box_coords': keep_coords[boxix], 'box_score': keep_scores[boxix], 'cluster_n_missing': keep_n_missing[boxix], 'box_pred_class_id': cl} if regress_flag: clustered_box.update({'regression': keep_regressions[boxix], 'rg_uncertainty': keep_rg_uncs[boxix], 'rg_bin': keep_rg_bins[boxix]}) out_patient_results_list[bix].append(clustered_box) # add gt boxes back to new output list. out_patient_results_list[bix].extend([box for box in b if box['box_type'] == 'gt']) return [out_patient_results_list, pid] def weighted_box_clustering(box_coords, scores, box_pc_facts, box_n_ovs, box_rg_bins, box_rg_uncs, box_regress, box_patch_id, thresh, n_ens): """Consolidates overlapping predictions resulting from patch overlaps, test data augmentations and temporal ensembling. clusters predictions together with iou > thresh (like in NMS). Output score and coordinate for one cluster are the average weighted by individual patch center factors (how trustworthy is this candidate measured by how centered its position within the patch is) and the size of the corresponding box. The number of expected predictions at a position is n_data_aug * n_temp_ens * n_overlaps_at_position (1 prediction per unique patch). Missing predictions at a cluster position are defined as the number of unique patches in the cluster, which did not contribute any predict any boxes. :param dets: (n_dets, (y1, x1, y2, x2, (z1), (z2), scores, box_pc_facts, box_n_ovs). :param box_coords: y1, x1, y2, x2, (z1), (z2). :param scores: confidence scores. :param box_pc_facts: patch-center factors from position on patch tiles. :param box_n_ovs: number of patch overlaps at box position. :param box_rg_bins: regression bin predictions. :param box_rg_uncs: (n_dets,) regression uncertainties (from model mrcnn_aleatoric). :param box_regress: (n_dets, n_regression_features). :param box_patch_id: ensemble index. :param thresh: threshold for iou_matching. :param n_ens: number of models, that are ensembled. (-> number of expected predictions per position). :return: keep_scores: (n_keep) new scores of boxes to be kept. :return: keep_coords: (n_keep, (y1, x1, y2, x2, (z1), (z2)) new coordinates of boxes to be kept. """ dim = 2 if box_coords.shape[1] == 4 else 3 y1 = box_coords[:,0] x1 = box_coords[:,1] y2 = box_coords[:,2] x2 = box_coords[:,3] areas = (y2 - y1 + 1) * (x2 - x1 + 1) if dim == 3: z1 = box_coords[:, 4] z2 = box_coords[:, 5] areas *= (z2 - z1 + 1) # order is the sorted index. maps order to index o[1] = 24 (rank1, ix 24) order = scores.argsort()[::-1] keep_scores = [] keep_coords = [] keep_n_missing = [] keep_regress = [] keep_rg_bins = [] keep_rg_uncs = [] while order.size > 0: i = order[0] # highest scoring element yy1 = np.maximum(y1[i], y1[order]) xx1 = np.maximum(x1[i], x1[order]) yy2 = np.minimum(y2[i], y2[order]) xx2 = np.minimum(x2[i], x2[order]) w = np.maximum(0, xx2 - xx1 + 1) h = np.maximum(0, yy2 - yy1 + 1) inter = w * h if dim == 3: zz1 = np.maximum(z1[i], z1[order]) zz2 = np.minimum(z2[i], z2[order]) d = np.maximum(0, zz2 - zz1 + 1) inter *= d # overlap between currently highest scoring box and all boxes. ovr = inter / (areas[i] + areas[order] - inter) ovr_fl = inter.astype('float64') / (areas[i] + areas[order] - inter.astype('float64')) assert np.all(ovr==ovr_fl), "ovr {}\n ovr_float {}".format(ovr, ovr_fl) # get all the predictions that match the current box to build one cluster. matches = np.nonzero(ovr > thresh)[0] match_n_ovs = box_n_ovs[order[matches]] match_pc_facts = box_pc_facts[order[matches]] match_patch_id = box_patch_id[order[matches]] match_ov_facts = ovr[matches] match_areas = areas[order[matches]] match_scores = scores[order[matches]] # weight all scores in cluster by patch factors, and size. match_score_weights = match_ov_facts * match_areas * match_pc_facts match_scores *= match_score_weights # for the weighted average, scores have to be divided by the number of total expected preds at the position # of the current cluster. 1 Prediction per patch is expected. therefore, the number of ensembled models is # multiplied by the mean overlaps of patches at this position (boxes of the cluster might partly be # in areas of different overlaps). n_expected_preds = n_ens * np.mean(match_n_ovs) # the number of missing predictions is obtained as the number of patches, # which did not contribute any prediction to the current cluster. n_missing_preds = np.max((0, n_expected_preds - np.unique(match_patch_id).shape[0])) # missing preds are given the mean weighting # (expected prediction is the mean over all predictions in cluster). denom = np.sum(match_score_weights) + n_missing_preds * np.mean(match_score_weights) # compute weighted average score for the cluster avg_score = np.sum(match_scores) / denom # compute weighted average of coordinates for the cluster. now only take existing # predictions into account. avg_coords = [np.sum(y1[order[matches]] * match_scores) / np.sum(match_scores), np.sum(x1[order[matches]] * match_scores) / np.sum(match_scores), np.sum(y2[order[matches]] * match_scores) / np.sum(match_scores), np.sum(x2[order[matches]] * match_scores) / np.sum(match_scores)] if dim == 3: avg_coords.append(np.sum(z1[order[matches]] * match_scores) / np.sum(match_scores)) avg_coords.append(np.sum(z2[order[matches]] * match_scores) / np.sum(match_scores)) if box_regress is not None: # compute wt. avg. of regression vectors (component-wise average) avg_regress = np.sum(box_regress[order[matches]] * match_scores[:, np.newaxis], axis=0) / np.sum( match_scores) avg_rg_bins = np.round(np.sum(box_rg_bins[order[matches]] * match_scores) / np.sum(match_scores)) avg_rg_uncs = np.sum(box_rg_uncs[order[matches]] * match_scores) / np.sum(match_scores) else: avg_regress = np.array(float('NaN')) avg_rg_bins = np.array(float('NaN')) avg_rg_uncs = np.array(float('NaN')) # some clusters might have very low scores due to high amounts of missing predictions. # filter out the with a conservative threshold, to speed up evaluation. if avg_score > 0.01: keep_scores.append(avg_score) keep_coords.append(avg_coords) keep_n_missing.append((n_missing_preds / n_expected_preds * 100)) # relative keep_regress.append(avg_regress) keep_rg_uncs.append(avg_rg_uncs) keep_rg_bins.append(avg_rg_bins) # get index of all elements that were not matched and discard all others. inds = np.nonzero(ovr <= thresh)[0] inds_where = np.where(ovr<=thresh)[0] assert np.all(inds == inds_where), "inds_nonzero {} \ninds_where {}".format(inds, inds_where) order = order[inds] return keep_scores, keep_coords, keep_n_missing, keep_regress, keep_rg_bins, keep_rg_uncs def apply_nms_to_patient(inputs): in_patient_results_list, pid, class_dict, iou_thresh = inputs out_patient_results_list = [] # collect box predictions over batch dimension (slices) and store slice info as slice_ids. for batch in in_patient_results_list: batch_el_boxes = [] for cl in list(class_dict.keys()): det_boxes = [box for box in batch if (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)] box_coords = np.array([box['box_coords'] for box in det_boxes]) box_scores = np.array([box['box_score'] for box in det_boxes]) if 0 not in box_scores.shape: keep_ix = mutils.nms_numpy(box_coords, box_scores, iou_thresh) else: keep_ix = [] batch_el_boxes += [det_boxes[ix] for ix in keep_ix] batch_el_boxes += [box for box in batch if box['box_type'] == 'gt'] out_patient_results_list.append(batch_el_boxes) assert len(in_patient_results_list) == len(out_patient_results_list), "batch dim needs to be maintained, in: {}, out {}".format(len(in_patient_results_list), len(out_patient_results_list)) return [out_patient_results_list, pid] def nms_2to3D(dets, thresh): """ Merges 2D boxes to 3D cubes. For this purpose, boxes of all slices are regarded as lying in one slice. An adaptation of Non-maximum suppression is applied where clusters are found (like in NMS) with the extra constraint that suppressed boxes have to have 'connected' z coordinates w.r.t the core slice (cluster center, highest scoring box, the prevailing box). 'connected' z-coordinates are determined as the z-coordinates with predictions until the first coordinate for which no prediction is found. example: a cluster of predictions was found overlap > iou thresh in xy (like NMS). The z-coordinate of the highest scoring box is 50. Other predictions have 23, 46, 48, 49, 51, 52, 53, 56, 57. Only the coordinates connected with 50 are clustered to one cube: 48, 49, 51, 52, 53. (46 not because nothing was found in 47, so 47 is a 'hole', which interrupts the connection). Only the boxes corresponding to these coordinates are suppressed. All others are kept for building of further clusters. This algorithm works better with a certain min_confidence of predictions, because low confidence (e.g. noisy/cluttery) predictions can break the relatively strong assumption of defining cubes' z-boundaries at the first 'hole' in the cluster. :param dets: (n_detections, (y1, x1, y2, x2, scores, slice_id) :param thresh: iou matchin threshold (like in NMS). :return: keep: (n_keep,) 1D tensor of indices to be kept. :return: keep_z: (n_keep, [z1, z2]) z-coordinates to be added to boxes, which are kept in order to form cubes. """ y1 = dets[:, 0] x1 = dets[:, 1] y2 = dets[:, 2] x2 = dets[:, 3] assert np.all(y1 <= y2) and np.all(x1 <= x2), """"the definition of the coordinates is crucially important here: where maximum is taken needs to be the lower coordinate""" scores = dets[:, -2] slice_id = dets[:, -1] areas = (x2 - x1 + 1) * (y2 - y1 + 1) order = scores.argsort()[::-1] keep = [] keep_z = [] while order.size > 0: # order is the sorted index. maps order to index: order[1] = 24 means (rank1, ix 24) i = order[0] # highest scoring element yy1 = np.maximum(y1[i], y1[order]) # highest scoring element still in >order<, is compared to itself: okay? xx1 = np.maximum(x1[i], x1[order]) yy2 = np.minimum(y2[i], y2[order]) xx2 = np.minimum(x2[i], x2[order]) h = np.maximum(0.0, yy2 - yy1 + 1) w = np.maximum(0.0, xx2 - xx1 + 1) inter = h * w iou = inter / (areas[i] + areas[order] - inter) matches = np.argwhere( iou > thresh) # get all the elements that match the current box and have a lower score slice_ids = slice_id[order[matches]] core_slice = slice_id[int(i)] upper_holes = [ii for ii in np.arange(core_slice, np.max(slice_ids)) if ii not in slice_ids] lower_holes = [ii for ii in np.arange(np.min(slice_ids), core_slice) if ii not in slice_ids] max_valid_slice_id = np.min(upper_holes) if len(upper_holes) > 0 else np.max(slice_ids) min_valid_slice_id = np.max(lower_holes) if len(lower_holes) > 0 else np.min(slice_ids) z_matches = matches[(slice_ids <= max_valid_slice_id) & (slice_ids >= min_valid_slice_id)] # expand by one z voxel since box content is surrounded w/o overlap, i.e., z-content computed as z2-z1 z1 = np.min(slice_id[order[z_matches]]) - 1 z2 = np.max(slice_id[order[z_matches]]) + 1 keep.append(i) keep_z.append([z1, z2]) order = np.delete(order, z_matches, axis=0) return keep, keep_z def apply_2d_3d_merging_to_patient(inputs): """ wrapper around 2Dto3D merging operation. Processes a single patient. Takes 2D patient results (slices in batch dimension) and returns 3D patient results (dummy batch dimension of 1). Applies an adaption of Non-Maximum Surpression (Detailed methodology is described in nms_2to3D). :return. results_dict_boxes: list over batch elements (1 in 3D). each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. :return. pid: string. patient id. """ in_patient_results_list, pid, class_dict, merge_3D_iou = inputs out_patient_results_list = [] for cl in list(class_dict.keys()): det_boxes, slice_ids = [], [] # collect box predictions over batch dimension (slices) and store slice info as slice_ids. for batch_ix, batch in enumerate(in_patient_results_list): batch_element_det_boxes = [(ix, box) for ix, box in enumerate(batch) if (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)] det_boxes += batch_element_det_boxes slice_ids += [batch_ix] * len(batch_element_det_boxes) box_coords = np.array([batch[1]['box_coords'] for batch in det_boxes]) box_scores = np.array([batch[1]['box_score'] for batch in det_boxes]) slice_ids = np.array(slice_ids) if 0 not in box_scores.shape: keep_ix, keep_z = nms_2to3D( np.concatenate((box_coords, box_scores[:, None], slice_ids[:, None]), axis=1), merge_3D_iou) else: keep_ix, keep_z = [], [] # store kept predictions in new results list and add corresponding z-dimension info to coordinates. for kix, kz in zip(keep_ix, keep_z): keep_box = det_boxes[kix][1] keep_box['box_coords'] = list(keep_box['box_coords']) + kz out_patient_results_list.append(keep_box) gt_boxes = [box for b in in_patient_results_list for box in b if box['box_type'] == 'gt'] if len(gt_boxes) > 0: assert np.all([len(box["box_coords"]) == 6 for box in gt_boxes]), "expanded preds to 3D but GT is 2D." out_patient_results_list += gt_boxes return [[out_patient_results_list], pid] # additional list wrapping is extra batch dim. class Predictor: """ Prediction pipeline: - receives a patched patient image (n_patches, c, y, x, (z)) from patient data loader. - forwards patches through model in chunks of batch_size. (method: batch_tiling_forward) - unmolds predictions (boxes and segmentations) to original patient coordinates. (method: spatial_tiling_forward) Ensembling (mode == 'test'): - for inference, forwards 4 mirrored versions of image to through model and unmolds predictions afterwards accordingly (method: data_aug_forward) - for inference, loads multiple parameter-sets of the trained model corresponding to different epochs. for each parameter-set loops over entire test set, runs prediction pipeline for each patient. (method: predict_test_set) Consolidation of predictions: - consolidates a patient's predictions (boxes, segmentations) collected over patches, data_aug- and temporal ensembling, performs clustering and weighted averaging (external function: apply_wbc_to_patient) to obtain consistent outptus. - for 2D networks, consolidates box predictions to 3D cubes via clustering (adaption of non-maximum surpression). (external function: apply_2d_3d_merging_to_patient) Ground truth handling: - dissmisses any ground truth boxes returned by the model (happens in validation mode, patch-based groundtruth) - if provided by data loader, adds patient-wise ground truth to the final predictions to be passed to the evaluator. """ def __init__(self, cf, net, logger, mode): self.cf = cf self.batch_size = cf.batch_size self.logger = logger self.mode = mode self.net = net self.n_ens = 1 self.rank_ix = '0' self.regress_flag = any(['regression' in task for task in self.cf.prediction_tasks]) if self.cf.merge_2D_to_3D_preds: assert self.cf.dim == 2, "Merge 2Dto3D only valid for 2D preds, but current dim is {}.".format(self.cf.dim) if self.mode == 'test': try: self.epoch_ranking = np.load(os.path.join(self.cf.fold_dir, 'epoch_ranking.npy'))[:cf.test_n_epochs] except: raise RuntimeError('no epoch ranking file in fold directory. ' 'seems like you are trying to run testing without prior training...') self.n_ens = cf.test_n_epochs if self.cf.test_aug_axes is not None: self.n_ens *= (len(self.cf.test_aug_axes)+1) self.example_plot_dir = os.path.join(cf.test_dir, "example_plots") os.makedirs(self.example_plot_dir, exist_ok=True) def batch_tiling_forward(self, batch): """ calls the actual network forward method. in patch-based prediction, the batch dimension might be overladed with n_patches >> batch_size, which would exceed gpu memory. In this case, batches are processed in chunks of batch_size. validation mode calls the train method to monitor losses (returned ground truth objects are discarded). test mode calls the test forward method, no ground truth required / involved. :return. results_dict: stores the results for one patient. dictionary with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) - loss / class_loss (only in validation mode) """ img = batch['data'] if img.shape[0] <= self.batch_size: if self.mode == 'val': # call training method to monitor losses results_dict = self.net.train_forward(batch, is_validation=True) # discard returned ground-truth boxes (also training info boxes). results_dict['boxes'] = [[box for box in b if box['box_type'] == 'det'] for b in results_dict['boxes']] elif self.mode == 'test': results_dict = self.net.test_forward(batch, return_masks=self.cf.return_masks_in_test) else: # needs batch tiling split_ixs = np.split(np.arange(img.shape[0]), np.arange(img.shape[0])[::self.batch_size]) chunk_dicts = [] for chunk_ixs in split_ixs[1:]: # first split is elements before 0, so empty b = {k: batch[k][chunk_ixs] for k in batch.keys() if (isinstance(batch[k], np.ndarray) and batch[k].shape[0] == img.shape[0])} if self.mode == 'val': chunk_dicts += [self.net.train_forward(b, is_validation=True)] else: chunk_dicts += [self.net.test_forward(b, return_masks=self.cf.return_masks_in_test)] results_dict = {} # flatten out batch elements from chunks ([chunk, chunk] -> [b, b, b, b, ...]) results_dict['boxes'] = [item for d in chunk_dicts for item in d['boxes']] results_dict['seg_preds'] = np.array([item for d in chunk_dicts for item in d['seg_preds']]) if self.mode == 'val': # if hasattr(self.cf, "losses_to_monitor"): # loss_names = self.cf.losses_to_monitor # else: # loss_names = {name for dic in chunk_dicts for name in dic if 'loss' in name} # estimate patient loss by mean over batch_chunks. Most similar to training loss. results_dict['torch_loss'] = torch.mean(torch.cat([d['torch_loss'] for d in chunk_dicts])) results_dict['class_loss'] = np.mean([d['class_loss'] for d in chunk_dicts]) # discard returned ground-truth boxes (also training info boxes). results_dict['boxes'] = [[box for box in b if box['box_type'] == 'det'] for b in results_dict['boxes']] return results_dict def spatial_tiling_forward(self, batch, patch_crops = None, n_aug='0'): """ forwards batch to batch_tiling_forward method and receives and returns a dictionary with results. if patch-based prediction, the results received from batch_tiling_forward will be on a per-patch-basis. this method uses the provided patch_crops to re-transform all predictions to whole-image coordinates. Patch-origin information of all box-predictions will be needed for consolidation, hence it is stored as 'patch_id', which is a unique string for each patch (also takes current data aug and temporal epoch instances into account). all box predictions get additional information about the amount overlapping patches at the respective position (used for consolidation). :return. results_dict: stores the results for one patient. dictionary with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) - monitor_values (only in validation mode) returned dict is a flattened version with 1 batch instance (3D) or slices (2D) """ if patch_crops is not None: #print("patch_crops not None, applying patch center factor") patches_dict = self.batch_tiling_forward(batch) results_dict = {'boxes': [[] for _ in range(batch['original_img_shape'][0])]} #bc of ohe--> channel dim of seg has size num_classes out_seg_shape = list(batch['original_img_shape']) out_seg_shape[1] = patches_dict["seg_preds"].shape[1] out_seg_preds = np.zeros(out_seg_shape, dtype=np.float16) patch_overlap_map = np.zeros_like(out_seg_preds, dtype='uint8') for pix, pc in enumerate(patch_crops): if self.cf.dim == 3: out_seg_preds[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += patches_dict['seg_preds'][pix] patch_overlap_map[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += 1 elif self.cf.dim == 2: out_seg_preds[pc[4]:pc[5], :, pc[0]:pc[1], pc[2]:pc[3], ] += patches_dict['seg_preds'][pix] patch_overlap_map[pc[4]:pc[5], :, pc[0]:pc[1], pc[2]:pc[3], ] += 1 out_seg_preds[patch_overlap_map > 0] /= patch_overlap_map[patch_overlap_map > 0] results_dict['seg_preds'] = out_seg_preds for pix, pc in enumerate(patch_crops): patch_boxes = patches_dict['boxes'][pix] for box in patch_boxes: # add unique patch id for consolidation of predictions. box['patch_id'] = self.rank_ix + '_' + n_aug + '_' + str(pix) # boxes from the edges of a patch have a lower prediction quality, than the ones at patch-centers. # hence they will be down-weighted for consolidation, using the 'box_patch_center_factor', which is # obtained by a gaussian distribution over positions in the patch and average over spatial dimensions. # Also the info 'box_n_overlaps' is stored for consolidation, which represents the amount of # overlapping patches at the box's position. c = box['box_coords'] #box_centers = np.array([(c[ii] + c[ii+2])/2 for ii in range(len(c)//2)]) box_centers = [(c[ii] + c[ii + 2]) / 2 for ii in range(2)] if self.cf.dim == 3: box_centers.append((c[4] + c[5]) / 2) box['box_patch_center_factor'] = np.mean( [norm.pdf(bc, loc=pc, scale=pc * 0.8) * np.sqrt(2 * np.pi) * pc * 0.8 for bc, pc in zip(box_centers, np.array(self.cf.patch_size) / 2)]) if self.cf.dim == 3: c += np.array([pc[0], pc[2], pc[0], pc[2], pc[4], pc[4]]) int_c = [int(np.floor(ii)) if ix%2 == 0 else int(np.ceil(ii)) for ix, ii in enumerate(c)] box['box_n_overlaps'] = np.mean(patch_overlap_map[:, :, int_c[1]:int_c[3], int_c[0]:int_c[2], int_c[4]:int_c[5]]) results_dict['boxes'][0].append(box) else: c += np.array([pc[0], pc[2], pc[0], pc[2]]) int_c = [int(np.floor(ii)) if ix % 2 == 0 else int(np.ceil(ii)) for ix, ii in enumerate(c)] box['box_n_overlaps'] = np.mean( patch_overlap_map[pc[4], :, int_c[1]:int_c[3], int_c[0]:int_c[2]]) results_dict['boxes'][pc[4]].append(box) if self.mode == 'val': results_dict['torch_loss'] = patches_dict['torch_loss'] results_dict['class_loss'] = patches_dict['class_loss'] else: results_dict = self.batch_tiling_forward(batch) for b in results_dict['boxes']: for box in b: box['box_patch_center_factor'] = 1 box['box_n_overlaps'] = 1 box['patch_id'] = self.rank_ix + '_' + n_aug return results_dict def data_aug_forward(self, batch): """ in val_mode: passes batch through to spatial_tiling method without data_aug. in test_mode: if cf.test_aug is set in configs, createst 4 mirrored versions of the input image, passes all of them to the next processing step (spatial_tiling method) and re-transforms returned predictions to original image version. :return. results_dict: stores the results for one patient. dictionary with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) - loss / class_loss (only in validation mode) """ patch_crops = batch['patch_crop_coords'] if self.patched_patient else None results_list = [self.spatial_tiling_forward(batch, patch_crops)] org_img_shape = batch['original_img_shape'] if self.mode == 'test' and self.cf.test_aug_axes is not None: if isinstance(self.cf.test_aug_axes, (int, float)): self.cf.test_aug_axes = (self.cf.test_aug_axes,) #assert np.all(np.array(self.cf.test_aug_axes)= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']] assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']] chunk_dict['boxes'][ix][boxix]['box_coords'] = coords # re-transform segmentation predictions. chunk_dict['seg_preds'] = np.flip(chunk_dict['seg_preds'], axis=axis) elif hasattr(sp_axis, "__iter__") and tuple(sp_axis)==(0,1) or tuple(sp_axis)==(1,0): #NEED: mirrored patch crops are given as [(y-axis), (x-axis), (y-,x-axis)], obey this order! # mirroring along two axes at same time batch['data'] = np.flip(np.flip(img, axis=axis[0]), axis=axis[1]).copy() chunk_dict = self.spatial_tiling_forward(batch, mirrored_patch_crops[n_aug], n_aug=str(n_aug)) # re-transform coordinates. for ix in range(len(chunk_dict['boxes'])): for boxix in range(len(chunk_dict['boxes'][ix])): coords = chunk_dict['boxes'][ix][boxix]['box_coords'].copy() coords[sp_axis[0]] = org_img_shape[axis[0]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[0]+2] coords[sp_axis[0]+2] = org_img_shape[axis[0]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[0]] coords[sp_axis[1]] = org_img_shape[axis[1]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[1]+2] coords[sp_axis[1]+2] = org_img_shape[axis[1]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[1]] assert coords[2] >= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']] assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']] chunk_dict['boxes'][ix][boxix]['box_coords'] = coords # re-transform segmentation predictions. chunk_dict['seg_preds'] = np.flip(np.flip(chunk_dict['seg_preds'], axis=axis[0]), axis=axis[1]).copy() else: raise Exception("Invalid axis type {} in test augs".format(type(axis))) results_list.append(chunk_dict) batch['data'] = img # aggregate all boxes/seg_preds per batch element from data_aug predictions. results_dict = {} results_dict['boxes'] = [[item for d in results_list for item in d['boxes'][batch_instance]] for batch_instance in range(org_img_shape[0])] # results_dict['seg_preds'] = np.array([[item for d in results_list for item in d['seg_preds'][batch_instance]] # for batch_instance in range(org_img_shape[0])]) results_dict['seg_preds'] = np.stack([dic['seg_preds'] for dic in results_list], axis=1) # needs segs probs in seg_preds entry: results_dict['seg_preds'] = np.sum(results_dict['seg_preds'], axis=1) #add up seg probs from different augs per class if self.mode == 'val': results_dict['torch_loss'] = results_list[0]['torch_loss'] results_dict['class_loss'] = results_list[0]['class_loss'] return results_dict def load_saved_predictions(self): """loads raw predictions saved by self.predict_test_set. aggregates and/or merges 2D boxes to 3D cubes for evaluation (if model predicts 2D but evaluation is run in 3D), according to settings config. :return: list_of_results_per_patient: list over patient results. each entry is a dict with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. - 'batch_dices': dice scores as recorded in raw prediction results. - 'seg_preds': not implemented yet. could replace dices by seg preds to have raw seg info available, however would consume critically large memory amount. todo evaluation of instance/semantic segmentation. """ results_file = 'pred_results.pkl' if not self.cf.held_out_test_set else 'pred_results_held_out.pkl' if not self.cf.held_out_test_set or self.cf.eval_test_fold_wise: self.logger.info("loading saved predictions of fold {}".format(self.cf.fold)) with open(os.path.join(self.cf.fold_dir, results_file), 'rb') as handle: results_list = pickle.load(handle) box_results_list = [(res_dict["boxes"], pid) for res_dict, pid in results_list] da_factor = len(self.cf.test_aug_axes)+1 if self.cf.test_aug_axes is not None else 1 self.n_ens = self.cf.test_n_epochs * da_factor self.logger.info('loaded raw test set predictions with n_patients = {} and n_ens = {}'.format( len(results_list), self.n_ens)) else: self.logger.info("loading saved predictions of hold-out test set") fold_dirs = sorted([os.path.join(self.cf.exp_dir, f) for f in os.listdir(self.cf.exp_dir) if os.path.isdir(os.path.join(self.cf.exp_dir, f)) and f.startswith("fold")]) results_list = [] folds_loaded = 0 for fold in range(self.cf.n_cv_splits): fold_dir = os.path.join(self.cf.exp_dir, 'fold_{}'.format(fold)) if fold_dir in fold_dirs: with open(os.path.join(fold_dir, results_file), 'rb') as handle: fold_list = pickle.load(handle) results_list += fold_list folds_loaded += 1 else: self.logger.info("Skipping fold {} since no saved predictions found.".format(fold)) box_results_list = [] for res_dict, pid in results_list: #without filtering gt out: box_results_list.append((res_dict['boxes'], pid)) #it's usually not right to filter out gts here, is it? da_factor = len(self.cf.test_aug_axes)+1 if self.cf.test_aug_axes is not None else 1 self.n_ens = self.cf.test_n_epochs * da_factor * folds_loaded # -------------- aggregation of boxes via clustering ----------------- if self.cf.clustering == "wbc": self.logger.info('applying WBC to test-set predictions with iou {} and n_ens {} over {} patients'.format( self.cf.clustering_iou, self.n_ens, len(box_results_list))) mp_inputs = [[self.regress_flag, ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou, self.n_ens] for ii in box_results_list] del box_results_list pool = Pool(processes=self.cf.n_workers) box_results_list = pool.map(apply_wbc_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() del mp_inputs elif self.cf.clustering == "nms": self.logger.info('applying standard NMS to test-set predictions with iou {} over {} patients.'.format( self.cf.clustering_iou, len(box_results_list))) pool = Pool(processes=self.cf.n_workers) mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou] for ii in box_results_list] del box_results_list box_results_list = pool.map(apply_nms_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() del mp_inputs if self.cf.merge_2D_to_3D_preds: self.logger.info('applying 2Dto3D merging to test-set predictions with iou = {}.'.format(self.cf.merge_3D_iou)) pool = Pool(processes=self.cf.n_workers) mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.merge_3D_iou] for ii in box_results_list] box_results_list = pool.map(apply_2d_3d_merging_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() del mp_inputs for ix in range(len(results_list)): assert np.all(results_list[ix][1] == box_results_list[ix][1]), "pid mismatch between loaded and aggregated results" results_list[ix][0]["boxes"] = box_results_list[ix][0] return results_list # holds (results_dict, pid) def predict_patient(self, batch): """ predicts one patient. called either directly via loop over validation set in exec.py (mode=='val') or from self.predict_test_set (mode=='test). in val mode: adds 3D ground truth info to predictions and runs consolidation and 2Dto3D merging of predictions. in test mode: returns raw predictions (ground truth addition, consolidation, 2D to 3D merging are done in self.predict_test_set, because patient predictions across several epochs might be needed to be collected first, in case of temporal ensembling). :return. results_dict: stores the results for one patient. dictionary with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) - loss / class_loss (only in validation mode) """ if self.mode=="test": self.logger.info('predicting patient {} for fold {} '.format(np.unique(batch['pid']), self.cf.fold)) # True if patient is provided in patches and predictions need to be tiled. self.patched_patient = 'patch_crop_coords' in list(batch.keys()) # forward batch through prediction pipeline. results_dict = self.data_aug_forward(batch) #has seg probs in entry 'seg_preds' if self.mode == 'val': for b in range(batch['patient_bb_target'].shape[0]): for t in range(len(batch['patient_bb_target'][b])): gt_box = {'box_type': 'gt', 'box_coords': batch['patient_bb_target'][b][t], 'class_targets': batch['patient_class_targets'][b][t]} for name in self.cf.roi_items: gt_box.update({name : batch['patient_'+name][b][t]}) results_dict['boxes'][b].append(gt_box) if 'dice' in self.cf.metrics: if self.patched_patient: assert 'patient_seg' in batch.keys(), "Results_dict preds are in original patient shape." results_dict['batch_dices'] = mutils.dice_per_batch_and_class( results_dict['seg_preds'], batch["patient_seg"] if self.patched_patient else batch['seg'], self.cf.num_seg_classes, convert_to_ohe=True) if self.patched_patient and self.cf.clustering == "wbc": wbc_input = [self.regress_flag, results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.clustering_iou, self.n_ens] results_dict['boxes'] = apply_wbc_to_patient(wbc_input)[0] elif self.patched_patient: nms_inputs = [results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.clustering_iou] results_dict['boxes'] = apply_nms_to_patient(nms_inputs)[0] if self.cf.merge_2D_to_3D_preds: results_dict['2D_boxes'] = results_dict['boxes'] merge_dims_inputs = [results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.merge_3D_iou] results_dict['boxes'] = apply_2d_3d_merging_to_patient(merge_dims_inputs)[0] return results_dict def predict_test_set(self, batch_gen, return_results=True): """ wrapper around test method, which loads multiple (or one) epoch parameters (temporal ensembling), loops through the test set and collects predictions per patient. Also flattens the results per patient and epoch and adds optional ground truth boxes for evaluation. Saves out the raw result list for later analysis and optionally consolidates and returns predictions immediately. :return: (optionally) list_of_results_per_patient: list over patient results. each entry is a dict with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': not implemented yet. todo evaluation of instance/semantic segmentation. """ # -------------- raw predicting ----------------- dict_of_patients_results = OrderedDict() set_of_result_types = set() # get paths of all parameter sets to be loaded for temporal ensembling. (or just one for no temp. ensembling). weight_paths = [os.path.join(self.cf.fold_dir, '{}_best_params.pth'.format(epoch)) for epoch in self.epoch_ranking] for rank_ix, weight_path in enumerate(weight_paths): self.logger.info(('tmp ensembling over rank_ix:{} epoch:{}'.format(rank_ix, weight_path))) self.net.load_state_dict(torch.load(weight_path)) self.net.eval() self.rank_ix = str(rank_ix) with torch.no_grad(): plot_batches = np.random.choice(np.arange(batch_gen['n_test']), size=self.cf.n_test_plots, replace=False) for i in range(batch_gen['n_test']): batch = next(batch_gen['test']) pid = np.unique(batch['pid']) assert len(pid)==1 pid = pid[0] if not pid in dict_of_patients_results.keys(): # store batch info in patient entry of results dict. dict_of_patients_results[pid] = {} dict_of_patients_results[pid]['results_dicts'] = [] dict_of_patients_results[pid]['patient_bb_target'] = batch['patient_bb_target'] for name in self.cf.roi_items: dict_of_patients_results[pid]["patient_"+name] = batch["patient_"+name] stime = time.time() results_dict = self.predict_patient(batch) #only holds "boxes", "seg_preds" # needs ohe seg probs in seg_preds entry: results_dict['seg_preds'] = np.argmax(results_dict['seg_preds'], axis=1)[:,np.newaxis] self.logger.info("predicting patient {} with weight rank {} (progress: {}/{}) took {:.2f}s".format( str(pid), rank_ix, (rank_ix)*batch_gen['n_test']+(i+1), len(weight_paths)*batch_gen['n_test'], time.time()-stime)) if i in plot_batches and (not self.patched_patient or 'patient_data' in batch.keys()): try: # view qualitative results of random test case self.logger.time("test_plot") out_file = os.path.join(self.example_plot_dir, 'batch_example_test_{}_rank_{}.png'.format(self.cf.fold, rank_ix)) - plg.view_batch(self.cf, batch, res_dict=results_dict, out_file=out_file, - show_seg_ids='dice' in self.cf.metrics, - has_colorchannels=self.cf.has_colorchannels, show_gt_labels=True) - self.logger.info("generated example test plot {} in {:.2f}s".format(os.path.basename(out_file), self.logger.time("test_plot"))) + utils.split_off_process(plg.view_batch, self.cf, batch, results_dict, + has_colorchannels=self.cf.has_colorchannels, + show_gt_labels=True, show_seg_ids='dice' in self.cf.metrics, + get_time="test-example plot", out_file=out_file) + self.logger.info("split-off example test plot {} in {:.2f}s".format( + os.path.basename(out_file), self.logger.time("test_plot"))) except Exception as e: self.logger.info("WARNING: error in view_batch: {}".format(e)) if 'dice' in self.cf.metrics: if self.patched_patient: assert 'patient_seg' in batch.keys(), "Results_dict preds are in original patient shape." results_dict['batch_dices'] = mutils.dice_per_batch_and_class( results_dict['seg_preds'], batch["patient_seg"] if self.patched_patient else batch['seg'], self.cf.num_seg_classes, convert_to_ohe=True) dict_of_patients_results[pid]['results_dicts'].append({k:v for k,v in results_dict.items() if k in ["boxes", "batch_dices"]}) # collect result types to know which ones to look for when saving set_of_result_types.update(dict_of_patients_results[pid]['results_dicts'][-1].keys()) # -------------- re-order, save raw results ----------------- self.logger.info('finished predicting test set. starting aggregation of predictions.') results_per_patient = [] for pid, p_dict in dict_of_patients_results.items(): # dict_of_patients_results[pid]['results_list'] has length batch['n_test'] results_dict = {} # collect all boxes/seg_preds of same batch_instance over temporal instances. b_size = len(p_dict['results_dicts'][0]["boxes"]) for res_type in [rtype for rtype in set_of_result_types if rtype in ["boxes", "batch_dices"]]:#, "seg_preds"]]: if not 'batch' in res_type: #assume it's results on batch-element basis results_dict[res_type] = [[item for rank_dict in p_dict['results_dicts'] for item in rank_dict[res_type][batch_instance]] for batch_instance in range(b_size)] else: results_dict[res_type] = [] for dict in p_dict['results_dicts']: if 'dice' in res_type: item = dict[res_type] #dict['batch_dices'] has shape (num_seg_classes,) assert len(item) == self.cf.num_seg_classes, \ "{}, {}".format(len(item), self.cf.num_seg_classes) else: raise NotImplementedError results_dict[res_type].append(item) # rdict[dice] shape (n_rank_epochs (n_saved_ranks), nsegclasses) # calc mean over test epochs so inline with shape from sampling results_dict[res_type] = np.mean(results_dict[res_type], axis=0) #maybe error type with other than dice if not hasattr(self.cf, "eval_test_separately") or not self.cf.eval_test_separately: # add unpatched 2D or 3D (if dim==3 or merge_2D_to_3D) ground truth boxes for evaluation. for b in range(p_dict['patient_bb_target'].shape[0]): for targ in range(len(p_dict['patient_bb_target'][b])): gt_box = {'box_type': 'gt', 'box_coords':p_dict['patient_bb_target'][b][targ], 'class_targets': p_dict['patient_class_targets'][b][targ]} for name in self.cf.roi_items: gt_box.update({name: p_dict["patient_"+name][b][targ]}) results_dict['boxes'][b].append(gt_box) results_per_patient.append([results_dict, pid]) out_string = 'pred_results_held_out' if self.cf.held_out_test_set else 'pred_results' with open(os.path.join(self.cf.fold_dir, '{}.pkl'.format(out_string)), 'wb') as handle: pickle.dump(results_per_patient, handle) if return_results: # -------------- results processing, clustering, etc. ----------------- final_patient_box_results = [ (res_dict["boxes"], pid) for res_dict,pid in results_per_patient ] if self.cf.clustering == "wbc": self.logger.info('applying WBC to test-set predictions with iou = {} and n_ens = {}.'.format( self.cf.clustering_iou, self.n_ens)) mp_inputs = [[self.regress_flag, ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou, self.n_ens] for ii in final_patient_box_results] del final_patient_box_results pool = Pool(processes=self.cf.n_workers) final_patient_box_results = pool.map(apply_wbc_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() del mp_inputs elif self.cf.clustering == "nms": self.logger.info('applying standard NMS to test-set predictions with iou = {}.'.format(self.cf.clustering_iou)) pool = Pool(processes=self.cf.n_workers) mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou] for ii in final_patient_box_results] del final_patient_box_results final_patient_box_results = pool.map(apply_nms_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() del mp_inputs if self.cf.merge_2D_to_3D_preds: self.logger.info('applying 2D-to-3D merging to test-set predictions with iou = {}.'.format(self.cf.merge_3D_iou)) mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.merge_3D_iou] for ii in final_patient_box_results] del final_patient_box_results pool = Pool(processes=self.cf.n_workers) final_patient_box_results = pool.map(apply_2d_3d_merging_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() del mp_inputs # final_patient_box_results holds [avg_boxes, pid] if wbc for ix in range(len(results_per_patient)): assert results_per_patient[ix][1] == final_patient_box_results[ix][1], "should be same pid" results_per_patient[ix][0]["boxes"] = final_patient_box_results[ix][0] # results_per_patient = [(res_dict["boxes"] = boxes, pid) for (boxes,pid) in final_patient_box_results] return results_per_patient # holds list of (results_dict, pid) diff --git a/utils/dataloader_utils.py b/utils/dataloader_utils.py index eb53a44..cdd92e1 100644 --- a/utils/dataloader_utils.py +++ b/utils/dataloader_utils.py @@ -1,742 +1,743 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import plotting as plg import os from multiprocessing import Pool, Lock import pickle import warnings import numpy as np import pandas as pd from batchgenerators.transforms.abstract_transforms import AbstractTransform from scipy.ndimage.measurements import label as lb from torch.utils.data import Dataset as torchDataset from batchgenerators.dataloading.data_loader import SlimDataLoaderBase import utils.exp_utils as utils import data_manager as dmanager for msg in ["This figure includes Axes that are not compatible with tight_layout", "Data has no positive values, and therefore cannot be log-scaled."]: warnings.filterwarnings("ignore", msg) class AttributeDict(dict): __getattr__ = dict.__getitem__ __setattr__ = dict.__setitem__ ################################## # data loading, organisation # ################################## class fold_generator: """ generates splits of indices for a given length of a dataset to perform n-fold cross-validation. splits each fold into 3 subsets for training, validation and testing. This form of cross validation uses an inner loop test set, which is useful if test scores shall be reported on a statistically reliable amount of patients, despite limited size of a dataset. If hold out test set is provided and hence no inner loop test set needed, just add test_idxs to the training data in the dataloader. This creates straight-forward train-val splits. :returns names list: list of len n_splits. each element is a list of len 3 for train_ix, val_ix, test_ix. """ def __init__(self, seed, n_splits, len_data): """ :param seed: Random seed for splits. :param n_splits: number of splits, e.g. 5 splits for 5-fold cross-validation :param len_data: number of elements in the dataset. """ self.tr_ix = [] self.val_ix = [] self.te_ix = [] self.slicer = None self.missing = 0 self.fold = 0 self.len_data = len_data self.n_splits = n_splits self.myseed = seed self.boost_val = 0 def init_indices(self): t = list(np.arange(self.l)) # round up to next splittable data amount. split_length = int(np.ceil(len(t) / float(self.n_splits))) self.slicer = split_length self.mod = len(t) % self.n_splits if self.mod > 0: # missing is the number of folds, in which the new splits are reduced to account for missing data. self.missing = self.n_splits - self.mod self.te_ix = t[:self.slicer] self.tr_ix = t[self.slicer:] self.val_ix = self.tr_ix[:self.slicer] self.tr_ix = self.tr_ix[self.slicer:] def new_fold(self): slicer = self.slicer if self.fold < self.missing : slicer = self.slicer - 1 temp = self.te_ix # catch exception mod == 1: test set collects 1+ data since walk through both roudned up splits. # account for by reducing last fold split by 1. if self.fold == self.n_splits-2 and self.mod ==1: temp += self.val_ix[-1:] self.val_ix = self.val_ix[:-1] self.te_ix = self.val_ix self.val_ix = self.tr_ix[:slicer] self.tr_ix = self.tr_ix[slicer:] + temp def get_fold_names(self): names_list = [] rgen = np.random.RandomState(self.myseed) cv_names = np.arange(self.len_data) rgen.shuffle(cv_names) self.l = len(cv_names) self.init_indices() for split in range(self.n_splits): train_names, val_names, test_names = cv_names[self.tr_ix], cv_names[self.val_ix], cv_names[self.te_ix] names_list.append([train_names, val_names, test_names, self.fold]) self.new_fold() self.fold += 1 return names_list class FoldGenerator(): r"""takes a set of elements (identifiers) and randomly splits them into the specified amt of subsets. """ def __init__(self, identifiers, seed, n_splits=5): self.ids = np.array(identifiers) self.n_splits = n_splits self.seed = seed def generate_splits(self, n_splits=None): if n_splits is None: n_splits = self.n_splits rgen = np.random.RandomState(self.seed) rgen.shuffle(self.ids) self.splits = list(np.array_split(self.ids, n_splits, axis=0)) # already returns list, but to be sure return self.splits class Dataset(torchDataset): r"""Parent Class for actual Dataset classes to inherit from! """ def __init__(self, cf, data_sourcedir=None): super(Dataset, self).__init__() self.cf = cf self.data_sourcedir = cf.data_sourcedir if data_sourcedir is None else data_sourcedir self.data_dir = cf.data_dir if hasattr(cf, 'data_dir') else self.data_sourcedir self.data_dest = cf.data_dest if hasattr(cf, "data_dest") else self.data_sourcedir self.data = {} self.set_ids = [] def copy_data(self, cf, file_subset, keep_packed=False, del_after_unpack=False): if os.path.normpath(self.data_sourcedir) != os.path.normpath(self.data_dest): self.data_sourcedir = os.path.join(self.data_sourcedir, '') args = AttributeDict({ "source" : self.data_sourcedir, "destination" : self.data_dest, "recursive" : True, "cp_only_npz" : False, "keep_packed" : keep_packed, "del_after_unpack" : del_after_unpack, "threads" : 16 if self.cf.server_env else os.cpu_count() }) dmanager.copy(args, file_subset=file_subset) self.data_dir = self.data_dest def __len__(self): return len(self.data) def __getitem__(self, id): """Return a sample of the dataset, i.e.,the dict of the id """ return self.data[id] def __iter__(self): return self.data.__iter__() def init_FoldGenerator(self, seed, n_splits): self.fg = FoldGenerator(self.set_ids, seed=seed, n_splits=n_splits) def generate_splits(self, check_file): if not os.path.exists(check_file): self.fg.generate_splits() with open(check_file, 'wb') as handle: pickle.dump(self.fg.splits, handle) else: with open(check_file, 'rb') as handle: self.fg.splits = pickle.load(handle) def calc_statistics(self, subsets=None, plot_dir=None, overall_stats=True): if self.df is None: self.df = pd.DataFrame() balance_t = self.cf.balance_target if hasattr(self.cf, "balance_target") else "class_targets" self.df._metadata.append(balance_t) if balance_t=="class_targets": mapper = lambda cl_id: self.cf.class_id2label[cl_id] labels = self.cf.class_id2label.values() elif balance_t=="rg_bin_targets": mapper = lambda rg_bin: self.cf.bin_id2label[rg_bin] labels = self.cf.bin_id2label.values() # elif balance_t=="regression_targets": # # todo this wont work # mapper = lambda rg_val: AttributeDict({"name":rg_val}) #self.cf.bin_id2label[self.cf.rg_val_to_bin_id(rg_val)] # labels = self.cf.bin_id2label.values() elif balance_t=="lesion_gleasons": mapper = lambda gs: self.cf.gs2label[gs] labels = self.cf.gs2label.values() else: mapper = lambda x: AttributeDict({"name":x}) labels = None for pid, subj_data in self.data.items(): unique_ts, counts = np.unique(subj_data[balance_t], return_counts=True) self.df = self.df.append(pd.DataFrame({"pid": [pid], **{mapper(unique_ts[i]).name: [counts[i]] for i in range(len(unique_ts))}}), ignore_index=True, sort=True) self.df = self.df.fillna(0) if overall_stats: df = self.df.drop("pid", axis=1) df = df.reindex(sorted(df.columns), axis=1).astype('uint32') print("Overall dataset roi counts per target kind:"); print(df.sum()) if subsets is not None: self.df["subset"] = np.nan self.df["display_order"] = np.nan for ix, (subset, pids) in enumerate(subsets.items()): self.df.loc[self.df.pid.isin(pids), "subset"] = subset self.df.loc[self.df.pid.isin(pids), "display_order"] = ix df = self.df.groupby("subset").agg("sum").drop("pid", axis=1, errors='ignore').astype('int64') df = df.sort_values(by=['display_order']).drop('display_order', axis=1) df = df.reindex(sorted(df.columns), axis=1) print("Fold {} dataset roi counts per target kind:".format(self.cf.fold)); print(df) if plot_dir is not None: os.makedirs(plot_dir, exist_ok=True) if subsets is not None: plg.plot_fold_stats(self.cf, df, labels, os.path.join(plot_dir, "data_stats_fold_" + str(self.cf.fold))+".pdf") if overall_stats: plg.plot_data_stats(self.cf, df, labels, os.path.join(plot_dir, 'data_stats_overall.pdf')) return df, labels def get_class_balanced_patients(all_pids, class_targets, batch_size, num_classes, random_ratio=0): ''' samples towards equilibrium of classes (on basis of total RoI counts). for highly imbalanced dataset, this might be a too strong requirement. :param class_targets: dic holding {patient_specifier : ROI class targets}, list position of ROI target corresponds to respective seg label - 1 :param batch_size: :param num_classes: :return: ''' # assert len(all_pids)>=batch_size, "not enough eligible pids {} to form a single batch of size {}".format(len(all_pids), batch_size) class_counts = {k: 0 for k in range(1,num_classes+1)} not_picked = np.array(all_pids) batch_patients = np.empty((batch_size,), dtype=not_picked.dtype) rarest_class = np.random.randint(1,num_classes+1) for ix in range(batch_size): if len(not_picked) == 0: warnings.warn("Dataset too small to generate batch with unique samples; => recycling.") not_picked = np.array(all_pids) np.random.shuffle(not_picked) #this could actually go outside(above) the loop. pick = not_picked[0] for cand in not_picked: if np.count_nonzero(class_targets[cand] == rarest_class) > 0: pick = cand cand_rarest_class = np.argmin([np.count_nonzero(class_targets[cand] == cl) for cl in range(1,num_classes+1)])+1 # if current batch already bigger than the batch random ratio, then # check that weakest class in this patient is not the weakest in current batch (since needs to be boosted) # also that at least one roi of this patient belongs to weakest class. If True, keep patient, else keep looking. if (cand_rarest_class != rarest_class and np.count_nonzero(class_targets[cand] == rarest_class) > 0) \ or ix < int(batch_size * random_ratio): break for c in range(1,num_classes+1): class_counts[c] += np.count_nonzero(class_targets[pick] == c) if not ix < int(batch_size * random_ratio) and class_counts[rarest_class] == 0: # means searched thru whole set without finding rarest class print("Class {} not represented in current dataset.".format(rarest_class)) rarest_class = np.argmin(([class_counts[c] for c in range(1,num_classes+1)]))+1 batch_patients[ix] = pick not_picked = not_picked[not_picked != pick] # removes pick return batch_patients class BatchGenerator(SlimDataLoaderBase): """ create the training/validation batch generator. Randomly sample batch_size patients from the data set, (draw a random slice if 2D), pad-crop them to equal sizes and merge to an array. :param data: data dictionary as provided by 'load_dataset' :param img_modalities: list of strings ['adc', 'b1500'] from config :param batch_size: number of patients to sample for the batch :param pre_crop_size: equal size for merging the patients to a single array (before the final random-crop in data aug.) :return dictionary containing the batch data / seg / pids as lists; the augmenter will later concatenate them into an array. """ def __init__(self, cf, data, sample_pids_w_replace=True, max_batches=None, raise_stop_iteration=False, n_threads=None, seed=0): if n_threads is None: n_threads = cf.n_workers super(BatchGenerator, self).__init__(data, cf.batch_size, number_of_threads_in_multithreaded=n_threads) self.cf = cf self.random_count = int(cf.batch_random_ratio * cf.batch_size) self.plot_dir = os.path.join(self.cf.plot_dir, 'train_generator') self.max_batches = max_batches self.raise_stop = raise_stop_iteration self.thread_id = 0 self.batches_produced = 0 self.dataset_length = len(self._data) self.dataset_pids = list(self._data.keys()) self.rgen = np.random.RandomState(seed=seed) self.eligible_pids = self.rgen.permutation(self.dataset_pids.copy()) self.eligible_pids = np.array_split(self.eligible_pids, self.number_of_threads_in_multithreaded) self.eligible_pids = sorted(self.eligible_pids, key=len, reverse=True) self.sample_pids_w_replace = sample_pids_w_replace if not self.sample_pids_w_replace: assert len(self.dataset_pids) / self.number_of_threads_in_multithreaded >= self.batch_size, \ "at least one batch needed per thread. dataset size: {}, n_threads: {}, batch_size: {}.".format( len(self.dataset_pids), self.number_of_threads_in_multithreaded, self.batch_size) self.lock = Lock() self.stats = {"roi_counts": np.zeros((self.cf.num_classes,), dtype='uint32'), "empty_counts": np.zeros((self.cf.num_classes,), dtype='uint32')} if hasattr(cf, "balance_target"): # WARNING: "balance targets are only implemented for 1-d targets (or 1-component vectors)" self.balance_target = cf.balance_target else: self.balance_target = "class_targets" self.targets = {k:v[self.balance_target] for (k,v) in self._data.items()} def set_thread_id(self, thread_id): self.thread_ids = self.eligible_pids[thread_id] self.thread_id = thread_id def reset(self): self.batches_produced = 0 self.thread_ids = self.rgen.permutation(self.eligible_pids[self.thread_id]) - def sample_targets_to_weights(self, targets): - weights = targets * self.fg_bg_weights + @staticmethod + def sample_targets_to_weights(targets, fg_bg_weights): + weights = targets * fg_bg_weights return weights def balance_target_distribution(self, plot=False): """Impose a drawing distribution over samples. Distribution should be designed so that classes' fg and bg examples are (as good as possible) shown in equal frequency. Since we are dealing with rois, fg/bg weights count a sample (e.g., a patient) with **at least** one occurrence as fg, otherwise bg. For fg weights among classes, each RoI counts. :param all_pids: :param self.targets: dic holding {patient_specifier : patient-wise-unique ROI targets} :return: probability distribution over all pids. draw without replace from this. """ self.unique_ts = np.unique([v for pat in self.targets.values() for v in pat]) self.sample_stats = pd.DataFrame(columns=[str(ix)+suffix for ix in self.unique_ts for suffix in ["", "_bg"]], index=list(self.targets.keys())) for pid in self.sample_stats.index: for targ in self.unique_ts: fg_count = np.count_nonzero(self.targets[pid] == targ) self.sample_stats.loc[pid, str(targ)] = int(fg_count > 0) self.sample_stats.loc[pid, str(targ)+"_bg"] = int(fg_count == 0) self.targ_stats = self.sample_stats.agg( ("sum", lambda col: col.sum() / len(self._data)), axis=0, sort=False).rename({"": "relative"}) anchor = 1. - self.targ_stats.loc["relative"].iloc[0] self.fg_bg_weights = anchor / self.targ_stats.loc["relative"] cum_weights = anchor * len(self.fg_bg_weights) self.fg_bg_weights /= cum_weights - self.p_probs = self.sample_stats.apply(self.sample_targets_to_weights, axis=1).sum(axis=1) + self.p_probs = self.sample_stats.apply(self.sample_targets_to_weights, args=(self.fg_bg_weights,), axis=1).sum(axis=1) self.p_probs = self.p_probs / self.p_probs.sum() if plot: print("Applying class-weights:\n {}".format(self.fg_bg_weights)) if len(self.sample_stats.columns) == 2: # assert that probs are calc'd correctly: # (self.p_probs * self.sample_stats["1"]).sum() == (self.p_probs * self.sample_stats["1_bg"]).sum() # only works if one label per patient (multi-label expectations depend on multi-label occurences). expectations = [] for targ in self.sample_stats.columns: expectations.append((self.p_probs * self.sample_stats[targ]).sum()) assert np.allclose(expectations, expectations[0], atol=1e-4), "expectation values for fgs/bgs: {}".format(expectations) # get unique foreground targets per patient, assign -1 to an "empty" patient (has no foreground) #patient_ts = [np.unique(lst) if len([t for t in lst if np.any(t>0)])>0 else [-1] for lst in self.targets.values()] #bg_mask = np.array([np.all(lst == [-1]) for lst in patient_ts]) #unique_ts, t_counts = np.unique([t for lst in patient_ts for t in lst if t!=-1], return_counts=True) # t_probs = t_counts.sum() / t_counts # t_probs /= t_probs.sum() # t_probs = {t : t_probs[ix] for ix, t in enumerate(unique_ts)} # t_probs[-1] = 0. # # fail if balance target is not a number (i.e., a vector) # self.p_probs = np.array([ max([t_probs[t] for t in lst]) for lst in patient_ts ]) # #normalize # self.p_probs /= self.p_probs.sum() # rescale probs of empty samples # if not 0 == self.p_probs[bg_mask].shape[0]: # #rescale_f = (1 - self.cf.empty_samples_ratio) / self.p_probs[~bg_mask].sum() # rescale_f = 1 / self.p_probs[~bg_mask].sum() # self.p_probs *= rescale_f # self.p_probs[bg_mask] = 0. #self.cf.empty_samples_ratio/self.p_probs[bg_mask].shape[0] #self.unique_ts = unique_ts if plot: os.makedirs(self.plot_dir, exist_ok=True) plg.plot_batchgen_distribution(self.cf, self.dataset_pids, self.p_probs, self.balance_target, out_file=os.path.join(self.plot_dir, "train_gen_distr_"+str(self.cf.fold)+".png")) return self.p_probs def get_batch_pids(self): if self.max_batches is not None and self.batches_produced * self.number_of_threads_in_multithreaded \ + self.thread_id >= self.max_batches: self.reset() raise StopIteration if self.sample_pids_w_replace: # fully random patients batch_pids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False)) # target-balanced patients batch_pids += list(np.random.choice( self.dataset_pids, size=self.batch_size - self.random_count, replace=False, p=self.p_probs)) else: with self.lock: if len(self.thread_ids) == 0: if self.raise_stop: self.reset() raise StopIteration else: self.thread_ids = self.rgen.permutation(self.eligible_pids[self.thread_id]) batch_pids = self.thread_ids[:self.batch_size] # batch_pids = np.random.choice(self.thread_ids, size=self.batch_size, replace=False) self.thread_ids = [pid for pid in self.thread_ids if pid not in batch_pids] self.batches_produced += 1 return batch_pids def generate_train_batch(self): # to be overriden by child # everything done in here is per batch # print statements in here get confusing due to multithreading raise NotImplementedError def print_stats(self, logger=None, file=None, plot_file=None, plot=True): print_f = utils.CombinedPrinter(logger, file) print_f('\n***Final Training Stats***') total_count = np.sum(self.stats['roi_counts']) for tix, count in enumerate(self.stats['roi_counts']): #name = self.cf.class_dict[tix] if self.balance_target=="class_targets" else str(self.unique_ts[tix]) name=str(self.unique_ts[tix]) print_f('{}: {} rois seen ({:.1f}%).'.format(name, count, count / total_count * 100)) total_samples = self.cf.num_epochs*self.cf.num_train_batches*self.cf.batch_size empties = [ '{}: {} ({:.1f}%)'.format(str(name), self.stats['empty_counts'][tix], self.stats['empty_counts'][tix]/total_samples*100) for tix, name in enumerate(self.unique_ts) ] empties = ", ".join(empties) print_f('empty samples seen: {}\n'.format(empties)) if plot: if plot_file is None: plot_file = os.path.join(self.plot_dir, "train_gen_stats_{}.png".format(self.cf.fold)) os.makedirs(self.plot_dir, exist_ok=True) plg.plot_batchgen_stats(self.cf, self.stats, empties, self.balance_target, self.unique_ts, plot_file) class PatientBatchIterator(SlimDataLoaderBase): """ creates a val/test generator. Step through the dataset and return dictionaries per patient. 2D is a special case of 3D patching with patch_size[2] == 1 (slices) Creates whole Patient batch and targets, and - if necessary - patchwise batch and targets. Appends patient targets anyway for evaluation. For Patching, shifts all patches into batch dimension. batch_tiling_forward will take care of exceeding batch dimensions. This iterator/these batches are not intended to go through MTaugmenter afterwards """ def __init__(self, cf, data): super(PatientBatchIterator, self).__init__(data, 0) self.cf = cf self.dataset_length = len(self._data) self.dataset_pids = list(self._data.keys()) def generate_train_batch(self, pid=None): # to be overriden by child return ################################### # transforms, image manipulation # ################################### def get_patch_crop_coords(img, patch_size, min_overlap=30): """ _:param img (y, x, (z)) _:param patch_size: list of len 2 (2D) or 3 (3D). _:param min_overlap: minimum required overlap of patches. If too small, some areas are poorly represented only at edges of single patches. _:return ndarray: shape (n_patches, 2*dim). crop coordinates for each patch. """ crop_coords = [] for dim in range(len(img.shape)): n_patches = int(np.ceil(img.shape[dim] / patch_size[dim])) # no crops required in this dimension, add image shape as coordinates. if n_patches == 1: crop_coords.append([(0, img.shape[dim])]) continue # fix the two outside patches to coords patchsize/2 and interpolate. center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1) if (patch_size[dim] - center_dists) < min_overlap: n_patches += 1 center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1) patch_centers = np.round([(patch_size[dim] / 2 + (center_dists * ii)) for ii in range(n_patches)]) dim_crop_coords = [(center - patch_size[dim] / 2, center + patch_size[dim] / 2) for center in patch_centers] crop_coords.append(dim_crop_coords) coords_mesh_grid = [] for ymin, ymax in crop_coords[0]: for xmin, xmax in crop_coords[1]: if len(crop_coords) == 3 and patch_size[2] > 1: for zmin, zmax in crop_coords[2]: coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmax]) elif len(crop_coords) == 3 and patch_size[2] == 1: for zmin in range(img.shape[2]): coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmin + 1]) else: coords_mesh_grid.append([ymin, ymax, xmin, xmax]) return np.array(coords_mesh_grid).astype(int) def pad_nd_image(image, new_shape=None, mode="edge", kwargs=None, return_slicer=False, shape_must_be_divisible_by=None): """ one padder to pad them all. Documentation? Well okay. A little bit. by Fabian Isensee :param image: nd image. can be anything :param new_shape: what shape do you want? new_shape does not have to have the same dimensionality as image. If len(new_shape) < len(image.shape) then the last axes of image will be padded. If new_shape < image.shape in any of the axes then we will not pad that axis, but also not crop! (interpret new_shape as new_min_shape) Example: image.shape = (10, 1, 512, 512); new_shape = (768, 768) -> result: (10, 1, 768, 768). Cool, huh? image.shape = (10, 1, 512, 512); new_shape = (364, 768) -> result: (10, 1, 512, 768). :param mode: see np.pad for documentation :param return_slicer: if True then this function will also return what coords you will need to use when cropping back to original shape :param shape_must_be_divisible_by: for network prediction. After applying new_shape, make sure the new shape is divisibly by that number (can also be a list with an entry for each axis). Whatever is missing to match that will be padded (so the result may be larger than new_shape if shape_must_be_divisible_by is not None) :param kwargs: see np.pad for documentation """ if kwargs is None: kwargs = {} if new_shape is not None: old_shape = np.array(image.shape[-len(new_shape):]) else: assert shape_must_be_divisible_by is not None assert isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)) new_shape = image.shape[-len(shape_must_be_divisible_by):] old_shape = new_shape num_axes_nopad = len(image.shape) - len(new_shape) new_shape = [max(new_shape[i], old_shape[i]) for i in range(len(new_shape))] if not isinstance(new_shape, np.ndarray): new_shape = np.array(new_shape) if shape_must_be_divisible_by is not None: if not isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)): shape_must_be_divisible_by = [shape_must_be_divisible_by] * len(new_shape) else: assert len(shape_must_be_divisible_by) == len(new_shape) for i in range(len(new_shape)): if new_shape[i] % shape_must_be_divisible_by[i] == 0: new_shape[i] -= shape_must_be_divisible_by[i] new_shape = np.array([new_shape[i] + shape_must_be_divisible_by[i] - new_shape[i] % shape_must_be_divisible_by[i] for i in range(len(new_shape))]) difference = new_shape - old_shape pad_below = difference // 2 pad_above = difference // 2 + difference % 2 pad_list = [[0, 0]]*num_axes_nopad + list([list(i) for i in zip(pad_below, pad_above)]) res = np.pad(image, pad_list, mode, **kwargs) if not return_slicer: return res else: pad_list = np.array(pad_list) pad_list[:, 1] = np.array(res.shape) - pad_list[:, 1] slicer = list(slice(*i) for i in pad_list) return res, slicer def convert_seg_to_bounding_box_coordinates(data_dict, dim, roi_item_keys, get_rois_from_seg=False, class_specific_seg=False): '''adapted from batchgenerators :param data_dict: seg: segmentation with labels indicating roi_count (get_rois_from_seg=False) or classes (get_rois_from_seg=True), class_targets: list where list index corresponds to roi id (roi_count) :param dim: :param roi_item_keys: keys of the roi-wise items in data_dict to process :param n_rg_feats: nr of regression vector features :param get_rois_from_seg: :return: coords (y1,x1,y2,x2 (,z1,z2)) where the segmentation GT is framed by +1 voxel, i.e., for an object with z-extensions z1=0 through z2=5, bbox target coords will be z1=-1, z2=6. (analogically for x,y). data_dict['roi_masks']: (b, n(b), c, h(n), w(n) (z(n))) list like roi_labels but with arrays (masks) inplace of integers. c==1 if segmentation not one-hot encoded. ''' bb_target = [] roi_masks = [] roi_items = {name:[] for name in roi_item_keys} out_seg = np.copy(data_dict['seg']) for b in range(data_dict['seg'].shape[0]): p_coords_list = [] #p for patient? p_roi_masks_list = [] p_roi_items_lists = {name:[] for name in roi_item_keys} if np.sum(data_dict['seg'][b] != 0) > 0: if get_rois_from_seg: clusters, n_cands = lb(data_dict['seg'][b]) data_dict['class_targets'][b] = [data_dict['class_targets'][b]] * n_cands else: n_cands = int(np.max(data_dict['seg'][b])) rois = np.array( [(data_dict['seg'][b] == ii) * 1 for ii in range(1, n_cands + 1)], dtype='uint8') # separate clusters for rix, r in enumerate(rois): if np.sum(r != 0) > 0: # check if the roi survived slicing (3D->2D) and data augmentation (cropping etc.) seg_ixs = np.argwhere(r != 0) coord_list = [np.min(seg_ixs[:, 1]) - 1, np.min(seg_ixs[:, 2]) - 1, np.max(seg_ixs[:, 1]) + 1, np.max(seg_ixs[:, 2]) + 1] if dim == 3: coord_list.extend([np.min(seg_ixs[:, 3]) - 1, np.max(seg_ixs[:, 3]) + 1]) p_coords_list.append(coord_list) p_roi_masks_list.append(r) # add background class = 0. rix is a patient wide index of lesions. since 'class_targets' is # also patient wide, this assignment is not dependent on patch occurrences. for name in roi_item_keys: p_roi_items_lists[name].append(data_dict[name][b][rix]) assert data_dict["class_targets"][b][rix]>=1, "convertsegtobbox produced bg roi w cl targ {} and unique roi seg {}".format(data_dict["class_targets"][b][rix], np.unique(r)) if class_specific_seg: out_seg[b][data_dict['seg'][b] == rix + 1] = data_dict['class_targets'][b][rix] if not class_specific_seg: out_seg[b][data_dict['seg'][b] > 0] = 1 bb_target.append(np.array(p_coords_list)) roi_masks.append(np.array(p_roi_masks_list)) for name in roi_item_keys: roi_items[name].append(np.array(p_roi_items_lists[name])) else: bb_target.append([]) roi_masks.append(np.zeros_like(data_dict['seg'][b], dtype='uint8')[None]) for name in roi_item_keys: roi_items[name].append(np.array([])) if get_rois_from_seg: data_dict.pop('class_targets', None) data_dict['bb_target'] = np.array(bb_target) data_dict['roi_masks'] = np.array(roi_masks) data_dict['seg'] = out_seg for name in roi_item_keys: data_dict[name] = np.array(roi_items[name]) return data_dict class ConvertSegToBoundingBoxCoordinates(AbstractTransform): """ Converts segmentation masks into bounding box coordinates. """ def __init__(self, dim, roi_item_keys, get_rois_from_seg=False, class_specific_seg=False): self.dim = dim self.roi_item_keys = roi_item_keys self.get_rois_from_seg = get_rois_from_seg self.class_specific_seg = class_specific_seg def __call__(self, **data_dict): return convert_seg_to_bounding_box_coordinates(data_dict, self.dim, self.roi_item_keys, self.get_rois_from_seg, self.class_specific_seg) ############################# # data packing / unpacking # not used, data_manager.py used instead ############################# def get_case_identifiers(folder): case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith("npz")] return case_identifiers def convert_to_npy(npz_file): if not os.path.isfile(npz_file[:-3] + "npy"): a = np.load(npz_file)['data'] np.save(npz_file[:-3] + "npy", a) def unpack_dataset(folder, threads=8): case_identifiers = get_case_identifiers(folder) p = Pool(threads) npz_files = [os.path.join(folder, i + ".npz") for i in case_identifiers] p.map(convert_to_npy, npz_files) p.close() p.join() def delete_npy(folder): case_identifiers = get_case_identifiers(folder) npy_files = [os.path.join(folder, i + ".npy") for i in case_identifiers] npy_files = [i for i in npy_files if os.path.isfile(i)] for n in npy_files: os.remove(n) \ No newline at end of file diff --git a/utils/exp_utils.py b/utils/exp_utils.py index 0bf2988..e528993 100644 --- a/utils/exp_utils.py +++ b/utils/exp_utils.py @@ -1,671 +1,679 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== # import plotting as plg import sys import os import subprocess +from multiprocessing import Process import threading import pickle import importlib.util import psutil import time import logging from torch.utils.tensorboard import SummaryWriter from collections import OrderedDict import numpy as np import pandas as pd import torch def import_module(name, path): """ correct way of importing a module dynamically in python 3. :param name: name given to module instance. :param path: path to module. :return: module: returned module instance. """ spec = importlib.util.spec_from_file_location(name, path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module def save_obj(obj, name): """Pickle a python object.""" with open(name + '.pkl', 'wb') as f: pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) def load_obj(file_path): with open(file_path, 'rb') as handle: return pickle.load(handle) def IO_safe(func, *args, _tries=5, _raise=True, **kwargs): """ Wrapper calling function func with arguments args and keyword arguments kwargs to catch input/output errors on cluster. :param func: function to execute (intended to be read/write operation to a problematic cluster drive, but can be any function). :param args: positional args of func. :param kwargs: kw args of func. :param _tries: how many attempts to make executing func. """ for _try in range(_tries): try: return func(*args, **kwargs) except OSError as e: # to catch cluster issues with network drives if _raise: raise e else: print("After attempting execution {} time{}, following error occurred:\n{}".format(_try + 1, "" if _try == 0 else "s", e)) continue +def split_off_process(target, *args, **kwargs): + """Start a process that won't block parent script. + No join(), no return value. Before parent exits, it waits for this to finish. + """ + p = Process(target=target, args=tuple(args), kwargs=kwargs, daemon=False) + p.start() + def query_nvidia_gpu(device_id, d_keyword=None, no_units=False): """ :param device_id: :param d_keyword: -d, --display argument (keyword(s) for selective display), all are selected if None :return: dict of gpu-info items """ cmd = ['nvidia-smi', '-i', str(device_id), '-q'] if d_keyword is not None: cmd += ['-d', d_keyword] outp = subprocess.check_output(cmd).strip().decode('utf-8').split("\n") outp = [x for x in outp if len(x) > 0] headers = [ix for ix, item in enumerate(outp) if len(item.split(":")) == 1] + [len(outp)] out_dict = {} for lix, hix in enumerate(headers[:-1]): head = outp[hix].strip().replace(" ", "_").lower() out_dict[head] = {} for lix2 in range(hix, headers[lix + 1]): try: key, val = [x.strip().lower() for x in outp[lix2].split(":")] if no_units: val = val.split()[0] out_dict[head][key] = val except: pass return out_dict class CombinedPrinter(object): """combined print function. prints to logger and/or file if given, to normal print if non given. """ def __init__(self, logger=None, file=None): if logger is None and file is None: self.out = [print] elif logger is None: self.out = [file.write] elif file is None: self.out = [logger.info] else: self.out = [logger.info, file.write] def __call__(self, string): for fct in self.out: fct(string) class Nvidia_GPU_Logger(object): def __init__(self): self.count = None def get_vals(self): cmd = ['nvidia-settings', '-t', '-q', 'GPUUtilization'] gpu_util = subprocess.check_output(cmd).strip().decode('utf-8').split(",") gpu_util = dict([f.strip().split("=") for f in gpu_util]) cmd[-1] = 'UsedDedicatedGPUMemory' gpu_used_mem = subprocess.check_output(cmd).strip().decode('utf-8') current_vals = {"gpu_mem_alloc": gpu_used_mem, "gpu_graphics_util": int(gpu_util['graphics']), "gpu_mem_util": gpu_util['memory'], "time": time.time()} return current_vals def loop(self, interval): i = 0 while True: self.get_vals() self.log["time"].append(time.time()) self.log["gpu_util"].append(self.current_vals["gpu_graphics_util"]) if self.count is not None: i += 1 if i == self.count: exit(0) time.sleep(self.interval) def start(self, interval=1.): self.interval = interval self.start_time = time.time() self.log = {"time": [], "gpu_util": []} if self.interval is not None: thread = threading.Thread(target=self.loop) thread.daemon = True thread.start() class CombinedLogger(object): """Combine console and tensorboard logger and record system metrics. """ def __init__(self, name, log_dir, server_env=True, fold="all", sysmetrics_interval=2): self.pylogger = logging.getLogger(name) self.tboard = SummaryWriter(log_dir=os.path.join(log_dir, "tboard")) self.times = {} self.log_dir = log_dir self.fold = str(fold) self.server_env = server_env self.pylogger.setLevel(logging.DEBUG) self.log_file = os.path.join(log_dir, "fold_"+self.fold, 'exec.log') os.makedirs(os.path.dirname(self.log_file), exist_ok=True) self.pylogger.addHandler(logging.FileHandler(self.log_file)) if not server_env: self.pylogger.addHandler(ColorHandler()) else: self.pylogger.addHandler(logging.StreamHandler()) self.pylogger.propagate = False # monitor system metrics (cpu, mem, ...) if not server_env and sysmetrics_interval > 0: self.sysmetrics = pd.DataFrame( columns=["global_step", "rel_time", r"CPU (%)", "mem_used (GB)", r"mem_used (%)", r"swap_used (GB)", r"gpu_utilization (%)"], dtype="float16") for device in range(torch.cuda.device_count()): self.sysmetrics[ "mem_allocd (GB) by torch on {:10s}".format(torch.cuda.get_device_name(device))] = np.nan self.sysmetrics[ "mem_cached (GB) by torch on {:10s}".format(torch.cuda.get_device_name(device))] = np.nan self.sysmetrics_start(sysmetrics_interval) pass else: print("NOT logging sysmetrics") def __getattr__(self, attr): """delegate all undefined method requests to objects of this class in order pylogger, tboard (first find first serve). E.g., combinedlogger.add_scalars(...) should trigger self.tboard.add_scalars(...) """ for obj in [self.pylogger, self.tboard]: if attr in dir(obj): return getattr(obj, attr) print("logger attr not found") #raise AttributeError("CombinedLogger has no attribute {}".format(attr)) def set_logfile(self, fold=None, log_file=None): if fold is not None: self.fold = str(fold) if log_file is None: self.log_file = os.path.join(self.log_dir, "fold_"+self.fold, 'exec.log') else: self.log_file = log_file os.makedirs(os.path.dirname(self.log_file), exist_ok=True) for hdlr in self.pylogger.handlers: hdlr.close() self.pylogger.handlers = [] self.pylogger.addHandler(logging.FileHandler(self.log_file)) if not self.server_env: self.pylogger.addHandler(ColorHandler()) else: self.pylogger.addHandler(logging.StreamHandler()) def time(self, name, toggle=None): """record time-spans as with a stopwatch. :param name: :param toggle: True^=On: start time recording, False^=Off: halt rec. if None determine from current status. :return: either start-time or last recorded interval """ if toggle is None: if name in self.times.keys(): toggle = not self.times[name]["toggle"] else: toggle = True if toggle: if not name in self.times.keys(): self.times[name] = {"total": 0, "last": 0} elif self.times[name]["toggle"] == toggle: self.info("restarting running stopwatch") self.times[name]["last"] = time.time() self.times[name]["toggle"] = toggle return time.time() else: if toggle == self.times[name]["toggle"]: self.info("WARNING: tried to stop stopped stop watch: {}.".format(name)) self.times[name]["last"] = time.time() - self.times[name]["last"] self.times[name]["total"] += self.times[name]["last"] self.times[name]["toggle"] = toggle return self.times[name]["last"] def get_time(self, name=None, kind="total", format=None, reset=False): """ :param name: :param kind: 'total' or 'last' :param format: None for float, "hms"/"ms" for (hours), mins, secs as string :param reset: reset time after retrieving :return: """ if name is None: times = self.times if reset: self.reset_time() return times else: if self.times[name]["toggle"]: self.time(name, toggle=False) time = self.times[name][kind] if format == "hms": m, s = divmod(time, 60) h, m = divmod(m, 60) time = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(m), int(s)) elif format == "ms": m, s = divmod(time, 60) time = "{:02d}m:{:02d}s".format(int(m), int(s)) if reset: self.reset_time(name) return time def reset_time(self, name=None): if name is None: self.times = {} else: del self.times[name] def sysmetrics_update(self, global_step=None): if global_step is None: global_step = time.strftime("%x_%X") mem = psutil.virtual_memory() mem_used = (mem.total - mem.available) gpu_vals = self.gpu_logger.get_vals() rel_time = time.time() - self.sysmetrics_start_time self.sysmetrics.loc[len(self.sysmetrics)] = [global_step, rel_time, psutil.cpu_percent(), mem_used / 1024 ** 3, mem_used / mem.total * 100, psutil.swap_memory().used / 1024 ** 3, int(gpu_vals['gpu_graphics_util']), *[torch.cuda.memory_allocated(d) / 1024 ** 3 for d in range(torch.cuda.device_count())], *[torch.cuda.memory_cached(d) / 1024 ** 3 for d in range(torch.cuda.device_count())] ] return self.sysmetrics.loc[len(self.sysmetrics) - 1].to_dict() def sysmetrics2tboard(self, metrics=None, global_step=None, suptitle=None): tag = "per_time" if metrics is None: metrics = self.sysmetrics_update(global_step=global_step) tag = "per_epoch" if suptitle is not None: suptitle = str(suptitle) elif self.fold != "": suptitle = "Fold_" + str(self.fold) if suptitle is not None: self.tboard.add_scalars(suptitle + "/System_Metrics/" + tag, {k: v for (k, v) in metrics.items() if (k != "global_step" and k != "rel_time")}, global_step) def sysmetrics_loop(self): try: os.nice(-19) self.info("Logging system metrics with superior process priority.") except: self.info("Logging system metrics without superior process priority.") while True: metrics = self.sysmetrics_update() self.sysmetrics2tboard(metrics, global_step=metrics["rel_time"]) # print("thread alive", self.thread.is_alive()) time.sleep(self.sysmetrics_interval) def sysmetrics_start(self, interval): if interval is not None and interval > 0: self.sysmetrics_interval = interval self.gpu_logger = Nvidia_GPU_Logger() self.sysmetrics_start_time = time.time() self.thread = threading.Thread(target=self.sysmetrics_loop) self.thread.daemon = True self.thread.start() def sysmetrics_save(self, out_file): self.sysmetrics.to_pickle(out_file) def metrics2tboard(self, metrics, global_step=None, suptitle=None): """ :param metrics: {'train': dataframe, 'val':df}, df as produced in evaluator.py.evaluate_predictions """ # print("metrics", metrics) if global_step is None: global_step = len(metrics['train'][list(metrics['train'].keys())[0]]) - 1 if suptitle is not None: suptitle = str(suptitle) else: suptitle = "Fold_" + str(self.fold) for key in ['train', 'val']: # series = {k:np.array(v[-1]) for (k,v) in metrics[key].items() if not np.isnan(v[-1]) and not 'Bin_Stats' in k} loss_series = {} unc_series = {} bin_stat_series = {} mon_met_series = {} for tag, val in metrics[key].items(): val = val[-1] # maybe remove list wrapping, recording in evaluator? if 'bin_stats' in tag.lower() and not np.isnan(val): bin_stat_series["{}".format(tag.split("/")[-1])] = val elif 'uncertainty' in tag.lower() and not np.isnan(val): unc_series["{}".format(tag)] = val elif 'loss' in tag.lower() and not np.isnan(val): loss_series["{}".format(tag)] = val elif not np.isnan(val): mon_met_series["{}".format(tag)] = val self.tboard.add_scalars(suptitle + "/Binary_Statistics/{}".format(key), bin_stat_series, global_step) self.tboard.add_scalars(suptitle + "/Uncertainties/{}".format(key), unc_series, global_step) self.tboard.add_scalars(suptitle + "/Losses/{}".format(key), loss_series, global_step) self.tboard.add_scalars(suptitle + "/Monitor_Metrics/{}".format(key), mon_met_series, global_step) self.tboard.add_scalars(suptitle + "/Learning_Rate", metrics["lr"], global_step) return def batchImgs2tboard(self, batch, results_dict, cmap, boxtype2color, img_bg=False, global_step=None): raise NotImplementedError("not up-to-date, problem with importing plotting-file, torchvision dependency.") if len(batch["seg"].shape) == 5: # 3D imgs slice_ix = np.random.randint(batch["seg"].shape[-1]) seg_gt = plg.to_rgb(batch['seg'][:, 0, :, :, slice_ix], cmap) seg_pred = plg.to_rgb(results_dict['seg_preds'][:, 0, :, :, slice_ix], cmap) mod_img = plg.mod_to_rgb(batch["data"][:, 0, :, :, slice_ix]) if img_bg else None elif len(batch["seg"].shape) == 4: seg_gt = plg.to_rgb(batch['seg'][:, 0, :, :], cmap) seg_pred = plg.to_rgb(results_dict['seg_preds'][:, 0, :, :], cmap) mod_img = plg.mod_to_rgb(batch["data"][:, 0]) if img_bg else None else: raise Exception("batch content has wrong format: {}".format(batch["seg"].shape)) # from here on only works in 2D seg_gt = np.transpose(seg_gt, axes=(0, 3, 1, 2)) # previous shp: b,x,y,c seg_pred = np.transpose(seg_pred, axes=(0, 3, 1, 2)) seg = np.concatenate((seg_gt, seg_pred), axis=0) # todo replace torchvision (tv) dependency seg = tv.utils.make_grid(torch.from_numpy(seg), nrow=2) self.tboard.add_image("Batch seg, 1st col: gt, 2nd: pred.", seg, global_step=global_step) if img_bg: bg_img = np.transpose(mod_img, axes=(0, 3, 1, 2)) else: bg_img = seg_gt box_imgs = plg.draw_boxes_into_batch(bg_img, results_dict["boxes"], boxtype2color) box_imgs = tv.utils.make_grid(torch.from_numpy(box_imgs), nrow=4) self.tboard.add_image("Batch bboxes", box_imgs, global_step=global_step) return def __del__(self): # otherwise might produce multiple prints e.g. in ipython console for hdlr in self.pylogger.handlers: hdlr.close() self.pylogger.handlers = [] del self.pylogger self.tboard.close() def get_logger(exp_dir, server_env=False, sysmetrics_interval=2): log_dir = os.path.join(exp_dir, "logs") logger = CombinedLogger('Reg R-CNN', log_dir, server_env=server_env, sysmetrics_interval=sysmetrics_interval) print("logging to {}".format(logger.log_file)) return logger def prep_exp(dataset_path, exp_path, server_env, use_stored_settings=True, is_training=True): """ I/O handling, creating of experiment folder structure. Also creates a snapshot of configs/model scripts and copies them to the exp_dir. This way the exp_dir contains all info needed to conduct an experiment, independent to changes in actual source code. Thus, training/inference of this experiment can be started at anytime. Therefore, the model script is copied back to the source code dir as tmp_model (tmp_backbone). Provides robust structure for cloud deployment. :param dataset_path: path to source code for specific data set. (e.g. medicaldetectiontoolkit/lidc_exp) :param exp_path: path to experiment directory. :param server_env: boolean flag. pass to configs script for cloud deployment. :param use_stored_settings: boolean flag. When starting training: If True, starts training from snapshot in existing experiment directory, else creates experiment directory on the fly using configs/model scripts from source code. :param is_training: boolean flag. distinguishes train vs. inference mode. :return: configs object. """ if is_training: if use_stored_settings: cf_file = import_module('cf', os.path.join(exp_path, 'configs.py')) cf = cf_file.Configs(server_env) # in this mode, previously saved model and backbone need to be found in exp dir. if not os.path.isfile(os.path.join(exp_path, 'model.py')) or \ not os.path.isfile(os.path.join(exp_path, 'backbone.py')): raise Exception( "Selected use_stored_settings option but no model and/or backbone source files exist in exp dir.") cf.model_path = os.path.join(exp_path, 'model.py') cf.backbone_path = os.path.join(exp_path, 'backbone.py') else: # this case overwrites settings files in exp dir, i.e., default_configs, configs, backbone, model os.makedirs(exp_path, exist_ok=True) # run training with source code info and copy snapshot of model to exp_dir for later testing (overwrite scripts if exp_dir already exists.) subprocess.call('cp {} {}'.format('default_configs.py', os.path.join(exp_path, 'default_configs.py')), shell=True) subprocess.call( 'cp {} {}'.format(os.path.join(dataset_path, 'configs.py'), os.path.join(exp_path, 'configs.py')), shell=True) cf_file = import_module('cf_file', os.path.join(dataset_path, 'configs.py')) cf = cf_file.Configs(server_env) subprocess.call('cp {} {}'.format(cf.model_path, os.path.join(exp_path, 'model.py')), shell=True) subprocess.call('cp {} {}'.format(cf.backbone_path, os.path.join(exp_path, 'backbone.py')), shell=True) if os.path.isfile(os.path.join(exp_path, "fold_ids.pickle")): subprocess.call('rm {}'.format(os.path.join(exp_path, "fold_ids.pickle")), shell=True) else: # testing, use model and backbone stored in exp dir. cf_file = import_module('cf', os.path.join(exp_path, 'configs.py')) cf = cf_file.Configs(server_env) cf.model_path = os.path.join(exp_path, 'model.py') cf.backbone_path = os.path.join(exp_path, 'backbone.py') cf.exp_dir = exp_path cf.test_dir = os.path.join(cf.exp_dir, 'test') cf.plot_dir = os.path.join(cf.exp_dir, 'plots') if not os.path.exists(cf.test_dir): os.mkdir(cf.test_dir) if not os.path.exists(cf.plot_dir): os.mkdir(cf.plot_dir) cf.experiment_name = exp_path.split("/")[-1] cf.dataset_name = dataset_path cf.server_env = server_env cf.created_fold_id_pickle = False return cf class ModelSelector: ''' saves a checkpoint after each epoch as 'last_state' (can be loaded to continue interrupted training). saves the top-k (k=cf.save_n_models) ranked epochs. In inference, predictions of multiple epochs can be ensembled to improve performance. ''' def __init__(self, cf, logger): self.cf = cf self.saved_epochs = [-1] * cf.save_n_models self.logger = logger def run_model_selection(self, net, optimizer, monitor_metrics, epoch): """rank epoch via weighted mean from self.cf.model_selection_criteria: {criterion : weight} :param net: :param optimizer: :param monitor_metrics: :param epoch: :return: """ crita = self.cf.model_selection_criteria # shorter alias non_nan_scores = {} for criterion in crita.keys(): # exclude first entry bc its dummy None entry non_nan_scores[criterion] = [0 if (ii is None or np.isnan(ii)) else ii for ii in monitor_metrics['val'][criterion]][1:] n_epochs = len(non_nan_scores[criterion]) epochs_scores = [] for e_ix in range(n_epochs): epochs_scores.append(np.sum([weight * non_nan_scores[criterion][e_ix] for criterion, weight in crita.items()]) / len(crita.keys())) # ranking of epochs according to model_selection_criterion epoch_ranking = np.argsort(epochs_scores)[::-1] + 1 # epochs start at 1 # if set in configs, epochs < min_save_thresh are discarded from saving process. epoch_ranking = epoch_ranking[epoch_ranking >= self.cf.min_save_thresh] # check if current epoch is among the top-k epchs. if epoch in epoch_ranking[:self.cf.save_n_models]: if self.cf.server_env: IO_safe(torch.save, net.state_dict(), os.path.join(self.cf.fold_dir, '{}_best_params.pth'.format(epoch))) # save epoch_ranking to keep info for inference. IO_safe(np.save, os.path.join(self.cf.fold_dir, 'epoch_ranking'), epoch_ranking[:self.cf.save_n_models]) else: torch.save(net.state_dict(), os.path.join(self.cf.fold_dir, '{}_best_params.pth'.format(epoch))) np.save(os.path.join(self.cf.fold_dir, 'epoch_ranking'), epoch_ranking[:self.cf.save_n_models]) self.logger.info( "saving current epoch {} at rank {}".format(epoch, np.argwhere(epoch_ranking == epoch))) # delete params of the epoch that just fell out of the top-k epochs. for se in [int(ii.split('_')[0]) for ii in os.listdir(self.cf.fold_dir) if 'best_params' in ii]: if se in epoch_ranking[self.cf.save_n_models:]: subprocess.call('rm {}'.format(os.path.join(self.cf.fold_dir, '{}_best_params.pth'.format(se))), shell=True) self.logger.info('deleting epoch {} at rank {}'.format(se, np.argwhere(epoch_ranking == se))) state = { 'epoch': epoch, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), } if self.cf.server_env: IO_safe(torch.save, state, os.path.join(self.cf.fold_dir, 'last_state.pth')) else: torch.save(state, os.path.join(self.cf.fold_dir, 'last_state.pth')) def load_checkpoint(checkpoint_path, net, optimizer): checkpoint = torch.load(checkpoint_path) net.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) return checkpoint['epoch'] def prepare_monitoring(cf): """ creates dictionaries, where train/val metrics are stored. """ metrics = {} # first entry for loss dict accounts for epoch starting at 1. metrics['train'] = OrderedDict() # [(l_name, [np.nan]) for l_name in cf.losses_to_monitor] ) metrics['val'] = OrderedDict() # [(l_name, [np.nan]) for l_name in cf.losses_to_monitor] ) metric_classes = [] if 'rois' in cf.report_score_level: metric_classes.extend([v for k, v in cf.class_dict.items()]) if hasattr(cf, "eval_bins_separately") and cf.eval_bins_separately: metric_classes.extend([v for k, v in cf.bin_dict.items()]) if 'patient' in cf.report_score_level: metric_classes.extend(['patient_' + cf.class_dict[cf.patient_class_of_interest]]) if hasattr(cf, "eval_bins_separately") and cf.eval_bins_separately: metric_classes.extend(['patient_' + cf.bin_dict[cf.patient_bin_of_interest]]) for cl in metric_classes: for m in cf.metrics: metrics['train'][cl + '_' + m] = [np.nan] metrics['val'][cl + '_' + m] = [np.nan] return metrics class _AnsiColorizer(object): """ A colorizer is an object that loosely wraps around a stream, allowing callers to write text to the stream in a particular color. Colorizer classes must implement C{supported()} and C{write(text, color)}. """ _colors = dict(black=30, red=31, green=32, yellow=33, blue=34, magenta=35, cyan=36, white=37, default=39) def __init__(self, stream): self.stream = stream @classmethod def supported(cls, stream=sys.stdout): """ A class method that returns True if the current platform supports coloring terminal output using this method. Returns False otherwise. """ if not stream.isatty(): return False # auto color only on TTYs try: import curses except ImportError: return False else: try: try: return curses.tigetnum("colors") > 2 except curses.error: curses.setupterm() return curses.tigetnum("colors") > 2 except: raise # guess false in case of error return False def write(self, text, color): """ Write the given text to the stream in the given color. @param text: Text to be written to the stream. @param color: A string label for a color. e.g. 'red', 'white'. """ color = self._colors[color] self.stream.write('\x1b[%sm%s\x1b[0m' % (color, text)) class ColorHandler(logging.StreamHandler): def __init__(self, stream=sys.stdout): super(ColorHandler, self).__init__(_AnsiColorizer(stream)) def emit(self, record): msg_colors = { logging.DEBUG: "green", logging.INFO: "default", logging.WARNING: "red", logging.ERROR: "red" } color = msg_colors.get(record.levelno, "blue") self.stream.write(record.msg + "\n", color)