diff --git a/exec.py b/exec.py index 0c03071..0903b48 100644 --- a/exec.py +++ b/exec.py @@ -1,344 +1,344 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ execution script. this where all routines come together and the only script you need to call. refer to parse args below to see options for execution. """ import plotting as plg import os import warnings import argparse import time import torch import utils.exp_utils as utils from evaluator import Evaluator from predictor import Predictor for msg in ["Attempting to set identical bottom==top results", "This figure includes Axes that are not compatible with tight_layout", "Data has no positive values, and therefore cannot be log-scaled.", ".*invalid value encountered in true_divide.*"]: warnings.filterwarnings("ignore", msg) def train(cf, logger): """ performs the training routine for a given fold. saves plots and selected parameters to the experiment dir specified in the configs. logs to file and tensorboard. """ logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format( cf.dim, cf.fold, cf.exp_dir, cf.model)) logger.time("train_val") # -------------- inits and settings ----------------- net = model.net(cf, logger).cuda() if cf.optimizer == "ADAM": optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay) elif cf.optimizer == "SGD": optimizer = torch.optim.SGD(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay, momentum=0.3) if cf.dynamic_lr_scheduling: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=cf.scheduling_mode, factor=cf.lr_decay_factor, patience=cf.scheduling_patience) model_selector = utils.ModelSelector(cf, logger) starting_epoch = 1 if cf.resume_from_checkpoint: starting_epoch = utils.load_checkpoint(cf.resume_from_checkpoint, net, optimizer) logger.info('resumed from checkpoint {} at epoch {}'.format(cf.resume_from_checkpoint, starting_epoch)) # prepare monitoring monitor_metrics = utils.prepare_monitoring(cf) logger.info('loading dataset and initializing batch generators...') batch_gen = data_loader.get_train_generators(cf, logger) # -------------- training ----------------- for epoch in range(starting_epoch, cf.num_epochs + 1): logger.info('starting training epoch {}/{}'.format(epoch, cf.num_epochs)) logger.time("train_epoch") net.train() train_results_list = [] train_evaluator = Evaluator(cf, logger, mode='train') for i in range(cf.num_train_batches): logger.time("train_batch_loadfw") batch = next(batch_gen['train']) batch_gen['train'].generator.stats['roi_counts'] += batch['roi_counts'] batch_gen['train'].generator.stats['empty_samples_count'] += batch['empty_samples_count'] logger.time("train_batch_loadfw") logger.time("train_batch_netfw") results_dict = net.train_forward(batch) logger.time("train_batch_netfw") logger.time("train_batch_bw") optimizer.zero_grad() results_dict['torch_loss'].backward() if cf.clip_norm: torch.nn.utils.clip_grad_norm_(net.parameters(), cf.clip_norm, norm_type=2) # gradient clipping optimizer.step() train_results_list.append(({k:v for k,v in results_dict.items() if k != "seg_preds"}, batch["pid"])) # slim res dict if not cf.server_env: print("\rFinished training batch " + "{}/{} in {:.1f}s ({:.2f}/{:.2f} forw load/net, {:.2f} backw).".format(i+1, cf.num_train_batches, logger.get_time("train_batch_loadfw")+ logger.get_time("train_batch_netfw") +logger.time("train_batch_bw"), logger.get_time("train_batch_loadfw",reset=True), logger.get_time("train_batch_netfw", reset=True), logger.get_time("train_batch_bw", reset=True)), end="", flush=True) print() #--------------- train eval ---------------- if (epoch-1)%cf.plot_frequency==0: # view an example batch logger.time("train_plot") plg.view_batch(cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True, out_file=os.path.join(cf.plot_dir, 'batch_example_train_{}.png'.format(cf.fold))) logger.info("generated train-example plot in {:.2f}s".format(logger.time("train_plot"))) logger.time("evals") _, monitor_metrics['train'] = train_evaluator.evaluate_predictions(train_results_list, monitor_metrics['train']) logger.time("evals") logger.time("train_epoch", toggle=False) del train_results_list #----------- validation ------------ logger.info('starting validation in mode {}.'.format(cf.val_mode)) logger.time("val_epoch") with torch.no_grad(): net.eval() val_results_list = [] val_evaluator = Evaluator(cf, logger, mode=cf.val_mode) val_predictor = Predictor(cf, net, logger, mode='val') for i in range(batch_gen['n_val']): logger.time("val_batch") batch = next(batch_gen[cf.val_mode]) if cf.val_mode == 'val_patient': results_dict = val_predictor.predict_patient(batch) elif cf.val_mode == 'val_sampling': results_dict = net.train_forward(batch, is_validation=True) val_results_list.append([results_dict, batch["pid"]]) if not cf.server_env: print("\rFinished validation {} {}/{} in {:.1f}s.".format('patient' if cf.val_mode=='val_patient' else 'batch', i + 1, batch_gen['n_val'], logger.time("val_batch")), end="", flush=True) print() #------------ val eval ------------- if (epoch - 1) % cf.plot_frequency == 0: logger.time("val_plot") plg.view_batch(cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True, out_file=os.path.join(cf.plot_dir, 'batch_example_val_{}.png'.format(cf.fold))) logger.info("generated val plot in {:.2f}s".format(logger.time("val_plot"))) logger.time("evals") _, monitor_metrics['val'] = val_evaluator.evaluate_predictions(val_results_list, monitor_metrics['val']) model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch) del val_results_list #----------- monitoring ------------- monitor_metrics.update({"lr": {str(g) : group['lr'] for (g, group) in enumerate(optimizer.param_groups)}}) logger.metrics2tboard(monitor_metrics, global_step=epoch) logger.time("evals") logger.info('finished epoch {}/{}, took {:.2f}s. train total: {:.2f}s, average: {:.2f}s. val total: {:.2f}s, average: {:.2f}s.'.format( epoch, cf.num_epochs, logger.get_time("train_epoch")+logger.time("val_epoch"), logger.get_time("train_epoch"), logger.get_time("train_epoch", reset=True)/cf.num_train_batches, logger.get_time("val_epoch"), logger.get_time("val_epoch", reset=True)/batch_gen["n_val"])) logger.info("time for evals: {:.2f}s".format(logger.get_time("evals", reset=True))) #-------------- scheduling ----------------- if not cf.dynamic_lr_scheduling: for param_group in optimizer.param_groups: param_group['lr'] = cf.learning_rate[epoch-1] else: scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1]) logger.time("train_val") logger.info("Training and validating over {} epochs took {}".format(cf.num_epochs, logger.get_time("train_val", format="hms", reset=True))) batch_gen['train'].generator.print_stats(logger, plot=True) def test(cf, logger, max_fold=None): """performs testing for a given fold (or held out set). saves stats in evaluator. """ logger.time("test_fold") logger.info('starting testing model of fold {} in exp {}'.format(cf.fold, cf.exp_dir)) net = model.net(cf, logger).cuda() batch_gen = data_loader.get_test_generator(cf, logger) test_predictor = Predictor(cf, net, logger, mode='test') test_results_list = test_predictor.predict_test_set(batch_gen, return_results = not hasattr( cf, "eval_test_separately") or not cf.eval_test_separately) if test_results_list is not None: test_evaluator = Evaluator(cf, logger, mode='test') test_evaluator.evaluate_predictions(test_results_list) test_evaluator.score_test_df(max_fold=max_fold) logger.info('Testing of fold {} took {}.'.format(cf.fold, logger.get_time("test_fold", reset=True, format="hms"))) if __name__ == '__main__': stime = time.time() parser = argparse.ArgumentParser() parser.add_argument('-m', '--mode', type=str, default='train_test', help='one out of: create_exp, analysis, train, train_test, or test') parser.add_argument('-f', '--folds', nargs='+', type=int, default=None, help='None runs over all folds in CV. otherwise specify list of folds.') parser.add_argument('--exp_dir', type=str, default='/home/gregor/Documents/regrcnn/datasets/toy/experiments/dev', help='path to experiment dir. will be created if non existent.') parser.add_argument('--server_env', default=False, action='store_true', help='change IO settings to deploy models on a cluster.') parser.add_argument('--data_dest', type=str, default=None, help="path to final data folder if different from config") parser.add_argument('--use_stored_settings', default=False, action='store_true', help='load configs from existing exp_dir instead of source dir. always done for testing, ' 'but can be set to true to do the same for training. useful in job scheduler environment, ' 'where source code might change before the job actually runs.') parser.add_argument('--resume_from_checkpoint', type=str, default=None, help='path to checkpoint. if resuming from checkpoint, the desired fold still needs to be parsed via --folds.') parser.add_argument('--dataset_name', type=str, default='toy', help="path to the dataset-specific code in source_dir/datasets") parser.add_argument('-d', '--dev', default=False, action='store_true', help="development mode: shorten everything") args = parser.parse_args() args.dataset_name = os.path.join("datasets", args.dataset_name) if not "datasets" in args.dataset_name else args.dataset_name folds = args.folds resume_from_checkpoint = None if args.resume_from_checkpoint in ['None', 'none'] else args.resume_from_checkpoint if args.mode == 'create_exp': cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=False) logger = utils.get_logger(cf.exp_dir, cf.server_env, -1) logger.info('created experiment directory at {}'.format(args.exp_dir)) elif args.mode == 'train' or args.mode == 'train_test': cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, args.use_stored_settings) if args.dev: folds = [0,1] cf.batch_size, cf.num_epochs, cf.min_save_thresh, cf.save_n_models = 3 if cf.dim==2 else 1, 1, 0, 1 cf.num_train_batches, cf.num_val_batches, cf.max_val_patients = 5, 1, 1 cf.test_n_epochs = cf.save_n_models cf.max_test_patients = 1 torch.backends.cudnn.benchmark = cf.dim==3 else: torch.backends.cudnn.benchmark = cf.cuda_benchmark if args.data_dest is not None: cf.data_dest = args.data_dest logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval) data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py')) model = utils.import_module('model', cf.model_path) logger.info("loaded model from {}".format(cf.model_path)) if folds is None: folds = range(cf.n_cv_splits) for fold in folds: """k-fold cross-validation: the dataset is split into k equally-sized folds, one used for validation, one for testing, the rest for training. This loop iterates k-times over the dataset, cyclically moving the splits. k==folds, fold in [0,folds) says which split is used for testing. """ cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold)) cf.fold, logger.fold = fold, fold cf.resume_from_checkpoint = resume_from_checkpoint if not os.path.exists(cf.fold_dir): os.mkdir(cf.fold_dir) train(cf, logger) cf.resume_from_checkpoint = None if args.mode == 'train_test': test(cf, logger) elif args.mode == 'test': cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False) if args.data_dest is not None: cf.data_dest = args.data_dest - logger = utils.get_logger(cf.exp_dir, cf.server_env) + logger = utils.get_logger(cf.exp_dir, cf.server_env, cf.sysmetrics_interval) data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py')) model = utils.import_module('model', cf.model_path) logger.info("loaded model from {}".format(cf.model_path)) fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")]) if folds is None: folds = range(cf.n_cv_splits) if args.dev: folds = folds[:2] cf.batch_size, cf.max_test_patients, cf.test_n_epochs = 1 if cf.dim==2 else 1, 2, 2 else: torch.backends.cudnn.benchmark = cf.cuda_benchmark for fold in folds: cf.fold = fold cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold)) if cf.fold_dir in fold_dirs: test(cf, logger, max_fold=max([int(f[-1]) for f in fold_dirs])) else: logger.info("Skipping fold {} since no model parameters found.".format(fold)) # load raw predictions saved by predictor during testing, run aggregation algorithms and evaluation. elif args.mode == 'analysis': """ analyse already saved predictions. """ cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False) logger = utils.get_logger(cf.exp_dir, cf.server_env, -1) if cf.held_out_test_set and not cf.eval_test_fold_wise: predictor = Predictor(cf, net=None, logger=logger, mode='analysis') results_list = predictor.load_saved_predictions() logger.info('starting evaluation...') cf.fold = 0 evaluator = Evaluator(cf, logger, mode='test') evaluator.evaluate_predictions(results_list) evaluator.score_test_df(max_fold=0) else: fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")]) if args.dev: fold_dirs = fold_dirs[:1] if folds is None: folds = range(cf.n_cv_splits) for fold in folds: cf.fold = fold cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold)) if cf.fold_dir in fold_dirs: predictor = Predictor(cf, net=None, logger=logger, mode='analysis') results_list = predictor.load_saved_predictions() # results_list[x][1] is pid, results_list[x][0] is list of len samples-per-patient, each entry hlds # list of boxes per that sample, i.e., len(results_list[x][y][0]) would be nr of boxes in sample y of patient x logger.info('starting evaluation...') evaluator = Evaluator(cf, logger, mode='test') evaluator.evaluate_predictions(results_list) max_fold = max([int(f[-1]) for f in fold_dirs]) evaluator.score_test_df(max_fold=max_fold) else: logger.info("Skipping fold {} since no model parameters found.".format(fold)) else: raise ValueError('mode "{}" specified in args is not implemented.'.format(args.mode)) mins, secs = divmod((time.time() - stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) logger.info("{} total runtime: {}".format(os.path.split(__file__)[1], t)) del logger torch.cuda.empty_cache() diff --git a/plotting.py b/plotting.py index d53d3e5..c6425c4 100644 --- a/plotting.py +++ b/plotting.py @@ -1,2135 +1,2136 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import matplotlib # matplotlib.rcParams['font.family'] = ['serif'] # matplotlib.rcParams['font.serif'] = ['Times New Roman'] matplotlib.rcParams['mathtext.fontset'] = 'cm' matplotlib.rcParams['font.family'] = 'STIXGeneral' matplotlib.use('Agg') #complains with spyder editor, bc spyder imports mpl at startup from matplotlib.ticker import FormatStrFormatter import matplotlib.colors as mcolors import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import matplotlib.patches as mpatches from matplotlib.ticker import StrMethodFormatter, ScalarFormatter import SimpleITK as sitk from tensorboard.backend.event_processing.event_multiplexer import EventMultiplexer import sys import os import warnings from copy import deepcopy import numpy as np import pandas as pd import scipy.interpolate as interpol from utils.exp_utils import IO_safe warnings.filterwarnings("ignore", module="matplotlib.image") def make_colormap(seq): """ Return a LinearSegmentedColormap seq: a sequence of floats and RGB-tuples. The floats should be increasing and in the interval (0,1). """ seq = [(None,) * 3, 0.0] + list(seq) + [1.0, (None,) * 3] cdict = {'red': [], 'green': [], 'blue': []} for i, item in enumerate(seq): if isinstance(item, float): r1, g1, b1 = seq[i - 1] r2, g2, b2 = seq[i + 1] cdict['red'].append([item, r1, r2]) cdict['green'].append([item, g1, g2]) cdict['blue'].append([item, b1, b2]) return mcolors.LinearSegmentedColormap('CustomMap', cdict) bw_cmap = make_colormap([(1.,1.,1.), (0.,0.,0.)]) #------------------------------------------------------------------------ #------------- plotting functions, not all are used --------------------- def shape_small_first(shape): """sort a tuple so that the smallest entry is swapped to the beginning """ if len(shape) <= 2: # no changing dimensions if channel-dim is missing return shape smallest_dim = np.argmin(shape) if smallest_dim != 0: # assume that smallest dim is color channel new_shape = np.array(shape) # to support mask indexing new_shape = (new_shape[smallest_dim], *new_shape[(np.arange(len(shape), dtype=int) != smallest_dim)]) return new_shape else: return shape def RGB_to_rgb(RGB): rgb = np.array(RGB) / 255. return rgb def mod_to_rgb(arr, cmap=None): """convert a single-channel modality img to 3-color-channel img. :param arr: input img, expected in shape (b,c,)x,y with c=1 :return: img of shape (...,c') with c'=3 """ if len(arr.shape) == 3: arr = np.squeeze(arr) elif len(arr.shape) != 2: raise Exception("Invalid input arr shape: {}".format(arr.shape)) if cmap is None: cmap = "gray" norm = matplotlib.colors.Normalize() norm.autoscale(arr) arr = norm(arr) arr = np.stack((arr,) * 3, axis=-1) return arr def to_rgb(arr, cmap): """ Transform an integer-labeled segmentation map using an rgb color-map. :param arr: img_arr w/o a color-channel :param cmap: dictionary mapping from integer class labels to rgb values :return: img of shape (...,c) """ new_arr = np.zeros(shape=(arr.shape) + (3,)) for l in cmap.keys(): ixs = np.where(arr == l) new_arr[ixs] = np.array([cmap[l][i] for i in range(3)]) return new_arr def to_rgba(arr, cmap): """ Transform an integer-labeled segmentation map using an rgba color-map. :param arr: img_arr w/o a color-channel :param cmap: dictionary mapping from integer class labels to rgba values :return: new array holding rgba-image """ new_arr = np.zeros(shape=(arr.shape) + (4,)) for lab, val in cmap.items(): # in case no alpha, complement with 100% alpha if len(val) == 3: cmap[lab] = (*val, 1.) assert len(cmap[lab]) == 4, "cmap has color with {} entries".format(len(val)) for lab in cmap.keys(): ixs = np.where(arr == lab) rgb = np.array(cmap[lab][:3]) new_arr[ixs] = np.append(rgb, cmap[lab][3]) return new_arr def bin_seg_to_rgba(arr, color): """ Transform a continuously labelled binary segmentation map using an rgba color-map. values are expected to be 0-1, will give alpha-value :param arr: img_arr w/o a color-channel :param color: color to give img :return: new array holding rgba-image """ new_arr = np.zeros(shape=(arr.shape) + (4,)) for i in range(arr.shape[0]): for j in range(arr.shape[1]): new_arr[i][j] = (*color, arr[i][j]) return new_arr def suppress_axes_lines(ax): """ :param ax: pyplot axes object """ ax.axes.get_xaxis().set_ticks([]) ax.axes.get_yaxis().set_ticks([]) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.spines['left'].set_visible(False) return def label_bar(ax, rects, labels=None, colors=None, fontsize=10): """Attach a text label above each bar displaying its height :param ax: :param rects: rectangles as returned by plt.bar() :param labels: :param colors: """ for ix, rect in enumerate(rects): height = rect.get_height() if labels is not None and labels[ix] is not None: label = labels[ix] else: label = '{:g}'.format(height) if colors is not None and colors[ix] is not None and np.any(np.array(colors[ix])<1): color = colors[ix] else: color = 'black' ax.text(rect.get_x() + rect.get_width() / 2., 1.007 * height, label, color=color, ha='center', va='bottom', bbox=dict(facecolor=(1., 1., 1.), edgecolor='none', clip_on=True, pad=0, alpha=0.75), fontsize=fontsize) def draw_box_into_arr(arr, box_coords, box_color=None, lw=2): """ :param arr: imgs shape, (3,y,x) :param box_coords: (x1,y1,x2,y2), in ascending order :param box_color: arr of shape (3,) :param lw: linewidth in pixels """ if box_color is None: box_color = [1., 0.4, 0.] (x1, y1, x2, y2) = box_coords[:4] arr = np.swapaxes(arr, 0, -1) arr[..., y1:y2, x1:x1 + lw, :], arr[..., y1:y2 + lw, x2:x2 + lw, :] = box_color, box_color arr[..., y1:y1 + lw, x1:x2, :], arr[..., y2:y2 + lw, x1:x2, :] = box_color, box_color arr = np.swapaxes(arr, 0, -1) return arr def draw_boxes_into_batch(imgs, batch_boxes, type2color=None, cmap=None): """ :param imgs: either the actual batch imgs or a tuple with shape of batch imgs, need to have 3 color channels, need to be rgb; """ if isinstance(imgs, tuple): img_oshp = imgs imgs = None else: img_oshp = imgs[0].shape img_shp = shape_small_first(img_oshp) # c,x/y,y/x now imgs = np.reshape(imgs, (-1, *img_shp)) box_imgs = np.empty((len(batch_boxes), *(img_shp))) for sample, boxes in enumerate(batch_boxes): # imgs in batch have shape b,c,x,y, swap c to end sample_img = np.full(img_shp, 1.) if imgs is None else imgs[sample] for box in boxes: if len(box["box_coords"]) > 0: if type2color is not None and "box_type" in box.keys(): sample_img = draw_box_into_arr(sample_img, box["box_coords"].astype(np.int32), type2color[box["box_type"]]) else: sample_img = draw_box_into_arr(sample_img, box["box_coords"].astype(np.int32)) box_imgs[sample] = sample_img return box_imgs def plot_prediction_hist(cf, spec_df, outfile, title=None, fs=11, ax=None): labels = spec_df.class_label.values preds = spec_df.pred_score.values type_list = spec_df.det_type.tolist() if hasattr(spec_df, "det_type") else None if title is None: title = outfile.split('/')[-1] + ' count:{}'.format(len(labels)) close=False if ax is None: fig = plt.figure(tight_layout=True) ax = fig.add_subplot(1,1,1) close=True ax.set_yscale('log') ax.set_xlabel("Prediction Score", fontsize=fs) ax.set_ylabel("Occurences", fontsize=fs) ax.hist(preds[labels == 0], alpha=0.3, color=cf.red, range=(0, 1), bins=50, label="fp") ax.hist(preds[labels == 1], alpha=0.3, color=cf.blue, range=(0, 1), bins=50, label="fn at score 0 and tp") ax.axvline(x=cf.min_det_thresh, alpha=1, color=cf.orange, linewidth=1.5, label="min det thresh") if type_list is not None: fp_count = type_list.count('det_fp') fn_count = type_list.count('det_fn') tp_count = type_list.count('det_tp') pos_count = fn_count + tp_count title += '\ntp:{} fp:{} fn:{} pos:{}'.format(tp_count, fp_count, fn_count, pos_count) ax.set_title(title, fontsize=fs) ax.tick_params(axis='both', which='major', labelsize=fs) ax.tick_params(axis='both', which='minor', labelsize=fs) if close: ax.legend(loc="best", fontsize=fs) if cf.server_env: IO_safe(plt.savefig, fname=outfile, _raise=False) else: plt.savefig(outfile) + pass plt.close() def plot_wbc_n_missing(cf, df, outfile, fs=11, ax=None): """ WBC (weighted box clustering) has parameter n_missing, which shows how many boxes are missing per cluster. This function plots the average relative amount of missing boxes sorted by cluster score. :param cf: config. :param df: dataframe. :param outfile: path to save image under. :param fs: fontsize. :param ax: axes object. """ bins = np.linspace(0., 1., 10) names = ["{:.1f}".format((bins[i]+(bins[i+1]-bins[i])/2.)*100) for i in range(len(bins)-1)] classes = df.pred_class.unique() colors = [cf.class_id2label[cl_id].color for cl_id in classes] binned_df = df.copy() binned_df.loc[:,"pred_score"] = pd.cut(binned_df["pred_score"], bins) close=False if ax is None: ax = plt.subplot() close=True width = 1 / (len(classes) + 1) group_positions = np.arange(len(names)) legend_handles = [] for ix, cl_id in enumerate(classes): cl_df = binned_df[binned_df.pred_class==cl_id].groupby("pred_score").agg({"cluster_n_missing": 'mean'}) ax.bar(group_positions + ix * width, cl_df.cluster_n_missing.values, width=width, color=colors[ix], alpha=0.4 + ix / 2 / len(classes), edgecolor=colors[ix]) legend_handles.append(mpatches.Patch(color=colors[ix], label=cf.class_dict[cl_id])) title = "Fold {} WBC Missing Preds\nAverage over scores and classes: {:.1f}%".format(cf.fold, df.cluster_n_missing.mean()) ax.set_title(title, fontsize=fs) ax.legend(handles=legend_handles, title="Class", loc="best", fontsize=fs, title_fontsize=fs) ax.set_xticks(group_positions + (len(classes) - 1) * width / 2) # ax.xaxis.set_major_formatter(StrMethodFormatter('{x:.1f}')) THIS WONT WORK... no clue! ax.set_xticklabels(names) ax.tick_params(axis='both', which='major', labelsize=fs) ax.tick_params(axis='both', which='minor', labelsize=fs) ax.set_axisbelow(True) ax.grid() ax.set_ylabel(r"Average Missing Preds per Cluster (%)", fontsize=fs) ax.set_xlabel("Prediction Score", fontsize=fs) if close: if cf.server_env: IO_safe(plt.savefig, fname=outfile, _raise=False) else: plt.savefig(outfile) plt.close() def plot_stat_curves(cf, stats, outfile, fill=False): """ Plot precision-recall and/or receiver-operating-characteristic curve(s). :param cf: config. :param stats: statistics as supplied by Evaluator. :param outfile: path to save plot under. :param fill: whether to colorize space between plot and x-axis. :return: """ for c in ['roc', 'prc']: plt.figure() empty_plot = True for ix, s in enumerate(stats): if s[c] is not np.nan: plt.plot(s[c][1], s[c][0], label=s['name'] + '_' + c, marker=None, color=cf.color_palette[ix%len(cf.color_palette)]) empty_plot = False if fill: plt.fill_between(s[c][1], s[c][0], alpha=0.33, color=cf.color_palette[ix%len(cf.color_palette)]) if not empty_plot: plt.title(outfile.split('/')[-1] + '_' + c) plt.legend(loc=3 if c == 'prc' else 4) plt.ylabel('precision' if c == 'prc' else '1-spec.') plt.ylim((0.,1)) plt.xlabel('recall') plt.savefig(outfile + '_' + c) plt.close() def plot_grouped_bar_chart(cf, bar_values, groups, splits, colors=None, alphas=None, errors=None, ylabel='', xlabel='', xticklabels=None, yticks=None, yticklabels=None, ylim=None, label_format="{:.3f}", title=None, ax=None, out_file=None, legend=False, fs=11): """ Plot a categorically grouped bar chart. :param cf: config. :param bar_values: values of the bars. :param groups: groups/categories that bars belong to. :param splits: splits within groups, i.e., names of bars. :param colors: colors. :param alphas: 1-opacity. :param errors: values for errorbars. :param ylabel: label of y-axis. :param xlabel: label of x-axis. :param title: plot title. :param ax: axes object to draw into. if None, new is created. :param out_file: path to save plot. :param legend: whether to show a legend. :param fs: fontsize. :return: legend handles. """ bar_values = np.array(bar_values) if alphas is None: alphas = [1.,] * len(splits) if colors is None: colors = [cf.color_palette[ix%len(cf.color_palette)] for ix in range(len(splits))] if errors is None: errors = np.zeros_like(bar_values) # patterns = ('/', '\\', '*', 'O', '.', '-', '+', 'x', 'o') # patterns = tuple([patterns[ix%len(patterns)] for ix in range(len(splits))]) close=False if ax is None: ax = plt.subplot() close=True width = 1 / (len(splits) +0.25) group_positions = np.arange(len(groups)) for ix, split in enumerate(splits): rects = ax.bar(group_positions + ix * width, bar_values[ix], width=width, color=(*colors[ix], 0.8), edgecolor=colors[ix], yerr=errors[ix], ecolor=(*np.array(colors[ix])*0.8, 1.), capsize=5) # for ix, bar in enumerate(rects): # bar.set_hatch(patterns[ix]) labels = [label_format.format(val) for val in bar_values[ix]] label_bar(ax, rects, labels, [colors[ix]]*len(labels), fontsize=fs) legend_handles = [mpatches.Patch(color=colors[ix], alpha=alphas[ix], label=split) for ix, split in enumerate(splits)] if legend: ax.legend(handles=legend_handles, fancybox=True, framealpha=1., loc="lower center") legend_handles = [(colors[ix], alphas[ix], split) for ix, split in enumerate(splits)] if title is not None: ax.set_title(title, fontsize=fs) ax.set_xticks(group_positions + (len(splits) - 1) * width / 2) if xticklabels is None: ax.set_xticklabels(groups, fontsize=fs) else: ax.set_xticklabels(xticklabels, fontsize=fs) ax.set_axisbelow(True) ax.set_xlabel(xlabel, fontsize=fs) ax.tick_params(labelsize=fs) ax.grid(axis='y') ax.set_ylabel(ylabel, fontsize=fs) if yticks is not None: ax.set_yticks(yticks) if yticklabels is not None: ax.set_yticklabels(yticklabels, fontsize=fs) if ylim is not None: ax.set_ylim(ylim) if out_file is not None: plt.savefig(out_file, dpi=600) if close: plt.close() return legend_handles def plot_binned_rater_dissent(cf, binned_stats, out_file=None, ax=None, legend=True, fs=11): """ LIDC-specific plot: rater disagreement as standard deviations within each bin. :param cf: config. :param binned_stats: list, ix==bin_id, item: [(roi_mean, roi_std, roi_max, roi_bin_id-roi_max_bin_id) for roi in bin] :return: """ dissent = [np.array([roi[1] for roi in bin]) for bin in binned_stats] avg_dissent_first_degree = [np.mean(bin) for bin in dissent] groups = list(cf.bin_id2label.keys()) splits = [r"$1^{st}$ std. dev.",] colors = [cf.bin_id2label[bin_id].color[:3] for bin_id in groups] #colors = [cf.blue for bin_id in groups] alphas = [0.9,] #patterns = ('/', '\\', '*', 'O', '.', '-', '+', 'x', 'o') #patterns = tuple([patterns[ix%len(patterns)] for ix in range(len(splits))]) close=False if ax is None: ax = plt.subplot() close=True width = 1/(len(splits)+1) group_positions = np.arange(len(groups)) #total_counts = [df.loc[split].sum() for split in splits] dissent = np.array(avg_dissent_first_degree) ix=0 rects = ax.bar(group_positions+ix*width, dissent, color=colors, alpha=alphas[ix], edgecolor=colors) #for ix, bar in enumerate(rects): #bar.set_hatch(patterns[ix]) labels = ["{:.2f}".format(diss) for diss in dissent] label_bar(ax, rects, labels, colors, fontsize=fs) bin_edge_color = cf.blue ax.axhline(y=0.5, color=bin_edge_color) ax.text(2.5, 0.38, "bin edge", color=cf.white, fontsize=fs, horizontalalignment="center", bbox=dict(boxstyle='round', facecolor=(*bin_edge_color, 0.85), edgecolor='none', clip_on=True, pad=0)) if legend: legend_handles = [mpatches.Patch(color=cf.blue ,alpha=alphas[ix], label=split) for ix, split in enumerate(splits)] ax.legend(handles=legend_handles, loc='lower center', fontsize=fs) title = "LIDC-IDRI: Average Std Deviation per Lesion" plt.title(title) ax.set_xticks(group_positions + (len(splits)-1)*width/2) ax.set_xticklabels(groups, fontsize=fs) ax.set_axisbelow(True) #ax.tick_params(axis='both', which='major', labelsize=fs) #ax.tick_params(axis='both', which='minor', labelsize=fs) ax.grid() ax.set_ylabel(r"Average Dissent (MS)", fontsize=fs) ax.set_xlabel("binned malignancy-score value (ms)", fontsize=fs) ax.tick_params(labelsize=fs) if out_file is not None: plt.savefig(out_file, dpi=600) if close: plt.close() return def plot_confusion_matrix(cf, cm, out_file=None, ax=None, fs=11, cmap=plt.cm.Blues, color_bar=True): """ Plot a confusion matrix. :param cf: config. :param cm: confusion matrix, e.g., as supplied by metrics.confusion_matrix from scikit-learn. :return: """ close=False if ax is None: ax = plt.subplot() close=True im = ax.imshow(cm, interpolation='nearest', cmap=cmap) if color_bar: ax.figure.colorbar(im, ax=ax) # Rotate the tick labels and set their alignment. #plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") # Loop over data dimensions and create text annotations. fmt = '.0%' if np.mod(cm, 1).any() else 'd' thresh = cm.max() / 2. for i in range(cm.shape[0]): for j in range(cm.shape[1]): ax.text(j, i, format(cm[i, j], fmt), ha="center", va="center", color="white" if cm[i, j] > thresh else "black") ax.set_ylabel(r"Binned Mean MS", fontsize=fs) ax.set_xlabel("Single-Annotator MS", fontsize=fs) #ax.tick_params(labelsize=fs) if close and out_file is not None: plt.savefig(out_file, dpi=600) if close: plt.close() else: return ax def plot_data_stats(cf, df, labels=None, out_file=None, ax=None, fs=11): """ Plot data-set statistics. Shows target counts. Mainly used by Dataset Class in dataloader.py. :param cf: configs obj :param df: pandas dataframe :param out_file: path to save fig in """ names = df.columns if labels is not None: colors = [label.color for name in names for label in labels if label.name==name] else: colors = [cf.color_palette[ix%len(cf.color_palette)] for ix in range(len(names))] #patterns = ('/', '\\', '*', 'O', '.', '-', '+', 'x', 'o') #patterns = tuple([patterns[ix%len(patterns)] for ix in range(len(splits))]) if ax is None: fig, ax = plt.subplots(figsize=(14,6), dpi=300) return_ax = False else: return_ax = True plt.margins(x=0.01) plt.subplots_adjust(bottom=0.15) bar_positions = np.arange(len(names)) name_counts = df.sum() total_count = name_counts.sum() rects = ax.bar(bar_positions, name_counts, color=colors, alpha=0.9, edgecolor=colors) labels = ["{:.0f}%".format(count/ total_count*100) for count in name_counts] label_bar(ax, rects, labels, colors, fontsize=fs) title= "Data Set RoI-Target Balance\nTotal #RoIs: {}".format(int(total_count)) ax.set_title(title, fontsize=fs) ax.set_xticks(bar_positions) rotation = "vertical" if np.any([len(str(name)) > 3 for name in names]) else None if all([isinstance(name, (float, int)) for name in names]): ax.set_xticklabels(["{:.2f}".format(name) for name in names], rotation=rotation, fontsize=fs) else: ax.set_xticklabels(names, rotation=rotation, fontsize=fs) ax.set_axisbelow(True) ax.grid() ax.set_ylabel(r"#RoIs", fontsize=fs) ax.set_xlabel(str(df._metadata[0]), fontsize=fs) ax.tick_params(axis='both', which='major', labelsize=fs) ax.tick_params(axis='both', which='minor', labelsize=fs) if out_file is not None: plt.savefig(out_file) if return_ax: return ax else: plt.close() def plot_fold_stats(cf, df, labels=None, out_file=None, ax=None): """ Similar as plot_data_stats but per single cross-val fold. :param cf: configs obj :param df: pandas dataframe :param out_file: path to save fig in """ names = df.columns splits = df.index if labels is not None: colors = [label.color for name in names for label in labels if label.name==name] else: colors = [cf.color_palette[ix%len(cf.color_palette)] for ix in range(len(names))] #patterns = ('/', '\\', '*', 'O', '.', '-', '+', 'x', 'o') #patterns = tuple([patterns[ix%len(patterns)] for ix in range(len(splits))]) if ax is None: ax = plt.subplot() return_ax = False else: return_ax = True width = 1/(len(names)+1) group_positions = np.arange(len(splits)) legend_handles = [] total_counts = [df.loc[split].sum() for split in splits] for ix, name in enumerate(names): rects = ax.bar(group_positions+ix*width, df.loc[:,name], width=width, color=colors[ix], alpha=0.9, edgecolor=colors[ix]) #for ix, bar in enumerate(rects): #bar.set_hatch(patterns[ix]) labels = ["{:.0f}%".format(df.loc[split, name]/ total_counts[ii]*100) for ii, split in enumerate(splits)] label_bar(ax, rects, labels, [colors[ix]]*len(group_positions)) legend_handles.append(mpatches.Patch(color=colors[ix] ,alpha=0.9, label=name)) title= "Fold {} RoI-Target Balances\nTotal #RoIs: {}".format(cf.fold, int(df.values.sum())) plt.title(title) ax.legend(handles=legend_handles) ax.set_xticks(group_positions + (len(names)-1)*width/2) ax.set_xticklabels(splits, rotation="vertical" if len(splits)>2 else None, size=12) ax.set_axisbelow(True) ax.grid() ax.set_ylabel(r"#RoIs") ax.set_xlabel("Set split") if out_file is not None: plt.savefig(out_file) if return_ax: return ax plt.close() def plot_batchgen_distribution(cf, pids, p_probs, balance_target, out_file=None): """plot top n_pids probabilities for drawing a pid into a batch. :param cf: experiment config object :param pids: sorted iterable of patient ids :param p_probs: pid's drawing likelihood, order needs to match the one of pids. :param out_file: :return: """ n_pids = len(pids) zip_sorted = np.array(sorted(list(zip(p_probs, pids)), reverse=True)) names, probs = zip_sorted[:n_pids,1], zip_sorted[:n_pids,0].astype('float32') * 100 try: names = [str(int(n)) for n in names] except ValueError: names = [str(n) for n in names] lowest_p = min(p_probs)*100 fig, ax = plt.subplots(1,1,figsize=(17,5), dpi=200) rects = ax.bar(names, probs, color=cf.blue, alpha=0.9, edgecolor=cf.blue) ax = plt.gca() ax.text(0.8, 0.92, "Lowest prob.: {:.5f}%".format(lowest_p), transform=ax.transAxes, color=cf.white, bbox=dict(boxstyle='round', facecolor=cf.blue, edgecolor='none', alpha=0.9)) ax.yaxis.set_major_formatter(StrMethodFormatter('{x:g}')) ax.set_xticklabels(names, rotation="vertical", fontsize=7) plt.margins(x=0.01) plt.subplots_adjust(bottom=0.15) if balance_target=="class_targets": balance_target = "Class" elif balance_target=="lesion_gleasons": balance_target = "GS" ax.set_title(str(balance_target)+"-Balanced Train Generator: Sampling Likelihood per PID") ax.set_axisbelow(True) ax.grid(axis='y') ax.set_ylabel("Sampling Likelihood (%)") ax.set_xlabel("PID") plt.tight_layout() if out_file is not None: plt.savefig(out_file) plt.close() def plot_batchgen_stats(cf, stats, target_name, unique_ts, out_file=None): """Plot bar chart showing RoI frequencies and empty-sample count of batch stats recorded by BatchGenerator. :param cf: config. :param stats: statistics as supplied by BatchGenerator class. :param out_file: path to save plot. """ total_samples = cf.num_epochs*cf.num_train_batches*cf.batch_size if target_name=="class_targets": target_name = "Class" label_dict = {cl_id: label for (cl_id, label) in cf.class_id2label.items()} elif target_name=="lesion_gleasons": target_name = "Lesion's Gleason Score" label_dict = cf.gs2label elif target_name=="rg_bin_targets": target_name = "Regression-Bin ID" label_dict = cf.bin_id2label else: raise NotImplementedError names = [label_dict[t_id].name for t_id in unique_ts] colors = [label_dict[t_id].color for t_id in unique_ts] title = "Training Target Frequencies" title += "\nempty samples: {} ({:.1f}%)".format(stats['empty_samples_count'], stats['empty_samples_count']/total_samples*100) rects = plt.bar(names, stats['roi_counts'], color=colors, alpha=0.9, edgecolor=colors) ax = plt.gca() ax.yaxis.set_major_formatter(StrMethodFormatter('{x:g}')) ax.set_title(title) ax.set_axisbelow(True) ax.grid() ax.set_ylabel(r"#RoIs") ax.set_xlabel(target_name) total_count = np.sum(stats["roi_counts"]) labels = ["{:.0f}%".format(count/total_count*100) for count in stats["roi_counts"]] label_bar(ax, rects, labels, colors) if out_file is not None: plt.savefig(out_file) plt.close() def view_3D_array(arr, outfile, elev=30, azim=30): from mpl_toolkits.mplot3d import Axes3D fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.set_aspect("equal") ax.set_xlabel("x") ax.set_ylabel("y") ax.set_zlabel("z") ax.voxels(arr) ax.view_init(elev=elev, azim=azim) plt.savefig(outfile) def view_batch(cf, batch, res_dict=None, out_file=None, legend=True, show_info=True, has_colorchannels=False, isRGB=True, show_seg_ids="all", show_seg_pred=True, show_gt_boxes=True, show_gt_labels=False, roi_items="all", sample_picks=None, vol_slice_picks=None, box_score_thres=None, plot_mods=True, dpi=200, vmin=None, return_fig=False): r""" View data and target entries of a batch. Batch expected as dic with entries 'data' and 'seg' holding np.arrays of size :math:`batch\_size \times modalities \times h \times w` for data and :math:`batch\_size \times classes \times h \times w` or :math:`batch\_size \times 1 \times h \times w` for segs. Classes, even if just dummy, are always needed for plotting since they determine colors. Pyplot expects dimensions in order y,x,chans (height, width, chans) for imshow. :param cf: config. :param batch: batch. :param res_dict: results dictionary. :param out_file: path to save plot. :param legend: whether to show a legend. :param show_info: whether to show text info about img sizes and type in plot. :param has_colorchannels: whether image has color channels. :param isRGB: if image is RGB. :param show_seg_ids: "all" or None or list with seg classes to show (seg_ids) :param show_seg_pred: whether to the predicted segmentation. :param show_gt_boxes: whether to show ground-truth boxes. :param show_gt_labels: whether to show labels of ground-truth boxes. :param roi_items: which roi items to show: strings "all" or "targets". --> all roi items in cf.roi_items or only those which are targets, or list holding keys/names of entries in cf.roi_items to plot additionally on roi boxes. empty iterator to show none. :param sample_picks: which indices of the batch to display. None for all. :param vol_slice_picks: when batch elements are 3D: which slices to display. None for all, or tuples ("random", int: amt) / (float€[0,1]: fg_prob, int: amt) for random pick / fg_slices pick w probability fg_prob of amt slices. fg pick requires gt seg. :param box_score_thres: plot only boxes with pred_score > box_score_thres. None or 0. for no threshold. :param plot_mods: whether to plot input modality/modalities. :param dpi: graphics resolution. :param vmin: min value for gray-scale cmap in imshow, set to a fix value for inter-batch normalization, or None for intra-batch. :param return_fig: whether to return created figure. """ # pfix = prefix, ptfix = postfix patched_patient = 'patch_crop_coords' in list(batch.keys()) pfix = 'patient_' if patched_patient else '' ptfix = '_2d' if (patched_patient and cf.dim == 2 and pfix + 'class_targets_2d' in batch.keys()) else '' # -------------- get data, set flags ----------------- try: btype = type(batch[pfix + 'data']) data = batch[pfix + 'data'].astype("float32") seg = batch[pfix + 'seg'] except AttributeError: # in this case: assume it's single-annotator ground truths btype = type(batch[pfix + 'data']) data = batch[pfix + 'data'].astype("float32") seg = batch[pfix + 'seg'][0] print("Showing only gts of rater 0") data_init_shp, seg_init_shp = data.shape, seg.shape seg = np.copy(seg) if show_seg_ids else None plot_bg = batch['plot_bg'] if 'plot_bg' in batch.keys() and not isinstance(batch['plot_bg'], (int, float)) else None plot_bg_chan = batch['plot_bg'] if 'plot_bg' in batch.keys() and isinstance(batch['plot_bg'], (int, float)) else 0 gt_boxes = batch[pfix+'bb_target'+ptfix] if pfix+'bb_target'+ptfix in batch.keys() and show_gt_boxes else None class_targets = batch[pfix+'class_targets'+ptfix] if pfix+'class_targets'+ptfix in batch.keys() else None cf_roi_items = [pfix+it+ptfix for it in cf.roi_items] if roi_items == "all": roi_items = [it for it in cf_roi_items] elif roi_items == "targets": roi_items = [it for it in cf_roi_items if 'targets' in it] else: roi_items = [it for it in cf_roi_items if it in roi_items] if res_dict is not None: seg_preds = res_dict["seg_preds"] if (show_seg_pred is not None and 'seg_preds' in res_dict.keys() and show_seg_ids) else None if '2D_boxes' in res_dict.keys(): assert cf.dim==2 pr_boxes = res_dict["2D_boxes"] elif 'boxes' in res_dict.keys(): pr_boxes = res_dict["boxes"] else: pr_boxes = None else: seg_preds = None pr_boxes = None # -------------- get shapes, apply sample selection ----------------- (n_samples, mods, h, w), d = data.shape[:4], 0 z_ics = [slice(None)] if has_colorchannels: #has to be 2D data = np.transpose(data, axes=(0, 2, 3, 1)) # now b,y,x,c mods = 1 else: if len(data.shape) == 5: # 3dim case d = data.shape[4] if vol_slice_picks is None: z_ics = np.arange(0, d) elif hasattr(vol_slice_picks, "__iter__") and vol_slice_picks[0]=="random": z_ics = np.random.choice(np.arange(0, d), size=min(vol_slice_picks[1], d), replace=False) else: z_ics = vol_slice_picks sample_ics = range(n_samples) # 8000 approx value of pixels that are displayable in one figure dim (pyplot has a render limit), depends on dpi however if data.shape[0]*data.shape[2]*len(z_ics)>8000: n_picks = max(1, int(8000/(data.shape[2]*len(z_ics)))) if len(z_ics)>1 and vol_slice_picks is None: z_ics = np.random.choice(np.arange(0, data.shape[4]), size=min(data.shape[4], max(1,int(8000/(n_picks*data.shape[2])))), replace=False) if sample_picks is None: sample_picks = np.random.choice(data.shape[0], n_picks, replace=False) if sample_picks is not None: sample_ics = [s for s in sample_picks if s in sample_ics] n_samples = len(sample_ics) if not plot_mods: mods = 0 if show_seg_ids=="all": show_seg_ids = np.unique(seg) if seg_preds is not None and not type(show_seg_ids)==str: seg_preds = np.copy(seg_preds) seg_preds = np.where(np.isin(seg_preds, show_seg_ids), seg_preds, 0) if seg is not None: if not type(show_seg_ids)==str: #to save time seg = np.where(np.isin(seg, show_seg_ids), seg, 0) legend_items = {cf.seg_id2label[seg_id] for seg_id in np.unique(seg) if seg_id != 0} # add seg labels else: legend_items = set() # -------------- setup figure ----------------- if isRGB: data = RGB_to_rgb(data) if plot_bg is not None: plot_bg = RGB_to_rgb(plot_bg) n_cols = mods if seg is not None or gt_boxes is not None: n_cols += 1 if seg_preds is not None or pr_boxes is not None: n_cols += 1 n_rows = n_samples*len(z_ics) grid = gridspec.GridSpec(n_rows, n_cols, wspace=0.01, hspace=0.0) fig = plt.figure(figsize=((n_cols + 1)*2, n_rows*2), tight_layout=True) title_fs = 12 # fontsize sample_ics, z_ics = sorted(sample_ics), sorted(z_ics) row = 0 # current row for s_count, s_ix in enumerate(sample_ics): for z_ix in z_ics: col = 0 # current col # ----visualise input data ------------- if has_colorchannels: if plot_mods: ax = fig.add_subplot(grid[row, col]) ax.imshow(data[s_ix][...,z_ix]) ax.axis("off") if row == 0: plt.title("Input", fontsize=title_fs) if col == 0: specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix) == slice else z_ix ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number col += 1 bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix][...,z_ix] else: for mod in range(mods): ax = fig.add_subplot(grid[row, col]) ax.imshow(data[s_ix, mod][...,z_ix], cmap="gray", vmin=vmin) suppress_axes_lines(ax) if row == 0: plt.title("Mod. " + str(mod), fontsize=title_fs) if col == 0: specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix)==slice else z_ix ylabel = str(specs[s_ix])[-5:]+"/"+str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number col += 1 bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix, plot_bg_chan][...,z_ix] # ---evtly visualise groundtruths------------------- if seg is not None or gt_boxes is not None: # img as bg for gt ax = fig.add_subplot(grid[row, col]) ax.imshow(bg_img, cmap="gray", vmin=vmin) if row == 0: plt.title("Ground Truth", fontsize=title_fs) if col == 0: specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix) == slice else z_ix ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number suppress_axes_lines(ax) else: plt.axis('off') col += 1 if seg is not None and seg.shape[1] == 1: ax.imshow(to_rgba(seg[s_ix][0][...,z_ix], cf.cmap), alpha=0.8) elif seg is not None: ax.imshow(to_rgba(np.argmax(seg[s_ix][...,z_ix], axis=0), cf.cmap), alpha=0.8) # gt bounding boxes if gt_boxes is not None and len(gt_boxes[s_ix]) > 0: for j, box in enumerate(gt_boxes[s_ix]): if d > 0: [z1, z2] = box[4:] if not (z1<=z_ix and z_ix<=z2): box = [] if len(box) > 0: [y1, x1, y2, x2] = box[:4] width, height = x2 - x1, y2 - y1 if class_targets is not None: label = cf.class_id2label[class_targets[s_ix][j]] legend_items.add(label) if show_gt_labels: text_poss, p = [(x1, y1), (x1, (y1+y2)//2)], 0 text_fs = title_fs // 3 if roi_items is not None: for name in roi_items: if name in cf_roi_items and batch[name][s_ix][j] is not None: if 'class_targets' in name and cf.plot_class_ids: text_x = x2 #- 2 * text_fs * (len(str(class_targets[s_ix][j]))) # avoid overlap of scores text_y = y1 #+ 2 * text_fs text_str = '{}'.format(class_targets[s_ix][j]) elif 'regression_targets' in name: text_x, text_y = (x2, y2) text_str = "[" + " ".join( ["{:.1f}".format(x) for x in batch[name][s_ix][j]]) + "]" elif 'rg_bin_targets' in name: text_x, text_y = (x1, y2) text_str = '{}'.format(batch[name][s_ix][j]) else: text_pos = text_poss.pop(0) text_x = text_pos[0] #- 2 * text_fs * len(str(batch[name][s_ix][j])) text_y = text_pos[1] #+ 2 * text_fs text_str = '{}'.format(batch[name][s_ix][j]) ax.text(text_x, text_y, text_str, color=cf.white, fontsize=text_fs, bbox=dict(facecolor=label.color, alpha=0.7, edgecolor='none', clip_on=True, pad=0)) p+=1 bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=label.color, facecolor='none') ax.add_patch(bbox) # -----evtly visualise predictions ------------- if pr_boxes is not None or seg_preds is not None: ax = fig.add_subplot(grid[row, col]) ax.imshow(bg_img, cmap="gray") ax.axis("off") col += 1 if row == 0: plt.title("Prediction", fontsize=title_fs) # ---------- pred boxes ------------------------- if pr_boxes is not None and len(pr_boxes[s_ix]) > 0: box_score_thres = cf.min_det_thresh if box_score_thres is None else box_score_thres for j, box in enumerate(pr_boxes[s_ix]): plot_box = box["box_type"] in ["det", "prop"] # , "pos_anchor", "neg_anchor"] if box["box_type"] == "det" and (float(box["box_score"]) <= box_score_thres or box["box_pred_class_id"] == 0): plot_box = False if plot_box: if d > 0: [z1, z2] = box["box_coords"][4:] if not (z1<=z_ix and z_ix<=z2): box = [] if len(box) > 0: [y1, x1, y2, x2] = box["box_coords"][:4] width, height = x2 - x1, y2 - y1 if box["box_type"] == "det": label = cf.class_id2label[box["box_pred_class_id"]] legend_items.add(label) text_x, text_y = x2, y1 id_text = str(box["box_pred_class_id"]) + "|" if cf.plot_class_ids else "" text_str = '{}{:.0f}'.format(id_text, box["box_score"] * 100) text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) edgecolor = label.color if 'regression' in box.keys(): text_x, text_y = x2, y2 id_text = "["+" ".join(["{:.1f}".format(x) for x in box["regression"]])+"]" #str(box["regression"]) #+ "|" if cf.plot_class_ids else "" if 'rg_uncertainty' in box.keys() and not np.isnan(box['rg_uncertainty']): id_text += " | {:.1f}".format(box['rg_uncertainty']) text_str = '{}'.format(id_text) #, box["box_score"] * 100) text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) if 'rg_bin' in box.keys(): text_x, text_y = x1, y2 text_str = '{}'.format(box["rg_bin"]) text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) else: label = cf.box_type2label[box["box_type"]] legend_items.add(label) edgecolor = label.color bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=edgecolor, facecolor='none') ax.add_patch(bbox) # ------------ pred segs -------- if seg_preds is not None: # and seg_preds.shape[1] == 1: if cf.class_specific_seg: ax.imshow(to_rgba(seg_preds[s_ix][0][...,z_ix], cf.cmap), alpha=0.8) else: ax.imshow(bin_seg_to_rgba(seg_preds[s_ix][0][...,z_ix], cf.orange), alpha=0.8) row += 1 # -----actions for all batch entries---------- if legend and len(legend_items) > 0: patches = [] for label in legend_items: if cf.plot_class_ids and type(label) != type(cf.box_labels[0]): id_text = str(label.id) + ":" else: id_text = "" patches.append(mpatches.Patch(color=label.color, label="{}{:.10s}".format(id_text, label.name))) # assumes one image gives enough y-space for 5 legend items ncols = max(1, len(legend_items) // (5 * n_samples)) plt.figlegend(handles=patches, loc="upper center", bbox_to_anchor=(0.99, 0.86), borderaxespad=0., ncol=ncols, bbox_transform=fig.transFigure, fontsize=int(2/3*title_fs)) # fig.set_size_inches(mods+3+ncols-1,1.5+1.2*n_samples) if show_info: plt.figtext(0, 0, "Batch content is of type\n{}\nand has shapes\n".format(btype) + \ "{} for 'data' and {} for 'seg'".format(data_init_shp, seg_init_shp)) if out_file is not None: if cf.server_env: IO_safe(plt.savefig, fname=out_file, dpi=dpi, pad_inches=0.0, bbox_inches='tight', _raise=False) else: plt.savefig(out_file, dpi=dpi, pad_inches=0.0, bbox_inches='tight') if return_fig: return plt.gcf() plt.clf() plt.close() def view_batch_paper(cf, batch, res_dict=None, out_file=None, legend=True, show_info=True, has_colorchannels=False, isRGB=True, show_seg_ids="all", show_seg_pred=True, show_gt_boxes=True, show_gt_labels=False, roi_items="all", split_ens_ics=False, server_env=True, sample_picks=None, vol_slice_picks=None, patient_items=False, box_score_thres=None, plot_mods=True, dpi=400, vmin=None, return_fig=False): r"""view data and target entries of a batch. batch expected as dic with entries 'data' and 'seg' holding tensors or nparrays of size :math:`batch\_size \times modalities \times h \times w` for data and :math:`batch\_size \times classes \times h \times w` or :math:`batch\_size \times 1 \times h \times w` for segs. Classes, even if just dummy, are always needed for plotting since they determine colors. :param cf: :param batch: :param res_dict: :param out_file: :param legend: :param show_info: :param has_colorchannels: :param isRGB: :param show_seg_ids: :param show_seg_pred: :param show_gt_boxes: :param show_gt_labels: :param roi_items: strings "all" or "targets" --> all roi items in cf.roi_items or only those which are targets, or list holding keys/names of entries in cf.roi_items to plot additionally on roi boxes. empty iterator to show none. :param split_ens_ics: :param server_env: :param sample_picks: which indices of the batch to display. None for all. :param vol_slice_picks: when batch elements are 3D: which slices to display. None for all, or tuples ("random", int: amt) / (float€[0,1]: fg_prob, int: amt) for random pick / fg_slices pick w probability fg_prob of amt slices. fg pick requires gt seg. :param patient_items: set to true if patient-wise batch items should be displayed (need to be contained in batch and marked via 'patient_' prefix. :param box_score_thres: plot only boxes with pred_score > box_score_thres. None or 0. for no thres. :param plot_mods: :param dpi: graphics resolution :param vmin: min value for gs cmap in imshow, set to fix inter-batch, or None for intra-batch. pyplot expects dimensions in order y,x,chans (height, width, chans) for imshow. show_seg_ids: "all" or None or list with seg classes to show (seg_ids) """ # pfix = prefix, ptfix = postfix pfix = 'patient_' if patient_items else '' ptfix = '_2d' if (patient_items and cf.dim==2) else '' # -------------- get data, set flags ----------------- btype = type(batch[pfix + 'data']) data = batch[pfix + 'data'].astype("float32") seg = batch[pfix + 'seg'] # seg = np.array(seg).mean(axis=0, keepdims=True) # seg[seg>0] = 1. print("Showing multirater GT") data_init_shp, seg_init_shp = data.shape, seg.shape fg_slices = np.where(np.sum(np.sum(np.squeeze(seg), axis=0), axis=0)>0)[0] if len(fg_slices)==0: print("skipping empty patient") return if vol_slice_picks is None: vol_slice_picks = fg_slices print("data shp, seg shp", data_init_shp, seg_init_shp) plot_bg = batch['plot_bg'] if 'plot_bg' in batch.keys() and not isinstance(batch['plot_bg'], (int, float)) else None plot_bg_chan = batch['plot_bg'] if 'plot_bg' in batch.keys() and isinstance(batch['plot_bg'], (int, float)) else 0 gt_boxes = batch[pfix+'bb_target'+ptfix] if pfix+'bb_target'+ptfix in batch.keys() and show_gt_boxes else None class_targets = batch[pfix+'class_targets'+ptfix] if pfix+'class_targets'+ptfix in batch.keys() else None cf_roi_items = [pfix+it+ptfix for it in cf.roi_items] if roi_items == "all": roi_items = [it for it in cf_roi_items] elif roi_items == "targets": roi_items = [it for it in cf_roi_items if 'targets' in it] else: roi_items = [it for it in cf_roi_items if it in roi_items] if res_dict is not None: seg_preds = res_dict["seg_preds"] if (show_seg_pred is not None and 'seg_preds' in res_dict.keys() and show_seg_ids) else None if '2D_boxes' in res_dict.keys(): assert cf.dim==2 pr_boxes = res_dict["2D_boxes"] elif 'boxes' in res_dict.keys(): pr_boxes = res_dict["boxes"] else: pr_boxes = None else: seg_preds = None pr_boxes = None # -------------- get shapes, apply sample selection ----------------- (n_samples, mods, h, w), d = data.shape[:4], 0 z_ics = [slice(None)] if has_colorchannels: #has to be 2D data = np.transpose(data, axes=(0, 2, 3, 1)) # now b,y,x,c mods = 1 else: if len(data.shape) == 5: # 3dim case d = data.shape[4] if vol_slice_picks is None: z_ics = np.arange(0, d) # elif hasattr(vol_slice_picks, "__iter__") and vol_slice_picks[0]=="random": # z_ics = np.random.choice(np.arange(0, d), size=min(vol_slice_picks[1], d), replace=False) else: z_ics = vol_slice_picks sample_ics = range(n_samples) # 8000 approx value of pixels that are displayable in one figure dim (pyplot has a render limit), depends on dpi however if data.shape[0]*data.shape[2]*len(z_ics)>8000: n_picks = max(1, int(8000/(data.shape[2]*len(z_ics)))) if len(z_ics)>1: if vol_slice_picks is None: z_ics = np.random.choice(np.arange(0, data.shape[4]), size=min(data.shape[4], max(1,int(8000/(n_picks*data.shape[2])))), replace=False) else: z_ics = np.random.choice(vol_slice_picks, size=min(len(vol_slice_picks), max(1,int(8000/(n_picks*data.shape[2])))), replace=False) if sample_picks is None: sample_picks = np.random.choice(data.shape[0], n_picks, replace=False) if sample_picks is not None: sample_ics = [s for s in sample_picks if s in sample_ics] n_samples = len(sample_ics) if not plot_mods: mods = 0 if show_seg_ids=="all": show_seg_ids = np.unique(seg) legend_items = set() # -------------- setup figure ----------------- if isRGB: data = RGB_to_rgb(data) if plot_bg is not None: plot_bg = RGB_to_rgb(plot_bg) n_cols = mods if seg is not None or gt_boxes is not None: n_cols += 1 if seg_preds is not None or pr_boxes is not None: n_cols += 1 n_rows = n_samples*len(z_ics) grid = gridspec.GridSpec(n_rows, n_cols, wspace=0.01, hspace=0.0) fig = plt.figure(figsize=((n_cols + 1)*2, n_rows*2), tight_layout=True) title_fs = 12 # fontsize sample_ics, z_ics = sorted(sample_ics), sorted(z_ics) row = 0 # current row for s_count, s_ix in enumerate(sample_ics): for z_ix in z_ics: col = 0 # current col # ----visualise input data ------------- if has_colorchannels: if plot_mods: ax = fig.add_subplot(grid[row, col]) ax.imshow(data[s_ix][...,z_ix]) ax.axis("off") if row == 0: plt.title("Input", fontsize=title_fs) if col == 0: # key = "spec" if "spec" in batch.keys() else "pid" specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix) == slice else z_ix ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number col += 1 bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix][...,z_ix] else: for mod in range(mods): ax = fig.add_subplot(grid[row, col]) ax.imshow(data[s_ix, mod][...,z_ix], cmap="gray", vmin=vmin) suppress_axes_lines(ax) if row == 0: plt.title("Mod. " + str(mod), fontsize=title_fs) if col == 0: # key = "spec" if "spec" in batch.keys() else "pid" specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix)==slice else z_ix ylabel = str(specs[s_ix])[-5:]+"/"+str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number col += 1 bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix, plot_bg_chan][...,z_ix] # ---evtly visualise groundtruths------------------- if seg is not None or gt_boxes is not None: # img as bg for gt ax = fig.add_subplot(grid[row, col]) ax.imshow(bg_img, cmap="gray", vmin=vmin) if row == 0: plt.title("Ground Truth+ Pred", fontsize=title_fs) if col == 0: specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix) == slice else z_ix ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number suppress_axes_lines(ax) else: plt.axis('off') col += 1 if seg is not None and seg.shape[1] == 1: cmap = {1: cf.orange} ax.imshow(to_rgba(seg[s_ix][0][...,z_ix], cmap), alpha=0.8) # gt bounding boxes if gt_boxes is not None and len(gt_boxes[s_ix]) > 0: for j, box in enumerate(gt_boxes[s_ix]): if d > 0: [z1, z2] = box[4:] if not (z1<=z_ix and z_ix<=z2): box = [] if len(box) > 0: [y1, x1, y2, x2] = box[:4] # [x1,y1,x2,y2] = box[:4]#:return: coords (x1, y1, x2, y2) width, height = x2 - x1, y2 - y1 if class_targets is not None: label = cf.class_id2label[class_targets[s_ix][j]] legend_items.add(label) if show_gt_labels and cf.plot_class_ids: text_poss, p = [(x1, y1), (x1, (y1+y2)//2)], 0 text_fs = title_fs // 3 if roi_items is not None: for name in roi_items: if name in cf_roi_items and batch[name][s_ix][j] is not None: if 'class_targets' in name: text_x = x2 #- 2 * text_fs * (len(str(class_targets[s_ix][j]))) # avoid overlap of scores text_y = y1 #+ 2 * text_fs text_str = '{}'.format(class_targets[s_ix][j]) elif 'regression_targets' in name: text_x, text_y = (x2, y2) text_str = "[" + " ".join( ["{:.1f}".format(x) for x in batch[name][s_ix][j]]) + "]" elif 'rg_bin_targets' in name: text_x, text_y = (x1, y2) text_str = '{}'.format(batch[name][s_ix][j]) else: text_pos = text_poss.pop(0) text_x = text_pos[0] #- 2 * text_fs * len(str(batch[name][s_ix][j])) text_y = text_pos[1] #+ 2 * text_fs text_str = '{}'.format(batch[name][s_ix][j]) ax.text(text_x, text_y, text_str, color=cf.black if label.color==cf.yellow else cf.white, fontsize=text_fs, bbox=dict(facecolor=label.color, alpha=0.7, edgecolor='none', clip_on=True, pad=0)) p+=1 bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=label.color, facecolor='none') ax.add_patch(bbox) # # -----evtly visualise predictions ------------- # if pr_boxes is not None or seg_preds is not None: # ax = fig.add_subplot(grid[row, col]) # ax.imshow(bg_img, cmap="gray") # ax.axis("off") # col += 1 # if row == 0: # plt.title("Prediction", fontsize=title_fs) # ---------- pred boxes ------------------------- if pr_boxes is not None and len(pr_boxes[s_ix]) > 0: box_score_thres = cf.min_det_thresh if box_score_thres is None else box_score_thres for j, box in enumerate(pr_boxes[s_ix]): plot_box = box["box_type"] in ["det", "prop"] # , "pos_anchor", "neg_anchor"] if box["box_type"] == "det" and (float(box["box_score"]) <= box_score_thres or box["box_pred_class_id"] == 0): plot_box = False if plot_box: if d > 0: [z1, z2] = box["box_coords"][4:] if not (z1<=z_ix and z_ix<=z2): box = [] if len(box) > 0: [y1, x1, y2, x2] = box["box_coords"][:4] width, height = x2 - x1, y2 - y1 if box["box_type"] == "det": label = cf.bin_id2label[box["rg_bin"]] color = cf.aubergine legend_items.add(label) text_x, text_y = x2, y1 #id_text = str(box["box_pred_class_id"]) + "|" if cf.plot_class_ids else "" id_text = "fg: " text_str = '{}{:.0f}'.format(id_text, box["box_score"] * 100) text_settings = dict(facecolor=color, alpha=0.5, edgecolor='none', clip_on=True, pad=0.2) ax.text(text_x, text_y, text_str, color=cf.black if label.color==cf.yellow else cf.white, bbox=text_settings, fontsize=title_fs // 2) edgecolor = color #label.color if 'regression' in box.keys(): text_x, text_y = x2, y2 id_text = "ms: "+" ".join(["{:.1f}".format(x) for x in box["regression"]])+"" text_str = '{}'.format(id_text) #, box["box_score"] * 100) text_settings = dict(facecolor=color, alpha=0.5, edgecolor='none', clip_on=True, pad=0.2) ax.text(text_x, text_y, text_str, color=cf.black if label.color==cf.yellow else cf.white, bbox=text_settings, fontsize=title_fs // 2) if 'rg_bin' in box.keys(): text_x, text_y = x1, y2 text_str = '{}'.format(box["rg_bin"]) text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) # ax.text(text_x, text_y, text_str, color=cf.white, # bbox=text_settings, fontsize=title_fs // 4) if split_ens_ics and "ens_ix" in box.keys(): n_aug = box["ens_ix"].split("_")[1] edgecolor = [c for c in cf.color_palette if not c == cf.green][ int(n_aug) % (len(cf.color_palette) - 1)] text_x, text_y = x1, y2 text_str = "{}".format(box["ens_ix"][2:]) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 6) else: label = cf.box_type2label[box["box_type"]] legend_items.add(label) edgecolor = label.color bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=edgecolor, facecolor='none') ax.add_patch(bbox) row += 1 # -----actions for all batch entries---------- if legend and len(legend_items) > 0: patches = [] for label in legend_items: if cf.plot_class_ids and type(label) != type(cf.box_labels[0]): id_text = str(label.id) + ":" else: id_text = "" patches.append(mpatches.Patch(color=label.color, label="{}{:.10s}".format(id_text, label.name))) # assumes one image gives enough y-space for 5 legend items ncols = max(1, len(legend_items) // (5 * n_samples)) plt.figlegend(handles=patches, loc="upper center", bbox_to_anchor=(0.99, 0.86), borderaxespad=0., ncol=ncols, bbox_transform=fig.transFigure, fontsize=int(2/3*title_fs)) # fig.set_size_inches(mods+3+ncols-1,1.5+1.2*n_samples) if show_info: plt.figtext(0, 0, "Batch content is of type\n{}\nand has shapes\n".format(btype) + \ "{} for 'data' and {} for 'seg'".format(data_init_shp, seg_init_shp)) if out_file is not None: plt.savefig(out_file, dpi=dpi, pad_inches=0.0, bbox_inches='tight', tight_layout=True) if return_fig: return plt.gcf() if not (server_env or cf.server_env): plt.show() plt.clf() plt.close() def view_batch_thesis(cf, batch, res_dict=None, out_file=None, legend=True, has_colorchannels=False, isRGB=True, show_seg_ids="all", show_seg_pred=True, show_gt_boxes=True, show_gt_labels=False, show_cl_ids=True, roi_items="all", server_env=True, sample_picks=None, vol_slice_picks=None, fontsize=12, seg_cmap="class", patient_items=False, box_score_thres=None, plot_mods=True, dpi=400, vmin=None, return_fig=False, axes=None): r"""view data and target entries of a batch. batch expected as dic with entries 'data' and 'seg' holding tensors or nparrays of size :math:`batch\_size \times modalities \times h \times w` for data and :math:`batch\_size \times classes \times h \times w` or :math:`batch\_size \times 1 \times h \times w` for segs. Classes, even if just dummy, are always needed for plotting since they determine colors. :param cf: :param batch: :param res_dict: :param out_file: :param legend: :param show_info: :param has_colorchannels: :param isRGB: :param show_seg_ids: :param show_seg_pred: :param show_gt_boxes: :param show_gt_labels: :param roi_items: strings "all" or "targets" --> all roi items in cf.roi_items or only those which are targets, or list holding keys/names of entries in cf.roi_items to plot additionally on roi boxes. empty iterator to show none. :param split_ens_ics: :param server_env: :param sample_picks: which indices of the batch to display. None for all. :param vol_slice_picks: when batch elements are 3D: which slices to display. None for all, or tuples ("random", int: amt) / (float€[0,1]: fg_prob, int: amt) for random pick / fg_slices pick w probability fg_prob of amt slices. fg pick requires gt seg. :param patient_items: set to true if patient-wise batch items should be displayed (need to be contained in batch and marked via 'patient_' prefix. :param box_score_thres: plot only boxes with pred_score > box_score_thres. None or 0. for no thres. :param plot_mods: :param dpi: graphics resolution :param vmin: min value for gs cmap in imshow, set to fix inter-batch, or None for intra-batch. pyplot expects dimensions in order y,x,chans (height, width, chans) for imshow. show_seg_ids: "all" or None or list with seg classes to show (seg_ids) """ # pfix = prefix, ptfix = postfix pfix = 'patient_' if patient_items else '' ptfix = '_2d' if (patient_items and cf.dim==2) else '' # -------------- get data, set flags ----------------- btype = type(batch[pfix + 'data']) data = batch[pfix + 'data'].astype("float32") seg = batch[pfix + 'seg'] data_init_shp, seg_init_shp = data.shape, seg.shape fg_slices = np.where(np.sum(np.sum(np.squeeze(seg), axis=0), axis=0)>0)[0] if len(fg_slices)==0: print("skipping empty patient") return if vol_slice_picks is None: vol_slice_picks = fg_slices #print("data shp, seg shp", data_init_shp, seg_init_shp) plot_bg = batch['plot_bg'] if 'plot_bg' in batch.keys() and not isinstance(batch['plot_bg'], (int, float)) else None plot_bg_chan = batch['plot_bg'] if 'plot_bg' in batch.keys() and isinstance(batch['plot_bg'], (int, float)) else 0 gt_boxes = batch[pfix+'bb_target'+ptfix] if pfix+'bb_target'+ptfix in batch.keys() and show_gt_boxes else None class_targets = batch[pfix+'class_targets'+ptfix] if pfix+'class_targets'+ptfix in batch.keys() else None cl_targets_sa = batch[pfix+'class_targets_sa'+ptfix] if pfix+'class_targets_sa'+ptfix in batch.keys() else None cf_roi_items = [pfix+it+ptfix for it in cf.roi_items] if roi_items == "all": roi_items = [it for it in cf_roi_items] elif roi_items == "targets": roi_items = [it for it in cf_roi_items if 'targets' in it] else: roi_items = [it for it in cf_roi_items if it in roi_items] if res_dict is not None: seg_preds = res_dict["seg_preds"] if (show_seg_pred is not None and 'seg_preds' in res_dict.keys() and show_seg_ids) else None if '2D_boxes' in res_dict.keys(): assert cf.dim==2 pr_boxes = res_dict["2D_boxes"] elif 'boxes' in res_dict.keys(): pr_boxes = res_dict["boxes"] else: pr_boxes = None else: seg_preds = None pr_boxes = None # -------------- get shapes, apply sample selection ----------------- (n_samples, mods, h, w), d = data.shape[:4], 0 z_ics = [slice(None)] if has_colorchannels: #has to be 2D data = np.transpose(data, axes=(0, 2, 3, 1)) # now b,y,x,c mods = 1 else: if len(data.shape) == 5: # 3dim case d = data.shape[4] if vol_slice_picks is None: z_ics = np.arange(0, d) else: z_ics = vol_slice_picks sample_ics = range(n_samples) # 8000 approx value of pixels that are displayable in one figure dim (pyplot has a render limit), depends on dpi however if data.shape[0]*data.shape[2]*len(z_ics)>8000: n_picks = max(1, int(8000/(data.shape[2]*len(z_ics)))) if len(z_ics)>1 and vol_slice_picks is None: z_ics = np.random.choice(np.arange(0, data.shape[4]), size=min(data.shape[4], max(1,int(8000/(n_picks*data.shape[2])))), replace=False) if sample_picks is None: sample_picks = np.random.choice(data.shape[0], n_picks, replace=False) if sample_picks is not None: sample_ics = [s for s in sample_picks if s in sample_ics] n_samples = len(sample_ics) if not plot_mods: mods = 0 if show_seg_ids=="all": show_seg_ids = np.unique(seg) legend_items = set() # -------------- setup figure ----------------- if isRGB: data = RGB_to_rgb(data) if plot_bg is not None: plot_bg = RGB_to_rgb(plot_bg) n_cols = mods if seg is not None or gt_boxes is not None: n_cols += 1 if seg_preds is not None or pr_boxes is not None: n_cols += 1 n_rows = n_samples*len(z_ics) grid = gridspec.GridSpec(n_rows, n_cols, wspace=0.01, hspace=0.0) fig = plt.figure(figsize=((n_cols + 1)*2, n_rows*2), tight_layout=True) title_fs = fontsize # fontsize text_fs = title_fs * 2 / 3 sample_ics, z_ics = sorted(sample_ics), sorted(z_ics) row = 0 # current row for s_count, s_ix in enumerate(sample_ics): for z_ix in z_ics: col = 0 # current col # ----visualise input data ------------- if has_colorchannels: if plot_mods: ax = fig.add_subplot(grid[row, col]) ax.imshow(data[s_ix][...,z_ix]) ax.axis("off") if row == 0: plt.title("Input", fontsize=title_fs) if col == 0: # key = "spec" if "spec" in batch.keys() else "pid" specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix) == slice else z_ix ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) if show_info else str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number col += 1 bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix][...,z_ix] else: for mod in range(mods): ax = fig.add_subplot(grid[row, col]) ax.imshow(data[s_ix, mod][...,z_ix], cmap="gray", vmin=vmin) suppress_axes_lines(ax) if row == 0: plt.title("Mod. " + str(mod), fontsize=title_fs) if col == 0: # key = "spec" if "spec" in batch.keys() else "pid" specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix)==slice else z_ix ylabel = str(specs[s_ix])[-5:]+"/"+str(intra_patient_ix) ax.set_ylabel("{:s}".format(ylabel), fontsize=title_fs) # show id-number col += 1 bg_img = plot_bg[s_ix][...,z_ix] if plot_bg is not None else data[s_ix, plot_bg_chan][...,z_ix] # ---evtly visualise groundtruths------------------- if seg is not None or gt_boxes is not None: # img as bg for gt if axes is not None and 'gt' in axes.keys(): ax = axes['gt'] else: ax = fig.add_subplot(grid[row, col]) ax.imshow(bg_img, cmap="gray", vmin=vmin) if row == 0: ax.set_title("Ground Truth", fontsize=title_fs) if col == 0: # key = "spec" if "spec" in batch.keys() else "pid" specs = batch.get('spec', batch['pid']) intra_patient_ix = s_ix if type(z_ix) == slice else z_ix ylabel = str(specs[s_ix])[-5:] + "/" + str(intra_patient_ix) # str(specs[s_ix])[-5:] ax.set_ylabel("{:s}".format(ylabel), fontsize=text_fs*1.3) # show id-number suppress_axes_lines(ax) else: ax.axis('off') col += 1 # gt bounding boxes if gt_boxes is not None and len(gt_boxes[s_ix]) > 0: for j, box in enumerate(gt_boxes[s_ix]): if d > 0: [z1, z2] = box[4:] if not (z1<=z_ix and z_ix<=z2): box = [] if len(box) > 0: [y1, x1, y2, x2] = box[:4] # [x1,y1,x2,y2] = box[:4]#:return: coords (x1, y1, x2, y2) width, height = x2 - x1, y2 - y1 if class_targets is not None: try: label = cf.bin_id2label[cf.rg_val_to_bin_id(batch['patient_regression_targets'][s_ix][j])] except AttributeError: label = cf.class_id2label[class_targets[s_ix][j]] legend_items.add(label) if show_gt_labels and cf.plot_class_ids: bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=label.color, facecolor='none') if height<=text_fs*6: y1 -= text_fs*1.5 y2 += text_fs*2 text_poss, p = [(x1, y1), (x1, (y1+y2)//2)], 0 if roi_items is not None: for name in roi_items: if name in cf_roi_items and batch[name][s_ix][j] is not None: if 'class_targets' in name: text_str = '{}'.format(class_targets[s_ix][j]) text_x, text_y = (x2 + 0 * len(text_str) // 4, y2) elif 'regression_targets' in name: text_str = 'agg. MS: {:.2f}'.format(batch[name][s_ix][j][0]) text_x, text_y = (x2 + 0 * len(text_str) // 4, y2) elif 'rg_bin_targets_sa' in name: text_str = 'sa. MS: {}'.format(batch[name][s_ix][j]) text_x, text_y = (x2-0*len(text_str)*text_fs//4, y1) # elif 'rg_bin_targets' in name: # text_str = 'agg. ms:{}'.format(batch[name][s_ix][j]) # text_x, text_y = (x2+0*len(text_str)//4, y1) ax.text(text_x, text_y, text_str, color=cf.black if (label.color[:3]==cf.yellow or label.color[:3]==cf.green) else cf.white, fontsize=text_fs, bbox=dict(facecolor=label.color, alpha=0.7, edgecolor='none', clip_on=True, pad=0)) p+=1 ax.add_patch(bbox) if seg is not None and seg.shape[1] == 1: #cmap = {1: cf.orange} # cmap = {label_id: label.color for label_id, label in cf.bin_id2label.items()} # this whole function is totally only hacked together for a quick very specific case if seg_cmap == "rg" or seg_cmap=="regression": cmap = {1: cf.bin_id2label[cf.rg_val_to_bin_id(batch['patient_regression_targets'][s_ix][0])].color} else: cmap = cf.class_cmap ax.imshow(to_rgba(seg[s_ix][0][...,z_ix], cmap), alpha=0.8) # # -----evtly visualise predictions ------------- if pr_boxes is not None or seg_preds is not None: if axes is not None and 'pred' in axes.keys(): ax = axes['pred'] else: ax = fig.add_subplot(grid[row, col]) ax.imshow(bg_img, cmap="gray") ax.axis("off") col += 1 if row == 0: ax.set_title("Prediction", fontsize=title_fs) # ---------- pred boxes ------------------------- if pr_boxes is not None and len(pr_boxes[s_ix]) > 0: alpha = 0.7 box_score_thres = cf.min_det_thresh if box_score_thres is None else box_score_thres for j, box in enumerate(pr_boxes[s_ix]): plot_box = box["box_type"] in ["det", "prop"] # , "pos_anchor", "neg_anchor"] if box["box_type"] == "det" and (float(box["box_score"]) <= box_score_thres or box["box_pred_class_id"] == 0): plot_box = False if plot_box: if d > 0: [z1, z2] = box["box_coords"][4:] if not (z1<=z_ix and z_ix<=z2): box = [] if len(box) > 0: [y1, x1, y2, x2] = box["box_coords"][:4] width, height = x2 - x1, y2 - y1 if box["box_type"] == "det": try: label = cf.bin_id2label[cf.rg_val_to_bin_id(box['regression'])] except AttributeError: label = cf.class_id2label[box['box_pred_class_id']] # assert box["rg_bin"] == cf.rg_val_to_bin_id(box['regression']), \ # "box bin: {}, rg-bin {}".format(box["rg_bin"], cf.rg_val_to_bin_id(box['regression'])) color = label.color#cf.aubergine edgecolor = color # label.color text_color = cf.black if (color[:3]==cf.yellow or color[:3]==cf.green) else cf.white legend_items.add(label) bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=0.6, edgecolor=edgecolor, facecolor='none') if height<=text_fs*6: y1 -= text_fs*1.5 y2 += text_fs*2 text_x, text_y = x2, y1 #id_text = str(box["box_pred_class_id"]) + "|" if cf.plot_class_ids else "" id_text = "FG: " text_str = r'{}{:.0f}%'.format(id_text, box["box_score"] * 100) text_settings = dict(facecolor=color, alpha=alpha, edgecolor='none', clip_on=True, pad=0.2) ax.text(text_x, text_y, text_str, color=text_color, bbox=text_settings, fontsize=text_fs ) if 'regression' in box.keys(): text_x, text_y = x2, y2 id_text = "MS: "+" ".join(["{:.2f}".format(x) for x in box["regression"]])+"" text_str = '{}'.format(id_text) text_settings = dict(facecolor=color, alpha=alpha, edgecolor='none', clip_on=True, pad=0.2) ax.text(text_x, text_y, text_str, color=text_color, bbox=text_settings, fontsize=text_fs) if 'rg_bin' in box.keys(): text_x, text_y = x1, y2 text_str = '{}'.format(box["rg_bin"]) text_settings = dict(facecolor=color, alpha=alpha, edgecolor='none', clip_on=True, pad=0) # ax.text(text_x, text_y, text_str, color=cf.white, # bbox=text_settings, fontsize=title_fs // 4) if 'box_pred_class_id' in box.keys() and show_cl_ids: text_x, text_y = x2, y2 id_text = box["box_pred_class_id"] text_str = '{}'.format(id_text) text_settings = dict(facecolor=color, alpha=alpha, edgecolor='none', clip_on=True, pad=0.2) ax.text(text_x, text_y, text_str, color=text_color, bbox=text_settings, fontsize=text_fs) else: label = cf.box_type2label[box["box_type"]] legend_items.add(label) edgecolor = label.color ax.add_patch(bbox) row += 1 # -----actions for all batch entries---------- if legend and len(legend_items) > 0: patches = [] for label in legend_items: if cf.plot_class_ids and type(label) != type(cf.box_labels[0]): id_text = str(label.id) + ":" else: id_text = "" patches.append(mpatches.Patch(color=label.color, label="{}{:.10s}".format(id_text, label.name))) # assumes one image gives enough y-space for 5 legend items ncols = max(1, len(legend_items) // (5 * n_samples)) plt.figlegend(handles=patches, loc="upper center", bbox_to_anchor=(0.99, 0.86), borderaxespad=0., ncol=ncols, bbox_transform=fig.transFigure, fontsize=int(2/3*title_fs)) # fig.set_size_inches(mods+3+ncols-1,1.5+1.2*n_samples) if out_file is not None: plt.savefig(out_file, dpi=dpi, pad_inches=0.0, bbox_inches='tight', tight_layout=True) if return_fig: return plt.gcf() if not (server_env or cf.server_env): plt.show() plt.clf() plt.close() def view_slices(cf, img, seg=None, ids=None, title="", out_dir=None, legend=True, cmap=None, label_remap=None, instance_labels=False): """View slices of a 3D image overlayed with corresponding segmentations. :params img, seg: expected as 3D-arrays """ if isinstance(img, sitk.SimpleITK.Image): img = sitk.GetArrayViewFromImage(img) elif isinstance(img, np.ndarray): #assume channels dim is smallest and in either first or last place if np.argmin(img.shape)==2: img = np.moveaxis(img, 2,0) else: raise Exception("view_slices got unexpected img type.") if seg is not None: if isinstance(seg, sitk.SimpleITK.Image): seg = sitk.GetArrayViewFromImage(seg) elif isinstance(img, np.ndarray): if np.argmin(seg.shape)==2: seg = np.moveaxis(seg, 2,0) else: raise Exception("view_slices got unexpected seg type.") if label_remap is not None: for (key, val) in label_remap.items(): seg[seg==key] = val if instance_labels: class Label(): def __init__(self, id, name, color): self.id = id self.name = name self.color = color legend_items = {Label(seg_id, "instance_{}".format(seg_id), cf.color_palette[seg_id%len(cf.color_palette)]) for seg_id in np.unique(seg)} if cmap is None: cmap = {label.id : label.color for label in legend_items} else: legend_items = {cf.seg_id2label[seg_id] for seg_id in np.unique(seg)} if cmap is None: cmap = {label.id : label.color for label in legend_items} slices = img.shape[0] if seg is not None: assert slices==seg.shape[0], "Img and seg have different amt of slices." grid = gridspec.GridSpec(int(np.ceil(slices/4)),4) fig = plt.figure(figsize=(10, slices/4*2.5)) rng = np.arange(slices, dtype='uint8') if not ids is None: rng = rng[ids] for s in rng: ax = fig.add_subplot(grid[int(s/4),int(s%4)]) ax.imshow(img[s], cmap="gray") if not seg is None: ax.imshow(to_rgba(seg[s], cmap), alpha=0.9) if legend and int(s/4)==0 and int(s%4)==3: patches = [mpatches.Patch(color=label.color, label="{}".format(label.name)) for label in legend_items] ncols = 1 plt.legend(handles=patches,bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., ncol=ncols) plt.title("slice {}, {}".format(s, img[s].shape)) plt.axis('off') plt.suptitle(title) if out_dir is not None: plt.savefig(out_dir, dpi=300, pad_inches=0.0, bbox_inches='tight') if not cf.server_env: plt.show() plt.close() def plot_txt(cf, txts, labels=None, title="", x_label="", y_labels=["",""], y_ranges=(None,None), twin_axes=(), smooth=None, out_dir=None): """Read and plot txt data, either from file (txts is paths) or directly (txts is arrays). :param twin_axes: plot two y-axis over same x-axis. twin_axes expected as tuple defining which txt files (determined via indices) share the second y-axis. """ if isinstance(txts, str) or not hasattr(txts, '__iter__'): txts = [txts] fig = plt.figure() ax1 = fig.add_subplot(1,1,1) if len(twin_axes)>0: ax2 = ax1.twinx() for i, txt in enumerate(txts): if isinstance(txt, str): arr = np.genfromtxt(txt, delimiter=',',skip_header=1, usecols=(1,2)) else: arr = txt if i in twin_axes: ax = ax2 else: ax = ax1 if smooth is not None: spline_graph = interpol.UnivariateSpline(arr[:,0], arr[:,1], k=5, s=float(smooth)) ax.plot(arr[:, 0], spline_graph(arr[:,0]), color=cf.color_palette[i % len(cf.color_palette)], marker='', markersize=2, linestyle='solid') ax.plot(arr[:,0], arr[:,1], color=cf.color_palette[i%len(cf.color_palette)], marker='', markersize=2, linestyle='solid', label=labels[i], alpha=0.5 if smooth else 1.) plt.title(title) ax1.set_xlabel(x_label) ax1.set_ylabel(y_labels[0]) if y_ranges[0] is not None: ax1.set_ylim(y_ranges[0]) if len(twin_axes)>0: ax2.set_ylabel(y_labels[1]) if y_ranges[1] is not None: ax2.set_ylim(y_ranges[1]) plt.grid() if labels is not None: ax1.legend(loc="upper center") if len(twin_axes)>0: ax2.legend(loc=4) if out_dir is not None: plt.savefig(out_dir, dpi=200) return fig def plot_tboard_logs(cf, log_dir, tag_filters=[""], inclusive_filters=True, out_dir=None, x_label="", y_labels=["",""], y_ranges=(None,None), twin_axes=(), smooth=None): """Plot (only) tboard scalar logs from given log_dir for multiple runs sorted by tag. """ print("log dir", log_dir) mpl = EventMultiplexer().AddRunsFromDirectory(log_dir) #EventAccumulator(log_dir) mpl.Reload() # Print tags of contained entities, use these names to retrieve entities as below #print(mpl.Runs()) scalars = {runName : data['scalars'] for (runName, data) in mpl.Runs().items() if len(data['scalars'])>0} print("scalars", scalars) tags = {} tag_filters = [tag_filter.lower() for tag_filter in tag_filters] for (runName, runtags) in scalars.items(): print("rn", runName.lower()) check = np.any if inclusive_filters else np.all if np.any([tag_filter in runName.lower() for tag_filter in tag_filters]): for runtag in runtags: #if tag_filter in runtag.lower(): if runtag not in tags: tags[runtag] = [runName] else: tags[runtag].append(runName) print("tags ", tags) for (tag, runNames) in tags.items(): print("runnames ", runNames) print("tag", tag) tag_scalars = [] labels = [] for run in runNames: #mpl.Scalars returns ScalarEvents array holding wall_time, step, value per time step (shape series_length x 3) #print(mpl.Scalars(runName, tag)[0]) run_scalars = [(s.step, s.value) for s in mpl.Scalars(run, tag)] print(np.array(run_scalars).shape) tag_scalars.append(np.array(run_scalars)) print("run", run) labels.append("/".join(run.split("/")[-2:])) #print("tag scalars ", tag_scalars) if out_dir is not None: out_path = os.path.join(out_dir,tag.replace("/","_")) else: out_path = None plot_txt(txts=tag_scalars, labels=labels, title=tag, out_dir=out_path, cf=cf, x_label=x_label, y_labels=y_labels, y_ranges=y_ranges, twin_axes=twin_axes, smooth=smooth) def plot_box_legend(cf, box_coords=None, class_id=None, out_dir=None): """plot a blank box explaining box annotations. :param cf: :return: """ if class_id is None: class_id = 1 img = np.ones(cf.patch_size[:2]) dim_max = max(cf.patch_size[:2]) width, height = cf.patch_size[0] // 2, cf.patch_size[1] // 2 if box_coords is None: # lower left corner x1, y1 = width // 2, height // 2 x2, y2 = x1 + width, y1 + height else: y1, x1, y2, x2 = box_coords fig = plt.figure(tight_layout=True, dpi=300) ax = fig.add_subplot(111) title_fs = 36 label = cf.class_id2label[class_id] # legend_items.add(label) ax.set_facecolor(cf.beige) ax.imshow(img, cmap='gray', vmin=0., vmax=1., alpha=0) # ax.axis('off') # suppress_axes_lines(ax) ax.set_xticks([]) ax.set_yticks([]) text_x, text_y = x2 * 0.85, y1 id_text = "class id" + " | " if cf.plot_class_ids else "" text_str = '{}{}'.format(id_text, "confidence") text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) edgecolor = label.color if any(['regression' in task for task in cf.prediction_tasks]): text_x, text_y = x2 * 0.85, y2 id_text = "regression" if any(['ken_gal' in task or 'feindt' in task for task in cf.prediction_tasks]): id_text += " | uncertainty" text_str = '{}'.format(id_text) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) if 'regression_bin' in cf.prediction_tasks or hasattr(cf, "rg_val_to_bin_id"): text_x, text_y = x1, y2 text_str = 'Rg. Bin' ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) if 'lesion_gleasons' in cf.observables_rois: text_x, text_y = x1, y1 text_str = 'Gleason Score' ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) bbox = mpatches.Rectangle((x1, y1), width, height, linewidth=1., edgecolor=edgecolor, facecolor='none') ax.add_patch(bbox) if out_dir is not None: plt.savefig(os.path.join(out_dir, "box_legend.png")) def plot_boxes(cf, box_coords, patch_size=None, scores=None, class_ids=None, out_file=None, ax=None): if patch_size is None: patch_size = cf.patch_size[:2] if class_ids is None: class_ids = np.ones((len(box_coords),), dtype='uint8') if scores is None: scores = np.ones((len(box_coords),), dtype='uint8') img = np.ones(patch_size) y1, x1, y2, x2 = box_coords[:,0], box_coords[:,1], box_coords[:,2], box_coords[:,3] width, height = x2-x1, y2-y1 close = False if ax is None: fig = plt.figure(tight_layout=True, dpi=300) ax = fig.add_subplot(111) close = True title_fs = 56 ax.set_facecolor((*cf.gray,0.15)) ax.imshow(img, cmap='gray', vmin=0., vmax=1., alpha=0) #ax.axis('off') #suppress_axes_lines(ax) ax.set_xticks([]) ax.set_yticks([]) for bix, cl_id in enumerate(class_ids): label = cf.class_id2label[cl_id] text_x, text_y = x2[bix] -20, y1[bix] +5 id_text = class_ids[bix] if cf.plot_class_ids else "" text_str = '{}{}{:.0f}'.format(id_text, " | ", scores[bix] * 100) text_settings = dict(facecolor=label.color, alpha=0.5, edgecolor='none', clip_on=True, pad=0) ax.text(text_x, text_y, text_str, color=cf.white, bbox=text_settings, fontsize=title_fs // 4) edgecolor = label.color bbox = mpatches.Rectangle((x1[bix], y1[bix]), width[bix], height[bix], linewidth=1., edgecolor=edgecolor, facecolor='none') ax.add_patch(bbox) if out_file is not None: plt.savefig(out_file) if close: plt.close() if __name__=="__main__": cluster_exp_root = "/mnt/E132-Cluster-Projects" #dataset="prostate/" dataset = "lidc/" exp_name = "ms13_mrcnnal3d_rg_bs8_480k" #exp_dir = os.path.join("datasets", dataset, "experiments", exp_name) # exp_dir = os.path.join(cluster_exp_root, dataset, "experiments", exp_name) # log_dir = os.path.join(exp_dir, "logs") # sys.path.append(exp_dir) # from configs import Configs # cf = configs() # # #print("logdir", log_dir) # #out_dir = os.path.join(cf.source_dir, log_dir.replace("/", "_")) # #print("outdir", out_dir) # log_dir = os.path.join(cf.source_dir, log_dir) # plot_tboard_logs(cf, log_dir, tag_filters=["train/lesion_avp", "val/lesion_ap", "val/lesion_avp", "val/patient_lesion_avp"], smooth=2.2, out_dir=log_dir, # y_ranges=([0,900], [0,0.8]), # twin_axes=[1], y_labels=["counts",""], x_label="epoch") #plot_box_legend(cf, out_dir=exp_dir) diff --git a/utils/exp_utils.py b/utils/exp_utils.py index 138cdb2..5bf4b05 100644 --- a/utils/exp_utils.py +++ b/utils/exp_utils.py @@ -1,632 +1,659 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -#import plotting as plg +# import plotting as plg import sys import os import subprocess import threading import pickle import importlib.util import psutil import time import logging from torch.utils.tensorboard import SummaryWriter from collections import OrderedDict import numpy as np import pandas as pd import torch def import_module(name, path): """ correct way of importing a module dynamically in python 3. :param name: name given to module instance. :param path: path to module. :return: module: returned module instance. """ spec = importlib.util.spec_from_file_location(name, path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module + def save_obj(obj, name): """Pickle a python object.""" with open(name + '.pkl', 'wb') as f: pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) + def load_obj(file_path): with open(file_path, 'rb') as handle: return pickle.load(handle) + def IO_safe(func, *args, _tries=5, _raise=True, **kwargs): """ Wrapper calling function func with arguments args and keyword arguments kwargs to catch input/output errors on cluster. :param func: function to execute (intended to be read/write operation to a problematic cluster drive, but can be any function). :param args: positional args of func. :param kwargs: kw args of func. :param _tries: how many attempts to make executing func. """ for _try in range(_tries): try: return func(*args, **kwargs) except OSError as e: # to catch cluster issues with network drives if _raise: raise e else: - print("After attempting execution {} time{}, following error occurred:\n{}".format(_try+1,"" if _try==0 else "s", e)) + print("After attempting execution {} time{}, following error occurred:\n{}".format(_try + 1, + "" if _try == 0 else "s", + e)) continue + def query_nvidia_gpu(device_id, d_keyword=None, no_units=False): """ :param device_id: :param d_keyword: -d, --display argument (keyword(s) for selective display), all are selected if None :return: dict of gpu-info items """ cmd = ['nvidia-smi', '-i', str(device_id), '-q'] if d_keyword is not None: cmd += ['-d', d_keyword] outp = subprocess.check_output(cmd).strip().decode('utf-8').split("\n") - outp = [x for x in outp if len(x)>0] - headers = [ix for ix, item in enumerate(outp) if len(item.split(":"))==1] + [len(outp)] + outp = [x for x in outp if len(x) > 0] + headers = [ix for ix, item in enumerate(outp) if len(item.split(":")) == 1] + [len(outp)] out_dict = {} for lix, hix in enumerate(headers[:-1]): head = outp[hix].strip().replace(" ", "_").lower() out_dict[head] = {} - for lix2 in range(hix, headers[lix+1]): + for lix2 in range(hix, headers[lix + 1]): try: key, val = [x.strip().lower() for x in outp[lix2].split(":")] if no_units: val = val.split()[0] out_dict[head][key] = val except: pass return out_dict + class CombinedPrinter(object): """combined print function. prints to logger and/or file if given, to normal print if non given. """ + def __init__(self, logger=None, file=None): if logger is None and file is None: self.out = [print] elif logger is None: self.out = [file.write] elif file is None: self.out = [logger.info] else: self.out = [logger.info, file.write] def __call__(self, string): for fct in self.out: fct(string) + class Nvidia_GPU_Logger(object): def __init__(self): self.count = None def get_vals(self): cmd = ['nvidia-settings', '-t', '-q', 'GPUUtilization'] gpu_util = subprocess.check_output(cmd).strip().decode('utf-8').split(",") gpu_util = dict([f.strip().split("=") for f in gpu_util]) cmd[-1] = 'UsedDedicatedGPUMemory' gpu_used_mem = subprocess.check_output(cmd).strip().decode('utf-8') current_vals = {"gpu_mem_alloc": gpu_used_mem, "gpu_graphics_util": int(gpu_util['graphics']), - "gpu_mem_util": gpu_util['memory'], "time": time.time()} + "gpu_mem_util": gpu_util['memory'], "time": time.time()} return current_vals - def loop(self): + def loop(self, interval): i = 0 while True: self.get_vals() self.log["time"].append(time.time()) self.log["gpu_util"].append(self.current_vals["gpu_graphics_util"]) if self.count is not None: i += 1 if i == self.count: exit(0) time.sleep(self.interval) def start(self, interval=1.): self.interval = interval self.start_time = time.time() self.log = {"time": [], "gpu_util": []} if self.interval is not None: thread = threading.Thread(target=self.loop) thread.daemon = True thread.start() +class DummyLogger(): + def __init__(self): + pass + def info(self, *args): + print(*args) + return None + class CombinedLogger(object): """Combine console and tensorboard logger and record system metrics. """ - def __init__(self, name, log_dir, server_env=True, fold="", sysmetrics_interval=2): - self.pylogger = logging.getLogger(name) + + def __init__(self, name, log_dir, server_env=True, fold="", sysmetrics_interval=-1): + self.pylogger = DummyLogger()#logging.getLogger(name) self.tboard = SummaryWriter(log_dir=log_dir) self.times = {} self.fold = fold - self.pylogger.setLevel(logging.DEBUG) - self.log_file = os.path.join(log_dir, 'exec.log') - self.pylogger.addHandler(logging.FileHandler(self.log_file)) - if not server_env: - self.pylogger.addHandler(ColorHandler()) - else: - self.pylogger.addHandler(logging.StreamHandler()) - self.pylogger.propagate = False + # self.pylogger.setLevel(logging.DEBUG) + # self.log_file = os.path.join(log_dir, 'exec.log') + # self.pylogger.addHandler(logging.FileHandler(self.log_file)) + # if not server_env: + # self.pylogger.addHandler(ColorHandler()) + # else: + # self.pylogger.addHandler(logging.StreamHandler()) + # self.pylogger.propagate = False # monitor system metrics (cpu, mem, ...) - if not server_env and sysmetrics_interval>0: - self.sysmetrics = pd.DataFrame(columns=["global_step", "rel_time", r"CPU (%)", "mem_used (GB)", r"mem_used (%)", - r"swap_used (GB)", r"gpu_utilization (%)"], dtype="float16") + if not server_env and sysmetrics_interval > 0: + self.sysmetrics = pd.DataFrame( + columns=["global_step", "rel_time", r"CPU (%)", "mem_used (GB)", r"mem_used (%)", + r"swap_used (GB)", r"gpu_utilization (%)"], dtype="float16") for device in range(torch.cuda.device_count()): - self.sysmetrics["mem_allocd (GB) by torch on {:10s}".format(torch.cuda.get_device_name(device))] = np.nan - self.sysmetrics["mem_cached (GB) by torch on {:10s}".format(torch.cuda.get_device_name(device))] = np.nan + self.sysmetrics[ + "mem_allocd (GB) by torch on {:10s}".format(torch.cuda.get_device_name(device))] = np.nan + self.sysmetrics[ + "mem_cached (GB) by torch on {:10s}".format(torch.cuda.get_device_name(device))] = np.nan self.sysmetrics_start(sysmetrics_interval) + pass else: print("NOT logging sysmetrics") def __getattr__(self, attr): """delegate all undefined method requests to objects of this class in order pylogger, tboard (first find first serve). E.g., combinedlogger.add_scalars(...) should trigger self.tboard.add_scalars(...) """ for obj in [self.pylogger, self.tboard]: if attr in dir(obj): return getattr(obj, attr) - raise AttributeError("CombinedLogger has no attribute {}".format(attr)) - + print("logger attr not found") + #raise AttributeError("CombinedLogger has no attribute {}".format(attr)) def time(self, name, toggle=None): """record time-spans as with a stopwatch. :param name: :param toggle: True^=On: start time recording, False^=Off: halt rec. if None determine from current status. :return: either start-time or last recorded interval """ if toggle is None: if name in self.times.keys(): toggle = not self.times[name]["toggle"] else: toggle = True if toggle: if not name in self.times.keys(): - self.times[name] = {"total": 0, "last":0} + self.times[name] = {"total": 0, "last": 0} elif self.times[name]["toggle"] == toggle: self.info("restarting running stopwatch") self.times[name]["last"] = time.time() self.times[name]["toggle"] = toggle return time.time() else: if toggle == self.times[name]["toggle"]: self.info("WARNING: tried to stop stopped stop watch: {}.".format(name)) - self.times[name]["last"] = time.time()-self.times[name]["last"] + self.times[name]["last"] = time.time() - self.times[name]["last"] self.times[name]["total"] += self.times[name]["last"] self.times[name]["toggle"] = toggle return self.times[name]["last"] def get_time(self, name=None, kind="total", format=None, reset=False): """ :param name: :param kind: 'total' or 'last' :param format: None for float, "hms"/"ms" for (hours), mins, secs as string :param reset: reset time after retrieving :return: """ if name is None: times = self.times if reset: self.reset_time() return times else: if self.times[name]["toggle"]: self.time(name, toggle=False) time = self.times[name][kind] if format == "hms": m, s = divmod(time, 60) h, m = divmod(m, 60) time = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(m), int(s)) elif format == "ms": m, s = divmod(time, 60) time = "{:02d}m:{:02d}s".format(int(m), int(s)) if reset: self.reset_time(name) return time def reset_time(self, name=None): if name is None: self.times = {} else: del self.times[name] - def sysmetrics_update(self, global_step=None): if global_step is None: global_step = time.strftime("%x_%X") - mem = psutil.virtual_memory() - mem_used = (mem.total-mem.available) + mem = psutil.virtual_memory() + mem_used = (mem.total - mem.available) gpu_vals = self.gpu_logger.get_vals() - rel_time = time.time()-self.sysmetrics_start_time + rel_time = time.time() - self.sysmetrics_start_time self.sysmetrics.loc[len(self.sysmetrics)] = [global_step, rel_time, - psutil.cpu_percent(), mem_used/1024**3, mem_used/mem.total*100, - psutil.swap_memory().used/1024**3, int(gpu_vals['gpu_graphics_util']), - *[torch.cuda.memory_allocated(d)/1024**3 for d in range(torch.cuda.device_count())], - *[torch.cuda.memory_cached(d)/1024**3 for d in range(torch.cuda.device_count())] - ] - return self.sysmetrics.loc[len(self.sysmetrics)-1].to_dict() + psutil.cpu_percent(), mem_used / 1024 ** 3, + mem_used / mem.total * 100, + psutil.swap_memory().used / 1024 ** 3, + int(gpu_vals['gpu_graphics_util']), + *[torch.cuda.memory_allocated(d) / 1024 ** 3 for d in + range(torch.cuda.device_count())], + *[torch.cuda.memory_cached(d) / 1024 ** 3 for d in + range(torch.cuda.device_count())] + ] + return self.sysmetrics.loc[len(self.sysmetrics) - 1].to_dict() def sysmetrics2tboard(self, metrics=None, global_step=None, suptitle=None): tag = "per_time" if metrics is None: metrics = self.sysmetrics_update(global_step=global_step) tag = "per_epoch" if suptitle is not None: suptitle = str(suptitle) - elif self.fold!="": - suptitle = "Fold_"+str(self.fold) + elif self.fold != "": + suptitle = "Fold_" + str(self.fold) if suptitle is not None: - self.tboard.add_scalars(suptitle+"/System_Metrics/"+tag, {k:v for (k,v) in metrics.items() if (k!="global_step" - and k!="rel_time")}, global_step) + self.tboard.add_scalars(suptitle + "/System_Metrics/" + tag, + {k: v for (k, v) in metrics.items() if (k != "global_step" + and k != "rel_time")}, global_step) def sysmetrics_loop(self): try: os.nice(-19) self.info("Logging system metrics with superior process priority.") except: self.info("Logging system metrics WITHOUT superior process priority.") while True: metrics = self.sysmetrics_update() self.sysmetrics2tboard(metrics, global_step=metrics["rel_time"]) - #print("thread alive", self.thread.is_alive()) + # print("thread alive", self.thread.is_alive()) time.sleep(self.sysmetrics_interval) - + def sysmetrics_start(self, interval): - if interval is not None and interval>0: + if interval is not None and interval > 0: self.sysmetrics_interval = interval self.gpu_logger = Nvidia_GPU_Logger() self.sysmetrics_start_time = time.time() self.thread = threading.Thread(target=self.sysmetrics_loop) self.thread.daemon = True self.thread.start() def sysmetrics_save(self, out_file): self.sysmetrics.to_pickle(out_file) - def metrics2tboard(self, metrics, global_step=None, suptitle=None): """ :param metrics: {'train': dataframe, 'val':df}, df as produced in evaluator.py.evaluate_predictions """ - #print("metrics", metrics) + # print("metrics", metrics) if global_step is None: - global_step = len(metrics['train'][list(metrics['train'].keys())[0]])-1 + global_step = len(metrics['train'][list(metrics['train'].keys())[0]]) - 1 if suptitle is not None: suptitle = str(suptitle) else: - suptitle = "Fold_"+str(self.fold) + suptitle = "Fold_" + str(self.fold) for key in ['train', 'val']: - #series = {k:np.array(v[-1]) for (k,v) in metrics[key].items() if not np.isnan(v[-1]) and not 'Bin_Stats' in k} + # series = {k:np.array(v[-1]) for (k,v) in metrics[key].items() if not np.isnan(v[-1]) and not 'Bin_Stats' in k} loss_series = {} unc_series = {} bin_stat_series = {} mon_met_series = {} - for tag,val in metrics[key].items(): - val = val[-1] #maybe remove list wrapping, recording in evaluator? + for tag, val in metrics[key].items(): + val = val[-1] # maybe remove list wrapping, recording in evaluator? if 'bin_stats' in tag.lower() and not np.isnan(val): bin_stat_series["{}".format(tag.split("/")[-1])] = val elif 'uncertainty' in tag.lower() and not np.isnan(val): unc_series["{}".format(tag)] = val elif 'loss' in tag.lower() and not np.isnan(val): loss_series["{}".format(tag)] = val elif not np.isnan(val): mon_met_series["{}".format(tag)] = val - self.tboard.add_scalars(suptitle+"/Binary_Statistics/{}".format(key), bin_stat_series, global_step) + self.tboard.add_scalars(suptitle + "/Binary_Statistics/{}".format(key), bin_stat_series, global_step) self.tboard.add_scalars(suptitle + "/Uncertainties/{}".format(key), unc_series, global_step) self.tboard.add_scalars(suptitle + "/Losses/{}".format(key), loss_series, global_step) - self.tboard.add_scalars(suptitle+"/Monitor_Metrics/{}".format(key), mon_met_series, global_step) + self.tboard.add_scalars(suptitle + "/Monitor_Metrics/{}".format(key), mon_met_series, global_step) self.tboard.add_scalars(suptitle + "/Learning_Rate", metrics["lr"], global_step) return - + def batchImgs2tboard(self, batch, results_dict, cmap, boxtype2color, img_bg=False, global_step=None): raise NotImplementedError("not up-to-date, problem with importing plotting-file, torchvision dependency.") - if len(batch["seg"].shape)==5: #3D imgs + if len(batch["seg"].shape) == 5: # 3D imgs slice_ix = np.random.randint(batch["seg"].shape[-1]) - seg_gt = plg.to_rgb(batch['seg'][:,0,:,:,slice_ix], cmap) - seg_pred = plg.to_rgb(results_dict['seg_preds'][:,0,:,:,slice_ix], cmap) - - mod_img = plg.mod_to_rgb(batch["data"][:,0,:,:,slice_ix]) if img_bg else None - - elif len(batch["seg"].shape)==4: - seg_gt = plg.to_rgb(batch['seg'][:,0,:,:], cmap) - seg_pred = plg.to_rgb(results_dict['seg_preds'][:,0,:,:], cmap) - mod_img = plg.mod_to_rgb(batch["data"][:,0]) if img_bg else None + seg_gt = plg.to_rgb(batch['seg'][:, 0, :, :, slice_ix], cmap) + seg_pred = plg.to_rgb(results_dict['seg_preds'][:, 0, :, :, slice_ix], cmap) + + mod_img = plg.mod_to_rgb(batch["data"][:, 0, :, :, slice_ix]) if img_bg else None + + elif len(batch["seg"].shape) == 4: + seg_gt = plg.to_rgb(batch['seg'][:, 0, :, :], cmap) + seg_pred = plg.to_rgb(results_dict['seg_preds'][:, 0, :, :], cmap) + mod_img = plg.mod_to_rgb(batch["data"][:, 0]) if img_bg else None else: raise Exception("batch content has wrong format: {}".format(batch["seg"].shape)) - - #from here on only works in 2D - seg_gt = np.transpose(seg_gt, axes=(0,3,1,2)) #previous shp: b,x,y,c - seg_pred = np.transpose(seg_pred, axes=(0,3,1,2)) - - + + # from here on only works in 2D + seg_gt = np.transpose(seg_gt, axes=(0, 3, 1, 2)) # previous shp: b,x,y,c + seg_pred = np.transpose(seg_pred, axes=(0, 3, 1, 2)) + seg = np.concatenate((seg_gt, seg_pred), axis=0) # todo replace torchvision (tv) dependency seg = tv.utils.make_grid(torch.from_numpy(seg), nrow=2) - self.tboard.add_image("Batch seg, 1st col: gt, 2nd: pred.", seg, global_step=global_step) - + self.tboard.add_image("Batch seg, 1st col: gt, 2nd: pred.", seg, global_step=global_step) + if img_bg: - bg_img = np.transpose(mod_img, axes=(0,3,1,2)) + bg_img = np.transpose(mod_img, axes=(0, 3, 1, 2)) else: bg_img = seg_gt box_imgs = plg.draw_boxes_into_batch(bg_img, results_dict["boxes"], boxtype2color) box_imgs = tv.utils.make_grid(torch.from_numpy(box_imgs), nrow=4) self.tboard.add_image("Batch bboxes", box_imgs, global_step=global_step) - + return - def __del__(self): # otherwise might produce multiple prints e.g. in ipython console - for hdlr in self.pylogger.handlers: - hdlr.close() + def __del__(self): # otherwise might produce multiple prints e.g. in ipython console + # for hdlr in self.pylogger.handlers: + # hdlr.close() + # #self.pylogger.handlers = [] + # del self.pylogger self.tboard.close() - self.pylogger.handlers = [] - del self.pylogger + def get_logger(exp_dir, server_env=False, sysmetrics_interval=-1): log_dir = os.path.join(exp_dir, "logs") logger = CombinedLogger('Reg R-CNN', os.path.join(log_dir, "tboard"), server_env=server_env, sysmetrics_interval=sysmetrics_interval) print("logging to {}".format(logger.log_file)) return logger + def prep_exp(dataset_path, exp_path, server_env, use_stored_settings=True, is_training=True): """ I/O handling, creating of experiment folder structure. Also creates a snapshot of configs/model scripts and copies them to the exp_dir. This way the exp_dir contains all info needed to conduct an experiment, independent to changes in actual source code. Thus, training/inference of this experiment can be started at anytime. Therefore, the model script is copied back to the source code dir as tmp_model (tmp_backbone). Provides robust structure for cloud deployment. :param dataset_path: path to source code for specific data set. (e.g. medicaldetectiontoolkit/lidc_exp) :param exp_path: path to experiment directory. :param server_env: boolean flag. pass to configs script for cloud deployment. :param use_stored_settings: boolean flag. When starting training: If True, starts training from snapshot in existing experiment directory, else creates experiment directory on the fly using configs/model scripts from source code. :param is_training: boolean flag. distinguishes train vs. inference mode. :return: configs object. """ if is_training: if use_stored_settings: cf_file = import_module('cf', os.path.join(exp_path, 'configs.py')) cf = cf_file.Configs(server_env) # in this mode, previously saved model and backbone need to be found in exp dir. if not os.path.isfile(os.path.join(exp_path, 'model.py')) or \ not os.path.isfile(os.path.join(exp_path, 'backbone.py')): - raise Exception("Selected use_stored_settings option but no model and/or backbone source files exist in exp dir.") + raise Exception( + "Selected use_stored_settings option but no model and/or backbone source files exist in exp dir.") cf.model_path = os.path.join(exp_path, 'model.py') cf.backbone_path = os.path.join(exp_path, 'backbone.py') - else: # this case overwrites settings files in exp dir, i.e., default_configs, configs, backbone, model + else: # this case overwrites settings files in exp dir, i.e., default_configs, configs, backbone, model if not os.path.exists(exp_path): os.mkdir(exp_path) # run training with source code info and copy snapshot of model to exp_dir for later testing (overwrite scripts if exp_dir already exists.) - subprocess.call('cp {} {}'.format('default_configs.py', os.path.join(exp_path, 'default_configs.py')), shell=True) - subprocess.call('cp {} {}'.format(os.path.join(dataset_path, 'configs.py'), os.path.join(exp_path, 'configs.py')), shell=True) + subprocess.call('cp {} {}'.format('default_configs.py', os.path.join(exp_path, 'default_configs.py')), + shell=True) + subprocess.call( + 'cp {} {}'.format(os.path.join(dataset_path, 'configs.py'), os.path.join(exp_path, 'configs.py')), + shell=True) cf_file = import_module('cf_file', os.path.join(dataset_path, 'configs.py')) cf = cf_file.Configs(server_env) subprocess.call('cp {} {}'.format(cf.model_path, os.path.join(exp_path, 'model.py')), shell=True) subprocess.call('cp {} {}'.format(cf.backbone_path, os.path.join(exp_path, 'backbone.py')), shell=True) if os.path.isfile(os.path.join(exp_path, "fold_ids.pickle")): subprocess.call('rm {}'.format(os.path.join(exp_path, "fold_ids.pickle")), shell=True) - else: # testing, use model and backbone stored in exp dir. + else: # testing, use model and backbone stored in exp dir. cf_file = import_module('cf', os.path.join(exp_path, 'configs.py')) cf = cf_file.Configs(server_env) cf.model_path = os.path.join(exp_path, 'model.py') cf.backbone_path = os.path.join(exp_path, 'backbone.py') cf.exp_dir = exp_path cf.test_dir = os.path.join(cf.exp_dir, 'test') cf.plot_dir = os.path.join(cf.exp_dir, 'plots') if not os.path.exists(cf.test_dir): os.mkdir(cf.test_dir) if not os.path.exists(cf.plot_dir): os.mkdir(cf.plot_dir) cf.experiment_name = exp_path.split("/")[-1] cf.dataset_name = dataset_path cf.server_env = server_env cf.created_fold_id_pickle = False return cf + class ModelSelector: ''' saves a checkpoint after each epoch as 'last_state' (can be loaded to continue interrupted training). saves the top-k (k=cf.save_n_models) ranked epochs. In inference, predictions of multiple epochs can be ensembled to improve performance. ''' def __init__(self, cf, logger): self.cf = cf self.saved_epochs = [-1] * cf.save_n_models self.logger = logger - def run_model_selection(self, net, optimizer, monitor_metrics, epoch): """rank epoch via weighted mean from self.cf.model_selection_criteria: {criterion : weight} :param net: :param optimizer: :param monitor_metrics: :param epoch: :return: """ - crita = self.cf.model_selection_criteria #shorter alias + crita = self.cf.model_selection_criteria # shorter alias non_nan_scores = {} for criterion in crita.keys(): - #exclude first entry bc its dummy None entry - non_nan_scores[criterion] = [0 if (ii is None or np.isnan(ii)) else ii for ii in monitor_metrics['val'][criterion]][1:] + # exclude first entry bc its dummy None entry + non_nan_scores[criterion] = [0 if (ii is None or np.isnan(ii)) else ii for ii in + monitor_metrics['val'][criterion]][1:] n_epochs = len(non_nan_scores[criterion]) epochs_scores = [] for e_ix in range(n_epochs): epochs_scores.append(np.sum([weight * non_nan_scores[criterion][e_ix] for - criterion,weight in crita.items()])/len(crita.keys())) + criterion, weight in crita.items()]) / len(crita.keys())) # ranking of epochs according to model_selection_criterion - epoch_ranking = np.argsort(epochs_scores)[::-1] + 1 #epochs start at 1 + epoch_ranking = np.argsort(epochs_scores)[::-1] + 1 # epochs start at 1 # if set in configs, epochs < min_save_thresh are discarded from saving process. epoch_ranking = epoch_ranking[epoch_ranking >= self.cf.min_save_thresh] # check if current epoch is among the top-k epchs. if epoch in epoch_ranking[:self.cf.save_n_models]: if self.cf.server_env: - IO_safe(torch.save, net.state_dict(), os.path.join(self.cf.fold_dir, '{}_best_params.pth'.format(epoch))) + IO_safe(torch.save, net.state_dict(), + os.path.join(self.cf.fold_dir, '{}_best_params.pth'.format(epoch))) # save epoch_ranking to keep info for inference. IO_safe(np.save, os.path.join(self.cf.fold_dir, 'epoch_ranking'), epoch_ranking[:self.cf.save_n_models]) else: torch.save(net.state_dict(), os.path.join(self.cf.fold_dir, '{}_best_params.pth'.format(epoch))) np.save(os.path.join(self.cf.fold_dir, 'epoch_ranking'), epoch_ranking[:self.cf.save_n_models]) self.logger.info( "saving current epoch {} at rank {}".format(epoch, np.argwhere(epoch_ranking == epoch))) # delete params of the epoch that just fell out of the top-k epochs. for se in [int(ii.split('_')[0]) for ii in os.listdir(self.cf.fold_dir) if 'best_params' in ii]: if se in epoch_ranking[self.cf.save_n_models:]: subprocess.call('rm {}'.format(os.path.join(self.cf.fold_dir, '{}_best_params.pth'.format(se))), shell=True) self.logger.info('deleting epoch {} at rank {}'.format(se, np.argwhere(epoch_ranking == se))) state = { 'epoch': epoch, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), } if self.cf.server_env: IO_safe(torch.save, state, os.path.join(self.cf.fold_dir, 'last_state.pth')) else: torch.save(state, os.path.join(self.cf.fold_dir, 'last_state.pth')) def load_checkpoint(checkpoint_path, net, optimizer): - checkpoint = torch.load(checkpoint_path) net.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) return checkpoint['epoch'] def prepare_monitoring(cf): """ creates dictionaries, where train/val metrics are stored. """ metrics = {} # first entry for loss dict accounts for epoch starting at 1. - metrics['train'] = OrderedDict()# [(l_name, [np.nan]) for l_name in cf.losses_to_monitor] ) - metrics['val'] = OrderedDict()# [(l_name, [np.nan]) for l_name in cf.losses_to_monitor] ) + metrics['train'] = OrderedDict() # [(l_name, [np.nan]) for l_name in cf.losses_to_monitor] ) + metrics['val'] = OrderedDict() # [(l_name, [np.nan]) for l_name in cf.losses_to_monitor] ) metric_classes = [] if 'rois' in cf.report_score_level: metric_classes.extend([v for k, v in cf.class_dict.items()]) if hasattr(cf, "eval_bins_separately") and cf.eval_bins_separately: metric_classes.extend([v for k, v in cf.bin_dict.items()]) if 'patient' in cf.report_score_level: - metric_classes.extend(['patient_'+cf.class_dict[cf.patient_class_of_interest]]) + metric_classes.extend(['patient_' + cf.class_dict[cf.patient_class_of_interest]]) if hasattr(cf, "eval_bins_separately") and cf.eval_bins_separately: metric_classes.extend(['patient_' + cf.bin_dict[cf.patient_bin_of_interest]]) for cl in metric_classes: for m in cf.metrics: metrics['train'][cl + '_' + m] = [np.nan] metrics['val'][cl + '_' + m] = [np.nan] return metrics class _AnsiColorizer(object): """ A colorizer is an object that loosely wraps around a stream, allowing callers to write text to the stream in a particular color. Colorizer classes must implement C{supported()} and C{write(text, color)}. """ _colors = dict(black=30, red=31, green=32, yellow=33, blue=34, magenta=35, cyan=36, white=37, default=39) def __init__(self, stream): self.stream = stream @classmethod def supported(cls, stream=sys.stdout): """ A class method that returns True if the current platform supports coloring terminal output using this method. Returns False otherwise. """ if not stream.isatty(): return False # auto color only on TTYs try: import curses except ImportError: return False else: try: try: return curses.tigetnum("colors") > 2 except curses.error: curses.setupterm() return curses.tigetnum("colors") > 2 except: raise # guess false in case of error return False def write(self, text, color): """ Write the given text to the stream in the given color. @param text: Text to be written to the stream. @param color: A string label for a color. e.g. 'red', 'white'. """ color = self._colors[color] self.stream.write('\x1b[%sm%s\x1b[0m' % (color, text)) -class ColorHandler(logging.StreamHandler): +class ColorHandler(logging.StreamHandler): def __init__(self, stream=sys.stdout): super(ColorHandler, self).__init__(_AnsiColorizer(stream)) def emit(self, record): msg_colors = { logging.DEBUG: "green", logging.INFO: "default", logging.WARNING: "red", logging.ERROR: "red" } color = msg_colors.get(record.levelno, "blue") self.stream.write(record.msg + "\n", color) - - - diff --git a/utils/model_utils.py b/utils/model_utils.py index d415f84..be6b451 100644 --- a/utils/model_utils.py +++ b/utils/model_utils.py @@ -1,1454 +1,1454 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn published under MIT license. """ import warnings warnings.filterwarnings('ignore', '.*From scipy 0.13.0, the output shape of zoom()*') import numpy as np import scipy.misc import scipy.ndimage from scipy.ndimage.measurements import label as lb import torch -from custom_extensions.nms import nms -from custom_extensions.roi_align import roi_align +#from custom_extensions.nms import nms +#from custom_extensions.roi_align import roi_align ############################################################ # Segmentation Processing ############################################################ def sum_tensor(input, axes, keepdim=False): axes = np.unique(axes) if keepdim: for ax in axes: input = input.sum(ax, keepdim=True) else: for ax in sorted(axes, reverse=True): input = input.sum(int(ax)) return input def get_one_hot_encoding(y, n_classes): """ transform a numpy label array to a one-hot array of the same shape. :param y: array of shape (b, 1, y, x, (z)). :param n_classes: int, number of classes to unfold in one-hot encoding. :return y_ohe: array of shape (b, n_classes, y, x, (z)) """ dim = len(y.shape) - 2 if dim == 2: y_ohe = np.zeros((y.shape[0], n_classes, y.shape[2], y.shape[3])).astype('int32') elif dim == 3: y_ohe = np.zeros((y.shape[0], n_classes, y.shape[2], y.shape[3], y.shape[4])).astype('int32') else: raise Exception("invalid dimensions {} encountered".format(y.shape)) for cl in np.arange(n_classes): y_ohe[:, cl][y[:, 0] == cl] = 1 return y_ohe def dice_per_batch_inst_and_class(pred, y, n_classes, convert_to_ohe=True, smooth=1e-8): ''' computes dice scores per batch instance and class. :param pred: prediction array of shape (b, 1, y, x, (z)) (e.g. softmax prediction with argmax over dim 1) :param y: ground truth array of shape (b, 1, y, x, (z)) (contains int [0, ..., n_classes] :param n_classes: int :return: dice scores of shape (b, c) ''' if convert_to_ohe: pred = get_one_hot_encoding(pred, n_classes) y = get_one_hot_encoding(y, n_classes) axes = tuple(range(2, len(pred.shape))) intersect = np.sum(pred*y, axis=axes) denominator = np.sum(pred, axis=axes)+np.sum(y, axis=axes) dice = (2.0*intersect + smooth) / (denominator + smooth) return dice def dice_per_batch_and_class(pred, targ, n_classes, convert_to_ohe=True, smooth=1e-8): ''' computes dice scores per batch and class. :param pred: prediction array of shape (b, 1, y, x, (z)) (e.g. softmax prediction with argmax over dim 1) :param targ: ground truth array of shape (b, 1, y, x, (z)) (contains int [0, ..., n_classes]) :param n_classes: int :param smooth: Laplacian smooth, https://en.wikipedia.org/wiki/Additive_smoothing :return: dice scores of shape (b, c) ''' if convert_to_ohe: pred = get_one_hot_encoding(pred, n_classes) targ = get_one_hot_encoding(targ, n_classes) axes = (0, *list(range(2, len(pred.shape)))) #(0,2,3(,4)) intersect = np.sum(pred * targ, axis=axes) denominator = np.sum(pred, axis=axes) + np.sum(targ, axis=axes) dice = (2.0 * intersect + smooth) / (denominator + smooth) assert dice.shape==(n_classes,), "dice shp {}".format(dice.shape) return dice def batch_dice(pred, y, false_positive_weight=1.0, eps=1e-6): ''' compute soft dice over batch. this is a differentiable score and can be used as a loss function. only dice scores of foreground classes are returned, since training typically does not benefit from explicit background optimization. Pixels of the entire batch are considered a pseudo-volume to compute dice scores of. This way, single patches with missing foreground classes can not produce faulty gradients. :param pred: (b, c, y, x, (z)), softmax probabilities (network output). :param y: (b, c, y, x, (z)), one hote encoded segmentation mask. :param false_positive_weight: float [0,1]. For weighting of imbalanced classes, reduces the penalty for false-positive pixels. Can be beneficial sometimes in data with heavy fg/bg imbalances. :return: soft dice score (float).This function discards the background score and returns the mena of foreground scores. ''' # todo also use additive smooth here instead of eps? if len(pred.size()) == 4: axes = (0, 2, 3) intersect = sum_tensor(pred * y, axes, keepdim=False) denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False) return torch.mean((2 * intersect / (denom + eps))[1:]) #only fg dice here. if len(pred.size()) == 5: axes = (0, 2, 3, 4) intersect = sum_tensor(pred * y, axes, keepdim=False) denom = sum_tensor(false_positive_weight*pred + y, axes, keepdim=False) return torch.mean((2 * intersect / (denom + eps))[1:]) #only fg dice here. else: raise ValueError('wrong input dimension in dice loss') ############################################################ # Bounding Boxes ############################################################ def compute_iou_2D(box, boxes, box_area, boxes_area): """Calculates IoU of the given box with the array of the given boxes. box: 1D vector [y1, x1, y2, x2] THIS IS THE GT BOX boxes: [boxes_count, (y1, x1, y2, x2)] box_area: float. the area of 'box' boxes_area: array of length boxes_count. Note: the areas are passed in rather than calculated here for efficency. Calculate once in the caller to avoid duplicate work. """ # Calculate intersection areas y1 = np.maximum(box[0], boxes[:, 0]) y2 = np.minimum(box[2], boxes[:, 2]) x1 = np.maximum(box[1], boxes[:, 1]) x2 = np.minimum(box[3], boxes[:, 3]) intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0) union = box_area + boxes_area[:] - intersection[:] iou = intersection / union return iou def compute_iou_3D(box, boxes, box_volume, boxes_volume): """Calculates IoU of the given box with the array of the given boxes. box: 1D vector [y1, x1, y2, x2, z1, z2] (typically gt box) boxes: [boxes_count, (y1, x1, y2, x2, z1, z2)] box_area: float. the area of 'box' boxes_area: array of length boxes_count. Note: the areas are passed in rather than calculated here for efficency. Calculate once in the caller to avoid duplicate work. """ # Calculate intersection areas y1 = np.maximum(box[0], boxes[:, 0]) y2 = np.minimum(box[2], boxes[:, 2]) x1 = np.maximum(box[1], boxes[:, 1]) x2 = np.minimum(box[3], boxes[:, 3]) z1 = np.maximum(box[4], boxes[:, 4]) z2 = np.minimum(box[5], boxes[:, 5]) intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0) * np.maximum(z2 - z1, 0) union = box_volume + boxes_volume[:] - intersection[:] iou = intersection / union return iou def compute_overlaps(boxes1, boxes2): """Computes IoU overlaps between two sets of boxes. boxes1, boxes2: [N, (y1, x1, y2, x2)]. / 3D: (z1, z2)) For better performance, pass the largest set first and the smaller second. :return: (#boxes1, #boxes2), ious of each box of 1 machted with each of 2 """ # Areas of anchors and GT boxes if boxes1.shape[1] == 4: area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) # Compute overlaps to generate matrix [boxes1 count, boxes2 count] # Each cell contains the IoU value. overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0])) for i in range(overlaps.shape[1]): box2 = boxes2[i] #this is the gt box overlaps[:, i] = compute_iou_2D(box2, boxes1, area2[i], area1) return overlaps else: # Areas of anchors and GT boxes volume1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) * (boxes1[:, 5] - boxes1[:, 4]) volume2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) * (boxes2[:, 5] - boxes2[:, 4]) # Compute overlaps to generate matrix [boxes1 count, boxes2 count] # Each cell contains the IoU value. overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0])) for i in range(boxes2.shape[0]): box2 = boxes2[i] # this is the gt box overlaps[:, i] = compute_iou_3D(box2, boxes1, volume2[i], volume1) return overlaps def box_refinement(box, gt_box): """Compute refinement needed to transform box to gt_box. box and gt_box are [N, (y1, x1, y2, x2)] / 3D: (z1, z2)) """ height = box[:, 2] - box[:, 0] width = box[:, 3] - box[:, 1] center_y = box[:, 0] + 0.5 * height center_x = box[:, 1] + 0.5 * width gt_height = gt_box[:, 2] - gt_box[:, 0] gt_width = gt_box[:, 3] - gt_box[:, 1] gt_center_y = gt_box[:, 0] + 0.5 * gt_height gt_center_x = gt_box[:, 1] + 0.5 * gt_width dy = (gt_center_y - center_y) / height dx = (gt_center_x - center_x) / width dh = torch.log(gt_height / height) dw = torch.log(gt_width / width) result = torch.stack([dy, dx, dh, dw], dim=1) if box.shape[1] > 4: depth = box[:, 5] - box[:, 4] center_z = box[:, 4] + 0.5 * depth gt_depth = gt_box[:, 5] - gt_box[:, 4] gt_center_z = gt_box[:, 4] + 0.5 * gt_depth dz = (gt_center_z - center_z) / depth dd = torch.log(gt_depth / depth) result = torch.stack([dy, dx, dz, dh, dw, dd], dim=1) return result def unmold_mask_2D(mask, bbox, image_shape): """Converts a mask generated by the neural network into a format similar to it's original shape. mask: [height, width] of type float. A small, typically 28x28 mask. bbox: [y1, x1, y2, x2]. The box to fit the mask in. Returns a binary mask with the same size as the original image. """ y1, x1, y2, x2 = bbox out_zoom = [y2 - y1, x2 - x1] zoom_factor = [i / j for i, j in zip(out_zoom, mask.shape)] mask = scipy.ndimage.zoom(mask, zoom_factor, order=1).astype(np.float32) # Put the mask in the right location. full_mask = np.zeros(image_shape[:2]) #only y,x full_mask[y1:y2, x1:x2] = mask return full_mask def unmold_mask_2D_torch(mask, bbox, image_shape): """Converts a mask generated by the neural network into a format similar to it's original shape. mask: [height, width] of type float. A small, typically 28x28 mask. bbox: [y1, x1, y2, x2]. The box to fit the mask in. Returns a binary mask with the same size as the original image. """ y1, x1, y2, x2 = bbox out_zoom = [(y2 - y1).float(), (x2 - x1).float()] zoom_factor = [i / j for i, j in zip(out_zoom, mask.shape)] mask = mask.unsqueeze(0).unsqueeze(0) mask = torch.nn.functional.interpolate(mask, scale_factor=zoom_factor) mask = mask[0][0] #mask = scipy.ndimage.zoom(mask.cpu().numpy(), zoom_factor, order=1).astype(np.float32) #mask = torch.from_numpy(mask).cuda() # Put the mask in the right location. full_mask = torch.zeros(image_shape[:2]) # only y,x full_mask[y1:y2, x1:x2] = mask return full_mask def unmold_mask_3D(mask, bbox, image_shape): """Converts a mask generated by the neural network into a format similar to it's original shape. mask: [height, width] of type float. A small, typically 28x28 mask. bbox: [y1, x1, y2, x2, z1, z2]. The box to fit the mask in. Returns a binary mask with the same size as the original image. """ y1, x1, y2, x2, z1, z2 = bbox out_zoom = [y2 - y1, x2 - x1, z2 - z1] zoom_factor = [i/j for i,j in zip(out_zoom, mask.shape)] mask = scipy.ndimage.zoom(mask, zoom_factor, order=1).astype(np.float32) # Put the mask in the right location. full_mask = np.zeros(image_shape[:3]) full_mask[y1:y2, x1:x2, z1:z2] = mask return full_mask def nms_numpy(box_coords, scores, thresh): """ non-maximum suppression on 2D or 3D boxes in numpy. :param box_coords: [y1,x1,y2,x2 (,z1,z2)] with y1<=y2, x1<=x2, z1<=z2. :param scores: ranking scores (higher score == higher rank) of boxes. :param thresh: IoU threshold for clustering. :return: """ y1 = box_coords[:, 0] x1 = box_coords[:, 1] y2 = box_coords[:, 2] x2 = box_coords[:, 3] assert np.all(y1 <= y2) and np.all(x1 <= x2), """"the definition of the coordinates is crucially important here: coordinates of which maxima are taken need to be the lower coordinates""" areas = (x2 - x1 + 1) * (y2 - y1 + 1) is_3d = box_coords.shape[1] == 6 if is_3d: # 3-dim case z1 = box_coords[:, 4] z2 = box_coords[:, 5] assert np.all(z1<=z2), """"the definition of the coordinates is crucially important here: coordinates of which maxima are taken need to be the lower coordinates""" areas *= (z2 - z1 + 1) order = scores.argsort()[::-1] keep = [] while order.size > 0: # order is the sorted index. maps order to index: order[1] = 24 means (rank1, ix 24) i = order[0] # highest scoring element yy1 = np.maximum(y1[i], y1[order]) # highest scoring element still in >order<, is compared to itself, that is okay. xx1 = np.maximum(x1[i], x1[order]) yy2 = np.minimum(y2[i], y2[order]) xx2 = np.minimum(x2[i], x2[order]) h = np.maximum(0.0, yy2 - yy1 + 1) w = np.maximum(0.0, xx2 - xx1 + 1) inter = h * w if is_3d: zz1 = np.maximum(z1[i], z1[order]) zz2 = np.minimum(z2[i], z2[order]) d = np.maximum(0.0, zz2 - zz1 + 1) inter *= d iou = inter / (areas[i] + areas[order] - inter) non_matches = np.nonzero(iou <= thresh)[0] # get all elements that were not matched and discard all others. #print("iou keep {}: {}, non_matches {}".format(i, iou, order[non_matches])) order = order[non_matches] keep.append(i) #print("total keep", keep) return keep ############################################################ # M-RCNN ############################################################ def refine_proposals(rpn_pred_probs, rpn_pred_deltas, proposal_count, batch_anchors, cf): """ Receives anchor scores and selects a subset to pass as proposals to the second stage. Filtering is done based on anchor scores and non-max suppression to remove overlaps. It also applies bounding box refinment details to anchors. :param rpn_pred_probs: (b, n_anchors, 2) :param rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d)))) :return: batch_normalized_props: Proposals in normalized coordinates (b, proposal_count, (y1, x1, y2, x2, (z1), (z2), score)) :return: batch_out_proposals: Box coords + RPN foreground scores for monitoring/plotting (b, proposal_count, (y1, x1, y2, x2, (z1), (z2), score)) """ std_dev = torch.from_numpy(cf.rpn_bbox_std_dev[None]).float().cuda() norm = torch.from_numpy(cf.scale).float().cuda() anchors = batch_anchors.clone() batch_scores = rpn_pred_probs[:, :, 1] # norm deltas batch_deltas = rpn_pred_deltas * std_dev batch_normalized_props = [] batch_out_proposals = [] # loop over batch dimension. for ix in range(batch_scores.shape[0]): scores = batch_scores[ix] deltas = batch_deltas[ix] # improve performance by trimming to top anchors by score # and doing the rest on the smaller subset. pre_nms_limit = min(cf.pre_nms_limit, anchors.size()[0]) scores, order = scores.sort(descending=True) order = order[:pre_nms_limit] scores = scores[:pre_nms_limit] deltas = deltas[order, :] # apply deltas to anchors to get refined anchors and filter with non-maximum suppression. if batch_deltas.shape[-1] == 4: boxes = apply_box_deltas_2D(anchors[order, :], deltas) boxes = clip_boxes_2D(boxes, cf.window) else: boxes = apply_box_deltas_3D(anchors[order, :], deltas) boxes = clip_boxes_3D(boxes, cf.window) # boxes are y1,x1,y2,x2, torchvision-nms requires x1,y1,x2,y2, but consistent swap x<->y is irrelevant. keep = nms.nms(boxes, scores, cf.rpn_nms_threshold) keep = keep[:proposal_count] boxes = boxes[keep, :] rpn_scores = scores[keep][:, None] # pad missing boxes with 0. if boxes.shape[0] < proposal_count: n_pad_boxes = proposal_count - boxes.shape[0] zeros = torch.zeros([n_pad_boxes, boxes.shape[1]]).cuda() boxes = torch.cat([boxes, zeros], dim=0) zeros = torch.zeros([n_pad_boxes, rpn_scores.shape[1]]).cuda() rpn_scores = torch.cat([rpn_scores, zeros], dim=0) # concat box and score info for monitoring/plotting. batch_out_proposals.append(torch.cat((boxes, rpn_scores), 1).cpu().data.numpy()) # normalize dimensions to range of 0 to 1. normalized_boxes = boxes / norm assert torch.all(normalized_boxes <= 1), "normalized box coords >1 found" # add again batch dimension batch_normalized_props.append(torch.cat((normalized_boxes, rpn_scores), 1).unsqueeze(0)) batch_normalized_props = torch.cat(batch_normalized_props) batch_out_proposals = np.array(batch_out_proposals) return batch_normalized_props, batch_out_proposals def pyramid_roi_align(feature_maps, rois, pool_size, pyramid_levels, dim): """ Implements ROI Pooling on multiple levels of the feature pyramid. :param feature_maps: list of feature maps, each of shape (b, c, y, x , (z)) :param rois: proposals (normalized coords.) as returned by RPN. contain info about original batch element allocation. (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ixs) :param pool_size: list of poolsizes in dims: [x, y, (z)] :param pyramid_levels: list. [0, 1, 2, ...] :return: pooled: pooled feature map rois (n_proposals, c, poolsize_y, poolsize_x, (poolsize_z)) Output: Pooled regions in the shape: [num_boxes, height, width, channels]. The width and height are those specific in the pool_shape in the layer constructor. """ boxes = rois[:, :dim*2] batch_ixs = rois[:, dim*2] # Assign each ROI to a level in the pyramid based on the ROI area. if dim == 2: y1, x1, y2, x2 = boxes.chunk(4, dim=1) else: y1, x1, y2, x2, z1, z2 = boxes.chunk(6, dim=1) h = y2 - y1 w = x2 - x1 # Equation 1 in https://arxiv.org/abs/1612.03144. Account for # the fact that our coordinates are normalized here. # divide sqrt(h*w) by 1 instead image_area. roi_level = (4 + torch.log2(torch.sqrt(h*w))).round().int().clamp(pyramid_levels[0], pyramid_levels[-1]) # if Pyramid contains additional level P6, adapt the roi_level assignment accordingly. if len(pyramid_levels) == 5: roi_level[h*w > 0.65] = 5 # Loop through levels and apply ROI pooling to each. pooled = [] box_to_level = [] fmap_shapes = [f.shape for f in feature_maps] for level_ix, level in enumerate(pyramid_levels): ix = roi_level == level if not ix.any(): continue ix = torch.nonzero(ix)[:, 0] level_boxes = boxes[ix, :] # re-assign rois to feature map of original batch element. ind = batch_ixs[ix].int() # Keep track of which box is mapped to which level box_to_level.append(ix) # Stop gradient propogation to ROI proposals level_boxes = level_boxes.detach() if len(pool_size) == 2: # remap to feature map coordinate system y_exp, x_exp = fmap_shapes[level_ix][2:] # exp = expansion level_boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp], dtype=torch.float32).cuda()) pooled_features = roi_align.roi_align_2d(feature_maps[level_ix], torch.cat((ind.unsqueeze(1).float(), level_boxes), dim=1), pool_size) else: y_exp, x_exp, z_exp = fmap_shapes[level_ix][2:] level_boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp, z_exp, z_exp], dtype=torch.float32).cuda()) pooled_features = roi_align.roi_align_3d(feature_maps[level_ix], torch.cat((ind.unsqueeze(1).float(), level_boxes), dim=1), pool_size) pooled.append(pooled_features) # Pack pooled features into one tensor pooled = torch.cat(pooled, dim=0) # Pack box_to_level mapping into one array and add another # column representing the order of pooled boxes box_to_level = torch.cat(box_to_level, dim=0) # Rearrange pooled features to match the order of the original boxes _, box_to_level = torch.sort(box_to_level) pooled = pooled[box_to_level, :, :] return pooled def refine_detections(cf, batch_ixs, rois, deltas, scores, regressions): """ Refine classified proposals (apply deltas to rpn rois), filter overlaps (nms) and return final detections. :param rois: (n_proposals, 2 * dim) normalized boxes as proposed by RPN. n_proposals = batch_size * POST_NMS_ROIS :param deltas: (n_proposals, n_classes, 2 * dim) box refinement deltas as predicted by mrcnn bbox regressor. :param batch_ixs: (n_proposals) batch element assignment info for re-allocation. :param scores: (n_proposals, n_classes) probabilities for all classes per roi as predicted by mrcnn classifier. :param regressions: (n_proposals, n_classes, regression_features (+1 for uncertainty if predicted) regression vector :return: result: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score, *regression vector features)) """ # class IDs per ROI. Since scores of all classes are of interest (not just max class), all are kept at this point. class_ids = [] fg_classes = cf.head_classes - 1 # repeat vectors to fill in predictions for all foreground classes. for ii in range(1, fg_classes + 1): class_ids += [ii] * rois.shape[0] class_ids = torch.from_numpy(np.array(class_ids)).cuda() batch_ixs = batch_ixs.repeat(fg_classes) rois = rois.repeat(fg_classes, 1) deltas = deltas.repeat(fg_classes, 1, 1) scores = scores.repeat(fg_classes, 1) regressions = regressions.repeat(fg_classes, 1, 1) # get class-specific scores and bounding box deltas idx = torch.arange(class_ids.size()[0]).long().cuda() # using idx instead of slice [:,] squashes first dimension. #len(class_ids)>scores.shape[1] --> probs is broadcasted by expansion from fg_classes-->len(class_ids) batch_ixs = batch_ixs[idx] deltas_specific = deltas[idx, class_ids] class_scores = scores[idx, class_ids] regressions = regressions[idx, class_ids] # apply bounding box deltas. re-scale to image coordinates. std_dev = torch.from_numpy(np.reshape(cf.rpn_bbox_std_dev, [1, cf.dim * 2])).float().cuda() scale = torch.from_numpy(cf.scale).float().cuda() refined_rois = apply_box_deltas_2D(rois, deltas_specific * std_dev) * scale if cf.dim == 2 else \ apply_box_deltas_3D(rois, deltas_specific * std_dev) * scale # round and cast to int since we're dealing with pixels now refined_rois = clip_to_window(cf.window, refined_rois) refined_rois = torch.round(refined_rois) # filter out low confidence boxes keep = idx keep_bool = (class_scores >= cf.model_min_confidence) if not 0 in torch.nonzero(keep_bool).size(): score_keep = torch.nonzero(keep_bool)[:, 0] pre_nms_class_ids = class_ids[score_keep] pre_nms_rois = refined_rois[score_keep] pre_nms_scores = class_scores[score_keep] pre_nms_batch_ixs = batch_ixs[score_keep] for j, b in enumerate(unique1d(pre_nms_batch_ixs)): bixs = torch.nonzero(pre_nms_batch_ixs == b)[:, 0] bix_class_ids = pre_nms_class_ids[bixs] bix_rois = pre_nms_rois[bixs] bix_scores = pre_nms_scores[bixs] for i, class_id in enumerate(unique1d(bix_class_ids)): ixs = torch.nonzero(bix_class_ids == class_id)[:, 0] # nms expects boxes sorted by score. ix_rois = bix_rois[ixs] ix_scores = bix_scores[ixs] ix_scores, order = ix_scores.sort(descending=True) ix_rois = ix_rois[order, :] class_keep = nms.nms(ix_rois, ix_scores, cf.detection_nms_threshold) # map indices back. class_keep = keep[score_keep[bixs[ixs[order[class_keep]]]]] # merge indices over classes for current batch element b_keep = class_keep if i == 0 else unique1d(torch.cat((b_keep, class_keep))) # only keep top-k boxes of current batch-element top_ids = class_scores[b_keep].sort(descending=True)[1][:cf.model_max_instances_per_batch_element] b_keep = b_keep[top_ids] # merge indices over batch elements. batch_keep = b_keep if j == 0 else unique1d(torch.cat((batch_keep, b_keep))) keep = batch_keep else: keep = torch.tensor([0]).long().cuda() # arrange output output = [refined_rois[keep], batch_ixs[keep].unsqueeze(1)] output += [class_ids[keep].unsqueeze(1).float(), class_scores[keep].unsqueeze(1)] output += [regressions[keep]] result = torch.cat(output, dim=1) # shape: (n_keeps, catted feats), catted feats: [0:dim*2] are box_coords, [dim*2] are batch_ics, # [dim*2+1] are class_ids, [dim*2+2] are scores, [dim*2+3:] are regression vector features (incl uncertainty) return result def loss_example_mining(cf, batch_proposals, batch_gt_boxes, batch_gt_masks, batch_roi_scores, batch_gt_class_ids, batch_gt_regressions): """ Subsamples proposals for mrcnn losses and generates targets. Sampling is done per batch element, seems to have positive effects on training, as opposed to sampling over entire batch. Negatives are sampled via stochastic hard-example mining (SHEM), where a number of negative proposals is drawn from larger pool of highest scoring proposals for stochasticity. Scoring is obtained here as the max over all foreground probabilities as returned by mrcnn_classifier (worked better than loss-based class-balancing methods like "online hard-example mining" or "focal loss".) Classification-regression duality: regressions can be given along with classes (at least fg/bg, only class scores are used for ranking). :param batch_proposals: (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ixs). boxes as proposed by RPN. n_proposals here is determined by batch_size * POST_NMS_ROIS. :param mrcnn_class_logits: (n_proposals, n_classes) :param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates. :param batch_gt_masks: list over batch elements. Each element is binary mask of shape (n_gt_rois, y, x, (z), c) :param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels. if no classes predicted (only fg/bg from RPN): expected as pseudo classes [0, 1] for bg, fg. :param batch_gt_regressions: list over b elements. Each element is a regression target vector. if None--> pseudo :return: sample_indices: (n_sampled_rois) indices of sampled proposals to be used for loss functions. :return: target_class_ids: (n_sampled_rois)containing target class labels of sampled proposals. :return: target_deltas: (n_sampled_rois, 2 * dim) containing target deltas of sampled proposals for box refinement. :return: target_masks: (n_sampled_rois, y, x, (z)) containing target masks of sampled proposals. """ # normalization of target coordinates #global sample_regressions if cf.dim == 2: h, w = cf.patch_size scale = torch.from_numpy(np.array([h, w, h, w])).float().cuda() else: h, w, z = cf.patch_size scale = torch.from_numpy(np.array([h, w, h, w, z, z])).float().cuda() positive_count = 0 negative_count = 0 sample_positive_indices = [] sample_negative_indices = [] sample_deltas = [] sample_masks = [] sample_class_ids = [] if batch_gt_regressions is not None: sample_regressions = [] else: target_regressions = torch.FloatTensor().cuda() # loop over batch and get positive and negative sample rois. for b in range(len(batch_gt_boxes)): gt_masks = torch.from_numpy(batch_gt_masks[b]).float().cuda() gt_class_ids = torch.from_numpy(batch_gt_class_ids[b]).int().cuda() if batch_gt_regressions is not None: gt_regressions = torch.from_numpy(batch_gt_regressions[b]).float().cuda() #if np.any(batch_gt_class_ids[b] > 0): # skip roi selection for no gt images. if np.any([len(coords)>0 for coords in batch_gt_boxes[b]]): gt_boxes = torch.from_numpy(batch_gt_boxes[b]).float().cuda() / scale else: gt_boxes = torch.FloatTensor().cuda() # get proposals and indices of current batch element. proposals = batch_proposals[batch_proposals[:, -1] == b][:, :-1] batch_element_indices = torch.nonzero(batch_proposals[:, -1] == b).squeeze(1) # Compute overlaps matrix [proposals, gt_boxes] if not 0 in gt_boxes.size(): if gt_boxes.shape[1] == 4: assert cf.dim == 2, "gt_boxes shape {} doesnt match cf.dim{}".format(gt_boxes.shape, cf.dim) overlaps = bbox_overlaps_2D(proposals, gt_boxes) else: assert cf.dim == 3, "gt_boxes shape {} doesnt match cf.dim{}".format(gt_boxes.shape, cf.dim) overlaps = bbox_overlaps_3D(proposals, gt_boxes) # Determine positive and negative ROIs roi_iou_max = torch.max(overlaps, dim=1)[0] # 1. Positive ROIs are those with >= 0.5 IoU with a GT box positive_roi_bool = roi_iou_max >= (0.5 if cf.dim == 2 else 0.3) # 2. Negative ROIs are those with < 0.1 with every GT box. negative_roi_bool = roi_iou_max < (0.1 if cf.dim == 2 else 0.01) else: positive_roi_bool = torch.FloatTensor().cuda() negative_roi_bool = torch.from_numpy(np.array([1]*proposals.shape[0])).cuda() # Sample Positive ROIs if not 0 in torch.nonzero(positive_roi_bool).size(): positive_indices = torch.nonzero(positive_roi_bool).squeeze(1) positive_samples = int(cf.train_rois_per_image * cf.roi_positive_ratio) rand_idx = torch.randperm(positive_indices.size()[0]) rand_idx = rand_idx[:positive_samples].cuda() positive_indices = positive_indices[rand_idx] positive_samples = positive_indices.size()[0] positive_rois = proposals[positive_indices, :] # Assign positive ROIs to GT boxes. positive_overlaps = overlaps[positive_indices, :] roi_gt_box_assignment = torch.max(positive_overlaps, dim=1)[1] roi_gt_boxes = gt_boxes[roi_gt_box_assignment, :] roi_gt_class_ids = gt_class_ids[roi_gt_box_assignment] if batch_gt_regressions is not None: roi_gt_regressions = gt_regressions[roi_gt_box_assignment] # Compute bbox refinement targets for positive ROIs deltas = box_refinement(positive_rois, roi_gt_boxes) std_dev = torch.from_numpy(cf.bbox_std_dev).float().cuda() deltas /= std_dev roi_masks = gt_masks[roi_gt_box_assignment].unsqueeze(1) # .squeeze(-1) assert roi_masks.shape[-1] == 1 # Compute mask targets boxes = positive_rois box_ids = torch.arange(roi_masks.shape[0]).cuda().unsqueeze(1).float() if len(cf.mask_shape) == 2: # todo what are the dims of roi_masks? (n_matched_boxes_with_gts, 1 (dummy channel dim), y,x, 1 (WHY?)) masks = roi_align.roi_align_2d(roi_masks, torch.cat((box_ids, boxes), dim=1), cf.mask_shape) else: masks = roi_align.roi_align_3d(roi_masks, torch.cat((box_ids, boxes), dim=1), cf.mask_shape) masks = masks.squeeze(1) # Threshold mask pixels at 0.5 to have GT masks be 0 or 1 to use with # binary cross entropy loss. masks = torch.round(masks) sample_positive_indices.append(batch_element_indices[positive_indices]) sample_deltas.append(deltas) sample_masks.append(masks) sample_class_ids.append(roi_gt_class_ids) if batch_gt_regressions is not None: sample_regressions.append(roi_gt_regressions) positive_count += positive_samples else: positive_samples = 0 # Sample negative ROIs. Add enough to maintain positive:negative ratio, but at least 1. Sample via SHEM. if not 0 in torch.nonzero(negative_roi_bool).size(): negative_indices = torch.nonzero(negative_roi_bool).squeeze(1) r = 1.0 / cf.roi_positive_ratio b_neg_count = np.max((int(r * positive_samples - positive_samples), 1)) roi_scores_neg = batch_roi_scores[batch_element_indices[negative_indices]] raw_sampled_indices = shem(roi_scores_neg, b_neg_count, cf.shem_poolsize) sample_negative_indices.append(batch_element_indices[negative_indices[raw_sampled_indices]]) negative_count += raw_sampled_indices.size()[0] if len(sample_positive_indices) > 0: target_deltas = torch.cat(sample_deltas) target_masks = torch.cat(sample_masks) target_class_ids = torch.cat(sample_class_ids) if batch_gt_regressions is not None: target_regressions = torch.cat(sample_regressions) # Pad target information with zeros for negative ROIs. if positive_count > 0 and negative_count > 0: sample_indices = torch.cat((torch.cat(sample_positive_indices), torch.cat(sample_negative_indices)), dim=0) zeros = torch.zeros(negative_count, cf.dim * 2).cuda() target_deltas = torch.cat([target_deltas, zeros], dim=0) zeros = torch.zeros(negative_count, *cf.mask_shape).cuda() target_masks = torch.cat([target_masks, zeros], dim=0) zeros = torch.zeros(negative_count).int().cuda() target_class_ids = torch.cat([target_class_ids, zeros], dim=0) if batch_gt_regressions is not None: # regression targets need to have 0 as background/negative with below practice if 'regression_bin' in cf.prediction_tasks: zeros = torch.zeros(negative_count, dtype=torch.float).cuda() else: zeros = torch.zeros(negative_count, cf.regression_n_features, dtype=torch.float).cuda() target_regressions = torch.cat([target_regressions, zeros], dim=0) elif positive_count > 0: sample_indices = torch.cat(sample_positive_indices) elif negative_count > 0: sample_indices = torch.cat(sample_negative_indices) target_deltas = torch.zeros(negative_count, cf.dim * 2).cuda() target_masks = torch.zeros(negative_count, *cf.mask_shape).cuda() target_class_ids = torch.zeros(negative_count).int().cuda() if batch_gt_regressions is not None: if 'regression_bin' in cf.prediction_tasks: target_regressions = torch.zeros(negative_count, dtype=torch.float).cuda() else: target_regressions = torch.zeros(negative_count, cf.regression_n_features, dtype=torch.float).cuda() else: sample_indices = torch.LongTensor().cuda() target_class_ids = torch.IntTensor().cuda() target_deltas = torch.FloatTensor().cuda() target_masks = torch.FloatTensor().cuda() target_regressions = torch.FloatTensor().cuda() return sample_indices, target_deltas, target_masks, target_class_ids, target_regressions ############################################################ # Anchors ############################################################ def generate_anchors(scales, ratios, shape, feature_stride, anchor_stride): """ scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128] ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2] shape: [height, width] spatial shape of the feature map over which to generate anchors. feature_stride: Stride of the feature map relative to the image in pixels. anchor_stride: Stride of anchors on the feature map. For example, if the value is 2 then generate anchors for every other feature map pixel. """ # Get all combinations of scales and ratios scales, ratios = np.meshgrid(np.array(scales), np.array(ratios)) scales = scales.flatten() ratios = ratios.flatten() # Enumerate heights and widths from scales and ratios heights = scales / np.sqrt(ratios) widths = scales * np.sqrt(ratios) # Enumerate shifts in feature space shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride shifts_x, shifts_y = np.meshgrid(shifts_x, shifts_y) # Enumerate combinations of shifts, widths, and heights box_widths, box_centers_x = np.meshgrid(widths, shifts_x) box_heights, box_centers_y = np.meshgrid(heights, shifts_y) # Reshape to get a list of (y, x) and a list of (h, w) box_centers = np.stack([box_centers_y, box_centers_x], axis=2).reshape([-1, 2]) box_sizes = np.stack([box_heights, box_widths], axis=2).reshape([-1, 2]) # Convert to corner coordinates (y1, x1, y2, x2) boxes = np.concatenate([box_centers - 0.5 * box_sizes, box_centers + 0.5 * box_sizes], axis=1) return boxes def generate_anchors_3D(scales_xy, scales_z, ratios, shape, feature_stride_xy, feature_stride_z, anchor_stride): """ scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128] ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2] shape: [height, width] spatial shape of the feature map over which to generate anchors. feature_stride: Stride of the feature map relative to the image in pixels. anchor_stride: Stride of anchors on the feature map. For example, if the value is 2 then generate anchors for every other feature map pixel. """ # Get all combinations of scales and ratios scales_xy, ratios_meshed = np.meshgrid(np.array(scales_xy), np.array(ratios)) scales_xy = scales_xy.flatten() ratios_meshed = ratios_meshed.flatten() # Enumerate heights and widths from scales and ratios heights = scales_xy / np.sqrt(ratios_meshed) widths = scales_xy * np.sqrt(ratios_meshed) depths = np.tile(np.array(scales_z), len(ratios_meshed)//np.array(scales_z)[..., None].shape[0]) # Enumerate shifts in feature space shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride_xy #translate from fm positions to input coords. shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride_xy shifts_z = np.arange(0, shape[2], anchor_stride) * (feature_stride_z) shifts_x, shifts_y, shifts_z = np.meshgrid(shifts_x, shifts_y, shifts_z) # Enumerate combinations of shifts, widths, and heights box_widths, box_centers_x = np.meshgrid(widths, shifts_x) box_heights, box_centers_y = np.meshgrid(heights, shifts_y) box_depths, box_centers_z = np.meshgrid(depths, shifts_z) # Reshape to get a list of (y, x, z) and a list of (h, w, d) box_centers = np.stack( [box_centers_y, box_centers_x, box_centers_z], axis=2).reshape([-1, 3]) box_sizes = np.stack([box_heights, box_widths, box_depths], axis=2).reshape([-1, 3]) # Convert to corner coordinates (y1, x1, y2, x2, z1, z2) boxes = np.concatenate([box_centers - 0.5 * box_sizes, box_centers + 0.5 * box_sizes], axis=1) boxes = np.transpose(np.array([boxes[:, 0], boxes[:, 1], boxes[:, 3], boxes[:, 4], boxes[:, 2], boxes[:, 5]]), axes=(1, 0)) return boxes def generate_pyramid_anchors(logger, cf): """Generate anchors at different levels of a feature pyramid. Each scale is associated with a level of the pyramid, but each ratio is used in all levels of the pyramid. from configs: :param scales: cf.RPN_ANCHOR_SCALES , for conformity with retina nets: scale entries need to be list, e.g. [[4], [8], [16], [32]] :param ratios: cf.RPN_ANCHOR_RATIOS , e.g. [0.5, 1, 2] :param feature_shapes: cf.BACKBONE_SHAPES , e.g. [array of shapes per feature map] [80, 40, 20, 10, 5] :param feature_strides: cf.BACKBONE_STRIDES , e.g. [2, 4, 8, 16, 32, 64] :param anchors_stride: cf.RPN_ANCHOR_STRIDE , e.g. 1 :return anchors: (N, (y1, x1, y2, x2, (z1), (z2)). All generated anchors in one array. Sorted with the same order of the given scales. So, anchors of scale[0] come first, then anchors of scale[1], and so on. """ scales = cf.rpn_anchor_scales ratios = cf.rpn_anchor_ratios feature_shapes = cf.backbone_shapes anchor_stride = cf.rpn_anchor_stride pyramid_levels = cf.pyramid_levels feature_strides = cf.backbone_strides logger.info("anchor scales {} and feature map shapes {}".format(scales, feature_shapes)) expected_anchors = [np.prod(feature_shapes[level]) * len(ratios) * len(scales['xy'][level]) for level in pyramid_levels] anchors = [] for lix, level in enumerate(pyramid_levels): if len(feature_shapes[level]) == 2: anchors.append(generate_anchors(scales['xy'][level], ratios, feature_shapes[level], feature_strides['xy'][level], anchor_stride)) elif len(feature_shapes[level]) == 3: anchors.append(generate_anchors_3D(scales['xy'][level], scales['z'][level], ratios, feature_shapes[level], feature_strides['xy'][level], feature_strides['z'][level], anchor_stride)) else: raise Exception("invalid feature_shapes[{}] size {}".format(level, feature_shapes[level])) logger.info("level {}: expected anchors {}, built anchors {}.".format(level, expected_anchors[lix], anchors[-1].shape)) out_anchors = np.concatenate(anchors, axis=0) logger.info("Total: expected anchors {}, built anchors {}.".format(np.sum(expected_anchors), out_anchors.shape)) return out_anchors def apply_box_deltas_2D(boxes, deltas): """Applies the given deltas to the given boxes. boxes: [N, 4] where each row is y1, x1, y2, x2 deltas: [N, 4] where each row is [dy, dx, log(dh), log(dw)] """ # Convert to y, x, h, w height = boxes[:, 2] - boxes[:, 0] width = boxes[:, 3] - boxes[:, 1] center_y = boxes[:, 0] + 0.5 * height center_x = boxes[:, 1] + 0.5 * width # Apply deltas center_y += deltas[:, 0] * height center_x += deltas[:, 1] * width height *= torch.exp(deltas[:, 2]) width *= torch.exp(deltas[:, 3]) # Convert back to y1, x1, y2, x2 y1 = center_y - 0.5 * height x1 = center_x - 0.5 * width y2 = y1 + height x2 = x1 + width result = torch.stack([y1, x1, y2, x2], dim=1) return result def apply_box_deltas_3D(boxes, deltas): """Applies the given deltas to the given boxes. boxes: [N, 6] where each row is y1, x1, y2, x2, z1, z2 deltas: [N, 6] where each row is [dy, dx, dz, log(dh), log(dw), log(dd)] """ # Convert to y, x, h, w height = boxes[:, 2] - boxes[:, 0] width = boxes[:, 3] - boxes[:, 1] depth = boxes[:, 5] - boxes[:, 4] center_y = boxes[:, 0] + 0.5 * height center_x = boxes[:, 1] + 0.5 * width center_z = boxes[:, 4] + 0.5 * depth # Apply deltas center_y += deltas[:, 0] * height center_x += deltas[:, 1] * width center_z += deltas[:, 2] * depth height *= torch.exp(deltas[:, 3]) width *= torch.exp(deltas[:, 4]) depth *= torch.exp(deltas[:, 5]) # Convert back to y1, x1, y2, x2 y1 = center_y - 0.5 * height x1 = center_x - 0.5 * width z1 = center_z - 0.5 * depth y2 = y1 + height x2 = x1 + width z2 = z1 + depth result = torch.stack([y1, x1, y2, x2, z1, z2], dim=1) return result def clip_boxes_2D(boxes, window): """ boxes: [N, 4] each col is y1, x1, y2, x2 window: [4] in the form y1, x1, y2, x2 """ boxes = torch.stack( \ [boxes[:, 0].clamp(float(window[0]), float(window[2])), boxes[:, 1].clamp(float(window[1]), float(window[3])), boxes[:, 2].clamp(float(window[0]), float(window[2])), boxes[:, 3].clamp(float(window[1]), float(window[3]))], 1) return boxes def clip_boxes_3D(boxes, window): """ boxes: [N, 6] each col is y1, x1, y2, x2, z1, z2 window: [6] in the form y1, x1, y2, x2, z1, z2 """ boxes = torch.stack( \ [boxes[:, 0].clamp(float(window[0]), float(window[2])), boxes[:, 1].clamp(float(window[1]), float(window[3])), boxes[:, 2].clamp(float(window[0]), float(window[2])), boxes[:, 3].clamp(float(window[1]), float(window[3])), boxes[:, 4].clamp(float(window[4]), float(window[5])), boxes[:, 5].clamp(float(window[4]), float(window[5]))], 1) return boxes from matplotlib import pyplot as plt def clip_boxes_numpy(boxes, window): """ boxes: [N, 4] each col is y1, x1, y2, x2 / [N, 6] in 3D. window: iamge shape (y, x, (z)) """ if boxes.shape[1] == 4: boxes = np.concatenate( (np.clip(boxes[:, 0], 0, window[0])[:, None], np.clip(boxes[:, 1], 0, window[0])[:, None], np.clip(boxes[:, 2], 0, window[1])[:, None], np.clip(boxes[:, 3], 0, window[1])[:, None]), 1 ) else: boxes = np.concatenate( (np.clip(boxes[:, 0], 0, window[0])[:, None], np.clip(boxes[:, 1], 0, window[0])[:, None], np.clip(boxes[:, 2], 0, window[1])[:, None], np.clip(boxes[:, 3], 0, window[1])[:, None], np.clip(boxes[:, 4], 0, window[2])[:, None], np.clip(boxes[:, 5], 0, window[2])[:, None]), 1 ) return boxes def bbox_overlaps_2D(boxes1, boxes2): """Computes IoU overlaps between two sets of boxes. boxes1, boxes2: [N, (y1, x1, y2, x2)]. """ # 1. Tile boxes2 and repeate boxes1. This allows us to compare # every boxes1 against every boxes2 without loops. # TF doesn't have an equivalent to np.repeate() so simulate it # using tf.tile() and tf.reshape. boxes1_repeat = boxes2.size()[0] boxes2_repeat = boxes1.size()[0] boxes1 = boxes1.repeat(1,boxes1_repeat).view(-1,4) boxes2 = boxes2.repeat(boxes2_repeat,1) # 2. Compute intersections b1_y1, b1_x1, b1_y2, b1_x2 = boxes1.chunk(4, dim=1) b2_y1, b2_x1, b2_y2, b2_x2 = boxes2.chunk(4, dim=1) y1 = torch.max(b1_y1, b2_y1)[:, 0] x1 = torch.max(b1_x1, b2_x1)[:, 0] y2 = torch.min(b1_y2, b2_y2)[:, 0] x2 = torch.min(b1_x2, b2_x2)[:, 0] #--> expects x11 produced in bbox_overlaps_2D" overlaps = iou.view(boxes2_repeat, boxes1_repeat) #--> per gt box: ious of all proposal boxes with that gt box return overlaps def bbox_overlaps_3D(boxes1, boxes2): """Computes IoU overlaps between two sets of boxes. boxes1, boxes2: [N, (y1, x1, y2, x2, z1, z2)]. """ # 1. Tile boxes2 and repeate boxes1. This allows us to compare # every boxes1 against every boxes2 without loops. # TF doesn't have an equivalent to np.repeate() so simulate it # using tf.tile() and tf.reshape. boxes1_repeat = boxes2.size()[0] boxes2_repeat = boxes1.size()[0] boxes1 = boxes1.repeat(1,boxes1_repeat).view(-1,6) boxes2 = boxes2.repeat(boxes2_repeat,1) # 2. Compute intersections b1_y1, b1_x1, b1_y2, b1_x2, b1_z1, b1_z2 = boxes1.chunk(6, dim=1) b2_y1, b2_x1, b2_y2, b2_x2, b2_z1, b2_z2 = boxes2.chunk(6, dim=1) y1 = torch.max(b1_y1, b2_y1)[:, 0] x1 = torch.max(b1_x1, b2_x1)[:, 0] y2 = torch.min(b1_y2, b2_y2)[:, 0] x2 = torch.min(b1_x2, b2_x2)[:, 0] z1 = torch.max(b1_z1, b2_z1)[:, 0] z2 = torch.min(b1_z2, b2_z2)[:, 0] zeros = torch.zeros(y1.size()[0], requires_grad=False) if y1.is_cuda: zeros = zeros.cuda() intersection = torch.max(x2 - x1, zeros) * torch.max(y2 - y1, zeros) * torch.max(z2 - z1, zeros) # 3. Compute unions b1_volume = (b1_y2 - b1_y1) * (b1_x2 - b1_x1) * (b1_z2 - b1_z1) b2_volume = (b2_y2 - b2_y1) * (b2_x2 - b2_x1) * (b2_z2 - b2_z1) union = b1_volume[:,0] + b2_volume[:,0] - intersection # 4. Compute IoU and reshape to [boxes1, boxes2] iou = intersection / union overlaps = iou.view(boxes2_repeat, boxes1_repeat) return overlaps def gt_anchor_matching(cf, anchors, gt_boxes, gt_class_ids=None): """Given the anchors and GT boxes, compute overlaps and identify positive anchors and deltas to refine them to match their corresponding GT boxes. anchors: [num_anchors, (y1, x1, y2, x2, (z1), (z2))] gt_boxes: [num_gt_boxes, (y1, x1, y2, x2, (z1), (z2))] gt_class_ids (optional): [num_gt_boxes] Integer class IDs for one stage detectors. in RPN case of Mask R-CNN, set all positive matches to 1 (foreground) Returns: anchor_class_matches: [N] (int32) matches between anchors and GT boxes. 1 = positive anchor, -1 = negative anchor, 0 = neutral anchor_delta_targets: [N, (dy, dx, (dz), log(dh), log(dw), (log(dd)))] Anchor bbox deltas. """ anchor_class_matches = np.zeros([anchors.shape[0]], dtype=np.int32) anchor_delta_targets = np.zeros((cf.rpn_train_anchors_per_image, 2*cf.dim)) anchor_matching_iou = cf.anchor_matching_iou if gt_boxes is None: anchor_class_matches = np.full(anchor_class_matches.shape, fill_value=-1) return anchor_class_matches, anchor_delta_targets # for mrcnn: anchor matching is done for RPN loss, so positive labels are all 1 (foreground) if gt_class_ids is None: gt_class_ids = np.array([1] * len(gt_boxes)) # Compute overlaps [num_anchors, num_gt_boxes] overlaps = compute_overlaps(anchors, gt_boxes) # Match anchors to GT Boxes # If an anchor overlaps a GT box with IoU >= anchor_matching_iou then it's positive. # If an anchor overlaps a GT box with IoU < 0.1 then it's negative. # Neutral anchors are those that don't match the conditions above, # and they don't influence the loss function. # However, don't keep any GT box unmatched (rare, but happens). Instead, # match it to the closest anchor (even if its max IoU is < 0.1). # 1. Set negative anchors first. They get overwritten below if a GT box is # matched to them. Skip boxes in crowd areas. anchor_iou_argmax = np.argmax(overlaps, axis=1) anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax] if anchors.shape[1] == 4: anchor_class_matches[(anchor_iou_max < 0.1)] = -1 elif anchors.shape[1] == 6: anchor_class_matches[(anchor_iou_max < 0.01)] = -1 else: raise ValueError('anchor shape wrong {}'.format(anchors.shape)) # 2. Set an anchor for each GT box (regardless of IoU value). gt_iou_argmax = np.argmax(overlaps, axis=0) for ix, ii in enumerate(gt_iou_argmax): anchor_class_matches[ii] = gt_class_ids[ix] # 3. Set anchors with high overlap as positive. above_thresh_ixs = np.argwhere(anchor_iou_max >= anchor_matching_iou) anchor_class_matches[above_thresh_ixs] = gt_class_ids[anchor_iou_argmax[above_thresh_ixs]] # Subsample to balance positive anchors. ids = np.where(anchor_class_matches > 0)[0] extra = len(ids) - (cf.rpn_train_anchors_per_image // 2) if extra > 0: # Reset the extra ones to neutral ids = np.random.choice(ids, extra, replace=False) anchor_class_matches[ids] = 0 # Leave all negative proposals negative for now and sample from them later in online hard example mining. # For positive anchors, compute shift and scale needed to transform them to match the corresponding GT boxes. ids = np.where(anchor_class_matches > 0)[0] ix = 0 # index into anchor_delta_targets for i, a in zip(ids, anchors[ids]): # closest gt box (it might have IoU < anchor_matching_iou) gt = gt_boxes[anchor_iou_argmax[i]] # convert coordinates to center plus width/height. gt_h = gt[2] - gt[0] gt_w = gt[3] - gt[1] gt_center_y = gt[0] + 0.5 * gt_h gt_center_x = gt[1] + 0.5 * gt_w # Anchor a_h = a[2] - a[0] a_w = a[3] - a[1] a_center_y = a[0] + 0.5 * a_h a_center_x = a[1] + 0.5 * a_w if cf.dim == 2: anchor_delta_targets[ix] = [ (gt_center_y - a_center_y) / a_h, (gt_center_x - a_center_x) / a_w, np.log(gt_h / a_h), np.log(gt_w / a_w), ] else: gt_d = gt[5] - gt[4] gt_center_z = gt[4] + 0.5 * gt_d a_d = a[5] - a[4] a_center_z = a[4] + 0.5 * a_d anchor_delta_targets[ix] = [ (gt_center_y - a_center_y) / a_h, (gt_center_x - a_center_x) / a_w, (gt_center_z - a_center_z) / a_d, np.log(gt_h / a_h), np.log(gt_w / a_w), np.log(gt_d / a_d) ] # normalize. anchor_delta_targets[ix] /= cf.rpn_bbox_std_dev ix += 1 return anchor_class_matches, anchor_delta_targets def clip_to_window(window, boxes): """ window: (y1, x1, y2, x2) / 3D: (z1, z2). The window in the image we want to clip to. boxes: [N, (y1, x1, y2, x2)] / 3D: (z1, z2) """ boxes[:, 0] = boxes[:, 0].clamp(float(window[0]), float(window[2])) boxes[:, 1] = boxes[:, 1].clamp(float(window[1]), float(window[3])) boxes[:, 2] = boxes[:, 2].clamp(float(window[0]), float(window[2])) boxes[:, 3] = boxes[:, 3].clamp(float(window[1]), float(window[3])) if boxes.shape[1] > 5: boxes[:, 4] = boxes[:, 4].clamp(float(window[4]), float(window[5])) boxes[:, 5] = boxes[:, 5].clamp(float(window[4]), float(window[5])) return boxes ############################################################ # Connected Componenent Analysis ############################################################ def get_coords(binary_mask, n_components, dim): """ loops over batch to perform connected component analysis on binary input mask. computes box coordinates around n_components - biggest components (rois). :param binary_mask: (b, y, x, (z)). binary mask for one specific foreground class. :param n_components: int. number of components to extract per batch element and class. :return: coords (b, n, (y1, x1, y2, x2 (,z1, z2)) :return: batch_components (b, n, (y1, x1, y2, x2, (z1), (z2)) """ assert len(binary_mask.shape)==dim+1 binary_mask = binary_mask.astype('uint8') batch_coords = [] batch_components = [] for ix,b in enumerate(binary_mask): clusters, n_cands = lb(b) # performs connected component analysis. uniques, counts = np.unique(clusters, return_counts=True) keep_uniques = uniques[1:][np.argsort(counts[1:])[::-1]][:n_components] #only keep n_components largest components p_components = np.array([(clusters == ii) * 1 for ii in keep_uniques]) # separate clusters and concat p_coords = [] if p_components.shape[0] > 0: for roi in p_components: mask_ixs = np.argwhere(roi != 0) # get coordinates around component. roi_coords = [np.min(mask_ixs[:, 0]) - 1, np.min(mask_ixs[:, 1]) - 1, np.max(mask_ixs[:, 0]) + 1, np.max(mask_ixs[:, 1]) + 1] if dim == 3: roi_coords += [np.min(mask_ixs[:, 2]), np.max(mask_ixs[:, 2])+1] p_coords.append(roi_coords) p_coords = np.array(p_coords) #clip coords. p_coords[p_coords < 0] = 0 p_coords[:, :4][p_coords[:, :4] > binary_mask.shape[-2]] = binary_mask.shape[-2] if dim == 3: p_coords[:, 4:][p_coords[:, 4:] > binary_mask.shape[-1]] = binary_mask.shape[-1] batch_coords.append(p_coords) batch_components.append(p_components) return batch_coords, batch_components # noinspection PyCallingNonCallable def get_coords_gpu(binary_mask, n_components, dim): """ loops over batch to perform connected component analysis on binary input mask. computes box coordiantes around n_components - biggest components (rois). :param binary_mask: (b, y, x, (z)). binary mask for one specific foreground class. :param n_components: int. number of components to extract per batch element and class. :return: coords (b, n, (y1, x1, y2, x2 (,z1, z2)) :return: batch_components (b, n, (y1, x1, y2, x2, (z1), (z2)) """ raise Exception("throws floating point exception") assert len(binary_mask.shape)==dim+1 binary_mask = binary_mask.type(torch.uint8) batch_coords = [] batch_components = [] for ix,b in enumerate(binary_mask): clusters, n_cands = lb(b.cpu().data.numpy()) # peforms connected component analysis. clusters = torch.from_numpy(clusters).cuda() uniques = torch.unique(clusters) counts = torch.stack([(clusters==unique).sum() for unique in uniques]) keep_uniques = uniques[1:][torch.sort(counts[1:])[1].flip(0)][:n_components] #only keep n_components largest components p_components = torch.cat([(clusters == ii).unsqueeze(0) for ii in keep_uniques]).cuda() # separate clusters and concat p_coords = [] if p_components.shape[0] > 0: for roi in p_components: mask_ixs = torch.nonzero(roi) # get coordinates around component. roi_coords = [torch.min(mask_ixs[:, 0]) - 1, torch.min(mask_ixs[:, 1]) - 1, torch.max(mask_ixs[:, 0]) + 1, torch.max(mask_ixs[:, 1]) + 1] if dim == 3: roi_coords += [torch.min(mask_ixs[:, 2]), torch.max(mask_ixs[:, 2])+1] p_coords.append(roi_coords) p_coords = torch.tensor(p_coords) #clip coords. p_coords[p_coords < 0] = 0 p_coords[:, :4][p_coords[:, :4] > binary_mask.shape[-2]] = binary_mask.shape[-2] if dim == 3: p_coords[:, 4:][p_coords[:, 4:] > binary_mask.shape[-1]] = binary_mask.shape[-1] batch_coords.append(p_coords) batch_components.append(p_components) return batch_coords, batch_components ############################################################ # Pytorch Utility Functions ############################################################ def unique1d(tensor): """discard all elements of tensor that occur more than once; make tensor unique. :param tensor: :return: """ if tensor.size()[0] == 0 or tensor.size()[0] == 1: return tensor tensor = tensor.sort()[0] unique_bool = tensor[1:] != tensor[:-1] first_element = torch.tensor([True], dtype=torch.bool, requires_grad=False) if tensor.is_cuda: first_element = first_element.cuda() unique_bool = torch.cat((first_element, unique_bool), dim=0) return tensor[unique_bool.data] def intersect1d(tensor1, tensor2): aux = torch.cat((tensor1, tensor2), dim=0) aux = aux.sort(descending=True)[0] return aux[:-1][(aux[1:] == aux[:-1]).data] def shem(roi_probs_neg, negative_count, poolsize): """ stochastic hard example mining: from a list of indices (referring to non-matched predictions), determine a pool of highest scoring (worst false positives) of size negative_count*poolsize. Then, sample n (= negative_count) predictions of this pool as negative examples for loss. :param roi_probs_neg: tensor of shape (n_predictions, n_classes). :param negative_count: int. :param poolsize: int. :return: (negative_count). indices refer to the positions in roi_probs_neg. If pool smaller than expected due to limited negative proposals availabel, this function will return sampled indices of number < negative_count without throwing an error. """ # sort according to higehst foreground score. probs, order = roi_probs_neg[:, 1:].max(1)[0].sort(descending=True) select = torch.tensor((poolsize * int(negative_count), order.size()[0])).min().int() pool_indices = order[:select] rand_idx = torch.randperm(pool_indices.size()[0]) return pool_indices[rand_idx[:negative_count].cuda()] ############################################################ # Weight Init ############################################################ def initialize_weights(net): """Initialize model weights. Current Default in Pytorch (version 0.4.1) is initialization from a uniform distriubtion. Will expectably be changed to kaiming_uniform in future versions. """ init_type = net.cf.weight_init for m in [module for module in net.modules() if type(module) in [torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, torch.nn.Linear]]: if init_type == 'xavier_uniform': torch.nn.init.xavier_uniform_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif init_type == 'xavier_normal': torch.nn.init.xavier_normal_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif init_type == "kaiming_uniform": torch.nn.init.kaiming_uniform_(m.weight.data, mode='fan_out', nonlinearity=net.cf.relu, a=0) if m.bias is not None: fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(m.weight.data) bound = 1 / np.sqrt(fan_out) torch.nn.init.uniform_(m.bias, -bound, bound) elif init_type == "kaiming_normal": torch.nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity=net.cf.relu, a=0) if m.bias is not None: fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(m.weight.data) bound = 1 / np.sqrt(fan_out) torch.nn.init.normal_(m.bias, -bound, bound) net.logger.info("applied {} weight init.".format(init_type)) \ No newline at end of file