diff --git a/datasets/cityscapes/configs.py b/datasets/cityscapes/configs.py index ed2cdab..46808c9 100644 --- a/datasets/cityscapes/configs.py +++ b/datasets/cityscapes/configs.py @@ -1,434 +1,434 @@ __author__ = '' #credit Paul F. Jaeger ######################### # Example Config # ######################### import os import sys import numpy as np from collections import namedtuple sys.path.append('../') from default_configs import DefaultConfigs class Configs(DefaultConfigs): def __init__(self, server_env=None): super(Configs, self).__init__(server_env) self.dim = 2 ######################### # I/O # ######################### self.data_sourcedir = "/mnt/HDD2TB/Documents/data/cityscapes/cs_20190715/" if server_env: #self.source_dir = '/home/ramien/medicaldetectiontoolkit/' self.data_sourcedir = '/datasets/data_ramien/cityscapes/cs_20190715_npz/' #self.data_sourcedir = "/mnt/HDD2TB/Documents/data/cityscapes/cs_6c_inst_only/" self.datapath = "leftImg8bit/" self.targetspath = "gtFine/" self.cities = {'train':['dusseldorf', 'aachen', 'bochum', 'cologne', 'erfurt', 'hamburg', 'hanover', 'jena', 'krefeld', 'monchengladbach', 'strasbourg', 'stuttgart', 'tubingen', 'ulm', 'weimar', 'zurich'], 'val':['frankfurt', 'munster'], 'test':['bremen', 'darmstadt', 'lindau'] } self.set_splits = ["train", "val", "test"] # for training and val, mixed up # test cities are not held out self.info_dict_name = 'city_info.pkl' self.info_dict_path = os.path.join(self.data_sourcedir, self.info_dict_name) self.config_path = os.path.realpath(__file__) self.backbone_path = 'models/backbone.py' # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'detection_fpn']. self.model = 'retina_unet' self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net') self.model_path = os.path.join(self.source_dir, self.model_path) self.select_prototype_subset = None ######################### # Preprocessing # ######################### self.prepro = { 'data_dir': '/mnt/HDD2TB/Documents/data/cityscapes_raw/', #raw files (input), needs to end with "/" 'targettype': "gtFine_instanceIds", 'set_splits': ["train", "val", "test"], 'img_target_size': np.array([256, 512])*4, #y,x 'output_directory': self.data_sourcedir, 'center_of_mass_crop': True, #not implemented #'pre_crop_size': , #z,y,x 'normalization': {'percentiles':[1., 99.]},#not implemented 'interpolation': 'nearest', #not implemented 'info_dict_path': self.info_dict_path, 'npz_dir' : self.data_sourcedir[:-1]+"_npz" #if not None: convert to npz, copy data here } ######################### # Architecture # ######################### # 'class', 'regression', 'regression_ken_gal' # 'class': standard object classification per roi, pairwise combinable with each of below tasks. # 'class' is only option implemented for CityScapes data set. self.prediction_tasks = ['class',] self.start_filts = 52 self.end_filts = self.start_filts * 4 self.res_architecture = 'resnet101' # 'resnet101' , 'resnet50' self.weight_init = None # 'kaiming', 'xavier' or None for pytorch default self.norm = 'instance_norm' # 'batch_norm' # one of 'None', 'instance_norm', 'batch_norm' self.relu = 'relu' ######################### # Data Loader # ######################### self.seed = 17 self.n_workers = 16 if server_env else os.cpu_count() self.batch_size = 8 self.n_cv_splits = 10 #at least 2 (train, val) self.num_classes = None #set below #for instance classification (excl background) self.num_seg_classes = None #set below #incl background self.create_bounding_box_targets = True self.class_specific_seg = True self.channels = [0,1,2] self.pre_crop_size = self.prepro['img_target_size'] # y,x self.crop_margin = [10,10] #has to be smaller than respective patch_size//2 self.patch_size_2D = [256, 512] #self.pre_crop_size #would be better to save as tuple since should not be altered self.patch_size_3D = self.patch_size_2D + [1] self.patch_size = self.patch_size_2D self.balance_target = "class_targets" # ratio of fully random patients drawn during batch generation # resulting batch random count is rounded down to closest integer self.batch_random_ratio = 0.2 self.observables_patient = [] self.observables_rois = [] ######################### # Data Augmentation # ######################### #the angle rotations are implemented incorrectly in batchgenerators! in 2D, #the x-axis angle controls the z-axis angle. self.do_aug = True self.da_kwargs = { 'mirror': True, 'mirror_axes': (1,), #image axes, (batch and channel are ignored, i.e., actual tensor dims are +2) 'random_crop': True, 'rand_crop_dist': (self.patch_size[0] / 2., self.patch_size[1] / 2.), 'do_elastic_deform': True, 'alpha': (0., 1000.), 'sigma': (28., 30.), 'do_rotation': True, 'angle_x': (-np.pi / 8., np.pi / 8.), 'angle_y': (0.,0.), 'angle_z': (0.,0.), 'do_scale': True, 'scale': (0.6, 1.4), 'border_mode_data': 'constant', 'gamma_range': (0.6, 1.4) } ################################# # Schedule / Selection / Optim # ################################# #mrcnn paper: ~2.56m samples seen during coco-dataset training self.num_epochs = 400 self.num_train_batches = 600 self.do_validation = True # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training) # the former is morge accurate, while the latter is faster (depending on volume size) self.val_mode = 'val_sampling' # one of 'val_sampling', 'val_patient' # if 'all' iterates over entire val_set once. self.num_val_batches = "all" # for val_sampling self.save_n_models = 3 self.min_save_thresh = 1 # in epochs self.model_selection_criteria = {"human_ap": 1., "vehicle_ap": 0.9} self.warm_up = 0 self.learning_rate = [5*1e-4] * self.num_epochs self.dynamic_lr_scheduling = True #with scheduler set in exec self.lr_decay_factor = 0.5 self.scheduling_patience = int(self.num_epochs//10) self.weight_decay = 1e-6 self.clip_norm = None # number or None ######################### # Colors and Legends # ######################### self.plot_frequency = 5 #colors self.color_palette = [self.red, self.blue, self.green, self.orange, self.aubergine, self.yellow, self.gray, self.cyan, self.black] #legends Label = namedtuple( 'Label' , [ 'name' , # The identifier of this label, e.g. 'car', 'person', ... . # We use them to uniquely name a class 'ppId' , # An integer ID that is associated with this label. # The IDs are used to represent the label in ground truth images # An ID of -1 means that this label does not have an ID and thus # is ignored when creating ground truth images (e.g. license plate). # Do not modify these IDs, since exactly these IDs are expected by the # evaluation server. 'id' , # Feel free to modify these IDs as suitable for your method. # Max value is 255! 'category' , # The name of the category that this label belongs to 'categoryId' , # The ID of this category. Used to create ground truth images # on category level. 'hasInstances', # Whether this label distinguishes between single instances or not 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored # during evaluations or not 'color' , # The color of this label ] ) segLabel = namedtuple( "segLabel", ["name", "id", "color"]) boxLabel = namedtuple( 'boxLabel', [ "name", "color"]) self.labels = [ # name ppId id category catId hasInstances ignoreInEval color Label( 'ignore' , 0 , 0 , 'void' , 0 , False , True , ( 0., 0., 0., 1.) ), Label( 'ego vehicle' , 1 , 0 , 'void' , 0 , False , True , ( 0., 0., 0., 1.) ), Label( 'rectification border' , 2 , 0 , 'void' , 0 , False , True , ( 0., 0., 0., 1.) ), Label( 'out of roi' , 3 , 0 , 'void' , 0 , False , True , ( 0., 0., 0., 1.) ), Label( 'static' , 4 , 0 , 'void' , 0 , False , True , ( 0., 0., 0., 1.) ), Label( 'dynamic' , 5 , 0 , 'void' , 0 , False , True , (0.44, 0.29, 0., 1.) ), Label( 'ground' , 6 , 0 , 'void' , 0 , False , True , ( 0.32, 0., 0.32, 1.) ), Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (0.5, 0.25, 0.5, 1.) ), Label( 'sidewalk' , 8 , 0 , 'flat' , 1 , False , False , (0.96, 0.14, 0.5, 1.) ), Label( 'parking' , 9 , 0 , 'flat' , 1 , False , True , (0.98, 0.67, 0.63, 1.) ), Label( 'rail track' , 10 , 0 , 'flat' , 1 , False , True , ( 0.9, 0.59, 0.55, 1.) ), Label( 'building' , 11 , 0 , 'construction' , 2 , False , False , ( 0.27, 0.27, 0.27, 1.) ), Label( 'wall' , 12 , 0 , 'construction' , 2 , False , False , (0.4,0.4,0.61, 1.) ), Label( 'fence' , 13 , 0 , 'construction' , 2 , False , False , (0.75,0.6,0.6, 1.) ), Label( 'guard rail' , 14 , 0 , 'construction' , 2 , False , True , (0.71,0.65,0.71, 1.) ), Label( 'bridge' , 15 , 0 , 'construction' , 2 , False , True , (0.59,0.39,0.39, 1.) ), Label( 'tunnel' , 16 , 0 , 'construction' , 2 , False , True , (0.59,0.47, 0.35, 1.) ), Label( 'pole' , 17 , 0 , 'object' , 3 , False , False , (0.6,0.6,0.6, 1.) ), Label( 'polegroup' , 18 , 0 , 'object' , 3 , False , True , (0.6,0.6,0.6, 1.) ), Label( 'traffic light' , 19 , 0 , 'object' , 3 , False , False , (0.98,0.67, 0.12, 1.) ), Label( 'traffic sign' , 20 , 0 , 'object' , 3 , False , False , (0.86,0.86, 0., 1.) ), Label( 'vegetation' , 21 , 0 , 'nature' , 4 , False , False , (0.42,0.56, 0.14, 1.) ), Label( 'terrain' , 22 , 0 , 'nature' , 4 , False , False , (0.6, 0.98,0.6, 1.) ), Label( 'sky' , 23 , 0 , 'sky' , 5 , False , False , (0.27,0.51,0.71, 1.) ), Label( 'person' , 24 , 1 , 'human' , 6 , True , False , (0.86, 0.08, 0.24, 1.) ), Label( 'rider' , 25 , 1 , 'human' , 6 , True , False , (1., 0., 0., 1.) ), Label( 'car' , 26 , 2 , 'vehicle' , 7 , True , False , ( 0., 0.,0.56, 1.) ), Label( 'truck' , 27 , 2 , 'vehicle' , 7 , True , False , ( 0., 0., 0.27, 1.) ), Label( 'bus' , 28 , 2 , 'vehicle' , 7 , True , False , ( 0., 0.24,0.39, 1.) ), Label( 'caravan' , 29 , 2 , 'vehicle' , 7 , True , True , ( 0., 0., 0.35, 1.) ), Label( 'trailer' , 30 , 2 , 'vehicle' , 7 , True , True , ( 0., 0.,0.43, 1.) ), Label( 'train' , 31 , 2 , 'vehicle' , 7 , True , False , ( 0., 0.31,0.39, 1.) ), Label( 'motorcycle' , 32 , 2 , 'vehicle' , 7 , True , False , ( 0., 0., 0.9, 1.) ), Label( 'bicycle' , 33 , 2 , 'vehicle' , 7 , True , False , (0.47, 0.04, 0.13, 1.) ), Label( 'license plate' , -1 , 0 , 'vehicle' , 7 , False , True , ( 0., 0., 0.56, 1.) ), Label( 'background' , -1 , 0 , 'void' , 0 , False , True , ( 0., 0., 0.0, 0.) ), Label( 'vehicle' , 33 , 2 , 'vehicle' , 7 , True , False , (*self.aubergine, 1.) ), Label( 'human' , 25 , 1 , 'human' , 6 , True , False , (*self.blue, 1.) ) ] # evtl problem: class-ids (trainIds) don't start with 0 for the first class, 0 is bg. #WONT WORK: class ids need to start at 0 (excluding bg!) and be consecutively numbered self.ppId2id = { label.ppId : label.id for label in self.labels} self.class_id2label = { label.id : label for label in self.labels} self.class_cmap = {label.id : label.color for label in self.labels} self.class_dict = {label.id : label.name for label in self.labels if label.id!=0} #c_dict: only for evaluation, remove bg class. self.box_type2label = {label.name : label for label in self.box_labels} self.box_color_palette = {label.name:label.color for label in self.box_labels} if self.class_specific_seg: self.seg_labels = [label for label in self.class_id2label.values()] else: self.seg_labels = [ # name id color segLabel( "bg" , 0, (1.,1.,1.,0.) ), segLabel( "fg" , 1, (*self.orange, .8)) ] self.seg_id2label = {label.id : label for label in self.seg_labels} self.cmap = {label.id : label.color for label in self.seg_labels} self.plot_prediction_histograms = True self.plot_stat_curves = False self.has_colorchannels = True self.plot_class_ids = True self.num_classes = len(self.class_dict) self.num_seg_classes = len(self.seg_labels) ######################### # Testing # ######################### self.test_aug_axes = None #None or list: choices are 2,3,(2,3) - self.held_out_test_set = False + self.hold_out_test_set = False self.max_test_patients = 'all' # 'all' for all self.report_score_level = ['rois',] # choose list from 'patient', 'rois' self.patient_class_of_interest = 1 self.metrics = ['ap', 'dice'] self.ap_match_ious = [0.1] # threshold(s) for considering a prediction as true positive # aggregation method for test and val_patient predictions. # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf, # nms = standard non-maximum suppression, or None = no clustering self.clustering = 'wbc' # iou thresh (exclusive!) for regarding two preds as concerning the same ROI self.clustering_iou = 0.1 # has to be larger than desired possible overlap iou of model predictions self.min_det_thresh = 0.06 self.merge_2D_to_3D_preds = False self.n_test_plots = 1 #per fold and rankself.ap_match_ious = [0.1] #threshold(s) for considering a prediction as true positive self.test_n_epochs = self.save_n_models ######################### # shared model settings # ######################### # max number of roi candidates to identify per image and class (slice in 2D, volume in 3D) self.n_roi_candidates = 100 ######################### # Add model specifics # ######################### {'mrcnn': self.add_mrcnn_configs, 'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs }[self.model]() def add_mrcnn_configs(self): self.scheduling_criterion = max(self.model_selection_criteria, key=self.model_selection_criteria.get) self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' # number of classes for network heads: n_foreground_classes + 1 (background) self.head_classes = self.num_classes + 1 # seg_classes here refers to the first stage classifier (RPN) reallY? # feed +/- n neighbouring slices into channel dimension. set to None for no context. self.n_3D_context = None self.frcnn_mode = False self.detect_while_training = True # disable the re-sampling of mask proposals to original size for speed-up. # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching), # mask outputs are optional. self.return_masks_in_train = True self.return_masks_in_val = True self.return_masks_in_test = True # feature map strides per pyramid level are inferred from architecture. anchor scales are set accordingly. self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]} # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.) self.rpn_anchor_scales = {'xy': [[4], [8], [16], [32]], 'z': [[1], [2], [4], [8]]} # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3. self.pyramid_levels = [0, 1, 2, 3] # number of feature maps in rpn. typically lowered in 3D to save gpu-memory. self.n_rpn_features = 512 if self.dim == 2 else 64 # anchor ratios and strides per position in feature maps. self.rpn_anchor_ratios = [0.5, 1., 2.] self.rpn_anchor_stride = 1 # Threshold for first stage (RPN) non-maximum suppression (NMS): LOWER == HARDER SELECTION self.rpn_nms_threshold = 0.7 # loss sampling settings. self.rpn_train_anchors_per_image = 8 self.train_rois_per_image = 10 # per batch_instance self.roi_positive_ratio = 0.5 self.anchor_matching_iou = 0.8 # k negative example candidates are drawn from a pool of size k*shem_poolsize (stochastic hard-example mining), # where k<=#positive examples. self.shem_poolsize = 3 self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3) self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5) self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10) self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]]) self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1], self.patch_size_3D[2], self.patch_size_3D[2]]) # y1,x1,y2,x2,z1,z2 if self.dim == 2: self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4] self.bbox_std_dev = self.bbox_std_dev[:4] self.window = self.window[:4] self.scale = self.scale[:4] self.plot_y_max = 1.5 self.n_plot_rpn_props = 5 # per batch_instance (slice in 2D / patient in 3D) # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element. self.pre_nms_limit = 3000 # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True, # since proposals of the entire batch are forwarded through second stage as one "batch". self.roi_batch_size = 2500 self.post_nms_rois_training = 500 self.post_nms_rois_inference = 500 # Final selection of detections (refine_detections) self.model_max_instances_per_batch_element = 50 # per batch element and class. self.detection_nms_threshold = 1e-5 # needs to be > 0, otherwise all predictions are one cluster. self.model_min_confidence = 0.05 # iou for nms in box refining (directly after heads), should be >0 since ths>=x in mrcnn.py if self.dim == 2: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride))] for stride in self.backbone_strides['xy']]) else: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride)), int(np.ceil(self.patch_size[2] / stride_z))] for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z'] )]) if self.model == 'retina_net' or self.model == 'retina_unet': # implement extra anchor-scales according to https://arxiv.org/abs/1708.02002 self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['xy']] self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['z']] self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3 self.n_rpn_features = 256 if self.dim == 2 else 64 # pre-selection of detections for NMS-speedup. per entire batch. self.pre_nms_limit = 10000 if self.dim == 2 else 30000 # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002 self.anchor_matching_iou = 0.5 if self.model == 'retina_unet': self.operate_stride1 = True \ No newline at end of file diff --git a/datasets/cityscapes/data_loader.py b/datasets/cityscapes/data_loader.py index 01a1a45..a799de3 100644 --- a/datasets/cityscapes/data_loader.py +++ b/datasets/cityscapes/data_loader.py @@ -1,452 +1,452 @@ import sys sys.path.append('../') #works on cluster indep from where sbatch job is started import plotting as plg import warnings import os import time import pickle import numpy as np import pandas as pd from PIL import Image as pil import torch import torch.utils.data # batch generator tools from https://github.com/MIC-DKFZ/batchgenerators from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror from batchgenerators.transforms.abstract_transforms import Compose from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter from batchgenerators.transforms.spatial_transforms import SpatialTransform from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform from batchgenerators.transforms.color_transforms import GammaTransform #from batchgenerators.transforms.utility_transforms import ConvertSegToBoundingBoxCoordinates sys.path.append(os.path.dirname(os.path.realpath(__file__))) import utils.exp_utils as utils import utils.dataloader_utils as dutils from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates from configs import Configs cf= configs() warnings.filterwarnings("ignore", message="This figure includes Axes.*") def load_obj(file_path): with open(file_path, 'rb') as handle: return pickle.load(handle) def save_to_npy(arr_out, array): np.save(arr_out+".npy", array) print("Saved binary .npy-file to {}".format(arr_out)) return arr_out+".npy" def shape_small_first(shape): if len(shape)<=2: #no changing dimensions if channel-dim is missing return shape smallest_dim = np.argmin(shape) if smallest_dim!=0: #assume that smallest dim is color channel new_shape = np.array(shape) #to support mask indexing new_shape = (new_shape[smallest_dim], *new_shape[(np.arange(len(shape),dtype=int)!=smallest_dim)]) return new_shape else: return shape class Dataset(dutils.Dataset): def __init__(self, cf, logger=None, subset_ids=None, data_sourcedir=None): super(Dataset, self).__init__(cf, data_sourcedir=data_sourcedir) info_dict = load_obj(cf.info_dict_path) if subset_ids is not None: img_ids = subset_ids if logger is None: print('subset: selected {} instances from df'.format(len(pids))) else: logger.info('subset: selected {} instances from df'.format(len(pids))) else: img_ids = list(info_dict.keys()) #evtly copy data from data_rootdir to data_dir if cf.server_env and not hasattr(cf, "data_dir"): file_subset = [info_dict[img_id]['img'][:-3]+"*" for img_id in img_ids] file_subset+= [info_dict[img_id]['seg'][:-3]+"*" for img_id in img_ids] file_subset+= [cf.info_dict_path] self.copy_data(cf, file_subset=file_subset) cf.data_dir = self.data_dir img_paths = [os.path.join(self.data_dir, info_dict[img_id]['img']) for img_id in img_ids] seg_paths = [os.path.join(self.data_dir, info_dict[img_id]['seg']) for img_id in img_ids] # load all subject files self.data = {} for i, img_id in enumerate(img_ids): subj_data = {'img_id':img_id} subj_data['img'] = img_paths[i] subj_data['seg'] = seg_paths[i] if 'class' in self.cf.prediction_tasks: subj_data['class_targets'] = np.array(info_dict[img_id]['roi_classes']) else: subj_data['class_targets'] = np.ones_like(np.array(info_dict[img_id]['roi_classes'])) self.data[img_id] = subj_data cf.roi_items = cf.observables_rois[:] cf.roi_items += ['class_targets'] if 'regression' in cf.prediction_tasks: cf.roi_items += ['regression_targets'] self.set_ids = list(self.data.keys()) self.df = None class BatchGenerator(dutils.BatchGenerator): """ create the training/validation batch generator. Randomly sample batch_size patients from the data set, (draw a random slice if 2D), pad-crop them to equal sizes and merge to an array. :param data: data dictionary as provided by 'load_dataset' :param img_modalities: list of strings ['adc', 'b1500'] from config :param batch_size: number of patients to sample for the batch :param pre_crop_size: equal size for merging the patients to a single array (before the final random-crop in data aug.) :return dictionary containing the batch data / seg / pids as lists; the augmenter will later concatenate them into an array. """ def __init__(self, cf, data, n_batches=None, sample_pids_w_replace=True): super(BatchGenerator, self).__init__(cf, data, n_batches) self.dataset_length = len(self._data) self.cf = cf self.sample_pids_w_replace = sample_pids_w_replace self.eligible_pids = list(self._data.keys()) self.chans = cf.channels if cf.channels is not None else np.index_exp[:] assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing" self.p_fg = 0.5 self.empty_samples_max_ratio = 0.33 self.random_count = int(cf.batch_random_ratio * cf.batch_size) self.balance_target_distribution(plot=sample_pids_w_replace) self.stats = {"roi_counts" : np.zeros((len(self.unique_ts),), dtype='uint32'), "empty_samples_count" : 0} def generate_train_batch(self): #everything done in here is per batch #print statements in here get confusing due to multithreading if self.sample_pids_w_replace: # fully random patients batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False)) # target-balanced patients batch_patient_ids += list(np.random.choice( self.dataset_pids, size=self.batch_size - self.random_count, replace=False, p=self.p_probs)) else: batch_patient_ids = np.random.choice(self.eligible_pids, size=self.batch_size, replace=False) if self.sample_pids_w_replace == False: self.eligible_pids = [pid for pid in self.eligible_pids if pid not in batch_patient_ids] if len(self.eligible_pids) < self.batch_size: self.eligible_pids = self.dataset_pids batch_data, batch_segs, batch_class_targets = [], [], [] # record roi count of classes in batch batch_roi_counts, empty_samples_count = np.zeros((self.cf.num_classes,), dtype='uint32'), 0 for sample in range(self.batch_size): patient = self._data[batch_patient_ids[sample]] data = np.load(patient["img"], mmap_mode="r") seg = np.load(patient['seg'], mmap_mode="r") (c,y,x) = data.shape spatial_shp = data[0].shape assert spatial_shp==seg.shape, "spatial shape incongruence betw. data {} and seg {}".format(spatial_shp, seg.shape) if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]): new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))] data = dutils.pad_nd_image(data, (len(data), *new_shape)) seg = dutils.pad_nd_image(seg, new_shape) #eventual cropping to pre_crop_size: with prob self.p_fg sample pixel from random ROI and shift center, #if possible, to that pixel, so that img still contains ROI after pre-cropping dim_cropflags = [spatial_shp[i] > self.cf.pre_crop_size[i] for i in range(len(spatial_shp))] if np.any(dim_cropflags): #sample crop center regardless of ROIs, not guaranteed to be empty def get_cropped_centercoords(dim): return np.random.randint(low=self.cf.pre_crop_size[dim]//2, high=spatial_shp[dim] - self.cf.pre_crop_size[dim]//2) sample_seg_center = {} for dim in np.where(dim_cropflags)[0]: sample_seg_center[dim] = get_cropped_centercoords(dim) min_ = int(sample_seg_center[dim] - self.cf.pre_crop_size[dim]//2) max_ = int(sample_seg_center[dim] + self.cf.pre_crop_size[dim]//2) data = np.take(data, indices=range(min_, max_), axis=dim+1) #+1 for channeldim seg = np.take(seg, indices=range(min_, max_), axis=dim) batch_data.append(data) batch_segs.append(seg[np.newaxis]) batch_class_targets.append(patient['class_targets']) for cl in range(self.cf.num_classes): batch_roi_counts[cl] += np.count_nonzero(patient['class_targets'][np.unique(seg[seg>0]) - 1] == cl) if not np.any(seg): empty_samples_count += 1 batch = {'data': np.array(batch_data).astype('float32'), 'seg': np.array(batch_segs).astype('uint8'), 'pid': batch_patient_ids, 'class_targets': np.array(batch_class_targets), 'roi_counts': batch_roi_counts, 'empty_samples_count': empty_samples_count} return batch class PatientBatchIterator(dutils.PatientBatchIterator): """ creates a val/test generator. Step through the dataset and return dictionaries per patient. For Patching, shifts all patches into batch dimension. batch_tiling_forward will take care of exceeding batch dimensions. This iterator/these batches are not intended to go through MTaugmenter afterwards """ def __init__(self, cf, data): super(PatientBatchIterator, self).__init__(cf, data) self.patch_size = cf.patch_size self.patient_ix = 0 # running index over all patients in set def generate_train_batch(self, pid=None): if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 if pid is None: pid = self.dataset_pids[self.patient_ix] # + self.thread_id patient = self._data[pid] batch_class_targets = np.array([patient['class_targets']]) data = np.load(patient["img"], mmap_mode="r")[np.newaxis] seg = np.load(patient['seg'], mmap_mode="r")[np.newaxis, np.newaxis] (b, c, y, x) = data.shape spatial_shp = data.shape[2:] assert spatial_shp == seg.shape[2:], "spatial shape incongruence betw. data {} and seg {}".format(spatial_shp, seg.shape) if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]): new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))] data = dutils.pad_nd_image(data, (len(data), *new_shape)) seg = dutils.pad_nd_image(seg, new_shape) batch = {'data': data, 'seg': seg, 'class_targets': batch_class_targets} converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, self.cf.class_specific_seg) batch = converter(**batch) batch.update({'patient_bb_target': batch['bb_target'], 'patient_class_targets': batch['class_targets'], 'original_img_shape': data.shape, 'pid': np.array([pid] * len(data))}) # eventual tiling into patches spatial_shp = batch["data"].shape[2:] if np.any([spatial_shp[ix] > self.patch_size[ix] for ix in range(len(spatial_shp))]): patient_batch = batch print("patientiterator produced patched batch!") patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size) new_img_batch, new_seg_batch = [], [] for c in patch_crop_coords_list: new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3]]) seg_patch = seg[:, c[0]:c[1], c[2]: c[3]] new_seg_batch.append(seg_patch) shps = [] for arr in new_img_batch: shps.append(arr.shape) data = np.array(new_img_batch) # (patches, c, x, y, z) seg = np.array(new_seg_batch) batch_class_targets = np.repeat(batch_class_targets, len(patch_crop_coords_list), axis=0) patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'), 'class_targets': batch_class_targets, 'pid': np.array([pid] * data.shape[0])} patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] patch_batch['patient_class_targets'] = patient_batch['patient_class_targets'] patch_batch['patient_data'] = patient_batch['data'] patch_batch['patient_seg'] = patient_batch['seg'] patch_batch['original_img_shape'] = patient_batch['original_img_shape'] converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, self.cf.class_specific_seg) patch_batch = converter(**patch_batch) batch = patch_batch self.patient_ix += 1 if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 return batch def create_data_gen_pipeline(cf, patient_data, do_aug=True, sample_pids_w_replace=True): """ create mutli-threaded train/val/test batch generation and augmentation pipeline. :param patient_data: dictionary containing one dictionary per patient in the train/test subset :param test_pids: (optional) list of test patient ids, calls the test generator. :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing) :return: multithreaded_generator """ data_gen = BatchGenerator(cf, patient_data, sample_pids_w_replace=sample_pids_w_replace) my_transforms = [] if do_aug: if cf.da_kwargs["mirror"]: mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes']) my_transforms.append(mirror_transform) spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'][:2], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop'], border_mode_data=cf.da_kwargs['border_mode_data']) my_transforms.append(spatial_transform) gamma_transform = GammaTransform(gamma_range=cf.da_kwargs["gamma_range"], invert_image=False, per_channel=False, retain_stats=False) my_transforms.append(gamma_transform) else: my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) if cf.create_bounding_box_targets: my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg)) #batch receives entry 'bb_target' w bbox coordinates as [y1,x1,y2,x2,z1,z2]. #my_transforms.append(ConvertSegToOnehotTransform(classes=range(cf.num_seg_classes))) all_transforms = Compose(my_transforms) #MTAugmenter creates iterator from data iterator data_gen after applying the composed transform all_transforms multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=np.random.randint(0,cf.n_workers*2,size=cf.n_workers)) return multithreaded_generator def get_train_generators(cf, logger, data_statistics=True): """ wrapper function for creating the training batch generator pipeline. returns the train/val generators need to select cv folds on patient level, but be able to include both breasts of each patient. """ dataset = Dataset(cf) dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits) dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle')) set_splits = dataset.fg.splits test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1) train_ids = np.concatenate(set_splits, axis=0) - if cf.held_out_test_set: + if cf.hold_out_test_set: train_ids = np.concatenate((train_ids, test_ids), axis=0) test_ids = [] train_data = {k: v for (k, v) in dataset.data.items() if k in train_ids} val_data = {k: v for (k, v) in dataset.data.items() if k in val_ids} logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids), len(test_ids))) if data_statistics: dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids}, plot_dir=os.path.join(cf.plot_dir, "data_stats_fold_"+str(cf.fold))) batch_gen = {} batch_gen['train'] = create_data_gen_pipeline(cf, train_data, do_aug=True) batch_gen[cf.val_mode] = create_data_gen_pipeline(cf, val_data, do_aug=False, sample_pids_w_replace=False) batch_gen['n_val'] = cf.num_val_batches if cf.num_val_batches!="all" else len(val_data) return batch_gen def get_test_generator(cf, logger): """ if get_test_generators is called multiple times in server env, every time of Dataset initiation rsync will check for copying the data; this should be okay since rsync will not copy if files already exist in destination. """ - if cf.held_out_test_set: + if cf.hold_out_test_set: sourcedir = cf.test_data_sourcedir test_ids = None else: sourcedir = None with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle: set_splits = pickle.load(handle) test_ids = set_splits[cf.fold] test_set = Dataset(cf, test_ids, data_sourcedir=sourcedir) logger.info("data set loaded with: {} test patients".format(len(test_set.set_ids))) batch_gen = {} batch_gen['test'] = PatientBatchIterator(cf, test_set.data) batch_gen['n_test'] = len(test_set.set_ids) if cf.max_test_patients=="all" else min(cf.max_test_patients, len(test_set.set_ids)) return batch_gen def main(): total_stime = time.time() times = {} CUDA = torch.cuda.is_available() print("CUDA available: ", CUDA) #cf.server_env = True #cf.data_dir = "experiments/dev_data" cf.exp_dir = "experiments/dev/" cf.plot_dir = cf.exp_dir+"plots" os.makedirs(cf.exp_dir, exist_ok=True) cf.fold = 0 logger = utils.get_logger(cf.exp_dir) gens = get_train_generators(cf, logger) train_loader = gens['train'] #for i in range(train_loader.dataset_length): # print("batch", i) stime = time.time() ex_batch = next(train_loader) # plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_extrainbatch.png", has_colorchannels=True, isRGB=True) times["train_batch"] = time.time()-stime val_loader = gens['val_sampling'] stime = time.time() ex_batch = next(val_loader) times["val_batch"] = time.time()-stime stime = time.time() plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch.png", has_colorchannels=True, isRGB=True, show_gt_boxes=False) times["val_plot"] = time.time()-stime test_loader = get_test_generator(cf, logger)["test"] stime = time.time() ex_batch = next(test_loader) times["test_batch"] = time.time()-stime #plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_expatientbatch.png", has_colorchannels=True, isRGB=True) print(ex_batch["data"].shape) print("Times recorded throughout:") for (k,v) in times.items(): print(k, "{:.2f}".format(v)) mins, secs = divmod((time.time() - total_stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) if __name__=="__main__": start_time = time.time() main() print("Program runtime in s: ", '{:.2f}'.format(time.time()-start_time)) \ No newline at end of file diff --git a/datasets/lidc/configs.py b/datasets/lidc/configs.py index 0b828af..c774902 100644 --- a/datasets/lidc/configs.py +++ b/datasets/lidc/configs.py @@ -1,445 +1,445 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import sys import os from collections import namedtuple sys.path.append(os.path.dirname(os.path.realpath(__file__))) import numpy as np sys.path.append(os.path.dirname(os.path.realpath(__file__))+"/../..") from default_configs import DefaultConfigs # legends, nested classes are not handled well in multiprocessing! hence, Label class def in outer scope Label = namedtuple("Label", ['id', 'name', 'color', 'm_scores']) # m_scores = malignancy scores binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'm_scores', 'bin_vals']) class Configs(DefaultConfigs): def __init__(self, server_env=None): super(Configs, self).__init__(server_env) ######################### # Preprocessing # ######################### self.root_dir = '/home/gregor/networkdrives/E130-Personal/Goetz/Datenkollektive/Lungendaten/Nodules_LIDC_IDRI' self.raw_data_dir = '{}/new_nrrd'.format(self.root_dir) self.pp_dir = '/media/gregor/HDD2TB/data/lidc/pp_20200309_dev' # 'merged' for one gt per image, 'single_annotator' for four gts per image. self.gts_to_produce = ["single_annotator", "merged"] self.target_spacing = (0.7, 0.7, 1.25) ######################### # I/O # ######################### # path to preprocessed data. self.pp_name = 'pp_20190805' self.input_df_name = 'info_df.pickle' self.data_sourcedir = '/media/gregor/HDD2TB/data/lidc/{}/'.format(self.pp_name) #self.data_sourcedir = '/home/gregor/networkdrives/E132-Cluster-Projects/lidc/data/{}/'.format(self.pp_name) # settings for deployment on cluster. if server_env: # path to preprocessed data. self.data_sourcedir = '/datasets/datasets_ramien/lidc/data/{}_npz/'.format(self.pp_name) # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_fpn']. self.model = 'detection_fpn' self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net') self.model_path = os.path.join(self.source_dir, self.model_path) ######################### # Architecture # ######################### # dimension the model operates in. one out of [2, 3]. self.dim = 3 # 'class': standard object classification per roi, pairwise combinable with each of below tasks. # if 'class' is omitted from tasks, object classes will be fg/bg (1/0) from RPN. # 'regression': regress some vector per each roi # 'regression_ken_gal': use kendall-gal uncertainty sigma # 'regression_bin': classify each roi into a bin related to a regression scale self.prediction_tasks = ['class'] self.start_filts = 48 if self.dim == 2 else 18 self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2 self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50' self.norm = "instance_norm" # one of None, 'instance_norm', 'batch_norm' # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform') self.weight_init = None self.regression_n_features = 1 ######################### # Data Loader # ######################### # distorted gt experiments: train on single-annotator gts in a random fashion to investigate network's # handling of noisy gts. # choose 'merged' for single, merged gt per image, or 'single_annotator' for four gts per image. # validation is always performed on same gt kind as training, testing always on merged gt. self.training_gts = "merged" # select modalities from preprocessed data self.channels = [0] self.n_channels = len(self.channels) # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation. self.pre_crop_size_2D = [320, 320] self.patch_size_2D = [320, 320] self.pre_crop_size_3D = [160, 160, 96] self.patch_size_3D = [160, 160, 96] self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D self.pre_crop_size = self.pre_crop_size_2D if self.dim == 2 else self.pre_crop_size_3D # ratio of free sampled batch elements before class balancing is triggered self.batch_random_ratio = 0.1 self.balance_target = "class_targets" if 'class' in self.prediction_tasks else 'rg_bin_targets' # set 2D network to match 3D gt boxes. self.merge_2D_to_3D_preds = self.dim==2 self.observables_rois = [] #self.rg_map = {1:1, 2:2, 3:3, 4:4, 5:5} ######################### # Colors and Legends # ######################### self.plot_frequency = 5 binary_cl_labels = [Label(1, 'benign', (*self.dark_green, 1.), (1, 2)), Label(2, 'malignant', (*self.red, 1.), (3, 4, 5))] quintuple_cl_labels = [Label(1, 'MS1', (*self.dark_green, 1.), (1,)), Label(2, 'MS2', (*self.dark_yellow, 1.), (2,)), Label(3, 'MS3', (*self.orange, 1.), (3,)), Label(4, 'MS4', (*self.bright_red, 1.), (4,)), Label(5, 'MS5', (*self.red, 1.), (5,))] # choose here if to do 2-way or 5-way regression-bin classification task_spec_cl_labels = quintuple_cl_labels self.class_labels = [ # #id #name #color #malignancy score Label( 0, 'bg', (*self.gray, 0.), (0,))] if "class" in self.prediction_tasks: self.class_labels += task_spec_cl_labels else: self.class_labels += [Label(1, 'lesion', (*self.orange, 1.), (1,2,3,4,5))] if any(['regression' in task for task in self.prediction_tasks]): self.bin_labels = [binLabel(0, 'MS0', (*self.gray, 1.), (0,), (0,))] self.bin_labels += [binLabel(cll.id, cll.name, cll.color, cll.m_scores, tuple([ms for ms in cll.m_scores])) for cll in task_spec_cl_labels] self.bin_id2label = {label.id: label for label in self.bin_labels} self.ms2bin_label = {ms: label for label in self.bin_labels for ms in label.m_scores} bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels] self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)} self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)] if self.class_specific_seg: self.seg_labels = self.class_labels else: self.seg_labels = [ # id #name #color Label(0, 'bg', (*self.gray, 0.)), Label(1, 'fg', (*self.orange, 1.)) ] self.class_id2label = {label.id: label for label in self.class_labels} self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0} # class_dict is used in evaluator / ap, auc, etc. statistics, and class 0 (bg) only needs to be # evaluated in debugging self.class_cmap = {label.id: label.color for label in self.class_labels} self.seg_id2label = {label.id: label for label in self.seg_labels} self.cmap = {label.id: label.color for label in self.seg_labels} self.plot_prediction_histograms = True self.plot_stat_curves = False self.has_colorchannels = False self.plot_class_ids = True self.num_classes = len(self.class_dict) # for instance classification (excl background) self.num_seg_classes = len(self.seg_labels) # incl background ######################### # Data Augmentation # ######################### self.da_kwargs={ 'mirror': True, 'mirror_axes': tuple(np.arange(0, self.dim, 1)), 'do_elastic_deform': True, 'alpha':(0., 1500.), 'sigma':(30., 50.), 'do_rotation':True, 'angle_x': (0., 2 * np.pi), 'angle_y': (0., 0), 'angle_z': (0., 0), 'do_scale': True, 'scale':(0.8, 1.1), 'random_crop':False, 'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3), 'border_mode_data': 'constant', 'border_cval_data': 0, 'order_data': 1} if self.dim == 3: self.da_kwargs['do_elastic_deform'] = False self.da_kwargs['angle_x'] = (0, 0.0) self.da_kwargs['angle_y'] = (0, 0.0) #must be 0!! self.da_kwargs['angle_z'] = (0., 2 * np.pi) ################################# # Schedule / Selection / Optim # ################################# self.num_epochs = 130 if self.dim == 2 else 150 self.num_train_batches = 200 if self.dim == 2 else 200 self.batch_size = 20 if self.dim == 2 else 8 # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training) # the former is morge accurate, while the latter is faster (depending on volume size) self.val_mode = 'val_sampling' # only 'val_sampling', 'val_patient' not implemented if self.val_mode == 'val_patient': raise NotImplementedError if self.val_mode == 'val_sampling': self.num_val_batches = 70 self.save_n_models = 4 # set a minimum epoch number for saving in case of instabilities in the first phase of training. self.min_save_thresh = 0 if self.dim == 2 else 0 # criteria to average over for saving epochs, 'criterion':weight. if "class" in self.prediction_tasks: # 'criterion': weight if len(self.class_labels)==3: self.model_selection_criteria = {"benign_ap": 0.5, "malignant_ap": 0.5} elif len(self.class_labels)==6: self.model_selection_criteria = {str(label.name)+"_ap": 1./5 for label in self.class_labels if label.id!=0} elif any("regression" in task for task in self.prediction_tasks): self.model_selection_criteria = {"lesion_ap": 0.2, "lesion_avp": 0.8} self.weight_decay = 1e-5 self.exclude_from_wd = [] self.clip_norm = 200 if 'regression_ken_gal' in self.prediction_tasks else None # number or None # int in [0, dataset_size]. select n patients from dataset for prototyping. If None, all data is used. self.select_prototype_subset = None #self.batch_size ######################### # Testing # ######################### # set the top-n-epochs to be saved for temporal averaging in testing. self.test_n_epochs = self.save_n_models self.test_aug_axes = (0,1,(0,1)) # None or list: choices are 0,1,(0,1) (0==spatial y, 1== spatial x). - self.held_out_test_set = False + self.hold_out_test_set = False self.max_test_patients = "all" # "all" or number self.report_score_level = ['rois', 'patient'] # choose list from 'patient', 'rois' self.patient_class_of_interest = 2 if 'class' in self.prediction_tasks else 1 self.metrics = ['ap', 'auc'] if any(['regression' in task for task in self.prediction_tasks]): self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp', 'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp'] if 'aleatoric' in self.model: self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted'] self.evaluate_fold_means = True self.ap_match_ious = [0.1] # list of ious to be evaluated for ap-scoring. self.min_det_thresh = 0.1 # minimum confidence value to select predictions for evaluation. # aggregation method for test and val_patient predictions. # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf, # nms = standard non-maximum suppression, or None = no clustering self.clustering = 'wbc' # iou thresh (exclusive!) for regarding two preds as concerning the same ROI self.clustering_iou = 0.1 # has to be larger than desired possible overlap iou of model predictions self.plot_prediction_histograms = True self.plot_stat_curves = False self.n_test_plots = 1 ######################### # Assertions # ######################### if not 'class' in self.prediction_tasks: assert self.num_classes == 1 ######################### # Add model specifics # ######################### {'detection_fpn': self.add_det_fpn_configs, 'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs, 'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs, }[self.model]() def rg_val_to_bin_id(self, rg_val): return float(np.digitize(np.mean(rg_val), self.bin_edges)) def add_det_fpn_configs(self): self.learning_rate = [3e-4] * self.num_epochs self.dynamic_lr_scheduling = False # RoI score assigned to aggregation from pixel prediction (connected component). One of ['max', 'median']. self.score_det = 'max' # max number of roi candidates to identify per batch element and class. self.n_roi_candidates = 10 if self.dim == 2 else 30 # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') self.seg_loss_mode = 'wce' # if <1, false positive predictions in foreground are penalized less. self.fp_dice_weight = 1 if self.dim == 2 else 1 if len(self.class_labels)==3: self.wce_weights = [1., 1., 1.] if self.seg_loss_mode=="dice_wce" else [0.1, 1., 1.] elif len(self.class_labels)==6: self.wce_weights = [1., 1., 1., 1., 1., 1.] if self.seg_loss_mode == "dice_wce" else [0.1, 1., 1., 1., 1., 1.] else: raise Exception("mismatch loss weights & nr of classes") self.detection_min_confidence = self.min_det_thresh self.head_classes = self.num_seg_classes def add_mrcnn_configs(self): # learning rate is a list with one entry per epoch. self.learning_rate = [3e-4] * self.num_epochs self.dynamic_lr_scheduling = False # disable the re-sampling of mask proposals to original size for speed-up. # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching), # mask-outputs are optional. self.return_masks_in_train = False self.return_masks_in_val = True self.return_masks_in_test = False # set number of proposal boxes to plot after each epoch. self.n_plot_rpn_props = 5 if self.dim == 2 else 30 # number of classes for network heads: n_foreground_classes + 1 (background) self.head_classes = self.num_classes + 1 self.frcnn_mode = False # feature map strides per pyramid level are inferred from architecture. self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]} # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.) self.rpn_anchor_scales = {'xy': [[8], [16], [32], [64]], 'z': [[2], [4], [8], [16]]} # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3. self.pyramid_levels = [0, 1, 2, 3] # number of feature maps in rpn. typically lowered in 3D to save gpu-memory. self.n_rpn_features = 512 if self.dim == 2 else 128 # anchor ratios and strides per position in feature maps. self.rpn_anchor_ratios = [0.5, 1, 2] self.rpn_anchor_stride = 1 # Threshold for first stage (RPN) non-maximum suppression (NMS): LOWER == HARDER SELECTION self.rpn_nms_threshold = 0.7 if self.dim == 2 else 0.7 # loss sampling settings. self.rpn_train_anchors_per_image = 64 #per batch element self.train_rois_per_image = 6 #per batch element self.roi_positive_ratio = 0.5 self.anchor_matching_iou = 0.7 # factor of top-k candidates to draw from per negative sample (stochastic-hard-example-mining). # poolsize to draw top-k candidates from will be shem_poolsize * n_negative_samples. self.shem_poolsize = 10 self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3) self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5) self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10) self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]]) self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1], self.patch_size_3D[2], self.patch_size_3D[2]]) if self.dim == 2: self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4] self.bbox_std_dev = self.bbox_std_dev[:4] self.window = self.window[:4] self.scale = self.scale[:4] # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element. self.pre_nms_limit = 3000 if self.dim == 2 else 6000 # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True, # since proposals of the entire batch are forwarded through second stage in as one "batch". self.roi_chunk_size = 2500 if self.dim == 2 else 600 self.post_nms_rois_training = 500 if self.dim == 2 else 75 self.post_nms_rois_inference = 500 # Final selection of detections (refine_detections) self.model_max_instances_per_batch_element = 10 if self.dim == 2 else 30 # per batch element and class. self.detection_nms_threshold = 1e-5 # needs to be > 0, otherwise all predictions are one cluster. self.model_min_confidence = 0.1 if self.dim == 2: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride))] for stride in self.backbone_strides['xy']]) else: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride)), int(np.ceil(self.patch_size[2] / stride_z))] for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z'] )]) if self.model == 'retina_net' or self.model == 'retina_unet': self.focal_loss = True # implement extra anchor-scales according to retina-net publication. self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['xy']] self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['z']] self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3 self.n_rpn_features = 256 if self.dim == 2 else 128 # pre-selection of detections for NMS-speedup. per entire batch. self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002 self.anchor_matching_iou = 0.5 if self.model == 'retina_unet': self.operate_stride1 = True diff --git a/datasets/lidc/data_loader.py b/datasets/lidc/data_loader.py index e9cd2a3..4d8c79e 100644 --- a/datasets/lidc/data_loader.py +++ b/datasets/lidc/data_loader.py @@ -1,1026 +1,1026 @@ # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== ''' Data Loader for the LIDC data set. This dataloader expects preprocessed data in .npy or .npz files per patient and a pandas dataframe containing the meta info e.g. file paths, and some ground-truth info like labels, foreground slice ids. LIDC 4-fold annotations storage capacity problem: keep segmentation gts compressed (npz), unpack at each batch generation. ''' import plotting as plg import os import pickle import time from multiprocessing import Pool import numpy as np import pandas as pd from collections import OrderedDict # batch generator tools from https://github.com/MIC-DKFZ/batchgenerators from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror from batchgenerators.transforms.abstract_transforms import Compose from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter from batchgenerators.transforms.spatial_transforms import SpatialTransform from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform import utils.dataloader_utils as dutils from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates from utils.dataloader_utils import BatchGenerator as BatchGeneratorParent def save_obj(obj, name): """Pickle a python object.""" with open(name + '.pkl', 'wb') as f: pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) def vector(item): """ensure item is vector-like (list or array or tuple) :param item: anything """ if not isinstance(item, (list, tuple, np.ndarray)): item = [item] return item class Dataset(dutils.Dataset): r"""Load a dict holding memmapped arrays and clinical parameters for each patient, evtly subset of those. If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to cf.data_dest. :param cf: config object. :param logger: logger. :param subset_ids: subset of patient/sample identifiers to load from whole set. :param data_sourcedir: directory in which to find data, defaults to cf.data_sourcedir if None. :return: dict with imgs, segs, pids, class_labels, observables """ def __init__(self, cf, logger=None, subset_ids=None, data_sourcedir=None, mode='train'): super(Dataset,self).__init__(cf, data_sourcedir) if mode == 'train' and not cf.training_gts == "merged": self.gt_dir = "patient_gts_sa" self.gt_kind = cf.training_gts else: self.gt_dir = "patient_gts_merged" self.gt_kind = "merged" if logger is not None: logger.info("loading {} ground truths for {}".format(self.gt_kind, 'training and validation' if mode=='train' else 'testing')) p_df = pd.read_pickle(os.path.join(self.data_sourcedir, self.gt_dir, cf.input_df_name)) #exclude_pids = ["0305a", "0447a"] # due to non-bg segmentation but bg mal label in nodules 5728, 8840 #p_df = p_df[~p_df.pid.isin(exclude_pids)] if subset_ids is not None: p_df = p_df[p_df.pid.isin(subset_ids)] if logger is not None: logger.info('subset: selected {} instances from df'.format(len(p_df))) if cf.select_prototype_subset is not None: prototype_pids = p_df.pid.tolist()[:cf.select_prototype_subset] p_df = p_df[p_df.pid.isin(prototype_pids)] if logger is not None: logger.warning('WARNING: using prototyping data subset of length {}!!!'.format(len(p_df))) pids = p_df.pid.tolist() # evtly copy data from data_sourcedir to data_dest if cf.server_env and not hasattr(cf, 'data_dir') and hasattr(cf, "data_dest"): # copy and unpack images file_subset = ["{}_img.npz".format(pid) for pid in pids if not os.path.isfile(os.path.join(cf.data_dest,'{}_img.npy'.format(pid)))] file_subset += [os.path.join(self.data_sourcedir, self.gt_dir, cf.input_df_name)] self.copy_data(cf, file_subset=file_subset, keep_packed=False, del_after_unpack=True) # copy and do not unpack segmentations file_subset = [os.path.join(self.gt_dir, "{}_rois.np*".format(pid)) for pid in pids] keep_packed = not cf.training_gts == "merged" self.copy_data(cf, file_subset=file_subset, keep_packed=keep_packed, del_after_unpack=(not keep_packed)) else: cf.data_dir = self.data_sourcedir ext = 'npy' if self.gt_kind == "merged" else 'npz' imgs = [os.path.join(self.data_dir, '{}_img.npy'.format(pid)) for pid in pids] segs = [os.path.join(self.data_dir, self.gt_dir, '{}_rois.{}'.format(pid, ext)) for pid in pids] orig_class_targets = p_df['class_target'].tolist() data = OrderedDict() if self.gt_kind == 'merged': for ix, pid in enumerate(pids): data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid} data[pid]['fg_slices'] = np.array(p_df['fg_slices'].tolist()[ix]) if 'class' in cf.prediction_tasks: if len(cf.class_labels)==3: # malignancy scores are binarized: (benign: 1-2 --> cl 1, malignant: 3-5 --> cl 2) data[pid]['class_targets'] = np.array([2 if ii >= 3 else 1 for ii in orig_class_targets[ix]], dtype='uint8') elif len(cf.class_labels)==6: # classify each malignancy score data[pid]['class_targets'] = np.array([1 if ii==0.5 else np.round(ii) for ii in orig_class_targets[ix]], dtype='uint8') else: raise Exception("mismatch class labels and data-loading implementations.") else: data[pid]['class_targets'] = np.ones_like(np.array(orig_class_targets[ix]), dtype='uint8') if any(['regression' in task for task in cf.prediction_tasks]): data[pid]["regression_targets"] = np.array([vector(v) for v in orig_class_targets[ix]], dtype='float16') data[pid]["rg_bin_targets"] = np.array( [cf.rg_val_to_bin_id(v) for v in data[pid]["regression_targets"]], dtype='uint8') else: for ix, pid in enumerate(pids): data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid} data[pid]['fg_slices'] = np.array(p_df['fg_slices'].values[ix]) if 'class' in cf.prediction_tasks: # malignancy scores are binarized: (benign: 1-2 --> cl 1, malignant: 3-5 --> cl 2) raise NotImplementedError # todo need to consider bg # data[pid]['class_targets'] = np.array( # [[2 if ii >= 3 else 1 for ii in four_fold_targs] for four_fold_targs in orig_class_targets[ix]]) else: data[pid]['class_targets'] = np.array( [[1 if ii > 0 else 0 for ii in four_fold_targs] for four_fold_targs in orig_class_targets[ix]], dtype='uint8') if any(['regression' in task for task in cf.prediction_tasks]): data[pid]["regression_targets"] = np.array( [[vector(v) for v in four_fold_targs] for four_fold_targs in orig_class_targets[ix]], dtype='float16') data[pid]["rg_bin_targets"] = np.array( [[cf.rg_val_to_bin_id(v) for v in four_fold_targs] for four_fold_targs in data[pid]["regression_targets"]], dtype='uint8') cf.roi_items = cf.observables_rois[:] cf.roi_items += ['class_targets'] if any(['regression' in task for task in cf.prediction_tasks]): cf.roi_items += ['regression_targets'] cf.roi_items += ['rg_bin_targets'] self.data = data self.set_ids = np.array(list(self.data.keys())) self.df = None # merged GTs class BatchGenerator_merged(dutils.BatchGenerator): """ creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D) from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size. Actual patch_size is obtained after data augmentation. :param data: data dictionary as provided by 'load_dataset'. :param batch_size: number of patients to sample for the batch :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target """ def __init__(self, cf, data, name="train"): super(BatchGenerator_merged, self).__init__(cf, data) self.crop_margin = np.array(self.cf.patch_size)/8. #min distance of ROI center to edge of cropped_patch. self.p_fg = 0.5 self.empty_samples_max_ratio = 0.6 self.random_count = int(cf.batch_random_ratio * cf.batch_size) self.class_targets = {k: v["class_targets"] for (k, v) in self._data.items()} self.balance_target_distribution(plot=name=="train") def generate_train_batch(self): # samples patients towards equilibrium of foreground classes on a roi-level after sampling a random ratio # fully random patients batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False)) # target-balanced patients batch_patient_ids += list(np.random.choice(self.dataset_pids, size=self.batch_size-self.random_count, replace=False, p=self.p_probs)) batch_data, batch_segs, batch_pids, batch_patient_labels = [], [], [], [] batch_roi_items = {name: [] for name in self.cf.roi_items} # record roi count of classes in batch batch_roi_counts = np.zeros((len(self.unique_ts),), dtype='uint32') batch_empty_counts = np.zeros((len(self.unique_ts),), dtype='uint32') # empty count for full bg samples (empty slices in 2D/patients in 3D) per class for sample in range(self.batch_size): patient = self._data[batch_patient_ids[sample]] data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis] seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0)) batch_pids.append(patient['pid']) (c, y, x, z) = data.shape if self.cf.dim == 2: elig_slices, choose_fg = [], False if len(patient['fg_slices']) > 0: if np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or \ np.random.rand(1)<=self.p_fg: # fg is to be picked for tix in np.argsort(batch_roi_counts): # pick slices of patient that have roi of sought-for target # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero( patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0])-1] == self.unique_ts[tix]) > 0] if len(elig_slices) > 0: choose_fg = True break else: # pick bg elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices']) if len(elig_slices)>0: sl_pick_ix = np.random.choice(elig_slices, size=None) else: sl_pick_ix = np.random.choice(z, size=None) data = data[..., sl_pick_ix] seg = seg[..., sl_pick_ix] # pad data if smaller than pre_crop_size. if np.any([data.shape[dim + 1] < ps for dim, ps in enumerate(self.cf.pre_crop_size)]): new_shape = [np.max([data.shape[dim + 1], ps]) for dim, ps in enumerate(self.cf.pre_crop_size)] data = dutils.pad_nd_image(data, new_shape, mode='constant') seg = dutils.pad_nd_image(seg, new_shape, mode='constant') # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg. crop_dims = [dim for dim, ps in enumerate(self.cf.pre_crop_size) if data.shape[dim + 1] > ps] if len(crop_dims) > 0: if self.cf.dim == 3: choose_fg = np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio)\ or np.random.rand(1) <= self.p_fg if choose_fg and np.any(seg): available_roi_ids = np.unique(seg)[1:] for tix in np.argsort(batch_roi_counts): elig_roi_ids = available_roi_ids[patient[self.balance_target][available_roi_ids-1] == self.unique_ts[tix]] if len(elig_roi_ids)>0: seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None)) break roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)] assert seg[tuple(roi_anchor_pixel)] > 0 # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by # distance to the desired ROI < patch_size /2. # (here final patch size to account for center_crop after data augmentation). sample_seg_center = {} for ii in crop_dims: low = np.max((self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] - (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) high = np.min((data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] + (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) # happens if lesion on the edge of the image. dont care about roi anymore, # just make sure pre-crop is inside image. if low >= high: low = data.shape[ii + 1] // 2 - (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) high = data.shape[ii + 1] // 2 + (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) sample_seg_center[ii] = np.random.randint(low=low, high=high) else: # not guaranteed to be empty. probability of emptiness depends on the data. sample_seg_center = {ii: np.random.randint(low=self.cf.pre_crop_size[ii]//2, high=data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2) for ii in crop_dims} for ii in crop_dims: min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2) max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2) data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1) seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii) batch_data.append(data) batch_segs.append(seg[np.newaxis]) for o in batch_roi_items: #after loop, holds every entry of every batchpatient per roi-item batch_roi_items[o].append(patient[o]) if self.cf.dim == 3: for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero==0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 elif self.cf.dim == 2: for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero == 0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 data = np.array(batch_data).astype(np.float16) seg = np.array(batch_segs).astype(np.uint8) batch = {'data': data, 'seg': seg, 'pid': batch_pids, 'roi_counts':batch_roi_counts, 'empty_counts': batch_empty_counts} for key,val in batch_roi_items.items(): #extend batch dic by roi-wise items (obs, class ids, regression vectors...) batch[key] = np.array(val) return batch class PatientBatchIterator_merged(dutils.PatientBatchIterator): """ creates a test generator that iterates over entire given dataset returning 1 patient per batch. Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actualy evaluation (done in 3D), if willing to accept speed-loss during training. :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or batch_size = n_2D_patches in 2D . """ def __init__(self, cf, data): # threads in augmenter super(PatientBatchIterator_merged, self).__init__(cf, data) self.patient_ix = 0 self.patch_size = cf.patch_size + [1] if cf.dim == 2 else cf.patch_size def generate_train_batch(self, pid=None): if pid is None: pid = self.dataset_pids[self.patient_ix] patient = self._data[pid] data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0)) seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0)) # pad data if smaller than patch_size seen during training. if np.any([data.shape[dim] < ps for dim, ps in enumerate(self.patch_size)]): new_shape = [np.max([data.shape[dim], self.patch_size[dim]]) for dim, ps in enumerate(self.patch_size)] data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape. seg = dutils.pad_nd_image(seg, new_shape) # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor. if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: out_data = data[np.newaxis, np.newaxis] out_seg = seg[np.newaxis, np.newaxis] batch_3D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_3D[o] = np.array([patient[o]]) converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg) batch_3D = converter(**batch_3D) batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_3D["patient_" + o] = batch_3D[o] if self.cf.dim == 2: out_data = np.transpose(data, axes=(2, 0, 1))[:, np.newaxis] # (z, c, x, y ) out_seg = np.transpose(seg, axes=(2, 0, 1))[:, np.newaxis] batch_2D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_2D[o] = np.repeat(np.array([patient[o]]), out_data.shape[0], axis=0) converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg) batch_2D = converter(**batch_2D) if self.cf.merge_2D_to_3D_preds: batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_3D[o] else: batch_2D.update({'patient_bb_target': batch_2D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_2D[o] out_batch = batch_3D if self.cf.dim == 3 else batch_2D out_batch.update({'pid': np.array([patient['pid']] * len(out_data))}) # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension. # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1. if np.any([data.shape[dim] > self.patch_size[dim] for dim in range(3)]): patient_batch = out_batch patch_crop_coords_list = dutils.get_patch_crop_coords(data, self.patch_size) new_img_batch, new_seg_batch = [], [] for cix, c in enumerate(patch_crop_coords_list): seg_patch = seg[c[0]:c[1], c[2]: c[3], c[4]:c[5]] new_seg_batch.append(seg_patch) tmp_c_5 = c[5] new_img_batch.append(data[c[0]:c[1], c[2]:c[3], c[4]:tmp_c_5]) data = np.array(new_img_batch)[:, np.newaxis] # (n_patches, c, x, y, z) seg = np.array(new_seg_batch)[:, np.newaxis] # (n_patches, 1, x, y, z) if self.cf.dim == 2: # all patches have z dimension 1 (slices). discard dimension data = data[..., 0] seg = seg[..., 0] patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'), 'pid': np.array([patient['pid']] * data.shape[0])} for o in self.cf.roi_items: patch_batch[o] = np.repeat(np.array([patient[o]]), len(patch_crop_coords_list), axis=0) # patient-wise (orig) batch info for putting the patches back together after prediction for o in self.cf.roi_items: patch_batch["patient_" + o] = patient_batch['patient_' + o] if self.cf.dim == 2: # this could also be named "unpatched_2d_roi_items" patch_batch["patient_" + o + "_2d"] = patient_batch[o] # adding patient-wise data and seg adds about 2 GB of additional RAM consumption to a batch 20x288x288 # and enables calculating test-dice/viewing patient-wise results in test # remove, but also remove dice from metrics, when like to save memory patch_batch['patient_data'] = patient_batch['data'] patch_batch['patient_seg'] = patient_batch['seg'] patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] if self.cf.dim == 2: patch_batch['patient_bb_target_2d'] = patient_batch['bb_target'] patch_batch['original_img_shape'] = patient_batch['original_img_shape'] converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, self.cf.class_specific_seg) patch_batch = converter(**patch_batch) out_batch = patch_batch self.patient_ix += 1 if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 return out_batch # single-annotator GTs class BatchGenerator_sa(BatchGeneratorParent): """ creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D) from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size. Actual patch_size is obtained after data augmentation. :param data: data dictionary as provided by 'load_dataset'. :param batch_size: number of patients to sample for the batch :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target """ # noinspection PyMethodOverriding def balance_target_distribution(self, rater, plot=False): """ :param rater: for which rater slot to generate the distribution :param self.targets: dic holding {patient_specifier : patient-wise-unique ROI targets} :param plot: whether to plot the generated patient distributions :return: probability distribution over all pids. draw without replace from this. """ # todo limit bg weights unique_ts = np.unique([v[rater] for pat in self.targets.values() for v in pat]) sample_stats = pd.DataFrame(columns=[str(ix) + suffix for ix in unique_ts for suffix in ["", "_bg"]], index=list(self.targets.keys())) for pid in sample_stats.index: for targ in unique_ts: fg_count = 0 if len(self.targets[pid]) == 0 else np.count_nonzero(self.targets[pid][:, rater] == targ) sample_stats.loc[pid, str(targ)] = int(fg_count > 0) sample_stats.loc[pid, str(targ) + "_bg"] = int(fg_count == 0) target_stats = sample_stats.agg( ("sum", lambda col: col.sum() / len(self._data)), axis=0, sort=False).rename({"": "relative"}) anchor = 1. - target_stats.loc["relative"].iloc[0] fg_bg_weights = anchor / target_stats.loc["relative"] cum_weights = anchor * len(fg_bg_weights) fg_bg_weights /= cum_weights p_probs = sample_stats.apply(self.sample_targets_to_weights, args=(fg_bg_weights,), axis=1).sum(axis=1) p_probs = p_probs / p_probs.sum() if plot: print("Rater: {}. Applying class-weights:\n {}".format(rater, fg_bg_weights)) if len(sample_stats.columns) == 2: # assert that probs are calc'd correctly: # (p_probs * sample_stats["1"]).sum() == (p_probs * sample_stats["1_bg"]).sum() # only works if one label per patient (multi-label expectations depend on multi-label occurences). for rater in range(self.rater_bsize): expectations = [] for targ in sample_stats.columns: expectations.append((p_probs[rater] * sample_stats[targ]).sum()) assert np.allclose(expectations, expectations[0], atol=1e-4), "expectation values for fgs/bgs: {}".format( expectations) if plot: plg.plot_batchgen_distribution(self.cf, self.dataset_pids, p_probs, self.balance_target, out_file=os.path.join(self.plot_dir, "train_gen_distr_"+str(self.cf.fold)+"_rater"+str(rater)+".png")) return p_probs, unique_ts, sample_stats def __init__(self, cf, data, name="train"): super(BatchGenerator_sa, self).__init__(cf, data) self.name = name self.crop_margin = np.array(self.cf.patch_size) / 8. # min distance of ROI center to edge of cropped_patch. self.p_fg = 0.5 self.empty_samples_max_ratio = 0.6 self.random_count = int(cf.batch_random_ratio * cf.batch_size) self.rater_bsize = 4 unique_ts_total = set() self.p_probs = [] self.sample_stats = [] # todo resolve pickling error # p = Pool(processes=min(self.rater_bsize, cf.n_workers)) # mp_res = p.starmap(self.balance_target_distribution, [(r, name=="train") for r in range(self.rater_bsize)]) # p.close() # p.join() # for r, res in enumerate(mp_res): # p_probs, unique_ts, sample_stats = res # self.p_probs.append(p_probs) # self.sample_stats.append(sample_stats) # unique_ts_total.update(unique_ts) for r in range(self.rater_bsize): # todo multiprocess. takes forever p_probs, unique_ts, sample_stats = self.balance_target_distribution(r, plot=name == "train") self.p_probs.append(p_probs) self.sample_stats.append(sample_stats) unique_ts_total.update(unique_ts) self.unique_ts = sorted(list(unique_ts_total)) self.stats = {"roi_counts": np.zeros(len(self.unique_ts,), dtype='uint32'), "empty_counts": np.zeros(len(self.unique_ts,), dtype='uint32')} def generate_train_batch(self): rater = np.random.randint(self.rater_bsize) # samples patients towards equilibrium of foreground classes on a roi-level (after randomly sampling the ratio batch_random_ratio). # random patients batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False)) # target-balanced patients batch_patient_ids += list(np.random.choice(self.dataset_pids, size=self.batch_size-self.random_count, replace=False, p=self.p_probs[rater])) batch_data, batch_segs, batch_pids, batch_patient_labels = [], [], [], [] batch_roi_items = {name: [] for name in self.cf.roi_items} # record roi count of classes in batch batch_roi_counts = np.zeros((len(self.unique_ts),), dtype='uint32') batch_empty_counts = np.zeros((len(self.unique_ts),), dtype='uint32') # empty count for full bg samples (empty slices in 2D/patients in 3D) for sample in range(self.batch_size): patient = self._data[batch_patient_ids[sample]] patient_balance_ts = np.array([roi[rater] for roi in patient[self.balance_target]]) data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis] seg = np.load(patient['seg'], mmap_mode='r') seg = np.transpose(seg[list(seg.keys())[0]][rater], axes=(1, 2, 0)) batch_pids.append(patient['pid']) (c, y, x, z) = data.shape if self.cf.dim == 2: elig_slices, choose_fg = [], False if len(patient['fg_slices']) > 0: if np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or \ np.random.rand(1) <= self.p_fg: # fg is to be picked for tix in np.argsort(batch_roi_counts): # pick slices of patient that have roi of sought-for target # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero( patient_balance_ts[np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] == self.unique_ts[tix]) > 0] if len(elig_slices) > 0: choose_fg = True break else: # pick bg elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices'][rater]) if len(elig_slices) > 0: sl_pick_ix = np.random.choice(elig_slices, size=None) else: sl_pick_ix = np.random.choice(z, size=None) data = data[..., sl_pick_ix] seg = seg[..., sl_pick_ix] # pad data if smaller than pre_crop_size. if np.any([data.shape[dim + 1] < ps for dim, ps in enumerate(self.cf.pre_crop_size)]): new_shape = [np.max([data.shape[dim + 1], ps]) for dim, ps in enumerate(self.cf.pre_crop_size)] data = dutils.pad_nd_image(data, new_shape, mode='constant') seg = dutils.pad_nd_image(seg, new_shape, mode='constant') # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg. crop_dims = [dim for dim, ps in enumerate(self.cf.pre_crop_size) if data.shape[dim + 1] > ps] if len(crop_dims) > 0: if self.cf.dim == 3: choose_fg = np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or \ np.random.rand(1) <= self.p_fg if choose_fg and np.any(seg): available_roi_ids = np.unique(seg[seg>0]) assert np.all(patient_balance_ts[available_roi_ids-1]>0), "trying to choose roi with rating 0" for tix in np.argsort(batch_roi_counts): elig_roi_ids = available_roi_ids[ patient_balance_ts[available_roi_ids-1] == self.unique_ts[tix] ] if len(elig_roi_ids)>0: seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None)) roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)] break assert seg[tuple(roi_anchor_pixel)] > 0, "roi_anchor_pixel not inside roi: {}, pb_ts {}, elig ids {}".format(tuple(roi_anchor_pixel), patient_balance_ts, elig_roi_ids) # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by # distance to the desired ROI < patch_size /2. # (here final patch size to account for center_crop after data augmentation). sample_seg_center = {} for ii in crop_dims: low = np.max((self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] - (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) high = np.min((data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] + (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) # happens if lesion on the edge of the image. dont care about roi anymore, # just make sure pre-crop is inside image. if low >= high: low = data.shape[ii + 1] // 2 - (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) high = data.shape[ii + 1] // 2 + (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) sample_seg_center[ii] = np.random.randint(low=low, high=high) else: # not guaranteed to be empty. probability of emptiness depends on the data. sample_seg_center = {ii: np.random.randint(low=self.cf.pre_crop_size[ii]//2, high=data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2) for ii in crop_dims} for ii in crop_dims: min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2) max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2) data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1) seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii) batch_data.append(data) batch_segs.append(seg[np.newaxis]) for o in batch_roi_items: #after loop, holds every entry of every batchpatient per roi-item batch_roi_items[o].append([roi[rater] for roi in patient[o]]) if self.cf.dim == 3: for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero==0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 elif self.cf.dim == 2: for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero == 0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 data = np.array(batch_data).astype('float16') seg = np.array(batch_segs).astype('uint8') batch = {'data': data, 'seg': seg, 'pid': batch_pids, 'rater_id': rater, 'roi_counts': batch_roi_counts, 'empty_counts': batch_empty_counts} for key,val in batch_roi_items.items(): #extend batch dic by roi-wise items (obs, class ids, regression vectors...) batch[key] = np.array(val) return batch class PatientBatchIterator_sa(dutils.PatientBatchIterator): """ creates a test generator that iterates over entire given dataset returning 1 patient per batch. Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actual evaluation (done in 3D), if willing to accept speed loss during training. :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or batch_size = n_2D_patches in 2D . This is the data & gt loader for the 4-fold single-annotator GTs: each data input has separate annotations of 4 annotators. the way the pipeline is currently setup, the single-annotator GTs are only used if training with validation mode val_patient; during testing the Iterator with the merged GTs is used. # todo mode val_patient not implemented yet (since very slow). would need to sample from all available rater GTs. """ def __init__(self, cf, data): #threads in augmenter super(PatientBatchIterator_sa, self).__init__(cf, data) self.cf = cf self.patient_ix = 0 self.dataset_pids = list(self._data.keys()) self.patch_size = cf.patch_size+[1] if cf.dim==2 else cf.patch_size self.rater_bsize = 4 def generate_train_batch(self, pid=None): if pid is None: pid = self.dataset_pids[self.patient_ix] patient = self._data[pid] data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0)) # all gts are 4-fold and npz! seg = np.load(patient['seg'], mmap_mode='r') seg = np.transpose(seg[list(seg.keys())[0]], axes=(0, 2, 3, 1)) # pad data if smaller than patch_size seen during training. if np.any([data.shape[dim] < ps for dim, ps in enumerate(self.patch_size)]): new_shape = [np.max([data.shape[dim], self.patch_size[dim]]) for dim, ps in enumerate(self.patch_size)] data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape. seg = dutils.pad_nd_image(seg, new_shape) # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor. if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: out_data = data[np.newaxis, np.newaxis] out_seg = seg[:, np.newaxis] batch_3D = {'data': out_data, 'seg': out_seg} for item in self.cf.roi_items: batch_3D[item] = [] for r in range(self.rater_bsize): for item in self.cf.roi_items: batch_3D[item].append(np.array([roi[r] for roi in patient[item]])) converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg) batch_3D = converter(**batch_3D) batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_3D["patient_" + o] = batch_3D[o] if self.cf.dim == 2: out_data = np.transpose(data, axes=(2, 0, 1))[:, np.newaxis] # (z, c, y, x ) out_seg = np.transpose(seg, axes=(0, 3, 1, 2))[:, :, np.newaxis] # (n_raters, z, 1, y,x) batch_2D = {'data': out_data} for item in ["seg", "bb_target"]+self.cf.roi_items: batch_2D[item] = [] converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg) for r in range(self.rater_bsize): tmp_batch = {"seg": out_seg[r]} for item in self.cf.roi_items: tmp_batch[item] = np.repeat(np.array([[roi[r] for roi in patient[item]]]), out_data.shape[0], axis=0) tmp_batch = converter(**tmp_batch) for item in ["seg", "bb_target"]+self.cf.roi_items: batch_2D[item].append(tmp_batch[item]) # for item in ["seg", "bb_target"]+self.cf.roi_items: # batch_2D[item] = np.array(batch_2D[item]) if self.cf.merge_2D_to_3D_preds: batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_3D[o] else: batch_2D.update({'patient_bb_target': batch_2D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_2D[o] out_batch = batch_3D if self.cf.dim == 3 else batch_2D out_batch.update({'pid': np.array([patient['pid']] * out_data.shape[0])}) # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension. # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1. if np.any([data.shape[dim] > self.patch_size[dim] for dim in range(3)]): patient_batch = out_batch patch_crop_coords_list = dutils.get_patch_crop_coords(data, self.patch_size) new_img_batch = [] new_seg_batch = [] for cix, c in enumerate(patch_crop_coords_list): seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]] new_seg_batch.append(seg_patch) tmp_c_5 = c[5] new_img_batch.append(data[c[0]:c[1], c[2]:c[3], c[4]:tmp_c_5]) data = np.array(new_img_batch)[:, np.newaxis] # (n_patches, c, x, y, z) seg = np.transpose(np.array(new_seg_batch), axes=(1,0,2,3,4))[:,:,np.newaxis] # (n_raters, n_patches, x, y, z) if self.cf.dim == 2: # all patches have z dimension 1 (slices). discard dimension data = data[..., 0] seg = seg[..., 0] patch_batch = {'data': data.astype('float32'), 'pid': np.array([patient['pid']] * data.shape[0])} # for o in self.cf.roi_items: # patch_batch[o] = np.repeat(np.array([patient[o]]), len(patch_crop_coords_list), axis=0) converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, self.cf.class_specific_seg) for item in ["seg", "bb_target"]+self.cf.roi_items: patch_batch[item] = [] # coord_list = [np.min(seg_ixs[:, 1]) - 1, np.min(seg_ixs[:, 2]) - 1, np.max(seg_ixs[:, 1]) + 1, # IndexError: index 2 is out of bounds for axis 1 with size 2 for r in range(self.rater_bsize): tmp_batch = {"seg": seg[r]} for item in self.cf.roi_items: tmp_batch[item] = np.repeat(np.array([[roi[r] for roi in patient[item]]]), len(patch_crop_coords_list), axis=0) tmp_batch = converter(**tmp_batch) for item in ["seg", "bb_target"]+self.cf.roi_items: patch_batch[item].append(tmp_batch[item]) # patient-wise (orig) batch info for putting the patches back together after prediction for o in self.cf.roi_items: patch_batch["patient_" + o] = patient_batch['patient_'+o] if self.cf.dim==2: # this could also be named "unpatched_2d_roi_items" patch_batch["patient_"+o+"_2d"] = patient_batch[o] # adding patient-wise data and seg adds about 2 GB of additional RAM consumption to a batch 20x288x288 # and enables calculating test-dice/viewing patient-wise results in test # remove, but also remove dice from metrics, if you like to save memory patch_batch['patient_data'] = patient_batch['data'] patch_batch['patient_seg'] = patient_batch['seg'] patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] if self.cf.dim==2: patch_batch['patient_bb_target_2d'] = patient_batch['bb_target'] patch_batch['original_img_shape'] = patient_batch['original_img_shape'] out_batch = patch_batch self.patient_ix += 1 if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 return out_batch def create_data_gen_pipeline(cf, patient_data, is_training=True): """ create multi-threaded train/val/test batch generation and augmentation pipeline. :param cf: configs object. :param patient_data: dictionary containing one dictionary per patient in the train/test subset. :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing) :return: multithreaded_generator """ BG_name = "train" if is_training else "val" data_gen = BatchGenerator_merged(cf, patient_data, name=BG_name) if cf.training_gts=='merged' else \ BatchGenerator_sa(cf, patient_data, name=BG_name) # add transformations to pipeline. my_transforms = [] if is_training: if cf.da_kwargs["mirror"]: mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes']) my_transforms.append(mirror_transform) spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop']) my_transforms.append(spatial_transform) else: my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) if cf.create_bounding_box_targets: my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg)) all_transforms = Compose(my_transforms) multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=data_gen.n_filled_threads, seeds=range(data_gen.n_filled_threads)) return multithreaded_generator def get_train_generators(cf, logger, data_statistics=True): """ wrapper function for creating the training batch generator pipeline. returns the train/val generators. selects patients according to cv folds (generated by first run/fold of experiment): splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set) - If cf.held_out_test_set is True, adds the test split to the training data. + If cf.hold_out_test_set is True, adds the test split to the training data. """ dataset = Dataset(cf, logger) dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits) dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle')) set_splits = dataset.fg.splits test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1) train_ids = np.concatenate(set_splits, axis=0) - if cf.held_out_test_set: + if cf.hold_out_test_set: train_ids = np.concatenate((train_ids, test_ids), axis=0) test_ids = [] train_data = {k: v for (k, v) in dataset.data.items() if k in train_ids} val_data = {k: v for (k, v) in dataset.data.items() if k in val_ids} logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids), len(test_ids))) if data_statistics: dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids}, plot_dir=os.path.join(cf.plot_dir,"dataset")) batch_gen = {} batch_gen['train'] = create_data_gen_pipeline(cf, train_data, is_training=True) batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, is_training=False) if cf.val_mode == 'val_patient': assert cf.training_gts == 'merged', 'val_patient not yet implemented for sa gts' batch_gen['val_patient'] = PatientBatchIterator_merged(cf, val_data) if cf.training_gts=='merged' \ else PatientBatchIterator_sa(cf, val_data) batch_gen['n_val'] = len(val_data) if cf.max_val_patients=="all" else min(len(val_data), cf.max_val_patients) else: batch_gen['n_val'] = cf.num_val_batches return batch_gen def get_test_generator(cf, logger): """ wrapper function for creating the test batch generator pipeline. selects patients according to cv folds (generated by first run/fold of experiment) - If cf.held_out_test_set is True, gets the data from an external folder instead. + If cf.hold_out_test_set is True, gets the data from an external folder instead. """ - if cf.held_out_test_set: + if cf.hold_out_test_set: sourcedir = cf.test_data_sourcedir test_ids = None else: sourcedir = None with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle: set_splits = pickle.load(handle) test_ids = set_splits[cf.fold] test_data = Dataset(cf, logger, subset_ids=test_ids, data_sourcedir=sourcedir, mode="test").data logger.info("data set loaded with: {} test patients".format(len(test_ids))) batch_gen = {} batch_gen['test'] = PatientBatchIterator_merged(cf, test_data) batch_gen['n_test'] = len(test_ids) if cf.max_test_patients == "all" else min(cf.max_test_patients, len(test_ids)) return batch_gen if __name__ == "__main__": import sys sys.path.append('../') import plotting as plg import utils.exp_utils as utils from configs import Configs cf = Configs() cf.batch_size = 3 #dataset_path = os.path.dirname(os.path.realpath(__file__)) #exp_path = os.path.join(dataset_path, "experiments/dev") #cf = utils.prep_exp(dataset_path, exp_path, server_env=False, use_stored_settings=False, is_training=True) cf.created_fold_id_pickle = False total_stime = time.time() times = {} # cf.server_env = True # cf.data_dir = "experiments/dev_data" # dataset = Dataset(cf) # patient = dataset['Master_00018'] cf.exp_dir = "experiments/dev/" cf.plot_dir = cf.exp_dir + "plots" os.makedirs(cf.exp_dir, exist_ok=True) cf.fold = 0 logger = utils.get_logger(cf.exp_dir) gens = get_train_generators(cf, logger) train_loader = gens['train'] for i in range(1): stime = time.time() #ex_batch = next(train_loader) print("train batch", i) times["train_batch"] = time.time() - stime #plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch.png", show_gt_labels=True) # # # with open(os.path.join(cf.exp_dir, "fold_"+str(cf.fold), "BatchGenerator_stats.txt"), mode="w") as file: # # train_loader.generator.print_stats(logger, file) # val_loader = gens['val_sampling'] stime = time.time() ex_batch = next(val_loader) times["val_batch"] = time.time() - stime stime = time.time() #plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch.png", show_gt_labels=True, plot_mods=False, # show_info=False) times["val_plot"] = time.time() - stime # test_loader = get_test_generator(cf, logger)["test"] stime = time.time() ex_batch = test_loader.generate_train_batch() times["test_batch"] = time.time() - stime stime = time.time() plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png", get_time=False)#, sample_picks=[0,1,2,3]) times["test_patchbatch_plot"] = time.time() - stime # ex_batch['data'] = ex_batch['patient_data'] # ex_batch['seg'] = ex_batch['patient_seg'] # ex_batch['bb_target'] = ex_batch['patient_bb_target'] # for item in cf.roi_items: # ex_batch[] # stime = time.time() # #ex_batch = next(test_loader) # ex_batch = next(test_loader) # plg.view_batch(cf, ex_batch, show_gt_labels=False, show_gt_boxes=True, patient_items=True,# vol_slice_picks=[146,148, 218,220], # out_file="experiments/dev/dev_expatientbatch.png") # , sample_picks=[0,1,2,3]) # times["test_patient_batch_plot"] = time.time() - stime print("Times recorded throughout:") for (k, v) in times.items(): print(k, "{:.2f}".format(v)) mins, secs = divmod((time.time() - total_stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) diff --git a/datasets/prostate/configs.py b/datasets/prostate/configs.py index 2de02f3..4171cbe 100644 --- a/datasets/prostate/configs.py +++ b/datasets/prostate/configs.py @@ -1,588 +1,588 @@ __author__ = '' #credit Paul F. Jaeger ######################### # Example Config # ######################### import os import sys import pickle import numpy as np import torch from collections import namedtuple from default_configs import DefaultConfigs def load_obj(file_path): with open(file_path, 'rb') as handle: return pickle.load(handle) # legends, nested classes are not handled well in multiprocessing! hence, Label class def in outer scope Label = namedtuple("Label", ['id', 'name', 'color', 'gleasons']) binLabel = namedtuple("Label", ['id', 'name', 'color', 'gleasons', 'bin_vals']) class Configs(DefaultConfigs): #todo change to Configs def __init__(self, server_env=None): ######################### # General # ######################### super(Configs, self).__init__(server_env) ######################### # I/O # ######################### self.data_sourcedir = "/mnt/HDD2TB/Documents/data/prostate/data_di_250519_ps384_gs6071/" #self.data_sourcedir = "/mnt/HDD2TB/Documents/data/prostate/data_t2_250519_ps384_gs6071/" #self.data_sourcedir = "/mnt/HDD2TB/Documents/data/prostate/data_analysis/" if server_env: self.data_sourcedir = "/datasets/data_ramien/prostate/data_di_250519_ps384_gs6071_npz/" #self.data_sourcedir = '/datasets/data_ramien/prostate/data_t2_250519_ps384_gs6071_npz/' #self.data_sourcedir = "/mnt/HDD2TB/Documents/data/prostate/data_di_ana_151118_ps384_gs60/" self.histo_dir = os.path.join(self.data_sourcedir,"histos/") self.info_dict_name = 'master_info.pkl' self.info_dict_path = os.path.join(self.data_sourcedir, self.info_dict_name) self.config_path = os.path.realpath(__file__) # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_fpn']. self.model = 'detection_fpn' self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net') self.model_path = os.path.join(self.source_dir,self.model_path) self.select_prototype_subset = None ######################### # Preprocessing # ######################### self.missing_pz_subjects = [#189, 196, 198, 205, 211, 214, 215, 217, 218, 219, 220, #223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, #234, 235, 236, 237, 238, 239, 240, 241, 242, 244, 258, #261, 262, 264, 267, 268, 269, 270, 271, 273, 275, 276, #277, 278, 283 ] self.no_bval_radval_subjects = [57] #this guy has master id 222 self.prepro = { 'data_dir': '/home/gregor/networkdrives/E132-Projekte/Move_to_E132-Rohdaten/Prisma_Master/Daten/', 'dir_spec': 'Master', #'images': {'t2': 'T2TRA', 'adc': 'ADC1500', 'b50': 'BVAL50', 'b500': 'BVAL500', # 'b1000': 'BVAL1000', 'b1500': 'BVAL1500'}, #'images': {'adc': 'ADC1500', 'b50': 'BVAL50', 'b500': 'BVAL500', 'b1000': 'BVAL1000', 'b1500': 'BVAL1500'}, 'images': {'t2': 'T2TRA'}, 'anatomical_masks': ['seg_T2_PRO'], # try: 'seg_T2_PRO','seg_T2_PZ', 'seg_ADC_PRO', 'seg_ADC_PZ', 'merge_mode' : 'union', #if registered data w/ two gts: take 'union' or 'adc' or 't2' of gt 'rename_tags': {'seg_ADC_PRO':"pro", 'seg_T2_PRO':"pro", 'seg_ADC_PZ':"pz", 'seg_T2_PZ':"pz"}, 'lesion_postfix': '_Re', #lesion files are tagged seg_MOD_LESx 'img_postfix': "_resampled2", #"_resampled2_registered", 'overall_postfix': ".nrrd", #including filetype ending! 'histo_dir': '/home/gregor/networkdrives/E132-Projekte/Move_to_E132-Rohdaten/Prisma_Master/Dokumente/', 'histo_dir_out': self.histo_dir, 'histo_lesion_based': 'MasterHistoAll.csv', 'histo_patient_based': 'MasterPatientbasedAll_clean.csv', 'histo_id_column_name': 'Master_ID', 'histo_pb_id_column_name': 'Master_ID_Short', #for patient histo 'excluded_prisma_subjects': [], 'excluded_radval_subjects': self.no_bval_radval_subjects, 'excluded_master_subjects': self.missing_pz_subjects, 'seg_labels': {'tz': 0, 'pz': 0, 'lesions':'roi'}, #set as hard label or 'roi' to have seg labels represent obj instance count #if not given 'lesions' are numbered highest seg label +lesion-nr-in-histofile 'class_labels': {'lesions':'gleason'}, #0 is not bg, but first fg class! #i.e., prepro labels are shifted by -1 towards later training labels in gt, legends, dicts, etc. #evtly set lesions to 'gleason' and check gleason remap in prepro #'gleason_thresh': 71, 'gleason_mapping': {0: -1, 60:0, 71:1, 72:1, 80:1, 90:1, 91:1, 92:1}, 'gleason_map': self.gleason_map, #see below 'color_palette': [self.green, self.red], 'output_directory': self.data_sourcedir, 'modalities2concat' : "all", #['t2', 'adc','b50','b500','b1000','b1500'], #will be concatenated on colorchannel 'center_of_mass_crop': True, 'mod_scaling' : (1,1,1), #z,y,x 'pre_crop_size': [20, 384, 384], #z,y,x, z-cropping and non-square not implemented atm!! 'swap_yx_to_xy': False, #change final spatial shape from z,y,x to z,x,y 'normalization': {'percentiles':[1., 99.]}, 'interpolation': 'nearest', 'observables_patient': ['Original_ID', 'GSBx', 'PIRADS2', 'PSA'], 'observables_rois': ['lesion_gleasons'], 'info_dict_path': self.info_dict_path, 'npz_dir' : self.data_sourcedir[:-1]+"_npz" #if not None: convert to npz, copy data here } if self.prepro["modalities2concat"] == "all": self.prepro["modalities2concat"] = list(self.prepro["images"].keys()) ######################### # Architecture # ######################### # dimension the model operates in. one out of [2, 3]. self.dim = 2 # 'class': standard object classification per roi, pairwise combinable with each of below tasks. # if 'class' is omitted from tasks, object classes will be fg/bg (1/0) from RPN. # 'regression': regress some vector per each roi # 'regression_ken_gal': use kendall-gal uncertainty sigma # 'regression_bin': classify each roi into a bin related to a regression scale self.prediction_tasks = ['class',] self.start_filts = 48 if self.dim == 2 else 18 self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2 self.res_architecture = 'resnet50' # 'resnet101' or 'resnet50' self.weight_init = None #'kaiming_normal' #, 'xavier' or None-->pytorch standard, self.norm = None #'instance_norm' # one of 'None', 'instance_norm', 'batch_norm' self.relu = 'relu' # 'relu' or 'leaky_relu' self.regression_n_features = 1 #length of regressor target vector (always 1D) ######################### # Data Loader # ######################### self.seed = 17 self.n_workers = 16 if server_env else os.cpu_count() self.batch_size = 10 if self.dim == 2 else 6 self.channels = [1, 2, 3, 4] # modalities2load, see prepo self.n_channels = len(self.channels) # for compatibility, but actually redundant # which channel (mod) to show as bg in plotting, will be extra added to batch if not in self.channels self.plot_bg_chan = 0 self.pre_crop_size = list(np.array(self.prepro['pre_crop_size'])[[1, 2, 0]]) # now y,x,z self.crop_margin = [20, 20, 1] # has to be smaller than respective patch_size//2 self.patch_size_2D = self.pre_crop_size[:2] #[288, 288] self.patch_size_3D = self.pre_crop_size[:2] + [8] # only numbers divisible by 2 multiple times # (at least 5 times for x,y, at least 3 for z)! # otherwise likely to produce error in crop fct or net self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D self.balance_target = "class_targets" if 'class' in self.prediction_tasks else 'rg_bin_targets' # ratio of fully random patients drawn during batch generation # resulting batch random count is rounded down to closest integer self.batch_random_ratio = 0.2 if self.dim==2 else 0.4 self.observables_patient = ['Original_ID', 'GSBx', 'PIRADS2'] self.observables_rois = ['lesion_gleasons'] self.regression_target = "lesion_gleasons" # name of the info_dict entry holding regression targets # linear mapping self.rg_map = {0: 0, 60: 1, 71: 2, 72: 3, 80: 4, 90: 5, 91: 6, 92: 7, None: 0} # non-linear mapping #self.rg_map = {0: 0, 60: 1, 71: 6, 72: 7.5, 80: 9, 90: 10, 91: 10, 92: 10, None: 0} ######################### # Colors and Legends # ######################### self.plot_frequency = 5 # colors self.gravity_col_palette = [self.green, self.yellow, self.orange, self.bright_red, self.red, self.dark_red] self.gs_labels = [ Label(0, 'bg', self.gray, (0,)), Label(60, 'GS60', self.dark_green, (60,)), Label(71, 'GS71', self.dark_yellow, (71,)), Label(72, 'GS72', self.orange, (72,)), Label(80, 'GS80', self.brighter_red,(80,)), Label(90, 'GS90', self.bright_red, (90,)), Label(91, 'GS91', self.red, (91,)), Label(92, 'GS92', self.dark_red, (92,)) ] self.gs2label = {label.id: label for label in self.gs_labels} binary_cl_labels = [Label(1, 'benign', (*self.green, 1.), (60,)), Label(2, 'malignant', (*self.red, 1.), (71,72,80,90,91,92)), #Label(3, 'pz', (*self.blue, 1.), (None,)), #Label(4, 'tz', (*self.aubergine, 1.), (None,)) ] self.class_labels = [ #id #name #color #gleason score Label( 0, 'bg', (*self.gray, 0.), (0,))] if "class" in self.prediction_tasks: self.class_labels += binary_cl_labels # self.class_labels += [Label(cl, cl_dic["name"], cl_dic["color"], tuple(cl_dic["gleasons"])) # for cl, cl_dic in # load_obj(os.path.join(self.data_sourcedir, "pp_class_labels.pkl")).items()] else: self.class_labels += [Label( 1, 'lesion', (*self.red, 1.), (60,71,72,80,90,91,92))] if any(['regression' in task for task in self.prediction_tasks]): self.bin_labels = [binLabel(0, 'bg', (*self.gray, 0.), (0,), (0,))] self.bin_labels += [binLabel(cl, cl_dic["name"], cl_dic["color"], tuple(cl_dic["gleasons"]), tuple([self.rg_map[gs] for gs in cl_dic["gleasons"]])) for cl, cl_dic in sorted(load_obj(os.path.join(self.data_sourcedir, "pp_class_labels.pkl")).items())] self.bin_id2label = {label.id: label for label in self.bin_labels} self.gs2bin_label = {gs: label for label in self.bin_labels for gs in label.gleasons} bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels] self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)} self.bin_edges = [(bins[i][1] + bins[i+1][0]) / 2 for i in range(len(bins)-1)] self.bin_dict = {label.id: label.name for label in self.bin_labels if label.id != 0} if self.class_specific_seg: self.seg_labels = self.class_labels else: self.seg_labels = [ # id #name #color Label(0, 'bg', (*self.white, 0.)), Label(1, 'fg', (*self.orange, 1.)) ] self.class_id2label = {label.id: label for label in self.class_labels} self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0} # class_dict is used in evaluator / ap, auc, etc. statistics, and class 0 (bg) only needs to be # evaluated in debugging self.class_cmap = {label.id: label.color for label in self.class_labels} self.seg_id2label = {label.id: label for label in self.seg_labels} self.cmap = {label.id: label.color for label in self.seg_labels} self.plot_prediction_histograms = True self.plot_stat_curves = False self.plot_class_ids = True self.num_classes = len(self.class_dict) # for instance classification (excl background) self.num_seg_classes = len(self.seg_labels) # incl background ######################### # Data Augmentation # ######################### #the angle rotations are implemented incorrectly in batchgenerators! in 2D, #the x-axis angle controls the z-axis angle. if self.dim == 2: angle_x = (-np.pi / 3., np.pi / 3.) angle_z = (0.,0.) rcd = (self.patch_size[0] / 2., self.patch_size[1] / 2.) else: angle_x = (0.,0.) angle_z = (-np.pi / 2., np.pi / 2.) rcd = (self.patch_size[0] / 2., self.patch_size[1] / 2., self.patch_size[2] / 2.) self.do_aug = True # DA settings for DWI self.da_kwargs = { 'mirror': True, 'mirror_axes': tuple(np.arange(0, self.dim, 1)), 'random_crop': True, 'rand_crop_dist': rcd, 'do_elastic_deform': self.dim==2, 'alpha': (0., 1500.), 'sigma': (25., 50.), 'do_rotation': True, 'angle_x': angle_x, 'angle_y': (0., 0.), 'angle_z': angle_z, 'do_scale': True, 'scale': (0.7, 1.3), 'border_mode_data': 'constant', 'gamma_transform': True, 'gamma_range': (0.5, 2.) } # for T2 # self.da_kwargs = { # 'mirror': True, # 'mirror_axes': tuple(np.arange(0, self.dim, 1)), # 'random_crop': False, # 'rand_crop_dist': rcd, # 'do_elastic_deform': False, # 'alpha': (0., 1500.), # 'sigma': (25., 50.), # 'do_rotation': True, # 'angle_x': angle_x, # 'angle_y': (0., 0.), # 'angle_z': angle_z, # 'do_scale': False, # 'scale': (0.7, 1.3), # 'border_mode_data': 'constant', # 'gamma_transform': False, # 'gamma_range': (0.5, 2.) # } ################################# # Schedule / Selection / Optim # ################################# # good guess: train for n_samples = 1.1m = epochs*n_train_bs*b_size self.num_epochs = 270 self.num_train_batches = 120 if self.dim == 2 else 140 self.val_mode = 'val_patient' # one of 'val_sampling', 'val_patient' # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training) # the former is more accurate, while the latter is faster (depending on volume size) self.num_val_batches = 200 if self.dim==2 else 40 # for val_sampling, number or "all" self.max_val_patients = "all" #for val_patient, "all" takes whole split self.save_n_models = 6 self.min_save_thresh = 3 if self.dim == 2 else 4 #=wait time in epochs if "class" in self.prediction_tasks: # 'criterion': weight self.model_selection_criteria = {"benign_ap": 0.2, "malignant_ap": 0.8} elif any("regression" in task for task in self.prediction_tasks): self.model_selection_criteria = {"lesion_ap": 0.2, "lesion_avp": 0.8} #self.model_selection_criteria = {"GS71-92_ap": 0.9, "GS60_ap": 0.1} # 'criterion':weight #self.model_selection_criteria = {"lesion_ap": 0.2, "lesion_avp": 0.8} #self.model_selection_criteria = {label.name+"_ap": 1. for label in self.class_labels if label.id!=0} self.scan_det_thresh = False self.warm_up = 0 self.optimizer = "ADAM" self.weight_decay = 1e-5 self.clip_norm = None #number or None self.learning_rate = [1e-4] * self.num_epochs self.dynamic_lr_scheduling = True self.lr_decay_factor = 0.5 self.scheduling_patience = int(self.num_epochs / 6) ######################### # Testing # ######################### self.test_aug_axes = (0,1,(0,1)) # None or list: choices are 0,1,(0,1) (0==spatial y, 1== spatial x). - self.held_out_test_set = False + self.hold_out_test_set = False self.max_test_patients = "all" # "all" or number self.report_score_level = ['rois', 'patient'] # 'patient' or 'rois' (incl) self.patient_class_of_interest = 2 if 'class' in self.prediction_tasks else 1 self.eval_bins_separately = "additionally" if not 'class' in self.prediction_tasks else False self.patient_bin_of_interest = 2 self.metrics = ['ap', 'auc', 'dice'] if any(['regression' in task for task in self.prediction_tasks]): self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp', 'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp'] if 'aleatoric' in self.model: self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted'] self.evaluate_fold_means = True self.min_det_thresh = 0.02 self.ap_match_ious = [0.1] # threshold(s) for considering a prediction as true positive # aggregation method for test and val_patient predictions. # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf, # nms = standard non-maximum suppression, or None = no clustering self.clustering = 'wbc' # iou thresh (exclusive!) for regarding two preds as concerning the same ROI self.clustering_iou = 0.1 # has to be larger than desired possible overlap iou of model predictions # 2D-3D merging is applied independently from clustering setting. self.merge_2D_to_3D_preds = True if self.dim == 2 else False self.merge_3D_iou = 0.1 self.n_test_plots = 1 # per fold and rank self.test_n_epochs = self.save_n_models # should be called n_test_ens, since is number of models to ensemble over during testing # is multiplied by n_test_augs if test_aug ######################### # shared model settings # ######################### # max number of roi candidates to identify per image and class (slice in 2D, volume in 3D) self.n_roi_candidates = 10 if self.dim == 2 else 15 ######################### # assertions # ######################### if not 'class' in self.prediction_tasks: assert self.num_classes == 1 for mod in self.prepro['modalities2concat']: assert mod in self.prepro['images'].keys(), "need to adapt mods2concat to chosen images" ######################### # Add model specifics # ######################### {'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs, 'mrcnn_gan': self.add_mrcnn_configs, 'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs, 'detection_unet': self.add_det_unet_configs, 'detection_fpn': self.add_det_fpn_configs }[self.model]() def gleason_map(self, GS): """gleason to class id :param GS: gleason score as in histo file """ if "gleason_thresh" in self.prepro.keys(): assert "gleason_mapping" not in self.prepro.keys(), "cant define both, thresh and map, for GS to classes" # -1 == bg, 0 == benign, 1 == malignant # before shifting, i.e., 0!=bg, but 0==first class remapping = 0 if GS >= self.prepro["gleason_thresh"] else -1 return remapping elif "gleason_mapping" in self.prepro.keys(): return self.prepro["gleason_mapping"][GS] else: raise Exception("Need to define some remapping, at least GS 0 -> background (class -1)") def rg_val_to_bin_id(self, rg_val): return float(np.digitize(rg_val, self.bin_edges)) def add_det_fpn_configs(self): self.scheduling_criterion = 'torch_loss' self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') self.seg_loss_mode = 'wce' self.wce_weights = [1]*self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1, 1] # if <1, false positive predictions in foreground are penalized less. self.fp_dice_weight = 1 if self.dim == 2 else 1 self.detection_min_confidence = 0.05 #how to determine score of roi: 'max' or 'median' self.score_det = 'max' self.cuda_benchmark = self.dim==3 def add_det_unet_configs(self): self.scheduling_criterion = "torch_loss" self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') self.seg_loss_mode = 'wce' self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1, 1] # if <1, false positive predictions in foreground are penalized less. self.fp_dice_weight = 1 if self.dim == 2 else 1 self.detection_min_confidence = 0.05 #how to determine score of roi: 'max' or 'median' self.score_det = 'max' self.init_filts = 32 self.kernel_size = 3 #ks for horizontal, normal convs self.kernel_size_m = 2 #ks for max pool self.pad = "same" # "same" or integer, padding of horizontal convs self.cuda_benchmark = True def add_mrcnn_configs(self): self.scheduling_criterion = max(self.model_selection_criteria, key=self.model_selection_criteria.get) self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' # number of classes for network heads: n_foreground_classes + 1 (background) self.head_classes = self.num_classes + 1 # # feed +/- n neighbouring slices into channel dimension. set to None for no context. self.n_3D_context = None if self.n_3D_context is not None and self.dim == 2: self.n_channels *= (self.n_3D_context * 2 + 1) self.frcnn_mode = False # disable the re-sampling of mask proposals to original size for speed-up. # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching), # mask outputs are optional. self.return_masks_in_train = True self.return_masks_in_val = True self.return_masks_in_test = True # feature map strides per pyramid level are inferred from architecture. anchor scales are set accordingly. self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]} # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.) self.rpn_anchor_scales = {'xy': [[4], [8], [16], [32]], 'z': [[1], [2], [4], [8]]} # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3. self.pyramid_levels = [0, 1, 2, 3] # number of feature maps in rpn. typically lowered in 3D to save gpu-memory. self.n_rpn_features = 512 if self.dim == 2 else 128 # anchor ratios and strides per position in feature maps. self.rpn_anchor_ratios = [0.5,1.,2.] self.rpn_anchor_stride = 1 # Threshold for first stage (RPN) non-maximum suppression (NMS): LOWER == HARDER SELECTION self.rpn_nms_threshold = 0.7 if self.dim == 2 else 0.7 # loss sampling settings. self.rpn_train_anchors_per_image = 6 self.train_rois_per_image = 6 #per batch_instance self.roi_positive_ratio = 0.5 self.anchor_matching_iou = 0.7 # k negative example candidates are drawn from a pool of size k*shem_poolsize (stochastic hard-example mining), # where k<=#positive examples. self.shem_poolsize = 3 self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3) self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5) self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10) self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]]) self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1], self.patch_size_3D[2], self.patch_size_3D[2]]) #y1,x1,y2,x2,z1,z2 if self.dim == 2: self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4] self.bbox_std_dev = self.bbox_std_dev[:4] self.window = self.window[:4] self.scale = self.scale[:4] self.plot_y_max = 1.5 self.n_plot_rpn_props = 5 if self.dim == 2 else 30 #per batch_instance (slice in 2D / patient in 3D) # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element. self.pre_nms_limit = 3000 if self.dim == 2 else 6000 # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True, # since proposals of the entire batch are forwarded through second stage in as one "batch". self.roi_chunk_size = 2000 if self.dim == 2 else 400 self.post_nms_rois_training = 250 * (self.head_classes-1) if self.dim == 2 else 500 self.post_nms_rois_inference = 250 * (self.head_classes-1) # Final selection of detections (refine_detections) self.model_max_instances_per_batch_element = self.n_roi_candidates # per batch element and class. # iou for nms in box refining (directly after heads), should be >0 since ths>=x in mrcnn.py, otherwise all predictions are one cluster. self.detection_nms_threshold = 1e-5 # detection score threshold in refine_detections() self.model_min_confidence = 0.05 #self.min_det_thresh/2 if self.dim == 2: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride))] for stride in self.backbone_strides['xy']]) else: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride)), int(np.ceil(self.patch_size[2] / stride_z))] for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z'] )]) self.operate_stride1 = False if self.model == 'retina_net' or self.model == 'retina_unet': self.cuda_benchmark = self.dim == 3 #implement extra anchor-scales according to https://arxiv.org/abs/1708.02002 self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['xy']] self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['z']] self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3 self.n_rpn_features = 256 if self.dim == 2 else 64 # pre-selection of detections for NMS-speedup. per entire batch. self.pre_nms_limit = (1000 if self.dim == 2 else 6250) * self.batch_size # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002 self.anchor_matching_iou = 0.5 if self.model == 'retina_unet': self.operate_stride1 = True \ No newline at end of file diff --git a/datasets/prostate/data_loader.py b/datasets/prostate/data_loader.py index 23797a3..e67a734 100644 --- a/datasets/prostate/data_loader.py +++ b/datasets/prostate/data_loader.py @@ -1,716 +1,716 @@ __author__ = '' #credit derives from Paul Jaeger, Simon Kohl import os import time import warnings from collections import OrderedDict import pickle import numpy as np import pandas as pd # batch generator tools from https://github.com/MIC-DKFZ/batchgenerators from batchgenerators.augmentations.utils import resize_image_by_padding, center_crop_2D_image, center_crop_3D_image from batchgenerators.dataloading.data_loader import SlimDataLoaderBase from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror from batchgenerators.transforms.abstract_transforms import Compose from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter from batchgenerators.dataloading import SingleThreadedAugmenter from batchgenerators.transforms.spatial_transforms import SpatialTransform from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform #from batchgenerators.transforms.utility_transforms import ConvertSegToBoundingBoxCoordinates from batchgenerators.transforms import AbstractTransform from batchgenerators.transforms.color_transforms import GammaTransform #sys.path.append(os.path.dirname(os.path.realpath(__file__))) #import utils.exp_utils as utils import utils.dataloader_utils as dutils from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates import data_manager as dmanager def load_obj(file_path): with open(file_path, 'rb') as handle: return pickle.load(handle) def id_to_spec(id, base_spec): """Construct subject specifier from base string and an integer subject number.""" num_zeros = 5 - len(str(id)) assert num_zeros>=0, "id_to_spec: patient id too long to fit into 5 figures" return base_spec + '_' + ('').join(['0'] * num_zeros) + str(id) def convert_3d_to_2d_generator(data_dict, shape="bcxyz"): """Fold/Shape z-dimension into color-channel. :param shape: bcxyz or bczyx :return: shape b(c*z)xy or b(c*z)yx """ if shape=="bcxyz": data_dict['data'] = np.transpose(data_dict['data'], axes=(0,1,4,3,2)) data_dict['seg'] = np.transpose(data_dict['seg'], axes=(0,1,4,3,2)) elif shape=="bczyx": pass else: raise Exception("unknown datashape {} in 3d_to_2d transform converter".format(shape)) shp = data_dict['data'].shape data_dict['orig_shape_data'] = shp seg_shp = data_dict['seg'].shape data_dict['orig_shape_seg'] = seg_shp data_dict['data'] = data_dict['data'].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4])) data_dict['seg'] = data_dict['seg'].reshape((seg_shp[0], seg_shp[1] * seg_shp[2], seg_shp[3], seg_shp[4])) return data_dict def convert_2d_to_3d_generator(data_dict, shape="bcxyz"): """Unfold z-dimension from color-channel. data needs to be in shape bcxy or bcyx, x,y dims won't be swapped relative to each other. :param shape: target shape, bcxyz or bczyx """ shp = data_dict['orig_shape_data'] cur_shape = data_dict['data'].shape seg_shp = data_dict['orig_shape_seg'] cur_shape_seg = data_dict['seg'].shape data_dict['data'] = data_dict['data'].reshape((shp[0], shp[1], shp[2], cur_shape[-2], cur_shape[-1])) data_dict['seg'] = data_dict['seg'].reshape((seg_shp[0], seg_shp[1], seg_shp[2], cur_shape_seg[-2], cur_shape_seg[-1])) if shape=="bcxyz": data_dict['data'] = np.transpose(data_dict['data'], axes=(0,1,4,3,2)) data_dict['seg'] = np.transpose(data_dict['seg'], axes=(0,1,4,3,2)) return data_dict class Convert3DTo2DTransform(AbstractTransform): def __init__(self): pass def __call__(self, **data_dict): return convert_3d_to_2d_generator(data_dict) class Convert2DTo3DTransform(AbstractTransform): def __init__(self): pass def __call__(self, **data_dict): return convert_2d_to_3d_generator(data_dict) def vector(item): """ensure item is vector-like (list or array or tuple) :param item: anything """ if not isinstance(item, (list, tuple, np.ndarray)): item = [item] return item class Dataset(dutils.Dataset): r"""Load a dict holding memmapped arrays and clinical parameters for each patient, evtly subset of those. If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to cf.data_dest. :param cf: config file :param data_dir: directory in which to find data, defaults to cf.data_dir if None. :return: dict with imgs, segs, pids, class_labels, observables """ def __init__(self, cf, logger=None, subset_ids=None, data_sourcedir=None): super(Dataset,self).__init__(cf, data_sourcedir=data_sourcedir) info_dict = load_obj(cf.info_dict_path) if subset_ids is not None: pids = subset_ids if logger is None: print('subset: selected {} instances from df'.format(len(pids))) else: logger.info('subset: selected {} instances from df'.format(len(pids))) else: pids = list(info_dict.keys()) #evtly copy data from data_rootdir to data_dir if cf.server_env and not hasattr(cf, "data_dir"): file_subset = [info_dict[pid]['img'][:-3]+"*" for pid in pids] file_subset+= [info_dict[pid]['seg'][:-3]+"*" for pid in pids] file_subset += [cf.info_dict_path] self.copy_data(cf, file_subset=file_subset) cf.data_dir = self.data_dir img_paths = [os.path.join(self.data_dir, info_dict[pid]['img']) for pid in pids] seg_paths = [os.path.join(self.data_dir, info_dict[pid]['seg']) for pid in pids] # load all subject files self.data = OrderedDict() for i, pid in enumerate(pids): subj_spec = id_to_spec(pid, cf.prepro['dir_spec']) subj_data = {'pid':pid, "spec":subj_spec} subj_data['img'] = img_paths[i] subj_data['seg'] = seg_paths[i] #read, add per-roi labels for obs in cf.observables_patient+cf.observables_rois: subj_data[obs] = np.array(info_dict[pid][obs]) if 'class' in self.cf.prediction_tasks: subj_data['class_targets'] = np.array(info_dict[pid]['roi_classes'], dtype='uint8') + 1 else: subj_data['class_targets'] = np.ones_like(np.array(info_dict[pid]['roi_classes']), dtype='uint8') if any(['regression' in task for task in self.cf.prediction_tasks]): if hasattr(cf, "rg_map"): subj_data["regression_targets"] = np.array([vector(cf.rg_map[v]) for v in info_dict[pid][cf.regression_target]], dtype='float16') else: subj_data["regression_targets"] = np.array([vector(v) for v in info_dict[pid][cf.regression_target]], dtype='float16') subj_data["rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in subj_data["regression_targets"]], dtype='uint8') subj_data['fg_slices'] = info_dict[pid]['fg_slices'] self.data[pid] = subj_data cf.roi_items = cf.observables_rois[:] cf.roi_items += ['class_targets'] if any(['regression' in task for task in self.cf.prediction_tasks]): cf.roi_items += ['regression_targets'] cf.roi_items += ['rg_bin_targets'] #cf.patient_items = cf.observables_patient[:] #patient-wise items not used currently self.set_ids = np.array(list(self.data.keys())) self.df = None class BatchGenerator(dutils.BatchGenerator): """ create the training/validation batch generator. Randomly sample batch_size patients from the data set, (draw a random slice if 2D), pad-crop them to equal sizes and merge to an array. :param data: data dictionary as provided by 'load_dataset' :param img_modalities: list of strings ['adc', 'b1500'] from config :param batch_size: number of patients to sample for the batch :param pre_crop_size: equal size for merging the patients to a single array (before the final random-crop in data aug.) :param sample_pids_w_replace: whether to randomly draw pids from dataset for batch generation. if False, step through whole dataset before repition. :return dictionary containing the batch data / seg / pids as lists; the augmenter will later concatenate them into an array. """ def __init__(self, cf, data, n_batches=None, sample_pids_w_replace=True): super(BatchGenerator, self).__init__(cf, data, n_batches) self.dataset_length = len(self._data) self.cf = cf self.sample_pids_w_replace = sample_pids_w_replace self.eligible_pids = list(self._data.keys()) self.chans = cf.channels if cf.channels is not None else np.index_exp[:] assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing" self.p_fg = 0.5 self.empty_samples_max_ratio = 0.6 self.random_count = int(cf.batch_random_ratio * cf.batch_size) self.balance_target_distribution(plot=sample_pids_w_replace) self.stats = {"roi_counts" : np.zeros((len(self.unique_ts),), dtype='uint32'), "empty_samples_count" : 0} def generate_train_batch(self): #everything done in here is per batch #print statements in here get confusing due to multithreading if self.sample_pids_w_replace: # fully random patients batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False)) # target-balanced patients batch_patient_ids += list(np.random.choice( self.dataset_pids, size=self.batch_size - self.random_count, replace=False, p=self.p_probs)) else: batch_patient_ids = np.random.choice(self.eligible_pids, size=self.batch_size, replace=False) if self.sample_pids_w_replace == False: self.eligible_pids = [pid for pid in self.eligible_pids if pid not in batch_patient_ids] if len(self.eligible_pids) < self.batch_size: self.eligible_pids = self.dataset_pids batch_data, batch_segs, batch_patient_specs = [], [], [] batch_roi_items = {name: [] for name in self.cf.roi_items} #record roi count of classes in batch batch_roi_counts, empty_samples_count = np.zeros((len(self.unique_ts),), dtype='uint32'), 0 #empty count for full bg samples (empty slices in 2D/patients in 3D) for sample in range(self.batch_size): patient = self._data[batch_patient_ids[sample]] #swap dimensions from (c,)z,y,x to (c,)y,x,z or h,w,d to ease 2D/3D-case handling data = np.transpose(np.load(patient['img'], mmap_mode='r'), axes=(0, 2, 3, 1))[self.chans] seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0)) (c,y,x,z) = data.shape #original data is 3D MRIs, so need to pick (e.g. randomly) single slice to make it 2D, #consider batch roi-class balance if self.cf.dim == 2: elig_slices, choose_fg = [], False if self.sample_pids_w_replace and len(patient['fg_slices']) > 0: if empty_samples_count / self.batch_size >= self.empty_samples_max_ratio or np.random.rand( 1) <= self.p_fg: # fg is to be picked for tix in np.argsort(batch_roi_counts): # pick slices of patient that have roi of sought-for target # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero( patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] == self.unique_ts[tix]) > 0] if len(elig_slices) > 0: choose_fg = True break else: # pick bg elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices']) if len(elig_slices) == 0: elig_slices = z sl_pick_ix = np.random.choice(elig_slices, size=None) data = data[..., sl_pick_ix] seg = seg[..., sl_pick_ix] spatial_shp = data[0].shape assert spatial_shp==seg.shape, "spatial shape incongruence betw. data and seg" if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]): new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))] data = dutils.pad_nd_image(data, (len(data), *new_shape)) seg = dutils.pad_nd_image(seg, new_shape) #eventual cropping to pre_crop_size: with prob self.p_fg sample pixel from random ROI and shift center, #if possible, to that pixel, so that img still contains ROI after pre-cropping dim_cropflags = [spatial_shp[i] > self.cf.pre_crop_size[i] for i in range(len(spatial_shp))] if np.any(dim_cropflags): print("dim crop applied") # sample pixel from random ROI and shift center, if possible, to that pixel if self.cf.dim==3: choose_fg = (empty_samples_count/self.batch_size>=self.empty_samples_max_ratio) or np.random.rand(1) <= self.p_fg if self.sample_pids_w_replace and choose_fg and np.any(seg): available_roi_ids = np.unique(seg)[1:] for tix in np.argsort(batch_roi_counts): elig_roi_ids = available_roi_ids[ patient[self.balance_target][available_roi_ids - 1] == self.unique_ts[tix]] if len(elig_roi_ids) > 0: seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None)) break roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)] assert seg[tuple(roi_anchor_pixel)] > 0 # sample the patch center coords. constrained by edges of image - pre_crop_size /2 and # distance to the selected ROI < patch_size /2 def get_cropped_centercoords(dim): low = np.max((self.cf.pre_crop_size[dim]//2, roi_anchor_pixel[dim] - (self.cf.patch_size[dim]//2 - self.cf.crop_margin[dim]))) high = np.min((spatial_shp[dim] - self.cf.pre_crop_size[dim]//2, roi_anchor_pixel[dim] + (self.cf.patch_size[dim]//2 - self.cf.crop_margin[dim]))) if low >= high: #happens if lesion on the edge of the image. #print('correcting low/high:', low, high, spatial_shp, roi_anchor_pixel, dim) low = self.cf.pre_crop_size[dim] // 2 high = spatial_shp[dim] - self.cf.pre_crop_size[dim]//2 assert low0]) - 1] == self.unique_ts[tix]) if not np.any(seg): empty_samples_count += 1 #self.stats['roi_counts'] += batch_roi_counts #DOESNT WORK WITH MULTITHREADING! do outside #self.stats['empty_samples_count'] += empty_samples_count batch = {'data': np.array(batch_data), 'seg': np.array(batch_segs).astype('uint8'), 'pid': batch_patient_ids, 'spec': batch_patient_specs, 'roi_counts':batch_roi_counts, 'empty_samples_count': empty_samples_count} for key,val in batch_roi_items.items(): #extend batch dic by roi-wise items (obs, class ids, regression vectors...) batch[key] = np.array(val) return batch class PatientBatchIterator(dutils.PatientBatchIterator): """ creates a val/test generator. Step through the dataset and return dictionaries per patient. 2D is a special case of 3D patching with patch_size[2] == 1 (slices) Creates whole Patient batch and targets, and - if necessary - patchwise batch and targets. Appends patient targets anyway for evaluation. For Patching, shifts all patches into batch dimension. batch_tiling_forward will take care of exceeding batch dimensions. This iterator/these batches are not intended to go through MTaugmenter afterwards """ def __init__(self, cf, data): super(PatientBatchIterator, self).__init__(cf, data) self.patient_ix = 0 #running index over all patients in set self.patch_size = cf.patch_size+[1] if cf.dim==2 else cf.patch_size self.chans = cf.channels if cf.channels is not None else np.index_exp[:] assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing" def generate_train_batch(self, pid=None): if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 if pid is None: pid = self.dataset_pids[self.patient_ix] # + self.thread_id patient = self._data[pid] #swap dimensions from (c,)z,y,x to c,y,x,z or h,w,d to ease 2D/3D-case handling data = np.transpose(np.load(patient['img'], mmap_mode='r'), axes=(0, 2, 3, 1)) seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis] data_shp_raw = data.shape plot_bg = data[self.cf.plot_bg_chan] if self.cf.plot_bg_chan not in self.chans else None data = data[self.chans] discarded_chans = len( [c for c in np.setdiff1d(np.arange(data_shp_raw[0]), self.chans) if c < self.cf.plot_bg_chan]) spatial_shp = data[0].shape # spatial dims need to be in order x,y,z assert spatial_shp==seg[0].shape, "spatial shape incongruence betw. data and seg" if np.any([spatial_shp[i] < ps for i, ps in enumerate(self.patch_size)]): new_shape = [np.max([spatial_shp[i], self.patch_size[i]]) for i in range(len(self.patch_size))] data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape. seg = dutils.pad_nd_image(seg, new_shape) if plot_bg is not None: plot_bg = dutils.pad_nd_image(plot_bg, new_shape) if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: #adds the batch dim here bc won't go through MTaugmenter out_data = data[np.newaxis] out_seg = seg[np.newaxis] if plot_bg is not None: out_plot_bg = plot_bg[np.newaxis] #data and seg shape: (1,c,x,y,z), where c=1 for seg batch_3D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_3D[o] = np.array([patient[o]]) converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg) batch_3D = converter(**batch_3D) batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_3D["patient_" + o] = batch_3D[o] if self.cf.dim == 2: out_data = np.transpose(data, axes=(3,0,1,2)) #(c,y,x,z) to (b=z,c,x,y), use z=b as batchdim out_seg = np.transpose(seg, axes=(3,0,1,2)).astype('uint8') #(c,y,x,z) to (b=z,c,x,y) batch_2D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_2D[o] = np.repeat(np.array([patient[o]]), len(out_data), axis=0) converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg) batch_2D = converter(**batch_2D) if plot_bg is not None: out_plot_bg = np.transpose(plot_bg, axes=(2,0,1)).astype('float32') if self.cf.merge_2D_to_3D_preds: batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_3D['patient_'+o] else: batch_2D.update({'patient_bb_target': batch_2D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_2D[o] out_batch = batch_3D if self.cf.dim == 3 else batch_2D out_batch.update({'pid': np.array([patient['pid']] * len(out_data)), 'spec':np.array([patient['spec']] * len(out_data))}) if self.cf.plot_bg_chan in self.chans and discarded_chans>0: assert plot_bg is None plot_bg = int(self.cf.plot_bg_chan - discarded_chans) out_plot_bg = plot_bg if plot_bg is not None: out_batch['plot_bg'] = out_plot_bg #eventual tiling into patches spatial_shp = out_batch["data"].shape[2:] if np.any([spatial_shp[ix] > self.patch_size[ix] for ix in range(len(spatial_shp))]): patient_batch = out_batch #print("patientiterator produced patched batch!") patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size) new_img_batch, new_seg_batch = [], [] for c in patch_crop_coords_list: new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:c[5]]) seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]] new_seg_batch.append(seg_patch) shps = [] for arr in new_img_batch: shps.append(arr.shape) data = np.array(new_img_batch) # (patches, c, x, y, z) seg = np.array(new_seg_batch) if self.cf.dim == 2: # all patches have z dimension 1 (slices). discard dimension data = data[..., 0] seg = seg[..., 0] patch_batch = {'data': data, 'seg': seg.astype('uint8'), 'pid': np.array([patient['pid']] * data.shape[0]), 'spec':np.array([patient['spec']] * data.shape[0])} for o in self.cf.roi_items: patch_batch[o] = np.repeat(np.array([patient[o]]), len(patch_crop_coords_list), axis=0) # patient-wise (orig) batch info for putting the patches back together after prediction for o in self.cf.roi_items: patch_batch["patient_"+o] = patient_batch['patient_'+o] patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] #patch_batch['patient_roi_labels'] = patient_batch['patient_roi_labels'] patch_batch['patient_data'] = patient_batch['data'] patch_batch['patient_seg'] = patient_batch['seg'] patch_batch['original_img_shape'] = patient_batch['original_img_shape'] if plot_bg is not None: patch_batch['patient_plot_bg'] = patient_batch['plot_bg'] converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, self.cf.class_specific_seg) patch_batch = converter(**patch_batch) out_batch = patch_batch self.patient_ix += 1 # todo raise stopiteration when in test mode if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 return out_batch def create_data_gen_pipeline(cf, patient_data, do_aug=True, sample_pids_w_replace=True): """ create mutli-threaded train/val/test batch generation and augmentation pipeline. :param patient_data: dictionary containing one dictionary per patient in the train/test subset :param test_pids: (optional) list of test patient ids, calls the test generator. :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing) :return: multithreaded_generator """ data_gen = BatchGenerator(cf, patient_data, sample_pids_w_replace=sample_pids_w_replace) my_transforms = [] if do_aug: if cf.da_kwargs["mirror"]: mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes']) my_transforms.append(mirror_transform) if cf.da_kwargs["gamma_transform"]: gamma_transform = GammaTransform(gamma_range=cf.da_kwargs["gamma_range"], invert_image=False, per_channel=False, retain_stats=True) my_transforms.append(gamma_transform) if cf.dim == 3: # augmentations with desired effect on z-dimension spatial_transform = SpatialTransform(patch_size=cf.patch_size, patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], do_elastic_deform=False, do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop'], border_mode_data=cf.da_kwargs['border_mode_data']) my_transforms.append(spatial_transform) # augmentations that are only meant to affect x-y my_transforms.append(Convert3DTo2DTransform()) spatial_transform = SpatialTransform(patch_size=cf.patch_size[:2], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'][:2], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=False, do_scale=False, random_crop=False, border_mode_data=cf.da_kwargs['border_mode_data']) my_transforms.append(spatial_transform) my_transforms.append(Convert2DTo3DTransform()) else: spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'][:2], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop'], border_mode_data=cf.da_kwargs['border_mode_data']) my_transforms.append(spatial_transform) else: my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) if cf.create_bounding_box_targets: my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg)) #batch receives entry 'bb_target' w bbox coordinates as [y1,x1,y2,x2,z1,z2]. #my_transforms.append(ConvertSegToOnehotTransform(classes=range(cf.num_seg_classes))) all_transforms = Compose(my_transforms) #MTAugmenter creates iterator from data iterator data_gen after applying the composed transform all_transforms multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=data_gen.n_filled_threads, seeds=range(data_gen.n_filled_threads)) return multithreaded_generator def get_train_generators(cf, logger, data_statistics=True): """ wrapper function for creating the training batch generator pipeline. returns the train/val generators need to select cv folds on patient level, but be able to include both breasts of each patient. """ dataset = Dataset(cf, logger) dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits) dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle')) set_splits = dataset.fg.splits test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold-1) train_ids = np.concatenate(set_splits, axis=0) - if cf.held_out_test_set: + if cf.hold_out_test_set: train_ids = np.concatenate((train_ids, test_ids), axis=0) test_ids = [] train_data = {k: v for (k, v) in dataset.data.items() if k in train_ids} val_data = {k: v for (k, v) in dataset.data.items() if k in val_ids} logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids), len(test_ids))) if data_statistics: dataset.calc_statistics(subsets={"train":train_ids, "val":val_ids, "test":test_ids}, plot_dir=os.path.join(cf.plot_dir,"dataset")) batch_gen = {} batch_gen['train'] = create_data_gen_pipeline(cf, train_data, do_aug=cf.do_aug) batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, do_aug=False, sample_pids_w_replace=False) if cf.val_mode == 'val_patient': batch_gen['val_patient'] = PatientBatchIterator(cf, val_data) batch_gen['n_val'] = len(val_ids) if cf.max_val_patients=="all" else cf.max_val_patients elif cf.val_mode == 'val_sampling': batch_gen['n_val'] = cf.num_val_batches if cf.num_val_batches!="all" else len(val_ids) return batch_gen def get_test_generator(cf, logger): """ if get_test_generators is called multiple times in server env, every time of Dataset initiation rsync will check for copying the data; this should be okay since rsync will not copy if files already exist in destination. """ - if cf.held_out_test_set: + if cf.hold_out_test_set: sourcedir = cf.test_data_sourcedir test_ids = None else: sourcedir = None with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle: set_splits = pickle.load(handle) test_ids = set_splits[cf.fold] test_set = Dataset(cf, logger, test_ids, data_sourcedir=sourcedir) logger.info("data set loaded with: {} test patients".format(len(test_set.set_ids))) batch_gen = {} batch_gen['test'] = PatientBatchIterator(cf, test_set.data) batch_gen['n_test'] = len(test_set.set_ids) if cf.max_test_patients=="all" else min(cf.max_test_patients, len(test_set.set_ids)) return batch_gen if __name__=="__main__": import sys sys.path.append('../') # works on cluster indep from where sbatch job is started import plotting as plg import utils.exp_utils as utils from configs import Configs cf = configs() total_stime = time.time() times = {} #cf.server_env = True #cf.data_dir = "experiments/dev_data" #dataset = Dataset(cf) #patient = dataset['Master_00018'] cf.exp_dir = "experiments/dev/" cf.plot_dir = cf.exp_dir+"plots" os.makedirs(cf.exp_dir, exist_ok=True) cf.fold = 0 logger = utils.get_logger(cf.exp_dir) gens = get_train_generators(cf, logger) train_loader = gens['train'] #for i in range(train_loader.dataset_length): # print("batch", i) stime = time.time() ex_batch = next(train_loader) #ex_batch = next(train_loader) times["train_batch"] = time.time()-stime plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch.png", show_gt_labels=True) #with open(os.path.join(cf.exp_dir, "fold_"+str(cf.fold), "BatchGenerator_stats.txt"), mode="w") as file: # train_loader.generator.print_stats(logger, file) val_loader = gens['val_sampling'] stime = time.time() ex_batch = next(val_loader) times["val_batch"] = time.time()-stime stime = time.time() plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch.png", show_gt_labels=True, plot_mods=False, show_info=False) times["val_plot"] = time.time()-stime test_loader = get_test_generator(cf, logger)["test"] stime = time.time() ex_batch = test_loader.generate_train_batch() print(ex_batch["data"].shape) times["test_batch"] = time.time()-stime stime = time.time() plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/ex_patchbatch.png", show_gt_boxes=False, show_info=False, dpi=400, sample_picks=[2,5], plot_mods=False) times["test_patchbatch_plot"] = time.time()-stime #stime = time.time() #ex_batch['data'] = ex_batch['patient_data'] #ex_batch['seg'] = ex_batch['patient_seg'] #if 'patient_plot_bg' in ex_batch.keys(): # ex_batch['plot_bg'] = ex_batch['patient_plot_bg'] #plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png") #times["test_patientbatch_plot"] = time.time() - stime #print("patch batch keys", ex_batch.keys()) #print("patch batch les gle", ex_batch["lesion_gleasons"].shape) #print("patch batch gsbx", ex_batch["GSBx"].shape) #print("patch batch class_targ", ex_batch["class_targets"].shape) #print("patient b roi labels", ex_batch["patient_roi_labels"].shape) #print("patient les gleas", ex_batch["patient_lesion_gleasons"].shape) #print("patch&patient batch pid", ex_batch["pid"], len(ex_batch["pid"])) #print("unique patient_seg", np.unique(ex_batch["patient_seg"])) #print("pb patient roi labels", len(ex_batch["patient_roi_labels"]), ex_batch["patient_roi_labels"]) #print("pid", ex_batch["pid"]) #patient_batch = {k[len("patient_"):]:v for (k,v) in ex_batch.items() if k.lower().startswith("patient")} #patient_batch["pid"] = ex_batch["pid"] #stime = time.time() #plg.view_batch(cf, patient_batch, out_file="experiments/dev_expatientbatch") #times["test_plot"] = time.time()-stime print("Times recorded throughout:") for (k,v) in times.items(): print(k, "{:.2f}".format(v)) mins, secs = divmod((time.time() - total_stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) \ No newline at end of file diff --git a/datasets/toy/configs.py b/datasets/toy/configs.py index 6780d22..c37535e 100644 --- a/datasets/toy/configs.py +++ b/datasets/toy/configs.py @@ -1,491 +1,491 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import sys import os sys.path.append(os.path.dirname(os.path.realpath(__file__))) import numpy as np from default_configs import DefaultConfigs from collections import namedtuple boxLabel = namedtuple('boxLabel', ["name", "color"]) Label = namedtuple("Label", ['id', 'name', 'shape', 'radius', 'color', 'regression', 'ambiguities', 'gt_distortion']) binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'bin_vals']) class Configs(DefaultConfigs): def __init__(self, server_env=None): super(Configs, self).__init__(server_env) ######################### # Prepro # ######################### self.pp_rootdir = os.path.join('/home/gregor/datasets/toy', "cyl1ps_dev") self.pp_npz_dir = self.pp_rootdir+"_npz" self.pre_crop_size = [320,320,8] #y,x,z; determines pp data shape (2D easily implementable, but only 3D for now) self.min_2d_radius = 6 #in pixels self.n_train_samples, self.n_test_samples = 1200, 1000 # not actually real one-hot encoding (ohe) but contains more info: roi-overlap only within classes. self.pp_create_ohe_seg = False self.pp_empty_samples_ratio = 0.1 self.pp_place_radii_mid_bin = True self.pp_only_distort_2d = True # outer-most intensity of blurred radii, relative to inner-object intensity. <1 for decreasing, > 1 for increasing. # e.g.: setting 0.1 means blurred edge has min intensity 10% as large as inner-object intensity. self.pp_blur_min_intensity = 0.2 self.max_instances_per_sample = 1 #how many max instances over all classes per sample (img if 2d, vol if 3d) self.max_instances_per_class = self.max_instances_per_sample # how many max instances per image per class self.noise_scale = 0. # std-dev of gaussian noise self.ambigs_sampling = "gaussian" #"gaussian" or "uniform" """ radius_calib: gt distort for calibrating uncertainty. Range of gt distortion is inferable from image by distinguishing it from the rest of the object. blurring width around edge will be shifted so that symmetric rel to orig radius. blurring scale: if self.ambigs_sampling is uniform, distribution's non-zero range (b-a) will be sqrt(12)*scale since uniform dist has variance (b-a)²/12. b,a will be placed symmetrically around unperturbed radius. if sampling is gaussian, then scale parameter sets one std dev, i.e., blurring width will be orig_radius * std_dev * 2. """ self.ambiguities = { #set which classes to apply which ambs to below in class labels #choose out of: 'outer_radius', 'inner_radius', 'radii_relations'. #kind #probability #scale (gaussian std, relative to unperturbed value) #"outer_radius": (1., 0.5), #"outer_radius_xy": (1., 0.5), #"inner_radius": (0.5, 0.1), #"radii_relations": (0.5, 0.1), "radius_calib": (1., 1./6) } # shape choices: 'cylinder', 'block' # id, name, shape, radius, color, regression, ambiguities, gt_distortion self.pp_classes = [Label(1, 'cylinder', 'cylinder', ((6,6,1),(40,40,8)), (*self.blue, 1.), "radius_2d", (), ()), #Label(2, 'block', 'block', ((6,6,1),(40,40,8)), (*self.aubergine,1.), "radii_2d", (), ('radius_calib',)) ] ######################### # I/O # ######################### self.data_sourcedir = '/home/gregor/datasets/toy/cyl1ps_dev' if server_env: self.data_sourcedir = '/datasets/datasets_ramien/toy/data/cyl1ps_dev_npz' self.test_data_sourcedir = os.path.join(self.data_sourcedir, 'test') self.data_sourcedir = os.path.join(self.data_sourcedir, "train") self.info_df_name = 'info_df.pickle' # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn', 'detection_fpn']. self.model = 'mrcnn' self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net') self.model_path = os.path.join(self.source_dir, self.model_path) ######################### # Architecture # ######################### # one out of [2, 3]. dimension the model operates in. self.dim = 2 # 'class', 'regression', 'regression_bin', 'regression_ken_gal' # currently only tested mode is a single-task at a time (i.e., only one task in below list) # but, in principle, tasks could be combined (e.g., object classes and regression per class) self.prediction_tasks = ['class', ] self.start_filts = 48 if self.dim == 2 else 18 self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2 self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50' self.norm = 'instance_norm' # one of None, 'instance_norm', 'batch_norm' self.relu = 'relu' # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform') self.weight_init = None self.regression_n_features = 1 # length of regressor target vector ######################### # Data Loader # ######################### self.num_epochs = 24 self.num_train_batches = 100 if self.dim == 2 else 180 self.batch_size = 20 if self.dim == 2 else 8 self.n_cv_splits = 4 # select modalities from preprocessed data self.channels = [0] self.n_channels = len(self.channels) # which channel (mod) to show as bg in plotting, will be extra added to batch if not in self.channels self.plot_bg_chan = 0 self.crop_margin = [20, 20, 1] # has to be smaller than respective patch_size//2 self.patch_size_2D = self.pre_crop_size[:2] self.patch_size_3D = self.pre_crop_size[:2]+[8] # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation. self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D # ratio of free sampled batch elements before class balancing is triggered # (>0 to include "empty"/background patches.) self.batch_random_ratio = 0.2 self.balance_target = "class_targets" if 'class' in self.prediction_tasks else "rg_bin_targets" self.observables_patient = [] self.observables_rois = [] self.seed = 3 #for generating folds ############################# # Colors, Classes, Legends # ############################# self.plot_frequency = 4 binary_bin_labels = [binLabel(1, 'r<=25', (*self.green, 1.), (1,25)), binLabel(2, 'r>25', (*self.red, 1.), (25,))] quintuple_bin_labels = [binLabel(1, 'r2-10', (*self.green, 1.), (2,10)), binLabel(2, 'r10-20', (*self.yellow, 1.), (10,20)), binLabel(3, 'r20-30', (*self.orange, 1.), (20,30)), binLabel(4, 'r30-40', (*self.bright_red, 1.), (30,40)), binLabel(5, 'r>40', (*self.red, 1.), (40,))] # choose here if to do 2-way or 5-way regression-bin classification task_spec_bin_labels = quintuple_bin_labels self.class_labels = [ # regression: regression-task label, either value or "(x,y,z)_radius" or "radii". # ambiguities: name of above defined ambig to apply to image data (not gt); need to be iterables! # gt_distortion: name of ambig to apply to gt only; needs to be iterable! # #id #name #shape #radius #color #regression #ambiguities #gt_distortion Label( 0, 'bg', None, (0, 0, 0), (*self.white, 0.), (0, 0, 0), (), ())] if "class" in self.prediction_tasks: self.class_labels += self.pp_classes else: self.class_labels += [Label(1, 'object', 'object', ('various',), (*self.orange, 1.), ('radius_2d',), ("various",), ('various',))] if any(['regression' in task for task in self.prediction_tasks]): self.bin_labels = [binLabel(0, 'bg', (*self.white, 1.), (0,))] self.bin_labels += task_spec_bin_labels self.bin_id2label = {label.id: label for label in self.bin_labels} bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels] self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)} self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)] self.bin_dict = {label.id: label.name for label in self.bin_labels if label.id != 0} if self.class_specific_seg: self.seg_labels = self.class_labels self.box_type2label = {label.name: label for label in self.box_labels} self.class_id2label = {label.id: label for label in self.class_labels} self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0} self.seg_id2label = {label.id: label for label in self.seg_labels} self.cmap = {label.id: label.color for label in self.seg_labels} self.plot_prediction_histograms = True self.plot_stat_curves = False self.has_colorchannels = False self.plot_class_ids = True self.num_classes = len(self.class_dict) self.num_seg_classes = len(self.seg_labels) ######################### # Data Augmentation # ######################### self.do_aug = True self.da_kwargs = { 'mirror': True, 'mirror_axes': tuple(np.arange(0, self.dim, 1)), 'do_elastic_deform': False, 'alpha': (500., 1500.), 'sigma': (40., 45.), 'do_rotation': False, 'angle_x': (0., 2 * np.pi), 'angle_y': (0., 0), 'angle_z': (0., 0), 'do_scale': False, 'scale': (0.8, 1.1), 'random_crop': False, 'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3), 'border_mode_data': 'constant', 'border_cval_data': 0, 'order_data': 1 } if self.dim == 3: self.da_kwargs['do_elastic_deform'] = False self.da_kwargs['angle_x'] = (0, 0.0) self.da_kwargs['angle_y'] = (0, 0.0) # must be 0!! self.da_kwargs['angle_z'] = (0., 2 * np.pi) ######################### # Schedule / Selection # ######################### # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training) # the former is morge accurate, while the latter is faster (depending on volume size) self.val_mode = 'val_sampling' # one of 'val_sampling' , 'val_patient' if self.val_mode == 'val_patient': self.max_val_patients = 220 # if 'all' iterates over entire val_set once. if self.val_mode == 'val_sampling': self.num_val_batches = 35 if self.dim==2 else 25 self.save_n_models = 2 self.min_save_thresh = 1 if self.dim == 2 else 1 # =wait time in epochs if "class" in self.prediction_tasks: self.model_selection_criteria = {name + "_ap": 1. for name in self.class_dict.values()} elif any("regression" in task for task in self.prediction_tasks): self.model_selection_criteria = {name + "_ap": 0.2 for name in self.class_dict.values()} self.model_selection_criteria.update({name + "_avp": 0.8 for name in self.class_dict.values()}) self.lr_decay_factor = 0.25 self.scheduling_patience = np.ceil(3600 / (self.num_train_batches * self.batch_size)) self.weight_decay = 3e-5 self.exclude_from_wd = [] self.clip_norm = None # number or None ######################### # Testing / Plotting # ######################### self.test_aug_axes = (0,1,(0,1)) # None or list: choices are 0,1,(0,1) - self.held_out_test_set = True + self.hold_out_test_set = True self.max_test_patients = "all" # number or "all" for all self.test_against_exact_gt = True # only True implemented self.val_against_exact_gt = False # True is an unrealistic --> irrelevant scenario. self.report_score_level = ['rois'] # 'patient' or 'rois' (incl) self.patient_class_of_interest = 1 self.patient_bin_of_interest = 2 self.eval_bins_separately = False#"additionally" if not 'class' in self.prediction_tasks else False self.metrics = ['ap', 'auc', 'dice'] if any(['regression' in task for task in self.prediction_tasks]): self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp', 'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp'] if 'aleatoric' in self.model: self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted'] self.evaluate_fold_means = True self.ap_match_ious = [0.5] # threshold(s) for considering a prediction as true positive self.min_det_thresh = 0.3 self.model_max_iou_resolution = 0.2 # aggregation method for test and val_patient predictions. # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf, # nms = standard non-maximum suppression, or None = no clustering self.clustering = 'wbc' # iou thresh (exclusive!) for regarding two preds as concerning the same ROI self.clustering_iou = self.model_max_iou_resolution # has to be larger than desired possible overlap iou of model predictions self.merge_2D_to_3D_preds = self.dim==2 self.merge_3D_iou = self.model_max_iou_resolution self.n_test_plots = 1 # per fold and rank self.test_n_epochs = self.save_n_models # should be called n_test_ens, since is number of models to ensemble over during testing # is multiplied by (1 + nr of test augs) ######################### # Assertions # ######################### if not 'class' in self.prediction_tasks: assert self.num_classes == 1 ######################### # Add model specifics # ######################### {'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs, 'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs, 'detection_unet': self.add_det_unet_configs, 'detection_fpn': self.add_det_fpn_configs }[self.model]() def rg_val_to_bin_id(self, rg_val): #only meant for isotropic radii!! # only 2D radii (x and y dims) or 1D (x or y) are expected return np.round(np.digitize(rg_val, self.bin_edges).mean()) def add_det_fpn_configs(self): self.learning_rate = [1 * 1e-4] * self.num_epochs self.dynamic_lr_scheduling = True self.scheduling_criterion = 'torch_loss' self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' self.n_roi_candidates = 4 if self.dim == 2 else 6 # max number of roi candidates to identify per image (slice in 2D, volume in 3D) # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') self.seg_loss_mode = 'wce' self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1] self.fp_dice_weight = 1 if self.dim == 2 else 1 # if <1, false positive predictions in foreground are penalized less. self.detection_min_confidence = 0.05 # how to determine score of roi: 'max' or 'median' self.score_det = 'max' def add_det_unet_configs(self): self.learning_rate = [3 * 1e-4] * self.num_epochs self.dynamic_lr_scheduling = True self.scheduling_criterion = "torch_loss" self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' # max number of roi candidates to identify per image (slice in 2D, volume in 3D) self.n_roi_candidates = 4 if self.dim == 2 else 6 # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') self.seg_loss_mode = 'wce' self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1] # if <1, false positive predictions in foreground are penalized less. self.fp_dice_weight = 1 if self.dim == 2 else 1 self.detection_min_confidence = 0.05 # how to determine score of roi: 'max' or 'median' self.score_det = 'max' self.init_filts = 32 self.kernel_size = 3 # ks for horizontal, normal convs self.kernel_size_m = 2 # ks for max pool self.pad = "same" # "same" or integer, padding of horizontal convs def add_mrcnn_configs(self): self.learning_rate = [3e-4] * self.num_epochs self.dynamic_lr_scheduling = True # with scheduler set in exec self.scheduling_criterion = max(self.model_selection_criteria, key=self.model_selection_criteria.get) self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' # number of classes for network heads: n_foreground_classes + 1 (background) self.head_classes = self.num_classes + 1 if 'class' in self.prediction_tasks else 2 # feed +/- n neighbouring slices into channel dimension. set to None for no context. self.n_3D_context = None if self.n_3D_context is not None and self.dim == 2: self.n_channels *= (self.n_3D_context * 2 + 1) self.detect_while_training = True # disable the re-sampling of mask proposals to original size for speed-up. # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching), # mask outputs are optional. self.return_masks_in_train = True self.return_masks_in_val = True self.return_masks_in_test = True # feature map strides per pyramid level are inferred from architecture. anchor scales are set accordingly. self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]} # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.) self.rpn_anchor_scales = {'xy': [[4], [8], [16], [32]], 'z': [[1], [2], [4], [8]]} # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3. self.pyramid_levels = [0, 1, 2, 3] # number of feature maps in rpn. typically lowered in 3D to save gpu-memory. self.n_rpn_features = 512 if self.dim == 2 else 64 # anchor ratios and strides per position in feature maps. self.rpn_anchor_ratios = [0.5, 1., 2.] self.rpn_anchor_stride = 1 # Threshold for first stage (RPN) non-maximum suppression (NMS): LOWER == HARDER SELECTION self.rpn_nms_threshold = max(0.7, self.model_max_iou_resolution) # loss sampling settings. self.rpn_train_anchors_per_image = 32 self.train_rois_per_image = 6 # per batch_instance self.roi_positive_ratio = 0.5 self.anchor_matching_iou = 0.8 # k negative example candidates are drawn from a pool of size k*shem_poolsize (stochastic hard-example mining), # where k<=#positive examples. self.shem_poolsize = 6 self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3) self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5) self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10) self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]]) self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1], self.patch_size_3D[2], self.patch_size_3D[2]]) # y1,x1,y2,x2,z1,z2 if self.dim == 2: self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4] self.bbox_std_dev = self.bbox_std_dev[:4] self.window = self.window[:4] self.scale = self.scale[:4] self.plot_y_max = 1.5 self.n_plot_rpn_props = 5 if self.dim == 2 else 30 # per batch_instance (slice in 2D / patient in 3D) # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element. self.pre_nms_limit = 2000 if self.dim == 2 else 4000 # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True, # since proposals of the entire batch are forwarded through second stage as one "batch". self.roi_chunk_size = 1300 if self.dim == 2 else 500 self.post_nms_rois_training = 200 * (self.head_classes-1) if self.dim == 2 else 400 self.post_nms_rois_inference = 200 * (self.head_classes-1) # Final selection of detections (refine_detections) self.model_max_instances_per_batch_element = 9 if self.dim == 2 else 18 # per batch element and class. self.detection_nms_threshold = self.model_max_iou_resolution # needs to be > 0, otherwise all predictions are one cluster. self.model_min_confidence = 0.2 # iou for nms in box refining (directly after heads), should be >0 since ths>=x in mrcnn.py if self.dim == 2: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride))] for stride in self.backbone_strides['xy']]) else: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride)), int(np.ceil(self.patch_size[2] / stride_z))] for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z'] )]) if self.model == 'retina_net' or self.model == 'retina_unet': # whether to use focal loss or SHEM for loss-sample selection self.focal_loss = False # implement extra anchor-scales according to https://arxiv.org/abs/1708.02002 self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['xy']] self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['z']] self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3 # pre-selection of detections for NMS-speedup. per entire batch. self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002 self.anchor_matching_iou = 0.7 if self.model == 'retina_unet': self.operate_stride1 = True diff --git a/datasets/toy/data_loader.py b/datasets/toy/data_loader.py index f4bf28f..e5c4c8c 100644 --- a/datasets/toy/data_loader.py +++ b/datasets/toy/data_loader.py @@ -1,595 +1,595 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import sys sys.path.append('../') # works on cluster indep from where sbatch job is started import plotting as plg import numpy as np import os from multiprocessing import Lock from collections import OrderedDict import pandas as pd import pickle import time # batch generator tools from https://github.com/MIC-DKFZ/batchgenerators from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror from batchgenerators.transforms.abstract_transforms import Compose from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter from batchgenerators.transforms.spatial_transforms import SpatialTransform from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform sys.path.append(os.path.dirname(os.path.realpath(__file__))) import utils.dataloader_utils as dutils from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates def load_obj(file_path): with open(file_path, 'rb') as handle: return pickle.load(handle) class Dataset(dutils.Dataset): r""" Load a dict holding memmapped arrays and clinical parameters for each patient, evtly subset of those. If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to cf.data_dir. :param cf: config file :param folds: number of folds out of @params n_cv folds to include :param n_cv: number of total folds :return: dict with imgs, segs, pids, class_labels, observables """ def __init__(self, cf, logger, subset_ids=None, data_sourcedir=None, mode='train'): super(Dataset,self).__init__(cf, data_sourcedir=data_sourcedir) load_exact_gts = (mode=='test' or cf.val_mode=="val_patient") and self.cf.test_against_exact_gt p_df = pd.read_pickle(os.path.join(self.data_dir, cf.info_df_name)) if subset_ids is not None: p_df = p_df[p_df.pid.isin(subset_ids)] logger.info('subset: selected {} instances from df'.format(len(p_df))) pids = p_df.pid.tolist() #evtly copy data from data_sourcedir to data_dest if cf.server_env and not hasattr(cf, "data_dir"): file_subset = [os.path.join(self.data_dir, '{}.*'.format(pid)) for pid in pids] file_subset += [os.path.join(self.data_dir, '{}_seg.*'.format(pid)) for pid in pids] file_subset += [cf.info_df_name] if load_exact_gts: file_subset += [os.path.join(self.data_dir, '{}_exact_seg.*'.format(pid)) for pid in pids] self.copy_data(cf, file_subset=file_subset) img_paths = [os.path.join(self.data_dir, '{}.npy'.format(pid)) for pid in pids] seg_paths = [os.path.join(self.data_dir, '{}_seg.npy'.format(pid)) for pid in pids] if load_exact_gts: exact_seg_paths = [os.path.join(self.data_dir, '{}_exact_seg.npy'.format(pid)) for pid in pids] class_targets = p_df['class_ids'].tolist() rg_targets = p_df['regression_vectors'].tolist() if load_exact_gts: exact_rg_targets = p_df['undistorted_rg_vectors'].tolist() fg_slices = p_df['fg_slices'].tolist() self.data = OrderedDict() for ix, pid in enumerate(pids): self.data[pid] = {'data': img_paths[ix], 'seg': seg_paths[ix], 'pid': pid, 'fg_slices': np.array(fg_slices[ix])} if load_exact_gts: self.data[pid]['exact_seg'] = exact_seg_paths[ix] if 'class' in self.cf.prediction_tasks: self.data[pid]['class_targets'] = np.array(class_targets[ix], dtype='uint8') else: self.data[pid]['class_targets'] = np.ones_like(np.array(class_targets[ix]), dtype='uint8') if load_exact_gts: self.data[pid]['exact_class_targets'] = self.data[pid]['class_targets'] if any(['regression' in task for task in self.cf.prediction_tasks]): self.data[pid]['regression_targets'] = np.array(rg_targets[ix], dtype='float16') self.data[pid]["rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in rg_targets[ix]], dtype='uint8') if load_exact_gts: self.data[pid]['exact_regression_targets'] = np.array(exact_rg_targets[ix], dtype='float16') self.data[pid]["exact_rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in exact_rg_targets[ix]], dtype='uint8') cf.roi_items = cf.observables_rois[:] cf.roi_items += ['class_targets'] if any(['regression' in task for task in self.cf.prediction_tasks]): cf.roi_items += ['regression_targets'] cf.roi_items += ['rg_bin_targets'] self.set_ids = np.array(list(self.data.keys())) self.df = None class BatchGenerator(dutils.BatchGenerator): """ creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D) from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size. Actual patch_size is obtained after data augmentation. :param data: data dictionary as provided by 'load_dataset'. :param batch_size: number of patients to sample for the batch :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target """ def __init__(self, cf, data, sample_pids_w_replace=True, max_batches=None, raise_stop_iteration=False, seed=0): super(BatchGenerator, self).__init__(cf, data, sample_pids_w_replace=sample_pids_w_replace, max_batches=max_batches, raise_stop_iteration=raise_stop_iteration, seed=seed) self.chans = cf.channels if cf.channels is not None else np.index_exp[:] assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing" self.crop_margin = np.array(self.cf.patch_size) / 8. # min distance of ROI center to edge of cropped_patch. self.p_fg = 0.5 self.empty_samples_max_ratio = 0.6 self.balance_target_distribution(plot=sample_pids_w_replace) def generate_train_batch(self): # everything done in here is per batch # print statements in here get confusing due to multithreading batch_pids = self.get_batch_pids() batch_data, batch_segs, batch_patient_targets = [], [], [] batch_roi_items = {name: [] for name in self.cf.roi_items} # record roi count and empty count of classes in batch # empty count for no presence of resp. class in whole sample (empty slices in 2D/patients in 3D) batch_roi_counts = np.zeros((len(self.unique_ts),), dtype='uint32') batch_empty_counts = np.zeros((len(self.unique_ts),), dtype='uint32') for b in range(len(batch_pids)): patient = self._data[batch_pids[b]] data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis] seg = np.load(patient['seg'], mmap_mode='r').astype('uint8') (c, y, x, z) = data.shape if self.cf.dim == 2: elig_slices, choose_fg = [], False if len(patient['fg_slices']) > 0: if np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or np.random.rand( 1) <= self.p_fg: # fg is to be picked for tix in np.argsort(batch_roi_counts): # pick slices of patient that have roi of sought-for target # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero( patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] == self.unique_ts[tix]) > 0] if len(elig_slices) > 0: choose_fg = True break else: # pick bg elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices']) if len(elig_slices) > 0: sl_pick_ix = np.random.choice(elig_slices, size=None) else: sl_pick_ix = np.random.choice(z, size=None) data = data[..., sl_pick_ix] seg = seg[..., sl_pick_ix] spatial_shp = data[0].shape assert spatial_shp == seg.shape, "spatial shape incongruence betw. data and seg" if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]): new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))] data = dutils.pad_nd_image(data, (len(data), *new_shape)) seg = dutils.pad_nd_image(seg, new_shape) # eventual cropping to pre_crop_size: sample pixel from random ROI and shift center, # if possible, to that pixel, so that img still contains ROI after pre-cropping dim_cropflags = [spatial_shp[i] > self.cf.pre_crop_size[i] for i in range(len(spatial_shp))] if np.any(dim_cropflags): # sample pixel from random ROI and shift center, if possible, to that pixel if self.cf.dim==3: choose_fg = np.any(batch_empty_counts/self.batch_size>=self.empty_samples_max_ratio) or \ np.random.rand(1) <= self.p_fg if choose_fg and np.any(seg): available_roi_ids = np.unique(seg)[1:] for tix in np.argsort(batch_roi_counts): elig_roi_ids = available_roi_ids[patient[self.balance_target][available_roi_ids-1] == self.unique_ts[tix]] if len(elig_roi_ids)>0: seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None)) break roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)] assert seg[tuple(roi_anchor_pixel)] > 0 # sample the patch center coords. constrained by edges of image - pre_crop_size /2 and # distance to the selected ROI < patch_size /2 def get_cropped_centercoords(dim): low = np.max((self.cf.pre_crop_size[dim] // 2, roi_anchor_pixel[dim] - ( self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim]))) high = np.min((spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2, roi_anchor_pixel[dim] + ( self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim]))) if low >= high: # happens if lesion on the edge of the image. low = self.cf.pre_crop_size[dim] // 2 high = spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2 assert low < high, 'low greater equal high, data dimension {} too small, shp {}, patient {}, low {}, high {}'.format( dim, spatial_shp, patient['pid'], low, high) return np.random.randint(low=low, high=high) else: # sample crop center regardless of ROIs, not guaranteed to be empty def get_cropped_centercoords(dim): return np.random.randint(low=self.cf.pre_crop_size[dim] // 2, high=spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2) sample_seg_center = {} for dim in np.where(dim_cropflags)[0]: sample_seg_center[dim] = get_cropped_centercoords(dim) min_ = int(sample_seg_center[dim] - self.cf.pre_crop_size[dim] // 2) max_ = int(sample_seg_center[dim] + self.cf.pre_crop_size[dim] // 2) data = np.take(data, indices=range(min_, max_), axis=dim + 1) # +1 for channeldim seg = np.take(seg, indices=range(min_, max_), axis=dim) batch_data.append(data) batch_segs.append(seg[np.newaxis]) for o in batch_roi_items: #after loop, holds every entry of every batchpatient per observable batch_roi_items[o].append(patient[o]) if self.cf.dim == 3: for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero==0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 elif self.cf.dim == 2: for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero == 0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 batch = {'data': np.array(batch_data), 'seg': np.array(batch_segs).astype('uint8'), 'pid': batch_pids, 'roi_counts': batch_roi_counts, 'empty_counts': batch_empty_counts} for key,val in batch_roi_items.items(): #extend batch dic by entries of observables dic batch[key] = np.array(val) return batch class PatientBatchIterator(dutils.PatientBatchIterator): """ creates a test generator that iterates over entire given dataset returning 1 patient per batch. Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actually evaluation (done in 3D), if willing to accept speed-loss during training. Specific properties of toy data set: toy data may be created with added ground-truth noise. thus, there are exact ground truths (GTs) and noisy ground truths available. the normal or noisy GTs are used in training by the BatchGenerator. The PatientIterator, however, may use the exact GTs if set in configs. :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or batch_size = n_2D_patches in 2D . """ def __init__(self, cf, data, mode='test'): super(PatientBatchIterator, self).__init__(cf, data) self.patch_size = cf.patch_size_2D + [1] if cf.dim == 2 else cf.patch_size_3D self.chans = cf.channels if cf.channels is not None else np.index_exp[:] assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing" if (mode=="validation" and hasattr(self.cf, 'val_against_exact_gt') and self.cf.val_against_exact_gt) or \ (mode == 'test' and self.cf.test_against_exact_gt): self.gt_prefix = 'exact_' print("PatientIterator: Loading exact Ground Truths.") else: self.gt_prefix = '' self.patient_ix = 0 # running index over all patients in set def generate_train_batch(self, pid=None): if pid is None: pid = self.dataset_pids[self.patient_ix] patient = self._data[pid] # already swapped dimensions in pp from (c,)z,y,x to c,y,x,z or h,w,d to ease 2D/3D-case handling data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis] seg = np.load(patient[self.gt_prefix+'seg']).astype('uint8')[np.newaxis] data_shp_raw = data.shape plot_bg = data[self.cf.plot_bg_chan] if self.cf.plot_bg_chan not in self.chans else None data = data[self.chans] discarded_chans = len( [c for c in np.setdiff1d(np.arange(data_shp_raw[0]), self.chans) if c < self.cf.plot_bg_chan]) spatial_shp = data[0].shape # spatial dims need to be in order x,y,z assert spatial_shp == seg[0].shape, "spatial shape incongruence betw. data and seg" if np.any([spatial_shp[i] < ps for i, ps in enumerate(self.patch_size)]): new_shape = [np.max([spatial_shp[i], self.patch_size[i]]) for i in range(len(self.patch_size))] data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape. seg = dutils.pad_nd_image(seg, new_shape) if plot_bg is not None: plot_bg = dutils.pad_nd_image(plot_bg, new_shape) if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: # adds the batch dim here bc won't go through MTaugmenter out_data = data[np.newaxis] out_seg = seg[np.newaxis] if plot_bg is not None: out_plot_bg = plot_bg[np.newaxis] # data and seg shape: (1,c,x,y,z), where c=1 for seg batch_3D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_3D[o] = np.array([patient[self.gt_prefix+o]]) converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg) batch_3D = converter(**batch_3D) batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_3D["patient_" + o] = batch_3D[o] if self.cf.dim == 2: out_data = np.transpose(data, axes=(3, 0, 1, 2)).astype('float32') # (c,y,x,z) to (b=z,c,x,y), use z=b as batchdim out_seg = np.transpose(seg, axes=(3, 0, 1, 2)).astype('uint8') # (c,y,x,z) to (b=z,c,x,y) batch_2D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_2D[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(out_data), axis=0) converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg) batch_2D = converter(**batch_2D) if plot_bg is not None: out_plot_bg = np.transpose(plot_bg, axes=(2, 0, 1)).astype('float32') if self.cf.merge_2D_to_3D_preds: batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_3D[o] else: batch_2D.update({'patient_bb_target': batch_2D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_2D[o] out_batch = batch_3D if self.cf.dim == 3 else batch_2D out_batch.update({'pid': np.array([patient['pid']] * len(out_data))}) if self.cf.plot_bg_chan in self.chans and discarded_chans > 0: # len(self.chans[:self.cf.plot_bg_chan]) self.patch_size[ix] for ix in range(len(spatial_shp))]): patient_batch = out_batch print("patientiterator produced patched batch!") patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size) new_img_batch, new_seg_batch = [], [] for c in patch_crop_coords_list: new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:c[5]]) seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]] new_seg_batch.append(seg_patch) shps = [] for arr in new_img_batch: shps.append(arr.shape) data = np.array(new_img_batch) # (patches, c, x, y, z) seg = np.array(new_seg_batch) if self.cf.dim == 2: # all patches have z dimension 1 (slices). discard dimension data = data[..., 0] seg = seg[..., 0] patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'), 'pid': np.array([patient['pid']] * data.shape[0])} for o in self.cf.roi_items: patch_batch[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(patch_crop_coords_list), axis=0) #patient-wise (orig) batch info for putting the patches back together after prediction for o in self.cf.roi_items: patch_batch["patient_"+o] = patient_batch["patient_"+o] if self.cf.dim == 2: # this could also be named "unpatched_2d_roi_items" patch_batch["patient_" + o + "_2d"] = patient_batch[o] patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] if self.cf.dim == 2: patch_batch['patient_bb_target_2d'] = patient_batch['bb_target'] patch_batch['patient_data'] = patient_batch['data'] patch_batch['patient_seg'] = patient_batch['seg'] patch_batch['original_img_shape'] = patient_batch['original_img_shape'] if plot_bg is not None: patch_batch['patient_plot_bg'] = patient_batch['plot_bg'] converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, get_rois_from_seg=False, class_specific_seg=self.cf.class_specific_seg) patch_batch = converter(**patch_batch) out_batch = patch_batch self.patient_ix += 1 if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 return out_batch def create_data_gen_pipeline(cf, patient_data, do_aug=True, **kwargs): """ create mutli-threaded train/val/test batch generation and augmentation pipeline. :param patient_data: dictionary containing one dictionary per patient in the train/test subset. :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing) :return: multithreaded_generator """ # create instance of batch generator as first element in pipeline. data_gen = BatchGenerator(cf, patient_data, **kwargs) my_transforms = [] if do_aug: if cf.da_kwargs["mirror"]: mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes']) my_transforms.append(mirror_transform) spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop']) my_transforms.append(spatial_transform) else: my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg)) all_transforms = Compose(my_transforms) # multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms) multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=data_gen.n_filled_threads, seeds=range(data_gen.n_filled_threads)) return multithreaded_generator def get_train_generators(cf, logger, data_statistics=False): """ wrapper function for creating the training batch generator pipeline. returns the train/val generators. selects patients according to cv folds (generated by first run/fold of experiment): splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set) If cf.hold_out_test_set is True, adds the test split to the training data. """ dataset = Dataset(cf, logger) dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits) dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle')) set_splits = dataset.fg.splits test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1) train_ids = np.concatenate(set_splits, axis=0) - if cf.held_out_test_set: + if cf.hold_out_test_set: train_ids = np.concatenate((train_ids, test_ids), axis=0) test_ids = [] train_data = {k: v for (k, v) in dataset.data.items() if str(k) in train_ids} val_data = {k: v for (k, v) in dataset.data.items() if str(k) in val_ids} logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids), len(test_ids))) if data_statistics: dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids}, plot_dir= os.path.join(cf.plot_dir,"dataset")) batch_gen = {} batch_gen['train'] = create_data_gen_pipeline(cf, train_data, do_aug=cf.do_aug, sample_pids_w_replace=True) if cf.val_mode == 'val_patient': batch_gen['val_patient'] = PatientBatchIterator(cf, val_data, mode='validation') batch_gen['n_val'] = len(val_ids) if cf.max_val_patients=="all" else min(len(val_ids), cf.max_val_patients) elif cf.val_mode == 'val_sampling': batch_gen['n_val'] = int(np.ceil(len(val_data)/cf.batch_size)) if cf.num_val_batches == "all" else cf.num_val_batches # in current setup, val loader is used like generator. with max_batches being applied in train routine. batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, do_aug=False, sample_pids_w_replace=False, max_batches=None, raise_stop_iteration=False) return batch_gen def get_test_generator(cf, logger): """ if get_test_generators is possibly called multiple times in server env, every time of Dataset initiation rsync will check for copying the data; this should be okay since rsync will not copy if files already exist in destination. """ - if cf.held_out_test_set: + if cf.hold_out_test_set: sourcedir = cf.test_data_sourcedir test_ids = None else: sourcedir = None with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle: set_splits = pickle.load(handle) test_ids = set_splits[cf.fold] test_set = Dataset(cf, logger, subset_ids=test_ids, data_sourcedir=sourcedir, mode='test') logger.info("data set loaded with: {} test patients".format(len(test_set.set_ids))) batch_gen = {} batch_gen['test'] = PatientBatchIterator(cf, test_set.data) batch_gen['n_test'] = len(test_set.set_ids) if cf.max_test_patients=="all" else \ min(cf.max_test_patients, len(test_set.set_ids)) return batch_gen if __name__=="__main__": import utils.exp_utils as utils from datasets.toy.configs import Configs cf = Configs() total_stime = time.time() times = {} # cf.server_env = True # cf.data_dir = "experiments/dev_data" cf.exp_dir = "experiments/dev/" cf.plot_dir = cf.exp_dir + "plots" os.makedirs(cf.exp_dir, exist_ok=True) cf.fold = 0 logger = utils.get_logger(cf.exp_dir) gens = get_train_generators(cf, logger) train_loader = gens['train'] for i in range(0): stime = time.time() print("producing training batch nr ", i) ex_batch = next(train_loader) times["train_batch"] = time.time() - stime #experiments/dev/dev_exbatch_{}.png".format(i) plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=False) val_loader = gens['val_sampling'] stime = time.time() for i in range(1): ex_batch = next(val_loader) times["val_batch"] = time.time() - stime stime = time.time() #"experiments/dev/dev_exvalbatch_{}.png" plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=True) times["val_plot"] = time.time() - stime # test_loader = get_test_generator(cf, logger)["test"] stime = time.time() ex_batch = test_loader.generate_train_batch(pid=None) times["test_batch"] = time.time() - stime stime = time.time() plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png", vmin=0) times["test_patchbatch_plot"] = time.time() - stime print("Times recorded throughout:") for (k, v) in times.items(): print(k, "{:.2f}".format(v)) mins, secs = divmod((time.time() - total_stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) \ No newline at end of file diff --git a/datasets/toy_mdt/configs.py b/datasets/toy_mdt/configs.py index 2333921..9902d26 100644 --- a/datasets/toy_mdt/configs.py +++ b/datasets/toy_mdt/configs.py @@ -1,355 +1,355 @@ #!/usr/bin/env python # Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import sys import os sys.path.append(os.path.dirname(os.path.realpath(__file__))) import numpy as np from collections import namedtuple from default_configs import DefaultConfigs Label = namedtuple("Label", ['id', 'name', 'color']) class Configs(DefaultConfigs): def __init__(self, server_env=None): ######################### # Preprocessing # ######################### self.root_dir = '/home/gregor/datasets/toy_mdt' ######################### # I/O # ######################### # one out of [2, 3]. dimension the model operates in. self.dim = 2 DefaultConfigs.__init__(self, server_env, self.dim) # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn']. self.model = 'retina_unet' self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net') self.model_path = os.path.join(self.source_dir, self.model_path) # int [0 < dataset_size]. select n patients from dataset for prototyping. self.select_prototype_subset = None - self.held_out_test_set = True + self.hold_out_test_set = True # including val set. will be 3/4 train, 1/4 val. self.n_train_val_data = 2500 # choose one of the 3 toy experiments described in https://arxiv.org/pdf/1811.08661.pdf # one of ['donuts_shape', 'donuts_pattern', 'circles_scale']. toy_mode = 'donuts_shape_noise' # path to preprocessed data. self.info_df_name = 'info_df.pickle' self.pp_name = os.path.join(toy_mode, 'train') self.data_sourcedir = os.path.join(self.root_dir, self.pp_name) self.pp_test_name = os.path.join(toy_mode, 'test') self.test_data_sourcedir = os.path.join(self.root_dir, self.pp_test_name) # settings for deployment in cloud. if server_env: # path to preprocessed data. pp_root_dir = '/datasets/datasets_ramien/toy_exp/data' self.pp_name = os.path.join(toy_mode, 'train') self.data_sourcedir = os.path.join(pp_root_dir, self.pp_name) self.pp_test_name = os.path.join(toy_mode, 'test') self.test_data_sourcedir = os.path.join(pp_root_dir, self.pp_test_name) self.select_prototype_subset = None ######################### # Data Loader # ######################### # select modalities from preprocessed data self.channels = [0] self.n_channels = len(self.channels) self.plot_bg_chan = 0 # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation. self.pre_crop_size_2D = [320, 320] self.patch_size_2D = [320, 320] self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D self.pre_crop_size = self.pre_crop_size_2D if self.dim == 2 else self.pre_crop_size_3D # ratio of free sampled batch elements before class balancing is triggered # (>0 to include "empty"/background patches.) self.batch_random_ratio = 0.2 # set 2D network to operate in 3D images. self.merge_2D_to_3D_preds = False ######################### # Architecture # ######################### self.start_filts = 48 if self.dim == 2 else 18 self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2 self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50' self.norm = "instance_norm" # one of None, 'instance_norm', 'batch_norm' # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform') self.weight_init = "xavier_uniform" # compatibility self.regression_n_features = 1 self.num_classes = 2 # excluding bg self.num_seg_classes = 3 # incl bg ######################### # Schedule / Selection # ######################### self.num_epochs = 26 self.num_train_batches = 100 if self.dim == 2 else 200 self.batch_size = 20 if self.dim == 2 else 8 self.do_validation = True # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training) # the former is morge accurate, while the latter is faster (depending on volume size) self.val_mode = 'val_patient' # one of 'val_sampling' , 'val_patient' if self.val_mode == 'val_patient': self.max_val_patients = "all" # if 'None' iterates over entire val_set once. if self.val_mode == 'val_sampling': self.num_val_batches = 50 self.optimizer = "ADAMW" # set dynamic_lr_scheduling to True to apply LR scheduling with below settings. self.dynamic_lr_scheduling = True self.lr_decay_factor = 0.25 self.scheduling_patience = np.ceil(4800 / (self.num_train_batches * self.batch_size)) self.scheduling_criterion = 'donuts_ap' self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' self.weight_decay = 1e-5 self.exclude_from_wd = ["norm"] self.clip_norm = 200 ######################### # Testing / Plotting # ######################### - self.eval_test_fold_wise = True + self.ensemble_folds = False # set the top-n-epochs to be saved for temporal averaging in testing. self.save_n_models = 5 self.test_n_epochs = 5 self.test_aug_axes = (0, 1, (0, 1)) self.n_test_plots = 2 self.clustering = "wbc" self.clustering_iou = 1e-5 # set a minimum epoch number for saving in case of instabilities in the first phase of training. self.min_save_thresh = 0 if self.dim == 2 else 0 self.report_score_level = ['patient', 'rois'] # choose list from 'patient', 'rois' self.class_labels = [Label(0, 'bg', (*self.white, 0.)), Label(1, 'circles', (*self.orange, .9)), Label(2, 'donuts', (*self.blue, .9)),] if self.class_specific_seg: self.seg_labels = self.class_labels self.box_type2label = {label.name: label for label in self.box_labels} self.class_id2label = {label.id: label for label in self.class_labels} self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0} self.seg_id2label = {label.id: label for label in self.seg_labels} self.cmap = {label.id: label.color for label in self.seg_labels} self.metrics = ["ap", "auc", "dice"] self.patient_class_of_interest = 2 # patient metrics are only plotted for one class. self.ap_match_ious = [0.1] # list of ious to be evaluated for ap-scoring. self.model_selection_criteria = {name + "_ap": 1. for name in self.class_dict.values()}# criteria to average over for saving epochs. self.min_det_thresh = 0.1 # minimum confidence value to select predictions for evaluation. self.plot_prediction_histograms = True self.plot_stat_curves = False self.plot_class_ids = True ######################### # Data Augmentation # ######################### self.do_aug = False self.da_kwargs={ 'do_elastic_deform': True, 'alpha':(0., 1500.), 'sigma':(30., 50.), 'do_rotation':True, 'angle_x': (0., 2 * np.pi), 'angle_y': (0., 0), 'angle_z': (0., 0), 'do_scale': True, 'scale':(0.8, 1.1), 'random_crop':False, 'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3), 'border_mode_data': 'constant', 'border_cval_data': 0, 'order_data': 1 } if self.dim == 3: self.da_kwargs['do_elastic_deform'] = False self.da_kwargs['angle_x'] = (0, 0.0) self.da_kwargs['angle_y'] = (0, 0.0) #must be 0!! self.da_kwargs['angle_z'] = (0., 2 * np.pi) ######################### # Add model specifics # ######################### {'detection_fpn': self.add_det_fpn_configs, 'mrcnn': self.add_mrcnn_configs, 'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs, }[self.model]() def add_det_fpn_configs(self): self.learning_rate = [3 * 1e-4] * self.num_epochs self.dynamic_lr_scheduling = True self.scheduling_criterion = 'torch_loss' self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' self.n_roi_candidates = 4 if self.dim == 2 else 6 # max number of roi candidates to identify per image (slice in 2D, volume in 3D) # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') self.seg_loss_mode = 'wce' self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1., 1.] self.fp_dice_weight = 1 if self.dim == 2 else 1 # if <1, false positive predictions in foreground are penalized less. self.detection_min_confidence = 0.05 # how to determine score of roi: 'max' or 'median' self.score_det = 'max' def add_mrcnn_configs(self): # learning rate is a list with one entry per epoch. self.learning_rate = [3e-4] * self.num_epochs # disable mask head loss. (e.g. if no pixelwise annotations available) self.frcnn_mode = False # disable the re-sampling of mask proposals to original size for speed-up. # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching), # mask-outputs are optional. self.return_masks_in_val = True self.return_masks_in_test = False # set number of proposal boxes to plot after each epoch. self.n_plot_rpn_props = 2 if self.dim == 2 else 2 # number of classes for head networks: n_foreground_classes + 1 (background) self.head_classes = self.num_classes + 1 # feature map strides per pyramid level are inferred from architecture. self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]} # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.) self.rpn_anchor_scales = {'xy': [[8], [16], [32], [64]], 'z': [[2], [4], [8], [16]]} # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3. self.pyramid_levels = [0, 1, 2, 3] # number of feature maps in rpn. typically lowered in 3D to save gpu-memory. self.n_rpn_features = 512 if self.dim == 2 else 128 # anchor ratios and strides per position in feature maps. self.rpn_anchor_ratios = [0.5, 1., 2.] self.rpn_anchor_stride = 1 # Threshold for first stage (RPN) non-maximum suppression (NMS): LOWER == HARDER SELECTION self.rpn_nms_threshold = 0.7 if self.dim == 2 else 0.7 # loss sampling settings. self.rpn_train_anchors_per_image = 64 #per batch element self.train_rois_per_image = 2 #per batch element self.roi_positive_ratio = 0.5 self.anchor_matching_iou = 0.7 # factor of top-k candidates to draw from per negative sample (stochastic-hard-example-mining). # poolsize to draw top-k candidates from will be shem_poolsize * n_negative_samples. self.shem_poolsize = 4 self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3) self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5) self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10) self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1]]) self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1]]) if self.dim == 2: self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4] self.bbox_std_dev = self.bbox_std_dev[:4] self.window = self.window[:4] self.scale = self.scale[:4] # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element. self.pre_nms_limit = 3000 if self.dim == 2 else 6000 # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True, # since proposals of the entire batch are forwarded through second stage in as one "batch". self.roi_chunk_size = 800 if self.dim == 2 else 600 self.post_nms_rois_training = 500 if self.dim == 2 else 75 self.post_nms_rois_inference = 500 # Final selection of detections (refine_detections) self.model_max_instances_per_batch_element = 10 if self.dim == 2 else 30 # per batch element and class. self.detection_nms_threshold = 1e-5 # needs to be > 0, otherwise all predictions are one cluster. self.model_min_confidence = 0.1 if self.dim == 2: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride))] for stride in self.backbone_strides['xy']]) else: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride)), int(np.ceil(self.patch_size[2] / stride_z))] for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z'] )]) if self.model == 'retina_net' or self.model == 'retina_unet': # whether to use focal loss or SHEM for loss-sample selection self.focal_loss = False # implement extra anchor-scales according to retina-net publication. self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['xy']] self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['z']] self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3 self.n_rpn_features = 256 if self.dim == 2 else 64 # pre-selection of detections for NMS-speedup. per entire batch. self.pre_nms_limit = 10000 if self.dim == 2 else 50000 # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002 self.anchor_matching_iou = 0.5 if self.model == 'retina_unet': self.operate_stride1 = True diff --git a/datasets/toy_mdt/data_loader.py b/datasets/toy_mdt/data_loader.py index bee309d..166e348 100644 --- a/datasets/toy_mdt/data_loader.py +++ b/datasets/toy_mdt/data_loader.py @@ -1,379 +1,379 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import sys sys.path.append('../') # works on cluster indep from where sbatch job is started import plotting as plg import numpy as np import os from multiprocessing import Lock from collections import OrderedDict import pandas as pd import pickle import time # batch generator tools from https://github.com/MIC-DKFZ/batchgenerators from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror from batchgenerators.transforms.abstract_transforms import Compose from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter from batchgenerators.transforms.spatial_transforms import SpatialTransform from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform sys.path.append(os.path.dirname(os.path.realpath(__file__))) import utils.dataloader_utils as dutils from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates def load_obj(file_path): with open(file_path, 'rb') as handle: return pickle.load(handle) class Dataset(dutils.Dataset): r""" Load a dict holding memmapped arrays and clinical parameters for each patient, evtly subset of those. If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to cf.data_dir. :param cf: config file :param folds: number of folds out of @params n_cv folds to include :param n_cv: number of total folds :return: dict with imgs, segs, pids, class_labels, observables """ def __init__(self, cf, logger, subset_ids=None, data_sourcedir=None, mode='train'): super(Dataset,self).__init__(cf, data_sourcedir=data_sourcedir) p_df = pd.read_pickle(os.path.join(self.data_dir, cf.info_df_name)) if subset_ids is not None: p_df = p_df[p_df.pid.isin(subset_ids)] logger.info('subset: selected {} instances from df'.format(len(p_df))) pids = p_df.pid.tolist() #evtly copy data from data_sourcedir to data_dest if cf.server_env and not hasattr(cf, "data_dir"): file_subset = [os.path.join(self.data_dir, '{}.*'.format(pid)) for pid in pids] file_subset += [os.path.join(self.data_dir, '{}_seg.*'.format(pid)) for pid in pids] file_subset += [cf.info_df_name] self.copy_data(cf, file_subset=file_subset) img_paths = [os.path.join(self.data_dir, '{}.npy'.format(pid)) for pid in pids] seg_paths = [os.path.join(self.data_dir, '{}.npy'.format(pid)) for pid in pids] class_targets = p_df['class_id'].tolist() self.data = OrderedDict() for ix, pid in enumerate(pids): self.data[pid] = {'data': img_paths[ix], 'seg': seg_paths[ix], 'pid': pid} self.data[pid]['class_targets'] = np.array([class_targets[ix]], dtype='uint8') + 1 cf.roi_items = ['class_targets'] self.set_ids = np.array(list(self.data.keys())) self.df = None class BatchGenerator(dutils.BatchGenerator): """ creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D) from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size. Actual patch_size is obtained after data augmentation. :param data: data dictionary as provided by 'load_dataset'. :param batch_size: number of patients to sample for the batch :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target """ def __init__(self, cf, data, sample_pids_w_replace=True, max_batches=None, raise_stop_iteration=False, seed=0): super(BatchGenerator, self).__init__(cf, data, sample_pids_w_replace=sample_pids_w_replace, max_batches=max_batches, raise_stop_iteration=raise_stop_iteration, seed=seed) self.chans = cf.channels if cf.channels is not None else np.index_exp[:] assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing" self.crop_margin = np.array(self.cf.patch_size) / 8. # min distance of ROI center to edge of cropped_patch. self.p_fg = 0.5 self.empty_samples_max_ratio = 0.6 self.balance_target_distribution(plot=sample_pids_w_replace) def generate_train_batch(self): # everything done in here is per batch # print statements in here get confusing due to multithreading batch_pids = self.get_batch_pids() batch_data, batch_segs, batch_patient_targets = [], [], [] batch_roi_items = {name: [] for name in self.cf.roi_items} # record roi count and empty count of classes in batch # empty count for no presence of resp. class in whole sample (empty slices in 2D/patients in 3D) batch_roi_counts = np.zeros((len(self.unique_ts),), dtype='uint32') batch_empty_counts = np.zeros((len(self.unique_ts),), dtype='uint32') for b in range(len(batch_pids)): patient = self._data[batch_pids[b]] all_data = np.load(patient['data'], mmap_mode='r') data = all_data[0].astype('float16')[np.newaxis] seg = all_data[1].astype('uint8') spatial_shp = data[0].shape assert spatial_shp == seg.shape, "spatial shape incongruence betw. data and seg" if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]): new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))] data = dutils.pad_nd_image(data, (len(data), *new_shape)) seg = dutils.pad_nd_image(seg, new_shape) batch_data.append(data) batch_segs.append(seg[np.newaxis]) for o in batch_roi_items: #after loop, holds every entry of every batchpatient per observable batch_roi_items[o].append(patient[o]) for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero == 0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 batch = {'data': np.array(batch_data), 'seg': np.array(batch_segs).astype('uint8'), 'pid': batch_pids, 'roi_counts': batch_roi_counts, 'empty_counts': batch_empty_counts} for key,val in batch_roi_items.items(): #extend batch dic by entries of observables dic batch[key] = np.array(val) return batch class PatientBatchIterator(dutils.PatientBatchIterator): """ creates a test generator that iterates over entire given dataset returning 1 patient per batch. Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actually evaluation (done in 3D), if willing to accept speed-loss during training. Specific properties of toy data set: toy data may be created with added ground-truth noise. thus, there are exact ground truths (GTs) and noisy ground truths available. the normal or noisy GTs are used in training by the BatchGenerator. The PatientIterator, however, may use the exact GTs if set in configs. :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or batch_size = n_2D_patches in 2D . """ def __init__(self, cf, data, mode='test'): super(PatientBatchIterator, self).__init__(cf, data) self.patch_size = cf.patch_size_2D + [1] if cf.dim == 2 else cf.patch_size_3D self.chans = cf.channels if cf.channels is not None else np.index_exp[:] assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing" self.patient_ix = 0 # running index over all patients in set def generate_train_batch(self, pid=None): if pid is None: pid = self.dataset_pids[self.patient_ix] patient = self._data[pid] # already swapped dimensions in pp from (c,)z,y,x to c,y,x,z or h,w,d to ease 2D/3D-case handling all_data = np.load(patient['data'], mmap_mode='r') data = all_data[0].astype('float16')[np.newaxis] seg = all_data[1].astype('uint8')[np.newaxis] data_shp_raw = data.shape data = data[self.chans] spatial_shp = data[0].shape # spatial dims need to be in order x,y,z assert spatial_shp == seg[0].shape, "spatial shape incongruence betw. data and seg" out_data = data[None] out_seg = seg[None] batch_2D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_2D[o] = np.repeat(np.array([patient[o]]), len(out_data), axis=0) converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg) batch_2D = converter(**batch_2D) batch_2D.update({'patient_bb_target': batch_2D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_2D[o] out_batch = batch_2D out_batch.update({'pid': np.array([patient['pid']] * len(out_data))}) self.patient_ix += 1 if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 return out_batch def create_data_gen_pipeline(cf, patient_data, do_aug=True, **kwargs): """ create mutli-threaded train/val/test batch generation and augmentation pipeline. :param patient_data: dictionary containing one dictionary per patient in the train/test subset. :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing) :return: multithreaded_generator """ # create instance of batch generator as first element in pipeline. data_gen = BatchGenerator(cf, patient_data, **kwargs) my_transforms = [] if do_aug: if cf.da_kwargs["mirror"]: mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes']) my_transforms.append(mirror_transform) spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop']) my_transforms.append(spatial_transform) else: my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg)) all_transforms = Compose(my_transforms) # multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms) multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=data_gen.n_filled_threads, seeds=range(data_gen.n_filled_threads)) return multithreaded_generator def get_train_generators(cf, logger, data_statistics=False): """ wrapper function for creating the training batch generator pipeline. returns the train/val generators. selects patients according to cv folds (generated by first run/fold of experiment): splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set) If cf.hold_out_test_set is True, adds the test split to the training data. """ dataset = Dataset(cf, logger) assert cf.n_train_val_data <= len(dataset.set_ids), \ "requested {} train val samples, but dataset only has {} train val samples.".format( cf.n_train_val_data, len(dataset.set_ids)) train_ids = dataset.set_ids[:int(2*cf.n_train_val_data//3)] val_ids = dataset.set_ids[int(np.ceil(2*cf.n_train_val_data//3)):cf.n_train_val_data] train_data = {k: v for (k, v) in dataset.data.items() if str(k) in train_ids} val_data = {k: v for (k, v) in dataset.data.items() if str(k) in val_ids} logger.info("data set loaded with: {} train / {} val patients".format(len(train_ids), len(val_ids))) if data_statistics: dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids}, plot_dir= os.path.join(cf.plot_dir,"dataset")) batch_gen = {} batch_gen['train'] = create_data_gen_pipeline(cf, train_data, do_aug=cf.do_aug, sample_pids_w_replace=True) if cf.val_mode == 'val_patient': batch_gen['val_patient'] = PatientBatchIterator(cf, val_data, mode='validation') batch_gen['n_val'] = len(val_ids) if cf.max_val_patients=="all" else min(len(val_ids), cf.max_val_patients) elif cf.val_mode == 'val_sampling': batch_gen['n_val'] = int(np.ceil(len(val_data)/cf.batch_size)) if cf.num_val_batches == "all" else cf.num_val_batches # in current setup, val loader is used like generator. with max_batches being applied in train routine. batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, do_aug=False, sample_pids_w_replace=False, max_batches=None, raise_stop_iteration=False) return batch_gen def get_test_generator(cf, logger): """ if get_test_generators is possibly called multiple times in server env, every time of Dataset initiation rsync will check for copying the data; this should be okay since rsync will not copy if files already exist in destination. """ - if cf.held_out_test_set: + if cf.hold_out_test_set: sourcedir = cf.test_data_sourcedir test_ids = None else: sourcedir = None with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle: set_splits = pickle.load(handle) test_ids = set_splits[cf.fold] test_set = Dataset(cf, logger, subset_ids=test_ids, data_sourcedir=sourcedir, mode='test') logger.info("data set loaded with: {} test patients".format(len(test_set.set_ids))) batch_gen = {} batch_gen['test'] = PatientBatchIterator(cf, test_set.data) batch_gen['n_test'] = len(test_set.set_ids) if cf.max_test_patients=="all" else \ min(cf.max_test_patients, len(test_set.set_ids)) return batch_gen if __name__=="__main__": import utils.exp_utils as utils from datasets.toy.configs import Configs cf = Configs() total_stime = time.time() times = {} # cf.server_env = True # cf.data_dir = "experiments/dev_data" cf.exp_dir = "experiments/dev/" cf.plot_dir = cf.exp_dir + "plots" os.makedirs(cf.exp_dir, exist_ok=True) cf.fold = 0 logger = utils.get_logger(cf.exp_dir) gens = get_train_generators(cf, logger) train_loader = gens['train'] for i in range(0): stime = time.time() print("producing training batch nr ", i) ex_batch = next(train_loader) times["train_batch"] = time.time() - stime #experiments/dev/dev_exbatch_{}.png".format(i) plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=False) val_loader = gens['val_sampling'] stime = time.time() for i in range(1): ex_batch = next(val_loader) times["val_batch"] = time.time() - stime stime = time.time() #"experiments/dev/dev_exvalbatch_{}.png" plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=True) times["val_plot"] = time.time() - stime # test_loader = get_test_generator(cf, logger)["test"] stime = time.time() ex_batch = test_loader.generate_train_batch(pid=None) times["test_batch"] = time.time() - stime stime = time.time() plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png", vmin=0) times["test_patchbatch_plot"] = time.time() - stime print("Times recorded throughout:") for (k, v) in times.items(): print(k, "{:.2f}".format(v)) mins, secs = divmod((time.time() - total_stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) \ No newline at end of file diff --git a/default_configs.py b/default_configs.py index c415e98..fd4740b 100644 --- a/default_configs.py +++ b/default_configs.py @@ -1,204 +1,206 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Default Configurations script. Avoids changing configs of all experiments if general settings are to be changed.""" import os from collections import namedtuple boxLabel = namedtuple('boxLabel', ["name", "color"]) class DefaultConfigs: def __init__(self, server_env=None, dim=2): self.server_env = server_env self.cuda_benchmark = True self.sysmetrics_interval = 2 # set > 0 to record system metrics to tboard with this time span in seconds. ######################### # I/O # ######################### self.dim = dim # int [0 < dataset_size]. select n patients from dataset for prototyping. self.select_prototype_subset = None # some default paths. self.source_dir = os.path.dirname(os.path.realpath(__file__)) # current dir. self.backbone_path = os.path.join(self.source_dir, 'models/backbone.py') self.input_df_name = 'info_df.pickle' if server_env: self.select_prototype_subset = None ######################### # Colors/legends # ######################### # in part from solarized theme. self.black = (0.1, 0.05, 0.) self.gray = (0.514, 0.580, 0.588) self.beige = (1., 1., 0.85) self.white = (0.992, 0.965, 0.890) self.green = (0.659, 0.792, 0.251) # [168, 202, 64] self.dark_green = (0.522, 0.600, 0.000) # [133.11, 153. , 0. ] self.cyan = (0.165, 0.631, 0.596) # [ 42.075, 160.905, 151.98 ] self.bright_blue = (0.85, 0.95, 1.) self.blue = (0.149, 0.545, 0.824) # [ 37.995, 138.975, 210.12 ] self.dkfz_blue = (0, 75. / 255, 142. / 255) self.dark_blue = (0.027, 0.212, 0.259) # [ 6.885, 54.06 , 66.045] self.purple = (0.424, 0.443, 0.769) # [108.12 , 112.965, 196.095] self.aubergine = (0.62, 0.21, 0.44) # [ 157, 53 , 111] self.magenta = (0.827, 0.212, 0.510) # [210.885, 54.06 , 130.05 ] self.coral = (1., 0.251, 0.4) # [255,64,102] self.bright_red = (1., 0.15, 0.1) # [255, 38.25, 25.5] self.brighter_red = (0.863, 0.196, 0.184) # [220.065, 49.98 , 46.92 ] self.red = (0.87, 0.05, 0.01) # [ 223, 13, 2] self.dark_red = (0.6, 0.04, 0.005) self.orange = (0.91, 0.33, 0.125) # [ 232.05 , 84.15 , 31.875] self.dark_orange = (0.796, 0.294, 0.086) #[202.98, 74.97, 21.93] self.yellow = (0.95, 0.9, 0.02) # [ 242.25, 229.5 , 5.1 ] self.dark_yellow = (0.710, 0.537, 0.000) # [181.05 , 136.935, 0. ] self.color_palette = [self.blue, self.dark_blue, self.aubergine, self.green, self.yellow, self.orange, self.red, self.cyan, self.black] self.box_labels = [ # name color boxLabel("det", self.blue), boxLabel("prop", self.gray), boxLabel("pos_anchor", self.cyan), boxLabel("neg_anchor", self.cyan), boxLabel("neg_class", self.green), boxLabel("pos_class", self.aubergine), boxLabel("gt", self.red) ] # neg and pos in a medical sense, i.e., pos=positive diagnostic finding self.box_type2label = {label.name: label for label in self.box_labels} self.box_color_palette = {label.name: label.color for label in self.box_labels} # whether the input data is mono-channel or RGB/rgb self.has_colorchannels = False ######################### # Data Loader # ######################### #random seed for fold_generator and batch_generator. self.seed = 0 #number of threads for multithreaded tasks like batch generation, wcs, merge2dto3d self.n_workers = 16 if server_env else os.cpu_count() self.create_bounding_box_targets = True self.class_specific_seg = True # False if self.model=="mrcnn" else True self.max_val_patients = "all" ######################### # Architecture # ######################### self.prediction_tasks = ["class"] # 'class', 'regression_class', 'regression_kendall', 'regression_feindt' self.weight_decay = 0.0 # nonlinearity to be applied after convs with nonlinearity. one of 'relu' or 'leaky_relu' self.relu = 'relu' # if True initializes weights as specified in model script. else use default Pytorch init. self.weight_init = None # if True adds high-res decoder levels to feature pyramid: P1 + P0. (e.g. set to true in retina_unet configs) self.operate_stride1 = False ######################### # Optimization # ######################### self.optimizer = "ADAMW" # "ADAMW" or "SGD" or implemented additionals ######################### # Schedule # ######################### # number of folds in cross validation. self.n_cv_splits = 5 ######################### # Testing / Plotting # ######################### # perform mirroring at test time. (only XY. Z not done to not blow up predictions times). self.test_aug = True # if True, test data lies in a separate folder and is not part of the cross validation. - self.held_out_test_set = False - # if hold-out test set: eval each fold's parameters separately on the test set - self.eval_test_fold_wise = True + self.hold_out_test_set = False + # if hold-out test set: if ensemble_folds is True, predictions of all folds on the common hold-out test set + # are aggregated (like ensemble members). if False, each fold's parameters are evaluated separately on the test + # set and the evaluations are aggregated (like normal cross-validation folds). + self.ensemble_folds = False - # if held_out_test_set provided, ensemble predictions over models of all trained cv-folds. + # if hold_out_test_set provided, ensemble predictions over models of all trained cv-folds. self.ensemble_folds = False # what metrics to evaluate self.metrics = ['ap'] # whether to evaluate fold means when evaluating over more than one fold self.evaluate_fold_means = False # how often (in nr of epochs) to plot example batches during train/val self.plot_frequency = 1 # color specifications for all box_types in prediction_plot. self.box_color_palette = {'det': 'b', 'gt': 'r', 'neg_class': 'purple', 'prop': 'w', 'pos_class': 'g', 'pos_anchor': 'c', 'neg_anchor': 'c'} # scan over confidence score in evaluation to optimize it on the validation set. self.scan_det_thresh = False # plots roc-curves / prc-curves in evaluation. self.plot_stat_curves = False # if True: evaluate average precision per patient id and average over per-pid results, # instead of computing one ap over whole data set. self.per_patient_ap = False # threshold for clustering 2D box predictions to 3D Cubes. Overlap is computed in XY. self.merge_3D_iou = 0.1 # number or "all" for all self.max_test_patients = "all" ######################### # MRCNN # ######################### # if True, mask loss is not applied. used for data sets, where no pixel-wise annotations are provided. self.frcnn_mode = False self.return_masks_in_train = False # if True, unmolds masks in Mask R-CNN to full-res for plotting/monitoring. self.return_masks_in_val = False self.return_masks_in_test = False # needed if doing instance segmentation. evaluation not yet implemented. # add P6 to Feature Pyramid Network. self.sixth_pooling = False ######################### # RetinaNet # ######################### self.focal_loss = False self.focal_loss_gamma = 2. diff --git a/evaluator.py b/evaluator.py index 26d2be5..8d41079 100644 --- a/evaluator.py +++ b/evaluator.py @@ -1,983 +1,983 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import os from multiprocessing import Pool import pickle import time import numpy as np import pandas as pd from sklearn.metrics import roc_auc_score, average_precision_score from sklearn.metrics import roc_curve, precision_recall_curve from sklearn.metrics import mean_squared_error, mean_absolute_error, accuracy_score import torch import utils.model_utils as mutils import plotting as plg import warnings def get_roi_ap_from_df(inputs): ''' :param df: data frame. :param det_thresh: min_threshold for filtering out low confidence predictions. :param per_patient_ap: boolean flag. evaluate average precision per patient id and average over per-pid results, instead of computing one ap over whole data set. :return: average_precision (float) ''' df, det_thresh, per_patient_ap = inputs if per_patient_ap: pids_list = df.pid.unique() aps = [] for match_iou in df.match_iou.unique(): iou_df = df[df.match_iou == match_iou] for pid in pids_list: pid_df = iou_df[iou_df.pid == pid] all_p = len(pid_df[pid_df.class_label == 1]) pid_df = pid_df[(pid_df.det_type == 'det_fp') | (pid_df.det_type == 'det_tp')].sort_values('pred_score', ascending=False) pid_df = pid_df[pid_df.pred_score > det_thresh] if (len(pid_df) ==0 and all_p == 0): pass elif (len(pid_df) > 0 and all_p == 0): aps.append(0) else: aps.append(compute_roi_ap(pid_df, all_p)) return np.mean(aps) else: aps = [] for match_iou in df.match_iou.unique(): iou_df = df[df.match_iou == match_iou] # it's important to not apply the threshold before counting all_p in order to not lose the fn! all_p = len(iou_df[(iou_df.det_type == 'det_tp') | (iou_df.det_type == 'det_fn')]) # sorting out all entries that are not fp or tp or have confidence(=pred_score) <= detection_threshold iou_df = iou_df[(iou_df.det_type == 'det_fp') | (iou_df.det_type == 'det_tp')].sort_values('pred_score', ascending=False) iou_df = iou_df[iou_df.pred_score > det_thresh] if all_p>0: aps.append(compute_roi_ap(iou_df, all_p)) return np.mean(aps) def compute_roi_ap(df, all_p): """ adapted from: https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py :param df: dataframe containing class labels of predictions sorted in descending manner by their prediction score. :param all_p: number of all ground truth objects. (for denominator of recall.) :return: """ tp = df.class_label.values fp = (tp == 0) * 1 #recall thresholds, where precision will be measured R = np.linspace(0., 1., np.round((1. - 0.) / .01).astype(int) + 1, endpoint=True) tp_sum = np.cumsum(tp) fp_sum = np.cumsum(fp) n_dets = len(tp) rc = tp_sum / all_p pr = tp_sum / (fp_sum + tp_sum) # initialize precision array over recall steps (q=queries). q = [0. for _ in range(len(R))] # numpy is slow without cython optimization for accessing elements # python array gets significant speed improvement pr = pr.tolist() for i in range(n_dets - 1, 0, -1): if pr[i] > pr[i - 1]: pr[i - 1] = pr[i] #--> pr[i]<=pr[i-1] for all i since we want to consider the maximum #precision value for a queried interval # discretize empiric recall steps with given bins. assert np.all(rc[:-1]<=rc[1:]), "recall not sorted ascendingly" inds = np.searchsorted(rc, R, side='left') try: for rc_ix, pr_ix in enumerate(inds): q[rc_ix] = pr[pr_ix] except IndexError: #now q is filled with pr values up to first non-available index pass return np.mean(q) def roi_avp(inputs): ''' :param df: data frame. :param det_thresh: min_threshold for filtering out low confidence predictions. :param per_patient_ap: boolean flag. evaluate average precision per patient id and average over per-pid results, instead of computing one ap over whole data set. :return: average_precision (float) ''' df, det_thresh, per_patient_ap = inputs if per_patient_ap: pids_list = df.pid.unique() aps = [] for match_iou in df.match_iou.unique(): iou_df = df[df.match_iou == match_iou] for pid in pids_list: pid_df = iou_df[iou_df.pid == pid] all_p = len(pid_df[pid_df.class_label == 1]) mask = ((pid_df.rg_bins == pid_df.rg_bin_target) & (pid_df.det_type == 'det_tp')) | (pid_df.det_type == 'det_fp') pid_df = pid_df[mask].sort_values('pred_score', ascending=False) pid_df = pid_df[pid_df.pred_score > det_thresh] if (len(pid_df) ==0 and all_p == 0): pass elif (len(pid_df) > 0 and all_p == 0): aps.append(0) else: aps.append(compute_roi_ap(pid_df, all_p)) return np.mean(aps) else: aps = [] for match_iou in df.match_iou.unique(): iou_df = df[df.match_iou == match_iou] #it's important to not apply the threshold before counting all_positives! all_p = len(iou_df[(iou_df.det_type == 'det_tp') | (iou_df.det_type == 'det_fn')]) # filtering out tps which don't match rg_bin target at this point is same as reclassifying them as fn. # also sorting out all entries that are not fp or have confidence(=pred_score) <= detection_threshold mask = ((iou_df.rg_bins == iou_df.rg_bin_target) & (iou_df.det_type == 'det_tp')) | (iou_df.det_type == 'det_fp') iou_df = iou_df[mask].sort_values('pred_score', ascending=False) iou_df = iou_df[iou_df.pred_score > det_thresh] if all_p>0: aps.append(compute_roi_ap(iou_df, all_p)) return np.mean(aps) def compute_prc(df): """compute precision-recall curve with maximum precision per recall interval. :param df: :param all_p: # of all positive samples in data. :return: array: [precisions, recall query values] """ assert (df.class_label==1).any(), "cannot compute prc when no positives in data." all_p = len(df[(df.det_type == 'det_tp') | (df.det_type == 'det_fn')]) df = df[(df.det_type=="det_tp") | (df.det_type=="det_fp")] df = df.sort_values("pred_score", ascending=False) # recall thresholds, where precision will be measured scores = df.pred_score.values labels = df.class_label.values n_dets = len(scores) pr = np.zeros((n_dets,)) rc = pr.copy() for rank in range(n_dets): tp = np.count_nonzero(labels[:rank+1]==1) fp = np.count_nonzero(labels[:rank+1]==0) pr[rank] = tp/(tp+fp) rc[rank] = tp/all_p #after obj detection convention/ coco-dataset template: take maximum pr within intervals: # --> pr[i]<=pr[i-1] for all i since we want to consider the maximum # precision value for a queried interval for i in range(n_dets - 1, 0, -1): if pr[i] > pr[i - 1]: pr[i - 1] = pr[i] R = np.linspace(0., 1., np.round((1. - 0.) / .01).astype(int) + 1, endpoint=True)#precision queried at R points inds = np.searchsorted(rc, R, side='left') queries = np.zeros((len(R),)) try: for q_ix, rank in enumerate(inds): queries[q_ix] = pr[rank] except IndexError: pass return np.array((queries, R)) def RMSE(y_true, y_pred, weights=None): if len(y_true)>0: return np.sqrt(mean_squared_error(y_true, y_pred, sample_weight=weights)) else: return np.nan def MAE_w_std(y_true, y_pred, weights=None): if len(y_true)>0: y_true, y_pred = np.array(y_true), np.array(y_pred) deltas = np.abs(y_true-y_pred) mae = np.average(deltas, weights=weights, axis=0).item() skmae = mean_absolute_error(y_true, y_pred, sample_weight=weights) assert np.allclose(mae, skmae, atol=1e-6), "mae {}, sklearn mae {}".format(mae, skmae) std = np.std(weights*deltas) return mae, std else: return np.nan, np.nan def MAE(y_true, y_pred, weights=None): if len(y_true)>0: return mean_absolute_error(y_true, y_pred, sample_weight=weights) else: return np.nan def accuracy(y_true, y_pred, weights=None): if len(y_true)>0: return accuracy_score(y_true, y_pred, sample_weight=weights) else: return np.nan # noinspection PyCallingNonCallable class Evaluator(): """ Evaluates given results dicts. Can return results as updated monitor_metrics. Can save test data frames to file. """ def __init__(self, cf, logger, mode='test'): """ :param mode: either 'train', 'val_sampling', 'val_patient' or 'test'. handles prediction lists of different forms. """ self.cf = cf self.logger = logger self.mode = mode self.regress_flag = any(['regression' in task for task in self.cf.prediction_tasks]) self.plot_dir = self.cf.test_dir if self.mode == "test" else self.cf.plot_dir if self.cf.plot_prediction_histograms: self.hist_dir = os.path.join(self.plot_dir, 'histograms') os.makedirs(self.hist_dir, exist_ok=True) if self.cf.plot_stat_curves: self.curves_dir = os.path.join(self.plot_dir, 'stat_curves') os.makedirs(self.curves_dir, exist_ok=True) def eval_losses(self, batch_res_dicts): if hasattr(self.cf, "losses_to_monitor"): loss_names = self.cf.losses_to_monitor else: loss_names = {name for b_res_dict in batch_res_dicts for name in b_res_dict if 'loss' in name} self.epoch_losses = {l_name: torch.tensor([b_res_dict[l_name] for b_res_dict in batch_res_dicts if l_name in b_res_dict.keys()]).mean().item() for l_name in loss_names} def eval_segmentations(self, batch_res_dicts, pid_list): batch_dices = [b_res_dict['batch_dices'] for b_res_dict in batch_res_dicts if 'batch_dices' in b_res_dict.keys()] # shape (n_batches, n_seg_classes) if len(batch_dices) > 0: batch_dices = np.array(batch_dices) # dims n_batches x 1 in sampling / n_test_epochs x n_classes assert batch_dices.shape[1] == self.cf.num_seg_classes, "bdices shp {}, n seg cl {}, pid lst len {}".format( batch_dices.shape, self.cf.num_seg_classes, len(pid_list)) self.seg_df = pd.DataFrame() for seg_id in range(batch_dices.shape[1]): self.seg_df[self.cf.seg_id2label[seg_id].name + "_dice"] = batch_dices[:, seg_id] # one row== one batch, one column== one class # self.seg_df[self.cf.seg_id2label[seg_id].name+"_dice"] = np.concatenate(batch_dices[:,:,seg_id]) self.seg_df['fold'] = self.cf.fold if self.mode == "val_patient" or self.mode == "test": # need to make it more conform between sampling and patient-mode self.seg_df["pid"] = [pid for pix, pid in enumerate(pid_list)] # for b_inst in batch_inst_boxes[pix]] else: self.seg_df["pid"] = np.nan def eval_boxes(self, batch_res_dicts, pid_list, obj_cl_dict, obj_cl_identifiers={"gt":'class_targets', "pred":'box_pred_class_id'}): """ :param batch_res_dicts: :param pid_list: [pid_0, pid_1, ...] :return: """ if self.mode == 'train' or self.mode == 'val_sampling': # one pid per batch element # batch_size > 1, with varying patients across batch: # [[[results_0, ...], [pid_0, ...]], [[results_n, ...], [pid_n, ...]], ...] # -> [results_0, results_1, ..] batch_inst_boxes = [b_res_dict['boxes'] for b_res_dict in batch_res_dicts] # len: nr of batches in epoch batch_inst_boxes = [[b_inst_boxes] for whole_batch_boxes in batch_inst_boxes for b_inst_boxes in whole_batch_boxes] # len: batch instances of whole epoch assert np.all(len(b_boxes_list) == self.cf.batch_size for b_boxes_list in batch_inst_boxes) elif self.mode == "val_patient" or self.mode == "test": # patient processing, one element per batch = one patient. # [[results_0, pid_0], [results_1, pid_1], ...] -> [results_0, results_1, ..] # in patientbatchiterator there is only one pid per batch batch_inst_boxes = [b_res_dict['boxes'] for b_res_dict in batch_res_dicts] # in patient mode not actually per batch instance, but per whole batch! if hasattr(self.cf, "eval_test_separately") and self.cf.eval_test_separately: """ you could write your own routines to add GTs to raw predictions for evaluation. implemented standard is: cf.eval_test_separately = False or not set --> GTs are saved at same time and in same file as raw prediction results. """ raise NotImplementedError assert len(batch_inst_boxes) == len(pid_list) df_list_preds = [] df_list_labels = [] df_list_class_preds = [] df_list_pids = [] df_list_type = [] df_list_match_iou = [] df_list_n_missing = [] df_list_regressions = [] df_list_rg_targets = [] df_list_rg_bins = [] df_list_rg_bin_targets = [] df_list_rg_uncs = [] for match_iou in self.cf.ap_match_ious: self.logger.info('evaluating with ap_match_iou: {}'.format(match_iou)) for cl in list(obj_cl_dict.keys()): for pix, pid in enumerate(pid_list): len_df_list_before_patient = len(df_list_pids) # input of each batch element is a list of boxes, where each box is a dictionary. for b_inst_ix, b_boxes_list in enumerate(batch_inst_boxes[pix]): b_tar_boxes = [] b_cand_boxes, b_cand_scores, b_cand_n_missing = [], [], [] if self.regress_flag: b_tar_regs, b_tar_rg_bins = [], [] b_cand_regs, b_cand_rg_bins, b_cand_rg_uncs = [], [], [] for box in b_boxes_list: # each box is either gt or detection or proposal/anchor # we need all gts in the same order & all dets in same order if box['box_type'] == 'gt' and box[obj_cl_identifiers["gt"]] == cl: b_tar_boxes.append(box["box_coords"]) if self.regress_flag: b_tar_regs.append(np.array(box['regression_targets'], dtype='float32')) b_tar_rg_bins.append(box['rg_bin_targets']) if box['box_type'] == 'det' and box[obj_cl_identifiers["pred"]] == cl: b_cand_boxes.append(box["box_coords"]) b_cand_scores.append(box["box_score"]) b_cand_n_missing.append(box["cluster_n_missing"] if 'cluster_n_missing' in box.keys() else np.nan) if self.regress_flag: b_cand_regs.append(box["regression"]) b_cand_rg_bins.append(box["rg_bin"]) b_cand_rg_uncs.append(box["rg_uncertainty"] if 'rg_uncertainty' in box.keys() else np.nan) b_tar_boxes = np.array(b_tar_boxes) b_cand_boxes, b_cand_scores, b_cand_n_missing = np.array(b_cand_boxes), np.array(b_cand_scores), np.array(b_cand_n_missing) if self.regress_flag: b_tar_regs, b_tar_rg_bins = np.array(b_tar_regs), np.array(b_tar_rg_bins) b_cand_regs, b_cand_rg_bins, b_cand_rg_uncs = np.array(b_cand_regs), np.array(b_cand_rg_bins), np.array(b_cand_rg_uncs) # check if predictions and ground truth boxes exist and match them according to match_iou. if not 0 in b_cand_boxes.shape and not 0 in b_tar_boxes.shape: assert np.all(np.round(b_cand_scores,6) <= 1.), "there is a box score>1: {}".format(b_cand_scores[~(b_cand_scores<=1.)]) #coords_check = np.array([len(coords)==self.cf.dim*2 for coords in b_cand_boxes]) #assert np.all(coords_check), "cand box with wrong bcoords dim: {}, mode: {}".format(b_cand_boxes[~coords_check], self.mode) expected_dim = len(b_cand_boxes[0]) assert np.all([len(coords) == expected_dim for coords in b_tar_boxes]), \ "gt/cand box coords mismatch, expected dim: {}.".format(expected_dim) # overlaps: shape len(cand_boxes) x len(tar_boxes) overlaps = mutils.compute_overlaps(b_cand_boxes, b_tar_boxes) # match_cand_ixs: shape (nr_of_matches,) # theses indices are the indices of b_cand_boxes match_cand_ixs = np.argwhere(np.max(overlaps, axis=1) > match_iou)[:, 0] non_match_cand_ixs = np.argwhere(np.max(overlaps, 1) <= match_iou)[:, 0] # the corresponding gt assigned to the pred boxes by highest iou overlap, # i.e., match_gt_ixs holds index into b_tar_boxes for each entry in match_cand_ixs, # i.e., gt_ixs and cand_ixs are paired via their position in their list # (cand_ixs[j] corresponds to gt_ixs[j]) match_gt_ixs = np.argmax(overlaps[match_cand_ixs, :], axis=1) if \ not 0 in match_cand_ixs.shape else np.array([]) assert len(match_gt_ixs)==len(match_cand_ixs) #match_gt_ixs: shape (nr_of_matches,) or 0 non_match_gt_ixs = np.array( [ii for ii in np.arange(b_tar_boxes.shape[0]) if ii not in match_gt_ixs]) unique, counts = np.unique(match_gt_ixs, return_counts=True) # check for double assignments, i.e. two predictions having been assigned to the same gt. # according to the COCO-metrics, only one prediction counts as true positive, the rest counts as # false positive. This case is supposed to be avoided by the model itself by, # e.g. using a low enough NMS threshold. if np.any(counts > 1): double_match_gt_ixs = unique[np.argwhere(counts > 1)[:, 0]] keep_max = [] double_match_list = [] for dg in double_match_gt_ixs: double_match_cand_ixs = match_cand_ixs[np.argwhere(match_gt_ixs == dg)] keep_max.append(double_match_cand_ixs[np.argmax(b_cand_scores[double_match_cand_ixs])]) double_match_list += [ii for ii in double_match_cand_ixs] fp_ixs = np.array([ii for ii in match_cand_ixs if (ii in double_match_list and ii not in keep_max)]) # count as fp: boxes that match gt above match_iou threshold but have not highest class confidence score match_gt_ixs = np.array([gt_ix for ii, gt_ix in enumerate(match_gt_ixs) if match_cand_ixs[ii] not in fp_ixs]) match_cand_ixs = np.array([cand_ix for cand_ix in match_cand_ixs if cand_ix not in fp_ixs]) assert len(match_gt_ixs) == len(match_cand_ixs) df_list_preds += [ii for ii in b_cand_scores[fp_ixs]] df_list_labels += [0] * fp_ixs.shape[0] # means label==gt==0==bg for all these fp_ixs df_list_class_preds += [cl] * fp_ixs.shape[0] df_list_n_missing += [n for n in b_cand_n_missing[fp_ixs]] if self.regress_flag: df_list_regressions += [r for r in b_cand_regs[fp_ixs]] df_list_rg_bins += [r for r in b_cand_rg_bins[fp_ixs]] df_list_rg_uncs += [r for r in b_cand_rg_uncs[fp_ixs]] df_list_rg_targets += [[0.]*self.cf.regression_n_features] * fp_ixs.shape[0] df_list_rg_bin_targets += [0.] * fp_ixs.shape[0] df_list_pids += [pid] * fp_ixs.shape[0] df_list_type += ['det_fp'] * fp_ixs.shape[0] # matched/tp: if not 0 in match_cand_ixs.shape: df_list_preds += list(b_cand_scores[match_cand_ixs]) df_list_labels += [1] * match_cand_ixs.shape[0] df_list_class_preds += [cl] * match_cand_ixs.shape[0] df_list_n_missing += list(b_cand_n_missing[match_cand_ixs]) if self.regress_flag: df_list_regressions += list(b_cand_regs[match_cand_ixs]) df_list_rg_bins += list(b_cand_rg_bins[match_cand_ixs]) df_list_rg_uncs += list(b_cand_rg_uncs[match_cand_ixs]) assert len(match_cand_ixs)==len(match_gt_ixs) df_list_rg_targets += list(b_tar_regs[match_gt_ixs]) df_list_rg_bin_targets += list(b_tar_rg_bins[match_gt_ixs]) df_list_pids += [pid] * match_cand_ixs.shape[0] df_list_type += ['det_tp'] * match_cand_ixs.shape[0] # rest fp: if not 0 in non_match_cand_ixs.shape: df_list_preds += list(b_cand_scores[non_match_cand_ixs]) df_list_labels += [0] * non_match_cand_ixs.shape[0] df_list_class_preds += [cl] * non_match_cand_ixs.shape[0] df_list_n_missing += list(b_cand_n_missing[non_match_cand_ixs]) if self.regress_flag: df_list_regressions += list(b_cand_regs[non_match_cand_ixs]) df_list_rg_bins += list(b_cand_rg_bins[non_match_cand_ixs]) df_list_rg_uncs += list(b_cand_rg_uncs[non_match_cand_ixs]) df_list_rg_targets += [[0.]*self.cf.regression_n_features] * non_match_cand_ixs.shape[0] df_list_rg_bin_targets += [0.] * non_match_cand_ixs.shape[0] df_list_pids += [pid] * non_match_cand_ixs.shape[0] df_list_type += ['det_fp'] * non_match_cand_ixs.shape[0] # fn: if not 0 in non_match_gt_ixs.shape: df_list_preds += [0] * non_match_gt_ixs.shape[0] df_list_labels += [1] * non_match_gt_ixs.shape[0] df_list_class_preds += [cl] * non_match_gt_ixs.shape[0] df_list_n_missing += [np.nan] * non_match_gt_ixs.shape[0] if self.regress_flag: df_list_regressions += [[0.]*self.cf.regression_n_features] * non_match_gt_ixs.shape[0] df_list_rg_bins += [0.] * non_match_gt_ixs.shape[0] df_list_rg_uncs += [np.nan] * non_match_gt_ixs.shape[0] df_list_rg_targets += list(b_tar_regs[non_match_gt_ixs]) df_list_rg_bin_targets += list(b_tar_rg_bins[non_match_gt_ixs]) df_list_pids += [pid] * non_match_gt_ixs.shape[0] df_list_type += ['det_fn'] * non_match_gt_ixs.shape[0] # only fp: if not 0 in b_cand_boxes.shape and 0 in b_tar_boxes.shape: # means there is no gt in all samples! any preds have to be fp. df_list_preds += list(b_cand_scores) df_list_labels += [0] * b_cand_boxes.shape[0] df_list_class_preds += [cl] * b_cand_boxes.shape[0] df_list_n_missing += list(b_cand_n_missing) if self.regress_flag: df_list_regressions += list(b_cand_regs) df_list_rg_bins += list(b_cand_rg_bins) df_list_rg_uncs += list(b_cand_rg_uncs) df_list_rg_targets += [[0.]*self.cf.regression_n_features] * b_cand_boxes.shape[0] df_list_rg_bin_targets += [0.] * b_cand_boxes.shape[0] df_list_pids += [pid] * b_cand_boxes.shape[0] df_list_type += ['det_fp'] * b_cand_boxes.shape[0] # only fn: if 0 in b_cand_boxes.shape and not 0 in b_tar_boxes.shape: df_list_preds += [0] * b_tar_boxes.shape[0] df_list_labels += [1] * b_tar_boxes.shape[0] df_list_class_preds += [cl] * b_tar_boxes.shape[0] df_list_n_missing += [np.nan] * b_tar_boxes.shape[0] if self.regress_flag: df_list_regressions += [[0.]*self.cf.regression_n_features] * b_tar_boxes.shape[0] df_list_rg_bins += [0.] * b_tar_boxes.shape[0] df_list_rg_uncs += [np.nan] * b_tar_boxes.shape[0] df_list_rg_targets += list(b_tar_regs) df_list_rg_bin_targets += list(b_tar_rg_bins) df_list_pids += [pid] * b_tar_boxes.shape[0] df_list_type += ['det_fn'] * b_tar_boxes.shape[0] # empty patient with 0 detections needs empty patient score, in order to not disappear from stats. # filtered out for roi-level evaluation later. During training (and val_sampling), # tn are assigned per sample independently of associated patients. # i.e., patient_tn is also meant as sample_tn if a list of samples is evaluated instead of whole patient if len(df_list_pids) == len_df_list_before_patient: df_list_preds += [0] df_list_labels += [0] df_list_class_preds += [cl] df_list_n_missing += [np.nan] if self.regress_flag: df_list_regressions += [[0.]*self.cf.regression_n_features] df_list_rg_bins += [0.] df_list_rg_uncs += [np.nan] df_list_rg_targets += [[0.]*self.cf.regression_n_features] df_list_rg_bin_targets += [0.] df_list_pids += [pid] df_list_type += ['patient_tn'] # true negative: no ground truth boxes, no detections. df_list_match_iou += [match_iou] * (len(df_list_preds) - len(df_list_match_iou)) self.test_df = pd.DataFrame() self.test_df['pred_score'] = df_list_preds self.test_df['class_label'] = df_list_labels # class labels are gt, 0,1, only indicate neg/pos (or bg/fg) remapped from all classes self.test_df['pred_class'] = df_list_class_preds # can be diff than 0,1 self.test_df['pid'] = df_list_pids self.test_df['det_type'] = df_list_type self.test_df['fold'] = self.cf.fold self.test_df['match_iou'] = df_list_match_iou self.test_df['cluster_n_missing'] = df_list_n_missing if self.regress_flag: self.test_df['regressions'] = df_list_regressions self.test_df['rg_targets'] = df_list_rg_targets self.test_df['rg_uncertainties'] = df_list_rg_uncs self.test_df['rg_bins'] = df_list_rg_bins # super weird error: pandas does not properly add an attribute if column is named "rg_bin_targets" ... ?!? self.test_df['rg_bin_target'] = df_list_rg_bin_targets assert hasattr(self.test_df, "rg_bin_target") #fn_df = self.test_df[self.test_df["det_type"] == "det_fn"] pass def evaluate_predictions(self, results_list, monitor_metrics=None): """ Performs the matching of predicted boxes and ground truth boxes. Loops over list of matching IoUs and foreground classes. Resulting info of each prediction is stored as one line in an internal dataframe, with the keys: det_type: 'tp' (true positive), 'fp' (false positive), 'fn' (false negative), 'tn' (true negative) pred_class: foreground class which the object predicts. pid: corresponding patient-id. pred_score: confidence score [0, 1] fold: corresponding fold of CV. match_iou: utilized IoU for matching. :param results_list: list of model predictions. Either from train/val_sampling (patch processing) for monitoring with form: [[[results_0, ...], [pid_0, ...]], [[results_n, ...], [pid_n, ...]], ...] Or from val_patient/testing (patient processing), with form: [[results_0, pid_0], [results_1, pid_1], ...]) :param monitor_metrics (optional): dict of dicts with all metrics of previous epochs. :return monitor_metrics: if provided (during training), return monitor_metrics now including results of current epoch. """ # gets results_list = [[batch_instances_box_lists], [batch_instances_pids]]*n_batches # we want to evaluate one batch_instance (= 2D or 3D image) at a time. self.logger.info('evaluating in mode {}'.format(self.mode)) batch_res_dicts = [batch[0] for batch in results_list] # len: nr of batches in epoch if self.mode == 'train' or self.mode=='val_sampling': # one pid per batch element # [[[results_0, ...], [pid_0, ...]], [[results_n, ...], [pid_n, ...]], ...] # -> [pid_0, pid_1, ...] # additional list wrapping to make conform with below per-patient batches, where one pid is linked to more than one batch instance pid_list = [batch_instance_pid for batch in results_list for batch_instance_pid in batch[1]] elif self.mode == "val_patient" or self.mode=="test": # [[results_0, pid_0], [results_1, pid_1], ...] -> [pid_0, pid_1, ...] # in patientbatchiterator there is only one pid per batch pid_list = [np.unique(batch[1]) for batch in results_list] assert np.all([len(pid)==1 for pid in pid_list]), "pid list in patient-eval mode, should only contain a single scalar per patient: {}".format(pid_list) pid_list = [pid[0] for pid in pid_list] else: raise Exception("undefined run mode encountered") self.eval_losses(batch_res_dicts) self.eval_segmentations(batch_res_dicts, pid_list) self.eval_boxes(batch_res_dicts, pid_list, self.cf.class_dict) if monitor_metrics is not None: # return all_stats, updated monitor_metrics return self.return_metrics(self.test_df, self.cf.class_dict, monitor_metrics) def return_metrics(self, df, obj_cl_dict, monitor_metrics=None, boxes_only=False): """ Calculates metric scores for internal data frame. Called directly from evaluate_predictions during training for monitoring, or from score_test_df during inference (for single folds or aggregated test set). Loops over foreground classes and score_levels ('roi' and/or 'patient'), gets scores and stores them. Optionally creates plots of prediction histograms and ROC/PR curves. :param df: Data frame that holds evaluated predictions. :param obj_cl_dict: Dict linking object-class ids to object-class names. E.g., {1: "bikes", 2 : "cars"}. Set in configs as cf.class_dict. :param monitor_metrics: dict of dicts with all metrics of previous epochs. This function adds metrics for current epoch and returns the same object. :param boxes_only: whether to produce metrics only for the boxes, not the segmentations. :return: all_stats: list. Contains dicts with resulting scores for each combination of foreground class and score_level. :return: monitor_metrics """ # -------------- monitoring independent of class, score level ------------ if monitor_metrics is not None: for l_name in self.epoch_losses: monitor_metrics[l_name] = [self.epoch_losses[l_name]] # -------------- metrics calc dependent on class, score level ------------ all_stats = [] # all_stats: one entry per score_level per class for cl in list(obj_cl_dict.keys()): # bg eval is neglected cl_name = obj_cl_dict[cl] cl_df = df[df.pred_class == cl] if hasattr(self, "seg_df") and not boxes_only: dice_col = self.cf.seg_id2label[cl].name+"_dice" seg_cl_df = self.seg_df.loc[:,['pid', dice_col, 'fold']] for score_level in self.cf.report_score_level: stats_dict = {} stats_dict['name'] = 'fold_{} {} {}'.format(self.cf.fold, score_level, cl_name) # -------------- RoI-based ----------------- if score_level == 'rois': stats_dict['auc'] = np.nan stats_dict['roc'] = np.nan if monitor_metrics is not None: tn = len(cl_df[cl_df.det_type == "patient_tn"]) tp = len(cl_df[(cl_df.det_type == "det_tp")&(cl_df.pred_score>self.cf.min_det_thresh)]) fp = len(cl_df[(cl_df.det_type == "det_fp")&(cl_df.pred_score>self.cf.min_det_thresh)]) fn = len(cl_df[cl_df.det_type == "det_fn"]) sens = np.divide(tp, (fn + tp)) monitor_metrics.update({"Bin_Stats/" + cl_name + "_fp": [fp], "Bin_Stats/" + cl_name + "_tp": [tp], "Bin_Stats/" + cl_name + "_fn": [fn], "Bin_Stats/" + cl_name + "_tn": [tn], "Bin_Stats/" + cl_name + "_sensitivity": [sens]}) # list wrapping only needed bc other metrics are recorded over all epochs; spec_df = cl_df[cl_df.det_type != 'patient_tn'] if self.regress_flag: # filter false negatives out for regression-only eval since regressor didn't predict truncd_df = spec_df[(((spec_df.det_type == "det_fp") | ( spec_df.det_type == "det_tp")) & spec_df.pred_score > self.cf.min_det_thresh)] truncd_df_tp = truncd_df[truncd_df.det_type == "det_tp"] weights, weights_tp = truncd_df.pred_score.tolist(), truncd_df_tp.pred_score.tolist() y_true, y_pred = truncd_df.rg_targets.tolist(), truncd_df.regressions.tolist() stats_dict["rg_RMSE"] = RMSE(y_true, y_pred) stats_dict["rg_MAE"] = MAE(y_true, y_pred) stats_dict["rg_RMSE_weighted"] = RMSE(y_true, y_pred, weights) stats_dict["rg_MAE_weighted"] = MAE(y_true, y_pred, weights) y_true, y_pred = truncd_df_tp.rg_targets.tolist(), truncd_df_tp.regressions.tolist() stats_dict["rg_MAE_weighted_tp"] = MAE(y_true, y_pred, weights_tp) stats_dict["rg_MAE_w_std_weighted_tp"] = MAE_w_std(y_true, y_pred, weights_tp) y_true, y_pred = truncd_df.rg_bin_target.tolist(), truncd_df.rg_bins.tolist() stats_dict["rg_bin_accuracy"] = accuracy(y_true, y_pred) stats_dict["rg_bin_accuracy_weighted"] = accuracy(y_true, y_pred, weights) y_true, y_pred = truncd_df_tp.rg_bin_target.tolist(), truncd_df_tp.rg_bins.tolist() stats_dict["rg_bin_accuracy_weighted_tp"] = accuracy(y_true, y_pred, weights_tp) if np.any(~truncd_df.rg_uncertainties.isna()): # det_fn are expected to be NaN so they drop out in means stats_dict.update({"rg_uncertainty": truncd_df.rg_uncertainties.mean(), "rg_uncertainty_tp": truncd_df_tp.rg_uncertainties.mean(), "rg_uncertainty_tp_weighted": (truncd_df_tp.rg_uncertainties * truncd_df_tp.pred_score).sum() / truncd_df_tp.pred_score.sum() }) if (spec_df.class_label==1).any(): stats_dict['ap'] = get_roi_ap_from_df((spec_df, self.cf.min_det_thresh, self.cf.per_patient_ap)) stats_dict['prc'] = precision_recall_curve(spec_df.class_label.tolist(), spec_df.pred_score.tolist()) if self.regress_flag: stats_dict['avp'] = roi_avp((spec_df, self.cf.min_det_thresh, self.cf.per_patient_ap)) else: stats_dict['ap'] = np.nan stats_dict['prc'] = np.nan stats_dict['avp'] = np.nan # np.nan is formattable by __format__ as a float, None-type is not if hasattr(self, "seg_df") and not boxes_only: stats_dict["dice"] = seg_cl_df.loc[:,dice_col].mean() # mean per all rois in this epoch stats_dict["dice_std"] = seg_cl_df.loc[:,dice_col].std() # for the aggregated test set case, additionally get the scores of averaging over fold results. if self.cf.evaluate_fold_means and len(df.fold.unique()) > 1: aps = [] for fold in df.fold.unique(): fold_df = spec_df[spec_df.fold == fold] if (fold_df.class_label==1).any(): aps.append(get_roi_ap_from_df((fold_df, self.cf.min_det_thresh, self.cf.per_patient_ap))) stats_dict['ap_folds_mean'] = np.mean(aps) if len(aps)>0 else np.nan stats_dict['ap_folds_std'] = np.std(aps) if len(aps)>0 else np.nan stats_dict['auc_folds_mean'] = np.nan stats_dict['auc_folds_std'] = np.nan if self.regress_flag: avps, accuracies, MAEs = [], [], [] for fold in df.fold.unique(): fold_df = spec_df[spec_df.fold == fold] if (fold_df.class_label == 1).any(): avps.append(roi_avp((fold_df, self.cf.min_det_thresh, self.cf.per_patient_ap))) truncd_df_tp = fold_df[((fold_df.det_type == "det_tp") & fold_df.pred_score > self.cf.min_det_thresh)] weights_tp = truncd_df_tp.pred_score.tolist() y_true, y_pred = truncd_df_tp.rg_bin_target.tolist(), truncd_df_tp.rg_bins.tolist() accuracies.append(accuracy(y_true, y_pred, weights_tp)) y_true, y_pred = truncd_df_tp.rg_targets.tolist(), truncd_df_tp.regressions.tolist() MAEs.append(MAE_w_std(y_true, y_pred, weights_tp)) stats_dict['avp_folds_mean'] = np.mean(avps) if len(avps) > 0 else np.nan stats_dict['avp_folds_std'] = np.std(avps) if len(avps) > 0 else np.nan stats_dict['rg_bin_accuracy_weighted_tp_folds_mean'] = np.mean(accuracies) if len(accuracies) > 0 else np.nan stats_dict['rg_bin_accuracy_weighted_tp_folds_std'] = np.std(accuracies) if len(accuracies) > 0 else np.nan stats_dict['rg_MAE_w_std_weighted_tp_folds_mean'] = np.mean(MAEs, axis=0) if len(MAEs) > 0 else np.nan stats_dict['rg_MAE_w_std_weighted_tp_folds_std'] = np.std(MAEs, axis=0) if len(MAEs) > 0 else np.nan if hasattr(self, "seg_df") and not boxes_only and self.cf.evaluate_fold_means and len(seg_cl_df.fold.unique()) > 1: fold_means = seg_cl_df.groupby(['fold'], as_index=True).agg({dice_col:"mean"}) stats_dict["dice_folds_mean"] = float(fold_means.mean()) stats_dict["dice_folds_std"] = float(fold_means.std()) # -------------- patient-based ----------------- # on patient level, aggregate predictions per patient (pid): The patient predicted score is the highest # confidence prediction for this class. The patient class label is 1 if roi of this class exists in patient, else 0. if score_level == 'patient': #this is the critical part in patient scoring: only the max gt and max pred score are taken per patient! #--> does mix up values from separate detections spec_df = cl_df.groupby(['pid'], as_index=False) agg_args = {'class_label': 'max', 'pred_score': 'max', 'fold': 'first'} if self.regress_flag: # pandas throws error if aggregated value is np.array, not if is list. agg_args.update({'regressions': lambda series: list(series.iloc[np.argmax(series.apply(np.linalg.norm).values)]), 'rg_targets': lambda series: list(series.iloc[np.argmax(series.apply(np.linalg.norm).values)]), 'rg_bins': 'max', 'rg_bin_target': 'max', 'rg_uncertainties': 'max' }) if hasattr(cl_df, "cluster_n_missing"): agg_args.update({'cluster_n_missing': 'mean'}) spec_df = spec_df.agg(agg_args) if len(spec_df.class_label.unique()) > 1: stats_dict['auc'] = roc_auc_score(spec_df.class_label.tolist(), spec_df.pred_score.tolist()) stats_dict['roc'] = roc_curve(spec_df.class_label.tolist(), spec_df.pred_score.tolist()) else: stats_dict['auc'] = np.nan stats_dict['roc'] = np.nan if (spec_df.class_label == 1).any(): patient_cl_labels = spec_df.class_label.tolist() stats_dict['ap'] = average_precision_score(patient_cl_labels, spec_df.pred_score.tolist()) stats_dict['prc'] = precision_recall_curve(patient_cl_labels, spec_df.pred_score.tolist()) if self.regress_flag: avp_scores = spec_df[spec_df.rg_bins == spec_df.rg_bin_target].pred_score.tolist() avp_scores += [0.] * (len(patient_cl_labels) - len(avp_scores)) stats_dict['avp'] = average_precision_score(patient_cl_labels, avp_scores) else: stats_dict['ap'] = np.nan stats_dict['prc'] = np.nan stats_dict['avp'] = np.nan if self.regress_flag: y_true, y_pred = spec_df.rg_targets.tolist(), spec_df.regressions.tolist() stats_dict["rg_RMSE"] = RMSE(y_true, y_pred) stats_dict["rg_MAE"] = MAE(y_true, y_pred) stats_dict["rg_bin_accuracy"] = accuracy(spec_df.rg_bin_target.tolist(), spec_df.rg_bins.tolist()) stats_dict["rg_uncertainty"] = spec_df.rg_uncertainties.mean() if hasattr(self, "seg_df") and not boxes_only: seg_cl_df = seg_cl_df.groupby(['pid'], as_index=False).agg( {dice_col: "mean", "fold": "first"}) # mean of all rois per patient in this epoch stats_dict["dice"] = seg_cl_df.loc[:,dice_col].mean() #mean of all patients stats_dict["dice_std"] = seg_cl_df.loc[:, dice_col].std() # for the aggregated test set case, additionally get the scores for averaging over fold results. if self.cf.evaluate_fold_means and len(df.fold.unique()) > 1 and self.mode in ["test", "analysis"]: aucs = [] aps = [] for fold in df.fold.unique(): fold_df = spec_df[spec_df.fold == fold] if (fold_df.class_label==1).any(): aps.append( average_precision_score(fold_df.class_label.tolist(), fold_df.pred_score.tolist())) if len(fold_df.class_label.unique())>1: aucs.append(roc_auc_score(fold_df.class_label.tolist(), fold_df.pred_score.tolist())) stats_dict['auc_folds_mean'] = np.mean(aucs) stats_dict['auc_folds_std'] = np.std(aucs) stats_dict['ap_folds_mean'] = np.mean(aps) stats_dict['ap_folds_std'] = np.std(aps) if hasattr(self, "seg_df") and not boxes_only and self.cf.evaluate_fold_means and len(seg_cl_df.fold.unique()) > 1: fold_means = seg_cl_df.groupby(['fold'], as_index=True).agg({dice_col:"mean"}) stats_dict["dice_folds_mean"] = float(fold_means.mean()) stats_dict["dice_folds_std"] = float(fold_means.std()) all_stats.append(stats_dict) # -------------- monitoring, visualisation ----------------- # fill new results into monitor_metrics dict. for simplicity, only one class (of interest) is monitored on patient level. patient_interests = [self.cf.class_dict[self.cf.patient_class_of_interest],] if hasattr(self.cf, "bin_dict"): patient_interests += [self.cf.bin_dict[self.cf.patient_bin_of_interest]] if monitor_metrics is not None and (score_level != 'patient' or cl_name in patient_interests): name = 'patient_'+cl_name if score_level == 'patient' else cl_name for metric in self.cf.metrics: if metric in stats_dict.keys(): monitor_metrics[name + '_'+metric].append(stats_dict[metric]) else: print("WARNING: skipped monitor metric {}_{} since not avail".format(name, metric)) # histograms if self.cf.plot_prediction_histograms: out_filename = os.path.join(self.hist_dir, 'pred_hist_{}_{}_{}_{}'.format( self.cf.fold, self.mode, score_level, cl_name)) plg.plot_prediction_hist(self.cf, spec_df, out_filename) # analysis of the hyper-parameter cf.min_det_thresh, for optimization on validation set. if self.cf.scan_det_thresh and "val" in self.mode: conf_threshs = list(np.arange(0.8, 1, 0.02)) pool = Pool(processes=self.cf.n_workers) mp_inputs = [[spec_df, ii, self.cf.per_patient_ap] for ii in conf_threshs] aps = pool.map(get_roi_ap_from_df, mp_inputs, chunksize=1) pool.close() pool.join() self.logger.info('results from scanning over det_threshs: {}'.format([[i, j] for i, j in zip(conf_threshs, aps)])) class_means = pd.DataFrame(columns=self.cf.report_score_level) for slevel in self.cf.report_score_level: level_stats = pd.DataFrame([stats for stats in all_stats if slevel in stats["name"]])[self.cf.metrics] class_means.loc[:, slevel] = level_stats.mean() all_stats.extend([{"name": 'fold_{} {} {}'.format(self.cf.fold, slevel, "class_means"), **level_means} for slevel, level_means in class_means.to_dict().items()]) if self.cf.plot_stat_curves: out_filename = os.path.join(self.curves_dir, '{}_{}_stat_curves'.format(self.cf.fold, self.mode)) plg.plot_stat_curves(self.cf, all_stats, out_filename) if self.cf.plot_prediction_histograms and hasattr(df, "cluster_n_missing") and df.cluster_n_missing.notna().any(): out_filename = os.path.join(self.hist_dir, 'n_missing_hist_{}_{}.png'.format(self.cf.fold, self.mode)) plg.plot_wbc_n_missing(self.cf, df, outfile=out_filename) return all_stats, monitor_metrics def score_test_df(self, max_fold=None, internal_df=True): """ Writes out resulting scores to text files: First checks for class-internal-df (typically current) fold, gets resulting scores, writes them to a text file and pickles data frame. Also checks if data-frame pickles of all folds of cross-validation exist in exp_dir. If true, loads all dataframes, aggregates test sets over folds, and calculates and writes out overall metrics. """ # this should maybe be extended to auc, ap stds. metrics_to_score = self.cf.metrics.copy() # + [ m+ext for m in self.cf.metrics if "dice" in m for ext in ["_std"]] if internal_df: self.test_df.to_pickle(os.path.join(self.cf.test_dir, '{}_test_df.pkl'.format(self.cf.fold))) if hasattr(self, "seg_df"): self.seg_df.to_pickle(os.path.join(self.cf.test_dir, '{}_test_seg_df.pkl'.format(self.cf.fold))) stats, _ = self.return_metrics(self.test_df, self.cf.class_dict) with open(os.path.join(self.cf.test_dir, 'results.txt'), 'a') as handle: handle.write('\n****************************\n') handle.write('\nresults for fold {}, {} \n'.format(self.cf.fold, time.strftime("%d/%m/%y %H:%M:%S"))) handle.write('\n****************************\n') handle.write('\nfold df shape {}\n \n'.format(self.test_df.shape)) for s in stats: for metric in metrics_to_score: if metric in s.keys(): #needed as long as no dice on patient level poss if "accuracy" in metric: handle.write('{} {:0.4f} '.format(metric, s[metric])) else: handle.write('{} {:0.3f} '.format(metric, s[metric])) else: print("WARNING: skipped metric {} since not avail".format(metric)) handle.write('{} \n'.format(s['name'])) if max_fold is None: max_fold = self.cf.n_cv_splits-1 if self.cf.fold == max_fold: print("max fold/overall stats triggered") self.cf.fold = 'overall' if self.cf.evaluate_fold_means: metrics_to_score += [m + ext for m in self.cf.metrics for ext in ("_folds_mean", "_folds_std")] - if not self.cf.held_out_test_set or self.cf.eval_test_fold_wise: + if not self.cf.hold_out_test_set or not self.cf.ensemble_folds: fold_df_paths = sorted([ii for ii in os.listdir(self.cf.test_dir) if 'test_df.pkl' in ii]) fold_seg_df_paths = sorted([ii for ii in os.listdir(self.cf.test_dir) if 'test_seg_df.pkl' in ii]) for paths in [fold_df_paths, fold_seg_df_paths]: assert len(paths) <= self.cf.n_cv_splits, "found {} > nr of cv splits results dfs in {}".format( len(paths), self.cf.test_dir) with open(os.path.join(self.cf.test_dir, 'results.txt'), 'a') as handle: dfs_list = [pd.read_pickle(os.path.join(self.cf.test_dir, ii)) for ii in fold_df_paths] seg_dfs_list = [pd.read_pickle(os.path.join(self.cf.test_dir, ii)) for ii in fold_seg_df_paths] self.test_df = pd.concat(dfs_list, sort=True) if len(seg_dfs_list)>0: self.seg_df = pd.concat(seg_dfs_list, sort=True) stats, _ = self.return_metrics(self.test_df, self.cf.class_dict) handle.write('\n****************************\n') handle.write('\nOVERALL RESULTS \n') handle.write('\n****************************\n') handle.write('\ndf shape \n \n'.format(self.test_df.shape)) for s in stats: for metric in metrics_to_score: if metric in s.keys(): handle.write('{} {:0.3f} '.format(metric, s[metric])) handle.write('{} \n'.format(s['name'])) results_table_path = os.path.join(self.cf.test_dir,"../../", 'results_table.csv') with open(results_table_path, 'a') as handle: #---column headers--- handle.write('\n{},'.format("Experiment Name")) handle.write('{},'.format("Time Stamp")) handle.write('{},'.format("Samples Seen")) handle.write('{},'.format("Spatial Dim")) handle.write('{},'.format("Patch Size")) handle.write('{},'.format("CV Folds")) handle.write('{},'.format("{}-clustering IoU".format(self.cf.clustering))) handle.write('{},'.format("Merge-2D-to-3D IoU")) if hasattr(self.cf, "test_against_exact_gt"): handle.write('{},'.format('Exact GT')) for s in stats: if self.cf.class_dict[self.cf.patient_class_of_interest] in s['name'] or "mean" in s["name"]: for metric in metrics_to_score: if metric in s.keys() and not np.isnan(s[metric]): if metric=='ap': handle.write('{}_{} : {}_{},'.format(*s['name'].split(" ")[1:], metric, int(np.mean(self.cf.ap_match_ious)*100))) elif not "folds_std" in metric: handle.write('{}_{} : {},'.format(*s['name'].split(" ")[1:], metric)) else: print("WARNING: skipped metric {} since not avail".format(metric)) handle.write('\n') #--- columns content--- handle.write('{},'.format(self.cf.exp_dir.split(os.sep)[-1])) handle.write('{},'.format(time.strftime("%d%b%y %H:%M:%S"))) handle.write('{},'.format(self.cf.num_epochs*self.cf.num_train_batches*self.cf.batch_size)) handle.write('{}D,'.format(self.cf.dim)) handle.write('{},'.format("x".join([str(self.cf.patch_size[i]) for i in range(self.cf.dim)]))) handle.write('{},'.format(str(self.test_df.fold.unique().tolist()).replace(",", ""))) handle.write('{},'.format(self.cf.clustering_iou if self.cf.clustering else str("N/A"))) handle.write('{},'.format(self.cf.merge_3D_iou if self.cf.merge_2D_to_3D_preds else str("N/A"))) if hasattr(self.cf, "test_against_exact_gt"): handle.write('{},'.format(self.cf.test_against_exact_gt)) for s in stats: if self.cf.class_dict[self.cf.patient_class_of_interest] in s['name'] or "mean" in s["name"]: for metric in metrics_to_score: if metric in s.keys() and not np.isnan(s[metric]): # needed as long as no dice on patient level possible if "folds_mean" in metric: handle.write('{:0.3f}\u00B1{:0.3f}, '.format(s[metric], s["_".join((*metric.split("_")[:-1], "std"))])) elif not "folds_std" in metric: handle.write('{:0.3f}, '.format(s[metric])) handle.write('\n') with open(os.path.join(self.cf.test_dir, 'results_extr_scores.txt'), 'w') as handle: handle.write('\n****************************\n') handle.write('\nextremal scores for fold {} \n'.format(self.cf.fold)) handle.write('\n****************************\n') # want: pid & fold (&other) of highest scoring tp & fp in test_df for cl in self.cf.class_dict.keys(): print("\nClass {}".format(self.cf.class_dict[cl]), file=handle) cl_df = self.test_df[self.test_df.pred_class == cl] #.dropna(axis=1) for det_type in ['det_tp', 'det_fp']: filtered_df = cl_df[cl_df.det_type==det_type] print("\nHighest scoring {} of class {}".format(det_type, self.cf.class_dict[cl]), file=handle) if len(filtered_df)>0: print(filtered_df.loc[filtered_df.pred_score.idxmax()], file=handle) else: print("No detections of type {} for class {} in this df".format(det_type, self.cf.class_dict[cl]), file=handle) handle.write('\n****************************\n') diff --git a/exec.py b/exec.py index d6f00e2..8ae261d 100644 --- a/exec.py +++ b/exec.py @@ -1,344 +1,344 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ execution script. this where all routines come together and the only script you need to call. refer to parse args below to see options for execution. """ import plotting as plg import os import warnings import argparse import time import torch import utils.exp_utils as utils from evaluator import Evaluator from predictor import Predictor for msg in ["Attempting to set identical bottom==top results", "This figure includes Axes that are not compatible with tight_layout", "Data has no positive values, and therefore cannot be log-scaled.", ".*invalid value encountered in true_divide.*"]: warnings.filterwarnings("ignore", msg) def train(cf, logger): """ performs the training routine for a given fold. saves plots and selected parameters to the experiment dir specified in the configs. logs to file and tensorboard. """ logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format( cf.dim, cf.fold, cf.exp_dir, cf.model)) logger.time("train_val") # -------------- inits and settings ----------------- net = model.net(cf, logger).cuda() if cf.optimizer == "ADAMW": optimizer = torch.optim.AdamW(utils.parse_params_for_optim(net, weight_decay=cf.weight_decay, exclude_from_wd=cf.exclude_from_wd), lr=cf.learning_rate[0]) elif cf.optimizer == "SGD": optimizer = torch.optim.SGD(utils.parse_params_for_optim(net, weight_decay=cf.weight_decay), lr=cf.learning_rate[0], momentum=0.3) if cf.dynamic_lr_scheduling: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=cf.scheduling_mode, factor=cf.lr_decay_factor, patience=cf.scheduling_patience) model_selector = utils.ModelSelector(cf, logger) starting_epoch = 1 if cf.resume: checkpoint_path = os.path.join(cf.fold_dir, "last_state.pth") starting_epoch, net, optimizer, model_selector = \ utils.load_checkpoint(checkpoint_path, net, optimizer, model_selector) logger.info('resumed from checkpoint {} to epoch {}'.format(checkpoint_path, starting_epoch)) # prepare monitoring monitor_metrics = utils.prepare_monitoring(cf) logger.info('loading dataset and initializing batch generators...') batch_gen = data_loader.get_train_generators(cf, logger) # -------------- training ----------------- for epoch in range(starting_epoch, cf.num_epochs + 1): logger.info('starting training epoch {}/{}'.format(epoch, cf.num_epochs)) logger.time("train_epoch") net.train() train_results_list = [] train_evaluator = Evaluator(cf, logger, mode='train') for i in range(cf.num_train_batches): logger.time("train_batch_loadfw") batch = next(batch_gen['train']) batch_gen['train'].generator.stats['roi_counts'] += batch['roi_counts'] batch_gen['train'].generator.stats['empty_counts'] += batch['empty_counts'] logger.time("train_batch_loadfw") logger.time("train_batch_netfw") results_dict = net.train_forward(batch) logger.time("train_batch_netfw") logger.time("train_batch_bw") optimizer.zero_grad() results_dict['torch_loss'].backward() if cf.clip_norm: torch.nn.utils.clip_grad_norm_(net.parameters(), cf.clip_norm, norm_type=2) # gradient clipping optimizer.step() train_results_list.append(({k:v for k,v in results_dict.items() if k != "seg_preds"}, batch["pid"])) # slim res dict if not cf.server_env: print("\rFinished training batch " + "{}/{} in {:.1f}s ({:.2f}/{:.2f} forw load/net, {:.2f} backw).".format(i+1, cf.num_train_batches, logger.get_time("train_batch_loadfw")+ logger.get_time("train_batch_netfw") +logger.time("train_batch_bw"), logger.get_time("train_batch_loadfw",reset=True), logger.get_time("train_batch_netfw", reset=True), logger.get_time("train_batch_bw", reset=True)), end="", flush=True) print() #--------------- train eval ---------------- if (epoch-1)%cf.plot_frequency==0: # view an example batch utils.split_off_process(plg.view_batch, cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True, get_time="train-example plot", out_file=os.path.join(cf.plot_dir, 'batch_example_train_{}.png'.format(cf.fold))) logger.time("evals") _, monitor_metrics['train'] = train_evaluator.evaluate_predictions(train_results_list, monitor_metrics['train']) logger.time("evals") logger.time("train_epoch", toggle=False) del train_results_list #----------- validation ------------ logger.info('starting validation in mode {}.'.format(cf.val_mode)) logger.time("val_epoch") with torch.no_grad(): net.eval() val_results_list = [] val_evaluator = Evaluator(cf, logger, mode=cf.val_mode) val_predictor = Predictor(cf, net, logger, mode='val') for i in range(batch_gen['n_val']): logger.time("val_batch") batch = next(batch_gen[cf.val_mode]) if cf.val_mode == 'val_patient': results_dict = val_predictor.predict_patient(batch) elif cf.val_mode == 'val_sampling': results_dict = net.train_forward(batch, is_validation=True) val_results_list.append([results_dict, batch["pid"]]) if not cf.server_env: print("\rFinished validation {} {}/{} in {:.1f}s.".format('patient' if cf.val_mode=='val_patient' else 'batch', i + 1, batch_gen['n_val'], logger.time("val_batch")), end="", flush=True) print() #------------ val eval ------------- if (epoch - 1) % cf.plot_frequency == 0: utils.split_off_process(plg.view_batch, cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True, get_time="val-example plot", out_file=os.path.join(cf.plot_dir, 'batch_example_val_{}.png'.format(cf.fold))) logger.time("evals") _, monitor_metrics['val'] = val_evaluator.evaluate_predictions(val_results_list, monitor_metrics['val']) model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch) del val_results_list #----------- monitoring ------------- monitor_metrics.update({"lr": {str(g) : group['lr'] for (g, group) in enumerate(optimizer.param_groups)}}) logger.metrics2tboard(monitor_metrics, global_step=epoch) logger.time("evals") logger.info('finished epoch {}/{}, took {:.2f}s. train total: {:.2f}s, average: {:.2f}s. val total: {:.2f}s, average: {:.2f}s.'.format( epoch, cf.num_epochs, logger.get_time("train_epoch")+logger.time("val_epoch"), logger.get_time("train_epoch"), logger.get_time("train_epoch", reset=True)/cf.num_train_batches, logger.get_time("val_epoch"), logger.get_time("val_epoch", reset=True)/batch_gen["n_val"])) logger.info("time for evals: {:.2f}s".format(logger.get_time("evals", reset=True))) #-------------- scheduling ----------------- if cf.dynamic_lr_scheduling: scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1]) else: for param_group in optimizer.param_groups: param_group['lr'] = cf.learning_rate[epoch-1] logger.time("train_val") logger.info("Training and validating over {} epochs took {}".format(cf.num_epochs, logger.get_time("train_val", format="hms", reset=True))) batch_gen['train'].generator.print_stats(logger, plot=True) def test(cf, logger, max_fold=None): """performs testing for a given fold (or held out set). saves stats in evaluator. """ logger.time("test_fold") logger.info('starting testing model of fold {} in exp {}'.format(cf.fold, cf.exp_dir)) net = model.net(cf, logger).cuda() batch_gen = data_loader.get_test_generator(cf, logger) test_predictor = Predictor(cf, net, logger, mode='test') test_results_list = test_predictor.predict_test_set(batch_gen, return_results = not hasattr( cf, "eval_test_separately") or not cf.eval_test_separately) if test_results_list is not None: test_evaluator = Evaluator(cf, logger, mode='test') test_evaluator.evaluate_predictions(test_results_list) test_evaluator.score_test_df(max_fold=max_fold) logger.info('Testing of fold {} took {}.\n'.format(cf.fold, logger.get_time("test_fold", reset=True, format="hms"))) if __name__ == '__main__': stime = time.time() parser = argparse.ArgumentParser() parser.add_argument('--dataset_name', type=str, default='toy', help="path to the dataset-specific code in source_dir/datasets") parser.add_argument('--exp_dir', type=str, default='/home/gregor/Documents/regrcnn/datasets/toy/experiments/dev', help='path to experiment dir. will be created if non existent.') parser.add_argument('-m', '--mode', type=str, default='train_test', help='one out of: create_exp, analysis, train, train_test, or test') parser.add_argument('-f', '--folds', nargs='+', type=int, default=None, help='None runs over all folds in CV. otherwise specify list of folds.') parser.add_argument('--server_env', default=False, action='store_true', help='change IO settings to deploy models on a cluster.') parser.add_argument('--data_dest', type=str, default=None, help="path to final data folder if different from config") parser.add_argument('--use_stored_settings', default=False, action='store_true', help='load configs from existing exp_dir instead of source dir. always done for testing, ' 'but can be set to true to do the same for training. useful in job scheduler environment, ' 'where source code might change before the job actually runs.') parser.add_argument('--resume', action="store_true", default=False, help='if given, resume from checkpoint(s) of the specified folds.') parser.add_argument('-d', '--dev', default=False, action='store_true', help="development mode: shorten everything") args = parser.parse_args() args.dataset_name = os.path.join("datasets", args.dataset_name) if not "datasets" in args.dataset_name else args.dataset_name folds = args.folds resume = None if args.resume in ['None', 'none'] else args.resume if args.mode == 'create_exp': cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=False) logger = utils.get_logger(cf.exp_dir, cf.server_env, -1) logger.info('created experiment directory at {}'.format(args.exp_dir)) elif args.mode == 'train' or args.mode == 'train_test': cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, args.use_stored_settings) if args.dev: folds = [0,1] cf.batch_size, cf.num_epochs, cf.min_save_thresh, cf.save_n_models = 3 if cf.dim==2 else 1, 2, 0, 2 cf.num_train_batches, cf.num_val_batches, cf.max_val_patients = 5, 1, 1 cf.test_n_epochs, cf.max_test_patients = cf.save_n_models, 2 torch.backends.cudnn.benchmark = cf.dim==3 else: torch.backends.cudnn.benchmark = cf.cuda_benchmark if args.data_dest is not None: cf.data_dest = args.data_dest logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval) data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py')) model = utils.import_module('model', cf.model_path) logger.info("loaded model from {}".format(cf.model_path)) if folds is None: folds = range(cf.n_cv_splits) for fold in folds: """k-fold cross-validation: the dataset is split into k equally-sized folds, one used for validation, one for testing, the rest for training. This loop iterates k-times over the dataset, cyclically moving the splits. k==folds, fold in [0,folds) says which split is used for testing. """ cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold)); cf.fold = fold logger.set_logfile(fold=fold) cf.resume = resume if not os.path.exists(cf.fold_dir): os.mkdir(cf.fold_dir) train(cf, logger) cf.resume = None if args.mode == 'train_test': test(cf, logger) elif args.mode == 'test': cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False) if args.data_dest is not None: cf.data_dest = args.data_dest logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval) data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py')) model = utils.import_module('model', cf.model_path) logger.info("loaded model from {}".format(cf.model_path)) fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")]) if folds is None: folds = range(cf.n_cv_splits) if args.dev: folds = folds[:2] cf.max_test_patients, cf.test_n_epochs = 2, 2 else: torch.backends.cudnn.benchmark = cf.cuda_benchmark for fold in folds: cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold)); cf.fold = fold logger.set_logfile(fold=fold) if cf.fold_dir in fold_dirs: test(cf, logger, max_fold=max([int(f[-1]) for f in fold_dirs])) else: logger.info("Skipping fold {} since no model parameters found.".format(fold)) # load raw predictions saved by predictor during testing, run aggregation algorithms and evaluation. elif args.mode == 'analysis': """ analyse already saved predictions. """ cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False) logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval) - if cf.held_out_test_set and not cf.eval_test_fold_wise: + if cf.hold_out_test_set and cf.ensemble_folds: predictor = Predictor(cf, net=None, logger=logger, mode='analysis') results_list = predictor.load_saved_predictions() logger.info('starting evaluation...') cf.fold = "overall" evaluator = Evaluator(cf, logger, mode='test') evaluator.evaluate_predictions(results_list) evaluator.score_test_df(max_fold=cf.fold) else: fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")]) if args.dev: cf.test_n_epochs = 2 fold_dirs = fold_dirs[:1] if folds is None: folds = range(cf.n_cv_splits) for fold in folds: cf.fold = fold; cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold)) logger.set_logfile(fold=fold) if cf.fold_dir in fold_dirs: predictor = Predictor(cf, net=None, logger=logger, mode='analysis') results_list = predictor.load_saved_predictions() # results_list[x][1] is pid, results_list[x][0] is list of len samples-per-patient, each entry hlds # list of boxes per that sample, i.e., len(results_list[x][y][0]) would be nr of boxes in sample y of patient x logger.info('starting evaluation...') evaluator = Evaluator(cf, logger, mode='test') evaluator.evaluate_predictions(results_list) max_fold = max([int(f[-1]) for f in fold_dirs]) evaluator.score_test_df(max_fold=max_fold) else: logger.info("Skipping fold {} since no model parameters found.".format(fold)) else: raise ValueError('mode "{}" specified in args is not implemented.'.format(args.mode)) mins, secs = divmod((time.time() - stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) logger.info("{} total runtime: {}".format(os.path.split(__file__)[1], t)) del logger torch.cuda.empty_cache() diff --git a/graphics_generation.py b/graphics_generation.py index 6c59a0c..42b298b 100644 --- a/graphics_generation.py +++ b/graphics_generation.py @@ -1,1932 +1,1932 @@ """ Created at 07/03/19 11:42 @author: gregor """ import plotting as plg import matplotlib.lines as mlines import os import sys import multiprocessing from copy import deepcopy import logging import time import numpy as np import pandas as pd from scipy.stats import norm from sklearn.metrics import confusion_matrix import utils.exp_utils as utils import utils.model_utils as mutils import utils.dataloader_utils as dutils from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates import predictor as predictor_file import evaluator as evaluator_file class NoDaemonProcess(multiprocessing.Process): # make 'daemon' attribute always return False def _get_daemon(self): return False def _set_daemon(self, value): pass daemon = property(_get_daemon, _set_daemon) # We sub-class multiprocessing.pool.Pool instead of multiprocessing.Pool # because the latter is only a wrapper function, not a proper class. class NoDaemonProcessPool(multiprocessing.pool.Pool): Process = NoDaemonProcess class AttributeDict(dict): __getattr__ = dict.__getitem__ __setattr__ = dict.__setitem__ def get_cf(dataset_name, exp_dir=""): cf_path = os.path.join('datasets', dataset_name, exp_dir, "configs.py") cf_file = utils.import_module('configs', cf_path) return cf_file.Configs() def prostate_results_static(plot_dir=None): cf = get_cf('prostate', '') if plot_dir is None: plot_dir = os.path.join('datasets', 'prostate', 'misc') text_fs = 18 fig = plg.plt.figure(figsize=(6, 3)) #w,h grid = plg.plt.GridSpec(1, 1, wspace=0.0, hspace=0.0, figure=fig) #r,c groups = ["b values", "ADC + b values", "T2"] splits = ["Det. U-Net", "Mask R-CNN", "Faster R-CNN+"] values = {"detu": [(0.296, 0.031), (0.312, 0.045), (0.090, 0.040)], "mask": [(0.393, 0.051), (0.382, 0.047), (0.136, 0.016)], "fast": [(0.424, 0.083), (0.390, 0.086), (0.036, 0.013)]} bar_values = [[v[0] for v in split] for split in values.values()] errors = [[v[1] for v in split] for split in values.values()] ax = fig.add_subplot(grid[0,0]) colors = [cf.aubergine, cf.blue, cf.dark_blue] plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, errors=errors, colors=colors, ax=ax, legend=True, title="Prostate Main Results (3D)", ylabel=r"Performance as $\mathrm{AP}_{10}$", xlabel="Input Modalities") plg.plt.tight_layout() plg.plt.savefig(os.path.join(plot_dir, 'prostate_main_results.png'), dpi=600) def prostate_GT_examples(exp_dir='', plot_dir=None, pid=8., z_ix=None): import datasets.prostate.data_loader as dl cf = get_cf('prostate', exp_dir) cf.exp_dir = exp_dir cf.fold = 0 cf.data_sourcedir = "/mnt/HDD2TB/Documents/data/prostate/data_di_250519_ps384_gs6071/" dataset = dl.Dataset(cf) dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits) dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle')) set_splits = dataset.fg.splits test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1) train_ids = np.concatenate(set_splits, axis=0) - if cf.held_out_test_set: + if cf.hold_out_test_set: train_ids = np.concatenate((train_ids, test_ids), axis=0) test_ids = [] print("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids), len(test_ids))) if plot_dir is None: plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'prostate', 'misc') text_fs = 18 fig = plg.plt.figure(figsize=(10, 7.7)) #w,h grid = plg.plt.GridSpec(3, 4, wspace=0.0, hspace=0.0, figure=fig) #r,c text_x, text_y = 0.1, 0.8 # ------- DWI ------- if z_ix is None: z_ix_dwi = np.random.choice(dataset[pid]["fg_slices"]) img = np.load(dataset[pid]["img"])[:,z_ix_dwi] # mods, z,y,x seg = np.load(dataset[pid]["seg"])[z_ix_dwi] # z,y,x ax = fig.add_subplot(grid[0,0]) ax.imshow(img[0], cmap='gray') ax.text(text_x, text_y, "ADC", size=text_fs, color=cf.white, transform=ax.transAxes, bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7)) ax.axis('off') ax = fig.add_subplot(grid[0,1]) ax.imshow(img[0], cmap='gray') cmap = cf.class_cmap for r_ix in np.unique(seg[seg>0]): seg[seg==r_ix] = dataset[pid]["class_targets"][r_ix-1] ax.imshow(plg.to_rgba(seg, cmap), alpha=1) ax.text(text_x, text_y, "DWI GT", size=text_fs, color=cf.white, transform=ax.transAxes, bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7)) ax.axis('off') for b_ix, b in enumerate([50,500,1000,1500]): ax = fig.add_subplot(grid[1, b_ix]) ax.imshow(img[b_ix+1], cmap='gray') ax.text(text_x, text_y, r"{}{}".format("$b=$" if b_ix == 0 else "", b), size=text_fs, color=cf.white, transform=ax.transAxes, bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7)) ax.axis('off') # ----- T2 ----- cf.data_sourcedir = "/mnt/HDD2TB/Documents/data/prostate/data_t2_250519_ps384_gs6071/" dataset = dl.Dataset(cf) if z_ix is None: if z_ix_dwi in dataset[pid]["fg_slices"]: z_ix_t2 = z_ix_dwi else: z_ix_t2 = np.random.choice(dataset[pid]["fg_slices"]) img = np.load(dataset[pid]["img"])[:,z_ix_t2] # mods, z,y,x seg = np.load(dataset[pid]["seg"])[z_ix_t2] # z,y,x ax = fig.add_subplot(grid[2,0]) ax.imshow(img[0], cmap='gray') ax.text(text_x, text_y, "T2w", size=text_fs, color=cf.white, transform=ax.transAxes, bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7)) ax.axis('off') ax = fig.add_subplot(grid[2,1]) ax.imshow(img[0], cmap='gray') cmap = cf.class_cmap for r_ix in np.unique(seg[seg>0]): seg[seg==r_ix] = dataset[pid]["class_targets"][r_ix-1] ax.imshow(plg.to_rgba(seg, cmap), alpha=1) ax.text(text_x, text_y, "T2 GT", size=text_fs, color=cf.white, transform=ax.transAxes, bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7)) ax.axis('off') #grid.tight_layout(fig) plg.plt.tight_layout() plg.plt.savefig(os.path.join(plot_dir, 'prostate_gt_examples.png'), dpi=600) def prostate_dataset_stats(exp_dir='', plot_dir=None, show_splits=True,): import datasets.prostate.data_loader as dl cf = get_cf('prostate', exp_dir) cf.exp_dir = exp_dir cf.fold = 0 dataset = dl.Dataset(cf) dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits) dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle')) set_splits = dataset.fg.splits test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1) train_ids = np.concatenate(set_splits, axis=0) - if cf.held_out_test_set: + if cf.hold_out_test_set: train_ids = np.concatenate((train_ids, test_ids), axis=0) test_ids = [] print("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids), len(test_ids))) df, labels = dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids}, plot_dir=None) if plot_dir is None: plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'prostate', 'misc') if show_splits: fig = plg.plt.figure(figsize=(6, 6)) # w, h grid = plg.plt.GridSpec(2, 2, wspace=0.05, hspace=0.15, figure=fig) # rows, cols else: fig = plg.plt.figure(figsize=(6, 3.)) grid = plg.plt.GridSpec(1, 1, wspace=0.0, hspace=0.15, figure=fig) ax = fig.add_subplot(grid[0,0]) ax = plg.plot_data_stats(cf, df, labels, ax=ax) ax.set_xlabel("") ax.set_xticklabels(df.columns, rotation='horizontal', fontsize=11) ax.set_title("") if show_splits: ax.text(0.05,0.95, 'a)', horizontalalignment='center', verticalalignment='center', transform = ax.transAxes, weight='bold') ax.text(0, 25, "GS$=6$", horizontalalignment='center', verticalalignment='center', bbox=dict(facecolor=(*cf.white, 0.8), edgecolor=cf.dark_green, pad=3)) ax.text(1, 25, "GS$\geq 7a$", horizontalalignment='center', verticalalignment='center', bbox=dict(facecolor=(*cf.white, 0.8), edgecolor=cf.red, pad=3)) ax.margins(y=0.1) if show_splits: ax = fig.add_subplot(grid[:, 1]) ax = plg.plot_fold_stats(cf, df, labels, ax=ax) ax.set_xlabel("") ax.set_title("") ax.text(0.05, 0.98, 'c)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, weight='bold') ax.yaxis.tick_right() ax.yaxis.set_label_position("right") ax.margins(y=0.1) ax = fig.add_subplot(grid[1, 0]) cf.balance_target = "lesion_gleasons" dataset.df = None df, labels = dataset.calc_statistics(plot_dir=None, overall_stats=True) ax = plg.plot_data_stats(cf, df, labels, ax=ax) ax.set_xlabel("") ax.set_title("") ax.text(0.05, 0.95, 'b)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, weight='bold') ax.margins(y=0.1) # rename GS according to names in thesis renamer = {'GS60':'GS 6', 'GS71':'GS 7a', 'GS72':'GS 7b', 'GS80':'GS 8', 'GS90': 'GS 9', 'GS91':'GS 9a', 'GS92':'GS 9b'} x_ticklabels = [str(l.get_text()) for l in ax.xaxis.get_ticklabels()] ax.xaxis.set_ticklabels([renamer[l] for l in x_ticklabels]) plg.plt.tight_layout() plg.plt.savefig(os.path.join(plot_dir, 'data_stats_prostate.png'), dpi=600) return def lidc_merged_sa_joint_plot(exp_dir='', plot_dir=None): import datasets.lidc.data_loader as dl cf = get_cf('lidc', exp_dir) cf.balance_target = "regression_targets" if plot_dir is None: plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'lidc', 'misc') cf.training_gts = 'merged' dataset = dl.Dataset(cf, mode='train') df, labels = dataset.calc_statistics(plot_dir=None, overall_stats=True) fig = plg.plt.figure(figsize=(4, 5.6)) #w, h # fig.subplots_adjust(hspace=0, wspace=0) grid = plg.plt.GridSpec(3, 1, wspace=0.0, hspace=0.7, figure=fig) #rows, cols fs = 9 ax = fig.add_subplot(grid[0, 0]) labels = [AttributeDict({ 'name': rg_val, 'color': cf.bin_id2label[cf.rg_val_to_bin_id(rg_val)].color}) for rg_val in df.columns] ax = plg.plot_data_stats(cf, df, labels, ax=ax, fs=fs) ax.set_xlabel("averaged multi-rater malignancy scores (ms)", fontsize=fs) ax.set_title("") ax.text(0.05, 0.91, 'a)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, weight='bold', fontsize=fs) ax.margins(y=0.2) #----- single annotator ------- cf.training_gts = 'sa' dataset = dl.Dataset(cf, mode='train') df, labels = dataset.calc_statistics(plot_dir=None, overall_stats=True) ax = fig.add_subplot(grid[1, 0]) labels = [AttributeDict({ 'name': '{:.0f}'.format(rg_val), 'color': cf.bin_id2label[cf.rg_val_to_bin_id(rg_val)].color}) for rg_val in df.columns] mapper = {rg_val:'{:.0f}'.format(rg_val) for rg_val in df.columns} df = df.rename(mapper, axis=1) ax = plg.plot_data_stats(cf, df, labels, ax=ax, fs=fs) ax.set_xlabel("unaggregrated single-rater malignancy scores (ms)", fontsize=fs) ax.set_title("") ax.text(0.05, 0.91, 'b)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, weight='bold', fontsize=fs) ax.margins(y=0.45) #------ binned dissent ----- #cf.balance_target = "regression_targets" all_patients = [(pid,patient['rg_bin_targets']) for pid, patient in dataset.data.items()] non_empty_patients = [(pid, lesions) for (pid, lesions) in all_patients if len(lesions) > 0] mean_std_per_lesion = np.array([(np.mean(roi), np.std(roi)) for (pid, lesions) in non_empty_patients for roi in lesions]) distribution_max_per_lesion = [np.unique(roi, return_counts=True) for (pid, lesions) in non_empty_patients for roi in lesions] distribution_max_per_lesion = np.array([uniq[cts.argmax()] for (uniq, cts) in distribution_max_per_lesion]) binned_stats = [[] for bin_id in cf.bin_id2rg_val.keys()] for l_ix, mean_std in enumerate(mean_std_per_lesion): bin_id = cf.rg_val_to_bin_id(mean_std[0]) bin_id_max = cf.rg_val_to_bin_id(distribution_max_per_lesion[l_ix]) binned_stats[int(bin_id)].append((*mean_std, distribution_max_per_lesion[l_ix], bin_id-bin_id_max)) ax = fig.add_subplot(grid[2, 0]) plg.plot_binned_rater_dissent(cf, binned_stats, ax=ax, fs=fs) ax.set_title("") ax.text(0.05, 0.91, 'c)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, weight='bold', fontsize=fs) ax.margins(y=0.2) plg.plt.savefig(os.path.join(plot_dir, 'data_stats_lidc_solarized.png'), bbox_inches='tight', dpi=600) return def lidc_dataset_stats(exp_dir='', plot_dir=None): import datasets.lidc.data_loader as dl cf = get_cf('lidc', exp_dir) cf.data_rootdir = cf.pp_data_path cf.balance_target = "regression_targets" dataset = dl.Dataset(cf, data_dir=cf.data_rootdir) if plot_dir is None: plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'lidc', 'misc') df, labels = dataset.calc_statistics(plot_dir=plot_dir, overall_stats=True) return df, labels def lidc_sa_dataset_stats(exp_dir='', plot_dir=None): import datasets.lidc_sa.data_loader as dl cf = get_cf('lidc_sa', exp_dir) #cf.data_rootdir = cf.pp_data_path cf.balance_target = "regression_targets" dataset = dl.Dataset(cf) if plot_dir is None: plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'lidc_sa', 'misc') dataset.calc_statistics(plot_dir=plot_dir, overall_stats=True) all_patients = [(pid,patient['rg_bin_targets']) for pid, patient in dataset.data.items()] empty_patients = [pid for (pid, lesions) in all_patients if len(lesions) == 0] non_empty_patients = [(pid, lesions) for (pid, lesions) in all_patients if len(lesions) > 0] full_consent_patients = [(pid, lesions) for (pid, lesions) in non_empty_patients if np.all([np.unique(roi).size == 1 for roi in lesions])] all_lesions = [roi for (pid, lesions) in non_empty_patients for roi in lesions] two_vote_min = [roi for (pid, lesions) in non_empty_patients for roi in lesions if np.count_nonzero(roi) > 1] three_vote_min = [roi for (pid, lesions) in non_empty_patients for roi in lesions if np.count_nonzero(roi) > 2] mean_std_per_lesion = np.array([(np.mean(roi), np.std(roi)) for (pid, lesions) in non_empty_patients for roi in lesions]) avg_mean_std_pl = np.mean(mean_std_per_lesion, axis=0) # call std dev per lesion disconsent from now on disconsent_std = np.std(mean_std_per_lesion[:, 1]) distribution_max_per_lesion = [np.unique(roi, return_counts=True) for (pid, lesions) in non_empty_patients for roi in lesions] distribution_max_per_lesion = np.array([uniq[cts.argmax()] for (uniq, cts) in distribution_max_per_lesion]) mean_max_delta = abs(mean_std_per_lesion[:, 0] - distribution_max_per_lesion) binned_stats = [[] for bin_id in cf.bin_id2rg_val.keys()] for l_ix, mean_std in enumerate(mean_std_per_lesion): bin_id = cf.rg_val_to_bin_id(mean_std[0]) bin_id_max = cf.rg_val_to_bin_id(distribution_max_per_lesion[l_ix]) binned_stats[int(bin_id)].append((*mean_std, distribution_max_per_lesion[l_ix], bin_id-bin_id_max)) plg.plot_binned_rater_dissent(cf, binned_stats, out_file=os.path.join(plot_dir, "binned_dissent.png")) mean_max_bin_divergence = [[] for bin_id in cf.bin_id2rg_val.keys()] for bin_id, bin_stats in enumerate(binned_stats): mean_max_bin_divergence[bin_id].append([roi for roi in bin_stats if roi[3] != 0]) mean_max_bin_divergence[bin_id].insert(0,len(mean_max_bin_divergence[bin_id][0])) return def lidc_annotator_confusion(exp_dir='', plot_dir=None, normalize=None, dataset=None, plot=True): """ :param exp_dir: :param plot_dir: :param normalize: str or None. str in ['truth', 'pred'] :param dataset: :param plot: :return: """ if dataset is None: import datasets.lidc.data_loader as dl cf = get_cf('lidc', exp_dir) # cf.data_rootdir = cf.pp_data_path cf.training_gts = "sa" cf.balance_target = "regression_targets" dataset = dl.Dataset(cf) else: cf = dataset.cf if plot_dir is None: plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'lidc', 'misc') dataset.calc_statistics(plot_dir=plot_dir, overall_stats=True) all_patients = [(pid,patient['rg_bin_targets']) for pid, patient in dataset.data.items()] non_empty_patients = [(pid, lesions) for (pid, lesions) in all_patients if len(lesions) > 0] y_true, y_pred = [], [] for (pid, lesions) in non_empty_patients: for roi in lesions: true_bin = cf.rg_val_to_bin_id(np.mean(roi)) y_true.extend([true_bin] * len(roi)) y_pred.extend(roi) cm = confusion_matrix(y_true, y_pred) if normalize in ["truth", "row"]: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] elif normalize in ["pred", "prediction", "column", "col"]: cm = cm.astype('float') / cm.sum(axis=0)[:, np.newaxis] if plot: plg.plot_confusion_matrix(cf, cm, out_file=os.path.join(plot_dir, "annotator_confusion.pdf")) return cm def plot_lidc_dissent_and_example(confusion_matrix=True, bin_stds=False, plot_dir=None, numbering=True, example_title="Example"): import datasets.lidc.data_loader as dl dataset_name = 'lidc' exp_dir1 = '/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/ms12345_mrcnn3d_rg_bs8' exp_dir2 = '/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/ms12345_mrcnn3d_rgbin_bs8' #exp_dir1 = '/home/gregor/networkdrives/E132-Cluster-Projects/lidc_sa/experiments/ms12345_mrcnn3d_rg_bs8' #exp_dir2 = '/home/gregor/networkdrives/E132-Cluster-Projects/lidc_sa/experiments/ms12345_mrcnn3d_rgbin_bs8' cf = get_cf(dataset_name, exp_dir1) #file_names = [f_name for f_name in os.listdir(os.path.join(exp_dir, 'inference_analysis')) if f_name.endswith('.pkl')] # file_names = [os.path.join(exp_dir, "inference_analysis", f_name) for f_name in file_names] file_names = ["bytes_merged_boxes_fold_0_pid_0811a.pkl",] z_ics = [194,] plot_files = [ {'files': [os.path.join(exp_dir, "inference_analysis", f_name) for exp_dir in [exp_dir1, exp_dir2]], 'z_ix': z_ix} for (f_name, z_ix) in zip(file_names, z_ics) ] cf.training_gts = 'sa' info_df_path = '/mnt/HDD2TB/Documents/data/lidc/pp_20190805/patient_gts_{}/info_df.pickle'.format(cf.training_gts) info_df = pd.read_pickle(info_df_path) cf.roi_items = ['regression_targets', 'rg_bin_targets_sa'] #['class_targets'] + cf.observables_rois text_fs = 14 title_fs = text_fs text_x, text_y = 0.06, 0.92 fig = plg.plt.figure(figsize=(8.6, 3)) #w, h #fig.subplots_adjust(hspace=0, wspace=0) grid = plg.plt.GridSpec(1, 4, wspace=0.0, hspace=0.0, figure=fig) #rows, cols cf.plot_class_ids = True f_ix = 0 z_ix = plot_files[f_ix]['z_ix'] for model_ix in range(2)[::-1]: print("f_ix, m_ix", f_ix, model_ix) plot_file = utils.load_obj(plot_files[f_ix]['files'][model_ix]) batch = plot_file["batch"] pid = batch["pid"][0] batch['patient_rg_bin_targets_sa'] = info_df[info_df.pid == pid]['class_target'].tolist() # apply same filter as with merged GTs: need at least two non-zero votes to consider a RoI. batch['patient_rg_bin_targets_sa'] = [[four_votes.astype("uint8") for four_votes in batch_el if np.count_nonzero(four_votes>0)>=2] for batch_el in batch['patient_rg_bin_targets_sa']] results_dict = plot_file["res_dict"] # pred ax = fig.add_subplot(grid[0, model_ix+2]) plg.view_batch_thesis(cf, batch, res_dict=results_dict, legend=False, sample_picks=None, fontsize=text_fs*1.3, vol_slice_picks=[z_ix, ], show_gt_labels=True, box_score_thres=0.2, plot_mods=False, seg_cmap="rg", show_cl_ids=False, out_file=None, dpi=600, patient_items=True, return_fig=False, axes={'pred': ax}) #ax.set_title("{}".format("Reg R-CNN" if model_ix==0 else "Mask R-CNN"), size=title_fs) ax.set_title("") ax.set_xlabel("{}".format("Reg R-CNN" if model_ix == 0 else "Mask R-CNN"), size=title_fs) if numbering: ax.text(text_x, text_y, chr(model_ix+99)+")", horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, weight='bold', color=cf.white, fontsize=title_fs) #ax.axis("off") ax.axis("on") plg.suppress_axes_lines(ax) # GT if model_ix==0: ax.set_title(example_title, fontsize=title_fs) ax = fig.add_subplot(grid[0, 1]) # ax.imshow(batch['patient_data'][0, 0, :, :, z_ix], cmap='gray') # ax.imshow(plg.to_rgba(batch['patient_seg'][0,0,:,:,z_ix], cf.cmap), alpha=0.8) plg.view_batch_thesis(cf, batch, res_dict=results_dict, legend=True, sample_picks=None, fontsize=text_fs*1.3, vol_slice_picks=[z_ix, ], show_gt_labels=True, box_score_thres=0.13, plot_mods=False, seg_cmap="rg", out_file=None, dpi=600, patient_items=True, return_fig=False, axes={'gt':ax}) if numbering: ax.text(text_x, text_y, "b)", horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, weight='bold', color=cf.white, fontsize=title_fs) #ax.set_title("Ground Truth", size=title_fs) ax.set_title("") ax.set_xlabel("Ground Truth", size=title_fs) plg.suppress_axes_lines(ax) #ax.axis('off') #----- annotator dissent plot(s) ------ cf.training_gts = 'sa' cf.balance_targets = 'rg_bin_targets' dataset = dl.Dataset(cf, mode='train') if bin_stds: #------ binned dissent ----- #cf = get_cf('lidc', "") #cf.balance_target = "regression_targets" all_patients = [(pid,patient['rg_bin_targets']) for pid, patient in dataset.data.items()] non_empty_patients = [(pid, lesions) for (pid, lesions) in all_patients if len(lesions) > 0] mean_std_per_lesion = np.array([(np.mean(roi), np.std(roi)) for (pid, lesions) in non_empty_patients for roi in lesions]) distribution_max_per_lesion = [np.unique(roi, return_counts=True) for (pid, lesions) in non_empty_patients for roi in lesions] distribution_max_per_lesion = np.array([uniq[cts.argmax()] for (uniq, cts) in distribution_max_per_lesion]) binned_stats = [[] for bin_id in cf.bin_id2rg_val.keys()] for l_ix, mean_std in enumerate(mean_std_per_lesion): bin_id = cf.rg_val_to_bin_id(mean_std[0]) bin_id_max = cf.rg_val_to_bin_id(distribution_max_per_lesion[l_ix]) binned_stats[int(bin_id)].append((*mean_std, distribution_max_per_lesion[l_ix], bin_id-bin_id_max)) ax = fig.add_subplot(grid[0, 0]) plg.plot_binned_rater_dissent(cf, binned_stats, ax=ax, fs=text_fs) if numbering: ax.text(text_x, text_y, 'a)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, weight='bold', fontsize=title_fs) ax.margins(y=0.2) ax.set_xlabel("Malignancy-Score Bins", fontsize=title_fs) #ax.yaxis.set_label_position("right") #ax.yaxis.tick_right() ax.set_yticklabels([]) #ax.xaxis.set_label_position("top") #ax.xaxis.tick_top() ax.set_title("Average Rater Dissent", fontsize=title_fs) if confusion_matrix: #------ confusion matrix ------- cm = lidc_annotator_confusion(dataset=dataset, plot=False, normalize="truth") ax = fig.add_subplot(grid[0, 0]) cmap = plg.make_colormap([(1,1,1), cf.dkfz_blue]) plg.plot_confusion_matrix(cf, cm, ax=ax, fs=text_fs, color_bar=False, cmap=cmap )#plg.plt.cm.Purples) ax.set_xticks(np.arange(cm.shape[1])) if numbering: ax.text(-0.16, text_y, 'a)', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, weight='bold', fontsize=title_fs) ax.margins(y=0.2) ax.set_title("Annotator Dissent", fontsize=title_fs) #fig.suptitle(" Example", fontsize=title_fs) #fig.text(0.63, 1.03, "Example", va="center", ha="center", size=title_fs, transform=fig.transFigure) #fig_patches = fig_leg.get_patches() #patches= [plg.mpatches.Patch(color=label.color, label="{:.10s}".format(label.name)) for label in cf.bin_id2label.values() if label.id!=0] #fig.legends.append(fig_leg) #plg.plt.figlegend(handles=patches, loc="lower center", bbox_to_anchor=(0.5, 0.0), borderaxespad=0., # ncol=len(patches), bbox_transform=fig.transFigure, title="Binned Malignancy Score", fontsize= text_fs) plg.plt.tight_layout() if plot_dir is None: plot_dir = "datasets/lidc/misc" out_file = os.path.join(plot_dir, "regrcnn_lidc_diss_example.png") if out_file is not None: plg.plt.savefig(out_file, dpi=600, bbox_inches='tight') def lidc_annotator_dissent_images(exp_dir='', plot_dir=None): if plot_dir is None: plot_dir = "datasets/lidc/misc" import datasets.lidc.data_loader as dl cf = get_cf('lidc', exp_dir) cf.training_gts = "sa" dataset = dl.Dataset(cf, mode='train') pids = {'0069a': 132, '0493a':125, '1008a': 164}#, '0355b': 138, '0484a': 86} # pid : (z_ix to show) # add_pids = dataset.set_ids[65:80] # for pid in add_pids: # try: # # pids[pid] = int(np.median(dataset.data[pid]['fg_slices'][0])) # # except (IndexError, ValueError): # print("pid {} has no foreground".format(pid)) if not os.path.exists(plot_dir): os.mkdir(plot_dir) out_file = os.path.join(plot_dir, "lidc_example_rater_dissent.png") #cf.training_gts = 'sa' cf.roi_items = ['regression_targets', 'rg_bin_targets_sa'] #['class_targets'] + cf.observables_rois title_fs = 14 text_fs = 14 fig = plg.plt.figure(figsize=(10, 5.9)) #w, h #fig.subplots_adjust(hspace=0, wspace=0) grid = plg.plt.GridSpec(len(pids.keys()), 5, wspace=0.0, hspace=0.0, figure=fig) #rows, cols cf.plot_class_ids = True cmap = {id : (label.color if id!=0 else (0.,0.,0.)) for id, label in cf.bin_id2label.items()} legend_handles = set() window_size = (250,250) for p_ix, (pid, z_ix) in enumerate(pids.items()): try: print("plotting pid, z_ix", pid, z_ix) patient = dataset[pid] img = np.load(patient['data'], mmap_mode='r')[z_ix] # z,y,x --> y,x seg = np.load(patient['seg'], mmap_mode='r')['seg'][:,z_ix] # rater,z,y,x --> rater,y,x rg_bin_targets = patient['rg_bin_targets'] contours = np.nonzero(seg[0]) center_y, center_x = np.median(contours[0]), np.median(contours[1]) #min_y, min_x = np.min(contours[0]), np.min(contours[1]) #max_y, max_x = np.max(contours[0]), np.max(contours[1]) #buffer_y, buffer_x = int(seg.shape[1]*0.5), int(seg.shape[2]*0.5) #y_range = np.arange(max(min_y-buffer_y, 0), min(min_y+buffer_y, seg.shape[1])) #x_range = np.arange(max(min_x-buffer_x, 0), min(min_x+buffer_x, seg.shape[2])) y_range = np.arange(max(int(center_y-window_size[0]/2), 0), min(int(center_y+window_size[0]/2), seg.shape[1])) min_x = int(center_x-window_size[1]/2) max_x = int(center_x+window_size[1]/2) if min_x<0: max_x += abs(min_x) elif max_x>seg.shape[2]: min_x -= max_x-seg.shape[2] x_range = np.arange(max(min_x, 0), min(max_x, seg.shape[2])) img = img[y_range][:,x_range] seg = seg[:, y_range][:,:,x_range] # data ax = fig.add_subplot(grid[p_ix, 0]) ax.imshow(img, cmap='gray') plg.suppress_axes_lines(ax) # key = "spec" if "spec" in batch.keys() else "pid" ylabel = str(pid) + "/" + str(z_ix) ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number if p_ix == 0: ax.set_title("Image", fontsize=title_fs) # raters for r_ix in range(seg.shape[0]): rater_bin_targets = rg_bin_targets[:,r_ix] for roi_ix, rating in enumerate(rater_bin_targets): seg[r_ix][seg[r_ix]==roi_ix+1] = rating ax = fig.add_subplot(grid[p_ix, r_ix+1]) ax.imshow(seg[r_ix], cmap='gray') ax.imshow(plg.to_rgba(seg[r_ix], cmap), alpha=0.8) ax.axis('off') if p_ix == 0: ax.set_title("Rating {}".format(r_ix+1), fontsize=title_fs) legend_handles.update([cf.bin_id2label[id] for id in np.unique(seg[r_ix]) if id!=0]) except: print("failed pid", pid) pass legend_handles = [plg.mpatches.Patch(color=label.color, label="{:.10s}".format(label.name)) for label in legend_handles] legend_handles = sorted(legend_handles, key=lambda h: h._label) fig.suptitle("LIDC Single-Rater Annotations", fontsize=title_fs) #patches= [plg.mpatches.Patch(color=label.color, label="{:.10s}".format(label.name)) for label in cf.bin_id2label.values() if label.id!=0] legend = fig.legend(handles=legend_handles, loc="lower center", bbox_to_anchor=(0.5, 0.0), borderaxespad=0, fontsize=text_fs, bbox_transform=fig.transFigure, ncol=len(legend_handles), title="Malignancy Score") plg.plt.setp(legend.get_title(), fontsize=title_fs) #grid.tight_layout(fig) #plg.plt.tight_layout(rect=[0, 0.00, 1, 1.5]) if out_file is not None: plg.plt.savefig(out_file, dpi=600, bbox_inches='tight') return def lidc_results_static(xlabels=None, plot_dir=None, in_percent=True): cf = get_cf('lidc', '') if plot_dir is None: plot_dir = os.path.join('datasets', 'lidc', 'misc') text_fs = 18 fig = plg.plt.figure(figsize=(3, 2.5)) #w,h grid = plg.plt.GridSpec(2, 1, wspace=0.0, hspace=0.0, figure=fig) #r,c #--- LIDC 3D ----- splits = ["Reg R-CNN", "Mask R-CNN"]#, "Reg R-CNN 2D", "Mask R-CNN 2D"] values = {"reg3d": [(0.259, 0.035), (0.628, 0.038), (0.477, 0.035)], "mask3d": [(0.235, 0.027), (0.622, 0.029), (0.411, 0.026)],} groups = [r"$\mathrm{AVP}_{10}$", "$\mathrm{AP}_{10}$", "Bin Acc."] if in_percent: bar_values = [[v[0]*100 for v in split] for split in values.values()] errors = [[v[1]*100 for v in split] for split in values.values()] else: bar_values = [[v[0] for v in split] for split in values.values()] errors = [[v[1] for v in split] for split in values.values()] ax = fig.add_subplot(grid[0,0]) colors = [cf.blue, cf.dkfz_blue] plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, errors=errors, colors=colors, ax=ax, legend=False, label_format="{:.1f}", title="LIDC Results", ylabel=r"3D Perf. (%)", xlabel="Metric", yticklabels=[], ylim=(0,80 if in_percent else 0.8)) #------ LIDC 2D ------- splits = ["Reg R-CNN", "Mask R-CNN"] values = {"reg2d": [(0.148, 0.046), (0.414, 0.052), (0.468, 0.057)], "mask2d": [(0.127, 0.034), (0.406, 0.040), (0.447, 0.018)]} groups = [r"$\mathrm{AVP}_{10}$", "$\mathrm{AP}_{10}$", "Bin Acc."] if in_percent: bar_values = [[v[0]*100 for v in split] for split in values.values()] errors = [[v[1]*100 for v in split] for split in values.values()] else: bar_values = [[v[0] for v in split] for split in values.values()] errors = [[v[1] for v in split] for split in values.values()] ax = fig.add_subplot(grid[1,0]) colors = [cf.blue, cf.dkfz_blue] plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, errors=errors, colors=colors, ax=ax, legend=False, label_format="{:.1f}", title="", ylabel=r"2D Perf.", xlabel="Metric", xticklabels=xlabels, yticklabels=[], ylim=(None,60 if in_percent else 0.6)) plg.plt.tight_layout() plg.plt.savefig(os.path.join(plot_dir, 'lidc_static_results.png'), dpi=700) def toy_results_static(xlabels=None, plot_dir=None, in_percent=True): cf = get_cf('toy', '') if plot_dir is None: plot_dir = os.path.join('datasets', 'toy', 'misc') text_fs = 18 fig = plg.plt.figure(figsize=(3, 2.5)) #w,h grid = plg.plt.GridSpec(2, 1, wspace=0.0, hspace=0.0, figure=fig) #r,c #--- Toy 3D ----- groups = [r"$\mathrm{AVP}_{10}$", "$\mathrm{AP}_{10}$", "Bin Acc."] splits = ["Reg R-CNN", "Mask R-CNN"]#, "Reg R-CNN 2D", "Mask R-CNN 2D"] values = {"reg3d": [(0.881, 0.014), (0.998, 0.004), (0.887, 0.014)], "mask3d": [(0.822, 0.070), (1.0, 0.0), (0.826, 0.069)],} if in_percent: bar_values = [[v[0]*100 for v in split] for split in values.values()] errors = [[v[1]*100 for v in split] for split in values.values()] else: bar_values = [[v[0] for v in split] for split in values.values()] errors = [[v[1] for v in split] for split in values.values()] ax = fig.add_subplot(grid[0,0]) colors = [cf.blue, cf.dkfz_blue] plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, errors=errors, colors=colors, ax=ax, legend=True, label_format="{:.1f}", title="Toy Results", ylabel=r"3D Perf. (%)", xlabel="Metric", yticklabels=[], ylim=(0,130 if in_percent else .3)) #------ Toy 2D ------- groups = [r"$\mathrm{AVP}_{10}$", "$\mathrm{AP}_{10}$", "Bin Acc."] splits = ["Reg R-CNN", "Mask R-CNN"] values = {"reg2d": [(0.859, 0.021), (1., 0.0), (0.860, 0.021)], "mask2d": [(0.748, 0.022), (1., 0.0), (0.748, 0.021)]} if in_percent: bar_values = [[v[0]*100 for v in split] for split in values.values()] errors = [[v[1]*100 for v in split] for split in values.values()] else: bar_values = [[v[0] for v in split] for split in values.values()] errors = [[v[1] for v in split] for split in values.values()] ax = fig.add_subplot(grid[1,0]) colors = [cf.blue, cf.dkfz_blue] plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, errors=errors, colors=colors, ax=ax, legend=False, label_format="{:.1f}", title="", ylabel=r"2D Perf.", xlabel="Metric", xticklabels=xlabels, yticklabels=[], ylim=(None,130 if in_percent else 1.3)) plg.plt.tight_layout() plg.plt.savefig(os.path.join(plot_dir, 'toy_static_results.png'), dpi=700) def analyze_test_df(dataset_name, exp_dir='', cf=None, logger=None, plot_dir=None): evaluator_file = utils.import_module('evaluator', "evaluator.py") if cf is None: cf = get_cf(dataset_name, exp_dir) cf.exp_dir = exp_dir cf.test_dir = os.path.join(exp_dir, 'test') if logger is None: logger = utils.get_logger(cf.exp_dir, False) evaluator = evaluator_file.Evaluator(cf, logger, mode='test') fold_df_paths = sorted([ii for ii in os.listdir(cf.test_dir) if 'test_df.pkl' in ii]) fold_seg_df_paths = sorted([ii for ii in os.listdir(cf.test_dir) if 'test_seg_df.pkl' in ii]) metrics_to_score = ['ap', 'auc']#, 'patient_ap', 'patient_auc', 'patient_dice'] #'rg_bin_accuracy_weighted_tp', 'rg_MAE_w_std_weighted_tp'] #cf.metrics if cf.evaluate_fold_means: means_to_score = [m for m in metrics_to_score] #+ ['rg_MAE_w_std_weighted_tp'] #metrics_to_score += ['rg_MAE_std'] metrics_to_score = [] cf.fold = 'overall' dfs_list = [pd.read_pickle(os.path.join(cf.test_dir, ii)) for ii in fold_df_paths] evaluator.test_df = pd.concat(dfs_list, sort=True) seg_dfs_list = [pd.read_pickle(os.path.join(cf.test_dir, ii)) for ii in fold_seg_df_paths] if len(seg_dfs_list) > 0: evaluator.seg_df = pd.concat(seg_dfs_list, sort=True) # stats, _ = evaluator.return_metrics(evaluator.test_df, cf.class_dict) # results_table_path = os.path.join(cf.exp_dir, "../", "semi_man_summary.csv") # # ---column headers--- # col_headers = ["Experiment Name", "CV Folds", "Spatial Dim", "Clustering Kind", "Clustering IoU", "Merge-2D-to-3D IoU"] # if hasattr(cf, "test_against_exact_gt"): # col_headers.append('Exact GT') # for s in stats: # assert "overall" in s['name'].split(" ")[0] # if cf.class_dict[cf.patient_class_of_interest] in s['name']: # for metric in metrics_to_score: # #if metric in s.keys() and not np.isnan(s[metric]): # col_headers.append('{}_{} : {}'.format(*s['name'].split(" ")[1:], metric)) # for mean in means_to_score: # if mean == "rg_MAE_w_std_weighted_tp": # col_headers.append('(MAE_fold_mean\u00B1std_fold_mean)\u00B1fold_mean_std\u00B1fold_std_std)'.format(*s['name'].split(" ")[1:], mean)) # elif mean in s.keys() and not np.isnan(s[mean]): # col_headers.append('{}_{} : {}'.format(*s['name'].split(" ")[1:], mean)) # else: # print("skipping {}".format(mean)) # with open(results_table_path, 'a') as handle: # with open(results_table_path, 'r') as doublehandle: # last_header = doublehandle.readlines() # if len(last_header)==0 or len(col_headers)!=len(last_header[1].split(",")[:-1]) or \ # not all([col_headers[ix]==lhix for ix, lhix in enumerate(last_header[1].split(",")[:-1])]): # handle.write('\n') # for head in col_headers: # handle.write(head+',') # handle.write('\n') # # # --- columns content--- # handle.write('{},'.format(cf.exp_dir.split(os.sep)[-1])) # handle.write('{},'.format(str(evaluator.test_df.fold.unique().tolist()).replace(",", ""))) # handle.write('{}D,'.format(cf.dim)) # handle.write('{},'.format(cf.clustering)) # handle.write('{},'.format(cf.clustering_iou if cf.clustering else str("N/A"))) # handle.write('{},'.format(cf.merge_3D_iou if cf.merge_2D_to_3D_preds else str("N/A"))) # if hasattr(cf, "test_against_exact_gt"): # handle.write('{},'.format(cf.test_against_exact_gt)) # for s in stats: # if cf.class_dict[cf.patient_class_of_interest] in s['name']: # for metric in metrics_to_score: # #if metric in s.keys() and not np.isnan(s[metric]): # needed as long as no dice on patient level poss # handle.write('{:0.3f}, '.format(s[metric])) # for mean in means_to_score: # #if metric in s.keys() and not np.isnan(s[metric]): # if mean=="rg_MAE_w_std_weighted_tp": # handle.write('({:0.3f}\u00B1{:0.3f})\u00B1({:0.3f}\u00B1{:0.3f}),'.format(*s[mean + "_folds_mean"], *s[mean + "_folds_std"])) # elif mean in s.keys() and not np.isnan(s[mean]): # handle.write('{:0.3f}\u00B1{:0.3f},'.format(s[mean+"_folds_mean"], s[mean+"_folds_std"])) # else: # print("skipping {}".format(mean)) # # handle.write('\n') return evaluator.test_df def cluster_results_to_df(dataset_name, exp_dir='', overall_df=None, cf=None, logger=None, plot_dir=None): evaluator_file = utils.import_module('evaluator', "evaluator.py") if cf is None: cf = get_cf(dataset_name, exp_dir) cf.exp_dir = exp_dir cf.test_dir = os.path.join(exp_dir, 'test') if logger is None: logger = utils.get_logger(cf.exp_dir, False) evaluator = evaluator_file.Evaluator(cf, logger, mode='test') cf.fold = 'overall' metrics_to_score = ['ap', 'auc']#, 'patient_ap', 'patient_auc', 'patient_dice'] #'rg_bin_accuracy_weighted_tp', 'rg_MAE_w_std_weighted_tp'] #cf.metrics if cf.evaluate_fold_means: means_to_score = [m for m in metrics_to_score] #+ ['rg_MAE_w_std_weighted_tp'] #metrics_to_score += ['rg_MAE_std'] metrics_to_score = [] # use passed overall_df or, if not given, read dfs from file if overall_df is None: fold_df_paths = sorted([ii for ii in os.listdir(cf.test_dir) if 'test_df.pkl' in ii]) fold_seg_df_paths = sorted([ii for ii in os.listdir(cf.test_dir) if 'test_seg_df.pkl' in ii]) for paths in [fold_df_paths, fold_seg_df_paths]: assert len(paths) <= cf.n_cv_splits, "found {} > nr of cv splits results dfs in {}".format(len(paths), cf.test_dir) dfs_list = [pd.read_pickle(os.path.join(cf.test_dir, ii)) for ii in fold_df_paths] evaluator.test_df = pd.concat(dfs_list, sort=True) # seg_dfs_list = [pd.read_pickle(os.path.join(cf.test_dir, ii)) for ii in fold_seg_df_paths] # if len(seg_dfs_list) > 0: # evaluator.seg_df = pd.concat(seg_dfs_list, sort=True) else: evaluator.test_df = overall_df # todo seg_df if desired stats, _ = evaluator.return_metrics(evaluator.test_df, cf.class_dict) # ---column headers--- col_headers = ["Experiment Name", "Model", "CV Folds", "Spatial Dim", "Clustering Kind", "Clustering IoU", "Merge-2D-to-3D IoU"] for s in stats: assert "overall" in s['name'].split(" ")[0] if cf.class_dict[cf.patient_class_of_interest] in s['name']: for metric in metrics_to_score: #if metric in s.keys() and not np.isnan(s[metric]): col_headers.append('{}_{} : {}'.format(*s['name'].split(" ")[1:], metric)) for mean in means_to_score: if mean in s.keys() and not np.isnan(s[mean]): col_headers.append('{}_{} : {}'.format(*s['name'].split(" ")[1:], mean+"_folds_mean")) else: print("skipping {}".format(mean)) results_df = pd.DataFrame(columns=col_headers) # --- columns content--- row = [] row.append('{}'.format(cf.exp_dir.split(os.sep)[-1])) model = 'frcnn' if (cf.model=="mrcnn" and cf.frcnn_mode) else cf.model row.append('{}'.format(model)) row.append('{}'.format(str(evaluator.test_df.fold.unique().tolist()).replace(",", ""))) row.append('{}D'.format(cf.dim)) row.append('{}'.format(cf.clustering)) row.append('{}'.format(cf.clustering_iou if cf.clustering else "N/A")) row.append('{}'.format(cf.merge_3D_iou if cf.merge_2D_to_3D_preds else "N/A")) for s in stats: if cf.class_dict[cf.patient_class_of_interest] in s['name']: for metric in metrics_to_score: #if metric in s.keys() and not np.isnan(s[metric]): # needed as long as no dice on patient level poss row.append('{:0.3f} '.format(s[metric])) for mean in means_to_score: #if metric in s.keys() and not np.isnan(s[metric]): if mean+"_folds_mean" in s.keys() and not np.isnan(s[mean+"_folds_mean"]): row.append('{:0.3f}\u00B1{:0.3f}'.format(s[mean+"_folds_mean"], s[mean+"_folds_std"])) else: print("skipping {}".format(mean+"_folds_mean")) #print("row, clustering, iou, exp", row, cf.clustering, cf.clustering_iou, cf.exp_dir) results_df.loc[0] = row return results_df def multiple_clustering_results(dataset_name, exp_dir, plot_dir=None, plot_hist=False): print("Gathering exp {}".format(exp_dir)) cf = get_cf(dataset_name, exp_dir) cf.n_workers = 1 logger = logging.getLogger("dummy") logger.setLevel(logging.DEBUG) #logger.addHandler(logging.StreamHandler()) cf.exp_dir = exp_dir cf.test_dir = os.path.join(exp_dir, 'test') cf.plot_prediction_histograms = False if plot_dir is None: #plot_dir = os.path.join(cf.test_dir, 'histograms') plot_dir = os.path.join("datasets", dataset_name, "misc") os.makedirs(plot_dir, exist_ok=True) # fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if # os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")]) folds = range(cf.n_cv_splits) clusterings = {None: ['lol'], 'wbc': [0.0, 0.1, 0.2, 0.3, 0.4], 'nms': [0.0, 0.1, 0.2, 0.3, 0.4]} #clusterings = {'wbc': [0.1,], 'nms': [0.1,]} #clusterings = {None: ['lol']} if plot_hist: clusterings = {None: ['lol'], 'nms': [0.1, ], 'wbc': [0.1, ]} class_of_interest = cf.patient_class_of_interest try: if plot_hist: title_fs, text_fs = 16, 13 fig = plg.plt.figure(figsize=(11, 8)) #width, height grid = plg.plt.GridSpec(len(clusterings.keys()), max([len(v) for v in clusterings.values()])+1, wspace=0.0, hspace=0.0, figure=fig) #rows, cols plg.plt.suptitle("Faster R-CNN+", fontsize=title_fs, va='bottom', y=0.925) results_df = pd.DataFrame() for cl_ix, (clustering, ious) in enumerate(clusterings.items()): cf.clustering = clustering for iou_ix, iou in enumerate(ious): cf.clustering_iou = iou print(r"Producing Results for Clustering {} @ IoU {}".format(cf.clustering, cf.clustering_iou)) overall_test_df = pd.DataFrame() for fold in folds[:]: cf.fold = fold cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold)) predictor = predictor_file.Predictor(cf, net=None, logger=logger, mode='analysis') results_list = predictor.load_saved_predictions() logger.info('starting evaluation...') evaluator = evaluator_file.Evaluator(cf, logger, mode='test') evaluator.evaluate_predictions(results_list) #evaluator.score_test_df(max_fold=100) overall_test_df = overall_test_df.append(evaluator.test_df) results_df = results_df.append(cluster_results_to_df(dataset_name, overall_df=overall_test_df,cf=cf, logger=logger)) if plot_hist: if clustering=='wbc' and iou_ix==len(ious)-1: # plot n_missing histogram for last wbc clustering only out_filename = os.path.join(plot_dir, 'analysis_n_missing_overall_hist_{}_{}.png'.format(clustering, iou)) ax = fig.add_subplot(grid[cl_ix, iou_ix+1]) plg.plot_wbc_n_missing(cf, overall_test_df, outfile=out_filename, fs=text_fs, ax=ax) ax.set_title("WBC Missing Predictions per Cluster.", fontsize=title_fs) #ax.set_ylabel(r"Average Missing Preds per Cluster (%)") ax.yaxis.tick_right() ax.yaxis.set_label_position("right") ax.text(0.07, 0.87, "{}) WBC".format(chr(len(clusterings.keys())*len(ious)+97)), transform=ax.transAxes, color=cf.white, fontsize=title_fs, bbox=dict(boxstyle='square', facecolor='black', edgecolor='none', alpha=0.9)) overall_test_df = overall_test_df[overall_test_df.pred_class == class_of_interest] overall_test_df = overall_test_df[overall_test_df.det_type!='patient_tn'] out_filename = "analysis_fold_overall_hist_{}_{}.png".format(clustering, iou) out_filename = os.path.join(plot_dir, out_filename) ax = fig.add_subplot(grid[cl_ix, iou_ix]) plg.plot_prediction_hist(cf, overall_test_df, out_filename, fs=text_fs, ax=ax) ax.text(0.11, 0.87, "{}) {}".format(chr((cl_ix+1)*len(ious)+96), clustering.upper() if clustering else "Raw Preds"), transform=ax.transAxes, color=cf.white, bbox=dict(boxstyle='square', facecolor='black', edgecolor='none', alpha=0.9), fontsize=title_fs) if cl_ix==0 and iou_ix==0: ax.set_title("Prediction Histograms Malignant Class", fontsize=title_fs) ax.legend(loc="best", fontsize=text_fs) else: ax.set_title("") #analyze_test_df(dataset_name, cf=cf, logger=logger) if plot_hist: #plg.plt.subplots_adjust(top=0.) plg.plt.savefig(os.path.join(plot_dir, "combined_hist_plot.pdf"), dpi=600, bbox_inches='tight') except FileNotFoundError as e: print("Ignoring exp dir {} due to\n{}".format(exp_dir, e)) logger.handlers = [] del cf; del logger return results_df def gather_clustering_results(dataset_name, exp_parent_dir, exps_filter=None, processes=os.cpu_count()//2): exp_dirs = [os.path.join(exp_parent_dir, i) for i in os.listdir(exp_parent_dir + "/") if os.path.isdir(os.path.join(exp_parent_dir, i))]#[:1] if exps_filter is not None: exp_dirs = [ed for ed in exp_dirs if not exps_filter in ed] # for debugging #exp_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_frcnn3d_cl_bs6" #exp_dirs = [exp_dir,] #exp_dirs = ["/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_detfpn2d_cl_bs10",] results_df = pd.DataFrame() p = NoDaemonProcessPool(processes=processes) mp_inputs = [(dataset_name, exp_dir) for exp_dir in exp_dirs][:] results_dfs = p.starmap(multiple_clustering_results, mp_inputs) p.close() p.join() for df in results_dfs: results_df = results_df.append(df) results_df.to_csv(os.path.join(exp_parent_dir, "df_cluster_summary.csv"), index=False) return results_df def plot_cluster_results_grid(cf, res_df, ylim=None, out_file=None): """ :param cf: :param res_df: results over a single dimension setting (2D or 3D), over all clustering methods and ious. :param out_file: :return: """ is_2d = np.all(res_df["Spatial Dim"]=="2D") # pandas has problems with recognising "N/A" string --> replace by None #res_df['Merge-2D-to-3D IoU'].iloc[res_df['Merge-2D-to-3D IoU'] == "N/A"] = None n_rows = 3#4 if is_2d else 3 grid = plg.plt.GridSpec(n_rows, 5, wspace=0.4, hspace=0.3) fig = plg.plt.figure(figsize=(11,6)) splits = res_df["Model"].unique().tolist() # need to be model names for split in splits: assoc_exps = res_df[res_df["Model"]==split]["Experiment Name"].unique() if len(assoc_exps)>1: print("Model {} has multiple experiments:\n{}".format(split, assoc_exps)) #res_df = res_df.where(~(res_df["Model"] == split), res_df["Experiment Name"], axis=0) raise Exception("Multiple Experiments") sort_map = {'detection_fpn': 0, 'mrcnn':1, 'frcnn':2, 'retina_net':3, 'retina_unet':4} splits.sort(key=sort_map.__getitem__) #colors = [cf.color_palette[ix+3 % len(cf.color_palette)] for ix in range(len(splits))] color_map = {'detection_fpn': cf.magenta, 'mrcnn':cf.blue, 'frcnn': cf.dark_blue, 'retina_net': cf.aubergine, 'retina_unet': cf.purple} colors = [color_map[split] for split in splits] alphas = [0.9,] * len(splits) legend_handles = [] model_renamer = {'detection_fpn': "Detection U-Net", 'mrcnn': "Mask R-CNN", 'frcnn': "Faster R-CNN+", 'retina_net': "RetinaNet", 'retina_unet': "Retina U-Net"} for rix, c_kind in zip([0, 1],['wbc', 'nms']): kind_df = res_df[res_df['Clustering Kind'] == c_kind] groups = kind_df['Clustering IoU'].unique() #for cix, iou in enumerate(groups): assert np.all([split in splits for split in kind_df["Model"].unique()]) #need to be model names ax = fig.add_subplot(grid[rix,:]) bar_values = [kind_df[kind_df["Model"]==split]["rois_malignant : ap_folds_mean"] for split in splits] bar_stds = [[float(val.split('\u00B1')[1]) for val in split_vals] for split_vals in bar_values] bar_values = [ [float(val.split('\u00B1')[0]) for val in split_vals] for split_vals in bar_values ] xlabel='' if rix == 0 else "Clustering IoU" ylabel = str(c_kind.upper()) + " / AP" lh = plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, colors=colors, alphas=alphas, errors=bar_stds, ax=ax, ylabel=ylabel, xlabel=xlabel) legend_handles.append(lh) if rix == 0: ax.axes.get_xaxis().set_ticks([]) #ax.spines['top'].set_visible(False) #ax.spines['right'].set_visible(False) ax.spines['bottom'].set_visible(False) #ax.spines['left'].set_visible(False) else: ax.spines['top'].set_visible(False) #ticklab = ax.xaxis.get_ticklabels() #trans = ticklab.get_transform() ax.xaxis.set_label_coords(0.05, -0.05) ax.set_ylim(0.,ylim) if is_2d: # only 2d-3d merging @ 0.1 ax = fig.add_subplot(grid[2, 1]) kind_df = res_df[(res_df['Clustering Kind'] == 'None') & ~(res_df['Merge-2D-to-3D IoU'].isna())] groups = kind_df['Clustering IoU'].unique() bar_values = [kind_df[kind_df["Model"] == split]["rois_malignant : ap_folds_mean"] for split in splits] bar_stds = [[float(val.split('\u00B1')[1]) for val in split_vals] for split_vals in bar_values] bar_values = np.array([[float(val.split('\u00B1')[0]) for val in split_vals] for split_vals in bar_values]) lh = plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, colors=colors, alphas=alphas, errors=bar_stds, ax=ax, ylabel="2D-3D Merging\nOnly / AP") legend_handles.append(lh) ax.axes.get_xaxis().set_ticks([]) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.spines['left'].set_visible(False) ax.set_ylim(0., ylim) next_row = 2 next_col = 2 else: next_row = 2 next_col = 2 # No clustering at all ax = fig.add_subplot(grid[next_row, next_col]) kind_df = res_df[(res_df['Clustering Kind'] == 'None') & (res_df['Merge-2D-to-3D IoU'].isna())] groups = kind_df['Clustering IoU'].unique() bar_values = [kind_df[kind_df["Model"] == split]["rois_malignant : ap_folds_mean"] for split in splits] bar_stds = [[float(val.split('\u00B1')[1]) for val in split_vals] for split_vals in bar_values] bar_values = np.array([[float(val.split('\u00B1')[0]) for val in split_vals] for split_vals in bar_values]) lh = plg.plot_grouped_bar_chart(cf, bar_values, groups, splits, colors=colors, alphas=alphas, errors=bar_stds, ax=ax, ylabel="No Clustering / AP") legend_handles.append(lh) #plg.suppress_axes_lines(ax) #ax = fig.add_subplot(grid[next_row, 0]) #ax.set_ylabel("No Clustering") #plg.suppress_axes_lines(ax) ax.axes.get_xaxis().set_ticks([]) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.spines['left'].set_visible(False) ax.set_ylim(0., ylim) ax = fig.add_subplot(grid[next_row, 3]) # awful hot fix: only legend_handles[0] used in order to have same order as in plots. legend_handles = [plg.mpatches.Patch(color=handle[0], alpha=handle[1], label=model_renamer[handle[2]]) for handle in legend_handles[0]] ax.legend(handles=legend_handles) ax.axis('off') fig.suptitle('Prostate {} Results over Clustering Settings'.format(res_df["Spatial Dim"].unique().item()), fontsize=14) if out_file is not None: plg.plt.savefig(out_file) return def get_plot_clustering_results(dataset_name, exp_parent_dir, res_from_file=True, exps_filter=None): if not res_from_file: results_df = gather_clustering_results(dataset_name, exp_parent_dir, exps_filter=exps_filter) else: results_df = pd.read_csv(os.path.join(exp_parent_dir, "df_cluster_summary.csv")) if os.path.isfile(os.path.join(exp_parent_dir, "df_cluster_summary_no_clustering_2D.csv")): results_df = results_df.append(pd.read_csv(os.path.join(exp_parent_dir, "df_cluster_summary_no_clustering_2D.csv"))) cf = get_cf(dataset_name) if np.count_nonzero(results_df["Spatial Dim"] == "3D") >0: # 3D plot_cluster_results_grid(cf, results_df[results_df["Spatial Dim"] == "3D"], ylim=0.52, out_file=os.path.join(exp_parent_dir, "cluster_results_3D.pdf")) if np.count_nonzero(results_df["Spatial Dim"] == "2D") > 0: # 2D plot_cluster_results_grid(cf, results_df[results_df["Spatial Dim"]=="2D"], ylim=0.4, out_file=os.path.join(exp_parent_dir, "cluster_results_2D.pdf")) def plot_single_results(cf, exp_dir, plot_files, res_df=None): out_file = os.path.join(exp_dir, "inference_analysis", "single_results.pdf") plot_files = utils.load_obj(plot_files) batch = plot_files["batch"] results_dict = plot_files["res_dict"] cf.roi_items = ['class_targets'] class_renamer = {1: "GS 6", 2: "GS $\geq 7$"} gs_renamer = {60: "6", 71: "7a"} if "adcb" in exp_dir: modality = "adcb" elif "t2" in exp_dir: modality = "t2" else: modality = "b" text_fs = 16 if modality=="t2": n_rows, n_cols = 2, 3 gt_col = 1 fig_w, fig_h = 14, 4 input_x, input_y = 0.05, 0.9 z_ix = 11 thresh = 0.22 input_title = "Input" elif modality=="b": n_rows, n_cols = 2, 6 gt_col = 2 # = gt_span fig_w, fig_h = 14, 4 input_x, input_y = 0.08, 0.8 z_ix = 8 thresh = 0.16 input_title = " Input" elif modality=="adcb": n_rows, n_cols = 2, 7 gt_col = 3 fig_w, fig_h = 14, 4 input_x, input_y = 0.08, 0.8 z_ix = 8 thresh = 0.16 input_title = "Input" fig_w, fig_h = 12, 3.87 fig = plg.plt.figure(figsize=(fig_w, fig_h)) grid = plg.plt.GridSpec(n_rows, n_cols, wspace=0.0, hspace=0.0, figure=fig) cf.plot_class_ids = True if modality=="t2": ax = fig.add_subplot(grid[:, 0]) ax.imshow(batch['patient_data'][0, 0, :, :, z_ix], cmap='gray') ax.set_title("Input", size=text_fs) ax.text(0.05, 0.9, "T2", size=text_fs, color=cf.white, transform=ax.transAxes, bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7)) ax.axis("off") elif modality=="b": for m_ix, b in enumerate([50, 500, 1000, 1500]): ax = fig.add_subplot(grid[int(np.round(m_ix/4+0.0001)), m_ix%2]) print(int(np.round(m_ix/4+0.0001)), m_ix%2) ax.imshow(batch['patient_data'][0, m_ix, :, :, z_ix], cmap='gray') ax.text(input_x, input_y, r"{}{}".format("$b=$" if m_ix==0 else "", b), size=text_fs, color=cf.white, transform=ax.transAxes, bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7)) ax.axis("off") if b==50: ax.set_title(input_title, size=text_fs) elif modality=="adcb": for m_ix, b in enumerate(["ADC", 50, 500, 1000, 1500]): p_ix = m_ix + 1 if m_ix>2 else m_ix ax = fig.add_subplot(grid[int(np.round(p_ix/6+0.0001)), p_ix%3]) print(int(np.round(p_ix/4+0.0001)), p_ix%2) ax.imshow(batch['patient_data'][0, m_ix, :, :, z_ix], cmap='gray') ax.text(input_x, input_y, r"{}{}".format("$b=$" if m_ix==1 else "", b), size=text_fs, color=cf.white, transform=ax.transAxes, bbox=dict(facecolor=cf.black, alpha=0.7, edgecolor=cf.white, clip_on=False, pad=7)) ax.axis("off") if b==50: ax.set_title(input_title, size=text_fs) ax_gt = fig.add_subplot(grid[:, gt_col:gt_col+2]) # GT ax_pred = fig.add_subplot(grid[:, gt_col+2:gt_col+4]) # Prediction #ax.imshow(batch['patient_data'][0, 0, :, :, z_ix], cmap='gray') #ax.imshow(batch['patient_data'][0, 0, :, :, z_ix], cmap='gray') #ax.imshow(plg.to_rgba(batch['patient_seg'][0,0,:,:,z_ix], cf.cmap), alpha=0.8) plg.view_batch_thesis(cf, batch, res_dict=results_dict, legend=True, sample_picks=None, patient_items=True, vol_slice_picks=[z_ix,], show_gt_labels=True, box_score_thres=thresh, plot_mods=True, out_file=None, dpi=600, return_fig=False, axes={'gt':ax_gt, 'pred':ax_pred}, fontsize=text_fs) ax_gt.set_title("Ground Truth", size=text_fs) ax_pred.set_title("Prediction", size=text_fs) texts = list(ax_gt.texts) ax_gt.texts = [] for text in texts: cl_id = int(text.get_text()) x, y = text.get_position() text_str = "GS="+str(gs_renamer[cf.class_id2label[cl_id].gleasons[0]]) ax_gt.text(x-4*text_fs//2, y, text_str, color=text.get_color(), fontsize=text_fs, bbox=dict(facecolor=text.get_bbox_patch().get_facecolor(), alpha=0.7, edgecolor='none', clip_on=True, pad=0)) texts = list(ax_pred.texts) ax_pred.texts = [] for text in texts: x, y = text.get_position() x -= 4 * text_fs // 2 try: cl_id = int(text.get_text()) text_str = class_renamer[cl_id] except ValueError: text_str = text.get_text() if text.get_bbox_patch().get_facecolor()[:3]==cf.dark_green: x -= 4* text_fs ax_pred.text(x, y, text_str, color=text.get_color(), fontsize=text_fs, bbox=dict(facecolor=text.get_bbox_patch().get_facecolor(), alpha=0.7, edgecolor='none', clip_on=True, pad=0)) ax_gt.axis("off") ax_pred.axis("off") plg.plt.tight_layout() if out_file is not None: plg.plt.savefig(out_file, dpi=600, bbox_inches='tight') return def find_suitable_examples(exp_dir1, exp_dir2): test_df1 = analyze_test_df('lidc',exp_dir1) test_df2 = analyze_test_df('lidc', exp_dir2) test_df1 = test_df1[test_df1.pred_score>0.3] test_df2 = test_df2[test_df2.pred_score > 0.3] tp_df1 = test_df1[test_df1.det_type == 'det_tp'] tp_pids = tp_df1.pid.unique() tp_fp_pids = test_df2[(test_df2.pid.isin(tp_pids)) & ((test_df2.regressions-test_df2.rg_targets).abs()>1)].pid.unique() cand_df = tp_df1[tp_df1.pid.isin(tp_fp_pids)] sorter = (cand_df.regressions - cand_df.rg_targets).abs().argsort() cand_df = cand_df.iloc[sorter] print("Good guesses for examples: ", cand_df.pid.unique()[:20]) return def plot_single_results_lidc(): dataset_name = 'lidc' exp_dir1 = '/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/ms12345_mrcnn3d_rg_copiedparams' exp_dir2 = '/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/ms12345_mrcnn3d_rgbin_copiedparams' cf = get_cf(dataset_name, exp_dir1) #file_names = [f_name for f_name in os.listdir(os.path.join(exp_dir, 'inference_analysis')) if f_name.endswith('.pkl')] # file_names = [os.path.join(exp_dir, "inference_analysis", f_name) for f_name in file_names] file_names = ['bytes_merged_boxes_fold_0_pid_0296a.pkl', 'bytes_merged_boxes_fold_2_pid_0416a.pkl', 'bytes_merged_boxes_fold_1_pid_0635a.pkl', "bytes_merged_boxes_fold_0_pid_0811a.pkl", "bytes_merged_boxes_fold_0_pid_0969a.pkl", # 'bytes_merged_boxes_fold_0_pid_0484a.pkl', 'bytes_merged_boxes_fold_0_pid_0492a.pkl', # 'bytes_merged_boxes_fold_0_pid_0505a.pkl','bytes_merged_boxes_fold_2_pid_0164a.pkl', # 'bytes_merged_boxes_fold_3_pid_0594a.pkl', ] z_ics = [167, 159, 107, 194, 177, # 84, 145, # 212, 219, # 67 ] plot_files = [ {'files': [os.path.join(exp_dir, "inference_analysis", f_name) for exp_dir in [exp_dir1, exp_dir2]], 'z_ix': z_ix} for (f_name, z_ix) in zip(file_names, z_ics) ] info_df_path = '/mnt/HDD2TB/Documents/data/lidc/pp_20190318/patient_gts_{}/info_df.pickle'.format(cf.training_gts) info_df = pd.read_pickle(info_df_path) #cf.training_gts = 'sa' cf.roi_items = ['regression_targets', 'rg_bin_targets_sa'] #['class_targets'] + cf.observables_rois text_fs = 8 fig = plg.plt.figure(figsize=(6, 9.9)) #w, h #fig = plg.plt.figure(figsize=(6, 6.5)) #fig.subplots_adjust(hspace=0, wspace=0) grid = plg.plt.GridSpec(len(plot_files), 3, wspace=0.0, hspace=0.0, figure=fig) #rows, cols cf.plot_class_ids = True for f_ix, pack in enumerate(plot_files): z_ix = plot_files[f_ix]['z_ix'] for model_ix in range(2)[::-1]: print("f_ix, m_ix", f_ix, model_ix) plot_file = utils.load_obj(plot_files[f_ix]['files'][model_ix]) batch = plot_file["batch"] pid = batch["pid"][0] batch['patient_rg_bin_targets_sa'] = info_df[info_df.pid == pid]['class_target'].tolist() # apply same filter as with merged GTs: need at least two non-zero votes to consider a RoI. batch['patient_rg_bin_targets_sa'] = [[four_votes for four_votes in batch_el if np.count_nonzero(four_votes>0)>=2] for batch_el in batch['patient_rg_bin_targets_sa']] results_dict = plot_file["res_dict"] # pred ax = fig.add_subplot(grid[f_ix, model_ix+1]) plg.view_batch_thesis(cf, batch, res_dict=results_dict, legend=True, sample_picks=None, vol_slice_picks=[z_ix, ], show_gt_labels=True, box_score_thres=0.2, plot_mods=False, out_file=None, dpi=600, patient_items=True, return_fig=False, axes={'pred': ax}) if f_ix==0: ax.set_title("{}".format("Reg R-CNN" if model_ix==0 else "Mask R-CNN"), size=text_fs*1.3) else: ax.set_title("") ax.axis("off") #grid.tight_layout(fig) # GT if model_ix==0: ax = fig.add_subplot(grid[f_ix, 0]) # ax.imshow(batch['patient_data'][0, 0, :, :, z_ix], cmap='gray') # ax.imshow(plg.to_rgba(batch['patient_seg'][0,0,:,:,z_ix], cf.cmap), alpha=0.8) boxes_fig = plg.view_batch_thesis(cf, batch, res_dict=results_dict, legend=True, sample_picks=None, vol_slice_picks=[z_ix, ], show_gt_labels=True, box_score_thres=0.1, plot_mods=False, seg_cmap="rg", out_file=None, dpi=600, patient_items=True, return_fig=False, axes={'gt':ax}) ax.set_ylabel(r"$\mathbf{"+chr(f_ix+97)+")}$ " + ax.get_ylabel()) ax.set_ylabel("") if f_ix==0: ax.set_title("Ground Truth", size=text_fs*1.3) else: ax.set_title("") #fig_patches = fig_leg.get_patches() patches= [plg.mpatches.Patch(color=label.color, label="{:.10s}".format(label.name)) for label in cf.bin_id2label.values() if not label.id in [0,]] #fig.legends.append(fig_leg) plg.plt.figlegend(handles=patches, loc="lower center", bbox_to_anchor=(0.5, 0.0), borderaxespad=0., ncol=len(patches), bbox_transform=fig.transFigure, title="Binned Malignancy Score", fontsize= text_fs) plg.plt.tight_layout() out_file = os.path.join(exp_dir1, "inference_analysis", "lidc_example_results_solarized.pdf") if out_file is not None: plg.plt.savefig(out_file, dpi=600, bbox_inches='tight') def box_clustering(exp_dir='', plot_dir=None): import datasets.prostate.data_loader as dl cf = get_cf('prostate', exp_dir) if plot_dir is None: plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('datasets', 'prostate', 'misc') fig = plg.plt.figure(figsize=(10, 4)) #fig.subplots_adjust(hspace=0, wspace=0) grid = plg.plt.GridSpec(2, 3, wspace=0.0, hspace=0., figure=fig) fs = 14 xyA = (.9, 0.5) xyB = (0.05, .5) patch_size = np.array([200, 320]) clustering_iou = 0.1 img_y, img_x = patch_size boxes = [ {'box_coords': [img_y * 0.2, img_x * 0.04, img_y * 0.55, img_x * 0.31], 'box_score': 0.45, 'box_cl': 1, 'regression': 2., 'rg_bin': cf.rg_val_to_bin_id(1.), 'box_patch_center_factor': 1., 'ens_ix': 1, 'box_n_overlaps': 1.}, {'box_coords': [img_y*0.05, img_x*0.05, img_y*0.5, img_x*0.3], 'box_score': 0.85, 'box_cl': 2, 'regression': 1., 'rg_bin': cf.rg_val_to_bin_id(1.), 'box_patch_center_factor': 1., 'ens_ix':1, 'box_n_overlaps':1.}, {'box_coords': [img_y * 0.1, img_x * 0.2, img_y * 0.4, img_x * 0.7], 'box_score': 0.95, 'box_cl': 2, 'regression': 1., 'rg_bin': cf.rg_val_to_bin_id(1.), 'box_patch_center_factor': 1., 'ens_ix':1, 'box_n_overlaps':1.}, {'box_coords': [img_y * 0.80, img_x * 0.35, img_y * 0.95, img_x * 0.85], 'box_score': 0.6, 'box_cl': 2, 'regression': 1., 'rg_bin': cf.rg_val_to_bin_id(1.), 'box_patch_center_factor': 1., 'ens_ix': 1, 'box_n_overlaps': 1.}, {'box_coords': [img_y * 0.85, img_x * 0.4, img_y * 0.93, img_x * 0.9], 'box_score': 0.85, 'box_cl': 2, 'regression': 1., 'rg_bin': cf.rg_val_to_bin_id(1.), 'box_patch_center_factor': 1., 'ens_ix':1, 'box_n_overlaps':1.}, ] for box in boxes: c = box['box_coords'] box_centers = np.array([(c[ii + 2] - c[ii]) / 2 for ii in range(len(c) // 2)]) box['box_patch_center_factor'] = np.mean( [norm.pdf(bc, loc=pc, scale=pc * 0.8) * np.sqrt(2 * np.pi) * pc * 0.8 for bc, pc in zip(box_centers, patch_size / 2)]) print("pc fact", box['box_patch_center_factor']) box_coords = np.array([box['box_coords'] for box in boxes]) box_scores = np.array([box['box_score'] for box in boxes]) box_cl_ids = np.array([box['box_cl'] for box in boxes]) ax0 = fig.add_subplot(grid[:,:2]) plg.plot_boxes(cf, box_coords, patch_size, box_scores, box_cl_ids, out_file=os.path.join(plot_dir, "demo_boxes_unclustered.png"), ax=ax0) ax0.text(*xyA, 'a) Raw ', horizontalalignment='right', verticalalignment='center', transform=ax0.transAxes, weight='bold', fontsize=fs) nms_boxes = [] for cl in range(1,3): cl_boxes = [box for box in boxes if box['box_cl'] == cl ] box_coords = np.array([box['box_coords'] for box in cl_boxes]) box_scores = np.array([box['box_score'] for box in cl_boxes]) if 0 not in box_scores.shape: keep_ix = mutils.nms_numpy(box_coords, box_scores, thresh=clustering_iou) else: keep_ix = [] nms_boxes += [cl_boxes[ix] for ix in keep_ix] box_coords = np.array([box['box_coords'] for box in nms_boxes]) box_scores = np.array([box['box_score'] for box in nms_boxes]) box_cl_ids = np.array([box['box_cl'] for box in nms_boxes]) ax1 = fig.add_subplot(grid[1, 2]) nms_color = cf.black plg.plot_boxes(cf, box_coords, patch_size, box_scores, box_cl_ids, out_file=os.path.join(plot_dir, "demo_boxes_nms_iou_{}.png".format(clustering_iou)), ax=ax1) ax1.text(*xyB, ' c) NMS', horizontalalignment='left', verticalalignment='center', transform=ax1.transAxes, weight='bold', color=nms_color, fontsize=fs) #------ WBC ------------------- regress_flag = False wbc_boxes = [] for cl in range(1,3): cl_boxes = [box for box in boxes if box['box_cl'] == cl] box_coords = np.array([box['box_coords'] for box in cl_boxes]) box_scores = np.array([box['box_score'] for box in cl_boxes]) box_center_factor = np.array([b['box_patch_center_factor'] for b in cl_boxes]) box_n_overlaps = np.array([b['box_n_overlaps'] for b in cl_boxes]) box_ens_ix = np.array([b['ens_ix'] for b in cl_boxes]) box_regressions = np.array([b['regression'] for b in cl_boxes]) if regress_flag else None box_rg_bins = np.array([b['rg_bin'] if 'rg_bin' in b.keys() else float('NaN') for b in cl_boxes]) box_rg_uncs = np.array([b['rg_uncertainty'] if 'rg_uncertainty' in b.keys() else float('NaN') for b in cl_boxes]) if 0 not in box_scores.shape: keep_scores, keep_coords, keep_n_missing, keep_regressions, keep_rg_bins, keep_rg_uncs = \ predictor_file.weighted_box_clustering(box_coords, box_scores, box_center_factor, box_n_overlaps, box_rg_bins, box_rg_uncs, box_regressions, box_ens_ix, clustering_iou, n_ens=1) for boxix in range(len(keep_scores)): clustered_box = {'box_type': 'det', 'box_coords': keep_coords[boxix], 'box_score': keep_scores[boxix], 'cluster_n_missing': keep_n_missing[boxix], 'box_pred_class_id': cl} if regress_flag: clustered_box.update({'regression': keep_regressions[boxix], 'rg_uncertainty': keep_rg_uncs[boxix], 'rg_bin': keep_rg_bins[boxix]}) wbc_boxes.append(clustered_box) box_coords = np.array([box['box_coords'] for box in wbc_boxes]) box_scores = np.array([box['box_score'] for box in wbc_boxes]) box_cl_ids = np.array([box['box_pred_class_id'] for box in wbc_boxes]) ax2 = fig.add_subplot(grid[0, 2]) wbc_color = cf.black plg.plot_boxes(cf, box_coords, patch_size, box_scores, box_cl_ids, out_file=os.path.join(plot_dir, "demo_boxes_wbc_iou_{}.png".format(clustering_iou)), ax=ax2) ax2.text(*xyB, ' b) WBC', horizontalalignment='left', verticalalignment='center', transform=ax2.transAxes, weight='bold', color=wbc_color, fontsize=fs) # ax2.spines['bottom'].set_color(wbc_color) # ax2.spines['top'].set_color(wbc_color) # ax2.spines['right'].set_color(wbc_color) # ax2.spines['left'].set_color(wbc_color) from matplotlib.patches import ConnectionPatch con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA="axes fraction", coordsB="axes fraction", axesA=ax0, axesB=ax2, color=wbc_color, lw=1.5, arrowstyle='-|>') ax0.add_artist(con) con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA="axes fraction", coordsB="axes fraction", axesA=ax0, axesB=ax1, color=nms_color, lw=1.5, arrowstyle='-|>') ax0.add_artist(con) # ax0.text(0.5, 0.5, "Test", size=30, va="center", ha="center", rotation=30, # bbox=dict(boxstyle="angled,pad=0.5", alpha=0.2)) plg.plt.tight_layout() plg.plt.savefig(os.path.join(plot_dir, "box_clustering.pdf"), bbox_inches='tight') def sketch_AP_AUC(plot_dir=None, draw_auc=True): from sklearn.metrics import roc_curve, roc_auc_score from understanding_metrics import get_det_types import matplotlib.transforms as mtrans cf = get_cf('prostate', '') if plot_dir is None: plot_dir = cf.plot_dir if hasattr(cf, 'plot_dir') else os.path.join('.') if draw_auc: fig = plg.plt.figure(figsize=(7, 6)) #width, height # fig.subplots_adjust(hspace=0, wspace=0) grid = plg.plt.GridSpec(2, 2, wspace=0.23, hspace=.45, figure=fig) #rows, cols else: fig = plg.plt.figure(figsize=(12, 3)) #width, height # fig.subplots_adjust(hspace=0, wspace=0) grid = plg.plt.GridSpec(1, 3, wspace=0.23, hspace=.45, figure=fig) #rows, cols fs = 13 text_fs = 11 optim_color = cf.dark_green non_opt_color = cf.aubergine df = pd.DataFrame(columns=['pred_score', 'class_label', 'pred_class', 'det_type', 'match_iou']) df2 = df.copy() df["pred_score"] = [0,0.3,0.25,0.2, 0.8, 0.9, 0.9, 0.9, 0.9] df["class_label"] = [0,0,0,0, 1, 1, 1, 1, 1] df["det_type"] = get_det_types(df) df["match_iou"] = [0.1] * len(df) df2["pred_score"] = [0, 0.77, 0.5, 1., 0.5, 0.35, 0.3, 0., 0.7, 0.85, 0.9] df2["class_label"] = [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1] df2["det_type"] = get_det_types(df2) df2["match_iou"] = [0.1] * len(df2) #------ PRC ------- # optimal if draw_auc: ax = fig.add_subplot(grid[1, 0]) else: ax = fig.add_subplot(grid[0, 2]) pr, rc = evaluator_file.compute_prc(df) ax.plot(rc, pr, color=optim_color, label="Optimal Detection") ax.fill_between(rc, pr, alpha=0.33, color=optim_color) # suboptimal pr, rc = evaluator_file.compute_prc(df2) ax.plot(rc, pr, color=non_opt_color, label="Suboptimal") ax.fill_between(rc, pr, alpha=0.33, color=non_opt_color) #plt.title() #plt.legend(loc=3 if c == 'prc' else 4) ax.set_ylabel('precision', fontsize=text_fs) ax.set_ylim((0., 1.1)) ax.set_xlabel('recall', fontsize=text_fs) ax.set_title('Precision-Recall Curves', fontsize=fs) #ax.legend(ncol=2, loc='center')#, bbox_to_anchor=(0.5, 1.05)) #---- ROC curve if draw_auc: ax = fig.add_subplot(grid[1, 1]) roc = roc_curve(df.class_label.tolist(), df.pred_score.tolist()) ax.plot(roc[0], roc[1], color=optim_color) ax.fill_between(roc[0], roc[1], alpha=0.33, color=optim_color) ax.set_xlabel('false-positive rate', fontsize=text_fs) ax.set_ylim((0., 1.1)) ax.set_ylabel('recall', fontsize=text_fs) roc = roc_curve(df2.class_label.tolist(), df2.pred_score.tolist()) ax.plot(roc[0], roc[1], color=non_opt_color) ax.fill_between(roc[0], roc[1], alpha=0.33, color=non_opt_color) roc = ([0, 1], [0, 1]) ax.plot(roc[0], roc[1], color=cf.gray, linestyle='dashed', label="random predictor") ax.set_title('ROC Curves', fontsize=fs) ax.legend(ncol=2, loc='lower right', fontsize=text_fs) #--- hist optimal text_left = 0.05 ax = fig.add_subplot(grid[0, 0]) tn_count = df.det_type.tolist().count('det_tn') AUC = roc_auc_score(df.class_label, df.pred_score) df = df[(df.det_type=="det_tp") | (df.det_type=="det_fp") | (df.det_type=="det_fn")] labels = df.class_label.values preds = df.pred_score.values type_list = df.det_type.tolist() ax.hist(preds[labels == 0], alpha=0.3, color=cf.red, range=(0, 1), bins=50, label="FP") ax.hist(preds[labels == 1], alpha=0.3, color=cf.blue, range=(0, 1), bins=50, label="FN at score 0 and TP") #ax.axvline(x=cf.min_det_thresh, alpha=0.4, color=cf.orange, linewidth=1.5, label="min det thresh") fp_count = type_list.count('det_fp') fn_count = type_list.count('det_fn') tp_count = type_list.count('det_tp') pos_count = fn_count + tp_count if draw_auc: text = "AP: {:.2f} ROC-AUC: {:.2f}\n".format(evaluator_file.get_roi_ap_from_df((df, 0.0, False)), AUC) else: text = "AP: {:.2f}\n".format(evaluator_file.get_roi_ap_from_df((df, 0.0, False))) text += 'TP: {} FP: {} FN: {} TN: {}\npositives: {}'.format(tp_count, fp_count, fn_count, tn_count, pos_count) ax.text(text_left,4, text, fontsize=text_fs) ax.set_yscale('log') ax.set_ylim(bottom=10**-2, top=10**2) ax.set_xlabel("prediction score", fontsize=text_fs) ax.set_ylabel("occurences", fontsize=text_fs) #autoAxis = ax.axis() # rec = plg.mpatches.Rectangle((autoAxis[0] - 0.7, autoAxis[2] - 0.2), (autoAxis[1] - autoAxis[0]) + 1, # (autoAxis[3] - autoAxis[2]) + 0.4, fill=False, lw=2) # rec = plg.mpatches.Rectangle((autoAxis[0] , autoAxis[2] ), (autoAxis[1] - autoAxis[0]) , # (autoAxis[3] - autoAxis[2]) , fill=False, lw=2, color=optim_color) # rec = ax.add_patch(rec) # rec.set_clip_on(False) plg.plt.setp(ax.spines.values(), color=optim_color, linewidth=2) ax.set_facecolor((*optim_color,0.1)) ax.set_title("Detection Histograms", fontsize=fs) ax = fig.add_subplot(grid[0, 1]) tn_count = df2.det_type.tolist().count('det_tn') AUC = roc_auc_score(df2.class_label, df2.pred_score) df2 = df2[(df2.det_type=="det_tp") | (df2.det_type=="det_fp") | (df2.det_type=="det_fn")] labels = df2.class_label.values preds = df2.pred_score.values type_list = df2.det_type.tolist() ax.hist(preds[labels == 0], alpha=0.3, color=cf.red, range=(0, 1), bins=50, label="FP") ax.hist(preds[labels == 1], alpha=0.3, color=cf.blue, range=(0, 1), bins=50, label="FN at score 0 and TP") # ax.axvline(x=cf.min_det_thresh, alpha=0.4, color=cf.orange, linewidth=1.5, label="min det thresh") fp_count = type_list.count('det_fp') fn_count = type_list.count('det_fn') tp_count = type_list.count('det_tp') pos_count = fn_count + tp_count if draw_auc: text = "AP: {:.2f} ROC-AUC: {:.2f}\n".format(evaluator_file.get_roi_ap_from_df((df2, 0.0, False)), AUC) else: text = "AP: {:.2f}\n".format(evaluator_file.get_roi_ap_from_df((df2, 0.0, False))) text += 'TP: {} FP: {} FN: {} TN: {}\npositives: {}'.format(tp_count, fp_count, fn_count, tn_count, pos_count) ax.text(text_left, 4*10**0, text, fontsize=text_fs) ax.set_yscale('log') ax.margins(y=10e2) ax.set_ylim(bottom=10**-2, top=10**2) ax.set_xlabel("prediction score", fontsize=text_fs) ax.set_yticks([]) plg.plt.setp(ax.spines.values(), color=non_opt_color, linewidth=2) ax.set_facecolor((*non_opt_color, 0.05)) ax.legend(ncol=2, loc='upper center', bbox_to_anchor=(0.5, 1.18), fontsize=text_fs) if draw_auc: # Draw a horizontal line line = plg.plt.Line2D([0.1, .9], [0.48, 0.48], transform=fig.transFigure, color="black") fig.add_artist(line) outfile = os.path.join(plot_dir, "metrics.png") print("Saving plot to {}".format(outfile)) plg.plt.savefig(outfile, bbox_inches='tight', dpi=600) return def draw_toy_cylinders(plot_dir=None): source_path = "datasets/toy" if plot_dir is None: plot_dir = os.path.join(source_path, "misc") #plot_dir = '/home/gregor/Dropbox/Thesis/Main/tmp' os.makedirs(plot_dir, exist_ok=True) cf = get_cf('toy', '') cf.pre_crop_size = [2200, 2200,1] #y,x,z; #cf.dim = 2 cf.ambiguities = {"radius_calib": (1., 1. / 6) } cf.pp_blur_min_intensity = 0.2 generate_toys = utils.import_module("generate_toys", os.path.join(source_path, 'generate_toys.py')) ToyGen = generate_toys.ToyGenerator(cf) fig = plg.plt.figure(figsize=(10, 8.2)) #width, height grid = plg.plt.GridSpec(4, 5, wspace=0.0, hspace=.0, figure=fig) #rows, cols fs, text_fs = 16, 14 text_x, text_y = 0.5, 0.85 true_gt_col, dist_gt_col = cf.dark_green, cf.blue true_cmap = {1:true_gt_col} img = np.random.normal(loc=0.0, scale=cf.noise_scale, size=ToyGen.sample_size) img[img < 0.] = 0. # one-hot-encoded seg seg = np.zeros((cf.num_classes + 1, *ToyGen.sample_size)).astype('uint8') undistorted_seg = np.copy(seg) applied_gt_distort = False class_id, shape = 1, 'cylinder' #all_radii = ToyGen.generate_sample_radii(class_ids, shapes) enlarge_f = 20 all_radii = np.array([np.mean(label.bin_vals) if label.id!=5 else label.bin_vals[0]+5 for label in cf.bin_labels if label.id!=0]) bins = [(min(label.bin_vals), max(label.bin_vals)) for label in cf.bin_labels] bin_edges = [(bins[i][1] + bins[i + 1][0])*enlarge_f / 2 for i in range(len(bins) - 1)] all_radii = [np.array([r*enlarge_f, r*enlarge_f, 1]) for r in all_radii] # extend to required 3D format regress_targets, undistorted_rg_targets = [], [] ics = np.argwhere(np.ones(seg[0].shape)) # indices ics equal positions within img/volume center = np.array([dim//2 for dim in img.shape]) # for illustrating GT distribution, keep scale same size #x = np.linspace(mu - 300, mu + 300, 100) x = np.linspace(0, 50*enlarge_f, 500) ax_gauss = fig.add_subplot(grid[3, :]) mus, sigmas = [], [] for roi_ix, radii in enumerate(all_radii): print('processing {} {}'.format(roi_ix, radii)) cur_img, cur_seg, cur_undistorted_seg, cur_regress_targets, cur_undistorted_rg_targets, cur_applied_gt_distort = \ ToyGen.draw_object(img.copy(), seg.copy(), undistorted_seg, ics, regress_targets, undistorted_rg_targets, applied_gt_distort, roi_ix, class_id, shape, np.copy(radii), center) ax = fig.add_subplot(grid[0,roi_ix]) ax.imshow(cur_img[...,0], cmap='gray', vmin=0) ax.set_title("r{}".format(roi_ix+1), fontsize=fs) if roi_ix==0: ax.set_ylabel(r"$\mathbf{a)}$ Input", fontsize=fs) plg.suppress_axes_lines(ax) else: ax.axis('off') ax = fig.add_subplot(grid[1, roi_ix]) ax.imshow(cur_img[..., 0], cmap='gray') ax.imshow(plg.to_rgba(np.argmax(cur_undistorted_seg[...,0], axis=0), true_cmap), alpha=0.8) ax.text(text_x, text_y, r"$r_{a}=$"+"{:.1f}".format(cur_undistorted_rg_targets[roi_ix][0]/enlarge_f), transform=ax.transAxes, color=cf.white, bbox=dict(facecolor=true_gt_col, alpha=0.7, edgecolor=cf.white, clip_on=False,pad=2.5), fontsize=text_fs, ha='center', va='center') if roi_ix==0: ax.set_ylabel(r"$\mathbf{b)}$ Exact GT", fontsize=fs) plg.suppress_axes_lines(ax) else: ax.axis('off') ax = fig.add_subplot(grid[2, roi_ix]) ax.imshow(cur_img[..., 0], cmap='gray') ax.imshow(plg.to_rgba(np.argmax(cur_seg[..., 0], axis=0), cf.cmap), alpha=0.7) ax.text(text_x, text_y, r"$r_{a}=$"+"{:.1f}".format(cur_regress_targets[roi_ix][0]/enlarge_f), transform=ax.transAxes, color=cf.white, bbox=dict(facecolor=cf.blue, alpha=0.7, edgecolor=cf.white, clip_on=False,pad=2.5), fontsize=text_fs, ha='center', va='center') if roi_ix == 0: ax.set_ylabel(r"$\mathbf{c)}$ Noisy GT", fontsize=fs) plg.suppress_axes_lines(ax) else: ax.axis('off') # GT distributions assert radii[0]==radii[1] mu, sigma = radii[0], radii[0] * cf.ambiguities["radius_calib"][1] ax_gauss.axvline(mu, color=true_gt_col) ax_gauss.text(mu, -0.003, "$r=${:.0f}".format(mu/enlarge_f), color=true_gt_col, fontsize=text_fs, ha='center', va='center', bbox = dict(facecolor='none', alpha=0.7, edgecolor=true_gt_col, clip_on=False, pad=2.5)) mus.append(mu); sigmas.append(sigma) lower_bound = max(bin_edges[roi_ix], min(x))# if roi_ix>0 else 2*mu-bin_edges[roi_ix+1] upper_bound = bin_edges[roi_ix+1] if len(bin_edges)>roi_ix+1 else max(x)#2*mu-bin_edges[roi_ix] if roi_ix, head_length = 0.05, head_width = .005", lw=1)) #ax_gauss.arrow(1, 0.5, 0., 0.1) handles = [plg.mpatches.Patch(facecolor=dist_gt_col, label='Inexact Seg.', alpha=0.7, edgecolor='none'), mlines.Line2D([], [], color=dist_gt_col, marker=r'$\curlywedge$', linestyle='none', markersize=11, label='GT Sampling Distr.'), mlines.Line2D([], [], color=true_gt_col, marker='|', markersize=12, label='Exact GT Radius.', linestyle='none'), plg.mpatches.Patch(facecolor=true_gt_col, label='a)-c) Exact Seg., d) Bin', alpha=0.7, edgecolor='none')] fig.legend(handles=handles, loc="lower center", ncol=len(handles), fontsize=text_fs) outfile = os.path.join(plot_dir, "toy_cylinders.png") print("Saving plot to {}".format(outfile)) plg.plt.savefig(outfile, bbox_inches='tight', dpi=600) return def seg_det_cityscapes_example(plot_dir=None): cf = get_cf('cityscapes', '') source_path = "datasets/cityscapes" if plot_dir is None: plot_dir = os.path.join(source_path, "misc") os.makedirs(plot_dir, exist_ok=True) dl = utils.import_module("dl", os.path.join(source_path, 'data_loader.py')) #from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates data_set = dl.Dataset(cf) Converter = dl.ConvertSegToBoundingBoxCoordinates(2, cf.roi_items) fig = plg.plt.figure(figsize=(9, 3)) #width, height grid = plg.plt.GridSpec(1, 2, wspace=0.05, hspace=.0, figure=fig) #rows, cols fs, text_fs = 12, 10 nice_imgs = ["bremen000099000019", "hamburg000000033506", "frankfurt000001058914",] img_id = nice_imgs[2] #img_id = np.random.choice(data_set.set_ids) print("Selected img", img_id) img = np.load(data_set[img_id]["img"]).transpose(1,2,0) seg = np.load(data_set[img_id]["seg"]) cl_targs = data_set[img_id]["class_targets"] roi_ids = np.unique(seg[seg > 0]) # ---- detection example ----- cl_id2name = {1: "h", 2: "v"} color_palette = [cf.purple, cf.aubergine, cf.magenta, cf.dark_blue, cf.blue, cf.bright_blue, cf.cyan, cf.dark_green, cf.green, cf.dark_yellow, cf.yellow, cf.orange, cf.red, cf.dark_red, cf.bright_red] n_colors = len(color_palette) cmap = {roi_id : color_palette[(roi_id-1)%n_colors] for roi_id in roi_ids} cmap[0] = (1,1,1,0.) ax = fig.add_subplot(grid[0, 1]) ax.imshow(img) ax.imshow(plg.to_rgba(seg, cmap), alpha=0.7) data_dict = Converter(**{'seg':seg[np.newaxis, np.newaxis], 'class_targets': [cl_targs]}) # needs batch dim and channel for roi_ix, bb_target in enumerate(data_dict['bb_target'][0]): [y1, x1, y2, x2] = bb_target width, height = x2 - x1, y2 - y1 cl_id = cl_targs[roi_ix] label = cf.class_id2label[cl_id] text_x, text_y = x2, y1 id_text = cl_id2name[cl_id] text_str = '{}'.format(id_text) text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) #ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=text_fs, ha="center", va="center") edgecolor = label.color bbox = plg.mpatches.Rectangle((x1, y1), width, height, linewidth=1.05, edgecolor=edgecolor, facecolor='none') ax.add_patch(bbox) ax.axis('off') # ---- seg example ----- for roi_id in roi_ids: seg[seg==roi_id] = cl_targs[roi_id-1] ax = fig.add_subplot(grid[0,0]) ax.imshow(img) ax.imshow(plg.to_rgba(seg, cf.cmap), alpha=0.7) ax.axis('off') plg.plt.tight_layout() outfile = os.path.join(plot_dir, "cityscapes_example.png") print("Saving plot to {}".format(outfile)) plg.plt.savefig(outfile, bbox_inches='tight', dpi=600) if __name__=="__main__": stime = time.time() #seg_det_cityscapes_example() #box_clustering() #sketch_AP_AUC(draw_auc=False) #draw_toy_cylinders() #prostate_GT_examples(plot_dir="/home/gregor/Dropbox/Thesis/Main/MFPPresentation/graphics") #prostate_results_static() #prostate_dataset_stats(plot_dir="/home/gregor/Dropbox/Thesis/Main/MFPPresentation/graphics", show_splits=False) #lidc_dataset_stats() #lidc_sa_dataset_stats() #lidc_annotator_confusion() #lidc_merged_sa_joint_plot() #lidc_annotator_dissent_images() exp_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_frcnn3d_cl_bs6" #multiple_clustering_results('prostate', exp_dir, plot_hist=True) exp_parent_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments" exp_parent_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments_debug_retinas" #get_plot_clustering_results('prostate', exp_parent_dir, res_from_file=False) exp_dir = "/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_frcnn3d_cl_bs6" #cf = get_cf('prostate', exp_dir) #plot_file = os.path.join(exp_dir, "inference_analysis/bytes_merged_boxes_fold_1_pid_177.pkl") #plot_single_results(cf, exp_dir, plot_file) exp_dir1 = "/home/gregor/networkdrives/E132-Cluster-Projects/lidc_sa/experiments/ms12345_mrcnn3d_rg_bs8" exp_dir2 = "/home/gregor/networkdrives/E132-Cluster-Projects/lidc_sa/experiments/ms12345_mrcnn3d_rgbin_bs8" #find_suitable_examples(exp_dir1, exp_dir2) #plot_single_results_lidc() plot_dir = "/home/gregor/Dropbox/Thesis/MICCAI2019/Graphics" #lidc_results_static(plot_dir=plot_dir) #toy_results_static(plot_dir=plot_dir) plot_lidc_dissent_and_example(plot_dir=plot_dir, confusion_matrix=True, numbering=False, example_title="LIDC example result") mins, secs = divmod((time.time() - stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) \ No newline at end of file diff --git a/predictor.py b/predictor.py index 2e2b699..d2a9e60 100644 --- a/predictor.py +++ b/predictor.py @@ -1,1006 +1,1006 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import os from multiprocessing import Pool import pickle import time import numpy as np import torch from scipy.stats import norm from collections import OrderedDict import plotting as plg import utils.model_utils as mutils import utils.exp_utils as utils def get_mirrored_patch_crops(patch_crops, org_img_shape): mirrored_patch_crops = [] mirrored_patch_crops.append([[org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2], ii[3]] if len(ii) == 4 else [org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2], ii[3], ii[4], ii[5]] for ii in patch_crops]) mirrored_patch_crops.append([[ii[0], ii[1], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2]] if len(ii) == 4 else [ii[0], ii[1], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2], ii[4], ii[5]] for ii in patch_crops]) mirrored_patch_crops.append([[org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2]] if len(ii) == 4 else [org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2], ii[4], ii[5]] for ii in patch_crops]) return mirrored_patch_crops def get_mirrored_patch_crops_ax_dep(patch_crops, org_img_shape, mirror_axes): mirrored_patch_crops = [] for ax_ix, axes in enumerate(mirror_axes): if isinstance(axes, (int, float)) and int(axes) == 0: mirrored_patch_crops.append([[org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2], ii[3]] if len(ii) == 4 else [org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2], ii[3], ii[4], ii[5]] for ii in patch_crops]) elif isinstance(axes, (int, float)) and int(axes) == 1: mirrored_patch_crops.append([[ii[0], ii[1], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2]] if len(ii) == 4 else [ii[0], ii[1], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2], ii[4], ii[5]] for ii in patch_crops]) elif hasattr(axes, "__iter__") and (tuple(axes) == (0, 1) or tuple(axes) == (1, 0)): mirrored_patch_crops.append([[org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2]] if len(ii) == 4 else [org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2], ii[4], ii[5]] for ii in patch_crops]) else: raise Exception("invalid mirror axes {} in get mirrored patch crops".format(axes)) return mirrored_patch_crops def apply_wbc_to_patient(inputs): """ wrapper around prediction box consolidation: weighted box clustering (wbc). processes a single patient. loops over batch elements in patient results (1 in 3D, slices in 2D) and foreground classes, aggregates and stores results in new list. :return. patient_results_list: list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, and a dummy batch dimension of 1 for 3D predictions. :return. pid: string. patient id. """ regress_flag, in_patient_results_list, pid, class_dict, clustering_iou, n_ens = inputs out_patient_results_list = [[] for _ in range(len(in_patient_results_list))] for bix, b in enumerate(in_patient_results_list): for cl in list(class_dict.keys()): boxes = [(ix, box) for ix, box in enumerate(b) if (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)] box_coords = np.array([b[1]['box_coords'] for b in boxes]) box_scores = np.array([b[1]['box_score'] for b in boxes]) box_center_factor = np.array([b[1]['box_patch_center_factor'] for b in boxes]) box_n_overlaps = np.array([b[1]['box_n_overlaps'] for b in boxes]) try: box_patch_id = np.array([b[1]['patch_id'] for b in boxes]) except KeyError: #backward compatibility for already saved pred results ... omg box_patch_id = np.array([b[1]['ens_ix'] for b in boxes]) box_regressions = np.array([b[1]['regression'] for b in boxes]) if regress_flag else None box_rg_bins = np.array([b[1]['rg_bin'] if 'rg_bin' in b[1].keys() else float('NaN') for b in boxes]) box_rg_uncs = np.array([b[1]['rg_uncertainty'] if 'rg_uncertainty' in b[1].keys() else float('NaN') for b in boxes]) if 0 not in box_scores.shape: keep_scores, keep_coords, keep_n_missing, keep_regressions, keep_rg_bins, keep_rg_uncs = \ weighted_box_clustering(box_coords, box_scores, box_center_factor, box_n_overlaps, box_rg_bins, box_rg_uncs, box_regressions, box_patch_id, clustering_iou, n_ens) for boxix in range(len(keep_scores)): clustered_box = {'box_type': 'det', 'box_coords': keep_coords[boxix], 'box_score': keep_scores[boxix], 'cluster_n_missing': keep_n_missing[boxix], 'box_pred_class_id': cl} if regress_flag: clustered_box.update({'regression': keep_regressions[boxix], 'rg_uncertainty': keep_rg_uncs[boxix], 'rg_bin': keep_rg_bins[boxix]}) out_patient_results_list[bix].append(clustered_box) # add gt boxes back to new output list. out_patient_results_list[bix].extend([box for box in b if box['box_type'] == 'gt']) return [out_patient_results_list, pid] def weighted_box_clustering(box_coords, scores, box_pc_facts, box_n_ovs, box_rg_bins, box_rg_uncs, box_regress, box_patch_id, thresh, n_ens): """Consolidates overlapping predictions resulting from patch overlaps, test data augmentations and temporal ensembling. clusters predictions together with iou > thresh (like in NMS). Output score and coordinate for one cluster are the average weighted by individual patch center factors (how trustworthy is this candidate measured by how centered its position within the patch is) and the size of the corresponding box. The number of expected predictions at a position is n_data_aug * n_temp_ens * n_overlaps_at_position (1 prediction per unique patch). Missing predictions at a cluster position are defined as the number of unique patches in the cluster, which did not contribute any predict any boxes. :param dets: (n_dets, (y1, x1, y2, x2, (z1), (z2), scores, box_pc_facts, box_n_ovs). :param box_coords: y1, x1, y2, x2, (z1), (z2). :param scores: confidence scores. :param box_pc_facts: patch-center factors from position on patch tiles. :param box_n_ovs: number of patch overlaps at box position. :param box_rg_bins: regression bin predictions. :param box_rg_uncs: (n_dets,) regression uncertainties (from model mrcnn_aleatoric). :param box_regress: (n_dets, n_regression_features). :param box_patch_id: ensemble index. :param thresh: threshold for iou_matching. :param n_ens: number of models, that are ensembled. (-> number of expected predictions per position). :return: keep_scores: (n_keep) new scores of boxes to be kept. :return: keep_coords: (n_keep, (y1, x1, y2, x2, (z1), (z2)) new coordinates of boxes to be kept. """ dim = 2 if box_coords.shape[1] == 4 else 3 y1 = box_coords[:,0] x1 = box_coords[:,1] y2 = box_coords[:,2] x2 = box_coords[:,3] areas = (y2 - y1 + 1) * (x2 - x1 + 1) if dim == 3: z1 = box_coords[:, 4] z2 = box_coords[:, 5] areas *= (z2 - z1 + 1) # order is the sorted index. maps order to index o[1] = 24 (rank1, ix 24) order = scores.argsort()[::-1] keep_scores = [] keep_coords = [] keep_n_missing = [] keep_regress = [] keep_rg_bins = [] keep_rg_uncs = [] while order.size > 0: i = order[0] # highest scoring element yy1 = np.maximum(y1[i], y1[order]) xx1 = np.maximum(x1[i], x1[order]) yy2 = np.minimum(y2[i], y2[order]) xx2 = np.minimum(x2[i], x2[order]) w = np.maximum(0, xx2 - xx1 + 1) h = np.maximum(0, yy2 - yy1 + 1) inter = w * h if dim == 3: zz1 = np.maximum(z1[i], z1[order]) zz2 = np.minimum(z2[i], z2[order]) d = np.maximum(0, zz2 - zz1 + 1) inter *= d # overlap between currently highest scoring box and all boxes. ovr = inter / (areas[i] + areas[order] - inter) ovr_fl = inter.astype('float64') / (areas[i] + areas[order] - inter.astype('float64')) assert np.all(ovr==ovr_fl), "ovr {}\n ovr_float {}".format(ovr, ovr_fl) # get all the predictions that match the current box to build one cluster. matches = np.nonzero(ovr > thresh)[0] match_n_ovs = box_n_ovs[order[matches]] match_pc_facts = box_pc_facts[order[matches]] match_patch_id = box_patch_id[order[matches]] match_ov_facts = ovr[matches] match_areas = areas[order[matches]] match_scores = scores[order[matches]] # weight all scores in cluster by patch factors, and size. match_score_weights = match_ov_facts * match_areas * match_pc_facts match_scores *= match_score_weights # for the weighted average, scores have to be divided by the number of total expected preds at the position # of the current cluster. 1 Prediction per patch is expected. therefore, the number of ensembled models is # multiplied by the mean overlaps of patches at this position (boxes of the cluster might partly be # in areas of different overlaps). n_expected_preds = n_ens * np.mean(match_n_ovs) # the number of missing predictions is obtained as the number of patches, # which did not contribute any prediction to the current cluster. n_missing_preds = np.max((0, n_expected_preds - np.unique(match_patch_id).shape[0])) # missing preds are given the mean weighting # (expected prediction is the mean over all predictions in cluster). denom = np.sum(match_score_weights) + n_missing_preds * np.mean(match_score_weights) # compute weighted average score for the cluster avg_score = np.sum(match_scores) / denom # compute weighted average of coordinates for the cluster. now only take existing # predictions into account. avg_coords = [np.sum(y1[order[matches]] * match_scores) / np.sum(match_scores), np.sum(x1[order[matches]] * match_scores) / np.sum(match_scores), np.sum(y2[order[matches]] * match_scores) / np.sum(match_scores), np.sum(x2[order[matches]] * match_scores) / np.sum(match_scores)] if dim == 3: avg_coords.append(np.sum(z1[order[matches]] * match_scores) / np.sum(match_scores)) avg_coords.append(np.sum(z2[order[matches]] * match_scores) / np.sum(match_scores)) if box_regress is not None: # compute wt. avg. of regression vectors (component-wise average) avg_regress = np.sum(box_regress[order[matches]] * match_scores[:, np.newaxis], axis=0) / np.sum( match_scores) avg_rg_bins = np.round(np.sum(box_rg_bins[order[matches]] * match_scores) / np.sum(match_scores)) avg_rg_uncs = np.sum(box_rg_uncs[order[matches]] * match_scores) / np.sum(match_scores) else: avg_regress = np.array(float('NaN')) avg_rg_bins = np.array(float('NaN')) avg_rg_uncs = np.array(float('NaN')) # some clusters might have very low scores due to high amounts of missing predictions. # filter out the with a conservative threshold, to speed up evaluation. if avg_score > 0.01: keep_scores.append(avg_score) keep_coords.append(avg_coords) keep_n_missing.append((n_missing_preds / n_expected_preds * 100)) # relative keep_regress.append(avg_regress) keep_rg_uncs.append(avg_rg_uncs) keep_rg_bins.append(avg_rg_bins) # get index of all elements that were not matched and discard all others. inds = np.nonzero(ovr <= thresh)[0] inds_where = np.where(ovr<=thresh)[0] assert np.all(inds == inds_where), "inds_nonzero {} \ninds_where {}".format(inds, inds_where) order = order[inds] return keep_scores, keep_coords, keep_n_missing, keep_regress, keep_rg_bins, keep_rg_uncs def apply_nms_to_patient(inputs): in_patient_results_list, pid, class_dict, iou_thresh = inputs out_patient_results_list = [] # collect box predictions over batch dimension (slices) and store slice info as slice_ids. for batch in in_patient_results_list: batch_el_boxes = [] for cl in list(class_dict.keys()): det_boxes = [box for box in batch if (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)] box_coords = np.array([box['box_coords'] for box in det_boxes]) box_scores = np.array([box['box_score'] for box in det_boxes]) if 0 not in box_scores.shape: keep_ix = mutils.nms_numpy(box_coords, box_scores, iou_thresh) else: keep_ix = [] batch_el_boxes += [det_boxes[ix] for ix in keep_ix] batch_el_boxes += [box for box in batch if box['box_type'] == 'gt'] out_patient_results_list.append(batch_el_boxes) assert len(in_patient_results_list) == len(out_patient_results_list), "batch dim needs to be maintained, in: {}, out {}".format(len(in_patient_results_list), len(out_patient_results_list)) return [out_patient_results_list, pid] def nms_2to3D(dets, thresh): """ Merges 2D boxes to 3D cubes. For this purpose, boxes of all slices are regarded as lying in one slice. An adaptation of Non-maximum suppression is applied where clusters are found (like in NMS) with the extra constraint that suppressed boxes have to have 'connected' z coordinates w.r.t the core slice (cluster center, highest scoring box, the prevailing box). 'connected' z-coordinates are determined as the z-coordinates with predictions until the first coordinate for which no prediction is found. example: a cluster of predictions was found overlap > iou thresh in xy (like NMS). The z-coordinate of the highest scoring box is 50. Other predictions have 23, 46, 48, 49, 51, 52, 53, 56, 57. Only the coordinates connected with 50 are clustered to one cube: 48, 49, 51, 52, 53. (46 not because nothing was found in 47, so 47 is a 'hole', which interrupts the connection). Only the boxes corresponding to these coordinates are suppressed. All others are kept for building of further clusters. This algorithm works better with a certain min_confidence of predictions, because low confidence (e.g. noisy/cluttery) predictions can break the relatively strong assumption of defining cubes' z-boundaries at the first 'hole' in the cluster. :param dets: (n_detections, (y1, x1, y2, x2, scores, slice_id) :param thresh: iou matchin threshold (like in NMS). :return: keep: (n_keep,) 1D tensor of indices to be kept. :return: keep_z: (n_keep, [z1, z2]) z-coordinates to be added to boxes, which are kept in order to form cubes. """ y1 = dets[:, 0] x1 = dets[:, 1] y2 = dets[:, 2] x2 = dets[:, 3] assert np.all(y1 <= y2) and np.all(x1 <= x2), """"the definition of the coordinates is crucially important here: where maximum is taken needs to be the lower coordinate""" scores = dets[:, -2] slice_id = dets[:, -1] areas = (x2 - x1 + 1) * (y2 - y1 + 1) order = scores.argsort()[::-1] keep = [] keep_z = [] while order.size > 0: # order is the sorted index. maps order to index: order[1] = 24 means (rank1, ix 24) i = order[0] # highest scoring element yy1 = np.maximum(y1[i], y1[order]) # highest scoring element still in >order<, is compared to itself: okay? xx1 = np.maximum(x1[i], x1[order]) yy2 = np.minimum(y2[i], y2[order]) xx2 = np.minimum(x2[i], x2[order]) h = np.maximum(0.0, yy2 - yy1 + 1) w = np.maximum(0.0, xx2 - xx1 + 1) inter = h * w iou = inter / (areas[i] + areas[order] - inter) matches = np.argwhere( iou > thresh) # get all the elements that match the current box and have a lower score slice_ids = slice_id[order[matches]] core_slice = slice_id[int(i)] upper_holes = [ii for ii in np.arange(core_slice, np.max(slice_ids)) if ii not in slice_ids] lower_holes = [ii for ii in np.arange(np.min(slice_ids), core_slice) if ii not in slice_ids] max_valid_slice_id = np.min(upper_holes) if len(upper_holes) > 0 else np.max(slice_ids) min_valid_slice_id = np.max(lower_holes) if len(lower_holes) > 0 else np.min(slice_ids) z_matches = matches[(slice_ids <= max_valid_slice_id) & (slice_ids >= min_valid_slice_id)] # expand by one z voxel since box content is surrounded w/o overlap, i.e., z-content computed as z2-z1 z1 = np.min(slice_id[order[z_matches]]) - 1 z2 = np.max(slice_id[order[z_matches]]) + 1 keep.append(i) keep_z.append([z1, z2]) order = np.delete(order, z_matches, axis=0) return keep, keep_z def apply_2d_3d_merging_to_patient(inputs): """ wrapper around 2Dto3D merging operation. Processes a single patient. Takes 2D patient results (slices in batch dimension) and returns 3D patient results (dummy batch dimension of 1). Applies an adaption of Non-Maximum Surpression (Detailed methodology is described in nms_2to3D). :return. results_dict_boxes: list over batch elements (1 in 3D). each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. :return. pid: string. patient id. """ in_patient_results_list, pid, class_dict, merge_3D_iou = inputs out_patient_results_list = [] for cl in list(class_dict.keys()): det_boxes, slice_ids = [], [] # collect box predictions over batch dimension (slices) and store slice info as slice_ids. for batch_ix, batch in enumerate(in_patient_results_list): batch_element_det_boxes = [(ix, box) for ix, box in enumerate(batch) if (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)] det_boxes += batch_element_det_boxes slice_ids += [batch_ix] * len(batch_element_det_boxes) box_coords = np.array([batch[1]['box_coords'] for batch in det_boxes]) box_scores = np.array([batch[1]['box_score'] for batch in det_boxes]) slice_ids = np.array(slice_ids) if 0 not in box_scores.shape: keep_ix, keep_z = nms_2to3D( np.concatenate((box_coords, box_scores[:, None], slice_ids[:, None]), axis=1), merge_3D_iou) else: keep_ix, keep_z = [], [] # store kept predictions in new results list and add corresponding z-dimension info to coordinates. for kix, kz in zip(keep_ix, keep_z): keep_box = det_boxes[kix][1] keep_box['box_coords'] = list(keep_box['box_coords']) + kz out_patient_results_list.append(keep_box) gt_boxes = [box for b in in_patient_results_list for box in b if box['box_type'] == 'gt'] if len(gt_boxes) > 0: assert np.all([len(box["box_coords"]) == 6 for box in gt_boxes]), "expanded preds to 3D but GT is 2D." out_patient_results_list += gt_boxes return [[out_patient_results_list], pid] # additional list wrapping is extra batch dim. class Predictor: """ Prediction pipeline: - receives a patched patient image (n_patches, c, y, x, (z)) from patient data loader. - forwards patches through model in chunks of batch_size. (method: batch_tiling_forward) - unmolds predictions (boxes and segmentations) to original patient coordinates. (method: spatial_tiling_forward) Ensembling (mode == 'test'): - for inference, forwards 4 mirrored versions of image to through model and unmolds predictions afterwards accordingly (method: data_aug_forward) - for inference, loads multiple parameter-sets of the trained model corresponding to different epochs. for each parameter-set loops over entire test set, runs prediction pipeline for each patient. (method: predict_test_set) Consolidation of predictions: - consolidates a patient's predictions (boxes, segmentations) collected over patches, data_aug- and temporal ensembling, performs clustering and weighted averaging (external function: apply_wbc_to_patient) to obtain consistent outptus. - for 2D networks, consolidates box predictions to 3D cubes via clustering (adaption of non-maximum surpression). (external function: apply_2d_3d_merging_to_patient) Ground truth handling: - dissmisses any ground truth boxes returned by the model (happens in validation mode, patch-based groundtruth) - if provided by data loader, adds patient-wise ground truth to the final predictions to be passed to the evaluator. """ def __init__(self, cf, net, logger, mode): self.cf = cf self.batch_size = cf.batch_size self.logger = logger self.mode = mode self.net = net self.n_ens = 1 self.rank_ix = '0' self.regress_flag = any(['regression' in task for task in self.cf.prediction_tasks]) if self.cf.merge_2D_to_3D_preds: assert self.cf.dim == 2, "Merge 2Dto3D only valid for 2D preds, but current dim is {}.".format(self.cf.dim) if self.mode == 'test': last_state_path = os.path.join(self.cf.fold_dir, 'last_state.pth') try: self.model_index = torch.load(last_state_path)["model_index"] self.model_index = self.model_index[self.model_index["rank"] <= self.cf.test_n_epochs] except FileNotFoundError: raise FileNotFoundError('no last_state/model_index file in fold directory. ' 'seems like you are trying to run testing without prior training...') self.n_ens = cf.test_n_epochs if self.cf.test_aug_axes is not None: self.n_ens *= (len(self.cf.test_aug_axes)+1) self.example_plot_dir = os.path.join(cf.test_dir, "example_plots") os.makedirs(self.example_plot_dir, exist_ok=True) def batch_tiling_forward(self, batch): """ calls the actual network forward method. in patch-based prediction, the batch dimension might be overladed with n_patches >> batch_size, which would exceed gpu memory. In this case, batches are processed in chunks of batch_size. validation mode calls the train method to monitor losses (returned ground truth objects are discarded). test mode calls the test forward method, no ground truth required / involved. :return. results_dict: stores the results for one patient. dictionary with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) - loss / class_loss (only in validation mode) """ img = batch['data'] if img.shape[0] <= self.batch_size: if self.mode == 'val': # call training method to monitor losses results_dict = self.net.train_forward(batch, is_validation=True) # discard returned ground-truth boxes (also training info boxes). results_dict['boxes'] = [[box for box in b if box['box_type'] == 'det'] for b in results_dict['boxes']] elif self.mode == 'test': results_dict = self.net.test_forward(batch, return_masks=self.cf.return_masks_in_test) else: # needs batch tiling split_ixs = np.split(np.arange(img.shape[0]), np.arange(img.shape[0])[::self.batch_size]) chunk_dicts = [] for chunk_ixs in split_ixs[1:]: # first split is elements before 0, so empty b = {k: batch[k][chunk_ixs] for k in batch.keys() if (isinstance(batch[k], np.ndarray) and batch[k].shape[0] == img.shape[0])} if self.mode == 'val': chunk_dicts += [self.net.train_forward(b, is_validation=True)] else: chunk_dicts += [self.net.test_forward(b, return_masks=self.cf.return_masks_in_test)] results_dict = {} # flatten out batch elements from chunks ([chunk, chunk] -> [b, b, b, b, ...]) results_dict['boxes'] = [item for d in chunk_dicts for item in d['boxes']] results_dict['seg_preds'] = np.array([item for d in chunk_dicts for item in d['seg_preds']]) if self.mode == 'val': # if hasattr(self.cf, "losses_to_monitor"): # loss_names = self.cf.losses_to_monitor # else: # loss_names = {name for dic in chunk_dicts for name in dic if 'loss' in name} # estimate patient loss by mean over batch_chunks. Most similar to training loss. results_dict['torch_loss'] = torch.mean(torch.cat([d['torch_loss'] for d in chunk_dicts])) results_dict['class_loss'] = np.mean([d['class_loss'] for d in chunk_dicts]) # discard returned ground-truth boxes (also training info boxes). results_dict['boxes'] = [[box for box in b if box['box_type'] == 'det'] for b in results_dict['boxes']] return results_dict def spatial_tiling_forward(self, batch, patch_crops = None, n_aug='0'): """ forwards batch to batch_tiling_forward method and receives and returns a dictionary with results. if patch-based prediction, the results received from batch_tiling_forward will be on a per-patch-basis. this method uses the provided patch_crops to re-transform all predictions to whole-image coordinates. Patch-origin information of all box-predictions will be needed for consolidation, hence it is stored as 'patch_id', which is a unique string for each patch (also takes current data aug and temporal epoch instances into account). all box predictions get additional information about the amount overlapping patches at the respective position (used for consolidation). :return. results_dict: stores the results for one patient. dictionary with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) - monitor_values (only in validation mode) returned dict is a flattened version with 1 batch instance (3D) or slices (2D) """ if patch_crops is not None: #print("patch_crops not None, applying patch center factor") patches_dict = self.batch_tiling_forward(batch) results_dict = {'boxes': [[] for _ in range(batch['original_img_shape'][0])]} #bc of ohe--> channel dim of seg has size num_classes out_seg_shape = list(batch['original_img_shape']) out_seg_shape[1] = patches_dict["seg_preds"].shape[1] out_seg_preds = np.zeros(out_seg_shape, dtype=np.float16) patch_overlap_map = np.zeros_like(out_seg_preds, dtype='uint8') for pix, pc in enumerate(patch_crops): if self.cf.dim == 3: out_seg_preds[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += patches_dict['seg_preds'][pix] patch_overlap_map[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += 1 elif self.cf.dim == 2: out_seg_preds[pc[4]:pc[5], :, pc[0]:pc[1], pc[2]:pc[3], ] += patches_dict['seg_preds'][pix] patch_overlap_map[pc[4]:pc[5], :, pc[0]:pc[1], pc[2]:pc[3], ] += 1 out_seg_preds[patch_overlap_map > 0] /= patch_overlap_map[patch_overlap_map > 0] results_dict['seg_preds'] = out_seg_preds for pix, pc in enumerate(patch_crops): patch_boxes = patches_dict['boxes'][pix] for box in patch_boxes: # add unique patch id for consolidation of predictions. box['patch_id'] = self.rank_ix + '_' + n_aug + '_' + str(pix) # boxes from the edges of a patch have a lower prediction quality, than the ones at patch-centers. # hence they will be down-weighted for consolidation, using the 'box_patch_center_factor', which is # obtained by a gaussian distribution over positions in the patch and average over spatial dimensions. # Also the info 'box_n_overlaps' is stored for consolidation, which represents the amount of # overlapping patches at the box's position. c = box['box_coords'] #box_centers = np.array([(c[ii] + c[ii+2])/2 for ii in range(len(c)//2)]) box_centers = [(c[ii] + c[ii + 2]) / 2 for ii in range(2)] if self.cf.dim == 3: box_centers.append((c[4] + c[5]) / 2) box['box_patch_center_factor'] = np.mean( [norm.pdf(bc, loc=pc, scale=pc * 0.8) * np.sqrt(2 * np.pi) * pc * 0.8 for bc, pc in zip(box_centers, np.array(self.cf.patch_size) / 2)]) if self.cf.dim == 3: c += np.array([pc[0], pc[2], pc[0], pc[2], pc[4], pc[4]]) int_c = [int(np.floor(ii)) if ix%2 == 0 else int(np.ceil(ii)) for ix, ii in enumerate(c)] box['box_n_overlaps'] = np.mean(patch_overlap_map[:, :, int_c[1]:int_c[3], int_c[0]:int_c[2], int_c[4]:int_c[5]]) results_dict['boxes'][0].append(box) else: c += np.array([pc[0], pc[2], pc[0], pc[2]]) int_c = [int(np.floor(ii)) if ix % 2 == 0 else int(np.ceil(ii)) for ix, ii in enumerate(c)] box['box_n_overlaps'] = np.mean( patch_overlap_map[pc[4], :, int_c[1]:int_c[3], int_c[0]:int_c[2]]) results_dict['boxes'][pc[4]].append(box) if self.mode == 'val': results_dict['torch_loss'] = patches_dict['torch_loss'] results_dict['class_loss'] = patches_dict['class_loss'] else: results_dict = self.batch_tiling_forward(batch) for b in results_dict['boxes']: for box in b: box['box_patch_center_factor'] = 1 box['box_n_overlaps'] = 1 box['patch_id'] = self.rank_ix + '_' + n_aug return results_dict def data_aug_forward(self, batch): """ in val_mode: passes batch through to spatial_tiling method without data_aug. in test_mode: if cf.test_aug is set in configs, createst 4 mirrored versions of the input image, passes all of them to the next processing step (spatial_tiling method) and re-transforms returned predictions to original image version. :return. results_dict: stores the results for one patient. dictionary with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) - loss / class_loss (only in validation mode) """ patch_crops = batch['patch_crop_coords'] if self.patched_patient else None results_list = [self.spatial_tiling_forward(batch, patch_crops)] org_img_shape = batch['original_img_shape'] if self.mode == 'test' and self.cf.test_aug_axes is not None: if isinstance(self.cf.test_aug_axes, (int, float)): self.cf.test_aug_axes = (self.cf.test_aug_axes,) #assert np.all(np.array(self.cf.test_aug_axes)= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']] assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']] chunk_dict['boxes'][ix][boxix]['box_coords'] = coords # re-transform segmentation predictions. chunk_dict['seg_preds'] = np.flip(chunk_dict['seg_preds'], axis=axis) elif hasattr(sp_axis, "__iter__") and tuple(sp_axis)==(0,1) or tuple(sp_axis)==(1,0): #NEED: mirrored patch crops are given as [(y-axis), (x-axis), (y-,x-axis)], obey this order! # mirroring along two axes at same time batch['data'] = np.flip(np.flip(img, axis=axis[0]), axis=axis[1]).copy() chunk_dict = self.spatial_tiling_forward(batch, mirrored_patch_crops[n_aug], n_aug=str(n_aug)) # re-transform coordinates. for ix in range(len(chunk_dict['boxes'])): for boxix in range(len(chunk_dict['boxes'][ix])): coords = chunk_dict['boxes'][ix][boxix]['box_coords'].copy() coords[sp_axis[0]] = org_img_shape[axis[0]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[0]+2] coords[sp_axis[0]+2] = org_img_shape[axis[0]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[0]] coords[sp_axis[1]] = org_img_shape[axis[1]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[1]+2] coords[sp_axis[1]+2] = org_img_shape[axis[1]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[1]] assert coords[2] >= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']] assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']] chunk_dict['boxes'][ix][boxix]['box_coords'] = coords # re-transform segmentation predictions. chunk_dict['seg_preds'] = np.flip(np.flip(chunk_dict['seg_preds'], axis=axis[0]), axis=axis[1]).copy() else: raise Exception("Invalid axis type {} in test augs".format(type(axis))) results_list.append(chunk_dict) batch['data'] = img # aggregate all boxes/seg_preds per batch element from data_aug predictions. results_dict = {} results_dict['boxes'] = [[item for d in results_list for item in d['boxes'][batch_instance]] for batch_instance in range(org_img_shape[0])] # results_dict['seg_preds'] = np.array([[item for d in results_list for item in d['seg_preds'][batch_instance]] # for batch_instance in range(org_img_shape[0])]) results_dict['seg_preds'] = np.stack([dic['seg_preds'] for dic in results_list], axis=1) # needs segs probs in seg_preds entry: results_dict['seg_preds'] = np.sum(results_dict['seg_preds'], axis=1) #add up seg probs from different augs per class if self.mode == 'val': results_dict['torch_loss'] = results_list[0]['torch_loss'] results_dict['class_loss'] = results_list[0]['class_loss'] return results_dict def load_saved_predictions(self): """loads raw predictions saved by self.predict_test_set. aggregates and/or merges 2D boxes to 3D cubes for evaluation (if model predicts 2D but evaluation is run in 3D), according to settings config. :return: list_of_results_per_patient: list over patient results. each entry is a dict with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. - 'batch_dices': dice scores as recorded in raw prediction results. - 'seg_preds': not implemented yet. could replace dices by seg preds to have raw seg info available, however would consume critically large memory amount. todo evaluation of instance/semantic segmentation. """ - results_file = 'pred_results.pkl' if not self.cf.held_out_test_set else 'pred_results_held_out.pkl' - if not self.cf.held_out_test_set or self.cf.eval_test_fold_wise: + results_file = 'pred_results.pkl' if not self.cf.hold_out_test_set else 'pred_results_held_out.pkl' + if not self.cf.hold_out_test_set or not self.cf.ensemble_folds: self.logger.info("loading saved predictions of fold {}".format(self.cf.fold)) with open(os.path.join(self.cf.fold_dir, results_file), 'rb') as handle: results_list = pickle.load(handle) box_results_list = [(res_dict["boxes"], pid) for res_dict, pid in results_list] da_factor = len(self.cf.test_aug_axes)+1 if self.cf.test_aug_axes is not None else 1 self.n_ens = self.cf.test_n_epochs * da_factor self.logger.info('loaded raw test set predictions with n_patients = {} and n_ens = {}'.format( len(results_list), self.n_ens)) else: self.logger.info("loading saved predictions of hold-out test set") fold_dirs = sorted([os.path.join(self.cf.exp_dir, f) for f in os.listdir(self.cf.exp_dir) if os.path.isdir(os.path.join(self.cf.exp_dir, f)) and f.startswith("fold")]) results_list = [] folds_loaded = 0 for fold in range(self.cf.n_cv_splits): fold_dir = os.path.join(self.cf.exp_dir, 'fold_{}'.format(fold)) if fold_dir in fold_dirs: with open(os.path.join(fold_dir, results_file), 'rb') as handle: fold_list = pickle.load(handle) results_list += fold_list folds_loaded += 1 else: self.logger.info("Skipping fold {} since no saved predictions found.".format(fold)) box_results_list = [] for res_dict, pid in results_list: #without filtering gt out: box_results_list.append((res_dict['boxes'], pid)) #it's usually not right to filter out gts here, is it? da_factor = len(self.cf.test_aug_axes)+1 if self.cf.test_aug_axes is not None else 1 self.n_ens = self.cf.test_n_epochs * da_factor * folds_loaded # -------------- aggregation of boxes via clustering ----------------- if self.cf.clustering == "wbc": self.logger.info('applying WBC to test-set predictions with iou {} and n_ens {} over {} patients'.format( self.cf.clustering_iou, self.n_ens, len(box_results_list))) mp_inputs = [[self.regress_flag, ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou, self.n_ens] for ii in box_results_list] del box_results_list pool = Pool(processes=self.cf.n_workers) box_results_list = pool.map(apply_wbc_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() del mp_inputs elif self.cf.clustering == "nms": self.logger.info('applying standard NMS to test-set predictions with iou {} over {} patients.'.format( self.cf.clustering_iou, len(box_results_list))) pool = Pool(processes=self.cf.n_workers) mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou] for ii in box_results_list] del box_results_list box_results_list = pool.map(apply_nms_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() del mp_inputs if self.cf.merge_2D_to_3D_preds: self.logger.info('applying 2Dto3D merging to test-set predictions with iou = {}.'.format(self.cf.merge_3D_iou)) pool = Pool(processes=self.cf.n_workers) mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.merge_3D_iou] for ii in box_results_list] box_results_list = pool.map(apply_2d_3d_merging_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() del mp_inputs for ix in range(len(results_list)): assert np.all(results_list[ix][1] == box_results_list[ix][1]), "pid mismatch between loaded and aggregated results" results_list[ix][0]["boxes"] = box_results_list[ix][0] return results_list # holds (results_dict, pid) def predict_patient(self, batch): """ predicts one patient. called either directly via loop over validation set in exec.py (mode=='val') or from self.predict_test_set (mode=='test). in val mode: adds 3D ground truth info to predictions and runs consolidation and 2Dto3D merging of predictions. in test mode: returns raw predictions (ground truth addition, consolidation, 2D to 3D merging are done in self.predict_test_set, because patient predictions across several epochs might be needed to be collected first, in case of temporal ensembling). :return. results_dict: stores the results for one patient. dictionary with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) - loss / class_loss (only in validation mode) """ #if self.mode=="test": # self.logger.info('predicting patient {} for fold {} '.format(np.unique(batch['pid']), self.cf.fold)) # True if patient is provided in patches and predictions need to be tiled. self.patched_patient = 'patch_crop_coords' in list(batch.keys()) # forward batch through prediction pipeline. results_dict = self.data_aug_forward(batch) #has seg probs in entry 'seg_preds' if self.mode == 'val': for b in range(batch['patient_bb_target'].shape[0]): for t in range(len(batch['patient_bb_target'][b])): gt_box = {'box_type': 'gt', 'box_coords': batch['patient_bb_target'][b][t], 'class_targets': batch['patient_class_targets'][b][t]} for name in self.cf.roi_items: gt_box.update({name : batch['patient_'+name][b][t]}) results_dict['boxes'][b].append(gt_box) if 'dice' in self.cf.metrics: if self.patched_patient: assert 'patient_seg' in batch.keys(), "Results_dict preds are in original patient shape." results_dict['batch_dices'] = mutils.dice_per_batch_and_class( results_dict['seg_preds'], batch["patient_seg"] if self.patched_patient else batch['seg'], self.cf.num_seg_classes, convert_to_ohe=True) if self.patched_patient and self.cf.clustering == "wbc": wbc_input = [self.regress_flag, results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.clustering_iou, self.n_ens] results_dict['boxes'] = apply_wbc_to_patient(wbc_input)[0] elif self.patched_patient: nms_inputs = [results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.clustering_iou] results_dict['boxes'] = apply_nms_to_patient(nms_inputs)[0] if self.cf.merge_2D_to_3D_preds: results_dict['2D_boxes'] = results_dict['boxes'] merge_dims_inputs = [results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.merge_3D_iou] results_dict['boxes'] = apply_2d_3d_merging_to_patient(merge_dims_inputs)[0] return results_dict def predict_test_set(self, batch_gen, return_results=True): """ wrapper around test method, which loads multiple (or one) epoch parameters (temporal ensembling), loops through the test set and collects predictions per patient. Also flattens the results per patient and epoch and adds optional ground truth boxes for evaluation. Saves out the raw result list for later analysis and optionally consolidates and returns predictions immediately. :return: (optionally) list_of_results_per_patient: list over patient results. each entry is a dict with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': not implemented yet. todo evaluation of instance/semantic segmentation. """ # -------------- raw predicting ----------------- dict_of_patients_results = OrderedDict() set_of_result_types = set() self.model_index = self.model_index.sort_values(by="rank") # get paths of all parameter sets to be loaded for temporal ensembling. (or just one for no temp. ensembling). weight_paths = [os.path.join(self.cf.fold_dir, file_name) for file_name in self.model_index["file_name"]] for rank_ix, weight_path in enumerate(weight_paths): self.logger.info(('tmp ensembling over rank_ix:{} epoch:{}'.format(rank_ix, weight_path))) self.net.load_state_dict(torch.load(weight_path)) self.net.eval() self.rank_ix = str(rank_ix) plot_batches = np.random.choice(np.arange(batch_gen['n_test']), size=min(batch_gen['n_test'], self.cf.n_test_plots), replace=False) with torch.no_grad(): for i in range(batch_gen['n_test']): batch = next(batch_gen['test']) pid = np.unique(batch['pid']) assert len(pid)==1 pid = pid[0] if not pid in dict_of_patients_results.keys(): # store batch info in patient entry of results dict. dict_of_patients_results[pid] = {} dict_of_patients_results[pid]['results_dicts'] = [] dict_of_patients_results[pid]['patient_bb_target'] = batch['patient_bb_target'] for name in self.cf.roi_items: dict_of_patients_results[pid]["patient_"+name] = batch["patient_"+name] stime = time.time() results_dict = self.predict_patient(batch) #only holds "boxes", "seg_preds" # needs ohe seg probs in seg_preds entry: results_dict['seg_preds'] = np.argmax(results_dict['seg_preds'], axis=1)[:,np.newaxis] print("\rpredicting patient {} with weight rank {} (progress: {}/{}) took {:.2f}s".format( str(pid), rank_ix, (rank_ix)*batch_gen['n_test']+(i+1), len(weight_paths)*batch_gen['n_test'], time.time()-stime), end="", flush=True) if i in plot_batches and (not self.patched_patient or 'patient_data' in batch.keys()): try: # view qualitative results of random test case out_file = os.path.join(self.example_plot_dir, 'batch_example_test_{}_rank_{}.png'.format(self.cf.fold, rank_ix)) utils.split_off_process(plg.view_batch, self.cf, batch, results_dict, has_colorchannels=self.cf.has_colorchannels, show_gt_labels=True, show_seg_ids='dice' in self.cf.metrics, get_time="test-example plot", out_file=out_file) except Exception as e: self.logger.info("WARNING: error in view_batch: {}".format(e)) if 'dice' in self.cf.metrics: if self.patched_patient: assert 'patient_seg' in batch.keys(), "Results_dict preds are in original patient shape." results_dict['batch_dices'] = mutils.dice_per_batch_and_class( results_dict['seg_preds'], batch["patient_seg"] if self.patched_patient else batch['seg'], self.cf.num_seg_classes, convert_to_ohe=True) dict_of_patients_results[pid]['results_dicts'].append({k:v for k,v in results_dict.items() if k in ["boxes", "batch_dices"]}) # collect result types to know which ones to look for when saving set_of_result_types.update(dict_of_patients_results[pid]['results_dicts'][-1].keys()) # -------------- re-order, save raw results ----------------- self.logger.info('finished predicting test set. starting aggregation of predictions.') results_per_patient = [] for pid, p_dict in dict_of_patients_results.items(): # dict_of_patients_results[pid]['results_list'] has length batch['n_test'] results_dict = {} # collect all boxes/seg_preds of same batch_instance over temporal instances. b_size = len(p_dict['results_dicts'][0]["boxes"]) for res_type in [rtype for rtype in set_of_result_types if rtype in ["boxes", "batch_dices"]]:#, "seg_preds"]]: if not 'batch' in res_type: #assume it's results on batch-element basis results_dict[res_type] = [[item for rank_dict in p_dict['results_dicts'] for item in rank_dict[res_type][batch_instance]] for batch_instance in range(b_size)] else: results_dict[res_type] = [] for dict in p_dict['results_dicts']: if 'dice' in res_type: item = dict[res_type] #dict['batch_dices'] has shape (num_seg_classes,) assert len(item) == self.cf.num_seg_classes, \ "{}, {}".format(len(item), self.cf.num_seg_classes) else: raise NotImplementedError results_dict[res_type].append(item) # rdict[dice] shape (n_rank_epochs (n_saved_ranks), nsegclasses) # calc mean over test epochs so inline with shape from sampling results_dict[res_type] = np.mean(results_dict[res_type], axis=0) #maybe error type with other than dice if not hasattr(self.cf, "eval_test_separately") or not self.cf.eval_test_separately: # add unpatched 2D or 3D (if dim==3 or merge_2D_to_3D) ground truth boxes for evaluation. for b in range(p_dict['patient_bb_target'].shape[0]): for targ in range(len(p_dict['patient_bb_target'][b])): gt_box = {'box_type': 'gt', 'box_coords':p_dict['patient_bb_target'][b][targ], 'class_targets': p_dict['patient_class_targets'][b][targ]} for name in self.cf.roi_items: gt_box.update({name: p_dict["patient_"+name][b][targ]}) results_dict['boxes'][b].append(gt_box) results_per_patient.append([results_dict, pid]) - out_string = 'pred_results_held_out' if self.cf.held_out_test_set else 'pred_results' + out_string = 'pred_results_held_out' if self.cf.hold_out_test_set else 'pred_results' with open(os.path.join(self.cf.fold_dir, '{}.pkl'.format(out_string)), 'wb') as handle: pickle.dump(results_per_patient, handle) if return_results: # -------------- results processing, clustering, etc. ----------------- final_patient_box_results = [ (res_dict["boxes"], pid) for res_dict,pid in results_per_patient ] if self.cf.clustering == "wbc": self.logger.info('applying WBC to test-set predictions with iou = {} and n_ens = {}.'.format( self.cf.clustering_iou, self.n_ens)) mp_inputs = [[self.regress_flag, ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou, self.n_ens] for ii in final_patient_box_results] del final_patient_box_results pool = Pool(processes=self.cf.n_workers) final_patient_box_results = pool.map(apply_wbc_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() del mp_inputs elif self.cf.clustering == "nms": self.logger.info('applying standard NMS to test-set predictions with iou = {}.'.format(self.cf.clustering_iou)) pool = Pool(processes=self.cf.n_workers) mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou] for ii in final_patient_box_results] del final_patient_box_results final_patient_box_results = pool.map(apply_nms_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() del mp_inputs if self.cf.merge_2D_to_3D_preds: self.logger.info('applying 2D-to-3D merging to test-set predictions with iou = {}.'.format(self.cf.merge_3D_iou)) mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.merge_3D_iou] for ii in final_patient_box_results] del final_patient_box_results pool = Pool(processes=self.cf.n_workers) final_patient_box_results = pool.map(apply_2d_3d_merging_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() del mp_inputs # final_patient_box_results holds [avg_boxes, pid] if wbc for ix in range(len(results_per_patient)): assert results_per_patient[ix][1] == final_patient_box_results[ix][1], "should be same pid" results_per_patient[ix][0]["boxes"] = final_patient_box_results[ix][0] # results_per_patient = [(res_dict["boxes"] = boxes, pid) for (boxes,pid) in final_patient_box_results] return results_per_patient # holds list of (results_dict, pid) diff --git a/utils/exp_utils.py b/utils/exp_utils.py index c734481..fc22592 100644 --- a/utils/exp_utils.py +++ b/utils/exp_utils.py @@ -1,727 +1,727 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from typing import Union, Iterable import sys import os import subprocess from multiprocessing import Process import threading import pickle import importlib.util import psutil import time import nvidia_smi import logging from torch.utils.tensorboard import SummaryWriter from collections import OrderedDict import numpy as np import pandas as pd import torch def import_module(name, path): """ correct way of importing a module dynamically in python 3. :param name: name given to module instance. :param path: path to module. :return: module: returned module instance. """ spec = importlib.util.spec_from_file_location(name, path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module def save_obj(obj, name): """Pickle a python object.""" with open(name + '.pkl', 'wb') as f: pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) def load_obj(file_path): with open(file_path, 'rb') as handle: return pickle.load(handle) def IO_safe(func, *args, _tries=5, _raise=True, **kwargs): """ Wrapper calling function func with arguments args and keyword arguments kwargs to catch input/output errors on cluster. :param func: function to execute (intended to be read/write operation to a problematic cluster drive, but can be any function). :param args: positional args of func. :param kwargs: kw args of func. :param _tries: how many attempts to make executing func. """ for _try in range(_tries): try: return func(*args, **kwargs) except OSError as e: # to catch cluster issues with network drives if _raise: raise e else: print("After attempting execution {} time{}, following error occurred:\n{}".format(_try + 1, "" if _try == 0 else "s", e)) continue def split_off_process(target, *args, daemon=False, **kwargs): """Start a process that won't block parent script. No join(), no return value. If daemon=False: before parent exits, it waits for this to finish. """ p = Process(target=target, args=tuple(args), kwargs=kwargs, daemon=daemon) p.start() return p def query_nvidia_gpu(device_id, d_keyword=None, no_units=False): """ :param device_id: :param d_keyword: -d, --display argument (keyword(s) for selective display), all are selected if None :return: dict of gpu-info items """ cmd = ['nvidia-smi', '-i', str(device_id), '-q'] if d_keyword is not None: cmd += ['-d', d_keyword] outp = subprocess.check_output(cmd).strip().decode('utf-8').split("\n") outp = [x for x in outp if len(x) > 0] headers = [ix for ix, item in enumerate(outp) if len(item.split(":")) == 1] + [len(outp)] out_dict = {} for lix, hix in enumerate(headers[:-1]): head = outp[hix].strip().replace(" ", "_").lower() out_dict[head] = {} for lix2 in range(hix, headers[lix + 1]): try: key, val = [x.strip().lower() for x in outp[lix2].split(":")] if no_units: val = val.split()[0] out_dict[head][key] = val except: pass return out_dict class _AnsiColorizer(object): """ A colorizer is an object that loosely wraps around a stream, allowing callers to write text to the stream in a particular color. Colorizer classes must implement C{supported()} and C{write(text, color)}. """ _colors = dict(black=30, red=31, green=32, yellow=33, blue=34, magenta=35, cyan=36, white=37, default=39) def __init__(self, stream): self.stream = stream @classmethod def supported(cls, stream=sys.stdout): """ A class method that returns True if the current platform supports coloring terminal output using this method. Returns False otherwise. """ if not stream.isatty(): return False # auto color only on TTYs try: import curses except ImportError: return False else: try: try: return curses.tigetnum("colors") > 2 except curses.error: curses.setupterm() return curses.tigetnum("colors") > 2 except: raise # guess false in case of error return False def write(self, text, color): """ Write the given text to the stream in the given color. @param text: Text to be written to the stream. @param color: A string label for a color. e.g. 'red', 'white'. """ color = self._colors[color] self.stream.write('\x1b[%sm%s\x1b[0m' % (color, text)) class ColorHandler(logging.StreamHandler): def __init__(self, stream=sys.stdout): super(ColorHandler, self).__init__(_AnsiColorizer(stream)) def emit(self, record): msg_colors = { logging.DEBUG: "green", logging.INFO: "default", logging.WARNING: "red", logging.ERROR: "red" } color = msg_colors.get(record.levelno, "blue") self.stream.write(record.msg + "\n", color) class CombinedPrinter(object): """combined print function. prints to logger and/or file if given, to normal print if non given. """ def __init__(self, logger=None, file=None): if logger is None and file is None: self.out = [print] elif logger is None: self.out = [file.write] elif file is None: self.out = [logger.info] else: self.out = [logger.info, file.write] def __call__(self, string): for fct in self.out: fct(string) class Nvidia_GPU_Logger(object): def __init__(self): self.count = None def get_vals(self): # cmd = ['nvidia-settings', '-t', '-q', 'GPUUtilization'] # gpu_util = subprocess.check_output(cmd).strip().decode('utf-8').split(",") # gpu_util = dict([f.strip().split("=") for f in gpu_util]) # cmd[-1] = 'UsedDedicatedGPUMemory' # gpu_used_mem = subprocess.check_output(cmd).strip().decode('utf-8') nvidia_smi.nvmlInit() # card id 0 hardcoded here, there is also a call to get all available card ids, so we could iterate self.gpu_handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0) util_res = nvidia_smi.nvmlDeviceGetUtilizationRates(self.gpu_handle) #mem_res = nvidia_smi.nvmlDeviceGetMemoryInfo(self.gpu_handle) # current_vals = {"gpu_mem_alloc": mem_res.used / (1024**2), "gpu_graphics_util": int(gpu_util['graphics']), # "gpu_mem_util": gpu_util['memory'], "time": time.time()} current_vals = {"gpu_graphics_util": float(util_res.gpu), "time": time.time()} return current_vals def loop(self, interval): i = 0 while True: current_vals = self.get_vals() self.log["time"].append(time.time()) self.log["gpu_util"].append(current_vals["gpu_graphics_util"]) if self.count is not None: i += 1 if i == self.count: exit(0) time.sleep(self.interval) def start(self, interval=1.): self.interval = interval self.start_time = time.time() self.log = {"time": [], "gpu_util": []} if self.interval is not None: thread = threading.Thread(target=self.loop) thread.daemon = True thread.start() class CombinedLogger(object): """Combine console and tensorboard logger and record system metrics. """ def __init__(self, name, log_dir, server_env=True, fold="all", sysmetrics_interval=2): self.pylogger = logging.getLogger(name) self.tboard = SummaryWriter(log_dir=os.path.join(log_dir, "tboard")) self.times = {} self.log_dir = log_dir self.fold = str(fold) self.server_env = server_env self.pylogger.setLevel(logging.DEBUG) self.log_file = os.path.join(log_dir, "fold_"+self.fold, 'exec.log') os.makedirs(os.path.dirname(self.log_file), exist_ok=True) self.pylogger.addHandler(logging.FileHandler(self.log_file)) if not server_env: self.pylogger.addHandler(ColorHandler()) else: self.pylogger.addHandler(logging.StreamHandler()) self.pylogger.propagate = False # monitor system metrics (cpu, mem, ...) if not server_env and sysmetrics_interval > 0: self.sysmetrics = pd.DataFrame( columns=["global_step", "rel_time", r"CPU (%)", "mem_used (GB)", r"mem_used (%)", r"swap_used (GB)", r"gpu_utilization (%)"], dtype="float16") for device in range(torch.cuda.device_count()): self.sysmetrics[ "mem_allocd (GB) by torch on {:10s}".format(torch.cuda.get_device_name(device))] = np.nan self.sysmetrics[ "mem_cached (GB) by torch on {:10s}".format(torch.cuda.get_device_name(device))] = np.nan self.sysmetrics_start(sysmetrics_interval) pass else: print("NOT logging sysmetrics") def __getattr__(self, attr): """delegate all undefined method requests to objects of this class in order pylogger, tboard (first find first serve). E.g., combinedlogger.add_scalars(...) should trigger self.tboard.add_scalars(...) """ for obj in [self.pylogger, self.tboard]: if attr in dir(obj): return getattr(obj, attr) print("logger attr not found") #raise AttributeError("CombinedLogger has no attribute {}".format(attr)) def set_logfile(self, fold=None, log_file=None): if fold is not None: self.fold = str(fold) if log_file is None: self.log_file = os.path.join(self.log_dir, "fold_"+self.fold, 'exec.log') else: self.log_file = log_file os.makedirs(os.path.dirname(self.log_file), exist_ok=True) for hdlr in self.pylogger.handlers: hdlr.close() self.pylogger.handlers = [] self.pylogger.addHandler(logging.FileHandler(self.log_file)) if not self.server_env: self.pylogger.addHandler(ColorHandler()) else: self.pylogger.addHandler(logging.StreamHandler()) def time(self, name, toggle=None): """record time-spans as with a stopwatch. :param name: :param toggle: True^=On: start time recording, False^=Off: halt rec. if None determine from current status. :return: either start-time or last recorded interval """ if toggle is None: if name in self.times.keys(): toggle = not self.times[name]["toggle"] else: toggle = True if toggle: if not name in self.times.keys(): self.times[name] = {"total": 0, "last": 0} elif self.times[name]["toggle"] == toggle: self.info("restarting running stopwatch") self.times[name]["last"] = time.time() self.times[name]["toggle"] = toggle return time.time() else: if toggle == self.times[name]["toggle"]: self.info("WARNING: tried to stop stopped stop watch: {}.".format(name)) self.times[name]["last"] = time.time() - self.times[name]["last"] self.times[name]["total"] += self.times[name]["last"] self.times[name]["toggle"] = toggle return self.times[name]["last"] def get_time(self, name=None, kind="total", format=None, reset=False): """ :param name: :param kind: 'total' or 'last' :param format: None for float, "hms"/"ms" for (hours), mins, secs as string :param reset: reset time after retrieving :return: """ if name is None: times = self.times if reset: self.reset_time() return times else: if self.times[name]["toggle"]: self.time(name, toggle=False) time = self.times[name][kind] if format == "hms": m, s = divmod(time, 60) h, m = divmod(m, 60) time = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(m), int(s)) elif format == "ms": m, s = divmod(time, 60) time = "{:02d}m:{:02d}s".format(int(m), int(s)) if reset: self.reset_time(name) return time def reset_time(self, name=None): if name is None: self.times = {} else: del self.times[name] def sysmetrics_update(self, global_step=None): if global_step is None: global_step = time.strftime("%x_%X") mem = psutil.virtual_memory() mem_used = (mem.total - mem.available) gpu_vals = self.gpu_logger.get_vals() rel_time = time.time() - self.sysmetrics_start_time self.sysmetrics.loc[len(self.sysmetrics)] = [global_step, rel_time, psutil.cpu_percent(), mem_used / 1024 ** 3, mem_used / mem.total * 100, psutil.swap_memory().used / 1024 ** 3, int(gpu_vals['gpu_graphics_util']), *[torch.cuda.memory_allocated(d) / 1024 ** 3 for d in range(torch.cuda.device_count())], *[torch.cuda.memory_cached(d) / 1024 ** 3 for d in range(torch.cuda.device_count())] ] return self.sysmetrics.loc[len(self.sysmetrics) - 1].to_dict() def sysmetrics2tboard(self, metrics=None, global_step=None, suptitle=None): tag = "per_time" if metrics is None: metrics = self.sysmetrics_update(global_step=global_step) tag = "per_epoch" if suptitle is not None: suptitle = str(suptitle) elif self.fold != "": suptitle = "Fold_" + str(self.fold) if suptitle is not None: self.tboard.add_scalars(suptitle + "/System_Metrics/" + tag, {k: v for (k, v) in metrics.items() if (k != "global_step" and k != "rel_time")}, global_step) def sysmetrics_loop(self): try: os.nice(-19) self.info("Logging system metrics with superior process priority.") except: self.info("Logging system metrics without superior process priority.") while True: metrics = self.sysmetrics_update() self.sysmetrics2tboard(metrics, global_step=metrics["rel_time"]) # print("thread alive", self.thread.is_alive()) time.sleep(self.sysmetrics_interval) def sysmetrics_start(self, interval): if interval is not None and interval > 0: self.sysmetrics_interval = interval self.gpu_logger = Nvidia_GPU_Logger() self.sysmetrics_start_time = time.time() self.sys_metrics_process = split_off_process(target=self.sysmetrics_loop, daemon=True) # self.thread = threading.Thread(target=self.sysmetrics_loop) # self.thread.daemon = True # self.thread.start() def sysmetrics_save(self, out_file): self.sysmetrics.to_pickle(out_file) def metrics2tboard(self, metrics, global_step=None, suptitle=None): """ :param metrics: {'train': dataframe, 'val':df}, df as produced in evaluator.py.evaluate_predictions """ # print("metrics", metrics) if global_step is None: global_step = len(metrics['train'][list(metrics['train'].keys())[0]]) - 1 if suptitle is not None: suptitle = str(suptitle) else: suptitle = "Fold_" + str(self.fold) for key in ['train', 'val']: # series = {k:np.array(v[-1]) for (k,v) in metrics[key].items() if not np.isnan(v[-1]) and not 'Bin_Stats' in k} loss_series = {} unc_series = {} bin_stat_series = {} mon_met_series = {} for tag, val in metrics[key].items(): val = val[-1] # maybe remove list wrapping, recording in evaluator? if 'bin_stats' in tag.lower() and not np.isnan(val): bin_stat_series["{}".format(tag.split("/")[-1])] = val elif 'uncertainty' in tag.lower() and not np.isnan(val): unc_series["{}".format(tag)] = val elif 'loss' in tag.lower() and not np.isnan(val): loss_series["{}".format(tag)] = val elif not np.isnan(val): mon_met_series["{}".format(tag)] = val self.tboard.add_scalars(suptitle + "/Binary_Statistics/{}".format(key), bin_stat_series, global_step) self.tboard.add_scalars(suptitle + "/Uncertainties/{}".format(key), unc_series, global_step) self.tboard.add_scalars(suptitle + "/Losses/{}".format(key), loss_series, global_step) self.tboard.add_scalars(suptitle + "/Monitor_Metrics/{}".format(key), mon_met_series, global_step) self.tboard.add_scalars(suptitle + "/Learning_Rate", metrics["lr"], global_step) return def batchImgs2tboard(self, batch, results_dict, cmap, boxtype2color, img_bg=False, global_step=None): raise NotImplementedError("not up-to-date, problem with importing plotting-file, torchvision dependency.") if len(batch["seg"].shape) == 5: # 3D imgs slice_ix = np.random.randint(batch["seg"].shape[-1]) seg_gt = plg.to_rgb(batch['seg'][:, 0, :, :, slice_ix], cmap) seg_pred = plg.to_rgb(results_dict['seg_preds'][:, 0, :, :, slice_ix], cmap) mod_img = plg.mod_to_rgb(batch["data"][:, 0, :, :, slice_ix]) if img_bg else None elif len(batch["seg"].shape) == 4: seg_gt = plg.to_rgb(batch['seg'][:, 0, :, :], cmap) seg_pred = plg.to_rgb(results_dict['seg_preds'][:, 0, :, :], cmap) mod_img = plg.mod_to_rgb(batch["data"][:, 0]) if img_bg else None else: raise Exception("batch content has wrong format: {}".format(batch["seg"].shape)) # from here on only works in 2D seg_gt = np.transpose(seg_gt, axes=(0, 3, 1, 2)) # previous shp: b,x,y,c seg_pred = np.transpose(seg_pred, axes=(0, 3, 1, 2)) seg = np.concatenate((seg_gt, seg_pred), axis=0) # todo replace torchvision (tv) dependency seg = tv.utils.make_grid(torch.from_numpy(seg), nrow=2) self.tboard.add_image("Batch seg, 1st col: gt, 2nd: pred.", seg, global_step=global_step) if img_bg: bg_img = np.transpose(mod_img, axes=(0, 3, 1, 2)) else: bg_img = seg_gt box_imgs = plg.draw_boxes_into_batch(bg_img, results_dict["boxes"], boxtype2color) box_imgs = tv.utils.make_grid(torch.from_numpy(box_imgs), nrow=4) self.tboard.add_image("Batch bboxes", box_imgs, global_step=global_step) return def __del__(self): # otherwise might produce multiple prints e.g. in ipython console #self.sys_metrics_process.terminate() for hdlr in self.pylogger.handlers: hdlr.close() self.pylogger.handlers = [] del self.pylogger self.tboard.flush() # close holds up main script exit. maybe revise this issue with a later pytorch version. #self.tboard.close() def get_logger(exp_dir, server_env=False, sysmetrics_interval=2): log_dir = os.path.join(exp_dir, "logs") logger = CombinedLogger('Reg R-CNN', log_dir, server_env=server_env, sysmetrics_interval=sysmetrics_interval) print("logging to {}".format(logger.log_file)) return logger def prepare_monitoring(cf): """ creates dictionaries, where train/val metrics are stored. """ metrics = {} # first entry for loss dict accounts for epoch starting at 1. metrics['train'] = OrderedDict() # [(l_name, [np.nan]) for l_name in cf.losses_to_monitor] ) metrics['val'] = OrderedDict() # [(l_name, [np.nan]) for l_name in cf.losses_to_monitor] ) metric_classes = [] if 'rois' in cf.report_score_level: metric_classes.extend([v for k, v in cf.class_dict.items()]) if hasattr(cf, "eval_bins_separately") and cf.eval_bins_separately: metric_classes.extend([v for k, v in cf.bin_dict.items()]) if 'patient' in cf.report_score_level: metric_classes.extend(['patient_' + cf.class_dict[cf.patient_class_of_interest]]) if hasattr(cf, "eval_bins_separately") and cf.eval_bins_separately: metric_classes.extend(['patient_' + cf.bin_dict[cf.patient_bin_of_interest]]) for cl in metric_classes: for m in cf.metrics: metrics['train'][cl + '_' + m] = [np.nan] metrics['val'][cl + '_' + m] = [np.nan] return metrics class ModelSelector: ''' saves a checkpoint after each epoch as 'last_state' (can be loaded to continue interrupted training). saves the top-k (k=cf.save_n_models) ranked epochs. In inference, predictions of multiple epochs can be ensembled to improve performance. ''' def __init__(self, cf, logger): self.cf = cf self.logger = logger self.model_index = pd.DataFrame(columns=["rank", "score", "criteria_values", "file_name"], index=pd.RangeIndex(self.cf.min_save_thresh, self.cf.num_epochs, name="epoch")) def run_model_selection(self, net, optimizer, monitor_metrics, epoch): """rank epoch via weighted mean from self.cf.model_selection_criteria: {criterion : weight} :param net: :param optimizer: :param monitor_metrics: :param epoch: :return: """ crita = self.cf.model_selection_criteria # shorter alias metrics = monitor_metrics['val'] epoch_score = np.sum([metrics[criterion][-1] * weight for criterion, weight in crita.items() if not np.isnan(metrics[criterion][-1])]) if not self.cf.resume: epoch_score_check = np.sum([metrics[criterion][epoch] * weight for criterion, weight in crita.items() if not np.isnan(metrics[criterion][epoch])]) assert np.all(epoch_score == epoch_score_check) self.model_index.loc[epoch, ["score", "criteria_values"]] = epoch_score, {cr: metrics[cr][-1] for cr in crita.keys()} nonna_ics = self.model_index["score"].dropna(axis=0).index order = np.argsort(self.model_index.loc[nonna_ics, "score"].to_numpy(), kind="stable")[::-1] self.model_index.loc[nonna_ics, "rank"] = np.argsort(order) + 1 # no zero-indexing for ranks (best rank is 1). rank = int(self.model_index.loc[epoch, "rank"]) if rank <= self.cf.save_n_models: name = '{}_best_params.pth'.format(epoch) if self.cf.server_env: IO_safe(torch.save, net.state_dict(), os.path.join(self.cf.fold_dir, name)) else: torch.save(net.state_dict(), os.path.join(self.cf.fold_dir, name)) self.model_index.loc[epoch, "file_name"] = name self.logger.info("saved current epoch {} at rank {}".format(epoch, rank)) clean_up = self.model_index.dropna(axis=0, subset=["file_name"]) clean_up = clean_up[clean_up["rank"] > self.cf.save_n_models] if clean_up.size > 0: file_name = clean_up["file_name"].to_numpy().item() subprocess.call("rm {}".format(os.path.join(self.cf.fold_dir, file_name)), shell=True) self.logger.info("removed outranked epoch {} at {}".format(clean_up.index.values.item(), os.path.join(self.cf.fold_dir, file_name))) self.model_index.loc[clean_up.index, "file_name"] = np.nan state = { 'epoch': epoch, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), 'model_index': self.model_index, } if self.cf.server_env: IO_safe(torch.save, state, os.path.join(self.cf.fold_dir, 'last_state.pth')) else: torch.save(state, os.path.join(self.cf.fold_dir, 'last_state.pth')) -def parse_params_for_optim(net: torch.nn.Module, weight_decay: float = 0., exclude_from_wd: Iterable = ("norm", "bias")): +def parse_params_for_optim(net: torch.nn.Module, weight_decay: float = 0., exclude_from_wd: Iterable = ("norm",)): """Format network parameters for the optimizer. Convenience function to include options for group-specific settings like weight decay. :param net: :param weight_decay: :param exclude_from_wd: List of strings of parameter-group names to exclude from weight decay. Options: "norm", "bias". :return: """ # pytorch implements parameter groups as dicts {'params': ...} and # weight decay as p.data.mul_(1 - group['lr'] * group['weight_decay']) norm_types = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.SyncBatchNorm, torch.nn.LocalResponseNorm ] level_map = {"bias": "weight", "norm": "module"} type_map = {"norm": norm_types} exclude_from_wd = [str(name).lower() for name in exclude_from_wd] exclude_weight_names = [k for k, v in level_map.items() if k in exclude_from_wd and v == "weight"] exclude_module_types = tuple([type_ for k, v in level_map.items() if (k in exclude_from_wd and v == "module") for type_ in type_map[k]]) if exclude_from_wd: print("excluding {} from weight decay.".format(exclude_from_wd)) with_dec, no_dec = [], [] for name, module in net.named_modules(): if isinstance(module, exclude_module_types): no_dec.extend(module.parameters()) else: for param_name, param in module.named_parameters(): if np.any([ename in param_name for ename in exclude_weight_names]): no_dec.append(param) else: with_dec.append(param) groups = [{'params': gr, 'weight_decay': wd} for gr, wd in [(no_dec, 0.), (with_dec, weight_decay)] if len(gr) > 0] return groups def load_checkpoint(checkpoint_path, net, optimizer, model_selector): checkpoint = torch.load(checkpoint_path) net.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) model_selector.model_index = checkpoint["model_index"] return checkpoint['epoch'] + 1, net, optimizer, model_selector def prep_exp(dataset_path, exp_path, server_env, use_stored_settings=True, is_training=True): """ I/O handling, creating of experiment folder structure. Also creates a snapshot of configs/model scripts and copies them to the exp_dir. This way the exp_dir contains all info needed to conduct an experiment, independent to changes in actual source code. Thus, training/inference of this experiment can be started at anytime. Therefore, the model script is copied back to the source code dir as tmp_model (tmp_backbone). Provides robust structure for cloud deployment. :param dataset_path: path to source code for specific data set. (e.g. medicaldetectiontoolkit/lidc_exp) :param exp_path: path to experiment directory. :param server_env: boolean flag. pass to configs script for cloud deployment. :param use_stored_settings: boolean flag. When starting training: If True, starts training from snapshot in existing experiment directory, else creates experiment directory on the fly using configs/model scripts from source code. :param is_training: boolean flag. distinguishes train vs. inference mode. :return: configs object. """ if is_training: if use_stored_settings: cf_file = import_module('cf', os.path.join(exp_path, 'configs.py')) cf = cf_file.Configs(server_env) # in this mode, previously saved model and backbone need to be found in exp dir. if not os.path.isfile(os.path.join(exp_path, 'model.py')) or \ not os.path.isfile(os.path.join(exp_path, 'backbone.py')): raise Exception( "Selected use_stored_settings option but no model and/or backbone source files exist in exp dir.") cf.model_path = os.path.join(exp_path, 'model.py') cf.backbone_path = os.path.join(exp_path, 'backbone.py') else: # this case overwrites settings files in exp dir, i.e., default_configs, configs, backbone, model os.makedirs(exp_path, exist_ok=True) # run training with source code info and copy snapshot of model to exp_dir for later testing (overwrite scripts if exp_dir already exists.) subprocess.call('cp {} {}'.format('default_configs.py', os.path.join(exp_path, 'default_configs.py')), shell=True) subprocess.call( 'cp {} {}'.format(os.path.join(dataset_path, 'configs.py'), os.path.join(exp_path, 'configs.py')), shell=True) cf_file = import_module('cf_file', os.path.join(dataset_path, 'configs.py')) cf = cf_file.Configs(server_env) subprocess.call('cp {} {}'.format(cf.model_path, os.path.join(exp_path, 'model.py')), shell=True) subprocess.call('cp {} {}'.format(cf.backbone_path, os.path.join(exp_path, 'backbone.py')), shell=True) if os.path.isfile(os.path.join(exp_path, "fold_ids.pickle")): subprocess.call('rm {}'.format(os.path.join(exp_path, "fold_ids.pickle")), shell=True) else: # testing, use model and backbone stored in exp dir. cf_file = import_module('cf', os.path.join(exp_path, 'configs.py')) cf = cf_file.Configs(server_env) cf.model_path = os.path.join(exp_path, 'model.py') cf.backbone_path = os.path.join(exp_path, 'backbone.py') cf.exp_dir = exp_path cf.test_dir = os.path.join(cf.exp_dir, 'test') cf.plot_dir = os.path.join(cf.exp_dir, 'plots') if not os.path.exists(cf.test_dir): os.mkdir(cf.test_dir) if not os.path.exists(cf.plot_dir): os.mkdir(cf.plot_dir) cf.experiment_name = exp_path.split("/")[-1] cf.dataset_name = dataset_path cf.server_env = server_env cf.created_fold_id_pickle = False return cf