diff --git a/datasets/lidc/data_loader.py b/datasets/lidc/data_loader.py index e848649..8588173 100644 --- a/datasets/lidc/data_loader.py +++ b/datasets/lidc/data_loader.py @@ -1,978 +1,971 @@ # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== ''' Data Loader for the LIDC data set. This dataloader expects preprocessed data in .npy or .npz files per patient and a pandas dataframe containing the meta info e.g. file paths, and some ground-truth info like labels, foreground slice ids. LIDC 4-fold annotations storage capacity problem: keep segmentation gts compressed (npz), unpack at each batch generation. ''' import plotting as plg import os import pickle import time -import subprocess -from multiprocessing import Pool import numpy as np import pandas as pd from collections import OrderedDict # batch generator tools from https://github.com/MIC-DKFZ/batchgenerators -from batchgenerators.dataloading.data_loader import SlimDataLoaderBase from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror from batchgenerators.transforms.abstract_transforms import Compose from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter -from batchgenerators.dataloading import SingleThreadedAugmenter from batchgenerators.transforms.spatial_transforms import SpatialTransform from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform -#from batchgenerators.transforms.utility_transforms import ConvertSegToBoundingBoxCoordinates import utils.dataloader_utils as dutils from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates -import data_manager as dmanager - def save_obj(obj, name): """Pickle a python object.""" with open(name + '.pkl', 'wb') as f: pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) def vector(item): """ensure item is vector-like (list or array or tuple) :param item: anything """ if not isinstance(item, (list, tuple, np.ndarray)): item = [item] return item class Dataset(dutils.Dataset): r"""Load a dict holding memmapped arrays and clinical parameters for each patient, evtly subset of those. If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to cf.data_dest. :param cf: config object. :param logger: logger. :param subset_ids: subset of patient/sample identifiers to load from whole set. :param data_sourcedir: directory in which to find data, defaults to cf.data_sourcedir if None. :return: dict with imgs, segs, pids, class_labels, observables """ def __init__(self, cf, logger=None, subset_ids=None, data_sourcedir=None, mode='train'): super(Dataset,self).__init__(cf, data_sourcedir) if mode == 'train' and not cf.training_gts == "merged": self.gt_dir = "patient_gts_sa" self.gt_kind = cf.training_gts else: self.gt_dir = "patient_gts_merged" self.gt_kind = "merged" if logger is not None: logger.info("loading {} ground truths for {}".format(self.gt_kind, 'training and validation' if mode=='train' else 'testing')) p_df = pd.read_pickle(os.path.join(self.data_sourcedir, self.gt_dir, cf.input_df_name)) #exclude_pids = ["0305a", "0447a"] # due to non-bg segmentation but bg mal label in nodules 5728, 8840 #p_df = p_df[~p_df.pid.isin(exclude_pids)] if subset_ids is not None: p_df = p_df[p_df.pid.isin(subset_ids)] if logger is not None: logger.info('subset: selected {} instances from df'.format(len(p_df))) if cf.select_prototype_subset is not None: prototype_pids = p_df.pid.tolist()[:cf.select_prototype_subset] p_df = p_df[p_df.pid.isin(prototype_pids)] if logger is not None: logger.warning('WARNING: using prototyping data subset of length {}!!!'.format(len(p_df))) pids = p_df.pid.tolist() # evtly copy data from data_sourcedir to data_dest if cf.server_env and not hasattr(cf, 'data_dir') and hasattr(cf, "data_dest"): # copy and unpack images file_subset = ["{}_img.npz".format(pid) for pid in pids if not os.path.isfile(os.path.join(cf.data_dest,'{}_img.npy'.format(pid)))] file_subset += [os.path.join(self.data_sourcedir, self.gt_dir, cf.input_df_name)] self.copy_data(cf, file_subset=file_subset, keep_packed=False, del_after_unpack=True) # copy and do not unpack segmentations file_subset = [os.path.join(self.gt_dir, "{}_rois.np*".format(pid)) for pid in pids] keep_packed = not cf.training_gts == "merged" self.copy_data(cf, file_subset=file_subset, keep_packed=keep_packed, del_after_unpack=(not keep_packed)) else: cf.data_dir = self.data_sourcedir ext = 'npy' if self.gt_kind == "merged" else 'npz' imgs = [os.path.join(self.data_dir, '{}_img.npy'.format(pid)) for pid in pids] segs = [os.path.join(self.data_dir, self.gt_dir, '{}_rois.{}'.format(pid, ext)) for pid in pids] orig_class_targets = p_df['class_target'].tolist() data = OrderedDict() if self.gt_kind == 'merged': for ix, pid in enumerate(pids): data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid} data[pid]['fg_slices'] = np.array(p_df['fg_slices'].tolist()[ix]) if 'class' in cf.prediction_tasks: if len(cf.class_labels)==3: # malignancy scores are binarized: (benign: 1-2 --> cl 1, malignant: 3-5 --> cl 2) data[pid]['class_targets'] = np.array([2 if ii >= 3 else 1 for ii in orig_class_targets[ix]], dtype='uint8') elif len(cf.class_labels)==6: # classify each malignancy score data[pid]['class_targets'] = np.array([1 if ii==0.5 else np.round(ii) for ii in orig_class_targets[ix]], dtype='uint8') else: raise Exception("mismatch class labels and data-loading implementations.") else: data[pid]['class_targets'] = np.ones_like(np.array(orig_class_targets[ix]), dtype='uint8') if any(['regression' in task for task in cf.prediction_tasks]): data[pid]["regression_targets"] = np.array([vector(v) for v in orig_class_targets[ix]], dtype='float16') data[pid]["rg_bin_targets"] = np.array( [cf.rg_val_to_bin_id(v) for v in data[pid]["regression_targets"]], dtype='uint8') else: for ix, pid in enumerate(pids): data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid} data[pid]['fg_slices'] = np.array(p_df['fg_slices'].values[ix]) if 'class' in cf.prediction_tasks: # malignancy scores are binarized: (benign: 1-2 --> cl 1, malignant: 3-5 --> cl 2) raise NotImplementedError # todo need to consider bg # data[pid]['class_targets'] = np.array( # [[2 if ii >= 3 else 1 for ii in four_fold_targs] for four_fold_targs in orig_class_targets[ix]]) else: data[pid]['class_targets'] = np.array( [[1 if ii > 0 else 0 for ii in four_fold_targs] for four_fold_targs in orig_class_targets[ix]], dtype='uint8') if any(['regression' in task for task in cf.prediction_tasks]): data[pid]["regression_targets"] = np.array( [[vector(v) for v in four_fold_targs] for four_fold_targs in orig_class_targets[ix]], dtype='float16') data[pid]["rg_bin_targets"] = np.array( [[cf.rg_val_to_bin_id(v) for v in four_fold_targs] for four_fold_targs in data[pid]["regression_targets"]], dtype='uint8') cf.roi_items = cf.observables_rois[:] cf.roi_items += ['class_targets'] if any(['regression' in task for task in cf.prediction_tasks]): cf.roi_items += ['regression_targets'] cf.roi_items += ['rg_bin_targets'] self.data = data self.set_ids = np.array(list(self.data.keys())) self.df = None # merged GTs class BatchGenerator_merged(dutils.BatchGenerator): """ creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D) from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size. Actual patch_size is obtained after data augmentation. :param data: data dictionary as provided by 'load_dataset'. :param batch_size: number of patients to sample for the batch :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target """ def __init__(self, cf, data): super(BatchGenerator_merged, self).__init__(cf, data) self.crop_margin = np.array(self.cf.patch_size)/8. #min distance of ROI center to edge of cropped_patch. self.p_fg = 0.5 self.empty_samples_max_ratio = 0.6 self.random_count = int(cf.batch_random_ratio * cf.batch_size) self.class_targets = {k: v["class_targets"] for (k, v) in self._data.items()} self.balance_target_distribution(plot=True) self.stats = {"roi_counts": np.zeros((len(self.unique_ts),), dtype='uint32'), "empty_samples_count": 0} def generate_train_batch(self): # samples patients towards equilibrium of foreground classes on a roi-level after sampling a random ratio # fully random patients batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False)) # target-balanced patients batch_patient_ids += list(np.random.choice(self.dataset_pids, size=self.batch_size-self.random_count, replace=False, p=self.p_probs)) batch_data, batch_segs, batch_pids, batch_patient_labels = [], [], [], [] batch_roi_items = {name: [] for name in self.cf.roi_items} # record roi count of classes in batch batch_roi_counts, empty_samples_count = np.zeros((len(self.unique_ts),), dtype='uint32'), 0 # empty count for full bg samples (empty slices in 2D/patients in 3D) for sample in range(self.batch_size): patient = self._data[batch_patient_ids[sample]] data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis] seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0)) batch_pids.append(patient['pid']) (c, y, x, z) = data.shape if self.cf.dim == 2: elig_slices, choose_fg = [], False if len(patient['fg_slices']) > 0: if empty_samples_count / self.batch_size >= self.empty_samples_max_ratio or np.random.rand(1)<=self.p_fg: # fg is to be picked for tix in np.argsort(batch_roi_counts): # pick slices of patient that have roi of sought-for target # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero( patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0])-1] == self.unique_ts[tix]) > 0] if len(elig_slices) > 0: choose_fg = True break else: # pick bg elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices']) if len(elig_slices)>0: sl_pick_ix = np.random.choice(elig_slices, size=None) else: sl_pick_ix = np.random.choice(z, size=None) data = data[..., sl_pick_ix] seg = seg[..., sl_pick_ix] # pad data if smaller than pre_crop_size. if np.any([data.shape[dim + 1] < ps for dim, ps in enumerate(self.cf.pre_crop_size)]): new_shape = [np.max([data.shape[dim + 1], ps]) for dim, ps in enumerate(self.cf.pre_crop_size)] data = dutils.pad_nd_image(data, new_shape, mode='constant') seg = dutils.pad_nd_image(seg, new_shape, mode='constant') # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg. crop_dims = [dim for dim, ps in enumerate(self.cf.pre_crop_size) if data.shape[dim + 1] > ps] if len(crop_dims) > 0: if self.cf.dim == 3: choose_fg = (empty_samples_count/self.batch_size>=self.empty_samples_max_ratio) or np.random.rand(1) <= self.p_fg if choose_fg and np.any(seg): available_roi_ids = np.unique(seg)[1:] for tix in np.argsort(batch_roi_counts): elig_roi_ids = available_roi_ids[patient[self.balance_target][available_roi_ids-1] == self.unique_ts[tix]] if len(elig_roi_ids)>0: seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None)) break roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)] assert seg[tuple(roi_anchor_pixel)] > 0 # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by # distance to the desired ROI < patch_size /2. # (here final patch size to account for center_crop after data augmentation). sample_seg_center = {} for ii in crop_dims: low = np.max((self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] - (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) high = np.min((data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] + (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) # happens if lesion on the edge of the image. dont care about roi anymore, # just make sure pre-crop is inside image. if low >= high: low = data.shape[ii + 1] // 2 - (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) high = data.shape[ii + 1] // 2 + (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) sample_seg_center[ii] = np.random.randint(low=low, high=high) else: # not guaranteed to be empty. probability of emptiness depends on the data. sample_seg_center = {ii: np.random.randint(low=self.cf.pre_crop_size[ii]//2, high=data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2) for ii in crop_dims} for ii in crop_dims: min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2) max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2) data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1) seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii) batch_data.append(data) batch_segs.append(seg[np.newaxis]) for o in batch_roi_items: #after loop, holds every entry of every batchpatient per roi-item batch_roi_items[o].append(patient[o]) if self.cf.dim == 3: for tix in range(len(self.unique_ts)): batch_roi_counts[tix] += np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix]) elif self.cf.dim == 2: for tix in range(len(self.unique_ts)): batch_roi_counts[tix] += np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix]) if not np.any(seg): empty_samples_count += 1 data = np.array(batch_data).astype(np.float16) seg = np.array(batch_segs).astype(np.uint8) batch = {'data': data, 'seg': seg, 'pid': batch_pids, 'roi_counts':batch_roi_counts, 'empty_samples_count': empty_samples_count} for key,val in batch_roi_items.items(): #extend batch dic by roi-wise items (obs, class ids, regression vectors...) batch[key] = np.array(val) return batch class PatientBatchIterator_merged(dutils.PatientBatchIterator): """ creates a test generator that iterates over entire given dataset returning 1 patient per batch. Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actualy evaluation (done in 3D), if willing to accept speed-loss during training. :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or batch_size = n_2D_patches in 2D . """ def __init__(self, cf, data): # threads in augmenter super(PatientBatchIterator_merged, self).__init__(cf, data) self.patient_ix = 0 self.patch_size = cf.patch_size + [1] if cf.dim == 2 else cf.patch_size def generate_train_batch(self, pid=None): if pid is None: pid = self.dataset_pids[self.patient_ix] patient = self._data[pid] data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0)) seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0)) # pad data if smaller than patch_size seen during training. if np.any([data.shape[dim] < ps for dim, ps in enumerate(self.patch_size)]): new_shape = [np.max([data.shape[dim], self.patch_size[dim]]) for dim, ps in enumerate(self.patch_size)] data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape. seg = dutils.pad_nd_image(seg, new_shape) # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor. if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: out_data = data[np.newaxis, np.newaxis] out_seg = seg[np.newaxis, np.newaxis] batch_3D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_3D[o] = np.array([patient[o]]) converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg) batch_3D = converter(**batch_3D) batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_3D["patient_" + o] = batch_3D[o] if self.cf.dim == 2: out_data = np.transpose(data, axes=(2, 0, 1))[:, np.newaxis] # (z, c, x, y ) out_seg = np.transpose(seg, axes=(2, 0, 1))[:, np.newaxis] batch_2D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_2D[o] = np.repeat(np.array([patient[o]]), out_data.shape[0], axis=0) converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg) batch_2D = converter(**batch_2D) if self.cf.merge_2D_to_3D_preds: batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_3D[o] else: batch_2D.update({'patient_bb_target': batch_2D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_2D[o] out_batch = batch_3D if self.cf.dim == 3 else batch_2D out_batch.update({'pid': np.array([patient['pid']] * len(out_data))}) # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension. # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1. if np.any([data.shape[dim] > self.patch_size[dim] for dim in range(3)]): patient_batch = out_batch patch_crop_coords_list = dutils.get_patch_crop_coords(data, self.patch_size) new_img_batch, new_seg_batch = [], [] for cix, c in enumerate(patch_crop_coords_list): seg_patch = seg[c[0]:c[1], c[2]: c[3], c[4]:c[5]] new_seg_batch.append(seg_patch) tmp_c_5 = c[5] new_img_batch.append(data[c[0]:c[1], c[2]:c[3], c[4]:tmp_c_5]) data = np.array(new_img_batch)[:, np.newaxis] # (n_patches, c, x, y, z) seg = np.array(new_seg_batch)[:, np.newaxis] # (n_patches, 1, x, y, z) if self.cf.dim == 2: # all patches have z dimension 1 (slices). discard dimension data = data[..., 0] seg = seg[..., 0] patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'), 'pid': np.array([patient['pid']] * data.shape[0])} for o in self.cf.roi_items: patch_batch[o] = np.repeat(np.array([patient[o]]), len(patch_crop_coords_list), axis=0) # patient-wise (orig) batch info for putting the patches back together after prediction for o in self.cf.roi_items: patch_batch["patient_" + o] = patient_batch['patient_' + o] if self.cf.dim == 2: # this could also be named "unpatched_2d_roi_items" patch_batch["patient_" + o + "_2d"] = patient_batch[o] # adding patient-wise data and seg adds about 2 GB of additional RAM consumption to a batch 20x288x288 # and enables calculating test-dice/viewing patient-wise results in test # remove, but also remove dice from metrics, when like to save memory patch_batch['patient_data'] = patient_batch['data'] patch_batch['patient_seg'] = patient_batch['seg'] patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] if self.cf.dim == 2: patch_batch['patient_bb_target_2d'] = patient_batch['bb_target'] patch_batch['original_img_shape'] = patient_batch['original_img_shape'] converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, self.cf.class_specific_seg) patch_batch = converter(**patch_batch) out_batch = patch_batch self.patient_ix += 1 if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 return out_batch # single-annotator GTs class BatchGenerator_sa(dutils.BatchGenerator): """ creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D) from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size. Actual patch_size is obtained after data augmentation. :param data: data dictionary as provided by 'load_dataset'. :param batch_size: number of patients to sample for the batch :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target """ # noinspection PyMethodOverriding def balance_target_distribution(self, rater, plot=False): """ :param rater: for which rater slot to generate the distribution :param self.targets: dic holding {patient_specifier : patient-wise-unique ROI targets} :param plot: whether to plot the generated patient distributions :return: probability distribution over all pids. draw without replace from this. """ # get unique foreground targets per patient, assign -1 to an "empty" patient (has no foreground) patient_ts = [[roi[rater] for roi in patient_rois_lst] for patient_rois_lst in self.targets.values()] # assign [-1] to empty patients patient_ts = [np.unique(lst) if len([t for t in lst if np.any(t>0)])>0 else [-1] for lst in patient_ts] #bg_mask = np.array([np.all(lst == [-1]) for lst in patient_ts]) # sort out bg labels (are 0) unique_ts, t_counts = np.unique([t for lst in patient_ts for t in lst if t>0], return_counts=True) t_probs = t_counts.sum() / t_counts t_probs /= t_probs.sum() t_probs = {t : t_probs[ix] for ix, t in enumerate(unique_ts)} t_probs[-1] = 0. t_probs[0] = 0. # fail if balance target is not a number (i.e., a vector) p_probs = np.array([ max([t_probs[t] for t in lst]) for lst in patient_ts ]) #normalize p_probs /= p_probs.sum() if plot: plg.plot_batchgen_distribution(self.cf, self.dataset_pids, p_probs, self.balance_target, out_file=os.path.join(self.cf.plot_dir, "train_gen_distr_"+str(self.cf.fold)+"_rater"+str(rater)+".png")) return p_probs, unique_ts def __init__(self, cf, data): super(BatchGenerator_sa, self).__init__(cf, data) self.crop_margin = np.array(self.cf.patch_size) / 8. # min distance of ROI center to edge of cropped_patch. self.p_fg = 0.5 self.empty_samples_max_ratio = 0.6 self.random_count = int(cf.batch_random_ratio * cf.batch_size) self.rater_bsize = 4 unique_ts_total = set() self.rater_p_probs = [] for r in range(self.rater_bsize): p_probs, unique_ts = self.balance_target_distribution(r, plot=True) self.rater_p_probs.append(p_probs) unique_ts_total.update(unique_ts) self.unique_ts = sorted(list(unique_ts_total)) self.stats = {"roi_counts": np.zeros((len(self.unique_ts),), dtype='uint32'), "empty_samples_count": 0} def generate_train_batch(self): rater = np.random.randint(self.rater_bsize) # samples patients towards equilibrium of foreground classes on a roi-level (after randomly sampling the ratio batch_random_ratio). # random patients batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False)) # target-balanced patients batch_patient_ids += list(np.random.choice(self.dataset_pids, size=self.batch_size-self.random_count, replace=False, p=self.rater_p_probs[rater])) batch_data, batch_segs, batch_pids, batch_patient_labels = [], [], [], [] batch_roi_items = {name: [] for name in self.cf.roi_items} # record roi count of classes in batch batch_roi_counts, empty_samples_count = np.zeros((len(self.unique_ts),), dtype='uint32'), 0 # empty count for full bg samples (empty slices in 2D/patients in 3D) for sample in range(self.batch_size): patient = self._data[batch_patient_ids[sample]] patient_balance_ts = np.array([roi[rater] for roi in patient[self.balance_target]]) data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis] seg = np.load(patient['seg'], mmap_mode='r') seg = np.transpose(seg[list(seg.keys())[0]][rater], axes=(1, 2, 0)) batch_pids.append(patient['pid']) (c, y, x, z) = data.shape if self.cf.dim == 2: elig_slices, choose_fg = [], False if len(patient['fg_slices']) > 0: if empty_samples_count / self.batch_size >= self.empty_samples_max_ratio or np.random.rand( 1) <= self.p_fg: # fg is to be picked for tix in np.argsort(batch_roi_counts): # pick slices of patient that have roi of sought-for target # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero( patient_balance_ts[np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] == self.unique_ts[tix]) > 0] if len(elig_slices) > 0: choose_fg = True break else: # pick bg elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices'][rater]) if len(elig_slices) > 0: sl_pick_ix = np.random.choice(elig_slices, size=None) else: sl_pick_ix = np.random.choice(z, size=None) data = data[..., sl_pick_ix] seg = seg[..., sl_pick_ix] # pad data if smaller than pre_crop_size. if np.any([data.shape[dim + 1] < ps for dim, ps in enumerate(self.cf.pre_crop_size)]): new_shape = [np.max([data.shape[dim + 1], ps]) for dim, ps in enumerate(self.cf.pre_crop_size)] data = dutils.pad_nd_image(data, new_shape, mode='constant') seg = dutils.pad_nd_image(seg, new_shape, mode='constant') # crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg. crop_dims = [dim for dim, ps in enumerate(self.cf.pre_crop_size) if data.shape[dim + 1] > ps] if len(crop_dims) > 0: if self.cf.dim == 3: choose_fg = (empty_samples_count/self.batch_size>=self.empty_samples_max_ratio) or np.random.rand(1) <= self.p_fg if choose_fg and np.any(seg): available_roi_ids = np.unique(seg[seg>0]) assert np.all(patient_balance_ts[available_roi_ids-1]>0), "trying to choose roi with rating 0" for tix in np.argsort(batch_roi_counts): elig_roi_ids = available_roi_ids[ patient_balance_ts[available_roi_ids-1] == self.unique_ts[tix] ] if len(elig_roi_ids)>0: seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None)) roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)] break assert seg[tuple(roi_anchor_pixel)] > 0, "roi_anchor_pixel not inside roi: {}, pb_ts {}, elig ids {}".format(tuple(roi_anchor_pixel), patient_balance_ts, elig_roi_ids) # sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by # distance to the desired ROI < patch_size /2. # (here final patch size to account for center_crop after data augmentation). sample_seg_center = {} for ii in crop_dims: low = np.max((self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] - (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) high = np.min((data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] + (self.cf.patch_size[ii]//2 - self.crop_margin[ii]))) # happens if lesion on the edge of the image. dont care about roi anymore, # just make sure pre-crop is inside image. if low >= high: low = data.shape[ii + 1] // 2 - (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) high = data.shape[ii + 1] // 2 + (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2) sample_seg_center[ii] = np.random.randint(low=low, high=high) else: # not guaranteed to be empty. probability of emptiness depends on the data. sample_seg_center = {ii: np.random.randint(low=self.cf.pre_crop_size[ii]//2, high=data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2) for ii in crop_dims} for ii in crop_dims: min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2) max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2) data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1) seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii) batch_data.append(data) batch_segs.append(seg[np.newaxis]) for o in batch_roi_items: #after loop, holds every entry of every batchpatient per roi-item batch_roi_items[o].append([roi[rater] for roi in patient[o]]) if self.cf.dim == 3: for tix in range(len(self.unique_ts)): batch_roi_counts[tix] += np.count_nonzero(patient_balance_ts == self.unique_ts[tix]) elif self.cf.dim == 2: for tix in range(len(self.unique_ts)): batch_roi_counts[tix] += np.count_nonzero(patient_balance_ts[np.unique(seg[seg>0]) - 1] == self.unique_ts[tix]) if not np.any(seg): empty_samples_count += 1 data = np.array(batch_data).astype('float16') seg = np.array(batch_segs).astype('uint8') batch = {'data': data, 'seg': seg, 'pid': batch_pids, 'rater_id': rater, 'roi_counts':batch_roi_counts, 'empty_samples_count': empty_samples_count} for key,val in batch_roi_items.items(): #extend batch dic by roi-wise items (obs, class ids, regression vectors...) batch[key] = np.array(val) return batch class PatientBatchIterator_sa(dutils.PatientBatchIterator): """ creates a test generator that iterates over entire given dataset returning 1 patient per batch. Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actual evaluation (done in 3D), if willing to accept speed loss during training. :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or batch_size = n_2D_patches in 2D . This is the data & gt loader for the 4-fold single-annotator GTs: each data input has separate annotations of 4 annotators. the way the pipeline is currently setup, the single-annotator GTs are only used if training with validation mode val_patient; during testing the Iterator with the merged GTs is used. # todo mode val_patient not implemented yet (since very slow). would need to sample from all available rater GTs. """ def __init__(self, cf, data): #threads in augmenter super(PatientBatchIterator_sa, self).__init__(cf, data) self.cf = cf self.patient_ix = 0 self.dataset_pids = list(self._data.keys()) self.patch_size = cf.patch_size+[1] if cf.dim==2 else cf.patch_size self.rater_bsize = 4 def generate_train_batch(self, pid=None): if pid is None: pid = self.dataset_pids[self.patient_ix] patient = self._data[pid] data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0)) # all gts are 4-fold and npz! seg = np.load(patient['seg'], mmap_mode='r') seg = np.transpose(seg[list(seg.keys())[0]], axes=(0, 2, 3, 1)) # pad data if smaller than patch_size seen during training. if np.any([data.shape[dim] < ps for dim, ps in enumerate(self.patch_size)]): new_shape = [np.max([data.shape[dim], self.patch_size[dim]]) for dim, ps in enumerate(self.patch_size)] data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape. seg = dutils.pad_nd_image(seg, new_shape) # get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor. if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: out_data = data[np.newaxis, np.newaxis] out_seg = seg[:, np.newaxis] batch_3D = {'data': out_data, 'seg': out_seg} for item in self.cf.roi_items: batch_3D[item] = [] for r in range(self.rater_bsize): for item in self.cf.roi_items: batch_3D[item].append(np.array([roi[r] for roi in patient[item]])) converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg) batch_3D = converter(**batch_3D) batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_3D["patient_" + o] = batch_3D[o] if self.cf.dim == 2: out_data = np.transpose(data, axes=(2, 0, 1))[:, np.newaxis] # (z, c, y, x ) out_seg = np.transpose(seg, axes=(0, 3, 1, 2))[:, :, np.newaxis] # (n_raters, z, 1, y,x) batch_2D = {'data': out_data} for item in ["seg", "bb_target"]+self.cf.roi_items: batch_2D[item] = [] converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg) for r in range(self.rater_bsize): tmp_batch = {"seg": out_seg[r]} for item in self.cf.roi_items: tmp_batch[item] = np.repeat(np.array([[roi[r] for roi in patient[item]]]), out_data.shape[0], axis=0) tmp_batch = converter(**tmp_batch) for item in ["seg", "bb_target"]+self.cf.roi_items: batch_2D[item].append(tmp_batch[item]) # for item in ["seg", "bb_target"]+self.cf.roi_items: # batch_2D[item] = np.array(batch_2D[item]) if self.cf.merge_2D_to_3D_preds: batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_3D[o] else: batch_2D.update({'patient_bb_target': batch_2D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_2D[o] out_batch = batch_3D if self.cf.dim == 3 else batch_2D out_batch.update({'pid': np.array([patient['pid']] * out_data.shape[0])}) # crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension. # in this case, 2D is treated as a special case of 3D with patch_size[z] = 1. if np.any([data.shape[dim] > self.patch_size[dim] for dim in range(3)]): patient_batch = out_batch patch_crop_coords_list = dutils.get_patch_crop_coords(data, self.patch_size) new_img_batch = [] new_seg_batch = [] for cix, c in enumerate(patch_crop_coords_list): seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]] new_seg_batch.append(seg_patch) tmp_c_5 = c[5] new_img_batch.append(data[c[0]:c[1], c[2]:c[3], c[4]:tmp_c_5]) data = np.array(new_img_batch)[:, np.newaxis] # (n_patches, c, x, y, z) seg = np.transpose(np.array(new_seg_batch), axes=(1,0,2,3,4))[:,:,np.newaxis] # (n_raters, n_patches, x, y, z) if self.cf.dim == 2: # all patches have z dimension 1 (slices). discard dimension data = data[..., 0] seg = seg[..., 0] patch_batch = {'data': data.astype('float32'), 'pid': np.array([patient['pid']] * data.shape[0])} # for o in self.cf.roi_items: # patch_batch[o] = np.repeat(np.array([patient[o]]), len(patch_crop_coords_list), axis=0) converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, False, self.cf.class_specific_seg) for item in ["seg", "bb_target"]+self.cf.roi_items: patch_batch[item] = [] # coord_list = [np.min(seg_ixs[:, 1]) - 1, np.min(seg_ixs[:, 2]) - 1, np.max(seg_ixs[:, 1]) + 1, # IndexError: index 2 is out of bounds for axis 1 with size 2 for r in range(self.rater_bsize): tmp_batch = {"seg": seg[r]} for item in self.cf.roi_items: tmp_batch[item] = np.repeat(np.array([[roi[r] for roi in patient[item]]]), len(patch_crop_coords_list), axis=0) tmp_batch = converter(**tmp_batch) for item in ["seg", "bb_target"]+self.cf.roi_items: patch_batch[item].append(tmp_batch[item]) # patient-wise (orig) batch info for putting the patches back together after prediction for o in self.cf.roi_items: patch_batch["patient_" + o] = patient_batch['patient_'+o] if self.cf.dim==2: # this could also be named "unpatched_2d_roi_items" patch_batch["patient_"+o+"_2d"] = patient_batch[o] # adding patient-wise data and seg adds about 2 GB of additional RAM consumption to a batch 20x288x288 # and enables calculating test-dice/viewing patient-wise results in test # remove, but also remove dice from metrics, if you like to save memory patch_batch['patient_data'] = patient_batch['data'] patch_batch['patient_seg'] = patient_batch['seg'] patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] if self.cf.dim==2: patch_batch['patient_bb_target_2d'] = patient_batch['bb_target'] patch_batch['original_img_shape'] = patient_batch['original_img_shape'] out_batch = patch_batch self.patient_ix += 1 if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 return out_batch def create_data_gen_pipeline(cf, patient_data, is_training=True): """ create multi-threaded train/val/test batch generation and augmentation pipeline. :param cf: configs object. :param patient_data: dictionary containing one dictionary per patient in the train/test subset. :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing) :return: multithreaded_generator """ data_gen = BatchGenerator_merged(cf, patient_data) if cf.training_gts=='merged' else BatchGenerator_sa(cf, patient_data) # add transformations to pipeline. my_transforms = [] if is_training: if cf.da_kwargs["mirror"]: mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes']) my_transforms.append(mirror_transform) spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop']) my_transforms.append(spatial_transform) else: my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) if cf.create_bounding_box_targets: my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg)) all_transforms = Compose(my_transforms) multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers)) return multithreaded_generator def get_train_generators(cf, logger, data_statistics=True): """ wrapper function for creating the training batch generator pipeline. returns the train/val generators. selects patients according to cv folds (generated by first run/fold of experiment): splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set) If cf.held_out_test_set is True, adds the test split to the training data. """ dataset = Dataset(cf, logger) dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits) dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle')) set_splits = dataset.fg.splits test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1) train_ids = np.concatenate(set_splits, axis=0) if cf.held_out_test_set: train_ids = np.concatenate((train_ids, test_ids), axis=0) test_ids = [] train_data = {k: v for (k, v) in dataset.data.items() if k in train_ids} val_data = {k: v for (k, v) in dataset.data.items() if k in val_ids} logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids), len(test_ids))) if data_statistics: dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids}, plot_dir=os.path.join(cf.plot_dir,"dataset")) batch_gen = {} batch_gen['train'] = create_data_gen_pipeline(cf, train_data, is_training=True) batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, is_training=False) if cf.val_mode == 'val_patient': assert cf.training_gts == 'merged', 'val_patient not yet implemented for sa gts' batch_gen['val_patient'] = PatientBatchIterator_merged(cf, val_data) if cf.training_gts=='merged' \ else PatientBatchIterator_sa(cf, val_data) batch_gen['n_val'] = len(val_data) if cf.max_val_patients=="all" else min(len(val_data), cf.max_val_patients) else: batch_gen['n_val'] = cf.num_val_batches return batch_gen def get_test_generator(cf, logger): """ wrapper function for creating the test batch generator pipeline. selects patients according to cv folds (generated by first run/fold of experiment) If cf.held_out_test_set is True, gets the data from an external folder instead. """ if cf.held_out_test_set: sourcedir = cf.test_data_sourcedir test_ids = None else: sourcedir = None with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle: set_splits = pickle.load(handle) test_ids = set_splits[cf.fold] test_data = Dataset(cf, logger, subset_ids=test_ids, data_sourcedir=sourcedir, mode="test").data logger.info("data set loaded with: {} test patients".format(len(test_ids))) batch_gen = {} batch_gen['test'] = PatientBatchIterator_merged(cf, test_data) batch_gen['n_test'] = len(test_ids) if cf.max_test_patients == "all" else min(cf.max_test_patients, len(test_ids)) return batch_gen if __name__ == "__main__": import sys sys.path.append('../') import plotting as plg import utils.exp_utils as utils from configs import Configs cf = Configs() cf.batch_size = 3 #dataset_path = os.path.dirname(os.path.realpath(__file__)) #exp_path = os.path.join(dataset_path, "experiments/dev") #cf = utils.prep_exp(dataset_path, exp_path, server_env=False, use_stored_settings=False, is_training=True) cf.created_fold_id_pickle = False total_stime = time.time() times = {} # cf.server_env = True # cf.data_dir = "experiments/dev_data" # dataset = Dataset(cf) # patient = dataset['Master_00018'] cf.exp_dir = "experiments/dev/" cf.plot_dir = cf.exp_dir + "plots" os.makedirs(cf.exp_dir, exist_ok=True) cf.fold = 0 logger = utils.get_logger(cf.exp_dir) gens = get_train_generators(cf, logger) train_loader = gens['train'] for i in range(1): stime = time.time() #ex_batch = next(train_loader) print("train batch", i) times["train_batch"] = time.time() - stime #plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch.png", show_gt_labels=True) # # # with open(os.path.join(cf.exp_dir, "fold_"+str(cf.fold), "BatchGenerator_stats.txt"), mode="w") as file: # # train_loader.generator.print_stats(logger, file) # val_loader = gens['val_sampling'] stime = time.time() ex_batch = next(val_loader) times["val_batch"] = time.time() - stime stime = time.time() #plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch.png", show_gt_labels=True, plot_mods=False, # show_info=False) times["val_plot"] = time.time() - stime # test_loader = get_test_generator(cf, logger)["test"] stime = time.time() ex_batch = test_loader.generate_train_batch() times["test_batch"] = time.time() - stime stime = time.time() plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png")#, sample_picks=[0,1,2,3]) times["test_patchbatch_plot"] = time.time() - stime # ex_batch['data'] = ex_batch['patient_data'] # ex_batch['seg'] = ex_batch['patient_seg'] # ex_batch['bb_target'] = ex_batch['patient_bb_target'] # for item in cf.roi_items: # ex_batch[] # stime = time.time() # #ex_batch = next(test_loader) # ex_batch = next(test_loader) # plg.view_batch(cf, ex_batch, show_gt_labels=False, show_gt_boxes=True, patient_items=True,# vol_slice_picks=[146,148, 218,220], # out_file="experiments/dev/dev_expatientbatch.png") # , sample_picks=[0,1,2,3]) # times["test_patient_batch_plot"] = time.time() - stime print("Times recorded throughout:") for (k, v) in times.items(): print(k, "{:.2f}".format(v)) mins, secs = divmod((time.time() - total_stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) diff --git a/datasets/toy/configs.py b/datasets/toy/configs.py index fe14898..8f81931 100644 --- a/datasets/toy/configs.py +++ b/datasets/toy/configs.py @@ -1,495 +1,490 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import sys import os sys.path.append(os.path.dirname(os.path.realpath(__file__))) import numpy as np from default_configs import DefaultConfigs from collections import namedtuple boxLabel = namedtuple('boxLabel', ["name", "color"]) Label = namedtuple("Label", ['id', 'name', 'shape', 'radius', 'color', 'regression', 'ambiguities', 'gt_distortion']) binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'bin_vals']) class Configs(DefaultConfigs): def __init__(self, server_env=None): super(Configs, self).__init__(server_env) ######################### # Prepro # ######################### self.pp_rootdir = os.path.join('/mnt/HDD2TB/Documents/data/toy', "cyl1ps_dev") self.pp_npz_dir = self.pp_rootdir+"_npz" self.pre_crop_size = [320,320,8] #y,x,z; determines pp data shape (2D easily implementable, but only 3D for now) self.min_2d_radius = 6 #in pixels self.n_train_samples, self.n_test_samples = 80, 80 # not actually real one-hot encoding (ohe) but contains more info: roi-overlap only within classes. self.pp_create_ohe_seg = False self.pp_empty_samples_ratio = 0.1 self.pp_place_radii_mid_bin = True self.pp_only_distort_2d = True # outer-most intensity of blurred radii, relative to inner-object intensity. <1 for decreasing, > 1 for increasing. # e.g.: setting 0.1 means blurred edge has min intensity 10% as large as inner-object intensity. self.pp_blur_min_intensity = 0.2 self.max_instances_per_sample = 1 #how many max instances over all classes per sample (img if 2d, vol if 3d) self.max_instances_per_class = self.max_instances_per_sample # how many max instances per image per class self.noise_scale = 0. # std-dev of gaussian noise self.ambigs_sampling = "gaussian" #"gaussian" or "uniform" """ radius_calib: gt distort for calibrating uncertainty. Range of gt distortion is inferable from image by distinguishing it from the rest of the object. blurring width around edge will be shifted so that symmetric rel to orig radius. blurring scale: if self.ambigs_sampling is uniform, distribution's non-zero range (b-a) will be sqrt(12)*scale since uniform dist has variance (b-a)²/12. b,a will be placed symmetrically around unperturbed radius. if sampling is gaussian, then scale parameter sets one std dev, i.e., blurring width will be orig_radius * std_dev * 2. """ self.ambiguities = { #set which classes to apply which ambs to below in class labels #choose out of: 'outer_radius', 'inner_radius', 'radii_relations'. #kind #probability #scale (gaussian std, relative to unperturbed value) #"outer_radius": (1., 0.5), #"outer_radius_xy": (1., 0.5), #"inner_radius": (0.5, 0.1), #"radii_relations": (0.5, 0.1), "radius_calib": (1., 1./6) } # shape choices: 'cylinder', 'block' # id, name, shape, radius, color, regression, ambiguities, gt_distortion self.pp_classes = [Label(1, 'cylinder', 'cylinder', ((6,6,1),(40,40,8)), (*self.blue, 1.), "radius_2d", (), ()), #Label(2, 'block', 'block', ((6,6,1),(40,40,8)), (*self.aubergine,1.), "radii_2d", (), ('radius_calib',)) ] ######################### # I/O # ######################### - self.data_sourcedir = '/mnt/HDD2TB/Documents/data/toy/cyl1ps_dev' - #self.data_sourcedir = '/mnt/HDD2TB/Documents/data/toy/cyl1ps_exact' + self.data_sourcedir = '/mnt/HDD2TB/Documents/data/toy/cyl1ps_exact' if server_env: - #self.data_sourcedir = '/datasets/data_ramien/toy/cyl1ps_exact_npz' - self.data_sourcedir = '/datasets/data_ramien/toy/cyl1ps_ambig_beyond_bin_npz' + self.data_sourcedir = '/datasets/data_ramien/toy/cyl1ps_exact_npz' + self.test_data_sourcedir = os.path.join(self.data_sourcedir, 'test') self.data_sourcedir = os.path.join(self.data_sourcedir, "train") self.info_df_name = 'info_df.pickle' # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn', 'detection_fpn']. - self.model = 'retina_net' + self.model = 'retina_unet' self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net') self.model_path = os.path.join(self.source_dir, self.model_path) ######################### # Architecture # ######################### # one out of [2, 3]. dimension the model operates in. self.dim = 3 # 'class', 'regression', 'regression_bin', 'regression_ken_gal' # currently only tested mode is a single-task at a time (i.e., only one task in below list) # but, in principle, tasks could be combined (e.g., object classes and regression per class) self.prediction_tasks = ['class',] self.start_filts = 48 if self.dim == 2 else 18 self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2 self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50' self.norm = 'instance_norm' # one of None, 'instance_norm', 'batch_norm' self.relu = 'relu' # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform') self.weight_init = None self.regression_n_features = 1 # length of regressor target vector ######################### # Data Loader # ######################### self.num_epochs = 32 self.num_train_batches = 120 if self.dim == 2 else 80 self.batch_size = 16 if self.dim == 2 else 8 self.n_cv_splits = 4 # select modalities from preprocessed data self.channels = [0] self.n_channels = len(self.channels) # which channel (mod) to show as bg in plotting, will be extra added to batch if not in self.channels self.plot_bg_chan = 0 self.crop_margin = [20, 20, 1] # has to be smaller than respective patch_size//2 self.patch_size_2D = self.pre_crop_size[:2] self.patch_size_3D = self.pre_crop_size[:2]+[8] # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation. self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D # ratio of free sampled batch elements before class balancing is triggered # (>0 to include "empty"/background patches.) self.batch_random_ratio = 0.2 self.balance_target = "class_targets" if 'class' in self.prediction_tasks else "rg_bin_targets" self.observables_patient = [] self.observables_rois = [] self.seed = 3 #for generating folds ############################# # Colors, Classes, Legends # ############################# self.plot_frequency = 1 binary_bin_labels = [binLabel(1, 'r<=25', (*self.green, 1.), (1,25)), binLabel(2, 'r>25', (*self.red, 1.), (25,))] quintuple_bin_labels = [binLabel(1, 'r2-10', (*self.green, 1.), (2,10)), binLabel(2, 'r10-20', (*self.yellow, 1.), (10,20)), binLabel(3, 'r20-30', (*self.orange, 1.), (20,30)), binLabel(4, 'r30-40', (*self.bright_red, 1.), (30,40)), binLabel(5, 'r>40', (*self.red, 1.), (40,))] # choose here if to do 2-way or 5-way regression-bin classification task_spec_bin_labels = quintuple_bin_labels self.class_labels = [ # regression: regression-task label, either value or "(x,y,z)_radius" or "radii". # ambiguities: name of above defined ambig to apply to image data (not gt); need to be iterables! # gt_distortion: name of ambig to apply to gt only; needs to be iterable! # #id #name #shape #radius #color #regression #ambiguities #gt_distortion Label( 0, 'bg', None, (0, 0, 0), (*self.white, 0.), (0, 0, 0), (), ())] if "class" in self.prediction_tasks: self.class_labels += self.pp_classes else: self.class_labels += [Label(1, 'object', 'object', ('various',), (*self.orange, 1.), ('radius_2d',), ("various",), ('various',))] if any(['regression' in task for task in self.prediction_tasks]): self.bin_labels = [binLabel(0, 'bg', (*self.white, 1.), (0,))] self.bin_labels += task_spec_bin_labels self.bin_id2label = {label.id: label for label in self.bin_labels} bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels] self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)} self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)] self.bin_dict = {label.id: label.name for label in self.bin_labels if label.id != 0} if self.class_specific_seg: self.seg_labels = self.class_labels self.box_type2label = {label.name: label for label in self.box_labels} self.class_id2label = {label.id: label for label in self.class_labels} self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0} self.seg_id2label = {label.id: label for label in self.seg_labels} self.cmap = {label.id: label.color for label in self.seg_labels} self.plot_prediction_histograms = True self.plot_stat_curves = False self.has_colorchannels = False self.plot_class_ids = True self.num_classes = len(self.class_dict) self.num_seg_classes = len(self.seg_labels) ######################### # Data Augmentation # ######################### self.do_aug = True self.da_kwargs = { 'mirror': True, 'mirror_axes': tuple(np.arange(0, self.dim, 1)), 'do_elastic_deform': False, 'alpha': (500., 1500.), 'sigma': (40., 45.), 'do_rotation': False, 'angle_x': (0., 2 * np.pi), 'angle_y': (0., 0), 'angle_z': (0., 0), 'do_scale': False, 'scale': (0.8, 1.1), 'random_crop': False, 'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3), 'border_mode_data': 'constant', 'border_cval_data': 0, 'order_data': 1 } if self.dim == 3: self.da_kwargs['do_elastic_deform'] = False self.da_kwargs['angle_x'] = (0, 0.0) self.da_kwargs['angle_y'] = (0, 0.0) # must be 0!! self.da_kwargs['angle_z'] = (0., 2 * np.pi) ######################### # Schedule / Selection # ######################### # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training) # the former is morge accurate, while the latter is faster (depending on volume size) self.val_mode = 'val_sampling' # one of 'val_sampling' , 'val_patient' if self.val_mode == 'val_patient': self.max_val_patients = 220 # if 'all' iterates over entire val_set once. if self.val_mode == 'val_sampling': self.num_val_batches = 25 if self.dim==2 else 15 self.save_n_models = 2 self.min_save_thresh = 1 if self.dim == 2 else 1 # =wait time in epochs if "class" in self.prediction_tasks: self.model_selection_criteria = {name + "_ap": 1. for name in self.class_dict.values()} elif any("regression" in task for task in self.prediction_tasks): self.model_selection_criteria = {name + "_ap": 0.2 for name in self.class_dict.values()} self.model_selection_criteria.update({name + "_avp": 0.8 for name in self.class_dict.values()}) self.lr_decay_factor = 0.5 self.scheduling_patience = int(self.num_epochs / 5) self.weight_decay = 1e-5 self.clip_norm = None # number or None ######################### # Testing / Plotting # ######################### self.test_aug_axes = (0,1,(0,1)) # None or list: choices are 0,1,(0,1) self.held_out_test_set = True self.max_test_patients = "all" # number or "all" for all self.test_against_exact_gt = not 'exact' in self.data_sourcedir self.val_against_exact_gt = False # True is an unrealistic --> irrelevant scenario. self.report_score_level = ['rois'] # 'patient' or 'rois' (incl) self.patient_class_of_interest = 1 self.patient_bin_of_interest = 2 self.eval_bins_separately = False#"additionally" if not 'class' in self.prediction_tasks else False self.metrics = ['ap', 'auc', 'dice'] if any(['regression' in task for task in self.prediction_tasks]): self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp', 'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp'] if 'aleatoric' in self.model: self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted'] self.evaluate_fold_means = True self.ap_match_ious = [0.5] # threshold(s) for considering a prediction as true positive self.min_det_thresh = 0.3 self.model_max_iou_resolution = 0.2 # aggregation method for test and val_patient predictions. # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf, # nms = standard non-maximum suppression, or None = no clustering self.clustering = 'wbc' # iou thresh (exclusive!) for regarding two preds as concerning the same ROI self.clustering_iou = self.model_max_iou_resolution # has to be larger than desired possible overlap iou of model predictions self.merge_2D_to_3D_preds = False self.merge_3D_iou = self.model_max_iou_resolution self.n_test_plots = 1 # per fold and rank self.test_n_epochs = self.save_n_models # should be called n_test_ens, since is number of models to ensemble over during testing # is multiplied by (1 + nr of test augs) - #self.losses_to_monitor += ['class_loss', 'rg_loss'] - ######################### # Assertions # ######################### if not 'class' in self.prediction_tasks: assert self.num_classes == 1 ######################### # Add model specifics # ######################### {'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs, 'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs, 'detection_unet': self.add_det_unet_configs, 'detection_fpn': self.add_det_fpn_configs }[self.model]() def rg_val_to_bin_id(self, rg_val): #only meant for isotropic radii!! # only 2D radii (x and y dims) or 1D (x or y) are expected return np.round(np.digitize(rg_val, self.bin_edges).mean()) def add_det_fpn_configs(self): self.learning_rate = [5 * 1e-4] * self.num_epochs self.dynamic_lr_scheduling = True self.scheduling_criterion = 'torch_loss' self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' self.n_roi_candidates = 4 if self.dim == 2 else 6 # max number of roi candidates to identify per image (slice in 2D, volume in 3D) # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') self.seg_loss_mode = 'wce' self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1] self.fp_dice_weight = 1 if self.dim == 2 else 1 # if <1, false positive predictions in foreground are penalized less. self.detection_min_confidence = 0.05 # how to determine score of roi: 'max' or 'median' self.score_det = 'max' def add_det_unet_configs(self): self.learning_rate = [5 * 1e-4] * self.num_epochs self.dynamic_lr_scheduling = True self.scheduling_criterion = "torch_loss" self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' # max number of roi candidates to identify per image (slice in 2D, volume in 3D) self.n_roi_candidates = 4 if self.dim == 2 else 6 # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') self.seg_loss_mode = 'wce' self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1] # if <1, false positive predictions in foreground are penalized less. self.fp_dice_weight = 1 if self.dim == 2 else 1 self.detection_min_confidence = 0.05 # how to determine score of roi: 'max' or 'median' self.score_det = 'max' self.init_filts = 32 self.kernel_size = 3 # ks for horizontal, normal convs self.kernel_size_m = 2 # ks for max pool self.pad = "same" # "same" or integer, padding of horizontal convs def add_mrcnn_configs(self): self.learning_rate = [1e-4] * self.num_epochs self.dynamic_lr_scheduling = True # with scheduler set in exec self.scheduling_criterion = max(self.model_selection_criteria, key=self.model_selection_criteria.get) self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' # number of classes for network heads: n_foreground_classes + 1 (background) self.head_classes = self.num_classes + 1 if 'class' in self.prediction_tasks else 2 # feed +/- n neighbouring slices into channel dimension. set to None for no context. self.n_3D_context = None if self.n_3D_context is not None and self.dim == 2: self.n_channels *= (self.n_3D_context * 2 + 1) self.detect_while_training = True # disable the re-sampling of mask proposals to original size for speed-up. # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching), # mask outputs are optional. self.return_masks_in_train = True self.return_masks_in_val = True self.return_masks_in_test = True # feature map strides per pyramid level are inferred from architecture. anchor scales are set accordingly. self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]} # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.) self.rpn_anchor_scales = {'xy': [[4], [8], [16], [32]], 'z': [[1], [2], [4], [8]]} # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3. self.pyramid_levels = [0, 1, 2, 3] # number of feature maps in rpn. typically lowered in 3D to save gpu-memory. self.n_rpn_features = 512 if self.dim == 2 else 64 # anchor ratios and strides per position in feature maps. self.rpn_anchor_ratios = [0.5, 1., 2.] self.rpn_anchor_stride = 1 # Threshold for first stage (RPN) non-maximum suppression (NMS): LOWER == HARDER SELECTION self.rpn_nms_threshold = max(0.8, self.model_max_iou_resolution) # loss sampling settings. self.rpn_train_anchors_per_image = 4 self.train_rois_per_image = 6 # per batch_instance self.roi_positive_ratio = 0.5 self.anchor_matching_iou = 0.8 # k negative example candidates are drawn from a pool of size k*shem_poolsize (stochastic hard-example mining), # where k<=#positive examples. self.shem_poolsize = 2 self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3) self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5) self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10) self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]]) self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1], self.patch_size_3D[2], self.patch_size_3D[2]]) # y1,x1,y2,x2,z1,z2 if self.dim == 2: self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4] self.bbox_std_dev = self.bbox_std_dev[:4] self.window = self.window[:4] self.scale = self.scale[:4] self.plot_y_max = 1.5 self.n_plot_rpn_props = 5 if self.dim == 2 else 30 # per batch_instance (slice in 2D / patient in 3D) # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element. self.pre_nms_limit = 2000 if self.dim == 2 else 4000 # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True, # since proposals of the entire batch are forwarded through second stage as one "batch". self.roi_chunk_size = 1300 if self.dim == 2 else 500 self.post_nms_rois_training = 200 * (self.head_classes-1) if self.dim == 2 else 400 self.post_nms_rois_inference = 200 * (self.head_classes-1) # Final selection of detections (refine_detections) self.model_max_instances_per_batch_element = 9 if self.dim == 2 else 18 # per batch element and class. self.detection_nms_threshold = self.model_max_iou_resolution # needs to be > 0, otherwise all predictions are one cluster. self.model_min_confidence = 0.2 # iou for nms in box refining (directly after heads), should be >0 since ths>=x in mrcnn.py if self.dim == 2: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride))] for stride in self.backbone_strides['xy']]) else: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride)), int(np.ceil(self.patch_size[2] / stride_z))] for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z'] )]) if self.model == 'retina_net' or self.model == 'retina_unet': # whether to use focal loss or SHEM for loss-sample selection self.focal_loss = False # implement extra anchor-scales according to https://arxiv.org/abs/1708.02002 self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['xy']] self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['z']] self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3 - #self.n_rpn_features = 256 if self.dim == 2 else 64 - # pre-selection of detections for NMS-speedup. per entire batch. self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002 self.anchor_matching_iou = 0.7 if self.model == 'retina_unet': self.operate_stride1 = True diff --git a/datasets/toy/data_loader.py b/datasets/toy/data_loader.py index f3c7ac0..dc9a03f 100644 --- a/datasets/toy/data_loader.py +++ b/datasets/toy/data_loader.py @@ -1,601 +1,597 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import sys -sys.path.append('../') #works on cluster indep from where sbatch job is started +sys.path.append('../') # works on cluster indep from where sbatch job is started import plotting as plg import numpy as np import os from collections import OrderedDict import pandas as pd import pickle import time # batch generator tools from https://github.com/MIC-DKFZ/batchgenerators -from batchgenerators.dataloading.data_loader import SlimDataLoaderBase from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror from batchgenerators.transforms.abstract_transforms import Compose from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter -from batchgenerators.dataloading import SingleThreadedAugmenter from batchgenerators.transforms.spatial_transforms import SpatialTransform from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform -#from batchgenerators.transforms.utility_transforms import ConvertSegToBoundingBoxCoordinates sys.path.append(os.path.dirname(os.path.realpath(__file__))) -import utils.exp_utils as utils import utils.dataloader_utils as dutils from utils.dataloader_utils import ConvertSegToBoundingBoxCoordinates def load_obj(file_path): with open(file_path, 'rb') as handle: return pickle.load(handle) class Dataset(dutils.Dataset): r""" Load a dict holding memmapped arrays and clinical parameters for each patient, evtly subset of those. If server_env: copy and evtly unpack (npz->npy) data in cf.data_rootdir to cf.data_dir. :param cf: config file :param folds: number of folds out of @params n_cv folds to include :param n_cv: number of total folds :return: dict with imgs, segs, pids, class_labels, observables """ def __init__(self, cf, logger, subset_ids=None, data_sourcedir=None, mode='train'): super(Dataset,self).__init__(cf, data_sourcedir=data_sourcedir) load_exact_gts = (mode=='test' or cf.val_mode=="val_patient") and self.cf.test_against_exact_gt p_df = pd.read_pickle(os.path.join(self.data_dir, cf.info_df_name)) if subset_ids is not None: p_df = p_df[p_df.pid.isin(subset_ids)] logger.info('subset: selected {} instances from df'.format(len(p_df))) pids = p_df.pid.tolist() #evtly copy data from data_sourcedir to data_dest if cf.server_env and not hasattr(cf, "data_dir"): file_subset = [os.path.join(self.data_dir, '{}.*'.format(pid)) for pid in pids] file_subset += [os.path.join(self.data_dir, '{}_seg.*'.format(pid)) for pid in pids] file_subset += [cf.info_df_name] if load_exact_gts: file_subset += [os.path.join(self.data_dir, '{}_exact_seg.*'.format(pid)) for pid in pids] self.copy_data(cf, file_subset=file_subset) img_paths = [os.path.join(self.data_dir, '{}.npy'.format(pid)) for pid in pids] seg_paths = [os.path.join(self.data_dir, '{}_seg.npy'.format(pid)) for pid in pids] if load_exact_gts: exact_seg_paths = [os.path.join(self.data_dir, '{}_exact_seg.npy'.format(pid)) for pid in pids] class_targets = p_df['class_ids'].tolist() rg_targets = p_df['regression_vectors'].tolist() if load_exact_gts: exact_rg_targets = p_df['undistorted_rg_vectors'].tolist() fg_slices = p_df['fg_slices'].tolist() self.data = OrderedDict() for ix, pid in enumerate(pids): self.data[pid] = {'data': img_paths[ix], 'seg': seg_paths[ix], 'pid': pid, 'fg_slices': np.array(fg_slices[ix])} if load_exact_gts: self.data[pid]['exact_seg'] = exact_seg_paths[ix] if 'class' in self.cf.prediction_tasks: self.data[pid]['class_targets'] = np.array(class_targets[ix], dtype='uint8') else: self.data[pid]['class_targets'] = np.ones_like(np.array(class_targets[ix]), dtype='uint8') if load_exact_gts: self.data[pid]['exact_class_targets'] = self.data[pid]['class_targets'] if any(['regression' in task for task in self.cf.prediction_tasks]): self.data[pid]['regression_targets'] = np.array(rg_targets[ix], dtype='float16') self.data[pid]["rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in rg_targets[ix]], dtype='uint8') if load_exact_gts: self.data[pid]['exact_regression_targets'] = np.array(exact_rg_targets[ix], dtype='float16') self.data[pid]["exact_rg_bin_targets"] = np.array([cf.rg_val_to_bin_id(v) for v in exact_rg_targets[ix]], dtype='uint8') cf.roi_items = cf.observables_rois[:] cf.roi_items += ['class_targets'] if any(['regression' in task for task in self.cf.prediction_tasks]): cf.roi_items += ['regression_targets'] cf.roi_items += ['rg_bin_targets'] self.set_ids = np.array(list(self.data.keys())) self.df = None class BatchGenerator(dutils.BatchGenerator): """ creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D) from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size. Actual patch_size is obtained after data augmentation. :param data: data dictionary as provided by 'load_dataset'. :param batch_size: number of patients to sample for the batch :return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target """ def __init__(self, cf, data, sample_pids_w_replace=True): super(BatchGenerator, self).__init__(cf, data) self.chans = cf.channels if cf.channels is not None else np.index_exp[:] assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing" self.sample_pids_w_replace = sample_pids_w_replace self.eligible_pids = list(self._data.keys()) self.crop_margin = np.array(self.cf.patch_size) / 8. # min distance of ROI center to edge of cropped_patch. self.p_fg = 0.5 self.empty_samples_max_ratio = 0.6 self.random_count = int(cf.batch_random_ratio * cf.batch_size) self.balance_target_distribution(plot=sample_pids_w_replace) self.stats = {"roi_counts": np.zeros((len(self.unique_ts),), dtype='uint32'), "empty_samples_count": 0} def generate_train_batch(self): # everything done in here is per batch # print statements in here get confusing due to multithreading if self.sample_pids_w_replace: # fully random patients batch_patient_ids = list(np.random.choice(self.dataset_pids, size=self.random_count, replace=False)) # target-balanced patients batch_patient_ids += list(np.random.choice( self.dataset_pids, size=self.batch_size - self.random_count, replace=False, p=self.p_probs)) else: batch_patient_ids = np.random.choice(self.eligible_pids, size=self.batch_size, replace=False) if self.sample_pids_w_replace == False: self.eligible_pids = [pid for pid in self.eligible_pids if pid not in batch_patient_ids] if len(self.eligible_pids) < self.batch_size: self.eligible_pids = self.dataset_pids batch_data, batch_segs, batch_patient_targets = [], [], [] batch_roi_items = {name: [] for name in self.cf.roi_items} # record roi count of classes in batch # empty count for full bg samples (empty slices in 2D/patients in 3D) in slot num_classes (last) batch_roi_counts, empty_samples_count = np.zeros((len(self.unique_ts),), dtype='uint32'), 0 for b in range(self.batch_size): patient = self._data[batch_patient_ids[b]] data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis] seg = np.load(patient['seg'], mmap_mode='r').astype('uint8') (c, y, x, z) = data.shape if self.cf.dim == 2: elig_slices, choose_fg = [], False if len(patient['fg_slices']) > 0: if empty_samples_count / self.batch_size >= self.empty_samples_max_ratio or np.random.rand( 1) <= self.p_fg: # fg is to be picked for tix in np.argsort(batch_roi_counts): # pick slices of patient that have roi of sought-for target # np.unique(seg[...,sl_ix][seg[...,sl_ix]>0]) gives roi_ids (numbering) of rois in slice sl_ix elig_slices = [sl_ix for sl_ix in np.arange(z) if np.count_nonzero( patient[self.balance_target][np.unique(seg[..., sl_ix][seg[..., sl_ix] > 0]) - 1] == self.unique_ts[tix]) > 0] if len(elig_slices) > 0: choose_fg = True break else: # pick bg elig_slices = np.setdiff1d(np.arange(z), patient['fg_slices']) if len(elig_slices) > 0: sl_pick_ix = np.random.choice(elig_slices, size=None) else: sl_pick_ix = np.random.choice(z, size=None) data = data[..., sl_pick_ix] seg = seg[..., sl_pick_ix] spatial_shp = data[0].shape assert spatial_shp == seg.shape, "spatial shape incongruence betw. data and seg" if np.any([spatial_shp[ix] < self.cf.pre_crop_size[ix] for ix in range(len(spatial_shp))]): new_shape = [np.max([spatial_shp[ix], self.cf.pre_crop_size[ix]]) for ix in range(len(spatial_shp))] data = dutils.pad_nd_image(data, (len(data), *new_shape)) seg = dutils.pad_nd_image(seg, new_shape) # eventual cropping to pre_crop_size: sample pixel from random ROI and shift center, # if possible, to that pixel, so that img still contains ROI after pre-cropping dim_cropflags = [spatial_shp[i] > self.cf.pre_crop_size[i] for i in range(len(spatial_shp))] if np.any(dim_cropflags): # sample pixel from random ROI and shift center, if possible, to that pixel if self.cf.dim==3: choose_fg = (empty_samples_count/self.batch_size>=self.empty_samples_max_ratio) or np.random.rand(1) <= self.p_fg if choose_fg and np.any(seg): available_roi_ids = np.unique(seg)[1:] for tix in np.argsort(batch_roi_counts): elig_roi_ids = available_roi_ids[patient[self.balance_target][available_roi_ids-1] == self.unique_ts[tix]] if len(elig_roi_ids)>0: seg_ics = np.argwhere(seg == np.random.choice(elig_roi_ids, size=None)) break roi_anchor_pixel = seg_ics[np.random.choice(seg_ics.shape[0], size=None)] assert seg[tuple(roi_anchor_pixel)] > 0 # sample the patch center coords. constrained by edges of image - pre_crop_size /2 and # distance to the selected ROI < patch_size /2 def get_cropped_centercoords(dim): low = np.max((self.cf.pre_crop_size[dim] // 2, roi_anchor_pixel[dim] - ( self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim]))) high = np.min((spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2, roi_anchor_pixel[dim] + ( self.cf.patch_size[dim] // 2 - self.cf.crop_margin[dim]))) if low >= high: # happens if lesion on the edge of the image. low = self.cf.pre_crop_size[dim] // 2 high = spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2 assert low < high, 'low greater equal high, data dimension {} too small, shp {}, patient {}, low {}, high {}'.format( dim, spatial_shp, patient['pid'], low, high) return np.random.randint(low=low, high=high) else: # sample crop center regardless of ROIs, not guaranteed to be empty def get_cropped_centercoords(dim): return np.random.randint(low=self.cf.pre_crop_size[dim] // 2, high=spatial_shp[dim] - self.cf.pre_crop_size[dim] // 2) sample_seg_center = {} for dim in np.where(dim_cropflags)[0]: sample_seg_center[dim] = get_cropped_centercoords(dim) min_ = int(sample_seg_center[dim] - self.cf.pre_crop_size[dim] // 2) max_ = int(sample_seg_center[dim] + self.cf.pre_crop_size[dim] // 2) data = np.take(data, indices=range(min_, max_), axis=dim + 1) # +1 for channeldim seg = np.take(seg, indices=range(min_, max_), axis=dim) batch_data.append(data) batch_segs.append(seg[np.newaxis]) for o in batch_roi_items: #after loop, holds every entry of every batchpatient per observable batch_roi_items[o].append(patient[o]) if self.cf.dim == 3: for tix in range(len(self.unique_ts)): batch_roi_counts[tix] += np.count_nonzero(patient[self.balance_target] == self.unique_ts[tix]) elif self.cf.dim == 2: for tix in range(len(self.unique_ts)): batch_roi_counts[tix] += np.count_nonzero(patient[self.balance_target][np.unique(seg[seg>0]) - 1] == self.unique_ts[tix]) if not np.any(seg): empty_samples_count += 1 batch = {'data': np.array(batch_data), 'seg': np.array(batch_segs).astype('uint8'), 'pid': batch_patient_ids, 'roi_counts': batch_roi_counts, 'empty_samples_count': empty_samples_count} for key,val in batch_roi_items.items(): #extend batch dic by entries of observables dic batch[key] = np.array(val) return batch class PatientBatchIterator(dutils.PatientBatchIterator): """ creates a test generator that iterates over entire given dataset returning 1 patient per batch. Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actually evaluation (done in 3D), if willing to accept speed-loss during training. Specific properties of toy data set: toy data may be created with added ground-truth noise. thus, there are exact ground truths (GTs) and noisy ground truths available. the normal or noisy GTs are used in training by the BatchGenerator. The PatientIterator, however, may use the exact GTs if set in configs. :return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or batch_size = n_2D_patches in 2D . """ def __init__(self, cf, data, mode='test'): super(PatientBatchIterator, self).__init__(cf, data) self.patch_size = cf.patch_size_2D + [1] if cf.dim == 2 else cf.patch_size_3D self.chans = cf.channels if cf.channels is not None else np.index_exp[:] assert hasattr(self.chans, "__iter__"), "self.chans has to be list-like to maintain dims when slicing" if (mode=="validation" and hasattr(self.cf, 'val_against_exact_gt') and self.cf.val_against_exact_gt) or \ (mode == 'test' and self.cf.test_against_exact_gt): self.gt_prefix = 'exact_' print("PatientIterator: Loading exact Ground Truths.") else: self.gt_prefix = '' self.patient_ix = 0 # running index over all patients in set def generate_train_batch(self, pid=None): if pid is None: pid = self.dataset_pids[self.patient_ix] patient = self._data[pid] # already swapped dimensions in pp from (c,)z,y,x to c,y,x,z or h,w,d to ease 2D/3D-case handling data = np.load(patient['data'], mmap_mode='r').astype('float16')[np.newaxis] seg = np.load(patient[self.gt_prefix+'seg']).astype('uint8')[np.newaxis] data_shp_raw = data.shape plot_bg = data[self.cf.plot_bg_chan] if self.cf.plot_bg_chan not in self.chans else None data = data[self.chans] discarded_chans = len( [c for c in np.setdiff1d(np.arange(data_shp_raw[0]), self.chans) if c < self.cf.plot_bg_chan]) spatial_shp = data[0].shape # spatial dims need to be in order x,y,z assert spatial_shp == seg[0].shape, "spatial shape incongruence betw. data and seg" if np.any([spatial_shp[i] < ps for i, ps in enumerate(self.patch_size)]): new_shape = [np.max([spatial_shp[i], self.patch_size[i]]) for i in range(len(self.patch_size))] data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape. seg = dutils.pad_nd_image(seg, new_shape) if plot_bg is not None: plot_bg = dutils.pad_nd_image(plot_bg, new_shape) if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds: # adds the batch dim here bc won't go through MTaugmenter out_data = data[np.newaxis] out_seg = seg[np.newaxis] if plot_bg is not None: out_plot_bg = plot_bg[np.newaxis] # data and seg shape: (1,c,x,y,z), where c=1 for seg batch_3D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_3D[o] = np.array([patient[self.gt_prefix+o]]) converter = ConvertSegToBoundingBoxCoordinates(3, self.cf.roi_items, False, self.cf.class_specific_seg) batch_3D = converter(**batch_3D) batch_3D.update({'patient_bb_target': batch_3D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_3D["patient_" + o] = batch_3D[o] if self.cf.dim == 2: out_data = np.transpose(data, axes=(3, 0, 1, 2)).astype('float32') # (c,y,x,z) to (b=z,c,x,y), use z=b as batchdim out_seg = np.transpose(seg, axes=(3, 0, 1, 2)).astype('uint8') # (c,y,x,z) to (b=z,c,x,y) batch_2D = {'data': out_data, 'seg': out_seg} for o in self.cf.roi_items: batch_2D[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(out_data), axis=0) converter = ConvertSegToBoundingBoxCoordinates(2, self.cf.roi_items, False, self.cf.class_specific_seg) batch_2D = converter(**batch_2D) if plot_bg is not None: out_plot_bg = np.transpose(plot_bg, axes=(2, 0, 1)).astype('float32') if self.cf.merge_2D_to_3D_preds: batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_3D[o] else: batch_2D.update({'patient_bb_target': batch_2D['bb_target'], 'original_img_shape': out_data.shape}) for o in self.cf.roi_items: batch_2D["patient_" + o] = batch_2D[o] out_batch = batch_3D if self.cf.dim == 3 else batch_2D out_batch.update({'pid': np.array([patient['pid']] * len(out_data))}) if self.cf.plot_bg_chan in self.chans and discarded_chans > 0: # len(self.chans[:self.cf.plot_bg_chan]) self.patch_size[ix] for ix in range(len(spatial_shp))]): patient_batch = out_batch print("patientiterator produced patched batch!") patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size) new_img_batch, new_seg_batch = [], [] for c in patch_crop_coords_list: new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:c[5]]) seg_patch = seg[:, c[0]:c[1], c[2]: c[3], c[4]:c[5]] new_seg_batch.append(seg_patch) shps = [] for arr in new_img_batch: shps.append(arr.shape) data = np.array(new_img_batch) # (patches, c, x, y, z) seg = np.array(new_seg_batch) if self.cf.dim == 2: # all patches have z dimension 1 (slices). discard dimension data = data[..., 0] seg = seg[..., 0] patch_batch = {'data': data.astype('float32'), 'seg': seg.astype('uint8'), 'pid': np.array([patient['pid']] * data.shape[0])} for o in self.cf.roi_items: patch_batch[o] = np.repeat(np.array([patient[self.gt_prefix+o]]), len(patch_crop_coords_list), axis=0) #patient-wise (orig) batch info for putting the patches back together after prediction for o in self.cf.roi_items: patch_batch["patient_"+o] = patient_batch["patient_"+o] if self.cf.dim == 2: # this could also be named "unpatched_2d_roi_items" patch_batch["patient_" + o + "_2d"] = patient_batch[o] patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list) patch_batch['patient_bb_target'] = patient_batch['patient_bb_target'] if self.cf.dim == 2: patch_batch['patient_bb_target_2d'] = patient_batch['bb_target'] patch_batch['patient_data'] = patient_batch['data'] patch_batch['patient_seg'] = patient_batch['seg'] patch_batch['original_img_shape'] = patient_batch['original_img_shape'] if plot_bg is not None: patch_batch['patient_plot_bg'] = patient_batch['plot_bg'] converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, self.cf.roi_items, get_rois_from_seg=False, class_specific_seg=self.cf.class_specific_seg) patch_batch = converter(**patch_batch) out_batch = patch_batch self.patient_ix += 1 if self.patient_ix == len(self.dataset_pids): self.patient_ix = 0 return out_batch def create_data_gen_pipeline(cf, patient_data, do_aug=True, sample_pids_w_replace=True): """ create mutli-threaded train/val/test batch generation and augmentation pipeline. :param patient_data: dictionary containing one dictionary per patient in the train/test subset. :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing) :return: multithreaded_generator """ # create instance of batch generator as first element in pipeline. data_gen = BatchGenerator(cf, patient_data, sample_pids_w_replace=sample_pids_w_replace) my_transforms = [] if do_aug: if cf.da_kwargs["mirror"]: mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes']) my_transforms.append(mirror_transform) spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop']) my_transforms.append(spatial_transform) else: my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg)) all_transforms = Compose(my_transforms) # multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms) multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers)) return multithreaded_generator def get_train_generators(cf, logger, data_statistics=False): """ wrapper function for creating the training batch generator pipeline. returns the train/val generators. selects patients according to cv folds (generated by first run/fold of experiment): splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set) If cf.hold_out_test_set is True, adds the test split to the training data. """ dataset = Dataset(cf, logger) dataset.init_FoldGenerator(cf.seed, cf.n_cv_splits) dataset.generate_splits(check_file=os.path.join(cf.exp_dir, 'fold_ids.pickle')) set_splits = dataset.fg.splits test_ids, val_ids = set_splits.pop(cf.fold), set_splits.pop(cf.fold - 1) train_ids = np.concatenate(set_splits, axis=0) if cf.held_out_test_set: train_ids = np.concatenate((train_ids, test_ids), axis=0) test_ids = [] train_data = {k: v for (k, v) in dataset.data.items() if str(k) in train_ids} val_data = {k: v for (k, v) in dataset.data.items() if str(k) in val_ids} logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ids), len(val_ids), len(test_ids))) if data_statistics: dataset.calc_statistics(subsets={"train": train_ids, "val": val_ids, "test": test_ids}, plot_dir= os.path.join(cf.plot_dir,"dataset")) batch_gen = {} batch_gen['train'] = create_data_gen_pipeline(cf, train_data, do_aug=cf.do_aug, sample_pids_w_replace=True) batch_gen['val_sampling'] = create_data_gen_pipeline(cf, val_data, do_aug=False, sample_pids_w_replace=False) if cf.val_mode == 'val_patient': batch_gen['val_patient'] = PatientBatchIterator(cf, val_data, mode='validation') batch_gen['n_val'] = len(val_ids) if cf.max_val_patients=="all" else min(len(val_ids), cf.max_val_patients) elif cf.val_mode == 'val_sampling': batch_gen['n_val'] = cf.num_val_batches if cf.num_val_batches != "all" else len(val_data) return batch_gen def get_test_generator(cf, logger): """ if get_test_generators is possibly called multiple times in server env, every time of Dataset initiation rsync will check for copying the data; this should be okay since rsync will not copy if files already exist in destination. """ if cf.held_out_test_set: sourcedir = cf.test_data_sourcedir test_ids = None else: sourcedir = None with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle: set_splits = pickle.load(handle) test_ids = set_splits[cf.fold] test_set = Dataset(cf, logger, subset_ids=test_ids, data_sourcedir=sourcedir, mode='test') logger.info("data set loaded with: {} test patients".format(len(test_set.set_ids))) batch_gen = {} batch_gen['test'] = PatientBatchIterator(cf, test_set.data) batch_gen['n_test'] = len(test_set.set_ids) if cf.max_test_patients=="all" else \ min(cf.max_test_patients, len(test_set.set_ids)) return batch_gen if __name__=="__main__": import utils.exp_utils as utils from configs import Configs cf = Configs() total_stime = time.time() times = {} # cf.server_env = True # cf.data_dir = "experiments/dev_data" cf.exp_dir = "experiments/dev/" cf.plot_dir = cf.exp_dir + "plots" os.makedirs(cf.exp_dir, exist_ok=True) cf.fold = 0 logger = utils.get_logger(cf.exp_dir) gens = get_train_generators(cf, logger) train_loader = gens['train'] for i in range(0): stime = time.time() print("producing training batch nr ", i) ex_batch = next(train_loader) times["train_batch"] = time.time() - stime #experiments/dev/dev_exbatch_{}.png".format(i) plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=False) val_loader = gens['val_sampling'] stime = time.time() for i in range(1): ex_batch = next(val_loader) times["val_batch"] = time.time() - stime stime = time.time() #"experiments/dev/dev_exvalbatch_{}.png" plg.view_batch(cf, ex_batch, out_file="experiments/dev/dev_exvalbatch_{}.png".format(i), show_gt_labels=True, vmin=0, show_info=True) times["val_plot"] = time.time() - stime import IPython; IPython.embed() # test_loader = get_test_generator(cf, logger)["test"] stime = time.time() ex_batch = test_loader.generate_train_batch(pid=None) times["test_batch"] = time.time() - stime stime = time.time() plg.view_batch(cf, ex_batch, show_gt_labels=True, out_file="experiments/dev/dev_expatchbatch.png", vmin=0) times["test_patchbatch_plot"] = time.time() - stime print("Times recorded throughout:") for (k, v) in times.items(): print(k, "{:.2f}".format(v)) mins, secs = divmod((time.time() - total_stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) print("{} total runtime: {}".format(os.path.split(__file__)[1], t)) \ No newline at end of file diff --git a/exec.py b/exec.py index 74c10ba..34fd5fe 100644 --- a/exec.py +++ b/exec.py @@ -1,343 +1,344 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ execution script. this where all routines come together and the only script you need to call. refer to parse args below to see options for execution. """ import plotting as plg import os import warnings import argparse import time import torch import utils.exp_utils as utils from evaluator import Evaluator from predictor import Predictor for msg in ["Attempting to set identical bottom==top results", "This figure includes Axes that are not compatible with tight_layout", "Data has no positive values, and therefore cannot be log-scaled.", ".*invalid value encountered in true_divide.*"]: warnings.filterwarnings("ignore", msg) def train(cf, logger): """ performs the training routine for a given fold. saves plots and selected parameters to the experiment dir specified in the configs. logs to file and tensorboard. """ logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format( cf.dim, cf.fold, cf.exp_dir, cf.model)) logger.time("train_val") # -------------- inits and settings ----------------- net = model.net(cf, logger).cuda() if cf.optimizer == "ADAM": optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay) elif cf.optimizer == "SGD": optimizer = torch.optim.SGD(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay, momentum=0.3) if cf.dynamic_lr_scheduling: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=cf.scheduling_mode, factor=cf.lr_decay_factor, patience=cf.scheduling_patience) model_selector = utils.ModelSelector(cf, logger) starting_epoch = 1 if cf.resume_from_checkpoint: starting_epoch = utils.load_checkpoint(cf.resume_from_checkpoint, net, optimizer) logger.info('resumed from checkpoint {} at epoch {}'.format(cf.resume_from_checkpoint, starting_epoch)) # prepare monitoring monitor_metrics = utils.prepare_monitoring(cf) logger.info('loading dataset and initializing batch generators...') batch_gen = data_loader.get_train_generators(cf, logger) # -------------- training ----------------- for epoch in range(starting_epoch, cf.num_epochs + 1): logger.info('starting training epoch {}/{}'.format(epoch, cf.num_epochs)) logger.time("train_epoch") net.train() train_results_list = [] train_evaluator = Evaluator(cf, logger, mode='train') for i in range(cf.num_train_batches): logger.time("train_batch_loadfw") batch = next(batch_gen['train']) batch_gen['train'].generator.stats['roi_counts'] += batch['roi_counts'] batch_gen['train'].generator.stats['empty_samples_count'] += batch['empty_samples_count'] logger.time("train_batch_loadfw") logger.time("train_batch_netfw") results_dict = net.train_forward(batch) logger.time("train_batch_netfw") logger.time("train_batch_bw") optimizer.zero_grad() results_dict['torch_loss'].backward() if cf.clip_norm: torch.nn.utils.clip_grad_norm_(net.parameters(), cf.clip_norm, norm_type=2) # gradient clipping optimizer.step() train_results_list.append(({k:v for k,v in results_dict.items() if k != "seg_preds"}, batch["pid"])) # slim res dict if not cf.server_env: print("\rFinished training batch " + "{}/{} in {:.1f}s ({:.2f}/{:.2f} forw load/net, {:.2f} backw).".format(i+1, cf.num_train_batches, logger.get_time("train_batch_loadfw")+ logger.get_time("train_batch_netfw") +logger.time("train_batch_bw"), logger.get_time("train_batch_loadfw",reset=True), logger.get_time("train_batch_netfw", reset=True), logger.get_time("train_batch_bw", reset=True)), end="", flush=True) print() #--------------- train eval ---------------- if (epoch-1)%cf.plot_frequency==0: # view an example batch logger.time("train_plot") plg.view_batch(cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True, out_file=os.path.join(cf.plot_dir, 'batch_example_train_{}.png'.format(cf.fold))) logger.info("generated train-example plot in {:.2f}s".format(logger.time("train_plot"))) logger.time("evals") _, monitor_metrics['train'] = train_evaluator.evaluate_predictions(train_results_list, monitor_metrics['train']) logger.time("evals") logger.time("train_epoch", toggle=False) del train_results_list #----------- validation ------------ logger.info('starting validation in mode {}.'.format(cf.val_mode)) logger.time("val_epoch") with torch.no_grad(): net.eval() val_results_list = [] val_evaluator = Evaluator(cf, logger, mode=cf.val_mode) val_predictor = Predictor(cf, net, logger, mode='val') for i in range(batch_gen['n_val']): logger.time("val_batch") batch = next(batch_gen[cf.val_mode]) if cf.val_mode == 'val_patient': results_dict = val_predictor.predict_patient(batch) elif cf.val_mode == 'val_sampling': results_dict = net.train_forward(batch, is_validation=True) val_results_list.append([results_dict, batch["pid"]]) if not cf.server_env: print("\rFinished validation {} {}/{} in {:.1f}s.".format('patient' if cf.val_mode=='val_patient' else 'batch', i + 1, batch_gen['n_val'], logger.time("val_batch")), end="", flush=True) print() #------------ val eval ------------- if (epoch - 1) % cf.plot_frequency == 0: logger.time("val_plot") plg.view_batch(cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True, out_file=os.path.join(cf.plot_dir, 'batch_example_val_{}.png'.format(cf.fold))) logger.info("generated val plot in {:.2f}s".format(logger.time("val_plot"))) logger.time("evals") _, monitor_metrics['val'] = val_evaluator.evaluate_predictions(val_results_list, monitor_metrics['val']) model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch) del val_results_list #----------- monitoring ------------- monitor_metrics.update({"lr": {str(g) : group['lr'] for (g, group) in enumerate(optimizer.param_groups)}}) logger.metrics2tboard(monitor_metrics, global_step=epoch) logger.time("evals") logger.info('finished epoch {}/{}, took {:.2f}s. train total: {:.2f}s, average: {:.2f}s. val total: {:.2f}s, average: {:.2f}s.'.format( epoch, cf.num_epochs, logger.get_time("train_epoch")+logger.time("val_epoch"), logger.get_time("train_epoch"), logger.get_time("train_epoch", reset=True)/cf.num_train_batches, logger.get_time("val_epoch"), logger.get_time("val_epoch", reset=True)/batch_gen["n_val"])) logger.info("time for evals: {:.2f}s".format(logger.get_time("evals", reset=True))) #-------------- scheduling ----------------- if not cf.dynamic_lr_scheduling: for param_group in optimizer.param_groups: param_group['lr'] = cf.learning_rate[epoch-1] else: scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1]) logger.time("train_val") logger.info("Training and validating over {} epochs took {}".format(cf.num_epochs, logger.get_time("train_val", format="hms", reset=True))) batch_gen['train'].generator.print_stats(logger, plot=True) def test(cf, logger, max_fold=None): """performs testing for a given fold (or held out set). saves stats in evaluator. """ logger.time("test_fold") logger.info('starting testing model of fold {} in exp {}'.format(cf.fold, cf.exp_dir)) net = model.net(cf, logger).cuda() batch_gen = data_loader.get_test_generator(cf, logger) test_predictor = Predictor(cf, net, logger, mode='test') test_results_list = test_predictor.predict_test_set(batch_gen, return_results = not hasattr( cf, "eval_test_separately") or not cf.eval_test_separately) if test_results_list is not None: test_evaluator = Evaluator(cf, logger, mode='test') test_evaluator.evaluate_predictions(test_results_list) test_evaluator.score_test_df(max_fold=max_fold) logger.info('Testing of fold {} took {}.'.format(cf.fold, logger.get_time("test_fold", reset=True, format="hms"))) if __name__ == '__main__': stime = time.time() parser = argparse.ArgumentParser() - parser.add_argument('-m', '--mode', type=str, default='train_test', help='one out of: create_exp, analysis, train, train_test, or test') - parser.add_argument('-f', '--folds', nargs='+', type=int, default=None, help='None runs over all folds in CV. otherwise specify list of folds.') + parser.add_argument('--dataset_name', type=str, default='toy', + help="path to the dataset-specific code in source_dir/datasets") parser.add_argument('--exp_dir', type=str, default='/home/gregor/Documents/regrcnn/datasets/toy/experiments/dev', help='path to experiment dir. will be created if non existent.') + parser.add_argument('-m', '--mode', type=str, default='train_test', help='one out of: create_exp, analysis, train, train_test, or test') + parser.add_argument('-f', '--folds', nargs='+', type=int, default=None, help='None runs over all folds in CV. otherwise specify list of folds.') parser.add_argument('--server_env', default=False, action='store_true', help='change IO settings to deploy models on a cluster.') parser.add_argument('--data_dest', type=str, default=None, help="path to final data folder if different from config") parser.add_argument('--use_stored_settings', default=False, action='store_true', help='load configs from existing exp_dir instead of source dir. always done for testing, ' 'but can be set to true to do the same for training. useful in job scheduler environment, ' 'where source code might change before the job actually runs.') parser.add_argument('--resume_from_checkpoint', type=str, default=None, help='path to checkpoint. if resuming from checkpoint, the desired fold still needs to be parsed via --folds.') - parser.add_argument('--dataset_name', type=str, default='toy', help="path to the dataset-specific code in source_dir/datasets") parser.add_argument('-d', '--dev', default=False, action='store_true', help="development mode: shorten everything") args = parser.parse_args() args.dataset_name = os.path.join("datasets", args.dataset_name) if not "datasets" in args.dataset_name else args.dataset_name folds = args.folds resume_from_checkpoint = None if args.resume_from_checkpoint in ['None', 'none'] else args.resume_from_checkpoint if args.mode == 'create_exp': cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=False) logger = utils.get_logger(cf.exp_dir, cf.server_env, -1) logger.info('created experiment directory at {}'.format(args.exp_dir)) elif args.mode == 'train' or args.mode == 'train_test': cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, args.use_stored_settings) if args.dev: folds = [0,1] cf.batch_size, cf.num_epochs, cf.min_save_thresh, cf.save_n_models = 3 if cf.dim==2 else 1, 1, 0, 1 cf.num_train_batches, cf.num_val_batches, cf.max_val_patients = 5, 1, 1 cf.test_n_epochs = cf.save_n_models cf.max_test_patients = 1 torch.backends.cudnn.benchmark = cf.dim==3 else: torch.backends.cudnn.benchmark = cf.cuda_benchmark if args.data_dest is not None: cf.data_dest = args.data_dest logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval) data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py')) model = utils.import_module('model', cf.model_path) logger.info("loaded model from {}".format(cf.model_path)) if folds is None: folds = range(cf.n_cv_splits) for fold in folds: """k-fold cross-validation: the dataset is split into k equally-sized folds, one used for validation, one for testing, the rest for training. This loop iterates k-times over the dataset, cyclically moving the splits. k==folds, fold in [0,folds) says which split is used for testing. """ cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold)) cf.fold, logger.fold = fold, fold cf.resume_from_checkpoint = resume_from_checkpoint if not os.path.exists(cf.fold_dir): os.mkdir(cf.fold_dir) train(cf, logger) cf.resume_from_checkpoint = None if args.mode == 'train_test': test(cf, logger) elif args.mode == 'test': cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False) if args.data_dest is not None: cf.data_dest = args.data_dest logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval) data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py')) model = utils.import_module('model', cf.model_path) logger.info("loaded model from {}".format(cf.model_path)) fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")]) if folds is None: folds = range(cf.n_cv_splits) if args.dev: folds = folds[:2] cf.batch_size, cf.max_test_patients, cf.test_n_epochs = 1 if cf.dim==2 else 1, 2, 2 else: torch.backends.cudnn.benchmark = cf.cuda_benchmark for fold in folds: cf.fold = fold cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold)) if cf.fold_dir in fold_dirs: test(cf, logger, max_fold=max([int(f[-1]) for f in fold_dirs])) else: logger.info("Skipping fold {} since no model parameters found.".format(fold)) # load raw predictions saved by predictor during testing, run aggregation algorithms and evaluation. elif args.mode == 'analysis': """ analyse already saved predictions. """ cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False) logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval) if cf.held_out_test_set and not cf.eval_test_fold_wise: predictor = Predictor(cf, net=None, logger=logger, mode='analysis') results_list = predictor.load_saved_predictions() logger.info('starting evaluation...') cf.fold = 0 evaluator = Evaluator(cf, logger, mode='test') evaluator.evaluate_predictions(results_list) evaluator.score_test_df(max_fold=0) else: fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")]) if args.dev: fold_dirs = fold_dirs[:1] if folds is None: folds = range(cf.n_cv_splits) for fold in folds: cf.fold = fold cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold)) if cf.fold_dir in fold_dirs: predictor = Predictor(cf, net=None, logger=logger, mode='analysis') results_list = predictor.load_saved_predictions() # results_list[x][1] is pid, results_list[x][0] is list of len samples-per-patient, each entry hlds # list of boxes per that sample, i.e., len(results_list[x][y][0]) would be nr of boxes in sample y of patient x logger.info('starting evaluation...') evaluator = Evaluator(cf, logger, mode='test') evaluator.evaluate_predictions(results_list) max_fold = max([int(f[-1]) for f in fold_dirs]) evaluator.score_test_df(max_fold=max_fold) else: logger.info("Skipping fold {} since no model parameters found.".format(fold)) else: raise ValueError('mode "{}" specified in args is not implemented.'.format(args.mode)) mins, secs = divmod((time.time() - stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) logger.info("{} total runtime: {}".format(os.path.split(__file__)[1], t)) del logger torch.cuda.empty_cache() diff --git a/models/mrcnn.py b/models/mrcnn.py index cdfcfb2..9e9c157 100644 --- a/models/mrcnn.py +++ b/models/mrcnn.py @@ -1,760 +1,755 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn published under MIT license. """ import os from multiprocessing import Pool import time import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.utils import utils.model_utils as mutils import utils.exp_utils as utils class RPN(nn.Module): """ Region Proposal Network. """ def __init__(self, cf, conv): super(RPN, self).__init__() self.dim = conv.dim self.conv_shared = conv(cf.end_filts, cf.n_rpn_features, ks=3, stride=cf.rpn_anchor_stride, pad=1, relu=cf.relu) self.conv_class = conv(cf.n_rpn_features, 2 * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None) self.conv_bbox = conv(cf.n_rpn_features, 2 * self.dim * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None) def forward(self, x): """ :param x: input feature maps (b, in_channels, y, x, (z)) :return: rpn_class_logits (b, 2, n_anchors) :return: rpn_probs_logits (b, 2, n_anchors) :return: rpn_bbox (b, 2 * dim, n_anchors) """ # Shared convolutional base of the RPN. x = self.conv_shared(x) # Anchor Score. (batch, anchors per location * 2, y, x, (z)). rpn_class_logits = self.conv_class(x) # Reshape to (batch, 2, anchors) axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) rpn_class_logits = rpn_class_logits.permute(*axes) rpn_class_logits = rpn_class_logits.contiguous() rpn_class_logits = rpn_class_logits.view(x.size()[0], -1, 2) # Softmax on last dimension (fg vs. bg). rpn_probs = F.softmax(rpn_class_logits, dim=2) # Bounding box refinement. (batch, anchors_per_location * (y, x, (z), log(h), log(w), (log(d)), y, x, (z)) rpn_bbox = self.conv_bbox(x) # Reshape to (batch, 2*dim, anchors) rpn_bbox = rpn_bbox.permute(*axes) rpn_bbox = rpn_bbox.contiguous() rpn_bbox = rpn_bbox.view(x.size()[0], -1, self.dim * 2) return [rpn_class_logits, rpn_probs, rpn_bbox] class Classifier(nn.Module): """ Head network for classification and bounding box refinement. Performs RoiAlign, processes resulting features through a shared convolutional base and finally branches off the classifier- and regression head. """ def __init__(self, cf, conv): super(Classifier, self).__init__() self.cf = cf self.dim = conv.dim self.in_channels = cf.end_filts self.pool_size = cf.pool_size self.pyramid_levels = cf.pyramid_levels # instance_norm does not work with spatial dims (1, 1, (1)) norm = cf.norm if cf.norm != 'instance_norm' else None self.conv1 = conv(cf.end_filts, cf.end_filts * 4, ks=self.pool_size, stride=1, norm=norm, relu=cf.relu) self.conv2 = conv(cf.end_filts * 4, cf.end_filts * 4, ks=1, stride=1, norm=norm, relu=cf.relu) self.linear_bbox = nn.Linear(cf.end_filts * 4, cf.head_classes * 2 * self.dim) if 'regression' in self.cf.prediction_tasks: self.linear_regressor = nn.Linear(cf.end_filts * 4, cf.head_classes * cf.regression_n_features) self.rg_n_feats = cf.regression_n_features #classify into bins of regression values elif 'regression_bin' in self.cf.prediction_tasks: self.linear_regressor = nn.Linear(cf.end_filts * 4, cf.head_classes * len(cf.bin_labels)) self.rg_n_feats = len(cf.bin_labels) else: self.linear_regressor = lambda x: torch.zeros((x.shape[0], cf.head_classes * 1), dtype=torch.float32).fill_(float('NaN')).cuda() self.rg_n_feats = 1 #cf.regression_n_features if 'class' in self.cf.prediction_tasks: self.linear_class = nn.Linear(cf.end_filts * 4, cf.head_classes) else: assert cf.head_classes == 2, "#head classes {} needs to be 2 (bg/fg) when not predicting classes".format(cf.head_classes) self.linear_class = lambda x: torch.zeros((x.shape[0], cf.head_classes), dtype=torch.float64).cuda() def forward(self, x, rois): """ :param x: input feature maps (b, in_channels, y, x, (z)) :param rois: normalized box coordinates as proposed by the RPN to be forwarded through the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements have been merged to one vector, while the origin info has been stored for re-allocation. :return: mrcnn_class_logits (n_proposals, n_head_classes) :return: mrcnn_bbox (n_proposals, n_head_classes, 2 * dim) predicted corrections to be applied to proposals for refinement. """ x = mutils.pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim) x = self.conv1(x) x = self.conv2(x) x = x.view(-1, self.in_channels * 4) mrcnn_bbox = self.linear_bbox(x) mrcnn_bbox = mrcnn_bbox.view(mrcnn_bbox.size()[0], -1, self.dim * 2) mrcnn_class_logits = self.linear_class(x) mrcnn_regress = self.linear_regressor(x) mrcnn_regress = mrcnn_regress.view(mrcnn_regress.size()[0], -1, self.rg_n_feats) return [mrcnn_bbox, mrcnn_class_logits, mrcnn_regress] class Mask(nn.Module): """ Head network for proposal-based mask segmentation. Performs RoiAlign, some convolutions and applies sigmoid on the output logits to allow for overlapping classes. """ def __init__(self, cf, conv): super(Mask, self).__init__() self.pool_size = cf.mask_pool_size self.pyramid_levels = cf.pyramid_levels self.dim = conv.dim self.conv1 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) self.conv2 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) self.conv3 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) self.conv4 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) if conv.dim == 2: self.deconv = nn.ConvTranspose2d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2) else: self.deconv = nn.ConvTranspose3d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2) self.relu = nn.ReLU(inplace=True) if cf.relu == 'relu' else nn.LeakyReLU(inplace=True) self.conv5 = conv(cf.end_filts, cf.head_classes, ks=1, stride=1, relu=None) self.sigmoid = nn.Sigmoid() def forward(self, x, rois): """ :param x: input feature maps (b, in_channels, y, x, (z)) :param rois: normalized box coordinates as proposed by the RPN to be forwarded through the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements have been merged to one vector, while the origin info has been stored for re-allocation. :return: x: masks (n_sampled_proposals (n_detections in inference), n_classes, y, x, (z)) """ x = mutils.pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim) x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = self.relu(self.deconv(x)) x = self.conv5(x) x = self.sigmoid(x) return x ############################################################ # Loss Functions ############################################################ def compute_rpn_class_loss(rpn_class_logits, rpn_match, shem_poolsize): """ :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. :param rpn_class_logits: (n_anchors, 2). logits from RPN classifier. :param SHEM_poolsize: int. factor of top-k candidates to draw from per negative sample (stochastic-hard-example-mining). :return: loss: torch tensor :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training. """ # Filter out netural anchors pos_indices = torch.nonzero(rpn_match == 1) neg_indices = torch.nonzero(rpn_match == -1) # loss for positive samples if not 0 in pos_indices.size(): pos_indices = pos_indices.squeeze(1) roi_logits_pos = rpn_class_logits[pos_indices] pos_loss = F.cross_entropy(roi_logits_pos, torch.LongTensor([1] * pos_indices.shape[0]).cuda()) else: pos_loss = torch.FloatTensor([0]).cuda() # loss for negative samples: draw hard negative examples (SHEM) # that match the number of positive samples, but at least 1. if not 0 in neg_indices.size(): neg_indices = neg_indices.squeeze(1) roi_logits_neg = rpn_class_logits[neg_indices] negative_count = np.max((1, pos_indices.cpu().data.numpy().size)) roi_probs_neg = F.softmax(roi_logits_neg, dim=1) neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize) neg_loss = F.cross_entropy(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda()) np_neg_ix = neg_ix.cpu().data.numpy() #print("pos, neg count", pos_indices.cpu().data.numpy().size, negative_count) else: neg_loss = torch.FloatTensor([0]).cuda() np_neg_ix = np.array([]).astype('int32') loss = (pos_loss + neg_loss) / 2 return loss, np_neg_ix def compute_rpn_bbox_loss(rpn_pred_deltas, rpn_target_deltas, rpn_match): """ :param rpn_target_deltas: (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))). Uses 0 padding to fill in unsed bbox deltas. :param rpn_pred_deltas: predicted deltas from RPN. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))) :param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. :return: loss: torch 1D tensor. """ if not 0 in torch.nonzero(rpn_match == 1).size(): indices = torch.nonzero(rpn_match == 1).squeeze(1) # Pick bbox deltas that contribute to the loss rpn_pred_deltas = rpn_pred_deltas[indices] # Trim target bounding box deltas to the same length as rpn_bbox. target_deltas = rpn_target_deltas[:rpn_pred_deltas.size()[0], :] # Smooth L1 loss loss = F.smooth_l1_loss(rpn_pred_deltas, target_deltas) else: loss = torch.FloatTensor([0]).cuda() return loss def compute_mrcnn_bbox_loss(mrcnn_pred_deltas, mrcnn_target_deltas, target_class_ids): """ :param mrcnn_target_deltas: (n_sampled_rois, (dy, dx, (dz), log(dh), log(dw), (log(dh))) :param mrcnn_pred_deltas: (n_sampled_rois, n_classes, (dy, dx, (dz), log(dh), log(dw), (log(dh))) :param target_class_ids: (n_sampled_rois) :return: loss: torch 1D tensor. """ if not 0 in torch.nonzero(target_class_ids > 0).size(): positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0] positive_roi_class_ids = target_class_ids[positive_roi_ix].long() target_bbox = mrcnn_target_deltas[positive_roi_ix, :].detach() pred_bbox = mrcnn_pred_deltas[positive_roi_ix, positive_roi_class_ids, :] loss = F.smooth_l1_loss(pred_bbox, target_bbox) else: loss = torch.FloatTensor([0]).cuda() return loss def compute_mrcnn_mask_loss(pred_masks, target_masks, target_class_ids): """ :param target_masks: (n_sampled_rois, y, x, (z)) A float32 tensor of values 0 or 1. Uses zero padding to fill array. :param pred_masks: (n_sampled_rois, n_classes, y, x, (z)) float32 tensor with values between [0, 1]. :param target_class_ids: (n_sampled_rois) :return: loss: torch 1D tensor. """ if not 0 in torch.nonzero(target_class_ids > 0).size(): # Only positive ROIs contribute to the loss. And only # the class-specific mask of each ROI. positive_ix = torch.nonzero(target_class_ids > 0)[:, 0] positive_class_ids = target_class_ids[positive_ix].long() y_true = target_masks[positive_ix, :, :].detach() y_pred = pred_masks[positive_ix, positive_class_ids, :, :] loss = F.binary_cross_entropy(y_pred, y_true) else: loss = torch.FloatTensor([0]).cuda() return loss def compute_mrcnn_class_loss(tasks, pred_class_logits, target_class_ids): """ :param pred_class_logits: (n_sampled_rois, n_classes) :param target_class_ids: (n_sampled_rois) batch dimension was merged into roi dimension. :return: loss: torch 1D tensor. """ if 'class' in tasks and not 0 in target_class_ids.size(): loss = F.cross_entropy(pred_class_logits, target_class_ids.long()) else: loss = torch.FloatTensor([0.]).cuda() return loss def compute_mrcnn_regression_loss(tasks, pred, target, target_class_ids): """regression loss is a distance metric between target vector and predicted regression vector. :param pred: (n_sampled_rois, n_classes, [n_rg_feats if real regression or 1 if rg_bin task) :param target: (n_sampled_rois, [n_rg_feats or n_rg_bins]) :return: differentiable loss, torch 1D tensor on cuda """ if not 0 in target.shape and not 0 in torch.nonzero(target_class_ids > 0).shape: positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0] positive_roi_class_ids = target_class_ids[positive_roi_ix].long() target = target[positive_roi_ix].detach() pred = pred[positive_roi_ix, positive_roi_class_ids] if "regression_bin" in tasks: loss = F.cross_entropy(pred, target.long()) else: loss = F.smooth_l1_loss(pred, target) #loss = F.mse_loss(pred, target) else: loss = torch.FloatTensor([0.]).cuda() return loss ############################################################ # Detection Layer ############################################################ def compute_roi_scores(tasks, batch_rpn_proposals, mrcnn_cl_logits): """ Depending on the predicition tasks: if no class prediction beyong fg/bg (--> means no additional class head was applied) use RPN objectness scores as roi scores, otherwise class head scores. :param cf: :param batch_rpn_proposals: :param mrcnn_cl_logits: :return: """ if not 'class' in tasks: scores = batch_rpn_proposals[:, :, -1].view(-1, 1) scores = torch.cat((1 - scores, scores), dim=1) else: scores = F.softmax(mrcnn_cl_logits, dim=1) return scores ############################################################ # MaskRCNN Class ############################################################ class net(nn.Module): def __init__(self, cf, logger): super(net, self).__init__() self.cf = cf self.logger = logger self.build() loss_order = ['rpn_class', 'rpn_bbox', 'mrcnn_bbox', 'mrcnn_mask', 'mrcnn_class', 'mrcnn_rg'] if hasattr(cf, "mrcnn_loss_weights"): - #bring into right order + # bring into right order self.loss_weights = np.array([cf.mrcnn_loss_weights[k] for k in loss_order]) else: self.loss_weights = np.array([1.]*len(loss_order)) if self.cf.weight_init=="custom": logger.info("Tried to use custom weight init which is not defined. Using pytorch default.") elif self.cf.weight_init: mutils.initialize_weights(self) else: logger.info("using default pytorch weight init") def build(self): """Build Mask R-CNN architecture.""" # Image size must be dividable by 2 multiple times. h, w = self.cf.patch_size[:2] if h / 2**5 != int(h / 2**5) or w / 2**5 != int(w / 2**5): raise Exception("Image size must be divisible by 2 at least 5 times " "to avoid fractions when downscaling and upscaling." "For example, use 256, 288, 320, 384, 448, 512, ... etc.,i.e.," "any number x*32 will do!") # instantiate abstract multi-dimensional conv generator and load backbone module. backbone = utils.import_module('bbone', self.cf.backbone_path) self.logger.info("loaded backbone from {}".format(self.cf.backbone_path)) conv = backbone.ConvGenerator(self.cf.dim) # build Anchors, FPN, RPN, Classifier / Bbox-Regressor -head, Mask-head self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf) self.anchors = torch.from_numpy(self.np_anchors).float().cuda() self.fpn = backbone.FPN(self.cf, conv, relu_enc=self.cf.relu, operate_stride1=False).cuda() self.rpn = RPN(self.cf, conv) self.classifier = Classifier(self.cf, conv) self.mask = Mask(self.cf, conv) def forward(self, img, is_training=True): """ :param img: input images (b, c, y, x, (z)). :return: rpn_pred_logits: (b, n_anchors, 2) :return: rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d)))) :return: batch_proposal_boxes: (b, n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix)) only for monitoring/plotting. :return: detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score) :return: detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head. """ # extract features. fpn_outs = self.fpn(img) rpn_feature_maps = [fpn_outs[i] for i in self.cf.pyramid_levels] self.mrcnn_feature_maps = rpn_feature_maps # loop through pyramid layers and apply RPN. layer_outputs = [ self.rpn(p_feats) for p_feats in rpn_feature_maps ] # concatenate layer outputs. # convert from list of lists of level outputs to list of lists of outputs across levels. # e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]] outputs = list(zip(*layer_outputs)) outputs = [torch.cat(list(o), dim=1) for o in outputs] rpn_pred_logits, rpn_pred_probs, rpn_pred_deltas = outputs # # # generate proposals: apply predicted deltas to anchors and filter by foreground scores from RPN classifier. proposal_count = self.cf.post_nms_rois_training if is_training else self.cf.post_nms_rois_inference batch_normed_props, batch_unnormed_props = mutils.refine_proposals(rpn_pred_probs, rpn_pred_deltas, proposal_count, self.anchors, self.cf) # merge batch dimension of proposals while storing allocation info in coordinate dimension. batch_ixs = torch.arange( batch_normed_props.shape[0]).cuda().unsqueeze(1).repeat(1,batch_normed_props.shape[1]).view(-1).float() rpn_rois = batch_normed_props[:, :, :-1].view(-1, batch_normed_props[:, :, :-1].shape[2]) self.rpn_rois_batch_info = torch.cat((rpn_rois, batch_ixs.unsqueeze(1)), dim=1) # this is the first of two forward passes in the second stage, where no activations are stored for backprop. # here, all proposals are forwarded (with virtual_batch_size = batch_size * post_nms_rois.) # for inference/monitoring as well as sampling of rois for the loss functions. # processed in chunks of roi_chunk_size to re-adjust to gpu-memory. chunked_rpn_rois = self.rpn_rois_batch_info.split(self.cf.roi_chunk_size) bboxes_list, class_logits_list, regressions_list = [], [], [] with torch.no_grad(): for chunk in chunked_rpn_rois: chunk_bboxes, chunk_class_logits, chunk_regressions = self.classifier(self.mrcnn_feature_maps, chunk) bboxes_list.append(chunk_bboxes) class_logits_list.append(chunk_class_logits) regressions_list.append(chunk_regressions) mrcnn_bbox = torch.cat(bboxes_list, 0) mrcnn_class_logits = torch.cat(class_logits_list, 0) mrcnn_regressions = torch.cat(regressions_list, 0) self.mrcnn_roi_scores = compute_roi_scores(self.cf.prediction_tasks, batch_normed_props, mrcnn_class_logits) # refine classified proposals, filter and return final detections. # returns (cf.max_inst_per_batch_element, n_coords+1+...) detections = mutils.refine_detections(self.cf, batch_ixs, rpn_rois, mrcnn_bbox, self.mrcnn_roi_scores, mrcnn_regressions) # forward remaining detections through mask-head to generate corresponding masks. scale = [img.shape[2]] * 4 + [img.shape[-1]] * 2 scale = torch.from_numpy(np.array(scale[:self.cf.dim * 2] + [1])[None]).float().cuda() # first self.cf.dim * 2 entries on axis 1 are always the box coords, +1 is batch_ix detection_boxes = detections[:, :self.cf.dim * 2 + 1] / scale with torch.no_grad(): detection_masks = self.mask(self.mrcnn_feature_maps, detection_boxes) return [rpn_pred_logits, rpn_pred_deltas, batch_unnormed_props, detections, detection_masks] def loss_samples_forward(self, batch_gt_boxes, batch_gt_masks, batch_gt_class_ids, batch_gt_regressions=None): """ this is the second forward pass through the second stage (features from stage one are re-used). samples few rois in loss_example_mining and forwards only those for loss computation. :param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels. :param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates. :param batch_gt_masks: list over batch elements. Each element is binary mask of shape (n_gt_rois, y, x, (z), c) :return: sample_logits: (n_sampled_rois, n_classes) predicted class scores. :return: sample_deltas: (n_sampled_rois, n_classes, 2 * dim) predicted corrections to be applied to proposals for refinement. :return: sample_mask: (n_sampled_rois, n_classes, y, x, (z)) predicted masks per class and proposal. :return: sample_target_class_ids: (n_sampled_rois) target class labels of sampled proposals. :return: sample_target_deltas: (n_sampled_rois, 2 * dim) target deltas of sampled proposals for box refinement. :return: sample_target_masks: (n_sampled_rois, y, x, (z)) target masks of sampled proposals. :return: sample_proposals: (n_sampled_rois, 2 * dim) RPN output for sampled proposals. only for monitoring/plotting. """ # sample rois for loss and get corresponding targets for all Mask R-CNN head network losses. sample_ics, sample_target_deltas, sample_target_mask, sample_target_class_ids, sample_target_regressions = \ mutils.loss_example_mining(self.cf, self.rpn_rois_batch_info, batch_gt_boxes, batch_gt_masks, self.mrcnn_roi_scores, batch_gt_class_ids, batch_gt_regressions) # re-use feature maps and RPN output from first forward pass. sample_proposals = self.rpn_rois_batch_info[sample_ics] if not 0 in sample_proposals.size(): sample_deltas, sample_logits, sample_regressions = self.classifier(self.mrcnn_feature_maps, sample_proposals) sample_mask = self.mask(self.mrcnn_feature_maps, sample_proposals) else: sample_logits = torch.FloatTensor().cuda() sample_deltas = torch.FloatTensor().cuda() sample_regressions = torch.FloatTensor().cuda() sample_mask = torch.FloatTensor().cuda() return [sample_deltas, sample_mask, sample_logits, sample_regressions, sample_proposals, sample_target_deltas, sample_target_mask, sample_target_class_ids, sample_target_regressions] def get_results(self, img_shape, detections, detection_masks, box_results_list=None, return_masks=True): """ Restores batch dimension of merged detections, unmolds detections, creates and fills results dict. :param img_shape: :param detections: shape (n_final_detections, len(info)), where info=( y1, x1, y2, x2, (z1,z2), batch_ix, pred_class_id, pred_score ) :param detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head. :param box_results_list: None or list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off). :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, 1] only fg. vs. bg for now. class-specific return of masks will come with implementation of instance segmentation evaluation. """ detections = detections.cpu().data.numpy() if self.cf.dim == 2: detection_masks = detection_masks.permute(0, 2, 3, 1).cpu().data.numpy() else: detection_masks = detection_masks.permute(0, 2, 3, 4, 1).cpu().data.numpy() # det masks shape now (n_dets, y,x(,z), n_classes) # restore batch dimension of merged detections using the batch_ix info. batch_ixs = detections[:, self.cf.dim*2] detections = [detections[batch_ixs == ix] for ix in range(img_shape[0])] mrcnn_mask = [detection_masks[batch_ixs == ix] for ix in range(img_shape[0])] - #mrcnn_mask: shape (b_size, variable, variable, n_classes), variable bc depends on single instance mask size + # mrcnn_mask: shape (b_size, variable, variable, n_classes), variable bc depends on single instance mask size if box_results_list == None: # for test_forward, where no previous list exists. box_results_list = [[] for _ in range(img_shape[0])] # seg_logits == seg_probs in mrcnn since mask head finishes with sigmoid (--> image space = [0,1]) seg_probs = [] # loop over batch and unmold detections. for ix in range(img_shape[0]): # final masks are one-hot encoded (b, n_classes, y, x, (z)) final_masks = np.zeros((self.cf.num_classes + 1, *img_shape[2:])) #+1 for bg, 0.5 bc mask head classifies only bg/fg with logits between 0,1--> bg is <0.5 if self.cf.num_classes + 1 != self.cf.num_seg_classes: self.logger.warning("n of roi-classifier head classes {} doesnt match cf.num_seg_classes {}".format( self.cf.num_classes + 1, self.cf.num_seg_classes)) if not 0 in detections[ix].shape: boxes = detections[ix][:, :self.cf.dim*2].astype(np.int32) class_ids = detections[ix][:, self.cf.dim*2 + 1].astype(np.int32) scores = detections[ix][:, self.cf.dim*2 + 2] masks = mrcnn_mask[ix][np.arange(boxes.shape[0]), ..., class_ids] regressions = detections[ix][:,self.cf.dim*2+3:] # Filter out detections with zero area. Often only happens in early # stages of training when the network weights are still a bit random. if self.cf.dim == 2: exclude_ix = np.where((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0] else: exclude_ix = np.where( (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4]) <= 0)[0] if exclude_ix.shape[0] > 0: boxes = np.delete(boxes, exclude_ix, axis=0) masks = np.delete(masks, exclude_ix, axis=0) class_ids = np.delete(class_ids, exclude_ix, axis=0) scores = np.delete(scores, exclude_ix, axis=0) regressions = np.delete(regressions, exclude_ix, axis=0) # Resize masks to original image size and set boundary threshold. if return_masks: for i in range(masks.shape[0]): #masks per this batch instance/element/image # Convert neural network mask to full size mask if self.cf.dim == 2: full_mask = mutils.unmold_mask_2D(masks[i], boxes[i], img_shape[2:]) else: full_mask = mutils.unmold_mask_3D(masks[i], boxes[i], img_shape[2:]) # take the maximum seg_logits per class of instances in that class, i.e., a pixel in a class # has the max seg_logit value over all instances of that class in one sample final_masks[class_ids[i]] = np.max((final_masks[class_ids[i]], full_mask), axis=0) final_masks[0] = np.full(final_masks[0].shape, 0.49999999) #effectively min_det_thres at 0.5 per pixel # add final predictions to results. if not 0 in boxes.shape: for ix2, coords in enumerate(boxes): box = {'box_coords': coords, 'box_type': 'det', 'box_score': scores[ix2], 'box_pred_class_id': class_ids[ix2]} #if (hasattr(self.cf, "convert_cl_to_rg") and self.cf.convert_cl_to_rg): if "regression_bin" in self.cf.prediction_tasks: # in this case, regression preds are actually the rg_bin_ids --> map to rg value the bin represents box['rg_bin'] = regressions[ix2].argmax() box['regression'] = self.cf.bin_id2rg_val[box['rg_bin']] else: box['regression'] = regressions[ix2] if hasattr(self.cf, "rg_val_to_bin_id") and \ any(['regression' in task for task in self.cf.prediction_tasks]): box.update({'rg_bin': self.cf.rg_val_to_bin_id(regressions[ix2])}) box_results_list[ix].append(box) # if no detections were made--> keep full bg mask (zeros). seg_probs.append(final_masks) # create and fill results dictionary. results_dict = {} results_dict['boxes'] = box_results_list results_dict['seg_preds'] = np.array(seg_probs) return results_dict def train_forward(self, batch, is_validation=False): """ train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data for processing, computes losses, and stores outputs in a dictionary. :param batch: dictionary containing 'data', 'seg', etc. :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]. 'torch_loss': 1D torch tensor for backprop. 'class_loss': classification loss for monitoring. """ img = batch['data'] gt_boxes = batch['bb_target'] axes = (0, 2, 3, 1) if self.cf.dim == 2 else (0, 2, 3, 4, 1) gt_masks = [np.transpose(batch['roi_masks'][ii], axes=axes) for ii in range(len(batch['roi_masks']))] gt_class_ids = batch['class_targets'] if 'regression' in self.cf.prediction_tasks: gt_regressions = batch["regression_targets"] elif 'regression_bin' in self.cf.prediction_tasks: gt_regressions = batch["rg_bin_targets"] else: gt_regressions = None img = torch.from_numpy(img).cuda().float() batch_rpn_class_loss = torch.FloatTensor([0]).cuda() batch_rpn_bbox_loss = torch.FloatTensor([0]).cuda() # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. box_results_list = [[] for _ in range(img.shape[0])] #forward passes. 1. general forward pass, where no activations are saved in second stage (for performance # monitoring and loss sampling). 2. second stage forward pass of sampled rois with stored activations for backprop. rpn_class_logits, rpn_pred_deltas, proposal_boxes, detections, detection_masks = self.forward(img) mrcnn_pred_deltas, mrcnn_pred_mask, mrcnn_class_logits, mrcnn_regressions, sample_proposals, \ mrcnn_target_deltas, target_mask, target_class_ids, target_regressions = \ self.loss_samples_forward(gt_boxes, gt_masks, gt_class_ids, gt_regressions) - stime = time.time() - #loop over batch + # loop over batch for b in range(img.shape[0]): if len(gt_boxes[b]) > 0: # add gt boxes to output list for tix in range(len(gt_boxes[b])): gt_box = {'box_type': 'gt', 'box_coords': batch['bb_target'][b][tix]} for name in self.cf.roi_items: gt_box.update({name: batch[name][b][tix]}) box_results_list[b].append(gt_box) # match gt boxes with anchors to generate targets for RPN losses. rpn_match, rpn_target_deltas = mutils.gt_anchor_matching(self.cf, self.np_anchors, gt_boxes[b]) # add positive anchors used for loss to output list for monitoring. pos_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == 1)][:, 0], img.shape[2:]) for p in pos_anchors: box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'}) else: rpn_match = np.array([-1]*self.np_anchors.shape[0]) rpn_target_deltas = np.array([0]) rpn_match_gpu = torch.from_numpy(rpn_match).cuda() rpn_target_deltas = torch.from_numpy(rpn_target_deltas).float().cuda() # compute RPN losses. rpn_class_loss, neg_anchor_ix = compute_rpn_class_loss(rpn_class_logits[b], rpn_match_gpu, self.cf.shem_poolsize) rpn_bbox_loss = compute_rpn_bbox_loss(rpn_pred_deltas[b], rpn_target_deltas, rpn_match_gpu) batch_rpn_class_loss += rpn_class_loss /img.shape[0] batch_rpn_bbox_loss += rpn_bbox_loss /img.shape[0] # add negative anchors used for loss to output list for monitoring. # neg_anchor_ix = neg_ix come from shem and mark positions in roi_probs_neg = rpn_class_logits[neg_indices] # with neg_indices = rpn_match == -1 neg_anchors = mutils.clip_boxes_numpy(self.np_anchors[rpn_match == -1][neg_anchor_ix], img.shape[2:]) for n in neg_anchors: box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'}) # add highest scoring proposals to output list for monitoring. rpn_proposals = proposal_boxes[b][proposal_boxes[b, :, -1].argsort()][::-1] for r in rpn_proposals[:self.cf.n_plot_rpn_props, :-1]: box_results_list[b].append({'box_coords': r, 'box_type': 'prop'}) - #print("gt anc matching, rpn losses loop time {:.4f}s".format(time.time()-stime)) # add positive and negative roi samples used for mrcnn losses to output list for monitoring. if not 0 in sample_proposals.shape: rois = mutils.clip_to_window(self.cf.window, sample_proposals).cpu().data.numpy() for ix, r in enumerate(rois): box_results_list[int(r[-1])].append({'box_coords': r[:-1] * self.cf.scale, 'box_type': 'pos_class' if target_class_ids[ix] > 0 else 'neg_class'}) # compute mrcnn losses. mrcnn_class_loss = compute_mrcnn_class_loss(self.cf.prediction_tasks, mrcnn_class_logits, target_class_ids) mrcnn_bbox_loss = compute_mrcnn_bbox_loss(mrcnn_pred_deltas, mrcnn_target_deltas, target_class_ids) mrcnn_regressions_loss = compute_mrcnn_regression_loss(self.cf.prediction_tasks, mrcnn_regressions, target_regressions, target_class_ids) # mrcnn can be run without pixelwise annotations available (Faster R-CNN mode). # In this case, the mask_loss is taken out of training. if not self.cf.frcnn_mode: mrcnn_mask_loss = compute_mrcnn_mask_loss(mrcnn_pred_mask, target_mask, target_class_ids) else: mrcnn_mask_loss = torch.FloatTensor([0]).cuda() loss = batch_rpn_class_loss + batch_rpn_bbox_loss +\ mrcnn_bbox_loss + mrcnn_mask_loss + mrcnn_class_loss + mrcnn_regressions_loss - # loss= [batch_rpn_class_loss, batch_rpn_bbox_loss, mrcnn_bbox_loss, mrcnn_mask_loss, mrcnn_class_loss, - # mrcnn_regressions_loss] - # loss = torch.tensor([part_loss * self.loss_weights[i] for i, part_loss in enumerate(loss)], requires_grad=True).sum(0, keepdim=True) # monitor RPN performance: detection count = the number of correctly matched proposals per fg-class. #dcount = [list(target_class_ids.cpu().data.numpy()).count(c) for c in np.arange(self.cf.head_classes)[1:]] #self.logger.info("regression loss {:.3f}".format(mrcnn_regressions_loss.item())) #self.logger.info("loss: {0:.2f}, rpn_class: {1:.2f}, rpn_bbox: {2:.2f}, mrcnn_class: {3:.2f}, mrcnn_bbox: {4:.2f}, " # "mrcnn_mask: {5:.2f}, dcount {6}".format(loss.item(), batch_rpn_class_loss.item(), # batch_rpn_bbox_loss.item(), mrcnn_class_loss.item(), mrcnn_bbox_loss.item(), mrcnn_mask_loss.item(), dcount)) # run unmolding of predictions for monitoring and merge all results to one dictionary. return_masks = self.cf.return_masks_in_val if is_validation else self.cf.return_masks_in_train results_dict = self.get_results(img.shape, detections, detection_masks, box_results_list, return_masks=return_masks) results_dict['seg_preds'] = results_dict['seg_preds'].argmax(axis=1).astype('uint8')[:,np.newaxis] if 'dice' in self.cf.metrics: results_dict['batch_dices'] = mutils.dice_per_batch_and_class( results_dict['seg_preds'], batch["seg"], self.cf.num_seg_classes, convert_to_ohe=True) results_dict['torch_loss'] = loss results_dict['class_loss'] = mrcnn_class_loss.item() results_dict['bbox_loss'] = mrcnn_bbox_loss.item() results_dict['rg_loss'] = mrcnn_regressions_loss.item() results_dict['rpn_class_loss'] = rpn_class_loss.item() results_dict['rpn_bbox_loss'] = rpn_bbox_loss.item() return results_dict def test_forward(self, batch, return_masks=True): """ test method. wrapper around forward pass of network without usage of any ground truth information. prepares input data for processing and stores outputs in a dictionary. :param batch: dictionary containing 'data' :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off). :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes] """ img = batch['data'] img = torch.from_numpy(img).float().cuda() _, _, _, detections, detection_masks = self.forward(img) results_dict = self.get_results(img.shape, detections, detection_masks, return_masks=return_masks) return results_dict \ No newline at end of file diff --git a/models/retina_net.py b/models/retina_net.py index d618e5a..f9dabd5 100644 --- a/models/retina_net.py +++ b/models/retina_net.py @@ -1,779 +1,779 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Retina Net. According to https://arxiv.org/abs/1708.02002""" import utils.model_utils as mutils import utils.exp_utils as utils import sys sys.path.append('../') from custom_extensions.nms import nms import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.utils class Classifier(nn.Module): def __init__(self, cf, conv): """ Builds the classifier sub-network. """ super(Classifier, self).__init__() self.dim = conv.dim self.n_classes = cf.head_classes n_input_channels = cf.end_filts n_features = cf.n_rpn_features n_output_channels = cf.n_anchors_per_pos * cf.head_classes anchor_stride = cf.rpn_anchor_stride self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, pad=1, relu=None) def forward(self, x): """ :param x: input feature map (b, in_c, y, x, (z)) :return: class_logits (b, n_anchors, n_classes) """ x = self.conv_1(x) x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) class_logits = self.conv_final(x) axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) class_logits = class_logits.permute(*axes) class_logits = class_logits.contiguous() class_logits = class_logits.view(x.shape[0], -1, self.n_classes) return [class_logits] class BBRegressor(nn.Module): def __init__(self, cf, conv): """ Builds the bb-regression sub-network. """ super(BBRegressor, self).__init__() self.dim = conv.dim n_input_channels = cf.end_filts n_features = cf.n_rpn_features n_output_channels = cf.n_anchors_per_pos * self.dim * 2 anchor_stride = cf.rpn_anchor_stride self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, pad=1, relu=None) def forward(self, x): """ :param x: input feature map (b, in_c, y, x, (z)) :return: bb_logits (b, n_anchors, dim * 2) """ x = self.conv_1(x) x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) bb_logits = self.conv_final(x) axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) bb_logits = bb_logits.permute(*axes) bb_logits = bb_logits.contiguous() bb_logits = bb_logits.view(x.shape[0], -1, self.dim * 2) return [bb_logits] class RoIRegressor(nn.Module): def __init__(self, cf, conv, rg_feats): """ Builds the RoI-item-regression sub-network. Regression items can be, e.g., malignancy scores of tumors. """ super(RoIRegressor, self).__init__() self.dim = conv.dim n_input_channels = cf.end_filts n_features = cf.n_rpn_features self.rg_feats = rg_feats n_output_channels = cf.n_anchors_per_pos * self.rg_feats anchor_stride = cf.rpn_anchor_stride self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, pad=1, relu=None) def forward(self, x): """ :param x: input feature map (b, in_c, y, x, (z)) :return: bb_logits (b, n_anchors, dim * 2) """ x = self.conv_1(x) x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) x = self.conv_final(x) axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) x = x.permute(*axes) x = x.contiguous() x = x.view(x.shape[0], -1, self.rg_feats) return [x] ############################################################ # Loss Functions ############################################################ # def compute_class_loss(anchor_matches, class_pred_logits, shem_poolsize=20): """ :param anchor_matches: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. :param class_pred_logits: (n_anchors, n_classes). logits from classifier sub-network. :param shem_poolsize: int. factor of top-k candidates to draw from per negative sample (online-hard-example-mining). :return: loss: torch tensor :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training. """ # Positive and Negative anchors contribute to the loss, # but neutral anchors (match value = 0) don't. pos_indices = torch.nonzero(anchor_matches > 0) neg_indices = torch.nonzero(anchor_matches == -1) # get positive samples and calucalte loss. if not 0 in pos_indices.size(): pos_indices = pos_indices.squeeze(1) roi_logits_pos = class_pred_logits[pos_indices] targets_pos = anchor_matches[pos_indices].detach() pos_loss = F.cross_entropy(roi_logits_pos, targets_pos.long()) else: pos_loss = torch.FloatTensor([0]).cuda() # get negative samples, such that the amount matches the number of positive samples, but at least 1. # get high scoring negatives by applying online-hard-example-mining. if not 0 in neg_indices.size(): neg_indices = neg_indices.squeeze(1) roi_logits_neg = class_pred_logits[neg_indices] negative_count = np.max((1, pos_indices.cpu().data.numpy().size)) roi_probs_neg = F.softmax(roi_logits_neg, dim=1) neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize) neg_loss = F.cross_entropy(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda()) # return the indices of negative samples, who contributed to the loss for monitoring plots. np_neg_ix = neg_ix.cpu().data.numpy() else: neg_loss = torch.FloatTensor([0]).cuda() np_neg_ix = np.array([]).astype('int32') loss = (pos_loss + neg_loss) / 2 return loss, np_neg_ix def compute_bbox_loss(target_deltas, pred_deltas, anchor_matches): """ :param target_deltas: (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))). Uses 0 padding to fill in unused bbox deltas. :param pred_deltas: predicted deltas from bbox regression head. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))) :param anchor_matches: tensor (n_anchors). value in [-1, 0, class_ids] for negative, neutral, and positive matched anchors. i.e., positively matched anchors are marked by class_id >0 :return: loss: torch 1D tensor. """ if not 0 in torch.nonzero(anchor_matches>0).shape: indices = torch.nonzero(anchor_matches>0).squeeze(1) # Pick bbox deltas that contribute to the loss pred_deltas = pred_deltas[indices] # Trim target bounding box deltas to the same length as pred_deltas. target_deltas = target_deltas[:pred_deltas.shape[0], :].detach() # Smooth L1 loss loss = F.smooth_l1_loss(pred_deltas, target_deltas) else: loss = torch.FloatTensor([0]).cuda() return loss def compute_rg_loss(tasks, target, pred, anchor_matches): """ :param target_deltas: (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))). Uses 0 padding to fill in unsed bbox deltas. :param pred_deltas: predicted deltas from bbox regression head. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))) :param anchor_matches: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. :return: loss: torch 1D tensor. """ if not 0 in target.shape and not 0 in torch.nonzero(anchor_matches>0).shape: indices = torch.nonzero(anchor_matches>0).squeeze(1) # Pick rgs that contribute to the loss pred = pred[indices] # Trim target target = target[:pred.shape[0]].detach() if 'regression_bin' in tasks: loss = F.cross_entropy(pred, target.long()) else: loss = F.smooth_l1_loss(pred, target) else: loss = torch.FloatTensor([0]).cuda() return loss def compute_focal_class_loss(anchor_matches, class_pred_logits, gamma=2.): """ Focal Loss FL = -(1-q)^g log(q) with q = pred class probability. :param anchor_matches: (n_anchors). [-1, 0, class] for negative, neutral, and positive matched anchors. :param class_pred_logits: (n_anchors, n_classes). logits from classifier sub-network. :param gamma: g in above formula, good results with g=2 in original paper. :return: loss: torch tensor :return: focal loss """ # Positive and Negative anchors contribute to the loss, # but neutral anchors (match value = 0) don't. pos_indices = torch.nonzero(anchor_matches > 0).squeeze(-1) # dim=-1 instead of 1 or 0 to cover empty matches. neg_indices = torch.nonzero(anchor_matches == -1).squeeze(-1) target_classes = torch.cat( (anchor_matches[pos_indices].long(), torch.LongTensor([0] * neg_indices.shape[0]).cuda()) ) non_neutral_indices = torch.cat( (pos_indices, neg_indices) ) q = F.softmax(class_pred_logits[non_neutral_indices], dim=1) # q shape: (n_non_neutral_anchors, n_classes) # one-hot encoded target classes: keep only the pred probs of the correct class. it will receive incentive to be maximized. # log(q_i) where i = target class --> FL shape (n_anchors,) # need to transform to indices into flattened tensor to use torch.take target_locs_flat = q.shape[1] * torch.arange(q.shape[0]).cuda() + target_classes q = torch.take(q, target_locs_flat) FL = torch.log(q) # element-wise log FL *= -(1-q)**gamma # take mean over all considered anchors FL = FL.sum() / FL.shape[0] return FL def refine_detections(anchors, probs, deltas, regressions, batch_ixs, cf): """Refine classified proposals, filter overlaps and return final detections. n_proposals here is typically a very large number: batch_size * n_anchors. This function is hence optimized on trimming down n_proposals. :param anchors: (n_anchors, 2 * dim) :param probs: (n_proposals, n_classes) softmax probabilities for all rois as predicted by classifier head. :param deltas: (n_proposals, n_classes, 2 * dim) box refinement deltas as predicted by bbox regressor head. :param regressions: (n_proposals, n_classes, n_rg_feats) :param batch_ixs: (n_proposals) batch element assignemnt info for re-allocation. :return: result: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score, pred_regr)) """ anchors = anchors.repeat(batch_ixs.unique().shape[0], 1) #flatten foreground probabilities, sort and trim down to highest confidences by pre_nms limit. fg_probs = probs[:, 1:].contiguous() flat_probs, flat_probs_order = fg_probs.view(-1).sort(descending=True) keep_ix = flat_probs_order[:cf.pre_nms_limit] # reshape indices to 2D index array with shape like fg_probs. keep_arr = torch.cat(((keep_ix / fg_probs.shape[1]).unsqueeze(1), (keep_ix % fg_probs.shape[1]).unsqueeze(1)), 1) pre_nms_scores = flat_probs[:cf.pre_nms_limit] pre_nms_class_ids = keep_arr[:, 1] + 1 # add background again. pre_nms_batch_ixs = batch_ixs[keep_arr[:, 0]] pre_nms_anchors = anchors[keep_arr[:, 0]] pre_nms_deltas = deltas[keep_arr[:, 0]] pre_nms_regressions = regressions[keep_arr[:, 0]] keep = torch.arange(pre_nms_scores.size()[0]).long().cuda() # apply bounding box deltas. re-scale to image coordinates. std_dev = torch.from_numpy(np.reshape(cf.rpn_bbox_std_dev, [1, cf.dim * 2])).float().cuda() scale = torch.from_numpy(cf.scale).float().cuda() refined_rois = mutils.apply_box_deltas_2D(pre_nms_anchors / scale, pre_nms_deltas * std_dev) * scale \ if cf.dim == 2 else mutils.apply_box_deltas_3D(pre_nms_anchors / scale, pre_nms_deltas * std_dev) * scale # round and cast to int since we're deadling with pixels now refined_rois = mutils.clip_to_window(cf.window, refined_rois) pre_nms_rois = torch.round(refined_rois) for j, b in enumerate(mutils.unique1d(pre_nms_batch_ixs)): bixs = torch.nonzero(pre_nms_batch_ixs == b)[:, 0] bix_class_ids = pre_nms_class_ids[bixs] bix_rois = pre_nms_rois[bixs] bix_scores = pre_nms_scores[bixs] for i, class_id in enumerate(mutils.unique1d(bix_class_ids)): ixs = torch.nonzero(bix_class_ids == class_id)[:, 0] # nms expects boxes sorted by score. ix_rois = bix_rois[ixs] ix_scores = bix_scores[ixs] ix_scores, order = ix_scores.sort(descending=True) ix_rois = ix_rois[order, :] ix_scores = ix_scores class_keep = nms.nms(ix_rois, ix_scores, cf.detection_nms_threshold) # map indices back. class_keep = keep[bixs[ixs[order[class_keep]]]] # merge indices over classes for current batch element b_keep = class_keep if i == 0 else mutils.unique1d(torch.cat((b_keep, class_keep))) # only keep top-k boxes of current batch-element. top_ids = pre_nms_scores[b_keep].sort(descending=True)[1][:cf.model_max_instances_per_batch_element] b_keep = b_keep[top_ids] # merge indices over batch elements. batch_keep = b_keep if j == 0 else mutils.unique1d(torch.cat((batch_keep, b_keep))) keep = batch_keep # arrange output. result = torch.cat((pre_nms_rois[keep], pre_nms_batch_ixs[keep].unsqueeze(1).float(), pre_nms_class_ids[keep].unsqueeze(1).float(), pre_nms_scores[keep].unsqueeze(1), pre_nms_regressions[keep]), dim=1) return result def gt_anchor_matching(cf, anchors, gt_boxes, gt_class_ids=None, gt_regressions=None): """Given the anchors and GT boxes, compute overlaps and identify positive anchors and deltas to refine them to match their corresponding GT boxes. anchors: [num_anchors, (y1, x1, y2, x2, (z1), (z2))] gt_boxes: [num_gt_boxes, (y1, x1, y2, x2, (z1), (z2))] gt_class_ids (optional): [num_gt_boxes] Integer class IDs for one stage detectors. in RPN case of Mask R-CNN, set all positive matches to 1 (foreground) gt_regressions: [num_gt_rgs, n_rg_feats], if None empty rg_targets are returned Returns: anchor_class_matches: [N] (int32) matches between anchors and GT boxes. class_id = positive anchor, -1 = negative anchor, 0 = neutral. i.e., positively matched anchors are marked by class_id (which is >0). anchor_delta_targets: [N, (dy, dx, (dz), log(dh), log(dw), (log(dd)))] Anchor bbox deltas. anchor_rg_targets: [n_anchors, n_rg_feats] """ anchor_class_matches = np.zeros([anchors.shape[0]], dtype=np.int32) anchor_delta_targets = np.zeros((cf.rpn_train_anchors_per_image, 2*cf.dim)) if gt_regressions is not None: if 'regression_bin' in cf.prediction_tasks: anchor_rg_targets = np.zeros((cf.rpn_train_anchors_per_image,)) else: anchor_rg_targets = np.zeros((cf.rpn_train_anchors_per_image, cf.regression_n_features)) else: anchor_rg_targets = np.array([]) anchor_matching_iou = cf.anchor_matching_iou if gt_boxes is None: anchor_class_matches = np.full(anchor_class_matches.shape, fill_value=-1) return anchor_class_matches, anchor_delta_targets, anchor_rg_targets # for mrcnn: anchor matching is done for RPN loss, so positive labels are all 1 (foreground) if gt_class_ids is None: gt_class_ids = np.array([1] * len(gt_boxes)) # Compute overlaps [num_anchors, num_gt_boxes] overlaps = mutils.compute_overlaps(anchors, gt_boxes) # Match anchors to GT Boxes # If an anchor overlaps a GT box with IoU >= anchor_matching_iou then it's positive. # If an anchor overlaps a GT box with IoU < 0.1 then it's negative. # Neutral anchors are those that don't match the conditions above, # and they don't influence the loss function. # However, don't keep any GT box unmatched (rare, but happens). Instead, # match it to the closest anchor (even if its max IoU is < 0.1). # 1. Set negative anchors first. They get overwritten below if a GT box is # matched to them. Skip boxes in crowd areas. anchor_iou_argmax = np.argmax(overlaps, axis=1) anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax] if anchors.shape[1] == 4: anchor_class_matches[(anchor_iou_max < 0.1)] = -1 elif anchors.shape[1] == 6: anchor_class_matches[(anchor_iou_max < 0.01)] = -1 else: raise ValueError('anchor shape wrong {}'.format(anchors.shape)) # 2. Set an anchor for each GT box (regardless of IoU value). gt_iou_argmax = np.argmax(overlaps, axis=0) for ix, ii in enumerate(gt_iou_argmax): anchor_class_matches[ii] = gt_class_ids[ix] # 3. Set anchors with high overlap as positive. above_thresh_ixs = np.argwhere(anchor_iou_max >= anchor_matching_iou) anchor_class_matches[above_thresh_ixs] = gt_class_ids[anchor_iou_argmax[above_thresh_ixs]] # Subsample to balance positive anchors. ids = np.where(anchor_class_matches > 0)[0] extra = len(ids) - (cf.rpn_train_anchors_per_image // 2) if extra > 0: # Reset the extra ones to neutral ids = np.random.choice(ids, extra, replace=False) anchor_class_matches[ids] = 0 # Leave all negative proposals negative for now and sample from them later in online hard example mining. # For positive anchors, compute shift and scale needed to transform them to match the corresponding GT boxes. ids = np.where(anchor_class_matches > 0)[0] ix = 0 # index into anchor_delta_targets for i, a in zip(ids, anchors[ids]): # closest gt box (it might have IoU < anchor_matching_iou) gt = gt_boxes[anchor_iou_argmax[i]] # convert coordinates to center plus width/height. gt_h = gt[2] - gt[0] gt_w = gt[3] - gt[1] gt_center_y = gt[0] + 0.5 * gt_h gt_center_x = gt[1] + 0.5 * gt_w # Anchor a_h = a[2] - a[0] a_w = a[3] - a[1] a_center_y = a[0] + 0.5 * a_h a_center_x = a[1] + 0.5 * a_w if cf.dim == 2: anchor_delta_targets[ix] = [ (gt_center_y - a_center_y) / a_h, (gt_center_x - a_center_x) / a_w, np.log(gt_h / a_h), np.log(gt_w / a_w)] else: gt_d = gt[5] - gt[4] gt_center_z = gt[4] + 0.5 * gt_d a_d = a[5] - a[4] a_center_z = a[4] + 0.5 * a_d anchor_delta_targets[ix] = [ (gt_center_y - a_center_y) / a_h, (gt_center_x - a_center_x) / a_w, (gt_center_z - a_center_z) / a_d, np.log(gt_h / a_h), np.log(gt_w / a_w), np.log(gt_d / a_d)] # normalize. anchor_delta_targets[ix] /= cf.rpn_bbox_std_dev if gt_regressions is not None: anchor_rg_targets[ix] = gt_regressions[anchor_iou_argmax[i]] ix += 1 return anchor_class_matches, anchor_delta_targets, anchor_rg_targets ############################################################ # RetinaNet Class ############################################################ class net(nn.Module): """Encapsulates the RetinaNet model functionality. """ def __init__(self, cf, logger): """ cf: A Sub-class of the cf class model_dir: Directory to save training logs and trained weights """ super(net, self).__init__() self.cf = cf self.logger = logger self.build() if self.cf.weight_init is not None: logger.info("using pytorch weight init of type {}".format(self.cf.weight_init)) mutils.initialize_weights(self) else: logger.info("using default pytorch weight init") self.debug_acm = [] def build(self): """Build Retina Net architecture.""" # Image size must be dividable by 2 multiple times. h, w = self.cf.patch_size[:2] if h / 2 ** 5 != int(h / 2 ** 5) or w / 2 ** 5 != int(w / 2 ** 5): raise Exception("Image size must be divisible by 2 at least 5 times " "to avoid fractions when downscaling and upscaling." "For example, use 256, 320, 384, 448, 512, ... etc. ") backbone = utils.import_module('bbone', self.cf.backbone_path) self.logger.info("loaded backbone from {}".format(self.cf.backbone_path)) conv = backbone.ConvGenerator(self.cf.dim) # build Anchors, FPN, Classifier / Bbox-Regressor -head self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf) self.anchors = torch.from_numpy(self.np_anchors).float().cuda() self.fpn = backbone.FPN(self.cf, conv, operate_stride1=self.cf.operate_stride1).cuda() self.classifier = Classifier(self.cf, conv).cuda() self.bb_regressor = BBRegressor(self.cf, conv).cuda() if 'regression' in self.cf.prediction_tasks: self.roi_regressor = RoIRegressor(self.cf, conv, self.cf.regression_n_features).cuda() elif 'regression_bin' in self.cf.prediction_tasks: # classify into bins of regression values self.roi_regressor = RoIRegressor(self.cf, conv, len(self.cf.bin_labels)).cuda() else: self.roi_regressor = lambda x: [torch.tensor([]).cuda()] if self.cf.model == 'retina_unet': - self.final_conv = conv(self.cf.end_filts, self.cf.num_seg_classes, ks=1, pad=0, norm=self.cf.norm, relu=None) + self.final_conv = conv(self.cf.end_filts, self.cf.num_seg_classes, ks=1, pad=0, norm=None, relu=None) def forward(self, img): """ :param img: input img (b, c, y, x, (z)). """ # Feature extraction fpn_outs = self.fpn(img) if self.cf.model == 'retina_unet': seg_logits = self.final_conv(fpn_outs[0]) selected_fmaps = [fpn_outs[i + 1] for i in self.cf.pyramid_levels] else: seg_logits = None selected_fmaps = [fpn_outs[i] for i in self.cf.pyramid_levels] # Loop through pyramid layers class_layer_outputs, bb_reg_layer_outputs, roi_reg_layer_outputs = [], [], [] # list of lists for p in selected_fmaps: class_layer_outputs.append(self.classifier(p)) bb_reg_layer_outputs.append(self.bb_regressor(p)) roi_reg_layer_outputs.append(self.roi_regressor(p)) # Concatenate layer outputs # Convert from list of lists of level outputs to list of lists # of outputs across levels. # e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]] class_logits = list(zip(*class_layer_outputs)) class_logits = [torch.cat(list(o), dim=1) for o in class_logits][0] bb_outputs = list(zip(*bb_reg_layer_outputs)) bb_outputs = [torch.cat(list(o), dim=1) for o in bb_outputs][0] if not 0 == roi_reg_layer_outputs[0][0].shape[0]: rg_outputs = list(zip(*roi_reg_layer_outputs)) rg_outputs = [torch.cat(list(o), dim=1) for o in rg_outputs][0] else: if self.cf.dim == 2: n_feats = np.array([p.shape[-2] * p.shape[-1] * self.cf.n_anchors_per_pos for p in selected_fmaps]).sum() else: n_feats = np.array([p.shape[-3]*p.shape[-2]*p.shape[-1]*self.cf.n_anchors_per_pos for p in selected_fmaps]).sum() rg_outputs = torch.zeros((selected_fmaps[0].shape[0], n_feats, self.cf.regression_n_features), dtype=torch.float32).fill_(float('NaN')).cuda() # merge batch_dimension and store info in batch_ixs for re-allocation. batch_ixs = torch.arange(class_logits.shape[0]).unsqueeze(1).repeat(1, class_logits.shape[1]).view(-1).cuda() flat_class_softmax = F.softmax(class_logits.view(-1, class_logits.shape[-1]), 1) flat_bb_outputs = bb_outputs.view(-1, bb_outputs.shape[-1]) flat_rg_outputs = rg_outputs.view(-1, rg_outputs.shape[-1]) detections = refine_detections(self.anchors, flat_class_softmax, flat_bb_outputs, flat_rg_outputs, batch_ixs, self.cf) return detections, class_logits, bb_outputs, rg_outputs, seg_logits def get_results(self, img_shape, detections, seg_logits, box_results_list=None): """ Restores batch dimension of merged detections, unmolds detections, creates and fills results dict. :param img_shape: :param detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score, pred_regression) :param box_results_list: None or list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, 1] only fg. vs. bg for now. class-specific return of masks will come with implementation of instance segmentation evaluation. """ detections = detections.cpu().data.numpy() batch_ixs = detections[:, self.cf.dim*2] detections = [detections[batch_ixs == ix] for ix in range(img_shape[0])] if box_results_list == None: # for test_forward, where no previous list exists. box_results_list = [[] for _ in range(img_shape[0])] for ix in range(img_shape[0]): if not 0 in detections[ix].shape: boxes = detections[ix][:, :2 * self.cf.dim].astype(np.int32) class_ids = detections[ix][:, 2 * self.cf.dim + 1].astype(np.int32) scores = detections[ix][:, 2 * self.cf.dim + 2] regressions = detections[ix][:, 2 * self.cf.dim + 3:] # Filter out detections with zero area. Often only happens in early # stages of training when the network weights are still a bit random. if self.cf.dim == 2: exclude_ix = np.where((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0] else: exclude_ix = np.where( (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4]) <= 0)[0] if exclude_ix.shape[0] > 0: boxes = np.delete(boxes, exclude_ix, axis=0) class_ids = np.delete(class_ids, exclude_ix, axis=0) scores = np.delete(scores, exclude_ix, axis=0) regressions = np.delete(regressions, exclude_ix, axis=0) if not 0 in boxes.shape: for ix2, score in enumerate(scores): if score >= self.cf.model_min_confidence: box = {'box_type': 'det', 'box_coords': boxes[ix2], 'box_score': score, 'box_pred_class_id': class_ids[ix2]} if "regression_bin" in self.cf.prediction_tasks: # in this case, regression preds are actually the rg_bin_ids --> map to rg value the bin stands for box['rg_bin'] = regressions[ix2].argmax() box['regression'] = self.cf.bin_id2rg_val[box['rg_bin']] else: box['regression'] = regressions[ix2] if hasattr(self.cf, "rg_val_to_bin_id") and \ any(['regression' in task for task in self.cf.prediction_tasks]): box['rg_bin'] = self.cf.rg_val_to_bin_id(regressions[ix2]) box_results_list[ix].append(box) results_dict = {} results_dict['boxes'] = box_results_list if seg_logits is None: # output dummy segmentation for retina_net. out_logits_shape = list(img_shape) out_logits_shape[1] = self.cf.num_seg_classes results_dict['seg_preds'] = np.zeros(out_logits_shape, dtype=np.float16) #todo: try with seg_preds=None? as to not carry heavy dummy preds. else: # output label maps for retina_unet. results_dict['seg_preds'] = F.softmax(seg_logits, 1).cpu().data.numpy() return results_dict def train_forward(self, batch, is_validation=False): """ train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data for processing, computes losses, and stores outputs in a dictionary. :param batch: dictionary containing 'data', 'seg', etc. :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixelwise segmentation output (b, c, y, x, (z)) with values [0, .., n_classes]. 'torch_loss': 1D torch tensor for backprop. 'class_loss': classification loss for monitoring. """ img = batch['data'] gt_class_ids = batch['class_targets'] gt_boxes = batch['bb_target'] if 'regression' in self.cf.prediction_tasks: gt_regressions = batch["regression_targets"] elif 'regression_bin' in self.cf.prediction_tasks: gt_regressions = batch["rg_bin_targets"] else: gt_regressions = None var_seg_ohe = torch.FloatTensor(mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).cuda() var_seg = torch.LongTensor(batch['seg']).cuda() img = torch.from_numpy(img).float().cuda() torch_loss = torch.FloatTensor([0]).cuda() # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. box_results_list = [[] for _ in range(img.shape[0])] detections, class_logits, pred_deltas, pred_rgs, seg_logits = self.forward(img) # loop over batch for b in range(img.shape[0]): # add gt boxes to results dict for monitoring. if len(gt_boxes[b]) > 0: for tix in range(len(gt_boxes[b])): gt_box = {'box_type': 'gt', 'box_coords': batch['bb_target'][b][tix]} for name in self.cf.roi_items: gt_box.update({name: batch[name][b][tix]}) box_results_list[b].append(gt_box) # match gt boxes with anchors to generate targets. anchor_class_match, anchor_target_deltas, anchor_target_rgs = gt_anchor_matching( self.cf, self.np_anchors, gt_boxes[b], gt_class_ids[b], gt_regressions[b] if gt_regressions is not None else None) # add positive anchors used for loss to results_dict for monitoring. pos_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere(anchor_class_match > 0)][:, 0], img.shape[2:]) for p in pos_anchors: box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'}) else: anchor_class_match = np.array([-1]*self.np_anchors.shape[0]) anchor_target_deltas = np.array([]) anchor_target_rgs = np.array([]) anchor_class_match = torch.from_numpy(anchor_class_match).cuda() anchor_target_deltas = torch.from_numpy(anchor_target_deltas).float().cuda() anchor_target_rgs = torch.from_numpy(anchor_target_rgs).float().cuda() if self.cf.focal_loss: # compute class loss as focal loss as suggested in original publication, but multi-class. class_loss = compute_focal_class_loss(anchor_class_match, class_logits[b], gamma=self.cf.focal_loss_gamma) # sparing appendix of negative anchors for monitoring as not really relevant else: # compute class loss with SHEM. class_loss, neg_anchor_ix = compute_class_loss(anchor_class_match, class_logits[b]) # add negative anchors used for loss to results_dict for monitoring. neg_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere(anchor_class_match.cpu().numpy() == -1)][neg_anchor_ix, 0], img.shape[2:]) for n in neg_anchors: box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'}) rg_loss = compute_rg_loss(self.cf.prediction_tasks, anchor_target_rgs, pred_rgs[b], anchor_class_match) bbox_loss = compute_bbox_loss(anchor_target_deltas, pred_deltas[b], anchor_class_match) torch_loss += (class_loss + bbox_loss + rg_loss) / img.shape[0] results_dict = self.get_results(img.shape, detections, seg_logits, box_results_list) results_dict['seg_preds'] = results_dict['seg_preds'].argmax(axis=1).astype('uint8')[:, np.newaxis] if self.cf.model == 'retina_unet': seg_loss_dice = 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1),var_seg_ohe) seg_loss_ce = F.cross_entropy(seg_logits, var_seg[:, 0]) torch_loss += (seg_loss_dice + seg_loss_ce) / 2 #self.logger.info("loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}, seg dice: {3:.3f}, seg ce: {4:.3f}, " # "mean pixel preds: {5:.5f}".format(torch_loss.item(), batch_class_loss.item(), batch_bbox_loss.item(), # seg_loss_dice.item(), seg_loss_ce.item(), np.mean(results_dict['seg_preds']))) if 'dice' in self.cf.metrics: results_dict['batch_dices'] = mutils.dice_per_batch_and_class( results_dict['seg_preds'], batch["seg"], self.cf.num_seg_classes, convert_to_ohe=True) #else: #self.logger.info("loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}".format( # torch_loss.item(), class_loss.item(), bbox_loss.item())) results_dict['torch_loss'] = torch_loss results_dict['class_loss'] = class_loss.item() return results_dict def test_forward(self, batch, **kwargs): """ test method. wrapper around forward pass of network without usage of any ground truth information. prepares input data for processing and stores outputs in a dictionary. :param batch: dictionary containing 'data' :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': actually contain seg probabilities since evaluated to seg_preds (via argmax) in predictor. or dummy seg logits for real retina net (detection only) """ img = torch.from_numpy(batch['data']).float().cuda() detections, _, _, _, seg_logits = self.forward(img) results_dict = self.get_results(img.shape, detections, seg_logits) return results_dict \ No newline at end of file diff --git a/predictor.py b/predictor.py index c1f70e9..370d2ce 100644 --- a/predictor.py +++ b/predictor.py @@ -1,1007 +1,1000 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import os from multiprocessing import Pool import pickle import time -import copy import numpy as np import torch from scipy.stats import norm from collections import OrderedDict -import pandas as pd import plotting as plg import utils.model_utils as mutils -import utils.exp_utils as utils def get_mirrored_patch_crops(patch_crops, org_img_shape): mirrored_patch_crops = [] mirrored_patch_crops.append([[org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2], ii[3]] if len(ii) == 4 else [org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2], ii[3], ii[4], ii[5]] for ii in patch_crops]) mirrored_patch_crops.append([[ii[0], ii[1], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2]] if len(ii) == 4 else [ii[0], ii[1], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2], ii[4], ii[5]] for ii in patch_crops]) mirrored_patch_crops.append([[org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2]] if len(ii) == 4 else [org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2], ii[4], ii[5]] for ii in patch_crops]) return mirrored_patch_crops def get_mirrored_patch_crops_ax_dep(patch_crops, org_img_shape, mirror_axes): mirrored_patch_crops = [] for ax_ix, axes in enumerate(mirror_axes): if isinstance(axes, (int, float)) and int(axes) == 0: mirrored_patch_crops.append([[org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2], ii[3]] if len(ii) == 4 else [org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], ii[2], ii[3], ii[4], ii[5]] for ii in patch_crops]) elif isinstance(axes, (int, float)) and int(axes) == 1: mirrored_patch_crops.append([[ii[0], ii[1], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2]] if len(ii) == 4 else [ii[0], ii[1], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2], ii[4], ii[5]] for ii in patch_crops]) elif hasattr(axes, "__iter__") and (tuple(axes) == (0, 1) or tuple(axes) == (1, 0)): mirrored_patch_crops.append([[org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2]] if len(ii) == 4 else [org_img_shape[2] - ii[1], org_img_shape[2] - ii[0], org_img_shape[3] - ii[3], org_img_shape[3] - ii[2], ii[4], ii[5]] for ii in patch_crops]) else: raise Exception("invalid mirror axes {} in get mirrored patch crops".format(axes)) return mirrored_patch_crops def apply_wbc_to_patient(inputs): """ wrapper around prediction box consolidation: weighted box clustering (wbc). processes a single patient. loops over batch elements in patient results (1 in 3D, slices in 2D) and foreground classes, aggregates and stores results in new list. :return. patient_results_list: list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, and a dummy batch dimension of 1 for 3D predictions. :return. pid: string. patient id. """ regress_flag, in_patient_results_list, pid, class_dict, clustering_iou, n_ens = inputs out_patient_results_list = [[] for _ in range(len(in_patient_results_list))] for bix, b in enumerate(in_patient_results_list): for cl in list(class_dict.keys()): boxes = [(ix, box) for ix, box in enumerate(b) if (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)] box_coords = np.array([b[1]['box_coords'] for b in boxes]) box_scores = np.array([b[1]['box_score'] for b in boxes]) box_center_factor = np.array([b[1]['box_patch_center_factor'] for b in boxes]) box_n_overlaps = np.array([b[1]['box_n_overlaps'] for b in boxes]) try: box_patch_id = np.array([b[1]['patch_id'] for b in boxes]) except KeyError: #backward compatibility for already saved pred results ... omg box_patch_id = np.array([b[1]['ens_ix'] for b in boxes]) box_regressions = np.array([b[1]['regression'] for b in boxes]) if regress_flag else None box_rg_bins = np.array([b[1]['rg_bin'] if 'rg_bin' in b[1].keys() else float('NaN') for b in boxes]) box_rg_uncs = np.array([b[1]['rg_uncertainty'] if 'rg_uncertainty' in b[1].keys() else float('NaN') for b in boxes]) if 0 not in box_scores.shape: keep_scores, keep_coords, keep_n_missing, keep_regressions, keep_rg_bins, keep_rg_uncs = \ weighted_box_clustering(box_coords, box_scores, box_center_factor, box_n_overlaps, box_rg_bins, box_rg_uncs, box_regressions, box_patch_id, clustering_iou, n_ens) for boxix in range(len(keep_scores)): clustered_box = {'box_type': 'det', 'box_coords': keep_coords[boxix], 'box_score': keep_scores[boxix], 'cluster_n_missing': keep_n_missing[boxix], 'box_pred_class_id': cl} if regress_flag: clustered_box.update({'regression': keep_regressions[boxix], 'rg_uncertainty': keep_rg_uncs[boxix], 'rg_bin': keep_rg_bins[boxix]}) out_patient_results_list[bix].append(clustered_box) # add gt boxes back to new output list. out_patient_results_list[bix].extend([box for box in b if box['box_type'] == 'gt']) return [out_patient_results_list, pid] def weighted_box_clustering(box_coords, scores, box_pc_facts, box_n_ovs, box_rg_bins, box_rg_uncs, box_regress, box_patch_id, thresh, n_ens): """Consolidates overlapping predictions resulting from patch overlaps, test data augmentations and temporal ensembling. clusters predictions together with iou > thresh (like in NMS). Output score and coordinate for one cluster are the average weighted by individual patch center factors (how trustworthy is this candidate measured by how centered its position within the patch is) and the size of the corresponding box. The number of expected predictions at a position is n_data_aug * n_temp_ens * n_overlaps_at_position (1 prediction per unique patch). Missing predictions at a cluster position are defined as the number of unique patches in the cluster, which did not contribute any predict any boxes. :param dets: (n_dets, (y1, x1, y2, x2, (z1), (z2), scores, box_pc_facts, box_n_ovs). :param box_coords: y1, x1, y2, x2, (z1), (z2). :param scores: confidence scores. :param box_pc_facts: patch-center factors from position on patch tiles. :param box_n_ovs: number of patch overlaps at box position. :param box_rg_bins: regression bin predictions. :param box_rg_uncs: (n_dets,) regression uncertainties (from model mrcnn_aleatoric). :param box_regress: (n_dets, n_regression_features). :param box_patch_id: ensemble index. :param thresh: threshold for iou_matching. :param n_ens: number of models, that are ensembled. (-> number of expected predictions per position). :return: keep_scores: (n_keep) new scores of boxes to be kept. :return: keep_coords: (n_keep, (y1, x1, y2, x2, (z1), (z2)) new coordinates of boxes to be kept. """ dim = 2 if box_coords.shape[1] == 4 else 3 y1 = box_coords[:,0] x1 = box_coords[:,1] y2 = box_coords[:,2] x2 = box_coords[:,3] areas = (y2 - y1 + 1) * (x2 - x1 + 1) if dim == 3: z1 = box_coords[:, 4] z2 = box_coords[:, 5] areas *= (z2 - z1 + 1) # order is the sorted index. maps order to index o[1] = 24 (rank1, ix 24) order = scores.argsort()[::-1] keep_scores = [] keep_coords = [] keep_n_missing = [] keep_regress = [] keep_rg_bins = [] keep_rg_uncs = [] while order.size > 0: i = order[0] # highest scoring element yy1 = np.maximum(y1[i], y1[order]) xx1 = np.maximum(x1[i], x1[order]) yy2 = np.minimum(y2[i], y2[order]) xx2 = np.minimum(x2[i], x2[order]) w = np.maximum(0, xx2 - xx1 + 1) h = np.maximum(0, yy2 - yy1 + 1) inter = w * h if dim == 3: zz1 = np.maximum(z1[i], z1[order]) zz2 = np.minimum(z2[i], z2[order]) d = np.maximum(0, zz2 - zz1 + 1) inter *= d # overlap between currently highest scoring box and all boxes. ovr = inter / (areas[i] + areas[order] - inter) ovr_fl = inter.astype('float64') / (areas[i] + areas[order] - inter.astype('float64')) assert np.all(ovr==ovr_fl), "ovr {}\n ovr_float {}".format(ovr, ovr_fl) # get all the predictions that match the current box to build one cluster. matches = np.nonzero(ovr > thresh)[0] match_n_ovs = box_n_ovs[order[matches]] match_pc_facts = box_pc_facts[order[matches]] match_patch_id = box_patch_id[order[matches]] match_ov_facts = ovr[matches] match_areas = areas[order[matches]] match_scores = scores[order[matches]] # weight all scores in cluster by patch factors, and size. match_score_weights = match_ov_facts * match_areas * match_pc_facts match_scores *= match_score_weights # for the weighted average, scores have to be divided by the number of total expected preds at the position # of the current cluster. 1 Prediction per patch is expected. therefore, the number of ensembled models is # multiplied by the mean overlaps of patches at this position (boxes of the cluster might partly be # in areas of different overlaps). n_expected_preds = n_ens * np.mean(match_n_ovs) # the number of missing predictions is obtained as the number of patches, # which did not contribute any prediction to the current cluster. n_missing_preds = np.max((0, n_expected_preds - np.unique(match_patch_id).shape[0])) # missing preds are given the mean weighting # (expected prediction is the mean over all predictions in cluster). denom = np.sum(match_score_weights) + n_missing_preds * np.mean(match_score_weights) # compute weighted average score for the cluster avg_score = np.sum(match_scores) / denom # compute weighted average of coordinates for the cluster. now only take existing # predictions into account. avg_coords = [np.sum(y1[order[matches]] * match_scores) / np.sum(match_scores), np.sum(x1[order[matches]] * match_scores) / np.sum(match_scores), np.sum(y2[order[matches]] * match_scores) / np.sum(match_scores), np.sum(x2[order[matches]] * match_scores) / np.sum(match_scores)] if dim == 3: avg_coords.append(np.sum(z1[order[matches]] * match_scores) / np.sum(match_scores)) avg_coords.append(np.sum(z2[order[matches]] * match_scores) / np.sum(match_scores)) if box_regress is not None: # compute wt. avg. of regression vectors (component-wise average) avg_regress = np.sum(box_regress[order[matches]] * match_scores[:, np.newaxis], axis=0) / np.sum( match_scores) avg_rg_bins = np.round(np.sum(box_rg_bins[order[matches]] * match_scores) / np.sum(match_scores)) avg_rg_uncs = np.sum(box_rg_uncs[order[matches]] * match_scores) / np.sum(match_scores) else: avg_regress = np.array(float('NaN')) avg_rg_bins = np.array(float('NaN')) avg_rg_uncs = np.array(float('NaN')) # some clusters might have very low scores due to high amounts of missing predictions. # filter out the with a conservative threshold, to speed up evaluation. if avg_score > 0.01: keep_scores.append(avg_score) keep_coords.append(avg_coords) keep_n_missing.append((n_missing_preds / n_expected_preds * 100)) # relative keep_regress.append(avg_regress) keep_rg_uncs.append(avg_rg_uncs) keep_rg_bins.append(avg_rg_bins) # get index of all elements that were not matched and discard all others. inds = np.nonzero(ovr <= thresh)[0] inds_where = np.where(ovr<=thresh)[0] assert np.all(inds == inds_where), "inds_nonzero {} \ninds_where {}".format(inds, inds_where) order = order[inds] return keep_scores, keep_coords, keep_n_missing, keep_regress, keep_rg_bins, keep_rg_uncs def apply_nms_to_patient(inputs): in_patient_results_list, pid, class_dict, iou_thresh = inputs out_patient_results_list = [] # collect box predictions over batch dimension (slices) and store slice info as slice_ids. for batch in in_patient_results_list: batch_el_boxes = [] for cl in list(class_dict.keys()): det_boxes = [box for box in batch if (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)] box_coords = np.array([box['box_coords'] for box in det_boxes]) box_scores = np.array([box['box_score'] for box in det_boxes]) if 0 not in box_scores.shape: keep_ix = mutils.nms_numpy(box_coords, box_scores, iou_thresh) else: keep_ix = [] batch_el_boxes += [det_boxes[ix] for ix in keep_ix] batch_el_boxes += [box for box in batch if box['box_type'] == 'gt'] out_patient_results_list.append(batch_el_boxes) assert len(in_patient_results_list) == len(out_patient_results_list), "batch dim needs to be maintained, in: {}, out {}".format(len(in_patient_results_list), len(out_patient_results_list)) return [out_patient_results_list, pid] def nms_2to3D(dets, thresh): """ Merges 2D boxes to 3D cubes. For this purpose, boxes of all slices are regarded as lying in one slice. An adaptation of Non-maximum suppression is applied where clusters are found (like in NMS) with the extra constraint that suppressed boxes have to have 'connected' z coordinates w.r.t the core slice (cluster center, highest scoring box, the prevailing box). 'connected' z-coordinates are determined as the z-coordinates with predictions until the first coordinate for which no prediction is found. example: a cluster of predictions was found overlap > iou thresh in xy (like NMS). The z-coordinate of the highest scoring box is 50. Other predictions have 23, 46, 48, 49, 51, 52, 53, 56, 57. Only the coordinates connected with 50 are clustered to one cube: 48, 49, 51, 52, 53. (46 not because nothing was found in 47, so 47 is a 'hole', which interrupts the connection). Only the boxes corresponding to these coordinates are suppressed. All others are kept for building of further clusters. This algorithm works better with a certain min_confidence of predictions, because low confidence (e.g. noisy/cluttery) predictions can break the relatively strong assumption of defining cubes' z-boundaries at the first 'hole' in the cluster. :param dets: (n_detections, (y1, x1, y2, x2, scores, slice_id) :param thresh: iou matchin threshold (like in NMS). :return: keep: (n_keep,) 1D tensor of indices to be kept. :return: keep_z: (n_keep, [z1, z2]) z-coordinates to be added to boxes, which are kept in order to form cubes. """ y1 = dets[:, 0] x1 = dets[:, 1] y2 = dets[:, 2] x2 = dets[:, 3] assert np.all(y1 <= y2) and np.all(x1 <= x2), """"the definition of the coordinates is crucially important here: where maximum is taken needs to be the lower coordinate""" scores = dets[:, -2] slice_id = dets[:, -1] areas = (x2 - x1 + 1) * (y2 - y1 + 1) order = scores.argsort()[::-1] keep = [] keep_z = [] while order.size > 0: # order is the sorted index. maps order to index: order[1] = 24 means (rank1, ix 24) i = order[0] # highest scoring element yy1 = np.maximum(y1[i], y1[order]) # highest scoring element still in >order<, is compared to itself: okay? xx1 = np.maximum(x1[i], x1[order]) yy2 = np.minimum(y2[i], y2[order]) xx2 = np.minimum(x2[i], x2[order]) h = np.maximum(0.0, yy2 - yy1 + 1) w = np.maximum(0.0, xx2 - xx1 + 1) inter = h * w iou = inter / (areas[i] + areas[order] - inter) matches = np.argwhere( iou > thresh) # get all the elements that match the current box and have a lower score slice_ids = slice_id[order[matches]] core_slice = slice_id[int(i)] upper_holes = [ii for ii in np.arange(core_slice, np.max(slice_ids)) if ii not in slice_ids] lower_holes = [ii for ii in np.arange(np.min(slice_ids), core_slice) if ii not in slice_ids] max_valid_slice_id = np.min(upper_holes) if len(upper_holes) > 0 else np.max(slice_ids) min_valid_slice_id = np.max(lower_holes) if len(lower_holes) > 0 else np.min(slice_ids) z_matches = matches[(slice_ids <= max_valid_slice_id) & (slice_ids >= min_valid_slice_id)] # expand by one z voxel since box content is surrounded w/o overlap, i.e., z-content computed as z2-z1 z1 = np.min(slice_id[order[z_matches]]) - 1 z2 = np.max(slice_id[order[z_matches]]) + 1 keep.append(i) keep_z.append([z1, z2]) order = np.delete(order, z_matches, axis=0) return keep, keep_z def apply_2d_3d_merging_to_patient(inputs): """ wrapper around 2Dto3D merging operation. Processes a single patient. Takes 2D patient results (slices in batch dimension) and returns 3D patient results (dummy batch dimension of 1). Applies an adaption of Non-Maximum Surpression (Detailed methodology is described in nms_2to3D). :return. results_dict_boxes: list over batch elements (1 in 3D). each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. :return. pid: string. patient id. """ in_patient_results_list, pid, class_dict, merge_3D_iou = inputs out_patient_results_list = [] for cl in list(class_dict.keys()): det_boxes, slice_ids = [], [] # collect box predictions over batch dimension (slices) and store slice info as slice_ids. for batch_ix, batch in enumerate(in_patient_results_list): batch_element_det_boxes = [(ix, box) for ix, box in enumerate(batch) if (box['box_type'] == 'det' and box['box_pred_class_id'] == cl)] det_boxes += batch_element_det_boxes slice_ids += [batch_ix] * len(batch_element_det_boxes) box_coords = np.array([batch[1]['box_coords'] for batch in det_boxes]) box_scores = np.array([batch[1]['box_score'] for batch in det_boxes]) slice_ids = np.array(slice_ids) if 0 not in box_scores.shape: keep_ix, keep_z = nms_2to3D( np.concatenate((box_coords, box_scores[:, None], slice_ids[:, None]), axis=1), merge_3D_iou) else: keep_ix, keep_z = [], [] # store kept predictions in new results list and add corresponding z-dimension info to coordinates. - # for kix, kz in zip(keep_ix, keep_z): - # out_patient_results_list.append({'box_type': 'det', 'box_coords': list(box_coords[kix]) + kz, - # 'box_score': box_scores[kix], 'box_pred_class_id': cl}) for kix, kz in zip(keep_ix, keep_z): keep_box = det_boxes[kix][1] keep_box['box_coords'] = list(keep_box['box_coords']) + kz out_patient_results_list.append(keep_box) gt_boxes = [box for b in in_patient_results_list for box in b if box['box_type'] == 'gt'] if len(gt_boxes) > 0: assert np.all([len(box["box_coords"]) == 6 for box in gt_boxes]), "expanded preds to 3D but GT is 2D." out_patient_results_list += gt_boxes return [[out_patient_results_list], pid] # additional list wrapping is extra batch dim. class Predictor: """ Prediction pipeline: - receives a patched patient image (n_patches, c, y, x, (z)) from patient data loader. - forwards patches through model in chunks of batch_size. (method: batch_tiling_forward) - unmolds predictions (boxes and segmentations) to original patient coordinates. (method: spatial_tiling_forward) Ensembling (mode == 'test'): - for inference, forwards 4 mirrored versions of image to through model and unmolds predictions afterwards accordingly (method: data_aug_forward) - for inference, loads multiple parameter-sets of the trained model corresponding to different epochs. for each parameter-set loops over entire test set, runs prediction pipeline for each patient. (method: predict_test_set) Consolidation of predictions: - consolidates a patient's predictions (boxes, segmentations) collected over patches, data_aug- and temporal ensembling, performs clustering and weighted averaging (external function: apply_wbc_to_patient) to obtain consistent outptus. - for 2D networks, consolidates box predictions to 3D cubes via clustering (adaption of non-maximum surpression). (external function: apply_2d_3d_merging_to_patient) Ground truth handling: - dissmisses any ground truth boxes returned by the model (happens in validation mode, patch-based groundtruth) - if provided by data loader, adds patient-wise ground truth to the final predictions to be passed to the evaluator. """ def __init__(self, cf, net, logger, mode): self.cf = cf self.batch_size = cf.batch_size self.logger = logger self.mode = mode self.net = net self.n_ens = 1 self.rank_ix = '0' self.regress_flag = any(['regression' in task for task in self.cf.prediction_tasks]) if self.cf.merge_2D_to_3D_preds: assert self.cf.dim == 2, "Merge 2Dto3D only valid for 2D preds, but current dim is {}.".format(self.cf.dim) if self.mode == 'test': try: self.epoch_ranking = np.load(os.path.join(self.cf.fold_dir, 'epoch_ranking.npy'))[:cf.test_n_epochs] except: raise RuntimeError('no epoch ranking file in fold directory. ' 'seems like you are trying to run testing without prior training...') self.n_ens = cf.test_n_epochs if self.cf.test_aug_axes is not None: self.n_ens *= (len(self.cf.test_aug_axes)+1) self.example_plot_dir = os.path.join(cf.test_dir, "example_plots") os.makedirs(self.example_plot_dir, exist_ok=True) def batch_tiling_forward(self, batch): """ calls the actual network forward method. in patch-based prediction, the batch dimension might be overladed with n_patches >> batch_size, which would exceed gpu memory. In this case, batches are processed in chunks of batch_size. validation mode calls the train method to monitor losses (returned ground truth objects are discarded). test mode calls the test forward method, no ground truth required / involved. :return. results_dict: stores the results for one patient. dictionary with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) - loss / class_loss (only in validation mode) """ - #self.logger.info('forwarding (patched) patient with shape: {}'.format(batch['data'].shape)) img = batch['data'] if img.shape[0] <= self.batch_size: if self.mode == 'val': # call training method to monitor losses results_dict = self.net.train_forward(batch, is_validation=True) # discard returned ground-truth boxes (also training info boxes). results_dict['boxes'] = [[box for box in b if box['box_type'] == 'det'] for b in results_dict['boxes']] elif self.mode == 'test': results_dict = self.net.test_forward(batch, return_masks=self.cf.return_masks_in_test) - else: #needs batch tiling + else: # needs batch tiling split_ixs = np.split(np.arange(img.shape[0]), np.arange(img.shape[0])[::self.batch_size]) chunk_dicts = [] for chunk_ixs in split_ixs[1:]: # first split is elements before 0, so empty b = {k: batch[k][chunk_ixs] for k in batch.keys() if (isinstance(batch[k], np.ndarray) and batch[k].shape[0] == img.shape[0])} if self.mode == 'val': chunk_dicts += [self.net.train_forward(b, is_validation=True)] else: chunk_dicts += [self.net.test_forward(b, return_masks=self.cf.return_masks_in_test)] results_dict = {} # flatten out batch elements from chunks ([chunk, chunk] -> [b, b, b, b, ...]) results_dict['boxes'] = [item for d in chunk_dicts for item in d['boxes']] results_dict['seg_preds'] = np.array([item for d in chunk_dicts for item in d['seg_preds']]) if self.mode == 'val': # if hasattr(self.cf, "losses_to_monitor"): # loss_names = self.cf.losses_to_monitor # else: # loss_names = {name for dic in chunk_dicts for name in dic if 'loss' in name} # estimate patient loss by mean over batch_chunks. Most similar to training loss. results_dict['torch_loss'] = torch.mean(torch.cat([d['torch_loss'] for d in chunk_dicts])) results_dict['class_loss'] = np.mean([d['class_loss'] for d in chunk_dicts]) # discard returned ground-truth boxes (also training info boxes). results_dict['boxes'] = [[box for box in b if box['box_type'] == 'det'] for b in results_dict['boxes']] return results_dict def spatial_tiling_forward(self, batch, patch_crops = None, n_aug='0'): """ forwards batch to batch_tiling_forward method and receives and returns a dictionary with results. if patch-based prediction, the results received from batch_tiling_forward will be on a per-patch-basis. this method uses the provided patch_crops to re-transform all predictions to whole-image coordinates. Patch-origin information of all box-predictions will be needed for consolidation, hence it is stored as 'patch_id', which is a unique string for each patch (also takes current data aug and temporal epoch instances into account). all box predictions get additional information about the amount overlapping patches at the respective position (used for consolidation). :return. results_dict: stores the results for one patient. dictionary with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) - monitor_values (only in validation mode) returned dict is a flattened version with 1 batch instance (3D) or slices (2D) """ if patch_crops is not None: #print("patch_crops not None, applying patch center factor") patches_dict = self.batch_tiling_forward(batch) results_dict = {'boxes': [[] for _ in range(batch['original_img_shape'][0])]} #bc of ohe--> channel dim of seg has size num_classes out_seg_shape = list(batch['original_img_shape']) out_seg_shape[1] = patches_dict["seg_preds"].shape[1] out_seg_preds = np.zeros(out_seg_shape, dtype=np.float16) patch_overlap_map = np.zeros_like(out_seg_preds, dtype='uint8') for pix, pc in enumerate(patch_crops): if self.cf.dim == 3: out_seg_preds[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += patches_dict['seg_preds'][pix] patch_overlap_map[:, :, pc[0]:pc[1], pc[2]:pc[3], pc[4]:pc[5]] += 1 elif self.cf.dim == 2: out_seg_preds[pc[4]:pc[5], :, pc[0]:pc[1], pc[2]:pc[3], ] += patches_dict['seg_preds'][pix] patch_overlap_map[pc[4]:pc[5], :, pc[0]:pc[1], pc[2]:pc[3], ] += 1 out_seg_preds[patch_overlap_map > 0] /= patch_overlap_map[patch_overlap_map > 0] results_dict['seg_preds'] = out_seg_preds for pix, pc in enumerate(patch_crops): patch_boxes = patches_dict['boxes'][pix] for box in patch_boxes: # add unique patch id for consolidation of predictions. box['patch_id'] = self.rank_ix + '_' + n_aug + '_' + str(pix) # boxes from the edges of a patch have a lower prediction quality, than the ones at patch-centers. # hence they will be down-weighted for consolidation, using the 'box_patch_center_factor', which is # obtained by a gaussian distribution over positions in the patch and average over spatial dimensions. # Also the info 'box_n_overlaps' is stored for consolidation, which represents the amount of # overlapping patches at the box's position. c = box['box_coords'] #box_centers = np.array([(c[ii] + c[ii+2])/2 for ii in range(len(c)//2)]) box_centers = [(c[ii] + c[ii + 2]) / 2 for ii in range(2)] if self.cf.dim == 3: box_centers.append((c[4] + c[5]) / 2) box['box_patch_center_factor'] = np.mean( [norm.pdf(bc, loc=pc, scale=pc * 0.8) * np.sqrt(2 * np.pi) * pc * 0.8 for bc, pc in zip(box_centers, np.array(self.cf.patch_size) / 2)]) if self.cf.dim == 3: c += np.array([pc[0], pc[2], pc[0], pc[2], pc[4], pc[4]]) int_c = [int(np.floor(ii)) if ix%2 == 0 else int(np.ceil(ii)) for ix, ii in enumerate(c)] box['box_n_overlaps'] = np.mean(patch_overlap_map[:, :, int_c[1]:int_c[3], int_c[0]:int_c[2], int_c[4]:int_c[5]]) results_dict['boxes'][0].append(box) else: c += np.array([pc[0], pc[2], pc[0], pc[2]]) int_c = [int(np.floor(ii)) if ix % 2 == 0 else int(np.ceil(ii)) for ix, ii in enumerate(c)] box['box_n_overlaps'] = np.mean( patch_overlap_map[pc[4], :, int_c[1]:int_c[3], int_c[0]:int_c[2]]) results_dict['boxes'][pc[4]].append(box) if self.mode == 'val': results_dict['torch_loss'] = patches_dict['torch_loss'] results_dict['class_loss'] = patches_dict['class_loss'] else: results_dict = self.batch_tiling_forward(batch) for b in results_dict['boxes']: for box in b: box['box_patch_center_factor'] = 1 box['box_n_overlaps'] = 1 box['patch_id'] = self.rank_ix + '_' + n_aug return results_dict def data_aug_forward(self, batch): """ in val_mode: passes batch through to spatial_tiling method without data_aug. in test_mode: if cf.test_aug is set in configs, createst 4 mirrored versions of the input image, passes all of them to the next processing step (spatial_tiling method) and re-transforms returned predictions to original image version. :return. results_dict: stores the results for one patient. dictionary with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions, and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) - loss / class_loss (only in validation mode) """ patch_crops = batch['patch_crop_coords'] if self.patched_patient else None results_list = [self.spatial_tiling_forward(batch, patch_crops)] org_img_shape = batch['original_img_shape'] if self.mode == 'test' and self.cf.test_aug_axes is not None: if isinstance(self.cf.test_aug_axes, (int, float)): self.cf.test_aug_axes = (self.cf.test_aug_axes,) #assert np.all(np.array(self.cf.test_aug_axes)= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']] assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']] chunk_dict['boxes'][ix][boxix]['box_coords'] = coords # re-transform segmentation predictions. chunk_dict['seg_preds'] = np.flip(chunk_dict['seg_preds'], axis=axis) elif hasattr(sp_axis, "__iter__") and tuple(sp_axis)==(0,1) or tuple(sp_axis)==(1,0): #NEED: mirrored patch crops are given as [(y-axis), (x-axis), (y-,x-axis)], obey this order! # mirroring along two axes at same time batch['data'] = np.flip(np.flip(img, axis=axis[0]), axis=axis[1]).copy() chunk_dict = self.spatial_tiling_forward(batch, mirrored_patch_crops[n_aug], n_aug=str(n_aug)) # re-transform coordinates. for ix in range(len(chunk_dict['boxes'])): for boxix in range(len(chunk_dict['boxes'][ix])): coords = chunk_dict['boxes'][ix][boxix]['box_coords'].copy() coords[sp_axis[0]] = org_img_shape[axis[0]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[0]+2] coords[sp_axis[0]+2] = org_img_shape[axis[0]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[0]] coords[sp_axis[1]] = org_img_shape[axis[1]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[1]+2] coords[sp_axis[1]+2] = org_img_shape[axis[1]] - chunk_dict['boxes'][ix][boxix]['box_coords'][sp_axis[1]] assert coords[2] >= coords[0], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']] assert coords[3] >= coords[1], [coords, chunk_dict['boxes'][ix][boxix]['box_coords']] chunk_dict['boxes'][ix][boxix]['box_coords'] = coords # re-transform segmentation predictions. chunk_dict['seg_preds'] = np.flip(np.flip(chunk_dict['seg_preds'], axis=axis[0]), axis=axis[1]).copy() else: raise Exception("Invalid axis type {} in test augs".format(type(axis))) results_list.append(chunk_dict) batch['data'] = img # aggregate all boxes/seg_preds per batch element from data_aug predictions. results_dict = {} results_dict['boxes'] = [[item for d in results_list for item in d['boxes'][batch_instance]] for batch_instance in range(org_img_shape[0])] # results_dict['seg_preds'] = np.array([[item for d in results_list for item in d['seg_preds'][batch_instance]] # for batch_instance in range(org_img_shape[0])]) results_dict['seg_preds'] = np.stack([dic['seg_preds'] for dic in results_list], axis=1) # needs segs probs in seg_preds entry: results_dict['seg_preds'] = np.sum(results_dict['seg_preds'], axis=1) #add up seg probs from different augs per class if self.mode == 'val': results_dict['torch_loss'] = results_list[0]['torch_loss'] results_dict['class_loss'] = results_list[0]['class_loss'] return results_dict def load_saved_predictions(self): """loads raw predictions saved by self.predict_test_set. aggregates and/or merges 2D boxes to 3D cubes for evaluation (if model predicts 2D but evaluation is run in 3D), according to settings config. :return: list_of_results_per_patient: list over patient results. each entry is a dict with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. - 'batch_dices': dice scores as recorded in raw prediction results. - 'seg_preds': not implemented yet. could replace dices by seg preds to have raw seg info available, however would consume critically large memory amount. todo evaluation of instance/semantic segmentation. """ results_file = 'pred_results.pkl' if not self.cf.held_out_test_set else 'pred_results_held_out.pkl' if not self.cf.held_out_test_set or self.cf.eval_test_fold_wise: self.logger.info("loading saved predictions of fold {}".format(self.cf.fold)) with open(os.path.join(self.cf.fold_dir, results_file), 'rb') as handle: results_list = pickle.load(handle) box_results_list = [(res_dict["boxes"], pid) for res_dict, pid in results_list] da_factor = len(self.cf.test_aug_axes)+1 if self.cf.test_aug_axes is not None else 1 self.n_ens = self.cf.test_n_epochs * da_factor self.logger.info('loaded raw test set predictions with n_patients = {} and n_ens = {}'.format( len(results_list), self.n_ens)) else: self.logger.info("loading saved predictions of hold-out test set") fold_dirs = sorted([os.path.join(self.cf.exp_dir, f) for f in os.listdir(self.cf.exp_dir) if os.path.isdir(os.path.join(self.cf.exp_dir, f)) and f.startswith("fold")]) results_list = [] folds_loaded = 0 for fold in range(self.cf.n_cv_splits): fold_dir = os.path.join(self.cf.exp_dir, 'fold_{}'.format(fold)) if fold_dir in fold_dirs: with open(os.path.join(fold_dir, results_file), 'rb') as handle: fold_list = pickle.load(handle) results_list += fold_list folds_loaded += 1 else: self.logger.info("Skipping fold {} since no saved predictions found.".format(fold)) box_results_list = [] for res_dict, pid in results_list: #without filtering gt out: box_results_list.append((res_dict['boxes'], pid)) #it's usually not right to filter out gts here, is it? da_factor = len(self.cf.test_aug_axes)+1 if self.cf.test_aug_axes is not None else 1 self.n_ens = self.cf.test_n_epochs * da_factor * folds_loaded # -------------- aggregation of boxes via clustering ----------------- if self.cf.clustering == "wbc": self.logger.info('applying WBC to test-set predictions with iou {} and n_ens {} over {} patients'.format( self.cf.clustering_iou, self.n_ens, len(box_results_list))) mp_inputs = [[self.regress_flag, ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou, self.n_ens] for ii in box_results_list] del box_results_list pool = Pool(processes=self.cf.n_workers) box_results_list = pool.map(apply_wbc_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() del mp_inputs elif self.cf.clustering == "nms": self.logger.info('applying standard NMS to test-set predictions with iou {} over {} patients.'.format( self.cf.clustering_iou, len(box_results_list))) pool = Pool(processes=self.cf.n_workers) mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou] for ii in box_results_list] del box_results_list box_results_list = pool.map(apply_nms_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() del mp_inputs if self.cf.merge_2D_to_3D_preds: self.logger.info('applying 2Dto3D merging to test-set predictions with iou = {}.'.format(self.cf.merge_3D_iou)) pool = Pool(processes=self.cf.n_workers) mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.merge_3D_iou] for ii in box_results_list] box_results_list = pool.map(apply_2d_3d_merging_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() del mp_inputs for ix in range(len(results_list)): assert np.all(results_list[ix][1] == box_results_list[ix][1]), "pid mismatch between loaded and aggregated results" results_list[ix][0]["boxes"] = box_results_list[ix][0] return results_list # holds (results_dict, pid) def predict_patient(self, batch): """ predicts one patient. called either directly via loop over validation set in exec.py (mode=='val') or from self.predict_test_set (mode=='test). in val mode: adds 3D ground truth info to predictions and runs consolidation and 2Dto3D merging of predictions. in test mode: returns raw predictions (ground truth addition, consolidation, 2D to 3D merging are done in self.predict_test_set, because patient predictions across several epochs might be needed to be collected first, in case of temporal ensembling). :return. results_dict: stores the results for one patient. dictionary with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': pixel-wise predictions. (b, 1, y, x, (z)) - loss / class_loss (only in validation mode) """ if self.mode=="test": self.logger.info('predicting patient {} for fold {} '.format(np.unique(batch['pid']), self.cf.fold)) # True if patient is provided in patches and predictions need to be tiled. self.patched_patient = 'patch_crop_coords' in list(batch.keys()) # forward batch through prediction pipeline. results_dict = self.data_aug_forward(batch) #has seg probs in entry 'seg_preds' if self.mode == 'val': for b in range(batch['patient_bb_target'].shape[0]): for t in range(len(batch['patient_bb_target'][b])): gt_box = {'box_type': 'gt', 'box_coords': batch['patient_bb_target'][b][t], 'class_targets': batch['patient_class_targets'][b][t]} for name in self.cf.roi_items: gt_box.update({name : batch['patient_'+name][b][t]}) results_dict['boxes'][b].append(gt_box) if 'dice' in self.cf.metrics: if self.patched_patient: assert 'patient_seg' in batch.keys(), "Results_dict preds are in original patient shape." results_dict['batch_dices'] = mutils.dice_per_batch_and_class( results_dict['seg_preds'], batch["patient_seg"] if self.patched_patient else batch['seg'], self.cf.num_seg_classes, convert_to_ohe=True) if self.patched_patient and self.cf.clustering == "wbc": wbc_input = [self.regress_flag, results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.clustering_iou, self.n_ens] results_dict['boxes'] = apply_wbc_to_patient(wbc_input)[0] elif self.patched_patient: nms_inputs = [results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.clustering_iou] results_dict['boxes'] = apply_nms_to_patient(nms_inputs)[0] if self.cf.merge_2D_to_3D_preds: results_dict['2D_boxes'] = results_dict['boxes'] merge_dims_inputs = [results_dict['boxes'], 'dummy_pid', self.cf.class_dict, self.cf.merge_3D_iou] results_dict['boxes'] = apply_2d_3d_merging_to_patient(merge_dims_inputs)[0] return results_dict def predict_test_set(self, batch_gen, return_results=True): """ wrapper around test method, which loads multiple (or one) epoch parameters (temporal ensembling), loops through the test set and collects predictions per patient. Also flattens the results per patient and epoch and adds optional ground truth boxes for evaluation. Saves out the raw result list for later analysis and optionally consolidates and returns predictions immediately. :return: (optionally) list_of_results_per_patient: list over patient results. each entry is a dict with keys: - 'boxes': list over batch elements. each element is a list over boxes, where each box is one dictionary: [[box_0, ...], [box_n,...]]. batch elements are slices for 2D predictions (if not merged to 3D), and a dummy batch dimension of 1 for 3D predictions. - 'seg_preds': not implemented yet. todo evaluation of instance/semantic segmentation. """ # -------------- raw predicting ----------------- dict_of_patients_results = OrderedDict() set_of_result_types = set() # get paths of all parameter sets to be loaded for temporal ensembling. (or just one for no temp. ensembling). weight_paths = [os.path.join(self.cf.fold_dir, '{}_best_params.pth'.format(epoch)) for epoch in self.epoch_ranking] for rank_ix, weight_path in enumerate(weight_paths): self.logger.info(('tmp ensembling over rank_ix:{} epoch:{}'.format(rank_ix, weight_path))) self.net.load_state_dict(torch.load(weight_path)) self.net.eval() self.rank_ix = str(rank_ix) with torch.no_grad(): plot_batches = np.random.choice(np.arange(batch_gen['n_test']), size=self.cf.n_test_plots, replace=False) for i in range(batch_gen['n_test']): batch = next(batch_gen['test']) pid = np.unique(batch['pid']) assert len(pid)==1 pid = pid[0] if not pid in dict_of_patients_results.keys(): # store batch info in patient entry of results dict. dict_of_patients_results[pid] = {} dict_of_patients_results[pid]['results_dicts'] = [] dict_of_patients_results[pid]['patient_bb_target'] = batch['patient_bb_target'] for name in self.cf.roi_items: dict_of_patients_results[pid]["patient_"+name] = batch["patient_"+name] stime = time.time() results_dict = self.predict_patient(batch) #only holds "boxes", "seg_preds" # needs ohe seg probs in seg_preds entry: results_dict['seg_preds'] = np.argmax(results_dict['seg_preds'], axis=1)[:,np.newaxis] self.logger.info("predicting patient {} with weight rank {} (progress: {}/{}) took {:.2f}s".format( str(pid), rank_ix, (rank_ix)*batch_gen['n_test']+(i+1), len(weight_paths)*batch_gen['n_test'], time.time()-stime)) if i in plot_batches and (not self.patched_patient or 'patient_data' in batch.keys()): try: # view qualitative results of random test case self.logger.time("test_plot") out_file = os.path.join(self.example_plot_dir, 'batch_example_test_{}_rank_{}.png'.format(self.cf.fold, rank_ix)) plg.view_batch(self.cf, batch, res_dict=results_dict, out_file=out_file, show_seg_ids='dice' in self.cf.metrics, has_colorchannels=self.cf.has_colorchannels, show_gt_labels=True) self.logger.info("generated example test plot {} in {:.2f}s".format(os.path.basename(out_file), self.logger.time("test_plot"))) except Exception as e: self.logger.info("WARNING: error in view_batch: {}".format(e)) if 'dice' in self.cf.metrics: if self.patched_patient: assert 'patient_seg' in batch.keys(), "Results_dict preds are in original patient shape." results_dict['batch_dices'] = mutils.dice_per_batch_and_class( results_dict['seg_preds'], batch["patient_seg"] if self.patched_patient else batch['seg'], self.cf.num_seg_classes, convert_to_ohe=True) dict_of_patients_results[pid]['results_dicts'].append({k:v for k,v in results_dict.items() if k in ["boxes", "batch_dices"]}) # collect result types to know which ones to look for when saving set_of_result_types.update(dict_of_patients_results[pid]['results_dicts'][-1].keys()) # -------------- re-order, save raw results ----------------- self.logger.info('finished predicting test set. starting aggregation of predictions.') results_per_patient = [] for pid, p_dict in dict_of_patients_results.items(): # dict_of_patients_results[pid]['results_list'] has length batch['n_test'] results_dict = {} # collect all boxes/seg_preds of same batch_instance over temporal instances. b_size = len(p_dict['results_dicts'][0]["boxes"]) for res_type in [rtype for rtype in set_of_result_types if rtype in ["boxes", "batch_dices"]]:#, "seg_preds"]]: if not 'batch' in res_type: #assume it's results on batch-element basis results_dict[res_type] = [[item for rank_dict in p_dict['results_dicts'] for item in rank_dict[res_type][batch_instance]] for batch_instance in range(b_size)] else: results_dict[res_type] = [] for dict in p_dict['results_dicts']: if 'dice' in res_type: item = dict[res_type] #dict['batch_dices'] has shape (num_seg_classes,) assert len(item) == self.cf.num_seg_classes, \ "{}, {}".format(len(item), self.cf.num_seg_classes) else: raise NotImplementedError results_dict[res_type].append(item) # rdict[dice] shape (n_rank_epochs (n_saved_ranks), nsegclasses) # calc mean over test epochs so inline with shape from sampling results_dict[res_type] = np.mean(results_dict[res_type], axis=0) #maybe error type with other than dice if not hasattr(self.cf, "eval_test_separately") or not self.cf.eval_test_separately: # add unpatched 2D or 3D (if dim==3 or merge_2D_to_3D) ground truth boxes for evaluation. for b in range(p_dict['patient_bb_target'].shape[0]): for targ in range(len(p_dict['patient_bb_target'][b])): gt_box = {'box_type': 'gt', 'box_coords':p_dict['patient_bb_target'][b][targ], 'class_targets': p_dict['patient_class_targets'][b][targ]} for name in self.cf.roi_items: gt_box.update({name: p_dict["patient_"+name][b][targ]}) results_dict['boxes'][b].append(gt_box) results_per_patient.append([results_dict, pid]) out_string = 'pred_results_held_out' if self.cf.held_out_test_set else 'pred_results' with open(os.path.join(self.cf.fold_dir, '{}.pkl'.format(out_string)), 'wb') as handle: pickle.dump(results_per_patient, handle) if return_results: # -------------- results processing, clustering, etc. ----------------- final_patient_box_results = [ (res_dict["boxes"], pid) for res_dict,pid in results_per_patient ] if self.cf.clustering == "wbc": self.logger.info('applying WBC to test-set predictions with iou = {} and n_ens = {}.'.format( self.cf.clustering_iou, self.n_ens)) mp_inputs = [[self.regress_flag, ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou, self.n_ens] for ii in final_patient_box_results] del final_patient_box_results pool = Pool(processes=self.cf.n_workers) final_patient_box_results = pool.map(apply_wbc_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() del mp_inputs elif self.cf.clustering == "nms": self.logger.info('applying standard NMS to test-set predictions with iou = {}.'.format(self.cf.clustering_iou)) pool = Pool(processes=self.cf.n_workers) mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.clustering_iou] for ii in final_patient_box_results] del final_patient_box_results final_patient_box_results = pool.map(apply_nms_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() del mp_inputs if self.cf.merge_2D_to_3D_preds: self.logger.info('applying 2D-to-3D merging to test-set predictions with iou = {}.'.format(self.cf.merge_3D_iou)) mp_inputs = [[ii[0], ii[1], self.cf.class_dict, self.cf.merge_3D_iou] for ii in final_patient_box_results] del final_patient_box_results pool = Pool(processes=self.cf.n_workers) final_patient_box_results = pool.map(apply_2d_3d_merging_to_patient, mp_inputs, chunksize=1) pool.close() pool.join() del mp_inputs # final_patient_box_results holds [avg_boxes, pid] if wbc for ix in range(len(results_per_patient)): assert results_per_patient[ix][1] == final_patient_box_results[ix][1], "should be same pid" results_per_patient[ix][0]["boxes"] = final_patient_box_results[ix][0] # results_per_patient = [(res_dict["boxes"] = boxes, pid) for (boxes,pid) in final_patient_box_results] return results_per_patient # holds list of (results_dict, pid) diff --git a/utils/model_utils.py b/utils/model_utils.py index ce934c5..7fbf51b 100644 --- a/utils/model_utils.py +++ b/utils/model_utils.py @@ -1,1525 +1,1524 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn published under MIT license. """ import warnings warnings.filterwarnings('ignore', '.*From scipy 0.13.0, the output shape of zoom()*') import numpy as np import scipy.misc import scipy.ndimage import scipy.interpolate from scipy.ndimage.measurements import label as lb import torch import tqdm from custom_extensions.nms import nms from custom_extensions.roi_align import roi_align ############################################################ # Segmentation Processing ############################################################ def sum_tensor(input, axes, keepdim=False): axes = np.unique(axes) if keepdim: for ax in axes: input = input.sum(ax, keepdim=True) else: for ax in sorted(axes, reverse=True): input = input.sum(int(ax)) return input def get_one_hot_encoding(y, n_classes): """ transform a numpy label array to a one-hot array of the same shape. :param y: array of shape (b, 1, y, x, (z)). :param n_classes: int, number of classes to unfold in one-hot encoding. :return y_ohe: array of shape (b, n_classes, y, x, (z)) """ dim = len(y.shape) - 2 if dim == 2: y_ohe = np.zeros((y.shape[0], n_classes, y.shape[2], y.shape[3])).astype('int32') elif dim == 3: y_ohe = np.zeros((y.shape[0], n_classes, y.shape[2], y.shape[3], y.shape[4])).astype('int32') else: raise Exception("invalid dimensions {} encountered".format(y.shape)) for cl in np.arange(n_classes): y_ohe[:, cl][y[:, 0] == cl] = 1 return y_ohe def dice_per_batch_inst_and_class(pred, y, n_classes, convert_to_ohe=True, smooth=1e-8): ''' computes dice scores per batch instance and class. :param pred: prediction array of shape (b, 1, y, x, (z)) (e.g. softmax prediction with argmax over dim 1) :param y: ground truth array of shape (b, 1, y, x, (z)) (contains int [0, ..., n_classes] :param n_classes: int :return: dice scores of shape (b, c) ''' if convert_to_ohe: pred = get_one_hot_encoding(pred, n_classes) y = get_one_hot_encoding(y, n_classes) axes = tuple(range(2, len(pred.shape))) intersect = np.sum(pred*y, axis=axes) denominator = np.sum(pred, axis=axes)+np.sum(y, axis=axes) dice = (2.0*intersect + smooth) / (denominator + smooth) return dice def dice_per_batch_and_class(pred, targ, n_classes, convert_to_ohe=True, smooth=1e-8): ''' computes dice scores per batch and class. :param pred: prediction array of shape (b, 1, y, x, (z)) (e.g. softmax prediction with argmax over dim 1) :param targ: ground truth array of shape (b, 1, y, x, (z)) (contains int [0, ..., n_classes]) :param n_classes: int :param smooth: Laplacian smooth, https://en.wikipedia.org/wiki/Additive_smoothing :return: dice scores of shape (b, c) ''' if convert_to_ohe: pred = get_one_hot_encoding(pred, n_classes) targ = get_one_hot_encoding(targ, n_classes) axes = (0, *list(range(2, len(pred.shape)))) #(0,2,3(,4)) intersect = np.sum(pred * targ, axis=axes) denominator = np.sum(pred, axis=axes) + np.sum(targ, axis=axes) dice = (2.0 * intersect + smooth) / (denominator + smooth) assert dice.shape==(n_classes,), "dice shp {}".format(dice.shape) return dice def batch_dice(pred, y, false_positive_weight=1.0, eps=1e-6): ''' compute soft dice over batch. this is a differentiable score and can be used as a loss function. only dice scores of foreground classes are returned, since training typically does not benefit from explicit background optimization. Pixels of the entire batch are considered a pseudo-volume to compute dice scores of. This way, single patches with missing foreground classes can not produce faulty gradients. :param pred: (b, c, y, x, (z)), softmax probabilities (network output). :param y: (b, c, y, x, (z)), one hote encoded segmentation mask. :param false_positive_weight: float [0,1]. For weighting of imbalanced classes, reduces the penalty for false-positive pixels. Can be beneficial sometimes in data with heavy fg/bg imbalances. :return: soft dice score (float).This function discards the background score and returns the mena of foreground scores. ''' # todo also use additive smooth here instead of eps? if len(pred.size()) == 4: axes = (0, 2, 3) intersect = sum_tensor(pred * y, axes, keepdim=False) denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False) return torch.mean((2 * intersect / (denom + eps))[1:]) #only fg dice here. if len(pred.size()) == 5: axes = (0, 2, 3, 4) intersect = sum_tensor(pred * y, axes, keepdim=False) denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False) return torch.mean((2 * intersect / (denom + eps))[1:]) #only fg dice here. else: raise ValueError('wrong input dimension in dice loss') ############################################################ # Bounding Boxes ############################################################ def compute_iou_2D(box, boxes, box_area, boxes_area): """Calculates IoU of the given box with the array of the given boxes. box: 1D vector [y1, x1, y2, x2] THIS IS THE GT BOX boxes: [boxes_count, (y1, x1, y2, x2)] box_area: float. the area of 'box' boxes_area: array of length boxes_count. Note: the areas are passed in rather than calculated here for efficency. Calculate once in the caller to avoid duplicate work. """ # Calculate intersection areas y1 = np.maximum(box[0], boxes[:, 0]) y2 = np.minimum(box[2], boxes[:, 2]) x1 = np.maximum(box[1], boxes[:, 1]) x2 = np.minimum(box[3], boxes[:, 3]) intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0) union = box_area + boxes_area[:] - intersection[:] iou = intersection / union return iou def compute_iou_3D(box, boxes, box_volume, boxes_volume): """Calculates IoU of the given box with the array of the given boxes. box: 1D vector [y1, x1, y2, x2, z1, z2] (typically gt box) boxes: [boxes_count, (y1, x1, y2, x2, z1, z2)] box_area: float. the area of 'box' boxes_area: array of length boxes_count. Note: the areas are passed in rather than calculated here for efficency. Calculate once in the caller to avoid duplicate work. """ # Calculate intersection areas y1 = np.maximum(box[0], boxes[:, 0]) y2 = np.minimum(box[2], boxes[:, 2]) x1 = np.maximum(box[1], boxes[:, 1]) x2 = np.minimum(box[3], boxes[:, 3]) z1 = np.maximum(box[4], boxes[:, 4]) z2 = np.minimum(box[5], boxes[:, 5]) intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0) * np.maximum(z2 - z1, 0) union = box_volume + boxes_volume[:] - intersection[:] iou = intersection / union return iou def compute_overlaps(boxes1, boxes2): """Computes IoU overlaps between two sets of boxes. boxes1, boxes2: [N, (y1, x1, y2, x2)]. / 3D: (z1, z2)) For better performance, pass the largest set first and the smaller second. :return: (#boxes1, #boxes2), ious of each box of 1 machted with each of 2 """ # Areas of anchors and GT boxes if boxes1.shape[1] == 4: area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) # Compute overlaps to generate matrix [boxes1 count, boxes2 count] # Each cell contains the IoU value. overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0])) for i in range(overlaps.shape[1]): box2 = boxes2[i] #this is the gt box overlaps[:, i] = compute_iou_2D(box2, boxes1, area2[i], area1) return overlaps else: # Areas of anchors and GT boxes volume1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) * (boxes1[:, 5] - boxes1[:, 4]) volume2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) * (boxes2[:, 5] - boxes2[:, 4]) # Compute overlaps to generate matrix [boxes1 count, boxes2 count] # Each cell contains the IoU value. overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0])) for i in range(boxes2.shape[0]): box2 = boxes2[i] # this is the gt box overlaps[:, i] = compute_iou_3D(box2, boxes1, volume2[i], volume1) return overlaps def box_refinement(box, gt_box): """Compute refinement needed to transform box to gt_box. box and gt_box are [N, (y1, x1, y2, x2)] / 3D: (z1, z2)) """ height = box[:, 2] - box[:, 0] width = box[:, 3] - box[:, 1] center_y = box[:, 0] + 0.5 * height center_x = box[:, 1] + 0.5 * width gt_height = gt_box[:, 2] - gt_box[:, 0] gt_width = gt_box[:, 3] - gt_box[:, 1] gt_center_y = gt_box[:, 0] + 0.5 * gt_height gt_center_x = gt_box[:, 1] + 0.5 * gt_width dy = (gt_center_y - center_y) / height dx = (gt_center_x - center_x) / width dh = torch.log(gt_height / height) dw = torch.log(gt_width / width) result = torch.stack([dy, dx, dh, dw], dim=1) if box.shape[1] > 4: depth = box[:, 5] - box[:, 4] center_z = box[:, 4] + 0.5 * depth gt_depth = gt_box[:, 5] - gt_box[:, 4] gt_center_z = gt_box[:, 4] + 0.5 * gt_depth dz = (gt_center_z - center_z) / depth dd = torch.log(gt_depth / depth) result = torch.stack([dy, dx, dz, dh, dw, dd], dim=1) return result def unmold_mask_2D(mask, bbox, image_shape): """Converts a mask generated by the neural network into a format similar to it's original shape. mask: [height, width] of type float. A small, typically 28x28 mask. bbox: [y1, x1, y2, x2]. The box to fit the mask in. Returns a binary mask with the same size as the original image. """ y1, x1, y2, x2 = bbox out_zoom = [y2 - y1, x2 - x1] zoom_factor = [i / j for i, j in zip(out_zoom, mask.shape)] mask = scipy.ndimage.zoom(mask, zoom_factor, order=1).astype(np.float32) # Put the mask in the right location. full_mask = np.zeros(image_shape[:2]) #only y,x full_mask[y1:y2, x1:x2] = mask return full_mask def unmold_mask_2D_torch(mask, bbox, image_shape): """Converts a mask generated by the neural network into a format similar to it's original shape. mask: [height, width] of type float. A small, typically 28x28 mask. bbox: [y1, x1, y2, x2]. The box to fit the mask in. Returns a binary mask with the same size as the original image. """ y1, x1, y2, x2 = bbox out_zoom = [(y2 - y1).float(), (x2 - x1).float()] zoom_factor = [i / j for i, j in zip(out_zoom, mask.shape)] mask = mask.unsqueeze(0).unsqueeze(0) mask = torch.nn.functional.interpolate(mask, scale_factor=zoom_factor) mask = mask[0][0] #mask = scipy.ndimage.zoom(mask.cpu().numpy(), zoom_factor, order=1).astype(np.float32) #mask = torch.from_numpy(mask).cuda() # Put the mask in the right location. full_mask = torch.zeros(image_shape[:2]) # only y,x full_mask[y1:y2, x1:x2] = mask return full_mask def unmold_mask_3D(mask, bbox, image_shape): """Converts a mask generated by the neural network into a format similar to it's original shape. mask: [height, width] of type float. A small, typically 28x28 mask. bbox: [y1, x1, y2, x2, z1, z2]. The box to fit the mask in. Returns a binary mask with the same size as the original image. """ y1, x1, y2, x2, z1, z2 = bbox out_zoom = [y2 - y1, x2 - x1, z2 - z1] zoom_factor = [i/j for i,j in zip(out_zoom, mask.shape)] mask = scipy.ndimage.zoom(mask, zoom_factor, order=1).astype(np.float32) # Put the mask in the right location. full_mask = np.zeros(image_shape[:3]) full_mask[y1:y2, x1:x2, z1:z2] = mask return full_mask def nms_numpy(box_coords, scores, thresh): """ non-maximum suppression on 2D or 3D boxes in numpy. :param box_coords: [y1,x1,y2,x2 (,z1,z2)] with y1<=y2, x1<=x2, z1<=z2. :param scores: ranking scores (higher score == higher rank) of boxes. :param thresh: IoU threshold for clustering. :return: """ y1 = box_coords[:, 0] x1 = box_coords[:, 1] y2 = box_coords[:, 2] x2 = box_coords[:, 3] assert np.all(y1 <= y2) and np.all(x1 <= x2), """"the definition of the coordinates is crucially important here: coordinates of which maxima are taken need to be the lower coordinates""" areas = (x2 - x1) * (y2 - y1) is_3d = box_coords.shape[1] == 6 if is_3d: # 3-dim case z1 = box_coords[:, 4] z2 = box_coords[:, 5] assert np.all(z1<=z2), """"the definition of the coordinates is crucially important here: coordinates of which maxima are taken need to be the lower coordinates""" areas *= (z2 - z1) order = scores.argsort()[::-1] keep = [] while order.size > 0: # order is the sorted index. maps order to index: order[1] = 24 means (rank1, ix 24) i = order[0] # highest scoring element yy1 = np.maximum(y1[i], y1[order]) # highest scoring element still in >order<, is compared to itself, that is okay. xx1 = np.maximum(x1[i], x1[order]) yy2 = np.minimum(y2[i], y2[order]) xx2 = np.minimum(x2[i], x2[order]) h = np.maximum(0.0, yy2 - yy1) w = np.maximum(0.0, xx2 - xx1) inter = h * w if is_3d: zz1 = np.maximum(z1[i], z1[order]) zz2 = np.minimum(z2[i], z2[order]) d = np.maximum(0.0, zz2 - zz1) inter *= d iou = inter / (areas[i] + areas[order] - inter) non_matches = np.nonzero(iou <= thresh)[0] # get all elements that were not matched and discard all others. - #print("iou keep {}: {}, non_matches {}".format(i, iou, order[non_matches])) order = order[non_matches] keep.append(i) - #print("total keep", keep) + return keep ############################################################ # M-RCNN ############################################################ def refine_proposals(rpn_pred_probs, rpn_pred_deltas, proposal_count, batch_anchors, cf): """ Receives anchor scores and selects a subset to pass as proposals to the second stage. Filtering is done based on anchor scores and non-max suppression to remove overlaps. It also applies bounding box refinment details to anchors. :param rpn_pred_probs: (b, n_anchors, 2) :param rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d)))) :return: batch_normalized_props: Proposals in normalized coordinates (b, proposal_count, (y1, x1, y2, x2, (z1), (z2), score)) :return: batch_out_proposals: Box coords + RPN foreground scores for monitoring/plotting (b, proposal_count, (y1, x1, y2, x2, (z1), (z2), score)) """ std_dev = torch.from_numpy(cf.rpn_bbox_std_dev[None]).float().cuda() norm = torch.from_numpy(cf.scale).float().cuda() anchors = batch_anchors.clone() batch_scores = rpn_pred_probs[:, :, 1] # norm deltas batch_deltas = rpn_pred_deltas * std_dev batch_normalized_props = [] batch_out_proposals = [] # loop over batch dimension. for ix in range(batch_scores.shape[0]): scores = batch_scores[ix] deltas = batch_deltas[ix] # improve performance by trimming to top anchors by score # and doing the rest on the smaller subset. pre_nms_limit = min(cf.pre_nms_limit, anchors.size()[0]) scores, order = scores.sort(descending=True) order = order[:pre_nms_limit] scores = scores[:pre_nms_limit] deltas = deltas[order, :] # apply deltas to anchors to get refined anchors and filter with non-maximum suppression. if batch_deltas.shape[-1] == 4: boxes = apply_box_deltas_2D(anchors[order, :], deltas) boxes = clip_boxes_2D(boxes, cf.window) else: boxes = apply_box_deltas_3D(anchors[order, :], deltas) boxes = clip_boxes_3D(boxes, cf.window) # boxes are y1,x1,y2,x2, torchvision-nms requires x1,y1,x2,y2, but consistent swap x<->y is irrelevant. keep = nms.nms(boxes, scores, cf.rpn_nms_threshold) keep = keep[:proposal_count] boxes = boxes[keep, :] rpn_scores = scores[keep][:, None] # pad missing boxes with 0. if boxes.shape[0] < proposal_count: n_pad_boxes = proposal_count - boxes.shape[0] zeros = torch.zeros([n_pad_boxes, boxes.shape[1]]).cuda() boxes = torch.cat([boxes, zeros], dim=0) zeros = torch.zeros([n_pad_boxes, rpn_scores.shape[1]]).cuda() rpn_scores = torch.cat([rpn_scores, zeros], dim=0) # concat box and score info for monitoring/plotting. batch_out_proposals.append(torch.cat((boxes, rpn_scores), 1).cpu().data.numpy()) # normalize dimensions to range of 0 to 1. normalized_boxes = boxes / norm assert torch.all(normalized_boxes <= 1), "normalized box coords >1 found" # add again batch dimension batch_normalized_props.append(torch.cat((normalized_boxes, rpn_scores), 1).unsqueeze(0)) batch_normalized_props = torch.cat(batch_normalized_props) batch_out_proposals = np.array(batch_out_proposals) return batch_normalized_props, batch_out_proposals def pyramid_roi_align(feature_maps, rois, pool_size, pyramid_levels, dim): """ Implements ROI Pooling on multiple levels of the feature pyramid. :param feature_maps: list of feature maps, each of shape (b, c, y, x , (z)) :param rois: proposals (normalized coords.) as returned by RPN. contain info about original batch element allocation. (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ixs) :param pool_size: list of poolsizes in dims: [x, y, (z)] :param pyramid_levels: list. [0, 1, 2, ...] :return: pooled: pooled feature map rois (n_proposals, c, poolsize_y, poolsize_x, (poolsize_z)) Output: Pooled regions in the shape: [num_boxes, height, width, channels]. The width and height are those specific in the pool_shape in the layer constructor. """ boxes = rois[:, :dim*2] batch_ixs = rois[:, dim*2] # Assign each ROI to a level in the pyramid based on the ROI area. if dim == 2: y1, x1, y2, x2 = boxes.chunk(4, dim=1) else: y1, x1, y2, x2, z1, z2 = boxes.chunk(6, dim=1) h = y2 - y1 w = x2 - x1 # Equation 1 in https://arxiv.org/abs/1612.03144. Account for # the fact that our coordinates are normalized here. # divide sqrt(h*w) by 1 instead image_area. roi_level = (4 + torch.log2(torch.sqrt(h*w))).round().int().clamp(pyramid_levels[0], pyramid_levels[-1]) # if Pyramid contains additional level P6, adapt the roi_level assignment accordingly. if len(pyramid_levels) == 5: roi_level[h*w > 0.65] = 5 # Loop through levels and apply ROI pooling to each. pooled = [] box_to_level = [] fmap_shapes = [f.shape for f in feature_maps] for level_ix, level in enumerate(pyramid_levels): ix = roi_level == level if not ix.any(): continue ix = torch.nonzero(ix)[:, 0] level_boxes = boxes[ix, :] # re-assign rois to feature map of original batch element. ind = batch_ixs[ix].int() # Keep track of which box is mapped to which level box_to_level.append(ix) # Stop gradient propogation to ROI proposals level_boxes = level_boxes.detach() if len(pool_size) == 2: # remap to feature map coordinate system y_exp, x_exp = fmap_shapes[level_ix][2:] # exp = expansion level_boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp], dtype=torch.float32).cuda()) pooled_features = roi_align.roi_align_2d(feature_maps[level_ix], torch.cat((ind.unsqueeze(1).float(), level_boxes), dim=1), pool_size) else: y_exp, x_exp, z_exp = fmap_shapes[level_ix][2:] level_boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp, z_exp, z_exp], dtype=torch.float32).cuda()) pooled_features = roi_align.roi_align_3d(feature_maps[level_ix], torch.cat((ind.unsqueeze(1).float(), level_boxes), dim=1), pool_size) pooled.append(pooled_features) # Pack pooled features into one tensor pooled = torch.cat(pooled, dim=0) # Pack box_to_level mapping into one array and add another # column representing the order of pooled boxes box_to_level = torch.cat(box_to_level, dim=0) # Rearrange pooled features to match the order of the original boxes _, box_to_level = torch.sort(box_to_level) pooled = pooled[box_to_level, :, :] return pooled def roi_align_3d_numpy(input: np.ndarray, rois, output_size: tuple, spatial_scale: float = 1., sampling_ratio: int = -1) -> np.ndarray: """ This fct mainly serves as a verification method for 3D CUDA implementation of RoIAlign, it's highly inefficient due to the nested loops. :param input: (ndarray[N, C, H, W, D]): input feature map :param rois: list (N,K(n), 6), K(n) = nr of rois in batch-element n, single roi of format (y1,x1,y2,x2,z1,z2) :param output_size: :param spatial_scale: :param sampling_ratio: :return: (List[N, K(n), C, output_size[0], output_size[1], output_size[2]]) """ out_height, out_width, out_depth = output_size coord_grid = tuple([np.linspace(0, input.shape[dim] - 1, num=input.shape[dim]) for dim in range(2, 5)]) pooled_rois = [[]] * len(rois) assert len(rois) == input.shape[0], "batch dim mismatch, rois: {}, input: {}".format(len(rois), input.shape[0]) print("Numpy 3D RoIAlign progress:", end="\n") for b in range(input.shape[0]): for roi in tqdm.tqdm(rois[b]): y1, x1, y2, x2, z1, z2 = np.array(roi) * spatial_scale roi_height = max(float(y2 - y1), 1.) roi_width = max(float(x2 - x1), 1.) roi_depth = max(float(z2 - z1), 1.) if sampling_ratio <= 0: sampling_ratio_h = int(np.ceil(roi_height / out_height)) sampling_ratio_w = int(np.ceil(roi_width / out_width)) sampling_ratio_d = int(np.ceil(roi_depth / out_depth)) else: sampling_ratio_h = sampling_ratio_w = sampling_ratio_d = sampling_ratio # == n points per bin bin_height = roi_height / out_height bin_width = roi_width / out_width bin_depth = roi_depth / out_depth n_points = sampling_ratio_h * sampling_ratio_w * sampling_ratio_d pooled_roi = np.empty((input.shape[1], out_height, out_width, out_depth), dtype="float32") for chan in range(input.shape[1]): lin_interpolator = scipy.interpolate.RegularGridInterpolator(coord_grid, input[b, chan], method="linear") for bin_iy in range(out_height): for bin_ix in range(out_width): for bin_iz in range(out_depth): bin_val = 0. for i in range(sampling_ratio_h): for j in range(sampling_ratio_w): for k in range(sampling_ratio_d): loc_ijk = [ y1 + bin_iy * bin_height + (i + 0.5) * (bin_height / sampling_ratio_h), x1 + bin_ix * bin_width + (j + 0.5) * (bin_width / sampling_ratio_w), z1 + bin_iz * bin_depth + (k + 0.5) * (bin_depth / sampling_ratio_d)] # print("loc_ijk", loc_ijk) if not (np.any([c < -1.0 for c in loc_ijk]) or loc_ijk[0] > input.shape[2] or loc_ijk[1] > input.shape[3] or loc_ijk[2] > input.shape[4]): for catch_case in range(3): # catch on-border cases if int(loc_ijk[catch_case]) == input.shape[catch_case + 2] - 1: loc_ijk[catch_case] = input.shape[catch_case + 2] - 1 bin_val += lin_interpolator(loc_ijk) pooled_roi[chan, bin_iy, bin_ix, bin_iz] = bin_val / n_points pooled_rois[b].append(pooled_roi) return np.array(pooled_rois) def refine_detections(cf, batch_ixs, rois, deltas, scores, regressions): """ Refine classified proposals (apply deltas to rpn rois), filter overlaps (nms) and return final detections. :param rois: (n_proposals, 2 * dim) normalized boxes as proposed by RPN. n_proposals = batch_size * POST_NMS_ROIS :param deltas: (n_proposals, n_classes, 2 * dim) box refinement deltas as predicted by mrcnn bbox regressor. :param batch_ixs: (n_proposals) batch element assignment info for re-allocation. :param scores: (n_proposals, n_classes) probabilities for all classes per roi as predicted by mrcnn classifier. :param regressions: (n_proposals, n_classes, regression_features (+1 for uncertainty if predicted) regression vector :return: result: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score, *regression vector features)) """ # class IDs per ROI. Since scores of all classes are of interest (not just max class), all are kept at this point. class_ids = [] fg_classes = cf.head_classes - 1 # repeat vectors to fill in predictions for all foreground classes. for ii in range(1, fg_classes + 1): class_ids += [ii] * rois.shape[0] class_ids = torch.from_numpy(np.array(class_ids)).cuda() batch_ixs = batch_ixs.repeat(fg_classes) rois = rois.repeat(fg_classes, 1) deltas = deltas.repeat(fg_classes, 1, 1) scores = scores.repeat(fg_classes, 1) regressions = regressions.repeat(fg_classes, 1, 1) # get class-specific scores and bounding box deltas idx = torch.arange(class_ids.size()[0]).long().cuda() # using idx instead of slice [:,] squashes first dimension. #len(class_ids)>scores.shape[1] --> probs is broadcasted by expansion from fg_classes-->len(class_ids) batch_ixs = batch_ixs[idx] deltas_specific = deltas[idx, class_ids] class_scores = scores[idx, class_ids] regressions = regressions[idx, class_ids] # apply bounding box deltas. re-scale to image coordinates. std_dev = torch.from_numpy(np.reshape(cf.rpn_bbox_std_dev, [1, cf.dim * 2])).float().cuda() scale = torch.from_numpy(cf.scale).float().cuda() refined_rois = apply_box_deltas_2D(rois, deltas_specific * std_dev) * scale if cf.dim == 2 else \ apply_box_deltas_3D(rois, deltas_specific * std_dev) * scale # round and cast to int since we're dealing with pixels now refined_rois = clip_to_window(cf.window, refined_rois) refined_rois = torch.round(refined_rois) # filter out low confidence boxes keep = idx keep_bool = (class_scores >= cf.model_min_confidence) if not 0 in torch.nonzero(keep_bool).size(): score_keep = torch.nonzero(keep_bool)[:, 0] pre_nms_class_ids = class_ids[score_keep] pre_nms_rois = refined_rois[score_keep] pre_nms_scores = class_scores[score_keep] pre_nms_batch_ixs = batch_ixs[score_keep] for j, b in enumerate(unique1d(pre_nms_batch_ixs)): bixs = torch.nonzero(pre_nms_batch_ixs == b)[:, 0] bix_class_ids = pre_nms_class_ids[bixs] bix_rois = pre_nms_rois[bixs] bix_scores = pre_nms_scores[bixs] for i, class_id in enumerate(unique1d(bix_class_ids)): ixs = torch.nonzero(bix_class_ids == class_id)[:, 0] # nms expects boxes sorted by score. ix_rois = bix_rois[ixs] ix_scores = bix_scores[ixs] ix_scores, order = ix_scores.sort(descending=True) ix_rois = ix_rois[order, :] class_keep = nms.nms(ix_rois, ix_scores, cf.detection_nms_threshold) # map indices back. class_keep = keep[score_keep[bixs[ixs[order[class_keep]]]]] # merge indices over classes for current batch element b_keep = class_keep if i == 0 else unique1d(torch.cat((b_keep, class_keep))) # only keep top-k boxes of current batch-element top_ids = class_scores[b_keep].sort(descending=True)[1][:cf.model_max_instances_per_batch_element] b_keep = b_keep[top_ids] # merge indices over batch elements. batch_keep = b_keep if j == 0 else unique1d(torch.cat((batch_keep, b_keep))) keep = batch_keep else: keep = torch.tensor([0]).long().cuda() # arrange output output = [refined_rois[keep], batch_ixs[keep].unsqueeze(1)] output += [class_ids[keep].unsqueeze(1).float(), class_scores[keep].unsqueeze(1)] output += [regressions[keep]] result = torch.cat(output, dim=1) # shape: (n_keeps, catted feats), catted feats: [0:dim*2] are box_coords, [dim*2] are batch_ics, # [dim*2+1] are class_ids, [dim*2+2] are scores, [dim*2+3:] are regression vector features (incl uncertainty) return result def loss_example_mining(cf, batch_proposals, batch_gt_boxes, batch_gt_masks, batch_roi_scores, batch_gt_class_ids, batch_gt_regressions): """ Subsamples proposals for mrcnn losses and generates targets. Sampling is done per batch element, seems to have positive effects on training, as opposed to sampling over entire batch. Negatives are sampled via stochastic hard-example mining (SHEM), where a number of negative proposals is drawn from larger pool of highest scoring proposals for stochasticity. Scoring is obtained here as the max over all foreground probabilities as returned by mrcnn_classifier (worked better than loss-based class-balancing methods like "online hard-example mining" or "focal loss".) Classification-regression duality: regressions can be given along with classes (at least fg/bg, only class scores are used for ranking). :param batch_proposals: (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ixs). boxes as proposed by RPN. n_proposals here is determined by batch_size * POST_NMS_ROIS. :param mrcnn_class_logits: (n_proposals, n_classes) :param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates. :param batch_gt_masks: list over batch elements. Each element is binary mask of shape (n_gt_rois, y, x, (z), c) :param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels. if no classes predicted (only fg/bg from RPN): expected as pseudo classes [0, 1] for bg, fg. :param batch_gt_regressions: list over b elements. Each element is a regression target vector. if None--> pseudo :return: sample_indices: (n_sampled_rois) indices of sampled proposals to be used for loss functions. :return: target_class_ids: (n_sampled_rois)containing target class labels of sampled proposals. :return: target_deltas: (n_sampled_rois, 2 * dim) containing target deltas of sampled proposals for box refinement. :return: target_masks: (n_sampled_rois, y, x, (z)) containing target masks of sampled proposals. """ # normalization of target coordinates #global sample_regressions if cf.dim == 2: h, w = cf.patch_size scale = torch.from_numpy(np.array([h, w, h, w])).float().cuda() else: h, w, z = cf.patch_size scale = torch.from_numpy(np.array([h, w, h, w, z, z])).float().cuda() positive_count = 0 negative_count = 0 sample_positive_indices = [] sample_negative_indices = [] sample_deltas = [] sample_masks = [] sample_class_ids = [] if batch_gt_regressions is not None: sample_regressions = [] else: target_regressions = torch.FloatTensor().cuda() # loop over batch and get positive and negative sample rois. for b in range(len(batch_gt_boxes)): gt_masks = torch.from_numpy(batch_gt_masks[b]).float().cuda() gt_class_ids = torch.from_numpy(batch_gt_class_ids[b]).int().cuda() if batch_gt_regressions is not None: gt_regressions = torch.from_numpy(batch_gt_regressions[b]).float().cuda() #if np.any(batch_gt_class_ids[b] > 0): # skip roi selection for no gt images. if np.any([len(coords)>0 for coords in batch_gt_boxes[b]]): gt_boxes = torch.from_numpy(batch_gt_boxes[b]).float().cuda() / scale else: gt_boxes = torch.FloatTensor().cuda() # get proposals and indices of current batch element. proposals = batch_proposals[batch_proposals[:, -1] == b][:, :-1] batch_element_indices = torch.nonzero(batch_proposals[:, -1] == b).squeeze(1) # Compute overlaps matrix [proposals, gt_boxes] if not 0 in gt_boxes.size(): if gt_boxes.shape[1] == 4: assert cf.dim == 2, "gt_boxes shape {} doesnt match cf.dim{}".format(gt_boxes.shape, cf.dim) overlaps = bbox_overlaps_2D(proposals, gt_boxes) else: assert cf.dim == 3, "gt_boxes shape {} doesnt match cf.dim{}".format(gt_boxes.shape, cf.dim) overlaps = bbox_overlaps_3D(proposals, gt_boxes) # Determine positive and negative ROIs roi_iou_max = torch.max(overlaps, dim=1)[0] # 1. Positive ROIs are those with >= 0.5 IoU with a GT box positive_roi_bool = roi_iou_max >= (0.5 if cf.dim == 2 else 0.3) # 2. Negative ROIs are those with < 0.1 with every GT box. negative_roi_bool = roi_iou_max < (0.1 if cf.dim == 2 else 0.01) else: positive_roi_bool = torch.FloatTensor().cuda() negative_roi_bool = torch.from_numpy(np.array([1]*proposals.shape[0])).cuda() # Sample Positive ROIs if not 0 in torch.nonzero(positive_roi_bool).size(): positive_indices = torch.nonzero(positive_roi_bool).squeeze(1) positive_samples = int(cf.train_rois_per_image * cf.roi_positive_ratio) rand_idx = torch.randperm(positive_indices.size()[0]) rand_idx = rand_idx[:positive_samples].cuda() positive_indices = positive_indices[rand_idx] positive_samples = positive_indices.size()[0] positive_rois = proposals[positive_indices, :] # Assign positive ROIs to GT boxes. positive_overlaps = overlaps[positive_indices, :] roi_gt_box_assignment = torch.max(positive_overlaps, dim=1)[1] roi_gt_boxes = gt_boxes[roi_gt_box_assignment, :] roi_gt_class_ids = gt_class_ids[roi_gt_box_assignment] if batch_gt_regressions is not None: roi_gt_regressions = gt_regressions[roi_gt_box_assignment] # Compute bbox refinement targets for positive ROIs deltas = box_refinement(positive_rois, roi_gt_boxes) std_dev = torch.from_numpy(cf.bbox_std_dev).float().cuda() deltas /= std_dev roi_masks = gt_masks[roi_gt_box_assignment].unsqueeze(1) # .squeeze(-1) assert roi_masks.shape[-1] == 1 # Compute mask targets boxes = positive_rois box_ids = torch.arange(roi_masks.shape[0]).cuda().unsqueeze(1).float() if len(cf.mask_shape) == 2: # todo what are the dims of roi_masks? (n_matched_boxes_with_gts, 1 (dummy channel dim), y,x, 1 (WHY?)) masks = roi_align.roi_align_2d(roi_masks, torch.cat((box_ids, boxes), dim=1), cf.mask_shape) else: masks = roi_align.roi_align_3d(roi_masks, torch.cat((box_ids, boxes), dim=1), cf.mask_shape) masks = masks.squeeze(1) # Threshold mask pixels at 0.5 to have GT masks be 0 or 1 to use with # binary cross entropy loss. masks = torch.round(masks) sample_positive_indices.append(batch_element_indices[positive_indices]) sample_deltas.append(deltas) sample_masks.append(masks) sample_class_ids.append(roi_gt_class_ids) if batch_gt_regressions is not None: sample_regressions.append(roi_gt_regressions) positive_count += positive_samples else: positive_samples = 0 # Sample negative ROIs. Add enough to maintain positive:negative ratio, but at least 1. Sample via SHEM. if not 0 in torch.nonzero(negative_roi_bool).size(): negative_indices = torch.nonzero(negative_roi_bool).squeeze(1) r = 1.0 / cf.roi_positive_ratio b_neg_count = np.max((int(r * positive_samples - positive_samples), 1)) roi_scores_neg = batch_roi_scores[batch_element_indices[negative_indices]] raw_sampled_indices = shem(roi_scores_neg, b_neg_count, cf.shem_poolsize) sample_negative_indices.append(batch_element_indices[negative_indices[raw_sampled_indices]]) negative_count += raw_sampled_indices.size()[0] if len(sample_positive_indices) > 0: target_deltas = torch.cat(sample_deltas) target_masks = torch.cat(sample_masks) target_class_ids = torch.cat(sample_class_ids) if batch_gt_regressions is not None: target_regressions = torch.cat(sample_regressions) # Pad target information with zeros for negative ROIs. if positive_count > 0 and negative_count > 0: sample_indices = torch.cat((torch.cat(sample_positive_indices), torch.cat(sample_negative_indices)), dim=0) zeros = torch.zeros(negative_count, cf.dim * 2).cuda() target_deltas = torch.cat([target_deltas, zeros], dim=0) zeros = torch.zeros(negative_count, *cf.mask_shape).cuda() target_masks = torch.cat([target_masks, zeros], dim=0) zeros = torch.zeros(negative_count).int().cuda() target_class_ids = torch.cat([target_class_ids, zeros], dim=0) if batch_gt_regressions is not None: # regression targets need to have 0 as background/negative with below practice if 'regression_bin' in cf.prediction_tasks: zeros = torch.zeros(negative_count, dtype=torch.float).cuda() else: zeros = torch.zeros(negative_count, cf.regression_n_features, dtype=torch.float).cuda() target_regressions = torch.cat([target_regressions, zeros], dim=0) elif positive_count > 0: sample_indices = torch.cat(sample_positive_indices) elif negative_count > 0: sample_indices = torch.cat(sample_negative_indices) target_deltas = torch.zeros(negative_count, cf.dim * 2).cuda() target_masks = torch.zeros(negative_count, *cf.mask_shape).cuda() target_class_ids = torch.zeros(negative_count).int().cuda() if batch_gt_regressions is not None: if 'regression_bin' in cf.prediction_tasks: target_regressions = torch.zeros(negative_count, dtype=torch.float).cuda() else: target_regressions = torch.zeros(negative_count, cf.regression_n_features, dtype=torch.float).cuda() else: sample_indices = torch.LongTensor().cuda() target_class_ids = torch.IntTensor().cuda() target_deltas = torch.FloatTensor().cuda() target_masks = torch.FloatTensor().cuda() target_regressions = torch.FloatTensor().cuda() return sample_indices, target_deltas, target_masks, target_class_ids, target_regressions ############################################################ # Anchors ############################################################ def generate_anchors(scales, ratios, shape, feature_stride, anchor_stride): """ scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128] ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2] shape: [height, width] spatial shape of the feature map over which to generate anchors. feature_stride: Stride of the feature map relative to the image in pixels. anchor_stride: Stride of anchors on the feature map. For example, if the value is 2 then generate anchors for every other feature map pixel. """ # Get all combinations of scales and ratios scales, ratios = np.meshgrid(np.array(scales), np.array(ratios)) scales = scales.flatten() ratios = ratios.flatten() # Enumerate heights and widths from scales and ratios heights = scales / np.sqrt(ratios) widths = scales * np.sqrt(ratios) # Enumerate shifts in feature space shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride shifts_x, shifts_y = np.meshgrid(shifts_x, shifts_y) # Enumerate combinations of shifts, widths, and heights box_widths, box_centers_x = np.meshgrid(widths, shifts_x) box_heights, box_centers_y = np.meshgrid(heights, shifts_y) # Reshape to get a list of (y, x) and a list of (h, w) box_centers = np.stack([box_centers_y, box_centers_x], axis=2).reshape([-1, 2]) box_sizes = np.stack([box_heights, box_widths], axis=2).reshape([-1, 2]) # Convert to corner coordinates (y1, x1, y2, x2) boxes = np.concatenate([box_centers - 0.5 * box_sizes, box_centers + 0.5 * box_sizes], axis=1) return boxes def generate_anchors_3D(scales_xy, scales_z, ratios, shape, feature_stride_xy, feature_stride_z, anchor_stride): """ scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128] ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2] shape: [height, width] spatial shape of the feature map over which to generate anchors. feature_stride: Stride of the feature map relative to the image in pixels. anchor_stride: Stride of anchors on the feature map. For example, if the value is 2 then generate anchors for every other feature map pixel. """ # Get all combinations of scales and ratios scales_xy, ratios_meshed = np.meshgrid(np.array(scales_xy), np.array(ratios)) scales_xy = scales_xy.flatten() ratios_meshed = ratios_meshed.flatten() # Enumerate heights and widths from scales and ratios heights = scales_xy / np.sqrt(ratios_meshed) widths = scales_xy * np.sqrt(ratios_meshed) depths = np.tile(np.array(scales_z), len(ratios_meshed)//np.array(scales_z)[..., None].shape[0]) # Enumerate shifts in feature space shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride_xy #translate from fm positions to input coords. shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride_xy shifts_z = np.arange(0, shape[2], anchor_stride) * (feature_stride_z) shifts_x, shifts_y, shifts_z = np.meshgrid(shifts_x, shifts_y, shifts_z) # Enumerate combinations of shifts, widths, and heights box_widths, box_centers_x = np.meshgrid(widths, shifts_x) box_heights, box_centers_y = np.meshgrid(heights, shifts_y) box_depths, box_centers_z = np.meshgrid(depths, shifts_z) # Reshape to get a list of (y, x, z) and a list of (h, w, d) box_centers = np.stack( [box_centers_y, box_centers_x, box_centers_z], axis=2).reshape([-1, 3]) box_sizes = np.stack([box_heights, box_widths, box_depths], axis=2).reshape([-1, 3]) # Convert to corner coordinates (y1, x1, y2, x2, z1, z2) boxes = np.concatenate([box_centers - 0.5 * box_sizes, box_centers + 0.5 * box_sizes], axis=1) boxes = np.transpose(np.array([boxes[:, 0], boxes[:, 1], boxes[:, 3], boxes[:, 4], boxes[:, 2], boxes[:, 5]]), axes=(1, 0)) return boxes def generate_pyramid_anchors(logger, cf): """Generate anchors at different levels of a feature pyramid. Each scale is associated with a level of the pyramid, but each ratio is used in all levels of the pyramid. from configs: :param scales: cf.RPN_ANCHOR_SCALES , for conformity with retina nets: scale entries need to be list, e.g. [[4], [8], [16], [32]] :param ratios: cf.RPN_ANCHOR_RATIOS , e.g. [0.5, 1, 2] :param feature_shapes: cf.BACKBONE_SHAPES , e.g. [array of shapes per feature map] [80, 40, 20, 10, 5] :param feature_strides: cf.BACKBONE_STRIDES , e.g. [2, 4, 8, 16, 32, 64] :param anchors_stride: cf.RPN_ANCHOR_STRIDE , e.g. 1 :return anchors: (N, (y1, x1, y2, x2, (z1), (z2)). All generated anchors in one array. Sorted with the same order of the given scales. So, anchors of scale[0] come first, then anchors of scale[1], and so on. """ scales = cf.rpn_anchor_scales ratios = cf.rpn_anchor_ratios feature_shapes = cf.backbone_shapes anchor_stride = cf.rpn_anchor_stride pyramid_levels = cf.pyramid_levels feature_strides = cf.backbone_strides logger.info("anchor scales {} and feature map shapes {}".format(scales, feature_shapes)) expected_anchors = [np.prod(feature_shapes[level]) * len(ratios) * len(scales['xy'][level]) for level in pyramid_levels] anchors = [] for lix, level in enumerate(pyramid_levels): if len(feature_shapes[level]) == 2: anchors.append(generate_anchors(scales['xy'][level], ratios, feature_shapes[level], feature_strides['xy'][level], anchor_stride)) elif len(feature_shapes[level]) == 3: anchors.append(generate_anchors_3D(scales['xy'][level], scales['z'][level], ratios, feature_shapes[level], feature_strides['xy'][level], feature_strides['z'][level], anchor_stride)) else: raise Exception("invalid feature_shapes[{}] size {}".format(level, feature_shapes[level])) logger.info("level {}: expected anchors {}, built anchors {}.".format(level, expected_anchors[lix], anchors[-1].shape)) out_anchors = np.concatenate(anchors, axis=0) logger.info("Total: expected anchors {}, built anchors {}.".format(np.sum(expected_anchors), out_anchors.shape)) return out_anchors def apply_box_deltas_2D(boxes, deltas): """Applies the given deltas to the given boxes. boxes: [N, 4] where each row is y1, x1, y2, x2 deltas: [N, 4] where each row is [dy, dx, log(dh), log(dw)] """ # Convert to y, x, h, w height = boxes[:, 2] - boxes[:, 0] width = boxes[:, 3] - boxes[:, 1] center_y = boxes[:, 0] + 0.5 * height center_x = boxes[:, 1] + 0.5 * width # Apply deltas center_y += deltas[:, 0] * height center_x += deltas[:, 1] * width height *= torch.exp(deltas[:, 2]) width *= torch.exp(deltas[:, 3]) # Convert back to y1, x1, y2, x2 y1 = center_y - 0.5 * height x1 = center_x - 0.5 * width y2 = y1 + height x2 = x1 + width result = torch.stack([y1, x1, y2, x2], dim=1) return result def apply_box_deltas_3D(boxes, deltas): """Applies the given deltas to the given boxes. boxes: [N, 6] where each row is y1, x1, y2, x2, z1, z2 deltas: [N, 6] where each row is [dy, dx, dz, log(dh), log(dw), log(dd)] """ # Convert to y, x, h, w height = boxes[:, 2] - boxes[:, 0] width = boxes[:, 3] - boxes[:, 1] depth = boxes[:, 5] - boxes[:, 4] center_y = boxes[:, 0] + 0.5 * height center_x = boxes[:, 1] + 0.5 * width center_z = boxes[:, 4] + 0.5 * depth # Apply deltas center_y += deltas[:, 0] * height center_x += deltas[:, 1] * width center_z += deltas[:, 2] * depth height *= torch.exp(deltas[:, 3]) width *= torch.exp(deltas[:, 4]) depth *= torch.exp(deltas[:, 5]) # Convert back to y1, x1, y2, x2 y1 = center_y - 0.5 * height x1 = center_x - 0.5 * width z1 = center_z - 0.5 * depth y2 = y1 + height x2 = x1 + width z2 = z1 + depth result = torch.stack([y1, x1, y2, x2, z1, z2], dim=1) return result def clip_boxes_2D(boxes, window): """ boxes: [N, 4] each col is y1, x1, y2, x2 window: [4] in the form y1, x1, y2, x2 """ boxes = torch.stack( \ [boxes[:, 0].clamp(float(window[0]), float(window[2])), boxes[:, 1].clamp(float(window[1]), float(window[3])), boxes[:, 2].clamp(float(window[0]), float(window[2])), boxes[:, 3].clamp(float(window[1]), float(window[3]))], 1) return boxes def clip_boxes_3D(boxes, window): """ boxes: [N, 6] each col is y1, x1, y2, x2, z1, z2 window: [6] in the form y1, x1, y2, x2, z1, z2 """ boxes = torch.stack( \ [boxes[:, 0].clamp(float(window[0]), float(window[2])), boxes[:, 1].clamp(float(window[1]), float(window[3])), boxes[:, 2].clamp(float(window[0]), float(window[2])), boxes[:, 3].clamp(float(window[1]), float(window[3])), boxes[:, 4].clamp(float(window[4]), float(window[5])), boxes[:, 5].clamp(float(window[4]), float(window[5]))], 1) return boxes from matplotlib import pyplot as plt def clip_boxes_numpy(boxes, window): """ boxes: [N, 4] each col is y1, x1, y2, x2 / [N, 6] in 3D. window: iamge shape (y, x, (z)) """ if boxes.shape[1] == 4: boxes = np.concatenate( (np.clip(boxes[:, 0], 0, window[0])[:, None], np.clip(boxes[:, 1], 0, window[0])[:, None], np.clip(boxes[:, 2], 0, window[1])[:, None], np.clip(boxes[:, 3], 0, window[1])[:, None]), 1 ) else: boxes = np.concatenate( (np.clip(boxes[:, 0], 0, window[0])[:, None], np.clip(boxes[:, 1], 0, window[0])[:, None], np.clip(boxes[:, 2], 0, window[1])[:, None], np.clip(boxes[:, 3], 0, window[1])[:, None], np.clip(boxes[:, 4], 0, window[2])[:, None], np.clip(boxes[:, 5], 0, window[2])[:, None]), 1 ) return boxes def bbox_overlaps_2D(boxes1, boxes2): """Computes IoU overlaps between two sets of boxes. boxes1, boxes2: [N, (y1, x1, y2, x2)]. """ # 1. Tile boxes2 and repeate boxes1. This allows us to compare # every boxes1 against every boxes2 without loops. # TF doesn't have an equivalent to np.repeate() so simulate it # using tf.tile() and tf.reshape. boxes1_repeat = boxes2.size()[0] boxes2_repeat = boxes1.size()[0] boxes1 = boxes1.repeat(1,boxes1_repeat).view(-1,4) boxes2 = boxes2.repeat(boxes2_repeat,1) # 2. Compute intersections b1_y1, b1_x1, b1_y2, b1_x2 = boxes1.chunk(4, dim=1) b2_y1, b2_x1, b2_y2, b2_x2 = boxes2.chunk(4, dim=1) y1 = torch.max(b1_y1, b2_y1)[:, 0] x1 = torch.max(b1_x1, b2_x1)[:, 0] y2 = torch.min(b1_y2, b2_y2)[:, 0] x2 = torch.min(b1_x2, b2_x2)[:, 0] #--> expects x11 produced in bbox_overlaps_2D" overlaps = iou.view(boxes2_repeat, boxes1_repeat) #--> per gt box: ious of all proposal boxes with that gt box return overlaps def bbox_overlaps_3D(boxes1, boxes2): """Computes IoU overlaps between two sets of boxes. boxes1, boxes2: [N, (y1, x1, y2, x2, z1, z2)]. """ # 1. Tile boxes2 and repeate boxes1. This allows us to compare # every boxes1 against every boxes2 without loops. # TF doesn't have an equivalent to np.repeate() so simulate it # using tf.tile() and tf.reshape. boxes1_repeat = boxes2.size()[0] boxes2_repeat = boxes1.size()[0] boxes1 = boxes1.repeat(1,boxes1_repeat).view(-1,6) boxes2 = boxes2.repeat(boxes2_repeat,1) # 2. Compute intersections b1_y1, b1_x1, b1_y2, b1_x2, b1_z1, b1_z2 = boxes1.chunk(6, dim=1) b2_y1, b2_x1, b2_y2, b2_x2, b2_z1, b2_z2 = boxes2.chunk(6, dim=1) y1 = torch.max(b1_y1, b2_y1)[:, 0] x1 = torch.max(b1_x1, b2_x1)[:, 0] y2 = torch.min(b1_y2, b2_y2)[:, 0] x2 = torch.min(b1_x2, b2_x2)[:, 0] z1 = torch.max(b1_z1, b2_z1)[:, 0] z2 = torch.min(b1_z2, b2_z2)[:, 0] zeros = torch.zeros(y1.size()[0], requires_grad=False) if y1.is_cuda: zeros = zeros.cuda() intersection = torch.max(x2 - x1, zeros) * torch.max(y2 - y1, zeros) * torch.max(z2 - z1, zeros) # 3. Compute unions b1_volume = (b1_y2 - b1_y1) * (b1_x2 - b1_x1) * (b1_z2 - b1_z1) b2_volume = (b2_y2 - b2_y1) * (b2_x2 - b2_x1) * (b2_z2 - b2_z1) union = b1_volume[:,0] + b2_volume[:,0] - intersection # 4. Compute IoU and reshape to [boxes1, boxes2] iou = intersection / union overlaps = iou.view(boxes2_repeat, boxes1_repeat) return overlaps def gt_anchor_matching(cf, anchors, gt_boxes, gt_class_ids=None): """Given the anchors and GT boxes, compute overlaps and identify positive anchors and deltas to refine them to match their corresponding GT boxes. anchors: [num_anchors, (y1, x1, y2, x2, (z1), (z2))] gt_boxes: [num_gt_boxes, (y1, x1, y2, x2, (z1), (z2))] gt_class_ids (optional): [num_gt_boxes] Integer class IDs for one stage detectors. in RPN case of Mask R-CNN, set all positive matches to 1 (foreground) Returns: anchor_class_matches: [N] (int32) matches between anchors and GT boxes. 1 = positive anchor, -1 = negative anchor, 0 = neutral anchor_delta_targets: [N, (dy, dx, (dz), log(dh), log(dw), (log(dd)))] Anchor bbox deltas. """ anchor_class_matches = np.zeros([anchors.shape[0]], dtype=np.int32) anchor_delta_targets = np.zeros((cf.rpn_train_anchors_per_image, 2*cf.dim)) anchor_matching_iou = cf.anchor_matching_iou if gt_boxes is None: anchor_class_matches = np.full(anchor_class_matches.shape, fill_value=-1) return anchor_class_matches, anchor_delta_targets # for mrcnn: anchor matching is done for RPN loss, so positive labels are all 1 (foreground) if gt_class_ids is None: gt_class_ids = np.array([1] * len(gt_boxes)) # Compute overlaps [num_anchors, num_gt_boxes] overlaps = compute_overlaps(anchors, gt_boxes) # Match anchors to GT Boxes # If an anchor overlaps a GT box with IoU >= anchor_matching_iou then it's positive. # If an anchor overlaps a GT box with IoU < 0.1 then it's negative. # Neutral anchors are those that don't match the conditions above, # and they don't influence the loss function. # However, don't keep any GT box unmatched (rare, but happens). Instead, # match it to the closest anchor (even if its max IoU is < 0.1). # 1. Set negative anchors first. They get overwritten below if a GT box is # matched to them. Skip boxes in crowd areas. anchor_iou_argmax = np.argmax(overlaps, axis=1) anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax] if anchors.shape[1] == 4: anchor_class_matches[(anchor_iou_max < 0.1)] = -1 elif anchors.shape[1] == 6: anchor_class_matches[(anchor_iou_max < 0.01)] = -1 else: raise ValueError('anchor shape wrong {}'.format(anchors.shape)) # 2. Set an anchor for each GT box (regardless of IoU value). gt_iou_argmax = np.argmax(overlaps, axis=0) for ix, ii in enumerate(gt_iou_argmax): anchor_class_matches[ii] = gt_class_ids[ix] # 3. Set anchors with high overlap as positive. above_thresh_ixs = np.argwhere(anchor_iou_max >= anchor_matching_iou) anchor_class_matches[above_thresh_ixs] = gt_class_ids[anchor_iou_argmax[above_thresh_ixs]] # Subsample to balance positive anchors. ids = np.where(anchor_class_matches > 0)[0] extra = len(ids) - (cf.rpn_train_anchors_per_image // 2) if extra > 0: # Reset the extra ones to neutral ids = np.random.choice(ids, extra, replace=False) anchor_class_matches[ids] = 0 # Leave all negative proposals negative for now and sample from them later in online hard example mining. # For positive anchors, compute shift and scale needed to transform them to match the corresponding GT boxes. ids = np.where(anchor_class_matches > 0)[0] ix = 0 # index into anchor_delta_targets for i, a in zip(ids, anchors[ids]): # closest gt box (it might have IoU < anchor_matching_iou) gt = gt_boxes[anchor_iou_argmax[i]] # convert coordinates to center plus width/height. gt_h = gt[2] - gt[0] gt_w = gt[3] - gt[1] gt_center_y = gt[0] + 0.5 * gt_h gt_center_x = gt[1] + 0.5 * gt_w # Anchor a_h = a[2] - a[0] a_w = a[3] - a[1] a_center_y = a[0] + 0.5 * a_h a_center_x = a[1] + 0.5 * a_w if cf.dim == 2: anchor_delta_targets[ix] = [ (gt_center_y - a_center_y) / a_h, (gt_center_x - a_center_x) / a_w, np.log(gt_h / a_h), np.log(gt_w / a_w), ] else: gt_d = gt[5] - gt[4] gt_center_z = gt[4] + 0.5 * gt_d a_d = a[5] - a[4] a_center_z = a[4] + 0.5 * a_d anchor_delta_targets[ix] = [ (gt_center_y - a_center_y) / a_h, (gt_center_x - a_center_x) / a_w, (gt_center_z - a_center_z) / a_d, np.log(gt_h / a_h), np.log(gt_w / a_w), np.log(gt_d / a_d) ] # normalize. anchor_delta_targets[ix] /= cf.rpn_bbox_std_dev ix += 1 return anchor_class_matches, anchor_delta_targets def clip_to_window(window, boxes): """ window: (y1, x1, y2, x2) / 3D: (z1, z2). The window in the image we want to clip to. boxes: [N, (y1, x1, y2, x2)] / 3D: (z1, z2) """ boxes[:, 0] = boxes[:, 0].clamp(float(window[0]), float(window[2])) boxes[:, 1] = boxes[:, 1].clamp(float(window[1]), float(window[3])) boxes[:, 2] = boxes[:, 2].clamp(float(window[0]), float(window[2])) boxes[:, 3] = boxes[:, 3].clamp(float(window[1]), float(window[3])) if boxes.shape[1] > 5: boxes[:, 4] = boxes[:, 4].clamp(float(window[4]), float(window[5])) boxes[:, 5] = boxes[:, 5].clamp(float(window[4]), float(window[5])) return boxes ############################################################ # Connected Componenent Analysis ############################################################ def get_coords(binary_mask, n_components, dim): """ loops over batch to perform connected component analysis on binary input mask. computes box coordinates around n_components - biggest components (rois). :param binary_mask: (b, y, x, (z)). binary mask for one specific foreground class. :param n_components: int. number of components to extract per batch element and class. :return: coords (b, n, (y1, x1, y2, x2 (,z1, z2)) :return: batch_components (b, n, (y1, x1, y2, x2, (z1), (z2)) """ assert len(binary_mask.shape)==dim+1 binary_mask = binary_mask.astype('uint8') batch_coords = [] batch_components = [] for ix,b in enumerate(binary_mask): clusters, n_cands = lb(b) # performs connected component analysis. uniques, counts = np.unique(clusters, return_counts=True) keep_uniques = uniques[1:][np.argsort(counts[1:])[::-1]][:n_components] #only keep n_components largest components p_components = np.array([(clusters == ii) * 1 for ii in keep_uniques]) # separate clusters and concat p_coords = [] if p_components.shape[0] > 0: for roi in p_components: mask_ixs = np.argwhere(roi != 0) # get coordinates around component. roi_coords = [np.min(mask_ixs[:, 0]) - 1, np.min(mask_ixs[:, 1]) - 1, np.max(mask_ixs[:, 0]) + 1, np.max(mask_ixs[:, 1]) + 1] if dim == 3: roi_coords += [np.min(mask_ixs[:, 2]), np.max(mask_ixs[:, 2])+1] p_coords.append(roi_coords) p_coords = np.array(p_coords) #clip coords. p_coords[p_coords < 0] = 0 p_coords[:, :4][p_coords[:, :4] > binary_mask.shape[-2]] = binary_mask.shape[-2] if dim == 3: p_coords[:, 4:][p_coords[:, 4:] > binary_mask.shape[-1]] = binary_mask.shape[-1] batch_coords.append(p_coords) batch_components.append(p_components) return batch_coords, batch_components # noinspection PyCallingNonCallable def get_coords_gpu(binary_mask, n_components, dim): """ loops over batch to perform connected component analysis on binary input mask. computes box coordiantes around n_components - biggest components (rois). :param binary_mask: (b, y, x, (z)). binary mask for one specific foreground class. :param n_components: int. number of components to extract per batch element and class. :return: coords (b, n, (y1, x1, y2, x2 (,z1, z2)) :return: batch_components (b, n, (y1, x1, y2, x2, (z1), (z2)) """ raise Exception("throws floating point exception") assert len(binary_mask.shape)==dim+1 binary_mask = binary_mask.type(torch.uint8) batch_coords = [] batch_components = [] for ix,b in enumerate(binary_mask): clusters, n_cands = lb(b.cpu().data.numpy()) # peforms connected component analysis. clusters = torch.from_numpy(clusters).cuda() uniques = torch.unique(clusters) counts = torch.stack([(clusters==unique).sum() for unique in uniques]) keep_uniques = uniques[1:][torch.sort(counts[1:])[1].flip(0)][:n_components] #only keep n_components largest components p_components = torch.cat([(clusters == ii).unsqueeze(0) for ii in keep_uniques]).cuda() # separate clusters and concat p_coords = [] if p_components.shape[0] > 0: for roi in p_components: mask_ixs = torch.nonzero(roi) # get coordinates around component. roi_coords = [torch.min(mask_ixs[:, 0]) - 1, torch.min(mask_ixs[:, 1]) - 1, torch.max(mask_ixs[:, 0]) + 1, torch.max(mask_ixs[:, 1]) + 1] if dim == 3: roi_coords += [torch.min(mask_ixs[:, 2]), torch.max(mask_ixs[:, 2])+1] p_coords.append(roi_coords) p_coords = torch.tensor(p_coords) #clip coords. p_coords[p_coords < 0] = 0 p_coords[:, :4][p_coords[:, :4] > binary_mask.shape[-2]] = binary_mask.shape[-2] if dim == 3: p_coords[:, 4:][p_coords[:, 4:] > binary_mask.shape[-1]] = binary_mask.shape[-1] batch_coords.append(p_coords) batch_components.append(p_components) return batch_coords, batch_components ############################################################ # Pytorch Utility Functions ############################################################ def unique1d(tensor): """discard all elements of tensor that occur more than once; make tensor unique. :param tensor: :return: """ if tensor.size()[0] == 0 or tensor.size()[0] == 1: return tensor tensor = tensor.sort()[0] unique_bool = tensor[1:] != tensor[:-1] first_element = torch.tensor([True], dtype=torch.bool, requires_grad=False) if tensor.is_cuda: first_element = first_element.cuda() unique_bool = torch.cat((first_element, unique_bool), dim=0) return tensor[unique_bool.data] def intersect1d(tensor1, tensor2): aux = torch.cat((tensor1, tensor2), dim=0) aux = aux.sort(descending=True)[0] return aux[:-1][(aux[1:] == aux[:-1]).data] def shem(roi_probs_neg, negative_count, poolsize): """ stochastic hard example mining: from a list of indices (referring to non-matched predictions), determine a pool of highest scoring (worst false positives) of size negative_count*poolsize. Then, sample n (= negative_count) predictions of this pool as negative examples for loss. :param roi_probs_neg: tensor of shape (n_predictions, n_classes). :param negative_count: int. :param poolsize: int. :return: (negative_count). indices refer to the positions in roi_probs_neg. If pool smaller than expected due to limited negative proposals availabel, this function will return sampled indices of number < negative_count without throwing an error. """ # sort according to higehst foreground score. probs, order = roi_probs_neg[:, 1:].max(1)[0].sort(descending=True) select = torch.tensor((poolsize * int(negative_count), order.size()[0])).min().int() pool_indices = order[:select] rand_idx = torch.randperm(pool_indices.size()[0]) return pool_indices[rand_idx[:negative_count].cuda()] ############################################################ # Weight Init ############################################################ def initialize_weights(net): """Initialize model weights. Current Default in Pytorch (version 0.4.1) is initialization from a uniform distriubtion. Will expectably be changed to kaiming_uniform in future versions. """ init_type = net.cf.weight_init for m in [module for module in net.modules() if type(module) in [torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, torch.nn.Linear]]: if init_type == 'xavier_uniform': torch.nn.init.xavier_uniform_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif init_type == 'xavier_normal': torch.nn.init.xavier_normal_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif init_type == "kaiming_uniform": torch.nn.init.kaiming_uniform_(m.weight.data, mode='fan_out', nonlinearity=net.cf.relu, a=0) if m.bias is not None: fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(m.weight.data) bound = 1 / np.sqrt(fan_out) torch.nn.init.uniform_(m.bias, -bound, bound) elif init_type == "kaiming_normal": torch.nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity=net.cf.relu, a=0) if m.bias is not None: fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(m.weight.data) bound = 1 / np.sqrt(fan_out) torch.nn.init.normal_(m.bias, -bound, bound) net.logger.info("applied {} weight init.".format(init_type)) \ No newline at end of file