diff --git a/datasets/lidc/data_loader.py b/datasets/lidc/data_loader.py index fad15fc..4f5b3b0 100644 --- a/datasets/lidc/data_loader.py +++ b/datasets/lidc/data_loader.py @@ -1,1024 +1,1025 @@ # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== ''' Data Loader for the LIDC data set. This dataloader expects preprocessed data in .npy or .npz files per patient and a pandas dataframe containing the meta info e.g. file paths, and some ground-truth info like labels, foreground slice ids. LIDC 4-fold annotations storage capacity problem: keep segmentation gts compressed (npz), unpack at each batch generation. ''' import plotting as plg import os import pickle import time from multiprocessing import Pool import numpy as np import pandas as pd from collections import OrderedDict # batch generator tools from https://github.com/MIC-DKFZ/batchgenerators from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror from batchgenerators.transforms.abstract_transforms import Compose from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter from batchgenerators.transforms.spatial_transforms import SpatialTransform from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform import utils.dataloader_utils as dutils from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates from utils.dataloader_utils import BatchGenerator as BatchGeneratorParent def save_obj(obj, name): """Pickle a python object.""" with open(name + '.pkl', 'wb') as f: pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) def vector(item): """ensure item is vector-like (list or array or tuple) :param item: anything """ if not isinstance(item, (list, tuple, np.ndarray)): item = [item] return item class Dataset(dutils.Dataset): r"""Load a dict holding memmapped arrays and clinical parameters for each patient, evtly subset of those. If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to cf.data_dest. :param cf: config object. :param logger: logger. :param subset_ids: subset of patient/sample identifiers to load from whole set. :param data_sourcedir: directory in which to find data, defaults to cf.data_sourcedir if None. :return: dict with imgs, segs, pids, class_labels, observables """ def __init__(self, cf, logger=None, subset_ids=None, data_sourcedir=None, mode='train'): super(Dataset,self).__init__(cf, data_sourcedir) if mode == 'train' and not cf.training_gts == "merged": self.gt_dir = "patient_gts_sa" self.gt_kind = cf.training_gts else: self.gt_dir = "patient_gts_merged" self.gt_kind = "merged" if logger is not None: logger.info("loading {} ground truths for {}".format(self.gt_kind, 'training and validation' if mode=='train' else 'testing')) p_df = pd.read_pickle(os.path.join(self.data_sourcedir, self.gt_dir, cf.input_df_name)) #exclude_pids = ["0305a", "0447a"] # due to non-bg segmentation but bg mal label in nodules 5728, 8840 #p_df = p_df[~p_df.pid.isin(exclude_pids)] if subset_ids is not None: p_df = p_df[p_df.pid.isin(subset_ids)] if logger is not None: logger.info('subset: selected {} instances from df'.format(len(p_df))) if cf.select_prototype_subset is not None: prototype_pids = p_df.pid.tolist()[:cf.select_prototype_subset] p_df = p_df[p_df.pid.isin(prototype_pids)] if logger is not None: logger.warning('WARNING: using prototyping data subset of length {}!!!'.format(len(p_df))) pids = p_df.pid.tolist() # evtly copy data from data_sourcedir to data_dest if cf.server_env and not hasattr(cf, 'data_dir') and hasattr(cf, "data_dest"): # copy and unpack images file_subset = ["{}_img.npz".format(pid) for pid in pids if not os.path.isfile(os.path.join(cf.data_dest,'{}_img.npy'.format(pid)))] file_subset += [os.path.join(self.data_sourcedir, self.gt_dir, cf.input_df_name)] self.copy_data(cf, file_subset=file_subset, keep_packed=False, del_after_unpack=True) # copy and do not unpack segmentations file_subset = [os.path.join(self.gt_dir, "{}_rois.np*".format(pid)) for pid in pids] keep_packed = not cf.training_gts == "merged" self.copy_data(cf, file_subset=file_subset, keep_packed=keep_packed, del_after_unpack=(not keep_packed)) else: cf.data_dir = self.data_sourcedir ext = 'npy' if self.gt_kind == "merged" else 'npz' imgs = [os.path.join(self.data_dir, '{}_img.npy'.format(pid)) for pid in pids] segs = [os.path.join(self.data_dir, self.gt_dir, '{}_rois.{}'.format(pid, ext)) for pid in pids] orig_class_targets = p_df['class_target'].tolist() data = OrderedDict() if self.gt_kind == 'merged': for ix, pid in enumerate(pids): data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid} data[pid]['fg_slices'] = np.array(p_df['fg_slices'].tolist()[ix]) if 'class' in cf.prediction_tasks: if len(cf.class_labels)==3: # malignancy scores are binarized: (benign: 1-2 --> cl 1, malignant: 3-5 --> cl 2) data[pid]['class_targets'] = np.array([2 if ii >= 3 else 1 for ii in orig_class_targets[ix]], dtype='uint8') elif len(cf.class_labels)==6: # classify each malignancy score data[pid]['class_targets'] = np.array([1 if ii==0.5 else np.round(ii) for ii in orig_class_targets[ix]], dtype='uint8') else: raise Exception("mismatch class labels and data-loading implementations.") else: data[pid]['class_targets'] = np.ones_like(np.array(orig_class_targets[ix]), dtype='uint8') if any(['regression' in task for task in cf.prediction_tasks]): data[pid]["regression_targets"] = np.array([vector(v) for v in orig_class_targets[ix]], dtype='float16') data[pid]["rg_bin_targets"] = np.array( [cf.rg_val_to_bin_id(v) for v in data[pid]["regression_targets"]], dtype='uint8') else: for ix, pid in enumerate(pids): data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid} data[pid]['fg_slices'] = np.array(p_df['fg_slices'].values[ix]) if 'class' in cf.prediction_tasks: # malignancy scores are binarized: (benign: 1-2 --> cl 1, malignant: 3-5 --> cl 2) raise NotImplementedError # todo need to consider bg # data[pid]['class_targets'] = np.array( # [[2 if ii >= 3 else 1 for ii in four_fold_targs] for four_fold_targs in orig_class_targets[ix]]) else: data[pid]['class_targets'] = np.array( [[1 if ii > 0 else 0 for ii in four_fold_targs] for four_fold_targs in orig_class_targets[ix]], dtype='uint8') if any(['regression' in task for task in cf.prediction_tasks]): data[pid]["regression_targets"] = np.array( [[vector(v) for v in four_fold_targs] for four_fold_targs in orig_class_targets[ix]], dtype='float16') data[pid]["rg_bin_targets"] = np.array( [[cf.rg_val_to_bin_id(v) for v in four_fold_targs] for four_fold_targs in data[pid]["regression_targets"]], dtype='uint8') cf.roi_items = cf.observables_rois[:] cf.roi_items += ['class_targets'] if any(['regression' in task for task in cf.prediction_tasks]): cf.roi_items += ['regression_targets'] cf.roi_items += ['rg_bin_targets'] self.data = data self.set_ids = np.array(list(self.data.keys())) self.df = None # merged GTs class BatchGenerator_merged(dutils.BatchGenerator): """ creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D) from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size. Actual patch_size is obtained after data augmentation. :param data: data dictionary as provided by 'load_dataset'. :param batch_size: number of patients to sample for the batch :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target """ def __init__(self, cf, data, name="train"): super(BatchGenerator_merged, self).__init__(cf, data) self.crop_margin = np.array(self.cf.patch_size)/8. #min distance of ROI center to edge of cropped_patch. self.p_fg = 0.5 self.empty_samples_max_ratio = 0.6 self.random_count = int(cf.batch_random_ratio * cf.batch_size) self.class_targets = {k: v["class_targets"] for (k, v) in self._data.items()} self.balance_target_distribution(plot=name=="train") def generate_train_batch(self): # samples patients towards equilibrium of foreground classes on a roi-level after sampling a random ratio # fully random patients batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False)) # target-balanced patients batch_patient_ids += list(np.random.choice(self.dataset_pids, size=self.batch_size-self.random_count, replace=False, p=self.p_probs)) batch_data, batch_segs, batch_pids, batch_patient_labels = [], [], [], [] batch_roi_items = {name: [] for name in self.cf.roi_items} # record roi count of classes in batch batch_roi_counts = np.zeros((len(self.unique_ts),), dtype='uint32') batch_empty_counts = np.zeros((len(self.unique_ts),), dtype='uint32') # empty count for full bg samples (empty slices in 2D/patients in 3D) per class for sample in range(self.batch_size): patient = self._data[batch_patient_ids[sample]] data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis] seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0)) batch_pids.append(patient['pid']) (c, y, x, z) = data.shape if self.cf.dim == 2: elig_slices, choose_fg = [], False if len(patient['fg_slices']) > 0: if np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or \ np.random.rand(1)<=self.p_fg: # fg is to be picked for tix in np.argsort(batch_roi_counts): # pick slices of patient that have roi of sought-for target # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero( patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0])-1] == self.unique_ts[tix]) > 0] if len(elig_slices) > 0: choose_fg = True break else: # pick bg elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices']) if len(elig_slices)>0: sl_pick_ix = np.random.choice(elig_slices, size=None) else: sl_pick_ix = np.random.choice(z, size=None) data = data[..., sl_pick_ix] seg = seg[..., sl_pick_ix] # pad data if smaller than pre_crop_size. if np.any([data.shape[dim + 1] < ps for dim, ps in enumerate(self.cf.pre_crop_size)]): new_shape = [np.max([data.shape[dim + 1], ps]) for dim, ps in enumerate(self.cf.pre_crop_size)] data = dutils.pad_nd_image(data, new_shape, mode='constant') seg = dutils.pad_nd_image(seg, new_shape, mode='constant') # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg. crop_dims = [dim for dim, ps in enumerate(self.cf.pre_crop_size) if data.shape[dim + 1] > ps] if len(crop_dims) > 0: if self.cf.dim == 3: choose_fg = np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio)\ or np.random.rand(1) <= self.p_fg if choose_fg and np.any(seg): available_roi_ids = np.unique(seg)[1:] for tix in np.argsort(batch_roi_counts): elig_roi_ids = available_roi_ids[patient[self.balance_target][available_roi_ids-1] == self.unique_ts[tix]] if len(elig_roi_ids)>0: seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None)) break roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)] assert seg[tuple(roi_anchor_pixel)] > 0 # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by # distance to the desired ROI < patch_size /2. # (here final patch size to account for center_crop after data augmentation). sample_seg_center = {} for ii in crop_dims: low = np.max((self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] - (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) high = np.min((data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] + (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) # happens if lesion on the edge of the image. dont care about roi anymore, # just make sure pre-crop is inside image. if low >= high: low = data.shape[ii + 1] // 2 - (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) high = data.shape[ii + 1] // 2 + (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) sample_seg_center[ii] = np.random.randint(low=low, high=high) else: # not guaranteed to be empty. probability of emptiness depends on the data. sample_seg_center = {ii: np.random.randint(low=self.cf.pre_crop_size[ii]//2, high=data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2) for ii in crop_dims} for ii in crop_dims: min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2) max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2) data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1) seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii) batch_data.append(data) batch_segs.append(seg[np.newaxis]) for o in batch_roi_items: #after loop, holds every entry of every batchpatient per roi-item batch_roi_items[o].append(patient[o]) if self.cf.dim == 3: for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero==0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 elif self.cf.dim == 2: for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero == 0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 data = np.array(batch_data).astype(np.float16) seg = np.array(batch_segs).astype(np.uint8) batch = {'data': data, 'seg': seg, 'pid': batch_pids, 'roi_counts':batch_roi_counts, 'empty_counts': batch_empty_counts} for key,val in batch_roi_items.items(): #extend batch dic by roi-wise items (obs, class ids, regression vectors...) batch[key] = np.array(val) return batch class PatientBatchIterator_merged(dutils.PatientBatchIterator): """ creates a test generator that iterates over entire given dataset returning 1 patient per batch. Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actualy evaluation (done in 3D), if willing to accept speed-loss during training. :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or batch_size = n_2D_patches in 2D . """ def __init__(self, cf, data): # threads in augmenter super(PatientBatchIterator_merged, self).__init__(cf, data) self.patient_ix = 0 self.patch_size = cf.patch_size + [1] if cf.dim == 2 else cf.patch_size def generate_train_batch(self, pid=None): if pid is None: pid = self.dataset_pids[self.patient_ix] patient = self._data[pid] data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0)) seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0)) # pad data if smaller than patch_size seen during training. if np.any([data.shape[dim] < ps for dim, ps in enumerate(self.patch_size)]): new_shape = [np.max([data.shape[dim], self.patch_size[dim]]) for dim, ps in enumerate(self.patch_size)] data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape. seg = dutils.pad_nd_image(seg, new_shape) # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor. if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: out_data = data[np.newaxis, np.newaxis] out_seg = seg[np.newaxis, np.newaxis] batch_3D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_3D[o] = np.array([patient[o]]) converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg) batch_3D = converter(**batch_3D) batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_3D["patient_" + o] = batch_3D[o] if self.cf.dim == 2: out_data = np.transpose(data, axes=(2, 0, 1))[:, np.newaxis] # (z, c, x, y ) out_seg = np.transpose(seg, axes=(2, 0, 1))[:, np.newaxis] batch_2D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_2D[o] = np.repeat(np.array([patient[o]]), out_data.shape[0], axis=0) converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg) batch_2D = converter(**batch_2D) if self.cf.merge_2D_to_3D_preds: batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_3D[o] else: batch_2D.update({'patient_bb_target': batch_2D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_2D[o] out_batch = batch_3D if self.cf.dim == 3 else batch_2D out_batch.update({'pid': np.array([patient['pid']] * len(out_data))}) # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension. # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1. if np.any([data.shape[dim] > self.patch_size[dim] for dim in range(3)]): patient_batch = out_batch patch_crop_coords_list = dutils.get_patch_crop_coords(data, self.patch_size) new_img_batch, new_seg_batch = [], [] for cix, c in enumerate(patch_crop_coords_list): seg_patch = seg[c[0]:c[1], c[2]: c[3], c[4]:c[5]] new_seg_batch.append(seg_patch) tmp_c_5 = c[5] new_img_batch.append(data[c[0]:c[1], c[2]:c[3], c[4]:tmp_c_5]) data = np.array(new_img_batch)[:, np.newaxis] # (n_patches, c, x, y, z) seg = np.array(new_seg_batch)[:, np.newaxis] # (n_patches, 1, x, y, z) if self.cf.dim == 2: # all patches have z dimension 1 (slices). discard dimension data = data[..., 0] seg = seg[..., 0] patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'), 'pid': np.array([patient['pid']] * data.shape[0])} for o in self.cf.roi_items: patch_batch[o] = np.repeat(np.array([patient[o]]), len(patch_crop_coords_list), axis=0) # patient-wise (orig) batch info for putting the patches back together after prediction for o in self.cf.roi_items: patch_batch["patient_" + o] = patient_batch['patient_' + o] if self.cf.dim == 2: # this could also be named "unpatched_2d_roi_items" patch_batch["patient_" + o + "_2d"] = patient_batch[o] # adding patient-wise data and seg adds about 2 GB of additional RAM consumption to a batch 20x288x288 # and enables calculating test-dice/viewing patient-wise results in test # remove, but also remove dice from metrics, when like to save memory patch_batch['patient_data'] = patient_batch['data'] patch_batch['patient_seg'] = patient_batch['seg'] patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] if self.cf.dim == 2: patch_batch['patient_bb_target_2d'] = patient_batch['bb_target'] patch_batch['original_img_shape'] = patient_batch['original_img_shape'] converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, self.cf.class_specific_seg) patch_batch = converter(**patch_batch) out_batch = patch_batch self.patient_ix += 1 if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 return out_batch # single-annotator GTs class BatchGenerator_sa(BatchGeneratorParent): """ creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D) from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size. Actual patch_size is obtained after data augmentation. :param data: data dictionary as provided by 'load_dataset'. :param batch_size: number of patients to sample for the batch :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target """ # noinspection PyMethodOverriding def balance_target_distribution(self, rater, plot=False): """ :param rater: for which rater slot to generate the distribution :param self.targets: dic holding {patient_specifier : patient-wise-unique ROI targets} :param plot: whether to plot the generated patient distributions :return: probability distribution over all pids. draw without replace from this. """ unique_ts = np.unique([v[rater] for pat in self.targets.values() for v in pat]) sample_stats = pd.DataFrame(columns=[str(ix) + suffix for ix in unique_ts for suffix in ["", "_bg"]], index=list(self.targets.keys())) for pid in sample_stats.index: for targ in unique_ts: fg_count = 0 if len(self.targets[pid]) == 0 else np.count_nonzero(self.targets[pid][:, rater] == targ) sample_stats.loc[pid, str(targ)] = int(fg_count > 0) sample_stats.loc[pid, str(targ) + "_bg"] = int(fg_count == 0) target_stats = sample_stats.agg( ("sum", lambda col: col.sum() / len(self._data)), axis=0, sort=False).rename({"": "relative"}) anchor = 1. - target_stats.loc["relative"].iloc[0] fg_bg_weights = anchor / target_stats.loc["relative"] cum_weights = anchor * len(fg_bg_weights) fg_bg_weights /= cum_weights p_probs = sample_stats.apply(self.sample_targets_to_weights, args=(fg_bg_weights,), axis=1).sum(axis=1) p_probs = p_probs / p_probs.sum() if plot: print("Rater: {}. Applying class-weights:\n {}".format(rater, fg_bg_weights)) if len(sample_stats.columns) == 2: # assert that probs are calc'd correctly: # (p_probs * sample_stats["1"]).sum() == (p_probs * sample_stats["1_bg"]).sum() # only works if one label per patient (multi-label expectations depend on multi-label occurences). for rater in range(self.rater_bsize): expectations = [] for targ in sample_stats.columns: expectations.append((p_probs[rater] * sample_stats[targ]).sum()) assert np.allclose(expectations, expectations[0], atol=1e-4), "expectation values for fgs/bgs: {}".format( expectations) if plot: plg.plot_batchgen_distribution(self.cf, self.dataset_pids, p_probs, self.balance_target, out_file=os.path.join(self.plot_dir, "train_gen_distr_"+str(self.cf.fold)+"_rater"+str(rater)+".png")) return p_probs, unique_ts, sample_stats def __init__(self, cf, data, name="train"): super(BatchGenerator_sa, self).__init__(cf, data) self.name = name self.crop_margin = np.array(self.cf.patch_size) / 8. # min distance of ROI center to edge of cropped_patch. self.p_fg = 0.5 self.empty_samples_max_ratio = 0.6 self.random_count = int(cf.batch_random_ratio * cf.batch_size) self.rater_bsize = 4 unique_ts_total = set() self.p_probs = [] self.sample_stats = [] # todo resolve pickling error # p = Pool(processes=min(self.rater_bsize, cf.n_workers)) # mp_res = p.starmap(self.balance_target_distribution, [(r, name=="train") for r in range(self.rater_bsize)]) # p.close() # p.join() # for r, res in enumerate(mp_res): # p_probs, unique_ts, sample_stats = res # self.p_probs.append(p_probs) # self.sample_stats.append(sample_stats) # unique_ts_total.update(unique_ts) for r in range(self.rater_bsize): # todo multiprocess. takes forever p_probs, unique_ts, sample_stats = self.balance_target_distribution(r, plot=name == "train") self.p_probs.append(p_probs) self.sample_stats.append(sample_stats) unique_ts_total.update(unique_ts) self.unique_ts = sorted(list(unique_ts_total)) self.stats = {"roi_counts": np.zeros(len(self.unique_ts,), dtype='uint32'), "empty_counts": np.zeros(len(self.unique_ts,), dtype='uint32')} def generate_train_batch(self): rater = np.random.randint(self.rater_bsize) # samples patients towards equilibrium of foreground classes on a roi-level (after randomly sampling the ratio batch_random_ratio). # random patients batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False)) # target-balanced patients batch_patient_ids += list(np.random.choice(self.dataset_pids, size=self.batch_size-self.random_count, replace=False, p=self.p_probs[rater])) batch_data, batch_segs, batch_pids, batch_patient_labels = [], [], [], [] batch_roi_items = {name: [] for name in self.cf.roi_items} # record roi count of classes in batch batch_roi_counts = np.zeros((len(self.unique_ts),), dtype='uint32') batch_empty_counts = np.zeros((len(self.unique_ts),), dtype='uint32') # empty count for full bg samples (empty slices in 2D/patients in 3D) for sample in range(self.batch_size): patient = self._data[batch_patient_ids[sample]] patient_balance_ts = np.array([roi[rater] for roi in patient[self.balance_target]]) data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis] seg = np.load(patient['seg'], mmap_mode='r') seg = np.transpose(seg[list(seg.keys())[0]][rater], axes=(1, 2, 0)) batch_pids.append(patient['pid']) (c, y, x, z) = data.shape if self.cf.dim == 2: elig_slices, choose_fg = [], False if len(patient['fg_slices']) > 0: if np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or \ np.random.rand(1) <= self.p_fg: # fg is to be picked for tix in np.argsort(batch_roi_counts): # pick slices of patient that have roi of sought-for target # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero( patient_balance_ts[np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] == self.unique_ts[tix]) > 0] if len(elig_slices) > 0: choose_fg = True break else: # pick bg elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices'][rater]) if len(elig_slices) > 0: sl_pick_ix = np.random.choice(elig_slices, size=None) else: sl_pick_ix = np.random.choice(z, size=None) data = data[..., sl_pick_ix] seg = seg[..., sl_pick_ix] # pad data if smaller than pre_crop_size. if np.any([data.shape[dim + 1] < ps for dim, ps in enumerate(self.cf.pre_crop_size)]): new_shape = [np.max([data.shape[dim + 1], ps]) for dim, ps in enumerate(self.cf.pre_crop_size)] data = dutils.pad_nd_image(data, new_shape, mode='constant') seg = dutils.pad_nd_image(seg, new_shape, mode='constant') # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg. crop_dims = [dim for dim, ps in enumerate(self.cf.pre_crop_size) if data.shape[dim + 1] > ps] if len(crop_dims) > 0: if self.cf.dim == 3: choose_fg = np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or \ np.random.rand(1) <= self.p_fg if choose_fg and np.any(seg): available_roi_ids = np.unique(seg[seg>0]) assert np.all(patient_balance_ts[available_roi_ids-1]>0), "trying to choose roi with rating 0" for tix in np.argsort(batch_roi_counts): elig_roi_ids = available_roi_ids[ patient_balance_ts[available_roi_ids-1] == self.unique_ts[tix] ] if len(elig_roi_ids)>0: seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None)) roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)] break assert seg[tuple(roi_anchor_pixel)] > 0, "roi_anchor_pixel not inside roi: {}, pb_ts {}, elig ids {}".format(tuple(roi_anchor_pixel), patient_balance_ts, elig_roi_ids) # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by # distance to the desired ROI < patch_size /2. # (here final patch size to account for center_crop after data augmentation). sample_seg_center = {} for ii in crop_dims: low = np.max((self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] - (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) high = np.min((data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] + (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) # happens if lesion on the edge of the image. dont care about roi anymore, # just make sure pre-crop is inside image. if low >= high: low = data.shape[ii + 1] // 2 - (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) high = data.shape[ii + 1] // 2 + (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) sample_seg_center[ii] = np.random.randint(low=low, high=high) else: # not guaranteed to be empty. probability of emptiness depends on the data. sample_seg_center = {ii: np.random.randint(low=self.cf.pre_crop_size[ii]//2, high=data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2) for ii in crop_dims} for ii in crop_dims: min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2) max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2) data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1) seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii) batch_data.append(data) batch_segs.append(seg[np.newaxis]) for o in batch_roi_items: #after loop, holds every entry of every batchpatient per roi-item batch_roi_items[o].append([roi[rater] for roi in patient[o]]) if self.cf.dim == 3: for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero==0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 elif self.cf.dim == 2: for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero == 0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 data = np.array(batch_data).astype('float16') seg = np.array(batch_segs).astype('uint8') batch = {'data': data, 'seg': seg, 'pid': batch_pids, 'rater_id': rater, 'roi_counts': batch_roi_counts, 'empty_counts': batch_empty_counts} for key,val in batch_roi_items.items(): #extend batch dic by roi-wise items (obs, class ids, regression vectors...) batch[key] = np.array(val) return batch class PatientBatchIterator_sa(dutils.PatientBatchIterator): """ creates a test generator that iterates over entire given dataset returning 1 patient per batch. Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actual evaluation (done in 3D), if willing to accept speed loss during training. :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or batch_size = n_2D_patches in 2D . This is the data & gt loader for the 4-fold single-annotator GTs: each data input has separate annotations of 4 annotators. the way the pipeline is currently setup, the single-annotator GTs are only used if training with validation mode val_patient; during testing the Iterator with the merged GTs is used. # todo mode val_patient not implemented yet (since very slow). would need to sample from all available rater GTs. """ def __init__(self, cf, data): #threads in augmenter super(PatientBatchIterator_sa, self).__init__(cf, data) self.cf = cf self.patient_ix = 0 self.dataset_pids = list(self._data.keys()) self.patch_size = cf.patch_size+[1] if cf.dim==2 else cf.patch_size self.rater_bsize = 4 def generate_train_batch(self, pid=None): if pid is None: pid = self.dataset_pids[self.patient_ix] patient = self._data[pid] data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0)) # all gts are 4-fold and npz! seg = np.load(patient['seg'], mmap_mode='r') seg = np.transpose(seg[list(seg.keys())[0]], axes=(0, 2, 3, 1)) # pad data if smaller than patch_size seen during training. if np.any([data.shape[dim] < ps for dim, ps in enumerate(self.patch_size)]): new_shape = [np.max([data.shape[dim], self.patch_size[dim]]) for dim, ps in enumerate(self.patch_size)] data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape. seg = dutils.pad_nd_image(seg, new_shape) # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor. if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: out_data = data[np.newaxis, np.newaxis] out_seg = seg[:, np.newaxis] batch_3D = {'data': out_data, 'seg': out_seg} for item in self.cf.roi_items: batch_3D[item] = [] for r in range(self.rater_bsize): for item in self.cf.roi_items: batch_3D[item].append(np.array([roi[r] for roi in patient[item]])) converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg) batch_3D = converter(**batch_3D) batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_3D["patient_" + o] = batch_3D[o] if self.cf.dim == 2: out_data = np.transpose(data, axes=(2, 0, 1))[:, np.newaxis] # (z, c, y, x ) out_seg = np.transpose(seg, axes=(0, 3, 1, 2))[:, :, np.newaxis] # (n_raters, z, 1, y,x) batch_2D = {'data': out_data} for item in ["seg", "bb_target"]+self.cf.roi_items: batch_2D[item] = [] converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg) for r in range(self.rater_bsize): tmp_batch = {"seg": out_seg[r]} for item in self.cf.roi_items: tmp_batch[item] = np.repeat(np.array([[roi[r] for roi in patient[item]]]), out_data.shape[0], axis=0) tmp_batch = converter(**tmp_batch) for item in ["seg", "bb_target"]+self.cf.roi_items: batch_2D[item].append(tmp_batch[item]) # for item in ["seg", "bb_target"]+self.cf.roi_items: # batch_2D[item] = np.array(batch_2D[item]) if self.cf.merge_2D_to_3D_preds: batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_3D[o] else: batch_2D.update({'patient_bb_target': batch_2D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_2D[o] out_batch = batch_3D if self.cf.dim == 3 else batch_2D out_batch.update({'pid': np.array([patient['pid']] * out_data.shape[0])}) # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension. # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1. if np.any([data.shape[dim] > self.patch_size[dim] for dim in range(3)]): patient_batch = out_batch patch_crop_coords_list = dutils.get_patch_crop_coords(data, self.patch_size) new_img_batch = [] new_seg_batch = [] for cix, c in enumerate(patch_crop_coords_list): seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]] new_seg_batch.append(seg_patch) tmp_c_5 = c[5] new_img_batch.append(data[c[0]:c[1], c[2]:c[3], c[4]:tmp_c_5]) data = np.array(new_img_batch)[:, np.newaxis] # (n_patches, c, x, y, z) seg = np.transpose(np.array(new_seg_batch), axes=(1,0,2,3,4))[:,:,np.newaxis] # (n_raters, n_patches, x, y, z) if self.cf.dim == 2: # all patches have z dimension 1 (slices). discard dimension data = data[..., 0] seg = seg[..., 0] patch_batch = {'data': data.astype('float32'), 'pid': np.array([patient['pid']] * data.shape[0])} # for o in self.cf.roi_items: # patch_batch[o] = np.repeat(np.array([patient[o]]), len(patch_crop_coords_list), axis=0) converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, self.cf.class_specific_seg) for item in ["seg", "bb_target"]+self.cf.roi_items: patch_batch[item] = [] # coord_list = [np.min(seg_ixs[:, 1]) - 1, np.min(seg_ixs[:, 2]) - 1, np.max(seg_ixs[:, 1]) + 1, # IndexError: index 2 is out of bounds for axis 1 with size 2 for r in range(self.rater_bsize): tmp_batch = {"seg": seg[r]} for item in self.cf.roi_items: tmp_batch[item] = np.repeat(np.array([[roi[r] for roi in patient[item]]]), len(patch_crop_coords_list), axis=0) tmp_batch = converter(**tmp_batch) for item in ["seg", "bb_target"]+self.cf.roi_items: patch_batch[item].append(tmp_batch[item]) # patient-wise (orig) batch info for putting the patches back together after prediction for o in self.cf.roi_items: patch_batch["patient_" + o] = patient_batch['patient_'+o] if self.cf.dim==2: # this could also be named "unpatched_2d_roi_items" patch_batch["patient_"+o+"_2d"] = patient_batch[o] # adding patient-wise data and seg adds about 2 GB of additional RAM consumption to a batch 20x288x288 # and enables calculating test-dice/viewing patient-wise results in test # remove, but also remove dice from metrics, if you like to save memory patch_batch['patient_data'] = patient_batch['data'] patch_batch['patient_seg'] = patient_batch['seg'] patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] if self.cf.dim==2: patch_batch['patient_bb_target_2d'] = patient_batch['bb_target'] patch_batch['original_img_shape'] = patient_batch['original_img_shape'] out_batch = patch_batch self.patient_ix += 1 if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 return out_batch def create_data_gen_pipeline(cf, patient_data, is_training=True): """ create multi-threaded train/val/test batch generation and augmentation pipeline. :param cf: configs object. :param patient_data: dictionary containing one dictionary per patient in the train/test subset. :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing) :return: multithreaded_generator """ BG_name = "train" if is_training else "val" data_gen = BatchGenerator_merged(cf, patient_data, name=BG_name) if cf.training_gts=='merged' else \ BatchGenerator_sa(cf, patient_data, name=BG_name) # add transformations to pipeline. my_transforms = [] if is_training: if cf.da_kwargs["mirror"]: mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes']) my_transforms.append(mirror_transform) spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop']) my_transforms.append(spatial_transform) else: my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) if cf.create_bounding_box_targets: my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg)) all_transforms = Compose(my_transforms) - multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers)) + multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=data_gen.n_filled_threads, + seeds=range(data_gen.n_filled_threads)) return multithreaded_generator def get_train_generators(cf, logger, data_statistics=True): """ wrapper function for creating the training batch generator pipeline. returns the train/val generators. selects patients according to cv folds (generated by first run/fold of experiment): splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set) If cf.held_out_test_set is True, adds the test split to the training data. """ dataset = Dataset(cf, logger) dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits) dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle')) set_splits = dataset.fg.splits test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1) train_ids = np.concatenate(set_splits, axis=0) if cf.held_out_test_set: train_ids = np.concatenate((train_ids, test_ids), axis=0) test_ids = [] train_data = {k: v for (k, v) in dataset.data.items() if k in train_ids} val_data = {k: v for (k, v) in dataset.data.items() if k in val_ids} logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids), len(test_ids))) if data_statistics: dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids}, plot_dir=os.path.join(cf.plot_dir,"dataset")) batch_gen = {} batch_gen['train'] = create_data_gen_pipeline(cf, train_data, is_training=True) batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, is_training=False) if cf.val_mode == 'val_patient': assert cf.training_gts == 'merged', 'val_patient not yet implemented for sa gts' batch_gen['val_patient'] = PatientBatchIterator_merged(cf, val_data) if cf.training_gts=='merged' \ else PatientBatchIterator_sa(cf, val_data) batch_gen['n_val'] = len(val_data) if cf.max_val_patients=="all" else min(len(val_data), cf.max_val_patients) else: batch_gen['n_val'] = cf.num_val_batches return batch_gen def get_test_generator(cf, logger): """ wrapper function for creating the test batch generator pipeline. selects patients according to cv folds (generated by first run/fold of experiment) If cf.held_out_test_set is True, gets the data from an external folder instead. """ if cf.held_out_test_set: sourcedir = cf.test_data_sourcedir test_ids = None else: sourcedir = None with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle: set_splits = pickle.load(handle) test_ids = set_splits[cf.fold] test_data = Dataset(cf, logger, subset_ids=test_ids, data_sourcedir=sourcedir, mode="test").data logger.info("data set loaded with: {} test patients".format(len(test_ids))) batch_gen = {} batch_gen['test'] = PatientBatchIterator_merged(cf, test_data) batch_gen['n_test'] = len(test_ids) if cf.max_test_patients == "all" else min(cf.max_test_patients, len(test_ids)) return batch_gen if __name__ == "__main__": import sys sys.path.append('../') import plotting as plg import utils.exp_utils as utils from configs import Configs cf = Configs() cf.batch_size = 3 #dataset_path = os.path.dirname(os.path.realpath(__file__)) #exp_path = os.path.join(dataset_path, "experiments/dev") #cf = utils.prep_exp(dataset_path, exp_path, server_env=False, use_stored_settings=False, is_training=True) cf.created_fold_id_pickle = False total_stime = time.time() times = {} # cf.server_env = True # cf.data_dir = "experiments/dev_data" # dataset = Dataset(cf) # patient = dataset['Master_00018'] cf.exp_dir = "experiments/dev/" cf.plot_dir = cf.exp_dir + "plots" os.makedirs(cf.exp_dir, exist_ok=True) cf.fold = 0 logger = utils.get_logger(cf.exp_dir) gens = get_train_generators(cf, logger) train_loader = gens['train'] for i in range(1): stime = time.time() #ex_batch = next(train_loader) print("train batch", i) times["train_batch"] = time.time() - stime #plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch.png", show_gt_labels=True) # # # with open(os.path.join(cf.exp_dir, "fold_"+str(cf.fold), "BatchGenerator_stats.txt"), mode="w") as file: # # train_loader.generator.print_stats(logger, file) # val_loader = gens['val_sampling'] stime = time.time() ex_batch = next(val_loader) times["val_batch"] = time.time() - stime stime = time.time() #plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch.png", show_gt_labels=True, plot_mods=False, # show_info=False) times["val_plot"] = time.time() - stime # test_loader = get_test_generator(cf, logger)["test"] stime = time.time() ex_batch = test_loader.generate_train_batch() times["test_batch"] = time.time() - stime stime = time.time() plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png", get_time=False)#, sample_picks=[0,1,2,3]) times["test_patchbatch_plot"] = time.time() - stime # ex_batch['data'] = ex_batch['patient_data'] # ex_batch['seg'] = ex_batch['patient_seg'] # ex_batch['bb_target'] = ex_batch['patient_bb_target'] # for item in cf.roi_items: # ex_batch[] # stime = time.time() # #ex_batch = next(test_loader) # ex_batch = next(test_loader) # plg.view_batch(cf, ex_batch, show_gt_labels=False, show_gt_boxes=True, patient_items=True,# vol_slice_picks=[146,148, 218,220], # out_file="experiments/dev/dev_expatientbatch.png") # , sample_picks=[0,1,2,3]) # times["test_patient_batch_plot"] = time.time() - stime print("Times recorded throughout:") for (k, v) in times.items(): print(k, "{:.2f}".format(v)) mins, secs = divmod((time.time() - total_stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) diff --git a/datasets/prostate/data_loader.py b/datasets/prostate/data_loader.py index 69c53e6..23797a3 100644 --- a/datasets/prostate/data_loader.py +++ b/datasets/prostate/data_loader.py @@ -1,716 +1,716 @@ __author__ = '' #credit derives from Paul Jaeger, Simon Kohl import os import time import warnings from collections import OrderedDict import pickle import numpy as np import pandas as pd # batch generator tools from https://github.com/MIC-DKFZ/batchgenerators from batchgenerators.augmentations.utils import resize_image_by_padding, center_crop_2D_image, center_crop_3D_image from batchgenerators.dataloading.data_loader import SlimDataLoaderBase from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror from batchgenerators.transforms.abstract_transforms import Compose from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter from batchgenerators.dataloading import SingleThreadedAugmenter from batchgenerators.transforms.spatial_transforms import SpatialTransform from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform #from batchgenerators.transforms.utility_transforms import ConvertSegToBoundingBoxCoordinates from batchgenerators.transforms import AbstractTransform from batchgenerators.transforms.color_transforms import GammaTransform #sys.path.append(os.path.dirname(os.path.realpath(__file__))) #import utils.exp_utils as utils import utils.dataloader_utils as dutils from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates import data_manager as dmanager def load_obj(file_path): with open(file_path, 'rb') as handle: return pickle.load(handle) def id_to_spec(id, base_spec): """Construct subject specifier from base string and an integer subject number.""" num_zeros = 5 - len(str(id)) assert num_zeros>=0, "id_to_spec: patient id too long to fit into 5 figures" return base_spec + '_' + ('').join(['0'] * num_zeros) + str(id) def convert_3d_to_2d_generator(data_dict, shape="bcxyz"): """Fold/Shape z-dimension into color-channel. :param shape: bcxyz or bczyx :return: shape b(c*z)xy or b(c*z)yx """ if shape=="bcxyz": data_dict['data'] = np.transpose(data_dict['data'], axes=(0,1,4,3,2)) data_dict['seg'] = np.transpose(data_dict['seg'], axes=(0,1,4,3,2)) elif shape=="bczyx": pass else: raise Exception("unknown datashape {} in 3d_to_2d transform converter".format(shape)) shp = data_dict['data'].shape data_dict['orig_shape_data'] = shp seg_shp = data_dict['seg'].shape data_dict['orig_shape_seg'] = seg_shp data_dict['data'] = data_dict['data'].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4])) data_dict['seg'] = data_dict['seg'].reshape((seg_shp[0], seg_shp[1] * seg_shp[2], seg_shp[3], seg_shp[4])) return data_dict def convert_2d_to_3d_generator(data_dict, shape="bcxyz"): """Unfold z-dimension from color-channel. data needs to be in shape bcxy or bcyx, x,y dims won't be swapped relative to each other. :param shape: target shape, bcxyz or bczyx """ shp = data_dict['orig_shape_data'] cur_shape = data_dict['data'].shape seg_shp = data_dict['orig_shape_seg'] cur_shape_seg = data_dict['seg'].shape data_dict['data'] = data_dict['data'].reshape((shp[0], shp[1], shp[2], cur_shape[-2], cur_shape[-1])) data_dict['seg'] = data_dict['seg'].reshape((seg_shp[0], seg_shp[1], seg_shp[2], cur_shape_seg[-2], cur_shape_seg[-1])) if shape=="bcxyz": data_dict['data'] = np.transpose(data_dict['data'], axes=(0,1,4,3,2)) data_dict['seg'] = np.transpose(data_dict['seg'], axes=(0,1,4,3,2)) return data_dict class Convert3DTo2DTransform(AbstractTransform): def __init__(self): pass def __call__(self, **data_dict): return convert_3d_to_2d_generator(data_dict) class Convert2DTo3DTransform(AbstractTransform): def __init__(self): pass def __call__(self, **data_dict): return convert_2d_to_3d_generator(data_dict) def vector(item): """ensure item is vector-like (list or array or tuple) :param item: anything """ if not isinstance(item, (list, tuple, np.ndarray)): item = [item] return item class Dataset(dutils.Dataset): r"""Load a dict holding memmapped arrays and clinical parameters for each patient, evtly subset of those. If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to cf.data_dest. :param cf: config file :param data_dir: directory in which to find data, defaults to cf.data_dir if None. :return: dict with imgs, segs, pids, class_labels, observables """ def __init__(self, cf, logger=None, subset_ids=None, data_sourcedir=None): super(Dataset,self).__init__(cf, data_sourcedir=data_sourcedir) info_dict = load_obj(cf.info_dict_path) if subset_ids is not None: pids = subset_ids if logger is None: print('subset: selected {} instances from df'.format(len(pids))) else: logger.info('subset: selected {} instances from df'.format(len(pids))) else: pids = list(info_dict.keys()) #evtly copy data from data_rootdir to data_dir if cf.server_env and not hasattr(cf, "data_dir"): file_subset = [info_dict[pid]['img'][:-3]+"*" for pid in pids] file_subset+= [info_dict[pid]['seg'][:-3]+"*" for pid in pids] file_subset += [cf.info_dict_path] self.copy_data(cf, file_subset=file_subset) cf.data_dir = self.data_dir img_paths = [os.path.join(self.data_dir, info_dict[pid]['img']) for pid in pids] seg_paths = [os.path.join(self.data_dir, info_dict[pid]['seg']) for pid in pids] # load all subject files self.data = OrderedDict() for i, pid in enumerate(pids): subj_spec = id_to_spec(pid, cf.prepro['dir_spec']) subj_data = {'pid':pid, "spec":subj_spec} subj_data['img'] = img_paths[i] subj_data['seg'] = seg_paths[i] #read, add per-roi labels for obs in cf.observables_patient+cf.observables_rois: subj_data[obs] = np.array(info_dict[pid][obs]) if 'class' in self.cf.prediction_tasks: subj_data['class_targets'] = np.array(info_dict[pid]['roi_classes'], dtype='uint8') + 1 else: subj_data['class_targets'] = np.ones_like(np.array(info_dict[pid]['roi_classes']), dtype='uint8') if any(['regression' in task for task in self.cf.prediction_tasks]): if hasattr(cf, "rg_map"): subj_data["regression_targets"] = np.array([vector(cf.rg_map[v]) for v in info_dict[pid][cf.regression_target]], dtype='float16') else: subj_data["regression_targets"] = np.array([vector(v) for v in info_dict[pid][cf.regression_target]], dtype='float16') subj_data["rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in subj_data["regression_targets"]], dtype='uint8') subj_data['fg_slices'] = info_dict[pid]['fg_slices'] self.data[pid] = subj_data cf.roi_items = cf.observables_rois[:] cf.roi_items += ['class_targets'] if any(['regression' in task for task in self.cf.prediction_tasks]): cf.roi_items += ['regression_targets'] cf.roi_items += ['rg_bin_targets'] #cf.patient_items = cf.observables_patient[:] #patient-wise items not used currently self.set_ids = np.array(list(self.data.keys())) self.df = None class BatchGenerator(dutils.BatchGenerator): """ create the training/validation batch generator. Randomly sample batch_size patients from the data set, (draw a random slice if 2D), pad-crop them to equal sizes and merge to an array. :param data: data dictionary as provided by 'load_dataset' :param img_modalities: list of strings ['adc', 'b1500'] from config :param batch_size: number of patients to sample for the batch :param pre_crop_size: equal size for merging the patients to a single array (before the final random-crop in data aug.) :param sample_pids_w_replace: whether to randomly draw pids from dataset for batch generation. if False, step through whole dataset before repition. :return dictionary containing the batch data / seg / pids as lists; the augmenter will later concatenate them into an array. """ def __init__(self, cf, data, n_batches=None, sample_pids_w_replace=True): super(BatchGenerator, self).__init__(cf, data, n_batches) self.dataset_length = len(self._data) self.cf = cf self.sample_pids_w_replace = sample_pids_w_replace self.eligible_pids = list(self._data.keys()) self.chans = cf.channels if cf.channels is not None else np.index_exp[:] assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing" self.p_fg = 0.5 self.empty_samples_max_ratio = 0.6 self.random_count = int(cf.batch_random_ratio * cf.batch_size) self.balance_target_distribution(plot=sample_pids_w_replace) self.stats = {"roi_counts" : np.zeros((len(self.unique_ts),), dtype='uint32'), "empty_samples_count" : 0} def generate_train_batch(self): #everything done in here is per batch #print statements in here get confusing due to multithreading if self.sample_pids_w_replace: # fully random patients batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False)) # target-balanced patients batch_patient_ids += list(np.random.choice( self.dataset_pids, size=self.batch_size - self.random_count, replace=False, p=self.p_probs)) else: batch_patient_ids = np.random.choice(self.eligible_pids, size=self.batch_size, replace=False) if self.sample_pids_w_replace == False: self.eligible_pids = [pid for pid in self.eligible_pids if pid not in batch_patient_ids] if len(self.eligible_pids) < self.batch_size: self.eligible_pids = self.dataset_pids batch_data, batch_segs, batch_patient_specs = [], [], [] batch_roi_items = {name: [] for name in self.cf.roi_items} #record roi count of classes in batch batch_roi_counts, empty_samples_count = np.zeros((len(self.unique_ts),), dtype='uint32'), 0 #empty count for full bg samples (empty slices in 2D/patients in 3D) for sample in range(self.batch_size): patient = self._data[batch_patient_ids[sample]] #swap dimensions from (c,)z,y,x to (c,)y,x,z or h,w,d to ease 2D/3D-case handling data = np.transpose(np.load(patient['img'], mmap_mode='r'), axes=(0, 2, 3, 1))[self.chans] seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0)) (c,y,x,z) = data.shape #original data is 3D MRIs, so need to pick (e.g. randomly) single slice to make it 2D, #consider batch roi-class balance if self.cf.dim == 2: elig_slices, choose_fg = [], False if self.sample_pids_w_replace and len(patient['fg_slices']) > 0: if empty_samples_count / self.batch_size >= self.empty_samples_max_ratio or np.random.rand( 1) <= self.p_fg: # fg is to be picked for tix in np.argsort(batch_roi_counts): # pick slices of patient that have roi of sought-for target # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero( patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] == self.unique_ts[tix]) > 0] if len(elig_slices) > 0: choose_fg = True break else: # pick bg elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices']) if len(elig_slices) == 0: elig_slices = z sl_pick_ix = np.random.choice(elig_slices, size=None) data = data[..., sl_pick_ix] seg = seg[..., sl_pick_ix] spatial_shp = data[0].shape assert spatial_shp==seg.shape, "spatial shape incongruence betw. data and seg" if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]): new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))] data = dutils.pad_nd_image(data, (len(data), *new_shape)) seg = dutils.pad_nd_image(seg, new_shape) #eventual cropping to pre_crop_size: with prob self.p_fg sample pixel from random ROI and shift center, #if possible, to that pixel, so that img still contains ROI after pre-cropping dim_cropflags = [spatial_shp[i] > self.cf.pre_crop_size[i] for i in range(len(spatial_shp))] if np.any(dim_cropflags): print("dim crop applied") # sample pixel from random ROI and shift center, if possible, to that pixel if self.cf.dim==3: choose_fg = (empty_samples_count/self.batch_size>=self.empty_samples_max_ratio) or np.random.rand(1) <= self.p_fg if self.sample_pids_w_replace and choose_fg and np.any(seg): available_roi_ids = np.unique(seg)[1:] for tix in np.argsort(batch_roi_counts): elig_roi_ids = available_roi_ids[ patient[self.balance_target][available_roi_ids - 1] == self.unique_ts[tix]] if len(elig_roi_ids) > 0: seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None)) break roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)] assert seg[tuple(roi_anchor_pixel)] > 0 # sample the patch center coords. constrained by edges of image - pre_crop_size /2 and # distance to the selected ROI < patch_size /2 def get_cropped_centercoords(dim): low = np.max((self.cf.pre_crop_size[dim]//2, roi_anchor_pixel[dim] - (self.cf.patch_size[dim]//2 - self.cf.crop_margin[dim]))) high = np.min((spatial_shp[dim] - self.cf.pre_crop_size[dim]//2, roi_anchor_pixel[dim] + (self.cf.patch_size[dim]//2 - self.cf.crop_margin[dim]))) if low >= high: #happens if lesion on the edge of the image. #print('correcting low/high:', low, high, spatial_shp, roi_anchor_pixel, dim) low = self.cf.pre_crop_size[dim] // 2 high = spatial_shp[dim] - self.cf.pre_crop_size[dim]//2 assert low0]) - 1] == self.unique_ts[tix]) if not np.any(seg): empty_samples_count += 1 #self.stats['roi_counts'] += batch_roi_counts #DOESNT WORK WITH MULTITHREADING! do outside #self.stats['empty_samples_count'] += empty_samples_count batch = {'data': np.array(batch_data), 'seg': np.array(batch_segs).astype('uint8'), 'pid': batch_patient_ids, 'spec': batch_patient_specs, 'roi_counts':batch_roi_counts, 'empty_samples_count': empty_samples_count} for key,val in batch_roi_items.items(): #extend batch dic by roi-wise items (obs, class ids, regression vectors...) batch[key] = np.array(val) return batch class PatientBatchIterator(dutils.PatientBatchIterator): """ creates a val/test generator. Step through the dataset and return dictionaries per patient. 2D is a special case of 3D patching with patch_size[2] == 1 (slices) Creates whole Patient batch and targets, and - if necessary - patchwise batch and targets. Appends patient targets anyway for evaluation. For Patching, shifts all patches into batch dimension. batch_tiling_forward will take care of exceeding batch dimensions. This iterator/these batches are not intended to go through MTaugmenter afterwards """ def __init__(self, cf, data): super(PatientBatchIterator, self).__init__(cf, data) self.patient_ix = 0 #running index over all patients in set self.patch_size = cf.patch_size+[1] if cf.dim==2 else cf.patch_size self.chans = cf.channels if cf.channels is not None else np.index_exp[:] assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing" def generate_train_batch(self, pid=None): if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 if pid is None: pid = self.dataset_pids[self.patient_ix] # + self.thread_id patient = self._data[pid] #swap dimensions from (c,)z,y,x to c,y,x,z or h,w,d to ease 2D/3D-case handling data = np.transpose(np.load(patient['img'], mmap_mode='r'), axes=(0, 2, 3, 1)) seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis] data_shp_raw = data.shape plot_bg = data[self.cf.plot_bg_chan] if self.cf.plot_bg_chan not in self.chans else None data = data[self.chans] discarded_chans = len( [c for c in np.setdiff1d(np.arange(data_shp_raw[0]), self.chans) if c < self.cf.plot_bg_chan]) spatial_shp = data[0].shape # spatial dims need to be in order x,y,z assert spatial_shp==seg[0].shape, "spatial shape incongruence betw. data and seg" if np.any([spatial_shp[i] < ps for i, ps in enumerate(self.patch_size)]): new_shape = [np.max([spatial_shp[i], self.patch_size[i]]) for i in range(len(self.patch_size))] data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape. seg = dutils.pad_nd_image(seg, new_shape) if plot_bg is not None: plot_bg = dutils.pad_nd_image(plot_bg, new_shape) if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: #adds the batch dim here bc won't go through MTaugmenter out_data = data[np.newaxis] out_seg = seg[np.newaxis] if plot_bg is not None: out_plot_bg = plot_bg[np.newaxis] #data and seg shape: (1,c,x,y,z), where c=1 for seg batch_3D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_3D[o] = np.array([patient[o]]) converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg) batch_3D = converter(**batch_3D) batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_3D["patient_" + o] = batch_3D[o] if self.cf.dim == 2: out_data = np.transpose(data, axes=(3,0,1,2)) #(c,y,x,z) to (b=z,c,x,y), use z=b as batchdim out_seg = np.transpose(seg, axes=(3,0,1,2)).astype('uint8') #(c,y,x,z) to (b=z,c,x,y) batch_2D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_2D[o] = np.repeat(np.array([patient[o]]), len(out_data), axis=0) converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg) batch_2D = converter(**batch_2D) if plot_bg is not None: out_plot_bg = np.transpose(plot_bg, axes=(2,0,1)).astype('float32') if self.cf.merge_2D_to_3D_preds: batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_3D['patient_'+o] else: batch_2D.update({'patient_bb_target': batch_2D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_2D[o] out_batch = batch_3D if self.cf.dim == 3 else batch_2D out_batch.update({'pid': np.array([patient['pid']] * len(out_data)), 'spec':np.array([patient['spec']] * len(out_data))}) if self.cf.plot_bg_chan in self.chans and discarded_chans>0: assert plot_bg is None plot_bg = int(self.cf.plot_bg_chan - discarded_chans) out_plot_bg = plot_bg if plot_bg is not None: out_batch['plot_bg'] = out_plot_bg #eventual tiling into patches spatial_shp = out_batch["data"].shape[2:] if np.any([spatial_shp[ix] > self.patch_size[ix] for ix in range(len(spatial_shp))]): patient_batch = out_batch #print("patientiterator produced patched batch!") patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size) new_img_batch, new_seg_batch = [], [] for c in patch_crop_coords_list: new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:c[5]]) seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]] new_seg_batch.append(seg_patch) shps = [] for arr in new_img_batch: shps.append(arr.shape) data = np.array(new_img_batch) # (patches, c, x, y, z) seg = np.array(new_seg_batch) if self.cf.dim == 2: # all patches have z dimension 1 (slices). discard dimension data = data[..., 0] seg = seg[..., 0] patch_batch = {'data': data, 'seg': seg.astype('uint8'), 'pid': np.array([patient['pid']] * data.shape[0]), 'spec':np.array([patient['spec']] * data.shape[0])} for o in self.cf.roi_items: patch_batch[o] = np.repeat(np.array([patient[o]]), len(patch_crop_coords_list), axis=0) # patient-wise (orig) batch info for putting the patches back together after prediction for o in self.cf.roi_items: patch_batch["patient_"+o] = patient_batch['patient_'+o] patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] #patch_batch['patient_roi_labels'] = patient_batch['patient_roi_labels'] patch_batch['patient_data'] = patient_batch['data'] patch_batch['patient_seg'] = patient_batch['seg'] patch_batch['original_img_shape'] = patient_batch['original_img_shape'] if plot_bg is not None: patch_batch['patient_plot_bg'] = patient_batch['plot_bg'] converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, self.cf.class_specific_seg) patch_batch = converter(**patch_batch) out_batch = patch_batch self.patient_ix += 1 # todo raise stopiteration when in test mode if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 return out_batch def create_data_gen_pipeline(cf, patient_data, do_aug=True, sample_pids_w_replace=True): """ create mutli-threaded train/val/test batch generation and augmentation pipeline. :param patient_data: dictionary containing one dictionary per patient in the train/test subset :param test_pids: (optional) list of test patient ids, calls the test generator. :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing) :return: multithreaded_generator """ data_gen = BatchGenerator(cf, patient_data, sample_pids_w_replace=sample_pids_w_replace) my_transforms = [] if do_aug: if cf.da_kwargs["mirror"]: mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes']) my_transforms.append(mirror_transform) if cf.da_kwargs["gamma_transform"]: gamma_transform = GammaTransform(gamma_range=cf.da_kwargs["gamma_range"], invert_image=False, per_channel=False, retain_stats=True) my_transforms.append(gamma_transform) if cf.dim == 3: # augmentations with desired effect on z-dimension spatial_transform = SpatialTransform(patch_size=cf.patch_size, patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], do_elastic_deform=False, do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop'], border_mode_data=cf.da_kwargs['border_mode_data']) my_transforms.append(spatial_transform) # augmentations that are only meant to affect x-y my_transforms.append(Convert3DTo2DTransform()) spatial_transform = SpatialTransform(patch_size=cf.patch_size[:2], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'][:2], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=False, do_scale=False, random_crop=False, border_mode_data=cf.da_kwargs['border_mode_data']) my_transforms.append(spatial_transform) my_transforms.append(Convert2DTo3DTransform()) else: spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'][:2], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop'], border_mode_data=cf.da_kwargs['border_mode_data']) my_transforms.append(spatial_transform) else: my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) if cf.create_bounding_box_targets: my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg)) #batch receives entry 'bb_target' w bbox coordinates as [y1,x1,y2,x2,z1,z2]. #my_transforms.append(ConvertSegToOnehotTransform(classes=range(cf.num_seg_classes))) all_transforms = Compose(my_transforms) #MTAugmenter creates iterator from data iterator data_gen after applying the composed transform all_transforms - multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, - seeds=list(np.random.randint(0,cf.n_workers*2,size=cf.n_workers))) + multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=data_gen.n_filled_threads, + seeds=range(data_gen.n_filled_threads)) return multithreaded_generator def get_train_generators(cf, logger, data_statistics=True): """ wrapper function for creating the training batch generator pipeline. returns the train/val generators need to select cv folds on patient level, but be able to include both breasts of each patient. """ dataset = Dataset(cf, logger) dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits) dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle')) set_splits = dataset.fg.splits test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold-1) train_ids = np.concatenate(set_splits, axis=0) if cf.held_out_test_set: train_ids = np.concatenate((train_ids, test_ids), axis=0) test_ids = [] train_data = {k: v for (k, v) in dataset.data.items() if k in train_ids} val_data = {k: v for (k, v) in dataset.data.items() if k in val_ids} logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids), len(test_ids))) if data_statistics: dataset.calc_statistics(subsets={"train":train_ids, "val":val_ids, "test":test_ids}, plot_dir=os.path.join(cf.plot_dir,"dataset")) batch_gen = {} batch_gen['train'] = create_data_gen_pipeline(cf, train_data, do_aug=cf.do_aug) batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, do_aug=False, sample_pids_w_replace=False) if cf.val_mode == 'val_patient': batch_gen['val_patient'] = PatientBatchIterator(cf, val_data) batch_gen['n_val'] = len(val_ids) if cf.max_val_patients=="all" else cf.max_val_patients elif cf.val_mode == 'val_sampling': batch_gen['n_val'] = cf.num_val_batches if cf.num_val_batches!="all" else len(val_ids) return batch_gen def get_test_generator(cf, logger): """ if get_test_generators is called multiple times in server env, every time of Dataset initiation rsync will check for copying the data; this should be okay since rsync will not copy if files already exist in destination. """ if cf.held_out_test_set: sourcedir = cf.test_data_sourcedir test_ids = None else: sourcedir = None with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle: set_splits = pickle.load(handle) test_ids = set_splits[cf.fold] test_set = Dataset(cf, logger, test_ids, data_sourcedir=sourcedir) logger.info("data set loaded with: {} test patients".format(len(test_set.set_ids))) batch_gen = {} batch_gen['test'] = PatientBatchIterator(cf, test_set.data) batch_gen['n_test'] = len(test_set.set_ids) if cf.max_test_patients=="all" else min(cf.max_test_patients, len(test_set.set_ids)) return batch_gen if __name__=="__main__": import sys sys.path.append('../') # works on cluster indep from where sbatch job is started import plotting as plg import utils.exp_utils as utils from configs import Configs cf = configs() total_stime = time.time() times = {} #cf.server_env = True #cf.data_dir = "experiments/dev_data" #dataset = Dataset(cf) #patient = dataset['Master_00018'] cf.exp_dir = "experiments/dev/" cf.plot_dir = cf.exp_dir+"plots" os.makedirs(cf.exp_dir, exist_ok=True) cf.fold = 0 logger = utils.get_logger(cf.exp_dir) gens = get_train_generators(cf, logger) train_loader = gens['train'] #for i in range(train_loader.dataset_length): # print("batch", i) stime = time.time() ex_batch = next(train_loader) #ex_batch = next(train_loader) times["train_batch"] = time.time()-stime plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch.png", show_gt_labels=True) #with open(os.path.join(cf.exp_dir, "fold_"+str(cf.fold), "BatchGenerator_stats.txt"), mode="w") as file: # train_loader.generator.print_stats(logger, file) val_loader = gens['val_sampling'] stime = time.time() ex_batch = next(val_loader) times["val_batch"] = time.time()-stime stime = time.time() plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch.png", show_gt_labels=True, plot_mods=False, show_info=False) times["val_plot"] = time.time()-stime test_loader = get_test_generator(cf, logger)["test"] stime = time.time() ex_batch = test_loader.generate_train_batch() print(ex_batch["data"].shape) times["test_batch"] = time.time()-stime stime = time.time() plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/ex_patchbatch.png", show_gt_boxes=False, show_info=False, dpi=400, sample_picks=[2,5], plot_mods=False) times["test_patchbatch_plot"] = time.time()-stime #stime = time.time() #ex_batch['data'] = ex_batch['patient_data'] #ex_batch['seg'] = ex_batch['patient_seg'] #if 'patient_plot_bg' in ex_batch.keys(): # ex_batch['plot_bg'] = ex_batch['patient_plot_bg'] #plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png") #times["test_patientbatch_plot"] = time.time() - stime #print("patch batch keys", ex_batch.keys()) #print("patch batch les gle", ex_batch["lesion_gleasons"].shape) #print("patch batch gsbx", ex_batch["GSBx"].shape) #print("patch batch class_targ", ex_batch["class_targets"].shape) #print("patient b roi labels", ex_batch["patient_roi_labels"].shape) #print("patient les gleas", ex_batch["patient_lesion_gleasons"].shape) #print("patch&patient batch pid", ex_batch["pid"], len(ex_batch["pid"])) #print("unique patient_seg", np.unique(ex_batch["patient_seg"])) #print("pb patient roi labels", len(ex_batch["patient_roi_labels"]), ex_batch["patient_roi_labels"]) #print("pid", ex_batch["pid"]) #patient_batch = {k[len("patient_"):]:v for (k,v) in ex_batch.items() if k.lower().startswith("patient")} #patient_batch["pid"] = ex_batch["pid"] #stime = time.time() #plg.view_batch(cf, patient_batch, out_file="experiments/dev_expatientbatch") #times["test_plot"] = time.time()-stime print("Times recorded throughout:") for (k,v) in times.items(): print(k, "{:.2f}".format(v)) mins, secs = divmod((time.time() - total_stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) \ No newline at end of file diff --git a/datasets/toy/configs.py b/datasets/toy/configs.py index 74f5927..8210f14 100644 --- a/datasets/toy/configs.py +++ b/datasets/toy/configs.py @@ -1,490 +1,490 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import sys import os sys.path.append(os.path.dirname(os.path.realpath(__file__))) import numpy as np from default_configs import DefaultConfigs from collections import namedtuple boxLabel = namedtuple('boxLabel', ["name", "color"]) Label = namedtuple("Label", ['id', 'name', 'shape', 'radius', 'color', 'regression', 'ambiguities', 'gt_distortion']) binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'bin_vals']) class Configs(DefaultConfigs): def __init__(self, server_env=None): super(Configs, self).__init__(server_env) ######################### # Prepro # ######################### - self.pp_rootdir = os.path.join('/media/gregor/HDD2TB/data/toy', "cyl1ps_dev") + self.pp_rootdir = os.path.join('/home/gregor/datasets/toy', "cyl1ps_dev") self.pp_npz_dir = self.pp_rootdir+"_npz" self.pre_crop_size = [320,320,8] #y,x,z; determines pp data shape (2D easily implementable, but only 3D for now) self.min_2d_radius = 6 #in pixels self.n_train_samples, self.n_test_samples = 1200, 1000 # not actually real one-hot encoding (ohe) but contains more info: roi-overlap only within classes. self.pp_create_ohe_seg = False self.pp_empty_samples_ratio = 0.1 self.pp_place_radii_mid_bin = True self.pp_only_distort_2d = True # outer-most intensity of blurred radii, relative to inner-object intensity. <1 for decreasing, > 1 for increasing. # e.g.: setting 0.1 means blurred edge has min intensity 10% as large as inner-object intensity. self.pp_blur_min_intensity = 0.2 self.max_instances_per_sample = 1 #how many max instances over all classes per sample (img if 2d, vol if 3d) self.max_instances_per_class = self.max_instances_per_sample # how many max instances per image per class self.noise_scale = 0. # std-dev of gaussian noise self.ambigs_sampling = "gaussian" #"gaussian" or "uniform" """ radius_calib: gt distort for calibrating uncertainty. Range of gt distortion is inferable from image by distinguishing it from the rest of the object. blurring width around edge will be shifted so that symmetric rel to orig radius. blurring scale: if self.ambigs_sampling is uniform, distribution's non-zero range (b-a) will be sqrt(12)*scale since uniform dist has variance (b-a)²/12. b,a will be placed symmetrically around unperturbed radius. if sampling is gaussian, then scale parameter sets one std dev, i.e., blurring width will be orig_radius * std_dev * 2. """ self.ambiguities = { #set which classes to apply which ambs to below in class labels #choose out of: 'outer_radius', 'inner_radius', 'radii_relations'. #kind #probability #scale (gaussian std, relative to unperturbed value) #"outer_radius": (1., 0.5), #"outer_radius_xy": (1., 0.5), #"inner_radius": (0.5, 0.1), #"radii_relations": (0.5, 0.1), "radius_calib": (1., 1./6) } # shape choices: 'cylinder', 'block' # id, name, shape, radius, color, regression, ambiguities, gt_distortion self.pp_classes = [Label(1, 'cylinder', 'cylinder', ((6,6,1),(40,40,8)), (*self.blue, 1.), "radius_2d", (), ()), #Label(2, 'block', 'block', ((6,6,1),(40,40,8)), (*self.aubergine,1.), "radii_2d", (), ('radius_calib',)) ] ######################### # I/O # ######################### - self.data_sourcedir = '/home/gregor/data/toy/cyl1ps_dev' + self.data_sourcedir = '/home/gregor/datasets/toy/cyl1ps_dev' if server_env: self.data_sourcedir = '/datasets/data_ramien/toy/cyl1ps_dev_npz' self.test_data_sourcedir = os.path.join(self.data_sourcedir, 'test') self.data_sourcedir = os.path.join(self.data_sourcedir, "train") self.info_df_name = 'info_df.pickle' # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn', 'detection_fpn']. self.model = 'mrcnn' self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net') self.model_path = os.path.join(self.source_dir, self.model_path) ######################### # Architecture # ######################### # one out of [2, 3]. dimension the model operates in. self.dim = 3 # 'class', 'regression', 'regression_bin', 'regression_ken_gal' # currently only tested mode is a single-task at a time (i.e., only one task in below list) # but, in principle, tasks could be combined (e.g., object classes and regression per class) self.prediction_tasks = ['class', ] self.start_filts = 48 if self.dim == 2 else 18 self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2 self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50' self.norm = 'instance_norm' # one of None, 'instance_norm', 'batch_norm' self.relu = 'relu' # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform') self.weight_init = None self.regression_n_features = 1 # length of regressor target vector ######################### # Data Loader # ######################### self.num_epochs = 32 self.num_train_batches = 120 if self.dim == 2 else 80 - self.batch_size = 16 if self.dim == 2 else 8 + self.batch_size = 8 if self.dim == 2 else 4 self.n_cv_splits = 4 # select modalities from preprocessed data self.channels = [0] self.n_channels = len(self.channels) # which channel (mod) to show as bg in plotting, will be extra added to batch if not in self.channels self.plot_bg_chan = 0 self.crop_margin = [20, 20, 1] # has to be smaller than respective patch_size//2 self.patch_size_2D = self.pre_crop_size[:2] self.patch_size_3D = self.pre_crop_size[:2]+[8] # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation. self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D # ratio of free sampled batch elements before class balancing is triggered # (>0 to include "empty"/background patches.) self.batch_random_ratio = 0.2 self.balance_target = "class_targets" if 'class' in self.prediction_tasks else "rg_bin_targets" self.observables_patient = [] self.observables_rois = [] self.seed = 3 #for generating folds ############################# # Colors, Classes, Legends # ############################# self.plot_frequency = 1 binary_bin_labels = [binLabel(1, 'r<=25', (*self.green, 1.), (1,25)), binLabel(2, 'r>25', (*self.red, 1.), (25,))] quintuple_bin_labels = [binLabel(1, 'r2-10', (*self.green, 1.), (2,10)), binLabel(2, 'r10-20', (*self.yellow, 1.), (10,20)), binLabel(3, 'r20-30', (*self.orange, 1.), (20,30)), binLabel(4, 'r30-40', (*self.bright_red, 1.), (30,40)), binLabel(5, 'r>40', (*self.red, 1.), (40,))] # choose here if to do 2-way or 5-way regression-bin classification task_spec_bin_labels = quintuple_bin_labels self.class_labels = [ # regression: regression-task label, either value or "(x,y,z)_radius" or "radii". # ambiguities: name of above defined ambig to apply to image data (not gt); need to be iterables! # gt_distortion: name of ambig to apply to gt only; needs to be iterable! # #id #name #shape #radius #color #regression #ambiguities #gt_distortion Label( 0, 'bg', None, (0, 0, 0), (*self.white, 0.), (0, 0, 0), (), ())] if "class" in self.prediction_tasks: self.class_labels += self.pp_classes else: self.class_labels += [Label(1, 'object', 'object', ('various',), (*self.orange, 1.), ('radius_2d',), ("various",), ('various',))] if any(['regression' in task for task in self.prediction_tasks]): self.bin_labels = [binLabel(0, 'bg', (*self.white, 1.), (0,))] self.bin_labels += task_spec_bin_labels self.bin_id2label = {label.id: label for label in self.bin_labels} bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels] self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)} self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)] self.bin_dict = {label.id: label.name for label in self.bin_labels if label.id != 0} if self.class_specific_seg: self.seg_labels = self.class_labels self.box_type2label = {label.name: label for label in self.box_labels} self.class_id2label = {label.id: label for label in self.class_labels} self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0} self.seg_id2label = {label.id: label for label in self.seg_labels} self.cmap = {label.id: label.color for label in self.seg_labels} self.plot_prediction_histograms = True self.plot_stat_curves = False self.has_colorchannels = False self.plot_class_ids = True self.num_classes = len(self.class_dict) self.num_seg_classes = len(self.seg_labels) ######################### # Data Augmentation # ######################### self.do_aug = True self.da_kwargs = { 'mirror': True, 'mirror_axes': tuple(np.arange(0, self.dim, 1)), 'do_elastic_deform': False, 'alpha': (500., 1500.), 'sigma': (40., 45.), 'do_rotation': False, 'angle_x': (0., 2 * np.pi), 'angle_y': (0., 0), 'angle_z': (0., 0), 'do_scale': False, 'scale': (0.8, 1.1), 'random_crop': False, 'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3), 'border_mode_data': 'constant', 'border_cval_data': 0, 'order_data': 1 } if self.dim == 3: self.da_kwargs['do_elastic_deform'] = False self.da_kwargs['angle_x'] = (0, 0.0) self.da_kwargs['angle_y'] = (0, 0.0) # must be 0!! self.da_kwargs['angle_z'] = (0., 2 * np.pi) ######################### # Schedule / Selection # ######################### # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training) # the former is morge accurate, while the latter is faster (depending on volume size) self.val_mode = 'val_sampling' # one of 'val_sampling' , 'val_patient' if self.val_mode == 'val_patient': self.max_val_patients = 220 # if 'all' iterates over entire val_set once. if self.val_mode == 'val_sampling': self.num_val_batches = 35 if self.dim==2 else 25 self.save_n_models = 2 self.min_save_thresh = 1 if self.dim == 2 else 1 # =wait time in epochs if "class" in self.prediction_tasks: self.model_selection_criteria = {name + "_ap": 1. for name in self.class_dict.values()} elif any("regression" in task for task in self.prediction_tasks): self.model_selection_criteria = {name + "_ap": 0.2 for name in self.class_dict.values()} self.model_selection_criteria.update({name + "_avp": 0.8 for name in self.class_dict.values()}) self.lr_decay_factor = 0.5 self.scheduling_patience = int(self.num_epochs / 5) self.weight_decay = 1e-5 self.clip_norm = None # number or None ######################### # Testing / Plotting # ######################### self.test_aug_axes = (0,1,(0,1)) # None or list: choices are 0,1,(0,1) self.held_out_test_set = True self.max_test_patients = "all" # number or "all" for all self.test_against_exact_gt = True # only True implemented self.val_against_exact_gt = False # True is an unrealistic --> irrelevant scenario. self.report_score_level = ['rois'] # 'patient' or 'rois' (incl) self.patient_class_of_interest = 1 self.patient_bin_of_interest = 2 self.eval_bins_separately = False#"additionally" if not 'class' in self.prediction_tasks else False self.metrics = ['ap', 'auc', 'dice'] if any(['regression' in task for task in self.prediction_tasks]): self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp', 'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp'] if 'aleatoric' in self.model: self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted'] self.evaluate_fold_means = True self.ap_match_ious = [0.5] # threshold(s) for considering a prediction as true positive self.min_det_thresh = 0.3 self.model_max_iou_resolution = 0.2 # aggregation method for test and val_patient predictions. # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf, # nms = standard non-maximum suppression, or None = no clustering self.clustering = 'wbc' # iou thresh (exclusive!) for regarding two preds as concerning the same ROI self.clustering_iou = self.model_max_iou_resolution # has to be larger than desired possible overlap iou of model predictions self.merge_2D_to_3D_preds = False self.merge_3D_iou = self.model_max_iou_resolution self.n_test_plots = 1 # per fold and rank self.test_n_epochs = self.save_n_models # should be called n_test_ens, since is number of models to ensemble over during testing # is multiplied by (1 + nr of test augs) ######################### # Assertions # ######################### if not 'class' in self.prediction_tasks: assert self.num_classes == 1 ######################### # Add model specifics # ######################### {'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs, 'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs, 'detection_unet': self.add_det_unet_configs, 'detection_fpn': self.add_det_fpn_configs }[self.model]() def rg_val_to_bin_id(self, rg_val): #only meant for isotropic radii!! # only 2D radii (x and y dims) or 1D (x or y) are expected return np.round(np.digitize(rg_val, self.bin_edges).mean()) def add_det_fpn_configs(self): self.learning_rate = [1 * 1e-4] * self.num_epochs self.dynamic_lr_scheduling = True self.scheduling_criterion = 'torch_loss' self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' self.n_roi_candidates = 4 if self.dim == 2 else 6 # max number of roi candidates to identify per image (slice in 2D, volume in 3D) # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') self.seg_loss_mode = 'wce' self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1] self.fp_dice_weight = 1 if self.dim == 2 else 1 # if <1, false positive predictions in foreground are penalized less. self.detection_min_confidence = 0.05 # how to determine score of roi: 'max' or 'median' self.score_det = 'max' def add_det_unet_configs(self): self.learning_rate = [1 * 1e-4] * self.num_epochs self.dynamic_lr_scheduling = True self.scheduling_criterion = "torch_loss" self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' # max number of roi candidates to identify per image (slice in 2D, volume in 3D) self.n_roi_candidates = 4 if self.dim == 2 else 6 # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') self.seg_loss_mode = 'wce' self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1] # if <1, false positive predictions in foreground are penalized less. self.fp_dice_weight = 1 if self.dim == 2 else 1 self.detection_min_confidence = 0.05 # how to determine score of roi: 'max' or 'median' self.score_det = 'max' self.init_filts = 32 self.kernel_size = 3 # ks for horizontal, normal convs self.kernel_size_m = 2 # ks for max pool self.pad = "same" # "same" or integer, padding of horizontal convs def add_mrcnn_configs(self): self.learning_rate = [1e-4] * self.num_epochs self.dynamic_lr_scheduling = True # with scheduler set in exec self.scheduling_criterion = max(self.model_selection_criteria, key=self.model_selection_criteria.get) self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' # number of classes for network heads: n_foreground_classes + 1 (background) self.head_classes = self.num_classes + 1 if 'class' in self.prediction_tasks else 2 # feed +/- n neighbouring slices into channel dimension. set to None for no context. self.n_3D_context = None if self.n_3D_context is not None and self.dim == 2: self.n_channels *= (self.n_3D_context * 2 + 1) self.detect_while_training = True # disable the re-sampling of mask proposals to original size for speed-up. # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching), # mask outputs are optional. self.return_masks_in_train = True self.return_masks_in_val = True self.return_masks_in_test = True # feature map strides per pyramid level are inferred from architecture. anchor scales are set accordingly. self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]} # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.) self.rpn_anchor_scales = {'xy': [[4], [8], [16], [32]], 'z': [[1], [2], [4], [8]]} # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3. self.pyramid_levels = [0, 1, 2, 3] # number of feature maps in rpn. typically lowered in 3D to save gpu-memory. self.n_rpn_features = 512 if self.dim == 2 else 64 # anchor ratios and strides per position in feature maps. self.rpn_anchor_ratios = [0.5, 1., 2.] self.rpn_anchor_stride = 1 # Threshold for first stage (RPN) non-maximum suppression (NMS): LOWER == HARDER SELECTION self.rpn_nms_threshold = max(0.8, self.model_max_iou_resolution) # loss sampling settings. self.rpn_train_anchors_per_image = 4 self.train_rois_per_image = 6 # per batch_instance self.roi_positive_ratio = 0.5 self.anchor_matching_iou = 0.8 # k negative example candidates are drawn from a pool of size k*shem_poolsize (stochastic hard-example mining), # where k<=#positive examples. self.shem_poolsize = 2 self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3) self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5) self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10) self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]]) self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1], self.patch_size_3D[2], self.patch_size_3D[2]]) # y1,x1,y2,x2,z1,z2 if self.dim == 2: self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4] self.bbox_std_dev = self.bbox_std_dev[:4] self.window = self.window[:4] self.scale = self.scale[:4] self.plot_y_max = 1.5 self.n_plot_rpn_props = 5 if self.dim == 2 else 30 # per batch_instance (slice in 2D / patient in 3D) # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element. self.pre_nms_limit = 2000 if self.dim == 2 else 4000 # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True, # since proposals of the entire batch are forwarded through second stage as one "batch". self.roi_chunk_size = 1300 if self.dim == 2 else 500 self.post_nms_rois_training = 200 * (self.head_classes-1) if self.dim == 2 else 400 self.post_nms_rois_inference = 200 * (self.head_classes-1) # Final selection of detections (refine_detections) self.model_max_instances_per_batch_element = 9 if self.dim == 2 else 18 # per batch element and class. self.detection_nms_threshold = self.model_max_iou_resolution # needs to be > 0, otherwise all predictions are one cluster. self.model_min_confidence = 0.2 # iou for nms in box refining (directly after heads), should be >0 since ths>=x in mrcnn.py if self.dim == 2: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride))] for stride in self.backbone_strides['xy']]) else: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride)), int(np.ceil(self.patch_size[2] / stride_z))] for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z'] )]) if self.model == 'retina_net' or self.model == 'retina_unet': # whether to use focal loss or SHEM for loss-sample selection self.focal_loss = False # implement extra anchor-scales according to https://arxiv.org/abs/1708.02002 self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['xy']] self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['z']] self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3 # pre-selection of detections for NMS-speedup. per entire batch. self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002 self.anchor_matching_iou = 0.7 if self.model == 'retina_unet': self.operate_stride1 = True diff --git a/datasets/toy/data_loader.py b/datasets/toy/data_loader.py index e77b3db..f4bf28f 100644 --- a/datasets/toy/data_loader.py +++ b/datasets/toy/data_loader.py @@ -1,594 +1,595 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import sys sys.path.append('../') # works on cluster indep from where sbatch job is started import plotting as plg import numpy as np import os from multiprocessing import Lock from collections import OrderedDict import pandas as pd import pickle import time # batch generator tools from https://github.com/MIC-DKFZ/batchgenerators from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror from batchgenerators.transforms.abstract_transforms import Compose from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter from batchgenerators.transforms.spatial_transforms import SpatialTransform from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform sys.path.append(os.path.dirname(os.path.realpath(__file__))) import utils.dataloader_utils as dutils from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates def load_obj(file_path): with open(file_path, 'rb') as handle: return pickle.load(handle) class Dataset(dutils.Dataset): r""" Load a dict holding memmapped arrays and clinical parameters for each patient, evtly subset of those. If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to cf.data_dir. :param cf: config file :param folds: number of folds out of @params n_cv folds to include :param n_cv: number of total folds :return: dict with imgs, segs, pids, class_labels, observables """ def __init__(self, cf, logger, subset_ids=None, data_sourcedir=None, mode='train'): super(Dataset,self).__init__(cf, data_sourcedir=data_sourcedir) load_exact_gts = (mode=='test' or cf.val_mode=="val_patient") and self.cf.test_against_exact_gt p_df = pd.read_pickle(os.path.join(self.data_dir, cf.info_df_name)) if subset_ids is not None: p_df = p_df[p_df.pid.isin(subset_ids)] logger.info('subset: selected {} instances from df'.format(len(p_df))) pids = p_df.pid.tolist() #evtly copy data from data_sourcedir to data_dest if cf.server_env and not hasattr(cf, "data_dir"): file_subset = [os.path.join(self.data_dir, '{}.*'.format(pid)) for pid in pids] file_subset += [os.path.join(self.data_dir, '{}_seg.*'.format(pid)) for pid in pids] file_subset += [cf.info_df_name] if load_exact_gts: file_subset += [os.path.join(self.data_dir, '{}_exact_seg.*'.format(pid)) for pid in pids] self.copy_data(cf, file_subset=file_subset) img_paths = [os.path.join(self.data_dir, '{}.npy'.format(pid)) for pid in pids] seg_paths = [os.path.join(self.data_dir, '{}_seg.npy'.format(pid)) for pid in pids] if load_exact_gts: exact_seg_paths = [os.path.join(self.data_dir, '{}_exact_seg.npy'.format(pid)) for pid in pids] class_targets = p_df['class_ids'].tolist() rg_targets = p_df['regression_vectors'].tolist() if load_exact_gts: exact_rg_targets = p_df['undistorted_rg_vectors'].tolist() fg_slices = p_df['fg_slices'].tolist() self.data = OrderedDict() for ix, pid in enumerate(pids): self.data[pid] = {'data': img_paths[ix], 'seg': seg_paths[ix], 'pid': pid, 'fg_slices': np.array(fg_slices[ix])} if load_exact_gts: self.data[pid]['exact_seg'] = exact_seg_paths[ix] if 'class' in self.cf.prediction_tasks: self.data[pid]['class_targets'] = np.array(class_targets[ix], dtype='uint8') else: self.data[pid]['class_targets'] = np.ones_like(np.array(class_targets[ix]), dtype='uint8') if load_exact_gts: self.data[pid]['exact_class_targets'] = self.data[pid]['class_targets'] if any(['regression' in task for task in self.cf.prediction_tasks]): self.data[pid]['regression_targets'] = np.array(rg_targets[ix], dtype='float16') self.data[pid]["rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in rg_targets[ix]], dtype='uint8') if load_exact_gts: self.data[pid]['exact_regression_targets'] = np.array(exact_rg_targets[ix], dtype='float16') self.data[pid]["exact_rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in exact_rg_targets[ix]], dtype='uint8') cf.roi_items = cf.observables_rois[:] cf.roi_items += ['class_targets'] if any(['regression' in task for task in self.cf.prediction_tasks]): cf.roi_items += ['regression_targets'] cf.roi_items += ['rg_bin_targets'] self.set_ids = np.array(list(self.data.keys())) self.df = None class BatchGenerator(dutils.BatchGenerator): """ creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D) from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size. Actual patch_size is obtained after data augmentation. :param data: data dictionary as provided by 'load_dataset'. :param batch_size: number of patients to sample for the batch :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target """ def __init__(self, cf, data, sample_pids_w_replace=True, max_batches=None, raise_stop_iteration=False, seed=0): super(BatchGenerator, self).__init__(cf, data, sample_pids_w_replace=sample_pids_w_replace, max_batches=max_batches, raise_stop_iteration=raise_stop_iteration, seed=seed) self.chans = cf.channels if cf.channels is not None else np.index_exp[:] assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing" self.crop_margin = np.array(self.cf.patch_size) / 8. # min distance of ROI center to edge of cropped_patch. self.p_fg = 0.5 self.empty_samples_max_ratio = 0.6 self.balance_target_distribution(plot=sample_pids_w_replace) def generate_train_batch(self): # everything done in here is per batch # print statements in here get confusing due to multithreading batch_pids = self.get_batch_pids() batch_data, batch_segs, batch_patient_targets = [], [], [] batch_roi_items = {name: [] for name in self.cf.roi_items} # record roi count and empty count of classes in batch # empty count for no presence of resp. class in whole sample (empty slices in 2D/patients in 3D) batch_roi_counts = np.zeros((len(self.unique_ts),), dtype='uint32') batch_empty_counts = np.zeros((len(self.unique_ts),), dtype='uint32') for b in range(len(batch_pids)): patient = self._data[batch_pids[b]] data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis] seg = np.load(patient['seg'], mmap_mode='r').astype('uint8') (c, y, x, z) = data.shape if self.cf.dim == 2: elig_slices, choose_fg = [], False if len(patient['fg_slices']) > 0: if np.all(batch_empty_counts / self.batch_size >= self.empty_samples_max_ratio) or np.random.rand( 1) <= self.p_fg: # fg is to be picked for tix in np.argsort(batch_roi_counts): # pick slices of patient that have roi of sought-for target # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero( patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] == self.unique_ts[tix]) > 0] if len(elig_slices) > 0: choose_fg = True break else: # pick bg elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices']) if len(elig_slices) > 0: sl_pick_ix = np.random.choice(elig_slices, size=None) else: sl_pick_ix = np.random.choice(z, size=None) data = data[..., sl_pick_ix] seg = seg[..., sl_pick_ix] spatial_shp = data[0].shape assert spatial_shp == seg.shape, "spatial shape incongruence betw. data and seg" if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]): new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))] data = dutils.pad_nd_image(data, (len(data), *new_shape)) seg = dutils.pad_nd_image(seg, new_shape) # eventual cropping to pre_crop_size: sample pixel from random ROI and shift center, # if possible, to that pixel, so that img still contains ROI after pre-cropping dim_cropflags = [spatial_shp[i] > self.cf.pre_crop_size[i] for i in range(len(spatial_shp))] if np.any(dim_cropflags): # sample pixel from random ROI and shift center, if possible, to that pixel if self.cf.dim==3: choose_fg = np.any(batch_empty_counts/self.batch_size>=self.empty_samples_max_ratio) or \ np.random.rand(1) <= self.p_fg if choose_fg and np.any(seg): available_roi_ids = np.unique(seg)[1:] for tix in np.argsort(batch_roi_counts): elig_roi_ids = available_roi_ids[patient[self.balance_target][available_roi_ids-1] == self.unique_ts[tix]] if len(elig_roi_ids)>0: seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None)) break roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)] assert seg[tuple(roi_anchor_pixel)] > 0 # sample the patch center coords. constrained by edges of image - pre_crop_size /2 and # distance to the selected ROI < patch_size /2 def get_cropped_centercoords(dim): low = np.max((self.cf.pre_crop_size[dim] // 2, roi_anchor_pixel[dim] - ( self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim]))) high = np.min((spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2, roi_anchor_pixel[dim] + ( self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim]))) if low >= high: # happens if lesion on the edge of the image. low = self.cf.pre_crop_size[dim] // 2 high = spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2 assert low < high, 'low greater equal high, data dimension {} too small, shp {}, patient {}, low {}, high {}'.format( dim, spatial_shp, patient['pid'], low, high) return np.random.randint(low=low, high=high) else: # sample crop center regardless of ROIs, not guaranteed to be empty def get_cropped_centercoords(dim): return np.random.randint(low=self.cf.pre_crop_size[dim] // 2, high=spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2) sample_seg_center = {} for dim in np.where(dim_cropflags)[0]: sample_seg_center[dim] = get_cropped_centercoords(dim) min_ = int(sample_seg_center[dim] - self.cf.pre_crop_size[dim] // 2) max_ = int(sample_seg_center[dim] + self.cf.pre_crop_size[dim] // 2) data = np.take(data, indices=range(min_, max_), axis=dim + 1) # +1 for channeldim seg = np.take(seg, indices=range(min_, max_), axis=dim) batch_data.append(data) batch_segs.append(seg[np.newaxis]) for o in batch_roi_items: #after loop, holds every entry of every batchpatient per observable batch_roi_items[o].append(patient[o]) if self.cf.dim == 3: for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero==0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 elif self.cf.dim == 2: for tix in range(len(self.unique_ts)): non_zero = np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix]) batch_roi_counts[tix] += non_zero batch_empty_counts[tix] += int(non_zero == 0) # todo remove assert when checked if not np.any(seg): assert non_zero==0 batch = {'data': np.array(batch_data), 'seg': np.array(batch_segs).astype('uint8'), 'pid': batch_pids, 'roi_counts': batch_roi_counts, 'empty_counts': batch_empty_counts} for key,val in batch_roi_items.items(): #extend batch dic by entries of observables dic batch[key] = np.array(val) return batch class PatientBatchIterator(dutils.PatientBatchIterator): """ creates a test generator that iterates over entire given dataset returning 1 patient per batch. Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actually evaluation (done in 3D), if willing to accept speed-loss during training. Specific properties of toy data set: toy data may be created with added ground-truth noise. thus, there are exact ground truths (GTs) and noisy ground truths available. the normal or noisy GTs are used in training by the BatchGenerator. The PatientIterator, however, may use the exact GTs if set in configs. :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or batch_size = n_2D_patches in 2D . """ def __init__(self, cf, data, mode='test'): super(PatientBatchIterator, self).__init__(cf, data) self.patch_size = cf.patch_size_2D + [1] if cf.dim == 2 else cf.patch_size_3D self.chans = cf.channels if cf.channels is not None else np.index_exp[:] assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing" if (mode=="validation" and hasattr(self.cf, 'val_against_exact_gt') and self.cf.val_against_exact_gt) or \ (mode == 'test' and self.cf.test_against_exact_gt): self.gt_prefix = 'exact_' print("PatientIterator: Loading exact Ground Truths.") else: self.gt_prefix = '' self.patient_ix = 0 # running index over all patients in set def generate_train_batch(self, pid=None): if pid is None: pid = self.dataset_pids[self.patient_ix] patient = self._data[pid] # already swapped dimensions in pp from (c,)z,y,x to c,y,x,z or h,w,d to ease 2D/3D-case handling data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis] seg = np.load(patient[self.gt_prefix+'seg']).astype('uint8')[np.newaxis] data_shp_raw = data.shape plot_bg = data[self.cf.plot_bg_chan] if self.cf.plot_bg_chan not in self.chans else None data = data[self.chans] discarded_chans = len( [c for c in np.setdiff1d(np.arange(data_shp_raw[0]), self.chans) if c < self.cf.plot_bg_chan]) spatial_shp = data[0].shape # spatial dims need to be in order x,y,z assert spatial_shp == seg[0].shape, "spatial shape incongruence betw. data and seg" if np.any([spatial_shp[i] < ps for i, ps in enumerate(self.patch_size)]): new_shape = [np.max([spatial_shp[i], self.patch_size[i]]) for i in range(len(self.patch_size))] data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape. seg = dutils.pad_nd_image(seg, new_shape) if plot_bg is not None: plot_bg = dutils.pad_nd_image(plot_bg, new_shape) if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: # adds the batch dim here bc won't go through MTaugmenter out_data = data[np.newaxis] out_seg = seg[np.newaxis] if plot_bg is not None: out_plot_bg = plot_bg[np.newaxis] # data and seg shape: (1,c,x,y,z), where c=1 for seg batch_3D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_3D[o] = np.array([patient[self.gt_prefix+o]]) converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg) batch_3D = converter(**batch_3D) batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_3D["patient_" + o] = batch_3D[o] if self.cf.dim == 2: out_data = np.transpose(data, axes=(3, 0, 1, 2)).astype('float32') # (c,y,x,z) to (b=z,c,x,y), use z=b as batchdim out_seg = np.transpose(seg, axes=(3, 0, 1, 2)).astype('uint8') # (c,y,x,z) to (b=z,c,x,y) batch_2D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_2D[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(out_data), axis=0) converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg) batch_2D = converter(**batch_2D) if plot_bg is not None: out_plot_bg = np.transpose(plot_bg, axes=(2, 0, 1)).astype('float32') if self.cf.merge_2D_to_3D_preds: batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_3D[o] else: batch_2D.update({'patient_bb_target': batch_2D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_2D[o] out_batch = batch_3D if self.cf.dim == 3 else batch_2D out_batch.update({'pid': np.array([patient['pid']] * len(out_data))}) if self.cf.plot_bg_chan in self.chans and discarded_chans > 0: # len(self.chans[:self.cf.plot_bg_chan]) self.patch_size[ix] for ix in range(len(spatial_shp))]): patient_batch = out_batch print("patientiterator produced patched batch!") patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size) new_img_batch, new_seg_batch = [], [] for c in patch_crop_coords_list: new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:c[5]]) seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]] new_seg_batch.append(seg_patch) shps = [] for arr in new_img_batch: shps.append(arr.shape) data = np.array(new_img_batch) # (patches, c, x, y, z) seg = np.array(new_seg_batch) if self.cf.dim == 2: # all patches have z dimension 1 (slices). discard dimension data = data[..., 0] seg = seg[..., 0] patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'), 'pid': np.array([patient['pid']] * data.shape[0])} for o in self.cf.roi_items: patch_batch[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(patch_crop_coords_list), axis=0) #patient-wise (orig) batch info for putting the patches back together after prediction for o in self.cf.roi_items: patch_batch["patient_"+o] = patient_batch["patient_"+o] if self.cf.dim == 2: # this could also be named "unpatched_2d_roi_items" patch_batch["patient_" + o + "_2d"] = patient_batch[o] patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] if self.cf.dim == 2: patch_batch['patient_bb_target_2d'] = patient_batch['bb_target'] patch_batch['patient_data'] = patient_batch['data'] patch_batch['patient_seg'] = patient_batch['seg'] patch_batch['original_img_shape'] = patient_batch['original_img_shape'] if plot_bg is not None: patch_batch['patient_plot_bg'] = patient_batch['plot_bg'] converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, get_rois_from_seg=False, class_specific_seg=self.cf.class_specific_seg) patch_batch = converter(**patch_batch) out_batch = patch_batch self.patient_ix += 1 if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 return out_batch def create_data_gen_pipeline(cf, patient_data, do_aug=True, **kwargs): """ create mutli-threaded train/val/test batch generation and augmentation pipeline. :param patient_data: dictionary containing one dictionary per patient in the train/test subset. :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing) :return: multithreaded_generator """ # create instance of batch generator as first element in pipeline. data_gen = BatchGenerator(cf, patient_data, **kwargs) my_transforms = [] if do_aug: if cf.da_kwargs["mirror"]: mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes']) my_transforms.append(mirror_transform) spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop']) my_transforms.append(spatial_transform) else: my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg)) all_transforms = Compose(my_transforms) # multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms) - multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers)) + multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=data_gen.n_filled_threads, + seeds=range(data_gen.n_filled_threads)) return multithreaded_generator def get_train_generators(cf, logger, data_statistics=False): """ wrapper function for creating the training batch generator pipeline. returns the train/val generators. selects patients according to cv folds (generated by first run/fold of experiment): splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set) If cf.hold_out_test_set is True, adds the test split to the training data. """ dataset = Dataset(cf, logger) dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits) dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle')) set_splits = dataset.fg.splits test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1) train_ids = np.concatenate(set_splits, axis=0) if cf.held_out_test_set: train_ids = np.concatenate((train_ids, test_ids), axis=0) test_ids = [] train_data = {k: v for (k, v) in dataset.data.items() if str(k) in train_ids} val_data = {k: v for (k, v) in dataset.data.items() if str(k) in val_ids} logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids), len(test_ids))) if data_statistics: dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids}, plot_dir= os.path.join(cf.plot_dir,"dataset")) batch_gen = {} batch_gen['train'] = create_data_gen_pipeline(cf, train_data, do_aug=cf.do_aug, sample_pids_w_replace=True) if cf.val_mode == 'val_patient': batch_gen['val_patient'] = PatientBatchIterator(cf, val_data, mode='validation') batch_gen['n_val'] = len(val_ids) if cf.max_val_patients=="all" else min(len(val_ids), cf.max_val_patients) elif cf.val_mode == 'val_sampling': batch_gen['n_val'] = int(np.ceil(len(val_data)/cf.batch_size)) if cf.num_val_batches == "all" else cf.num_val_batches # in current setup, val loader is used like generator. with max_batches being applied in train routine. batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, do_aug=False, sample_pids_w_replace=False, max_batches=None, raise_stop_iteration=False) return batch_gen def get_test_generator(cf, logger): """ if get_test_generators is possibly called multiple times in server env, every time of Dataset initiation rsync will check for copying the data; this should be okay since rsync will not copy if files already exist in destination. """ if cf.held_out_test_set: sourcedir = cf.test_data_sourcedir test_ids = None else: sourcedir = None with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle: set_splits = pickle.load(handle) test_ids = set_splits[cf.fold] test_set = Dataset(cf, logger, subset_ids=test_ids, data_sourcedir=sourcedir, mode='test') logger.info("data set loaded with: {} test patients".format(len(test_set.set_ids))) batch_gen = {} batch_gen['test'] = PatientBatchIterator(cf, test_set.data) batch_gen['n_test'] = len(test_set.set_ids) if cf.max_test_patients=="all" else \ min(cf.max_test_patients, len(test_set.set_ids)) return batch_gen if __name__=="__main__": import utils.exp_utils as utils from datasets.toy.configs import Configs cf = Configs() total_stime = time.time() times = {} # cf.server_env = True # cf.data_dir = "experiments/dev_data" cf.exp_dir = "experiments/dev/" cf.plot_dir = cf.exp_dir + "plots" os.makedirs(cf.exp_dir, exist_ok=True) cf.fold = 0 logger = utils.get_logger(cf.exp_dir) gens = get_train_generators(cf, logger) train_loader = gens['train'] for i in range(0): stime = time.time() print("producing training batch nr ", i) ex_batch = next(train_loader) times["train_batch"] = time.time() - stime #experiments/dev/dev_exbatch_{}.png".format(i) plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=False) val_loader = gens['val_sampling'] stime = time.time() for i in range(1): ex_batch = next(val_loader) times["val_batch"] = time.time() - stime stime = time.time() #"experiments/dev/dev_exvalbatch_{}.png" plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=True) times["val_plot"] = time.time() - stime # test_loader = get_test_generator(cf, logger)["test"] stime = time.time() ex_batch = test_loader.generate_train_batch(pid=None) times["test_batch"] = time.time() - stime stime = time.time() plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png", vmin=0) times["test_patchbatch_plot"] = time.time() - stime print("Times recorded throughout:") for (k, v) in times.items(): print(k, "{:.2f}".format(v)) mins, secs = divmod((time.time() - total_stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f8f0800..6d9b914 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,62 +1,65 @@ absl-py==0.8.1 backcall==0.1.0 -batchgenerators==0.19.3 +batchgenerators==0.19.7 cachetools==3.1.1 certifi==2019.11.28 chardet==3.0.4 cycler==0.10.0 Cython==0.29.14 decorator==4.4.1 future==0.18.2 google-auth==1.7.2 google-auth-oauthlib==0.4.1 grpcio==1.25.0 idna==2.8 imageio==2.6.1 ipython==7.9.0 ipython-genutils==0.2.0 jedi==0.15.1 joblib==0.14.0 kiwisolver==1.1.0 linecache2==1.0.0 Markdown==3.1.1 matplotlib==3.1.2 networkx==2.4 nms-extension==0.0.0 numpy==1.17.4 oauthlib==3.1.0 pandas==0.25.3 parso==0.5.1 pexpect==4.7.0 pickleshare==0.7.5 Pillow==6.2.1 prompt-toolkit==2.0.10 protobuf==3.11.1 psutil==5.7.0 ptyprocess==0.6.0 pyasn1==0.4.8 pyasn1-modules==0.2.7 Pygments==2.5.2 pyparsing==2.4.5 python-dateutil==2.8.1 pytz==2019.3 PyWavelets==1.1.1 RegRCNN==0.0.2 requests==2.22.0 requests-oauthlib==1.3.0 +RoIAlign-extension-2D==0.0.0 +RoIAlign-extension-3D==0.0.0 rsa==4.0 scikit-image==0.16.2 scikit-learn==0.21.3 scipy==1.3.1 SimpleITK==1.2.3 six==1.13.0 tensorboard==2.0.2 +threadpoolctl==2.0.0 torch==1.3.1 torchvision==0.4.2 tqdm==4.39.0 traceback2==1.4.0 traitlets==4.3.3 unittest2==1.1.0 urllib3==1.25.7 wcwidth==0.1.7 Werkzeug==0.16.0 diff --git a/unittests.py b/unittests.py index 329fdcf..2811c00 100644 --- a/unittests.py +++ b/unittests.py @@ -1,622 +1,623 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import unittest import os import pickle import time from multiprocessing import Pool import subprocess from pathlib import Path import numpy as np import pandas as pd import torch import torchvision as tv import tqdm import plotting as plg import utils.exp_utils as utils import utils.model_utils as mutils """ Note on unittests: run this file either in the way intended for unittests by starting the script with python -m unittest unittests.py or start it as a normal python file as python unittests.py. You can selective run single tests by calling python -m unittest unittests.TestClassOfYourChoice, where TestClassOfYourChoice is the name of the test defined below, e.g., CompareFoldSplits. """ def inspect_info_df(pp_dir): """ use your debugger to look into the info df of a pp dir. :param pp_dir: preprocessed-data directory """ info_df = pd.read_pickle(os.path.join(pp_dir, "info_df.pickle")) return def generate_boxes(count, dim=2, h=100, w=100, d=20, normalize=False, on_grid=False, seed=0): """ generate boxes of format [y1, x1, y2, x2, (z1, z2)]. :param count: nr of boxes :param dim: dimension of boxes (2 or 3) :return: boxes in format (n_boxes, 4 or 6), scores """ np.random.seed(seed) if on_grid: lower_y = np.random.randint(0, h // 2, (count,)) lower_x = np.random.randint(0, w // 2, (count,)) upper_y = np.random.randint(h // 2, h, (count,)) upper_x = np.random.randint(w // 2, w, (count,)) if dim == 3: lower_z = np.random.randint(0, d // 2, (count,)) upper_z = np.random.randint(d // 2, d, (count,)) else: lower_y = np.random.rand(count) * h / 2. lower_x = np.random.rand(count) * w / 2. upper_y = (np.random.rand(count) + 1.) * h / 2. upper_x = (np.random.rand(count) + 1.) * w / 2. if dim == 3: lower_z = np.random.rand(count) * d / 2. upper_z = (np.random.rand(count) + 1.) * d / 2. if dim == 3: boxes = np.array(list(zip(lower_y, lower_x, upper_y, upper_x, lower_z, upper_z))) # add an extreme box that tests the boundaries boxes = np.concatenate((boxes, np.array([[0., 0., h, w, 0, d]]))) else: boxes = np.array(list(zip(lower_y, lower_x, upper_y, upper_x))) boxes = np.concatenate((boxes, np.array([[0., 0., h, w]]))) scores = np.random.rand(count + 1) if normalize: divisor = np.array([h, w, h, w, d, d]) if dim == 3 else np.array([h, w, h, w]) boxes = boxes / divisor return boxes, scores #------- perform integrity checks on data set(s) ----------- class VerifyLIDCSAIntegrity(unittest.TestCase): """ Perform integrity checks on preprocessed single-annotator GTs of LIDC data set. """ @staticmethod def check_patient_sa_gt(pid, pp_dir, check_meta_files, check_info_df): faulty_cases = pd.DataFrame(columns=['pid', 'rater', 'cl_targets', 'roi_ids']) all_segs = np.load(os.path.join(pp_dir, pid + "_rois.npz"), mmap_mode='r') all_segs = all_segs[list(all_segs.keys())[0]] all_roi_ids = np.unique(all_segs[all_segs > 0]) assert len(all_roi_ids) == np.max(all_segs), "roi ids not consecutive" if check_meta_files: meta_file = os.path.join(pp_dir, pid + "_meta_info.pickle") with open(meta_file, "rb") as handle: info = pickle.load(handle) assert info["pid"] == pid, "wrong pid in meta_file" all_cl_targets = info["class_target"] if check_info_df: info_df = pd.read_pickle(os.path.join(pp_dir, "info_df.pickle")) pid_info = info_df[info_df.pid == pid] assert len(pid_info) == 1, "found {} entries for pid {} in info df, expected exactly 1".format(len(pid_info), pid) if check_meta_files: assert pid_info[ "class_target"] == all_cl_targets, "meta_info and info_df class targets mismatch:\n{}\n{}".format( pid_info["class_target"], all_cl_targets) all_cl_targets = pid_info["class_target"].iloc[0] assert len(all_roi_ids) == len(all_cl_targets) for rater in range(4): seg = all_segs[rater] roi_ids = np.unique(seg[seg > 0]) cl_targs = np.array([roi[rater] for roi in all_cl_targets]) assert np.count_nonzero(cl_targs) == len(roi_ids), "rater {} has targs {} but roi ids {}".format(rater, cl_targs, roi_ids) assert len(cl_targs) >= len(roi_ids), "not all marked rois have a label" for zeroix_roi_id, rating in enumerate(cl_targs): if not ((rating > 0) == (np.any(seg == zeroix_roi_id + 1))): print("\n\nFAULTY CASE:", end=" ", ) print("pid {}, rater {}, cl_targs {}, ids {}\n".format(pid, rater, cl_targs, roi_ids)) faulty_cases = faulty_cases.append( {'pid': pid, 'rater': rater, 'cl_targets': cl_targs, 'roi_ids': roi_ids}, ignore_index=True) print("finished checking pid {}, {} faulty cases".format(pid, len(faulty_cases))) return faulty_cases def check_sa_gts(cf, pp_dir, pid_subset=None, check_meta_files=False, check_info_df=True, processes=os.cpu_count()): report_name = "verify_seg_label_pairings.csv" pids = {file_name.split("_")[0] for file_name in os.listdir(pp_dir) if file_name not in [report_name, "info_df.pickle"]} if pid_subset is not None: pids = [pid for pid in pids if pid in pid_subset] faulty_cases = pd.DataFrame(columns=['pid', 'rater', 'cl_targets', 'roi_ids']) p = Pool(processes=processes) mp_args = zip(pids, [pp_dir]*len(pids), [check_meta_files]*len(pids), [check_info_df]*len(pids)) patient_cases = p.starmap(self.check_patient_sa_gt, mp_args) p.close(); p.join() faulty_cases = faulty_cases.append(patient_cases, sort=False) print("\n\nfaulty case count {}".format(len(faulty_cases))) print(faulty_cases) findings_file = os.path.join(pp_dir, "verify_seg_label_pairings.csv") faulty_cases.to_csv(findings_file) assert len(faulty_cases)==0, "there was a faulty case in data set {}.\ncheck {}".format(pp_dir, findings_file) def test(self): pp_root = "/mnt/HDD2TB/Documents/data/" pp_dir = "lidc/pp_20190805" gt_dir = os.path.join(pp_root, pp_dir, "patient_gts_sa") self.check_sa_gts(gt_dir, check_meta_files=True, check_info_df=False, pid_subset=None) # ["0811a", "0812a"]) #------ compare segmentation gts of preprocessed data sets ------ class CompareSegGTs(unittest.TestCase): """ load and compare pre-processed gts by dice scores of segmentations. """ @staticmethod def group_seg_paths(ref_path, comp_paths): # not working recursively ref_files = [fn for fn in os.listdir(ref_path) if os.path.isfile(os.path.join(ref_path, fn)) and 'seg' in fn and fn.endswith('.npy')] comp_files = [[os.path.join(c_path, fn) for c_path in comp_paths] for fn in ref_files] ref_files = [os.path.join(ref_path, fn) for fn in ref_files] return zip(ref_files, comp_files) @staticmethod def load_calc_dice(paths): dices = [] ref_seg = np.load(paths[0])[np.newaxis, np.newaxis] n_classes = len(np.unique(ref_seg)) ref_seg = mutils.get_one_hot_encoding(ref_seg, n_classes) for c_file in paths[1]: c_seg = np.load(c_file)[np.newaxis, np.newaxis] assert n_classes == len(np.unique(c_seg)), "unequal nr of objects/classes betw segs {} {}".format(paths[0], c_file) c_seg = mutils.get_one_hot_encoding(c_seg, n_classes) dice = mutils.dice_per_batch_inst_and_class(c_seg, ref_seg, n_classes, convert_to_ohe=False) dices.append(dice) print("processed ref_path {}".format(paths[0])) return np.mean(dices), np.std(dices) def iterate_files(self, grouped_paths, processes=os.cpu_count()): p = Pool(processes) means_stds = np.array(p.map(self.load_calc_dice, grouped_paths)) p.close(); p.join() min_dice = np.min(means_stds[:, 0]) print("min mean dice {:.2f}, max std {:.4f}".format(min_dice, np.max(means_stds[:, 1]))) assert min_dice > 1-1e5, "compared seg gts have insufficient minimum mean dice overlap of {}".format(min_dice) def test(self): ref_path = '/mnt/HDD2TB/Documents/data/prostate/data_t2_250519_ps384_gs6071' comp_paths = ['/mnt/HDD2TB/Documents/data/prostate/data_t2_190419_ps384_gs6071', ] paths = self.group_seg_paths(ref_path, comp_paths) self.iterate_files(paths) #------- check if cross-validation fold splits of different experiments are identical ---------- class CompareFoldSplits(unittest.TestCase): """ Find evtl. differences in cross-val file splits across different experiments. """ @staticmethod def group_id_paths(ref_exp_dir, comp_exp_dirs): f_name = 'fold_ids.pickle' ref_paths = os.path.join(ref_exp_dir, f_name) assert os.path.isfile(ref_paths), "ref file {} does not exist.".format(ref_paths) ref_paths = [ref_paths for comp_ed in comp_exp_dirs] comp_paths = [os.path.join(comp_ed, f_name) for comp_ed in comp_exp_dirs] return zip(ref_paths, comp_paths) @staticmethod def comp_fold_ids(mp_input): fold_ids1, fold_ids2 = mp_input with open(fold_ids1, 'rb') as f: fold_ids1 = pickle.load(f) try: with open(fold_ids2, 'rb') as f: fold_ids2 = pickle.load(f) except FileNotFoundError: print("comp file {} does not exist.".format(fold_ids2)) return n_splits = len(fold_ids1) assert n_splits == len(fold_ids2), "mismatch n splits: ref has {}, comp {}".format(n_splits, len(fold_ids2)) split_diffs = [np.setdiff1d(fold_ids1[s], fold_ids2[s]) for s in range(n_splits)] all_equal = np.any(split_diffs) return (split_diffs, all_equal) def iterate_exp_dirs(self, ref_exp, comp_exps, processes=os.cpu_count()): grouped_paths = list(self.group_id_paths(ref_exp, comp_exps)) print("performing {} comparisons of cross-val file splits".format(len(grouped_paths))) p = Pool(processes) split_diffs = p.map(self.comp_fold_ids, grouped_paths) p.close(); p.join() df = pd.DataFrame(index=range(0,len(grouped_paths)), columns=["ref", "comp", "all_equal"])#, "diffs"]) for ix, (ref, comp) in enumerate(grouped_paths): df.iloc[ix] = [ref, comp, split_diffs[ix][1]]#, split_diffs[ix][0]] print("Any splits not equal?", df.all_equal.any()) assert not df.all_equal.any(), "a split set is different from reference split set, {}".format(df[~df.all_equal]) def test(self): exp_parent_dir = '/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/' ref_exp = '/home/gregor/networkdrives/E132-Cluster-Projects/prostate/experiments/gs6071_detfpn2d_cl_bs10' comp_exps = [os.path.join(exp_parent_dir, p) for p in os.listdir(exp_parent_dir)] comp_exps = [p for p in comp_exps if os.path.isdir(p) and p != ref_exp] self.iterate_exp_dirs(ref_exp, comp_exps) #------- check if cross-validation fold splits of a single experiment are actually incongruent (as required) ---------- class VerifyFoldSplits(unittest.TestCase): """ Check, for a single fold_ids file, i.e., for a single experiment, if the assigned folds (assignment of data identifiers) is actually incongruent. No overlaps between folds are required for a correct cross validation. """ @staticmethod def verify_fold_ids(splits): for i, split1 in enumerate(splits): for j, split2 in enumerate(splits): if j > i: inter = np.intersect1d(split1, split2) if len(inter) > 0: raise Exception("Split {} and {} intersect by pids {}".format(i, j, inter)) def test(self): exp_dir = "/home/gregor/Documents/medicaldetectiontoolkit/datasets/lidc/experiments/dev" check_file = os.path.join(exp_dir, 'fold_ids.pickle') with open(check_file, 'rb') as handle: splits = pickle.load(handle) self.verify_fold_ids(splits) # -------- check own nms CUDA implement against own numpy implement ------ class CheckNMSImplementation(unittest.TestCase): @staticmethod def assert_res_equality(keep_ics1, keep_ics2, boxes, scores, tolerance=0, names=("res1", "res2")): """ :param keep_ics1: keep indices (results), torch.Tensor of shape (n_ics,) :param keep_ics2: :return: """ keep_ics1, keep_ics2 = keep_ics1.cpu().numpy(), keep_ics2.cpu().numpy() discrepancies = np.setdiff1d(keep_ics1, keep_ics2) try: checks = np.array([ len(discrepancies) <= tolerance ]) except: checks = np.zeros((1,)).astype("bool") msgs = np.array([ """{}: {} \n{}: {} \nboxes: {}\n {}\n""".format(names[0], keep_ics1, names[1], keep_ics2, boxes, scores) ]) assert np.all(checks), "NMS: results mismatch: " + "\n".join(msgs[~checks]) def single_case(self, count=20, dim=3, threshold=0.2, seed=0): boxes, scores = generate_boxes(count, dim, seed=seed, h=320, w=280, d=30) keep_numpy = torch.tensor(mutils.nms_numpy(boxes, scores, threshold)) # for some reason torchvision nms requires box coords as floats. boxes = torch.from_numpy(boxes).type(torch.float32) scores = torch.from_numpy(scores).type(torch.float32) if dim == 2: """need to wait until next pytorch release where they fixed nms on cpu (currently they have >= where it needs to be >. """ keep_ops = tv.ops.nms(boxes, scores, threshold) # self.assert_res_equality(keep_numpy, keep_ops, boxes, scores, tolerance=0, names=["np", "ops"]) pass boxes = boxes.cuda() scores = scores.cuda() keep = self.nms_ext.nms(boxes, scores, threshold) self.assert_res_equality(keep_numpy, keep, boxes, scores, tolerance=0, names=["np", "cuda"]) def test(self, n_cases=200, box_count=30, threshold=0.5): # dynamically import module so that it doesn't affect other tests if import fails self.nms_ext = utils.import_module("nms_ext", 'custom_extensions/nms/nms.py') # change seed to something fix if you want exactly reproducible test seed0 = np.random.randint(50) print("NMS test progress (done/total box configurations) 2D:", end="\n") for i in tqdm.tqdm(range(n_cases)): self.single_case(count=box_count, dim=2, threshold=threshold, seed=seed0+i) print("NMS test progress (done/total box configurations) 3D:", end="\n") for i in tqdm.tqdm(range(n_cases)): self.single_case(count=box_count, dim=3, threshold=threshold, seed=seed0+i) return class CheckRoIAlignImplementation(unittest.TestCase): def prepare(self, dim=2): b, c, h, w = 1, 3, 50, 50 # feature map, (b, c, h, w(, z)) if dim == 2: fmap = torch.rand(b, c, h, w).cuda() # rois = torch.tensor([[ # [0.1, 0.1, 0.3, 0.3], # [0.2, 0.2, 0.4, 0.7], # [0.5, 0.7, 0.7, 0.9], # ]]).cuda() pool_size = (7, 7) rois = generate_boxes(5, dim=dim, h=h, w=w, on_grid=True, seed=np.random.randint(50))[0] elif dim == 3: d = 20 fmap = torch.rand(b, c, h, w, d).cuda() # rois = torch.tensor([[ # [0.1, 0.1, 0.3, 0.3, 0.1, 0.1], # [0.2, 0.2, 0.4, 0.7, 0.2, 0.4], # [0.5, 0.0, 0.7, 1.0, 0.4, 0.5], # [0.0, 0.0, 0.9, 1.0, 0.0, 1.0], # ]]).cuda() pool_size = (7, 7, 3) rois = generate_boxes(5, dim=dim, h=h, w=w, d=d, on_grid=True, seed=np.random.randint(50), normalize=False)[0] else: raise ValueError("dim needs to be 2 or 3") rois = [torch.from_numpy(rois).type(dtype=torch.float32).cuda(), ] fmap.requires_grad_(True) return fmap, rois, pool_size def check_2d(self): """ check vs torchvision ops not possible as on purpose different approach. :return: """ raise NotImplementedError # fmap, rois, pool_size = self.prepare(dim=2) # ra_object = self.ra_ext.RoIAlign(output_size=pool_size, spatial_scale=1., sampling_ratio=-1) # align_ext = ra_object(fmap, rois) # loss_ext = align_ext.sum() # loss_ext.backward() # # rois_swapped = [rois[0][:, [1,3,0,2]]] # align_ops = tv.ops.roi_align(fmap, rois_swapped, pool_size) # loss_ops = align_ops.sum() # loss_ops.backward() # # assert (loss_ops == loss_ext), "sum of roialign ops and extension 2D diverges" # assert (align_ops == align_ext).all(), "ROIAlign failed 2D test" def check_3d(self): fmap, rois, pool_size = self.prepare(dim=3) ra_object = self.ra_ext.RoIAlign(output_size=pool_size, spatial_scale=1., sampling_ratio=-1) align_ext = ra_object(fmap, rois) loss_ext = align_ext.sum() loss_ext.backward() align_np = mutils.roi_align_3d_numpy(fmap.cpu().detach().numpy(), [roi.cpu().numpy() for roi in rois], pool_size) align_np = np.squeeze(align_np) # remove singleton batch dim align_ext = align_ext.cpu().detach().numpy() assert np.allclose(align_np, align_ext, rtol=1e-5, atol=1e-8), "RoIAlign differences in numpy and CUDA implement" def specific_example_check(self): # dummy input self.ra_ext = utils.import_module("ra_ext", 'custom_extensions/roi_align/roi_align.py') exp = 6 pool_size = (2,2) fmap = torch.arange(exp**2).view(exp,exp).unsqueeze(0).unsqueeze(0).cuda().type(dtype=torch.float32) boxes = torch.tensor([[1., 1., 5., 5.]]).cuda()/exp ind = torch.tensor([0.]*len(boxes)).cuda().type(torch.float32) y_exp, x_exp = fmap.shape[2:] # exp = expansion boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp], dtype=torch.float32).cuda()) boxes = torch.cat((ind.unsqueeze(1), boxes), dim=1) aligned_tv = tv.ops.roi_align(fmap, boxes, output_size=pool_size, sampling_ratio=-1) aligned = self.ra_ext.roi_align_2d(fmap, boxes, output_size=pool_size, sampling_ratio=-1) boxes_3d = torch.cat((boxes, torch.tensor([[-1.,1.]]*len(boxes)).cuda()), dim=1) fmap_3d = fmap.unsqueeze(dim=-1) pool_size = (*pool_size,1) ra_object = self.ra_ext.RoIAlign(output_size=pool_size, spatial_scale=1.,) aligned_3d = ra_object(fmap_3d, boxes_3d) expected_res = torch.tensor([[[[10.5000, 12.5000], [22.5000, 24.5000]]]]).cuda() expected_res_3d = torch.tensor([[[[[10.5000],[12.5000]], [[22.5000],[24.5000]]]]]).cuda() assert torch.all(aligned==expected_res), "2D RoIAlign check vs. specific example failed. res: {}\n expected: {}\n".format(aligned, expected_res) assert torch.all(aligned_3d==expected_res_3d), "3D RoIAlign check vs. specific example failed. res: {}\n expected: {}\n".format(aligned_3d, expected_res_3d) def manual_check(self): """ print examples from a toy batch to file. :return: """ self.ra_ext = utils.import_module("ra_ext", 'custom_extensions/roi_align/roi_align.py') # actual mrcnn mask input from datasets.toy import configs cf = configs.Configs() cf.exp_dir = "datasets/toy/experiments/dev/" cf.plot_dir = cf.exp_dir + "plots" os.makedirs(cf.exp_dir, exist_ok=True) cf.fold = 0 cf.n_workers = 1 logger = utils.get_logger(cf.exp_dir) data_loader = utils.import_module('data_loader', os.path.join("datasets", "toy", 'data_loader.py')) batch_gen = data_loader.get_train_generators(cf, logger=logger) batch = next(batch_gen['train']) roi_mask = np.zeros((1, 320, 200)) bb_target = (np.array([50, 40, 90, 120])).astype("int") roi_mask[:, bb_target[0]+1:bb_target[2]+1, bb_target[1]+1:bb_target[3]+1] = 1. #batch = {"roi_masks": np.array([np.array([roi_mask, roi_mask]), np.array([roi_mask])]), "bb_target": [[bb_target, bb_target + 25], [bb_target-20]]} #batch_boxes_cor = [torch.tensor(batch_el_boxes).cuda().float() for batch_el_boxes in batch_cor["bb_target"]] batch_boxes = [torch.tensor(batch_el_boxes).cuda().float() for batch_el_boxes in batch["bb_target"]] #import IPython; IPython.embed() for b in range(len(batch_boxes)): roi_masks = batch["roi_masks"][b] #roi_masks_cor = batch_cor["roi_masks"][b] if roi_masks.sum()>0: boxes = batch_boxes[b] roi_masks = torch.tensor(roi_masks).cuda().type(dtype=torch.float32) box_ids = torch.arange(roi_masks.shape[0]).cuda().unsqueeze(1).type(dtype=torch.float32) masks = tv.ops.roi_align(roi_masks, [boxes], cf.mask_shape) masks = masks.squeeze(1) masks = torch.round(masks) masks_own = self.ra_ext.roi_align_2d(roi_masks, torch.cat((box_ids, boxes), dim=1), cf.mask_shape) boxes = boxes.type(torch.int) #print("check roi mask", roi_masks[0, 0, boxes[0][0]:boxes[0][2], boxes[0][1]:boxes[0][3]].sum(), (boxes[0][2]-boxes[0][0]) * (boxes[0][3]-boxes[0][1])) #print("batch masks", batch["roi_masks"]) masks_own = masks_own.squeeze(1) masks_own = torch.round(masks_own) #import IPython; IPython.embed() for mix, mask in enumerate(masks): fig = plg.plt.figure() ax = fig.add_subplot() ax.imshow(roi_masks[mix][0].cpu().numpy(), cmap="gray", vmin=0.) ax.axis("off") y1, x1, y2, x2 = boxes[mix] bbox = plg.mpatches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=0.9, edgecolor="c", facecolor='none') ax.add_patch(bbox) x1, y1, x2, y2 = boxes[mix] bbox = plg.mpatches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=0.9, edgecolor="r", facecolor='none') ax.add_patch(bbox) debug_dir = Path("/home/gregor/Documents/regrcnn/datasets/toy/experiments/debugroial") os.makedirs(debug_dir, exist_ok=True) plg.plt.savefig(debug_dir/"mask_b{}_{}.png".format(b, mix)) plg.plt.imsave(debug_dir/"mask_b{}_{}_pooled_tv.png".format(b, mix), mask.cpu().numpy(), cmap="gray", vmin=0.) plg.plt.imsave(debug_dir/"mask_b{}_{}_pooled_own.png".format(b, mix), masks_own[mix].cpu().numpy(), cmap="gray", vmin=0.) return def test(self): # dynamically import module so that it doesn't affect other tests if import fails self.ra_ext = utils.import_module("ra_ext", 'custom_extensions/roi_align/roi_align.py') self.specific_example_check() # 2d test #self.check_2d() # 3d test self.check_3d() return class CheckRuntimeErrors(unittest.TestCase): """ Check if minimal examples of the exec.py module finish without runtime errors. This check requires a working path to data in the toy-dataset configs. """ def test(self): cf = utils.import_module("toy_cf", 'datasets/toy/configs.py').Configs() exp_dir = "./unittesting/" #checks = {"retina_net": False, "mrcnn": False} #print("Testing for runtime errors with models {}".format(list(checks.keys()))) #for model in tqdm.tqdm(list(checks.keys())): # cf.model = model # cf.model_path = 'models/{}.py'.format(cf.model if not 'retina' in cf.model else 'retina_net') # cf.model_path = os.path.join(cf.source_dir, cf.model_path) # {'mrcnn': cf.add_mrcnn_configs, # 'retina_net': cf.add_mrcnn_configs, 'retina_unet': cf.add_mrcnn_configs, # 'detection_unet': cf.add_det_unet_configs, 'detection_fpn': cf.add_det_fpn_configs # }[model]() # todo change structure of configs-handling with exec.py so that its dynamically parseable instead of needing to # todo be changed in the file all the time. checks = {cf.model:False} completed_process = subprocess.run("python exec.py --dev --dataset_name toy -m train_test --exp_dir {}".format(exp_dir), shell=True, capture_output=True, text=True) if completed_process.returncode!=0: print("Runtime test of model {} failed due to\n{}".format(cf.model, completed_process.stderr)) else: checks[cf.model] = True subprocess.call("rm -rf {}".format(exp_dir), shell=True) assert all(checks.values()), "A runtime test crashed." class MulithreadedDataiterator(unittest.TestCase): def test(self): print("Testing multithreaded iterator.") dataset = "toy" exp_dir = Path("datasets/{}/experiments/dev".format(dataset)) cf_file = utils.import_module("cf_file", exp_dir/"configs.py") cf = cf_file.Configs() dloader = utils.import_module('data_loader', 'datasets/{}/data_loader.py'.format(dataset)) cf.exp_dir = Path(exp_dir) cf.n_workers = 5 cf.batch_size = 3 cf.fold = 0 cf.plot_dir = cf.exp_dir / "plots" logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval) cf.num_val_batches = "all" cf.val_mode = "val_sampling" cf.n_workers = 8 batch_gens = dloader.get_train_generators(cf, logger, data_statistics=False) val_loader = batch_gens["val_sampling"] for epoch in range(4): produced_ids = [] for i in range(batch_gens['n_val']): batch = next(val_loader) produced_ids.append(batch["pid"]) uni, cts = np.unique(np.concatenate(produced_ids), return_counts=True) assert np.all(cts < 3), "with batch size one: every item should occur exactly once.\n uni {}, cts {}".format( uni[cts>2], cts[cts>2]) #assert len(np.setdiff1d(val_loader.generator.dataset_pids, uni))==0, "not all val pids were shown." assert len(np.setdiff1d(uni, val_loader.generator.dataset_pids))==0, "pids shown that are not val set. impossible?" - + cf.n_workers = os.cpu_count() + cf.batch_size = int(val_loader.generator.dataset_length / cf.n_workers) + 2 val_loader = dloader.create_data_gen_pipeline(cf, val_loader.generator._data, do_aug=False, sample_pids_w_replace=False, max_batches=None, raise_stop_iteration=True) for epoch in range(2): produced_ids = [] for b, batch in enumerate(val_loader): produced_ids.append(batch["pid"]) uni, cts = np.unique(np.concatenate(produced_ids), return_counts=True) assert np.all(cts == 1), "with batch size one: every item should occur exactly once.\n uni {}, cts {}".format( uni[cts>1], cts[cts>1]) assert len(np.setdiff1d(val_loader.generator.dataset_pids, uni))==0, "not all val pids were shown." assert len(np.setdiff1d(uni, val_loader.generator.dataset_pids))==0, "pids shown that are not val set. impossible?" pass if __name__=="__main__": stime = time.time() t = CheckRoIAlignImplementation() t.manual_check() #unittest.main() mins, secs = divmod((time.time() - stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) \ No newline at end of file diff --git a/utils/dataloader_utils.py b/utils/dataloader_utils.py index 0724b28..8760408 100644 --- a/utils/dataloader_utils.py +++ b/utils/dataloader_utils.py @@ -1,723 +1,730 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import plotting as plg import os from multiprocessing import Pool, Lock import pickle import warnings import numpy as np import pandas as pd from batchgenerators.transforms.abstract_transforms import AbstractTransform from scipy.ndimage.measurements import label as lb from torch.utils.data import Dataset as torchDataset from batchgenerators.dataloading.data_loader import SlimDataLoaderBase import utils.exp_utils as utils import data_manager as dmanager for msg in ["This figure includes Axes that are not compatible with tight_layout", "Data has no positive values, and therefore cannot be log-scaled."]: warnings.filterwarnings("ignore", msg) class AttributeDict(dict): __getattr__ = dict.__getitem__ __setattr__ = dict.__setitem__ ################################## # data loading, organisation # ################################## class fold_generator: """ generates splits of indices for a given length of a dataset to perform n-fold cross-validation. splits each fold into 3 subsets for training, validation and testing. This form of cross validation uses an inner loop test set, which is useful if test scores shall be reported on a statistically reliable amount of patients, despite limited size of a dataset. If hold out test set is provided and hence no inner loop test set needed, just add test_idxs to the training data in the dataloader. This creates straight-forward train-val splits. :returns names list: list of len n_splits. each element is a list of len 3 for train_ix, val_ix, test_ix. """ def __init__(self, seed, n_splits, len_data): """ :param seed: Random seed for splits. :param n_splits: number of splits, e.g. 5 splits for 5-fold cross-validation :param len_data: number of elements in the dataset. """ self.tr_ix = [] self.val_ix = [] self.te_ix = [] self.slicer = None self.missing = 0 self.fold = 0 self.len_data = len_data self.n_splits = n_splits self.myseed = seed self.boost_val = 0 def init_indices(self): t = list(np.arange(self.l)) # round up to next splittable data amount. split_length = int(np.ceil(len(t) / float(self.n_splits))) self.slicer = split_length self.mod = len(t) % self.n_splits if self.mod > 0: # missing is the number of folds, in which the new splits are reduced to account for missing data. self.missing = self.n_splits - self.mod self.te_ix = t[:self.slicer] self.tr_ix = t[self.slicer:] self.val_ix = self.tr_ix[:self.slicer] self.tr_ix = self.tr_ix[self.slicer:] def new_fold(self): slicer = self.slicer if self.fold < self.missing : slicer = self.slicer - 1 temp = self.te_ix # catch exception mod == 1: test set collects 1+ data since walk through both roudned up splits. # account for by reducing last fold split by 1. if self.fold == self.n_splits-2 and self.mod ==1: temp += self.val_ix[-1:] self.val_ix = self.val_ix[:-1] self.te_ix = self.val_ix self.val_ix = self.tr_ix[:slicer] self.tr_ix = self.tr_ix[slicer:] + temp def get_fold_names(self): names_list = [] rgen = np.random.RandomState(self.myseed) cv_names = np.arange(self.len_data) rgen.shuffle(cv_names) self.l = len(cv_names) self.init_indices() for split in range(self.n_splits): train_names, val_names, test_names = cv_names[self.tr_ix], cv_names[self.val_ix], cv_names[self.te_ix] names_list.append([train_names, val_names, test_names, self.fold]) self.new_fold() self.fold += 1 return names_list class FoldGenerator(): r"""takes a set of elements (identifiers) and randomly splits them into the specified amt of subsets. """ def __init__(self, identifiers, seed, n_splits=5): self.ids = np.array(identifiers) self.n_splits = n_splits self.seed = seed def generate_splits(self, n_splits=None): if n_splits is None: n_splits = self.n_splits rgen = np.random.RandomState(self.seed) rgen.shuffle(self.ids) self.splits = list(np.array_split(self.ids, n_splits, axis=0)) # already returns list, but to be sure return self.splits class Dataset(torchDataset): r"""Parent Class for actual Dataset classes to inherit from! """ def __init__(self, cf, data_sourcedir=None): super(Dataset, self).__init__() self.cf = cf self.data_sourcedir = cf.data_sourcedir if data_sourcedir is None else data_sourcedir self.data_dir = cf.data_dir if hasattr(cf, 'data_dir') else self.data_sourcedir self.data_dest = cf.data_dest if hasattr(cf, "data_dest") else self.data_sourcedir self.data = {} self.set_ids = [] def copy_data(self, cf, file_subset, keep_packed=False, del_after_unpack=False): if os.path.normpath(self.data_sourcedir) != os.path.normpath(self.data_dest): self.data_sourcedir = os.path.join(self.data_sourcedir, '') args = AttributeDict({ "source" : self.data_sourcedir, "destination" : self.data_dest, "recursive" : True, "cp_only_npz" : False, "keep_packed" : keep_packed, "del_after_unpack" : del_after_unpack, "threads" : 16 if self.cf.server_env else os.cpu_count() }) dmanager.copy(args, file_subset=file_subset) self.data_dir = self.data_dest def __len__(self): return len(self.data) def __getitem__(self, id): """Return a sample of the dataset, i.e.,the dict of the id """ return self.data[id] def __iter__(self): return self.data.__iter__() def init_FoldGenerator(self, seed, n_splits): self.fg = FoldGenerator(self.set_ids, seed=seed, n_splits=n_splits) def generate_splits(self, check_file): if not os.path.exists(check_file): self.fg.generate_splits() with open(check_file, 'wb') as handle: pickle.dump(self.fg.splits, handle) else: with open(check_file, 'rb') as handle: self.fg.splits = pickle.load(handle) def calc_statistics(self, subsets=None, plot_dir=None, overall_stats=True): if self.df is None: self.df = pd.DataFrame() balance_t = self.cf.balance_target if hasattr(self.cf, "balance_target") else "class_targets" self.df._metadata.append(balance_t) if balance_t=="class_targets": mapper = lambda cl_id: self.cf.class_id2label[cl_id] labels = self.cf.class_id2label.values() elif balance_t=="rg_bin_targets": mapper = lambda rg_bin: self.cf.bin_id2label[rg_bin] labels = self.cf.bin_id2label.values() # elif balance_t=="regression_targets": # # todo this wont work # mapper = lambda rg_val: AttributeDict({"name":rg_val}) #self.cf.bin_id2label[self.cf.rg_val_to_bin_id(rg_val)] # labels = self.cf.bin_id2label.values() elif balance_t=="lesion_gleasons": mapper = lambda gs: self.cf.gs2label[gs] labels = self.cf.gs2label.values() else: mapper = lambda x: AttributeDict({"name":x}) labels = None for pid, subj_data in self.data.items(): unique_ts, counts = np.unique(subj_data[balance_t], return_counts=True) self.df = self.df.append(pd.DataFrame({"pid": [pid], **{mapper(unique_ts[i]).name: [counts[i]] for i in range(len(unique_ts))}}), ignore_index=True, sort=True) self.df = self.df.fillna(0) if overall_stats: df = self.df.drop("pid", axis=1) df = df.reindex(sorted(df.columns), axis=1).astype('uint32') print("Overall dataset roi counts per target kind:"); print(df.sum()) if subsets is not None: self.df["subset"] = np.nan self.df["display_order"] = np.nan for ix, (subset, pids) in enumerate(subsets.items()): self.df.loc[self.df.pid.isin(pids), "subset"] = subset self.df.loc[self.df.pid.isin(pids), "display_order"] = ix df = self.df.groupby("subset").agg("sum").drop("pid", axis=1, errors='ignore').astype('int64') df = df.sort_values(by=['display_order']).drop('display_order', axis=1) df = df.reindex(sorted(df.columns), axis=1) print("Fold {} dataset roi counts per target kind:".format(self.cf.fold)); print(df) if plot_dir is not None: os.makedirs(plot_dir, exist_ok=True) if subsets is not None: plg.plot_fold_stats(self.cf, df, labels, os.path.join(plot_dir, "data_stats_fold_" + str(self.cf.fold))+".pdf") if overall_stats: plg.plot_data_stats(self.cf, df, labels, os.path.join(plot_dir, 'data_stats_overall.pdf')) return df, labels def get_class_balanced_patients(all_pids, class_targets, batch_size, num_classes, random_ratio=0): ''' samples towards equilibrium of classes (on basis of total RoI counts). for highly imbalanced dataset, this might be a too strong requirement. :param class_targets: dic holding {patient_specifier : ROI class targets}, list position of ROI target corresponds to respective seg label - 1 :param batch_size: :param num_classes: :return: ''' # assert len(all_pids)>=batch_size, "not enough eligible pids {} to form a single batch of size {}".format(len(all_pids), batch_size) class_counts = {k: 0 for k in range(1,num_classes+1)} not_picked = np.array(all_pids) batch_patients = np.empty((batch_size,), dtype=not_picked.dtype) rarest_class = np.random.randint(1,num_classes+1) for ix in range(batch_size): if len(not_picked) == 0: warnings.warn("Dataset too small to generate batch with unique samples; => recycling.") not_picked = np.array(all_pids) np.random.shuffle(not_picked) #this could actually go outside(above) the loop. pick = not_picked[0] for cand in not_picked: if np.count_nonzero(class_targets[cand] == rarest_class) > 0: pick = cand cand_rarest_class = np.argmin([np.count_nonzero(class_targets[cand] == cl) for cl in range(1,num_classes+1)])+1 # if current batch already bigger than the batch random ratio, then # check that weakest class in this patient is not the weakest in current batch (since needs to be boosted) # also that at least one roi of this patient belongs to weakest class. If True, keep patient, else keep looking. if (cand_rarest_class != rarest_class and np.count_nonzero(class_targets[cand] == rarest_class) > 0) \ or ix < int(batch_size * random_ratio): break for c in range(1,num_classes+1): class_counts[c] += np.count_nonzero(class_targets[pick] == c) if not ix < int(batch_size * random_ratio) and class_counts[rarest_class] == 0: # means searched thru whole set without finding rarest class print("Class {} not represented in current dataset.".format(rarest_class)) rarest_class = np.argmin(([class_counts[c] for c in range(1,num_classes+1)]))+1 batch_patients[ix] = pick not_picked = not_picked[not_picked != pick] # removes pick return batch_patients class BatchGenerator(SlimDataLoaderBase): """ create the training/validation batch generator. Randomly sample batch_size patients from the data set, (draw a random slice if 2D), pad-crop them to equal sizes and merge to an array. :param data: data dictionary as provided by 'load_dataset' :param img_modalities: list of strings ['adc', 'b1500'] from config :param batch_size: number of patients to sample for the batch :param pre_crop_size: equal size for merging the patients to a single array (before the final random-crop in data aug.) :return dictionary containing the batch data / seg / pids as lists; the augmenter will later concatenate them into an array. """ def __init__(self, cf, data, sample_pids_w_replace=True, max_batches=None, raise_stop_iteration=False, n_threads=None, seed=0): if n_threads is None: n_threads = cf.n_workers super(BatchGenerator, self).__init__(data, cf.batch_size, number_of_threads_in_multithreaded=n_threads) self.cf = cf self.random_count = int(cf.batch_random_ratio * cf.batch_size) self.plot_dir = os.path.join(self.cf.plot_dir, 'train_generator') os.makedirs(self.plot_dir, exist_ok=True) self.max_batches = max_batches self.raise_stop = raise_stop_iteration self.thread_id = 0 self.batches_produced = 0 self.dataset_length = len(self._data) self.dataset_pids = list(self._data.keys()) + + self.n_filled_threads = min(int(self.dataset_length/self.batch_size), self.number_of_threads_in_multithreaded) + if self.n_filled_threads != self.number_of_threads_in_multithreaded: + print("Adjusting nr of threads from {} to {}.".format(self.number_of_threads_in_multithreaded, + self.n_filled_threads)) + self.rgen = np.random.RandomState(seed=seed) self.eligible_pids = self.rgen.permutation(self.dataset_pids.copy()) - self.eligible_pids = np.array_split(self.eligible_pids, self.number_of_threads_in_multithreaded) + self.eligible_pids = np.array_split(self.eligible_pids, self.n_filled_threads) self.eligible_pids = sorted(self.eligible_pids, key=len, reverse=True) + self.sample_pids_w_replace = sample_pids_w_replace if not self.sample_pids_w_replace: - assert len(self.dataset_pids) / self.number_of_threads_in_multithreaded >= self.batch_size, \ + assert len(self.dataset_pids) / self.n_filled_threads >= self.batch_size, \ "at least one batch needed per thread. dataset size: {}, n_threads: {}, batch_size: {}.".format( - len(self.dataset_pids), self.number_of_threads_in_multithreaded, self.batch_size) + len(self.dataset_pids), self.n_filled_threads, self.batch_size) self.lock = Lock() if hasattr(cf, "balance_target"): # WARNING: "balance targets are only implemented for 1-d targets (or 1-component vectors)" self.balance_target = cf.balance_target else: self.balance_target = "class_targets" self.targets = {k:v[self.balance_target] for (k,v) in self._data.items()} def set_thread_id(self, thread_id): self.thread_ids = self.eligible_pids[thread_id] self.thread_id = thread_id def reset(self): self.batches_produced = 0 self.thread_ids = self.rgen.permutation(self.eligible_pids[self.thread_id]) @staticmethod def sample_targets_to_weights(targets, fg_bg_weights): weights = targets * fg_bg_weights return weights def balance_target_distribution(self, plot=False): """Impose a drawing distribution over samples. Distribution should be designed so that classes' fg and bg examples are (as good as possible) shown in equal frequency. Since we are dealing with rois, fg/bg weights count a sample (e.g., a patient) with **at least** one occurrence as fg, otherwise bg. For fg weights among classes, each RoI counts. :param all_pids: :param self.targets: dic holding {patient_specifier : patient-wise-unique ROI targets} :return: probability distribution over all pids. draw without replace from this. """ self.unique_ts = np.unique([v for pat in self.targets.values() for v in pat]) self.sample_stats = pd.DataFrame(columns=[str(ix)+suffix for ix in self.unique_ts for suffix in ["", "_bg"]], index=list(self.targets.keys())) for pid in self.sample_stats.index: for targ in self.unique_ts: fg_count = np.count_nonzero(self.targets[pid] == targ) self.sample_stats.loc[pid, str(targ)] = int(fg_count > 0) self.sample_stats.loc[pid, str(targ)+"_bg"] = int(fg_count == 0) self.targ_stats = self.sample_stats.agg( ("sum", lambda col: col.sum() / len(self._data)), axis=0, sort=False).rename({"": "relative"}) anchor = 1. - self.targ_stats.loc["relative"].iloc[0] self.fg_bg_weights = anchor / self.targ_stats.loc["relative"] cum_weights = anchor * len(self.fg_bg_weights) self.fg_bg_weights /= cum_weights self.p_probs = self.sample_stats.apply(self.sample_targets_to_weights, args=(self.fg_bg_weights,), axis=1).sum(axis=1) self.p_probs = self.p_probs / self.p_probs.sum() if plot: print("Applying class-weights:\n {}".format(self.fg_bg_weights)) if len(self.sample_stats.columns) == 2: # assert that probs are calc'd correctly: # (self.p_probs * self.sample_stats["1"]).sum() == (self.p_probs * self.sample_stats["1_bg"]).sum() # only works if one label per patient (multi-label expectations depend on multi-label occurences). expectations = [] for targ in self.sample_stats.columns: expectations.append((self.p_probs * self.sample_stats[targ]).sum()) assert np.allclose(expectations, expectations[0], atol=1e-4), "expectation values for fgs/bgs: {}".format(expectations) self.stats = {"roi_counts": np.zeros(len(self.unique_ts,), dtype='uint32'), "empty_counts": np.zeros(len(self.unique_ts,), dtype='uint32')} if plot: os.makedirs(self.plot_dir, exist_ok=True) plg.plot_batchgen_distribution(self.cf, self.dataset_pids, self.p_probs, self.balance_target, out_file=os.path.join(self.plot_dir, "train_gen_distr_"+str(self.cf.fold)+".png")) return self.p_probs def get_batch_pids(self): - if self.max_batches is not None and self.batches_produced * self.number_of_threads_in_multithreaded \ + if self.max_batches is not None and self.batches_produced * self.n_filled_threads \ + self.thread_id >= self.max_batches: self.reset() raise StopIteration if self.sample_pids_w_replace: # fully random patients batch_pids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False)) # target-balanced patients batch_pids += list(np.random.choice( self.dataset_pids, size=self.batch_size - self.random_count, replace=False, p=self.p_probs)) else: with self.lock: if len(self.thread_ids) == 0: if self.raise_stop: self.reset() raise StopIteration else: self.thread_ids = self.rgen.permutation(self.eligible_pids[self.thread_id]) batch_pids = self.thread_ids[:self.batch_size] # batch_pids = np.random.choice(self.thread_ids, size=self.batch_size, replace=False) self.thread_ids = [pid for pid in self.thread_ids if pid not in batch_pids] self.batches_produced += 1 return batch_pids def generate_train_batch(self): # to be overriden by child # everything done in here is per batch # print statements in here get confusing due to multithreading raise NotImplementedError def print_stats(self, logger=None, file=None, plot_file=None, plot=True): print_f = utils.CombinedPrinter(logger, file) print_f('\n***Final Training Stats***') total_count = np.sum(self.stats['roi_counts']) for tix, count in enumerate(self.stats['roi_counts']): #name = self.cf.class_dict[tix] if self.balance_target=="class_targets" else str(self.unique_ts[tix]) name=str(self.unique_ts[tix]) print_f('{}: {} rois seen ({:.1f}%).'.format(name, count, count / total_count * 100)) total_samples = self.cf.num_epochs*self.cf.num_train_batches*self.cf.batch_size empties = [ '{}: {} ({:.1f}%)'.format(str(name), self.stats['empty_counts'][tix], self.stats['empty_counts'][tix]/total_samples*100) for tix, name in enumerate(self.unique_ts) ] empties = ", ".join(empties) print_f('empty samples seen: {}\n'.format(empties)) if plot: if plot_file is None: plot_file = os.path.join(self.plot_dir, "train_gen_stats_{}.png".format(self.cf.fold)) os.makedirs(self.plot_dir, exist_ok=True) plg.plot_batchgen_stats(self.cf, self.stats, empties, self.balance_target, self.unique_ts, plot_file) class PatientBatchIterator(SlimDataLoaderBase): """ creates a val/test generator. Step through the dataset and return dictionaries per patient. 2D is a special case of 3D patching with patch_size[2] == 1 (slices) Creates whole Patient batch and targets, and - if necessary - patchwise batch and targets. Appends patient targets anyway for evaluation. For Patching, shifts all patches into batch dimension. batch_tiling_forward will take care of exceeding batch dimensions. This iterator/these batches are not intended to go through MTaugmenter afterwards """ def __init__(self, cf, data): super(PatientBatchIterator, self).__init__(data, 0) self.cf = cf self.dataset_length = len(self._data) self.dataset_pids = list(self._data.keys()) def generate_train_batch(self, pid=None): # to be overriden by child return ################################### # transforms, image manipulation # ################################### def get_patch_crop_coords(img, patch_size, min_overlap=30): """ _:param img (y, x, (z)) _:param patch_size: list of len 2 (2D) or 3 (3D). _:param min_overlap: minimum required overlap of patches. If too small, some areas are poorly represented only at edges of single patches. _:return ndarray: shape (n_patches, 2*dim). crop coordinates for each patch. """ crop_coords = [] for dim in range(len(img.shape)): n_patches = int(np.ceil(img.shape[dim] / patch_size[dim])) # no crops required in this dimension, add image shape as coordinates. if n_patches == 1: crop_coords.append([(0, img.shape[dim])]) continue # fix the two outside patches to coords patchsize/2 and interpolate. center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1) if (patch_size[dim] - center_dists) < min_overlap: n_patches += 1 center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1) patch_centers = np.round([(patch_size[dim] / 2 + (center_dists * ii)) for ii in range(n_patches)]) dim_crop_coords = [(center - patch_size[dim] / 2, center + patch_size[dim] / 2) for center in patch_centers] crop_coords.append(dim_crop_coords) coords_mesh_grid = [] for ymin, ymax in crop_coords[0]: for xmin, xmax in crop_coords[1]: if len(crop_coords) == 3 and patch_size[2] > 1: for zmin, zmax in crop_coords[2]: coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmax]) elif len(crop_coords) == 3 and patch_size[2] == 1: for zmin in range(img.shape[2]): coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmin + 1]) else: coords_mesh_grid.append([ymin, ymax, xmin, xmax]) return np.array(coords_mesh_grid).astype(int) def pad_nd_image(image, new_shape=None, mode="edge", kwargs=None, return_slicer=False, shape_must_be_divisible_by=None): """ one padder to pad them all. Documentation? Well okay. A little bit. by Fabian Isensee :param image: nd image. can be anything :param new_shape: what shape do you want? new_shape does not have to have the same dimensionality as image. If len(new_shape) < len(image.shape) then the last axes of image will be padded. If new_shape < image.shape in any of the axes then we will not pad that axis, but also not crop! (interpret new_shape as new_min_shape) Example: image.shape = (10, 1, 512, 512); new_shape = (768, 768) -> result: (10, 1, 768, 768). Cool, huh? image.shape = (10, 1, 512, 512); new_shape = (364, 768) -> result: (10, 1, 512, 768). :param mode: see np.pad for documentation :param return_slicer: if True then this function will also return what coords you will need to use when cropping back to original shape :param shape_must_be_divisible_by: for network prediction. After applying new_shape, make sure the new shape is divisibly by that number (can also be a list with an entry for each axis). Whatever is missing to match that will be padded (so the result may be larger than new_shape if shape_must_be_divisible_by is not None) :param kwargs: see np.pad for documentation """ if kwargs is None: kwargs = {} if new_shape is not None: old_shape = np.array(image.shape[-len(new_shape):]) else: assert shape_must_be_divisible_by is not None assert isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)) new_shape = image.shape[-len(shape_must_be_divisible_by):] old_shape = new_shape num_axes_nopad = len(image.shape) - len(new_shape) new_shape = [max(new_shape[i], old_shape[i]) for i in range(len(new_shape))] if not isinstance(new_shape, np.ndarray): new_shape = np.array(new_shape) if shape_must_be_divisible_by is not None: if not isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)): shape_must_be_divisible_by = [shape_must_be_divisible_by] * len(new_shape) else: assert len(shape_must_be_divisible_by) == len(new_shape) for i in range(len(new_shape)): if new_shape[i] % shape_must_be_divisible_by[i] == 0: new_shape[i] -= shape_must_be_divisible_by[i] new_shape = np.array([new_shape[i] + shape_must_be_divisible_by[i] - new_shape[i] % shape_must_be_divisible_by[i] for i in range(len(new_shape))]) difference = new_shape - old_shape pad_below = difference // 2 pad_above = difference // 2 + difference % 2 pad_list = [[0, 0]]*num_axes_nopad + list([list(i) for i in zip(pad_below, pad_above)]) res = np.pad(image, pad_list, mode, **kwargs) if not return_slicer: return res else: pad_list = np.array(pad_list) pad_list[:, 1] = np.array(res.shape) - pad_list[:, 1] slicer = list(slice(*i) for i in pad_list) return res, slicer def convert_seg_to_bounding_box_coordinates(data_dict, dim, roi_item_keys, get_rois_from_seg=False, class_specific_seg=False): '''adapted from batchgenerators :param data_dict: seg: segmentation with labels indicating roi_count (get_rois_from_seg=False) or classes (get_rois_from_seg=True), class_targets: list where list index corresponds to roi id (roi_count) :param dim: :param roi_item_keys: keys of the roi-wise items in data_dict to process :param n_rg_feats: nr of regression vector features :param get_rois_from_seg: :return: coords (y1,x1,y2,x2 (,z1,z2)) where the segmentation GT is framed by +1 voxel, i.e., for an object with z-extensions z1=0 through z2=5, bbox target coords will be z1=-1, z2=6. (analogically for x,y). data_dict['roi_masks']: (b, n(b), c, h(n), w(n) (z(n))) list like roi_labels but with arrays (masks) inplace of integers. c==1 if segmentation not one-hot encoded. ''' bb_target = [] roi_masks = [] roi_items = {name:[] for name in roi_item_keys} out_seg = np.copy(data_dict['seg']) for b in range(data_dict['seg'].shape[0]): p_coords_list = [] #p for patient? p_roi_masks_list = [] p_roi_items_lists = {name:[] for name in roi_item_keys} if np.sum(data_dict['seg'][b] != 0) > 0: if get_rois_from_seg: clusters, n_cands = lb(data_dict['seg'][b]) data_dict['class_targets'][b] = [data_dict['class_targets'][b]] * n_cands else: n_cands = int(np.max(data_dict['seg'][b])) rois = np.array( [(data_dict['seg'][b] == ii) * 1 for ii in range(1, n_cands + 1)], dtype='uint8') # separate clusters for rix, r in enumerate(rois): if np.sum(r != 0) > 0: # check if the roi survived slicing (3D->2D) and data augmentation (cropping etc.) seg_ixs = np.argwhere(r != 0) coord_list = [np.min(seg_ixs[:, 1]) - 1, np.min(seg_ixs[:, 2]) - 1, np.max(seg_ixs[:, 1]) + 1, np.max(seg_ixs[:, 2]) + 1] if dim == 3: coord_list.extend([np.min(seg_ixs[:, 3]) - 1, np.max(seg_ixs[:, 3]) + 1]) p_coords_list.append(coord_list) p_roi_masks_list.append(r) # add background class = 0. rix is a patient wide index of lesions. since 'class_targets' is # also patient wide, this assignment is not dependent on patch occurrences. for name in roi_item_keys: p_roi_items_lists[name].append(data_dict[name][b][rix]) assert data_dict["class_targets"][b][rix]>=1, "convertsegtobbox produced bg roi w cl targ {} and unique roi seg {}".format(data_dict["class_targets"][b][rix], np.unique(r)) if class_specific_seg: out_seg[b][data_dict['seg'][b] == rix + 1] = data_dict['class_targets'][b][rix] if not class_specific_seg: out_seg[b][data_dict['seg'][b] > 0] = 1 bb_target.append(np.array(p_coords_list)) roi_masks.append(np.array(p_roi_masks_list)) for name in roi_item_keys: roi_items[name].append(np.array(p_roi_items_lists[name])) else: bb_target.append([]) roi_masks.append(np.zeros_like(data_dict['seg'][b], dtype='uint8')[None]) for name in roi_item_keys: roi_items[name].append(np.array([])) if get_rois_from_seg: data_dict.pop('class_targets', None) data_dict['bb_target'] = np.array(bb_target) data_dict['roi_masks'] = np.array(roi_masks) data_dict['seg'] = out_seg for name in roi_item_keys: data_dict[name] = np.array(roi_items[name]) return data_dict class ConvertSegToBoundingBoxCoordinates(AbstractTransform): """ Converts segmentation masks into bounding box coordinates. """ def __init__(self, dim, roi_item_keys, get_rois_from_seg=False, class_specific_seg=False): self.dim = dim self.roi_item_keys = roi_item_keys self.get_rois_from_seg = get_rois_from_seg self.class_specific_seg = class_specific_seg def __call__(self, **data_dict): return convert_seg_to_bounding_box_coordinates(data_dict, self.dim, self.roi_item_keys, self.get_rois_from_seg, self.class_specific_seg) ############################# # data packing / unpacking # not used, data_manager.py used instead ############################# def get_case_identifiers(folder): case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith("npz")] return case_identifiers def convert_to_npy(npz_file): if not os.path.isfile(npz_file[:-3] + "npy"): a = np.load(npz_file)['data'] np.save(npz_file[:-3] + "npy", a) def unpack_dataset(folder, threads=8): case_identifiers = get_case_identifiers(folder) p = Pool(threads) npz_files = [os.path.join(folder, i + ".npz") for i in case_identifiers] p.map(convert_to_npy, npz_files) p.close() p.join() def delete_npy(folder): case_identifiers = get_case_identifiers(folder) npy_files = [os.path.join(folder, i + ".npy") for i in case_identifiers] npy_files = [i for i in npy_files if os.path.isfile(i)] for n in npy_files: os.remove(n) \ No newline at end of file