diff --git a/experiments/pet_ct_tnm_classification/configs.py b/experiments/pet_ct_tnm_classification/configs.py new file mode 100644 index 0000000..91e32e0 --- /dev/null +++ b/experiments/pet_ct_tnm_classification/configs.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python +# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import sys +import os +sys.path.append(os.path.dirname(os.path.realpath(__file__))) +import numpy as np +from default_configs import DefaultConfigs + +class configs(DefaultConfigs): + + def __init__(self, server_env=None): + + ######################### + # Preprocessing # + ######################### + + self.root_dir = '/mnt/hdd2/basel_lung/' + self.raw_data_dir = '{}/LungStageData'.format(self.root_dir) + self.pp_dir = '/media/paul/ssd1/pp_norm_basel' + + ######################### + # I/O # + ######################### + + + # one out of [2, 3]. dimension the model operates in. + self.dim = 3 + + # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn', 'detection_unet']. + self.model = 'retina_unet' + + DefaultConfigs.__init__(self, self.model, server_env, self.dim) + + # int [0 < dataset_size]. select n patients from dataset for prototyping. + self.select_prototype_subset = None + + # if True, test data lies in a separate folder and is not part of the cross validation. + self.hold_out_test_set = True + + # path to preprocessed data. + self.pp_name = 'pp_norm_basel' + self.input_df_name = 'info_df.pickle' + self.pp_data_path = '/media/paul/ssd1/{}'.format(self.pp_name) + self.pp_test_data_path = self.pp_data_path # path to mounted input dir + self.pp_test_out_path = self.pp_data_path # path to mounted output dir + + + # settings for deployment in cloud. + if server_env: + # path to preprocessed data. + self.pp_name = 'pp_fg_slices' + self.crop_name = 'pp_fg_slices_packed' + self.pp_data_path = '/datasets/datasets_paul/{}/{}'.format(self.pp_name, self.crop_name) + self.pp_test_data_path = self.pp_data_path + self.select_prototype_subset = None + + ######################### + # Data Loader # + ######################### + + # select modalities from preprocessed data + self.channels = [0, 1] + self.n_channels = len(self.channels) + + # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation. + self.pre_crop_size_2D = [300, 300] + self.patch_size_2D = [288, 288] + self.pre_crop_size_3D = [280, 280, 48] + self.patch_size_3D = [192, 192, 32] + self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D + self.pre_crop_size = self.pre_crop_size_2D if self.dim == 2 else self.pre_crop_size_3D + + # ratio of free sampled batch elements before class balancing is triggered + # (>0 to include "empty"/background patches.) + self.batch_sample_slack = 0.2 + + # set 2D network to operate in 3D images. + self.merge_2D_to_3D_preds = True + + # feed +/- n neighbouring slices into channel dimension. set to None for no context. + self.n_3D_context = None + if self.n_3D_context is not None and self.dim == 2: + self.n_channels *= (self.n_3D_context * 2 + 1) + + + ######################### + # Architecture # + ######################### + + self.start_filts = 48 if self.dim == 2 else 18 + self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2 + self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50' + self.norm = None # one of None, 'instance_norm', 'batch_norm' + self.weight_decay = 0 + + # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform') + self.weight_init = None + + ######################### + # Schedule / Selection # + ######################### + + self.num_epochs = 100 + self.num_train_batches = 5 if self.dim == 2 else 60 + self.batch_size = 20 if self.dim == 2 else 8 + + self.do_validation = False + # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training) + # the former is morge accurate, while the latter is faster (depending on volume size) + self.val_mode = 'val_sampling' # one of 'val_sampling' , 'val_patient' + if self.val_mode == 'val_patient': + self.max_val_patients = 50 # if 'None' iterates over entire val_set once. + if self.val_mode == 'val_sampling': + self.num_val_batches = 10 + + ######################### + # Testing / Plotting # + ######################### + + # set the top-n-epochs to be saved for temporal averaging in testing. + self.save_n_models = 5 + self.test_n_epochs = 5 + # set a minimum epoch number for saving in case of instabilities in the first phase of training. + self.min_save_thresh = 0 if self.dim == 2 else 0 + + self.report_score_level = ['patient', 'rois'] # choose list from 'patient', 'rois' + self.class_dict = {1: 'foreground'} # 0 is background. + self.patient_class_of_interest = 1 # patient metrics are only plotted for one class. + self.ap_match_ious = [0.1] # list of ious to be evaluated for ap-scoring. + + self.model_selection_criteria = ['foreground_ap'] # criteria to average over for saving epochs. + self.min_det_thresh = 0.1 # minimum confidence value to select predictions for evaluation. + + # threshold for clustering predictions together (wcs = weighted cluster scoring). + # needs to be >= the expected overlap of predictions coming from one model (typically NMS threshold). + # if too high, preds of the same object are separate clusters. + self.wcs_iou = 1e-5 + + self.plot_prediction_histograms = True + self.plot_stat_curves = False + + ######################### + # Data Augmentation # + ######################### + + self.da_kwargs={ + 'do_elastic_deform': True, + 'alpha':(0., 1500.), + 'sigma':(30., 50.), + 'do_rotation':True, + 'angle_x': (0., 2 * np.pi), + 'angle_y': (0., 0), + 'angle_z': (0., 0), + 'do_scale': True, + 'scale':(0.8, 1.1), + 'random_crop':False, + 'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3), + 'border_mode_data': 'constant', + 'border_cval_data': 0, + 'order_data': 1 + } + + if self.dim == 3: + self.da_kwargs['do_elastic_deform'] = False + self.da_kwargs['angle_x'] = (0, 0.0) + self.da_kwargs['angle_y'] = (0, 0.0) #must be 0!! + self.da_kwargs['angle_z'] = (0., 2 * np.pi) + + + ######################### + # Add model specifics # + ######################### + + {'detection_unet': self.add_det_unet_configs, + 'mrcnn': self.add_mrcnn_configs, + 'retina_net': self.add_mrcnn_configs, + 'retina_unet': self.add_mrcnn_configs, + 'prob_detector': self.add_mrcnn_configs, + }[self.model]() + + + def add_det_unet_configs(self): + + self.learning_rate = [1e-4] * (self.num_epochs//4) + [5*1e-5] * (self.num_epochs//4) + [1e-5] * (self.num_epochs//4) + assert len(self.learning_rate) == self.num_epochs + + self.aggregation_operation = 'max' + + # max number of roi candidates to identify per batch element and class. + self.n_roi_candidates = 10 if self.dim == 2 else 30 + + # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') + self.seg_loss_mode = 'dice_wce' + + # if <1, false positive predictions in foreground are penalized less. + self.fp_dice_weight = 1 if self.dim == 2 else 1 + + self.wce_weights = [1, 1, 1] + self.detection_min_confidence = self.min_det_thresh + + # if 'True', loss distinguishes all classes, else only foreground vs. background (class agnostic). + self.class_specific_seg_flag = True + self.num_seg_classes = 2 if self.class_specific_seg_flag else 2 + self.head_classes = self.num_seg_classes + + def add_mrcnn_configs(self): + + # learning rate is a list with one entry per epoch. + self.learning_rate = [1e-4] * (self.num_epochs//2) + [5*1e-5] * (self.num_epochs//4) + [1e-5] * (self.num_epochs//4) + assert len(self.learning_rate) == self.num_epochs, [len(self.learning_rate), self.num_epochs] + + # disable the re-sampling of mask proposals to original size for speed-up. + # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching), + # mask-outputs are optional. + self.return_masks_in_val = True + self.return_masks_in_test = False + + # set number of proposal boxes to plot after each epoch. + self.n_plot_rpn_props = 5 if self.dim == 2 else 30 + + # number of classes for head networks: n_foreground_classes + 1 (background) + self.head_classes = 2 + + # seg_classes hier refers to the first stage classifier (RPN) + self.num_seg_classes = 2 # foreground vs. background + + # feature map strides per pyramid level are inferred from architecture. + self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]} + + # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale + # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.) + self.rpn_anchor_scales = {'xy': [[8], [16], [32], [64]], 'z': [[2], [4], [8], [16]]} + + # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3. + self.pyramid_levels = [0, 1, 2, 3] + + # number of feature maps in rpn. typically lowered in 3D to save gpu-memory. + self.n_rpn_features = 512 if self.dim == 2 else 128 + + # anchor ratios and strides per position in feature maps. + self.rpn_anchor_ratios = [0.5, 1, 2] + self.rpn_anchor_stride = 1 + + # Threshold for first stage (RPN) non-maximum suppression (NMS): LOWER == HARDER SELECTION + self.rpn_nms_threshold = 0.7 if self.dim == 2 else 0.7 + + # loss sampling settings. + self.rpn_train_anchors_per_image = 6 #per batch element + self.train_rois_per_image = 6 #per batch element + self.roi_positive_ratio = 0.5 + self.anchor_matching_iou = 0.7 + + # factor of top-k candidates to draw from per negative sample (stochastic-hard-example-mining). + # poolsize to draw top-k candidates from will be shem_poolsize * n_negative_samples. + self.shem_poolsize = 10 + + self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3) + self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5) + self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10) + + self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) + self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) + self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]]) + self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1], + self.patch_size_3D[2], self.patch_size_3D[2]]) + if self.dim == 2: + self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4] + self.bbox_std_dev = self.bbox_std_dev[:4] + self.window = self.window[:4] + self.scale = self.scale[:4] + + # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element. + self.pre_nms_limit = 3000 if self.dim == 2 else 6000 + + # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True, + # since proposals of the entire batch are forwarded through second stage in as one "batch". + self.roi_chunk_size = 2500 if self.dim == 2 else 600 + self.post_nms_rois_training = 500 if self.dim == 2 else 75 + self.post_nms_rois_inference = 500 + + # Final selection of detections (refine_detections) + self.model_max_instances_per_batch_element = 10 if self.dim == 2 else 30 # per batch element and class. + self.detection_nms_threshold = 1e-5 # needs to be > 0, otherwise all predictions are one cluster. + self.model_min_confidence = 0.1 + + if self.dim == 2: + self.backbone_shapes = np.array( + [[int(np.ceil(self.patch_size[0] / stride)), + int(np.ceil(self.patch_size[1] / stride))] + for stride in self.backbone_strides['xy']]) + else: + self.backbone_shapes = np.array( + [[int(np.ceil(self.patch_size[0] / stride)), + int(np.ceil(self.patch_size[1] / stride)), + int(np.ceil(self.patch_size[2] / stride_z))] + for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z'] + )]) + + if self.model == 'retina_net' or self.model == 'retina_unet' or self.model == 'prob_detector': + # implement extra anchor-scales according to retina-net publication. + self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in + self.rpn_anchor_scales['xy']] + self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in + self.rpn_anchor_scales['z']] + self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3 + + self.n_rpn_features = 256 if self.dim == 2 else 64 + + # pre-selection of detections for NMS-speedup. per entire batch. + self.pre_nms_limit = 10000 if self.dim == 2 else 50000 + + # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002 + self.anchor_matching_iou = 0.5 + + + if self.model == 'retina_unet': + self.operate_stride1 = True + + + if self.model == 'prob_detector': + self.monitor_extra_values = ['kl_loss', 'mu_prior', 'mu_post', 'sigma_prior', 'sigma_post'] + self.box_color_palette['sample'] = 'w' + self.n_latent_dims = 6 + self.class_specific_seg_flag = True + self.n_probabilistic_samples = 4 + + # if 'True', seg loss distinguishes all classes, else only foreground vs. background (class agnostic). + self.num_seg_classes = 2 if self.class_specific_seg_flag else 2 \ No newline at end of file diff --git a/experiments/pet_ct_tnm_classification/data_loader.py b/experiments/pet_ct_tnm_classification/data_loader.py new file mode 100644 index 0000000..2375ab9 --- /dev/null +++ b/experiments/pet_ct_tnm_classification/data_loader.py @@ -0,0 +1,471 @@ +#!/usr/bin/env python +# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +''' +Example Data Loader for the LIDC data set. This dataloader expects preprocessed data in .npy or .npz files per patient and +a pandas dataframe in the same directory containing the meta-info e.g. file paths, labels, foregound slice-ids. +''' + + +import numpy as np +import os +from collections import OrderedDict +import pandas as pd +import pickle +import time +import subprocess +import utils.dataloader_utils as dutils +import preprocessing as pp + +# batch generator tools from https://github.com/MIC-DKFZ/batchgenerators +from batchgenerators.dataloading.data_loader import SlimDataLoaderBase +from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror +from batchgenerators.transforms.abstract_transforms import Compose +from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter +from batchgenerators.dataloading import SingleThreadedAugmenter +from batchgenerators.transforms.spatial_transforms import SpatialTransform +from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform +from batchgenerators.transforms.utility_transforms import ConvertSegToBoundingBoxCoordinates + +import SimpleITK as sitk + + + +def get_train_generators(cf, logger): + """ + wrapper function for creating the training batch generator pipeline. returns the train/val generators. + selects patients according to cv folds (generated by first run/fold of experiment): + splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set) + If cf.hold_out_test_set is True, adds the test split to the training data. + """ + all_data = load_dataset(cf, logger) + all_pids_list = np.unique([v['pid'] for (k, v) in all_data.items()]) + + if not cf.created_fold_id_pickle: + fg = dutils.fold_generator(seed=cf.seed, n_splits=cf.n_cv_splits, len_data=len(all_pids_list)).get_fold_names() + with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'wb') as handle: + pickle.dump(fg, handle) + cf.created_fold_id_pickle = True + else: + with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle: + fg = pickle.load(handle) + + train_ix, val_ix, test_ix, _ = fg[cf.fold] + + train_pids = [all_pids_list[ix] for ix in train_ix] + val_pids = [all_pids_list[ix] for ix in val_ix] + + if cf.hold_out_test_set: + train_pids += [all_pids_list[ix] for ix in test_ix] + + train_data = {k: v for (k, v) in all_data.items() if any(p == v['pid'] for p in train_pids)} + val_data = {k: v for (k, v) in all_data.items() if any(p == v['pid'] for p in val_pids)} + + logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ix), len(val_ix), len(test_ix))) + batch_gen = {} + batch_gen['train'] = create_data_gen_pipeline(train_data, cf=cf, is_training=True) + batch_gen['val_sampling'] = create_data_gen_pipeline(val_data, cf=cf, is_training=False) + if cf.val_mode == 'val_patient': + batch_gen['val_patient'] = PatientBatchIterator(val_data, cf=cf) + batch_gen['n_val'] = len(val_ix) if cf.max_val_patients is None else cf.max_val_patients + else: + batch_gen['n_val'] = cf.num_val_batches + + return batch_gen + + +def get_test_generator(cf, logger): + """ + wrapper function for creating the test batch generator pipeline. + selects patients according to cv folds (generated by first run/fold of experiment) + If cf.hold_out_test_set is True, gets the data from an external folder instead. + """ + if cf.hold_out_test_set: + cf.pp_data_path = cf.pp_test_data_path + test_ix = None + else: + with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle: + fold_list = pickle.load(handle) + _, _, test_ix, _ = fold_list[cf.fold] + # warnings.warn('WARNING: using validation set for testing!!!') + + test_paths = os.listdir(cf.pp_test_data_path) + test_data = {'data': test_paths} + logger.info("data set loaded with: {} test patients".format(len(test_paths))) + batch_gen = {} + batch_gen['test'] = PatientBatchIterator(test_data, cf=cf) + batch_gen['n_test'] = len(test_paths) + return batch_gen + + + +def load_dataset(cf, logger, subset_ixs=None): + """ + loads the dataset. if deployed in cloud also copies and unpacks the data to the working directory. + :param subset_ixs: subset indices to be loaded from the dataset. used e.g. for testing to only load the test folds. + :return: data: dictionary with one entry per patient (in this case per patient-breast, since they are treated as + individual images for training) each entry is a dictionary containing respective meta-info as well as paths to the preprocessed + numpy arrays to be loaded during batch-generation + """ + if cf.server_env: + copy_data = True + target_dir = os.path.join('/ssd', cf.slurm_job_id, cf.pp_name, cf.crop_name) + if not os.path.exists(target_dir): + cf.data_source_dir = cf.pp_data_path + os.makedirs(target_dir) + subprocess.call('rsync -av {} {}'.format( + os.path.join(cf.data_source_dir, cf.input_df_name), os.path.join(target_dir, cf.input_df_name)), shell=True) + logger.info('created target dir and info df at {}'.format(os.path.join(target_dir, cf.input_df_name))) + + elif subset_ixs is None: + copy_data = False + + cf.pp_data_path = target_dir + + + p_df = pd.read_pickle(os.path.join(cf.pp_data_path, cf.input_df_name)) + + if cf.select_prototype_subset is not None: + prototype_pids = p_df.pid.tolist()[:cf.select_prototype_subset] + p_df = p_df[p_df.pid.isin(prototype_pids)] + logger.warning('WARNING: using prototyping data subset!!!') + + if subset_ixs is not None: + subset_pids = [np.unique(p_df.pid.tolist())[ix] for ix in subset_ixs] + p_df = p_df[p_df.pid.isin(subset_pids)] + logger.info('subset: selected {} instances from df'.format(len(p_df))) + + if cf.server_env: + if copy_data: + copy_and_unpack_data(logger, p_df.pid.tolist(), cf.fold_dir, cf.data_source_dir, target_dir) + + class_targets = p_df['class_target'].tolist() + pids = p_df.pid.tolist() + imgs = [os.path.join(cf.pp_data_path, '{}_img.npy'.format(pid)) for pid in pids] + segs = [os.path.join(cf.pp_data_path,'{}_rois.npy'.format(pid)) for pid in pids] + + data = OrderedDict() + for ix, pid in enumerate(pids): + # for the experiment conducted here, malignancy scores are binarized: (benign: 1-2, malignant: 3-5) + data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid, 'class_target': class_targets[ix]} + data[pid]['fg_slices'] = p_df['fg_slices'].tolist()[ix] + + return data + + + +def create_data_gen_pipeline(patient_data, cf, is_training=True): + """ + create mutli-threaded train/val/test batch generation and augmentation pipeline. + :param patient_data: dictionary containing one dictionary per patient in the train/test subset. + :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing) + :return: multithreaded_generator + """ + + # create instance of batch generator as first element in pipeline. + data_gen = BatchGenerator(patient_data, batch_size=cf.batch_size, cf=cf) + + # add transformations to pipeline. + my_transforms = [] + if is_training: + mirror_transform = Mirror(axes=np.arange(2, cf.dim+2, 1)) + my_transforms.append(mirror_transform) + spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], + patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], + do_elastic_deform=cf.da_kwargs['do_elastic_deform'], + alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], + do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], + angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], + do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], + random_crop=cf.da_kwargs['random_crop'], order_seg=0, border_cval_seg=0) + + my_transforms.append(spatial_transform) + else: + my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) + + my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, get_rois_from_seg_flag=True, class_specific_seg_flag=cf.class_specific_seg_flag)) + all_transforms = Compose(my_transforms) + # multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms) + multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers)) + return multithreaded_generator + + +class BatchGenerator(SlimDataLoaderBase): + """ + creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D) + from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size. + Actual patch_size is obtained after data augmentation. + :param data: data dictionary as provided by 'load_dataset'. + :param batch_size: number of patients to sample for the batch + :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target + """ + def __init__(self, data, batch_size, cf): + super(BatchGenerator, self).__init__(data, batch_size) + + self.cf = cf + self.crop_margin = np.array(self.cf.patch_size)/8. #min distance of ROI center to edge of cropped_patch. + self.p_fg = 0.5 + + def generate_train_batch(self): + + batch_data, batch_segs, batch_pids, batch_targets, batch_patient_labels = [], [], [], [], [] + class_targets_list = [v['class_target'] for (k, v) in self._data.items()] + + batch_ixs = np.random.choice(len(class_targets_list), self.batch_size) + patients = list(self._data.items()) + + for b in batch_ixs: + patient = patients[b][1] + + data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(0, 2, 3, 1)) + seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0)) + batch_pids.append(str(patient['pid'])) + + + + if self.cf.dim == 2: + # draw random slice from patient while oversampling slices containing foreground objects with p_fg. + if len(patient['fg_slices']) > 0: + fg_prob = self.p_fg / len(patient['fg_slices']) + bg_prob = (1 - self.p_fg) / (data.shape[3] - len(patient['fg_slices'])) + slices_prob = [fg_prob if ix in patient['fg_slices'] else bg_prob for ix in range(data.shape[3])] + slice_id = np.random.choice(data.shape[3], p=slices_prob) + else: + slice_id = np.random.choice(data.shape[3]) + + # if set to not None, add neighbouring slices to each selected slice in channel dimension. + if self.cf.n_3D_context is not None: + padded_data = dutils.pad_nd_image(data[0], [(data.shape[-1] + (self.cf.n_3D_context*2))], mode='constant') + padded_slice_id = slice_id + self.cf.n_3D_context + data = (np.concatenate([padded_data[..., ii][np.newaxis] for ii in range( + padded_slice_id - self.cf.n_3D_context, padded_slice_id + self.cf.n_3D_context + 1)], axis=0)) + else: + data = data[..., slice_id] + seg = seg[..., slice_id] + + # pad data if smaller than pre_crop_size. + if np.any([data.shape[dim + 1] < ps for dim, ps in enumerate(self.cf.pre_crop_size)]): + new_shape = [np.max([data.shape[dim + 1], ps]) for dim, ps in enumerate(self.cf.pre_crop_size)] + data = dutils.pad_nd_image(data, new_shape, mode='constant') + seg = dutils.pad_nd_image(seg, new_shape, mode='constant') + + # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg. + crop_dims = [dim for dim, ps in enumerate(self.cf.pre_crop_size) if data.shape[dim + 1] > ps] + if len(crop_dims) > 0: + fg_prob_sample = np.random.rand(1) + # with p_fg: sample random pixel from random ROI and shift center by random value. + if fg_prob_sample < self.p_fg and np.sum(seg) > 0: + seg_ixs = np.argwhere(seg == np.random.choice(np.unique(seg)[1:], 1)) + roi_anchor_pixel = seg_ixs[np.random.choice(seg_ixs.shape[0], 1)][0] + assert seg[tuple(roi_anchor_pixel)] > 0 + # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by + # distance to the desired ROI < patch_size /2. + # (here final patch size to account for center_crop after data augmentation). + sample_seg_center = {} + for ii in crop_dims: + low = np.max((self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] - (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) + high = np.min((data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2, + roi_anchor_pixel[ii] + (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) + # happens if lesion on the edge of the image. dont care about roi anymore, + # just make sure pre-crop is inside image. + if low >= high: + low = data.shape[ii + 1] // 2 - (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) - 1 + high = data.shape[ii + 1] // 2 + (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) + 1 + sample_seg_center[ii] = np.random.randint(low=low, high=high) + + else: + # not guaranteed to be empty. probability of emptiness depends on the data. + sample_seg_center = {ii: np.random.randint(low=self.cf.pre_crop_size[ii]//2, + high=data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2) for ii in crop_dims} + + for ii in crop_dims: + min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2) + max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2) + data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1) + seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii) + + + class_targets = 0 if np.sum(seg) > 0 else -1 + batch_targets.append(class_targets) + + batch_data.append(data) + batch_segs.append(seg[np.newaxis]) + + data = np.array(batch_data).astype(np.float16) + seg = np.array(batch_segs).astype(np.uint8) + class_target = batch_targets + return {'data': data, 'seg': seg, 'pid': batch_pids, 'class_target': class_target} + + + +class PatientBatchIterator(SlimDataLoaderBase): + """ + creates a test generator that iterates over entire given dataset returning 1 patient per batch. + Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actualy evaluation (done in 3D), + if willing to accept speed-loss during training. + :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or + batch_size = n_2D_patches in 2D . + """ + def __init__(self, data, cf): #threads in augmenter + super(PatientBatchIterator, self).__init__(data, 0) + self.cf = cf + self.patient_ix = 0 + self.patch_size = cf.patch_size + if len(self.patch_size) == 2: + self.patch_size = self.patch_size + [1] + + + def preprocess_patient(self, path): + + + x = sitk.ReadImage(os.path.join(path, 'lsa_ct.nii.gz')) + p = sitk.ReadImage(os.path.join(path, 'lsa_pet.nii.gz')) + + x_spacing = x.GetSpacing() + if x_spacing[0] < 0.95 or x_spacing[2] < 3: + new_spacing = (0.976562, 0.976562, 3.27) + new_size = [int(x.GetSize()[ii] * x_spacing[ii] / new_spacing[ii]) for ii in range(3)] + reference_image = sitk.Image(new_size, x.GetPixelIDValue()) + reference_image.SetOrigin(x.GetOrigin()) + reference_image.SetDirection(x.GetDirection()) + reference_image.SetSpacing(new_spacing) + + # Resample without any smoothing. + x = sitk.Resample(x, reference_image) + + resampler = sitk.ResampleImageFilter() + resampler.SetReferenceImage(x) + # §resampler.SetInterpolator() # default linear + rp = resampler.Execute(p) + pi = sitk.GetArrayFromImage(rp) + xi = sitk.GetArrayFromImage(x) + + zmin, zmax = pp.get_z_crops(xi, 0) + x = xi[zmin:zmax] + p = pi[zmin:zmax] + x = np.clip(x, -1200, 600) + x = (1200 + x) / (600 + 1200) # (x-a) / (b-a) * (c-d) + x = (x - np.mean(x)) / np.std(x) + p = np.clip(p, 0, 2) + p = (p) / (2) + p = (p - np.mean(p) / np.std(p)) + return np.concatenate((x[None], p[None])).astype(np.float32) + + + + + def generate_train_batch(self): + + path = self._data[self.patient_ix] + data = self.preprocess_patient(path) + pid = path.split('/')[-1] + data = np.transpose(data, axes=(0, 2, 3, 1)) + + # pad data if smaller than patch_size seen during training. + if np.any([data.shape[dim + 1] < ps for dim, ps in enumerate(self.patch_size)]): + new_shape = [data.shape[0]] + [np.max([data.shape[dim + 1], self.patch_size[dim]]) for dim, ps in enumerate(self.patch_size)] + data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape. + + # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor. + if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: + out_data = data[np.newaxis] + + batch_3D = {'data': out_data, 'pid': pid} + converter = ConvertSegToBoundingBoxCoordinates(dim=3, get_rois_from_seg_flag=True, class_specific_seg_flag=self.cf.class_specific_seg_flag) + batch_3D = converter(**batch_3D) + batch_3D.update({'patient_bb_target': batch_3D['bb_target'], + 'patient_roi_labels': batch_3D['roi_labels'], + 'original_img_shape': out_data.shape}) + + if self.cf.dim == 2: + out_data = np.transpose(data, axes=(3, 0, 1 , 2)) # (z, c, x, y ) + + # if set to not None, add neighbouring slices to each selected slice in channel dimension. + if self.cf.n_3D_context is not None: + slice_range = range(self.cf.n_3D_context, out_data.shape[0] + self.cf.n_3D_context) + out_data = np.pad(out_data, ((self.cf.n_3D_context, self.cf.n_3D_context), (0, 0), (0, 0), (0, 0)), 'constant', constant_values=0) + out_data = np.array( + [np.concatenate([out_data[ii] for ii in range( + slice_id - self.cf.n_3D_context, slice_id + self.cf.n_3D_context + 1)], axis=0) for slice_id in + slice_range]) + + batch_2D = {'data': out_data, 'pid': pid} + + if self.cf.merge_2D_to_3D_preds: + batch_2D.update({'original_img_shape': out_data.shape}) + else: + batch_2D.update({'original_img_shape': out_data.shape}) + + out_batch = batch_3D if self.cf.dim == 3 else batch_2D + patient_batch = out_batch + + # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension. + # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1. + if np.any([data.shape[dim + 1] > self.patch_size[dim] for dim in range(3)]): + patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size) + new_img_batch, new_seg_batch, new_class_targets_batch = [], [], [] + + for cix, c in enumerate(patch_crop_coords_list): + + # if set to not None, add neighbouring slices to each selected slice in channel dimension. + # correct patch_crop coordinates by added slices of 3D context. + if self.cf.dim == 2 and self.cf.n_3D_context is not None: + tmp_c_5 = c[5] + (self.cf.n_3D_context * 2) + if cix == 0: + data = np.pad(data, ((0, 0), (0, 0), (0, 0), (self.cf.n_3D_context, self.cf.n_3D_context)), 'constant', constant_values=0) + else: + tmp_c_5 = c[5] + + new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:tmp_c_5]) + + data = np.array(new_img_batch) # (n_patches, c, x, y, z) + + if self.cf.dim == 2: + if self.cf.n_3D_context is not None: + data = np.transpose(data[:, 0], axes=(0, 3, 1, 2)) + else: + # all patches have z dimension 1 (slices). discard dimension + data = data[..., 0] + + patch_batch = {'data': data, 'pid': pid} + patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) + patch_batch['original_img_shape'] = patient_batch['original_img_shape'] + + out_batch = patch_batch + + self.patient_ix += 1 + if self.patient_ix == len(self.dataset_pids): + self.patient_ix = 0 + + return out_batch + + + +def copy_and_unpack_data(logger, pids, fold_dir, source_dir, target_dir): + + + start_time = time.time() + with open(os.path.join(fold_dir, 'file_list.txt'), 'w') as handle: + for pid in pids: + handle.write('{}_img.npz\n'.format(pid)) + handle.write('{}_rois.npz\n'.format(pid)) + + subprocess.call('rsync -av --files-from {} {} {}'.format(os.path.join(fold_dir, 'file_list.txt'), + source_dir, target_dir), shell=True) + dutils.unpack_dataset(target_dir) + copied_files = os.listdir(target_dir) + logger.info("copying and unpacking data set finsihed : {} files in target dir: {}. took {} sec".format( + len(copied_files), target_dir, np.round(time.time() - start_time, 0))) diff --git a/experiments/pet_ct_tnm_classification/preprocessing.py b/experiments/pet_ct_tnm_classification/preprocessing.py new file mode 100644 index 0000000..134432e --- /dev/null +++ b/experiments/pet_ct_tnm_classification/preprocessing.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python +# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +import SimpleITK as sitk +import numpy as np +from multiprocessing import Pool +import pandas as pd +import numpy.testing as npt +from skimage.transform import resize +import subprocess +from scipy.ndimage.measurements import label as lb +from scipy.ndimage.measurements import center_of_mass as com +import nrrd +from copy import deepcopy +from skimage.segmentation import clear_border + +import configs +cf = configs.configs() + +# if a rater did not identify a nodule, this vote counts as 0s on the pixels. and as 0 == background (or 1?) on the mal. score. +# will this lead to many surpressed nodules. yes. they are not stored in segmentation map and the mal. labels are discarded. +# a pixel counts as foreground, if at least 2 raters drew it as foreground. + + +def get_z_crops(x, ix, min_pix=1500, n_comps=2, rad_crit = 20000): + final_slices = [] + + for six in range(x.shape[0]): + + tx = np.copy(x[six]) < -600 + img_center = np.array(tx.shape) / 2 + tx = clear_border(tx) + + clusters, n_cands = lb(tx) + count = np.unique(clusters, return_counts=True) + keep_comps = np.array([ii for ii in np.argwhere((count[1] > min_pix)) if ii > 0]).astype(int) + + if len(keep_comps) > n_comps - 1: + coms = com(tx, clusters, index=keep_comps) + keep_com = [ix for ix, ii in enumerate(coms[0]) if + ((ii[0] - img_center[0]) ** 2 + (ii[1] - img_center[1]) ** 2 < rad_crit)] + keep_comps = keep_comps[keep_com] + + if len(keep_comps) > n_comps - 1: + final_slices.append(six) + # print('appending', six) + + z_min = np.min(final_slices) - 7 + z_max = np.max(final_slices) + 7 + dist = z_max - z_min + if dist >= 151: + print('trying again with min pix', min_pix + 500, rad_crit - 500, ix, dist) + z_min, z_max = get_z_crops(x, ix, min_pix=min_pix + 500, rad_crit= rad_crit - 500) + if dist <= 43: + print('trying again with one component', min_pix - 100, rad_crit + 100, ix, dist) + z_min, z_max = get_z_crops(x, ix, n_comps=1, min_pix=min_pix - 100, rad_crit = rad_crit + 100) + + print(z_min, z_max, z_max - z_min, ix) + return z_min, z_max + + +def pp_patient(inputs): + + ix, path = inputs + background_categories = ['M1b_brain', 'N_inflammation', 'T_benign', 'T_other'] + # for lix, l in enumerate(patient['class_target']): + # if l in background_categories: + # seg[seg == lix + 1] = 0 + # else: + # seg[seg == lix + 1] = 1 + + selection = [106, 273] + + if ix in selection: + + pid = ix + print('processing', pid, path) + x = sitk.ReadImage(os.path.join(path, 'lsa_ct.nii.gz')) + p = sitk.ReadImage(os.path.join(path, 'lsa_pet.nii.gz')) + readdata, header = nrrd.read(os.path.join(path, 'lsa.seg.nrrd')) + if len(readdata.shape) == 3: + readdata = readdata[None] + spacing = np.diagonal(header['space directions']) + else: + spacing = np.diagonal(header['space directions'][1:, :]) + + origin = header['space origin'] * np.sign(spacing) + labels = [header[k].split('=')[-1] for k in header.keys() if '_Name' in k] + seg = np.zeros_like(readdata[0]) + print('READDATA SHAPE', readdata.shape) + for ix in range(readdata.shape[0]): + if labels[ix] not in background_categories: + seg[readdata[ix] == 1] = ix = 1 + + seg = seg.astype('uint8') + s = sitk.GetImageFromArray(np.transpose(seg, axes=(2, 1, 0))) + s.SetSpacing(abs(spacing)) + s.SetOrigin(origin) + + x_spacing = x.GetSpacing() + if x_spacing[0] < 0.95 or x_spacing[2] < 3: + new_spacing = (0.976562, 0.976562, 3.27) + new_size = [int(x.GetSize()[ii] * x_spacing[ii] / new_spacing[ii]) for ii in range(3)] + reference_image = sitk.Image(new_size, x.GetPixelIDValue()) + reference_image.SetOrigin(x.GetOrigin()) + reference_image.SetDirection(x.GetDirection()) + reference_image.SetSpacing(new_spacing) + + # Resample without any smoothing. + x = sitk.Resample(x, reference_image) + + resampler = sitk.ResampleImageFilter() + resampler.SetReferenceImage(x) + # §resampler.SetInterpolator() # default linear + rp = resampler.Execute(p) + rs = resampler.Execute(s) + pi = sitk.GetArrayFromImage(rp) + si = sitk.GetArrayFromImage(rs) + xi = sitk.GetArrayFromImage(x) + + zmin, zmax = get_z_crops(xi, ix) + + x = xi[zmin:zmax] + p = pi[zmin:zmax] + s = si[zmin:zmax] + + x = np.clip(x, -1200, 600) + x = (1200 + x) / (600 + 1200) # (x-a) / (b-a) * (c-d) + x = (x - np.mean(x)) / np.std(x) + + # p = np.clip(p, 0, 2) + # p = (p) / (2) + p = (p - np.mean(p)) / np.std(p) + + + assert np.all(np.array(x.shape) == np.array(s.shape)) + + img = np.concatenate((x[None], p[None])).astype(np.float32) + + remaining_comps = np.unique(s) + remaining_labels = [ii for ix, ii in enumerate(labels) if ix + 1 in remaining_comps] + s[s > 0] = 1 + + fg_slices = [ii for ii in np.unique(np.argwhere(s != 0)[:, 0])] + + out_df = pd.read_pickle(os.path.join(cf.pp_dir, 'info_df.pickle')) + out_df.loc[len(out_df)] = {'pid': pid, 'raw_pid': path.split('/')[-1], 'class_target': remaining_labels, 'fg_slices': fg_slices} + out_df.to_pickle(os.path.join(cf.pp_dir, 'info_df.pickle')) + + np.save(os.path.join(cf.pp_dir, '{}_rois.npy'.format(pid)), s) + np.save(os.path.join(cf.pp_dir, '{}_img.npy'.format(pid)), img) + + +def collectPaths(in_dir): + + paths = [] + for path, dirs, files in os.walk(in_dir): + pet_files = [f for f in files if 'lsa_pet' in f] + if len(files) > 0 and 'TNM' in path and len(pet_files) > 0: + paths.append(path) + + return paths + +if __name__ == "__main__": + + paths = collectPaths(cf.raw_data_dir) + print('all paths', len(paths)) + + + if not os.path.exists(cf.pp_dir): + os.mkdir(cf.pp_dir) + # df = pd.DataFrame(columns=['pid', 'raw_pid', 'class_target', 'fg_slices']) + # df.to_pickle(os.path.join(cf.pp_dir, 'info_df.pickle')) + + pool = Pool(processes=8) + p1 = pool.map(pp_patient, enumerate(paths), chunksize=1) + pool.close() + pool.join() + # for i in enumerate(paths): + # pp_patient(i) + # + # subprocess.call('cp {} {}'.format(os.path.join(cf.pp_dir, 'info_df.pickle'), os.path.join(cf.pp_dir, 'info_df_bk.pickle')), shell=True) \ No newline at end of file