diff --git a/datasets/toy/configs.py b/datasets/toy/configs.py index b4759c7..da900eb 100644 --- a/datasets/toy/configs.py +++ b/datasets/toy/configs.py @@ -1,495 +1,495 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import sys import os sys.path.append(os.path.dirname(os.path.realpath(__file__))) import numpy as np from default_configs import DefaultConfigs from collections import namedtuple boxLabel = namedtuple('boxLabel', ["name", "color"]) Label = namedtuple("Label", ['id', 'name', 'shape', 'radius', 'color', 'regression', 'ambiguities', 'gt_distortion']) binLabel = namedtuple("binLabel", ['id', 'name', 'color', 'bin_vals']) class Configs(DefaultConfigs): def __init__(self, server_env=None): super(Configs, self).__init__(server_env) ######################### # Prepro # ######################### self.pp_rootdir = os.path.join('/mnt/HDD2TB/Documents/data/toy', "cyl1ps_dev") self.pp_npz_dir = self.pp_rootdir+"_npz" self.pre_crop_size = [320,320,8] #y,x,z; determines pp data shape (2D easily implementable, but only 3D for now) self.min_2d_radius = 6 #in pixels self.n_train_samples, self.n_test_samples = 80, 80 # not actually real one-hot encoding (ohe) but contains more info: roi-overlap only within classes. self.pp_create_ohe_seg = False self.pp_empty_samples_ratio = 0.1 self.pp_place_radii_mid_bin = True self.pp_only_distort_2d = True # outer-most intensity of blurred radii, relative to inner-object intensity. <1 for decreasing, > 1 for increasing. # e.g.: setting 0.1 means blurred edge has min intensity 10% as large as inner-object intensity. self.pp_blur_min_intensity = 0.2 self.max_instances_per_sample = 1 #how many max instances over all classes per sample (img if 2d, vol if 3d) self.max_instances_per_class = self.max_instances_per_sample # how many max instances per image per class self.noise_scale = 0. # std-dev of gaussian noise self.ambigs_sampling = "gaussian" #"gaussian" or "uniform" """ radius_calib: gt distort for calibrating uncertainty. Range of gt distortion is inferable from image by distinguishing it from the rest of the object. blurring width around edge will be shifted so that symmetric rel to orig radius. blurring scale: if self.ambigs_sampling is uniform, distribution's non-zero range (b-a) will be sqrt(12)*scale since uniform dist has variance (b-a)²/12. b,a will be placed symmetrically around unperturbed radius. if sampling is gaussian, then scale parameter sets one std dev, i.e., blurring width will be orig_radius * std_dev * 2. """ self.ambiguities = { #set which classes to apply which ambs to below in class labels #choose out of: 'outer_radius', 'inner_radius', 'radii_relations'. #kind #probability #scale (gaussian std, relative to unperturbed value) #"outer_radius": (1., 0.5), #"outer_radius_xy": (1., 0.5), #"inner_radius": (0.5, 0.1), #"radii_relations": (0.5, 0.1), "radius_calib": (1., 1./6) } # shape choices: 'cylinder', 'block' self.pp_classes = [Label(1, 'cylinder', 'cylinder', ((6,6,1),(40,40,8)), (*self.blue, 1.), "radius_2d", (), ('radius_calib',)), #Label(2, 'block', 'block', ((6,6,1),(40,40,8)), (*self.aubergine,1.), "radii_2d", (), ('radius_calib',)) ] ######################### # I/O # ######################### self.data_sourcedir = '/mnt/HDD2TB/Documents/data/toy/cyl1ps_dev' #self.data_sourcedir = '/mnt/HDD2TB/Documents/data/toy/cyl1ps_exact' self.data_sourcedir = '/mnt/HDD2TB/Documents/data/toy/cyl1ps_ambig_beyond_bin' if server_env: #self.data_sourcedir = '/datasets/data_ramien/toy/cyl1ps_exact_npz' self.data_sourcedir = '/datasets/data_ramien/toy/cyl1ps_ambig_beyond_bin_npz' self.test_data_sourcedir = os.path.join(self.data_sourcedir, 'test') self.data_sourcedir = os.path.join(self.data_sourcedir, "train") self.info_df_name = 'info_df.pickle' # one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn', 'detection_fpn']. - self.model = 'retina_unet' + self.model = 'retina_net' self.model_path = 'models/{}.py'.format(self.model if not 'retina' in self.model else 'retina_net') self.model_path = os.path.join(self.source_dir, self.model_path) ######################### # Architecture # ######################### # one out of [2, 3]. dimension the model operates in. - self.dim = 3 + self.dim = 2 # 'class', 'regression', 'regression_bin', 'regression_ken_gal' - # currently only tested mode is a single-task at a time (i.e., only one task in below list) + # currently only tested mode is a single task at a time (i.e., only one task in below list) # but, in principle, tasks could be combined (e.g., object classes and regression per class) self.prediction_tasks = ['class'] self.start_filts = 48 if self.dim == 2 else 18 self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2 self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50' self.norm = 'instance_norm' # one of None, 'instance_norm', 'batch_norm' self.relu = 'relu' # one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform') self.weight_init = None self.regression_n_features = 1 # length of regressor target vector ######################### # Data Loader # ######################### self.num_epochs = 1 self.num_train_batches = 10 if self.dim == 2 else 16 self.batch_size = 16 if self.dim == 2 else 8 self.n_cv_splits = 4 # select modalities from preprocessed data self.channels = [0] self.n_channels = len(self.channels) # which channel (mod) to show as bg in plotting, will be extra added to batch if not in self.channels self.plot_bg_chan = 0 self.crop_margin = [20, 20, 1] # has to be smaller than respective patch_size//2 self.patch_size_2D = self.pre_crop_size[:2] self.patch_size_3D = self.pre_crop_size[:2]+[8] # patch_size to be used for training. pre_crop_size is the patch_size before data augmentation. self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D # ratio of free sampled batch elements before class balancing is triggered # (>0 to include "empty"/background patches.) self.batch_random_ratio = 0.2 self.balance_target = "class_targets" if 'class' in self.prediction_tasks else "rg_bin_targets" self.observables_patient = [] self.observables_rois = [] self.seed = 3 #for generating folds ############################# # Colors, Classes, Legends # ############################# self.plot_frequency = 5 binary_bin_labels = [binLabel(1, 'r<=25', (*self.green, 1.), (1,25)), binLabel(2, 'r>25', (*self.red, 1.), (25,))] quintuple_bin_labels = [binLabel(1, 'r2-10', (*self.green, 1.), (2,10)), binLabel(2, 'r10-20', (*self.yellow, 1.), (10,20)), binLabel(3, 'r20-30', (*self.orange, 1.), (20,30)), binLabel(4, 'r30-40', (*self.bright_red, 1.), (30,40)), binLabel(5, 'r>40', (*self.red, 1.), (40,))] # choose here if to do 2-way or 5-way regression-bin classification task_spec_bin_labels = quintuple_bin_labels self.class_labels = [ # regression: regression-task label, either value or "(x,y,z)_radius" or "radii". # ambiguities: name of above defined ambig to apply to image data (not gt); need to be iterables! # gt_distortion: name of ambig to apply to gt only; needs to be iterable! # #id #name #shape #radius #color #regression #ambiguities #gt_distortion Label( 0, 'bg', None, (0, 0, 0), (*self.white, 0.), (0, 0, 0), (), ())] if "class" in self.prediction_tasks: self.class_labels += self.pp_classes else: self.class_labels += [Label(1, 'object', 'object', ('various',), (*self.orange, 1.), ('radius_2d',), ("various",), ('various',))] if any(['regression' in task for task in self.prediction_tasks]): self.bin_labels = [binLabel(0, 'bg', (*self.white, 1.), (0,))] self.bin_labels += task_spec_bin_labels self.bin_id2label = {label.id: label for label in self.bin_labels} bins = [(min(label.bin_vals), max(label.bin_vals)) for label in self.bin_labels] self.bin_id2rg_val = {ix: [np.mean(bin)] for ix, bin in enumerate(bins)} self.bin_edges = [(bins[i][1] + bins[i + 1][0]) / 2 for i in range(len(bins) - 1)] self.bin_dict = {label.id: label.name for label in self.bin_labels if label.id != 0} if self.class_specific_seg: self.seg_labels = self.class_labels self.box_type2label = {label.name: label for label in self.box_labels} self.class_id2label = {label.id: label for label in self.class_labels} self.class_dict = {label.id: label.name for label in self.class_labels if label.id != 0} self.seg_id2label = {label.id: label for label in self.seg_labels} self.cmap = {label.id: label.color for label in self.seg_labels} self.plot_prediction_histograms = True self.plot_stat_curves = False self.has_colorchannels = False self.plot_class_ids = True self.num_classes = len(self.class_dict) self.num_seg_classes = len(self.seg_labels) ######################### # Data Augmentation # ######################### self.do_aug = True self.da_kwargs = { 'mirror': True, 'mirror_axes': tuple(np.arange(0, self.dim, 1)), 'do_elastic_deform': False, 'alpha': (500., 1500.), 'sigma': (40., 45.), 'do_rotation': False, 'angle_x': (0., 2 * np.pi), 'angle_y': (0., 0), 'angle_z': (0., 0), 'do_scale': False, 'scale': (0.8, 1.1), 'random_crop': False, 'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3), 'border_mode_data': 'constant', 'border_cval_data': 0, 'order_data': 1 } if self.dim == 3: self.da_kwargs['do_elastic_deform'] = False self.da_kwargs['angle_x'] = (0, 0.0) self.da_kwargs['angle_y'] = (0, 0.0) # must be 0!! self.da_kwargs['angle_z'] = (0., 2 * np.pi) ######################### # Schedule / Selection # ######################### # decide whether to validate on entire patient volumes (like testing) or sampled patches (like training) # the former is morge accurate, while the latter is faster (depending on volume size) self.val_mode = 'val_patient' # one of 'val_sampling' , 'val_patient' if self.val_mode == 'val_patient': self.max_val_patients = 220 # if 'all' iterates over entire val_set once. if self.val_mode == 'val_sampling': self.num_val_batches = 200 if self.dim==2 else 100 self.save_n_models = 2 self.min_save_thresh = 1 if self.dim == 2 else 1 # =wait time in epochs if "class" in self.prediction_tasks: self.model_selection_criteria = {name + "_ap": 1. for name in self.class_dict.values()} elif any("regression" in task for task in self.prediction_tasks): self.model_selection_criteria = {name + "_ap": 0.2 for name in self.class_dict.values()} self.model_selection_criteria.update({name + "_avp": 0.8 for name in self.class_dict.values()}) self.lr_decay_factor = 0.5 self.scheduling_patience = int(self.num_epochs / 5) self.weight_decay = 1e-5 self.clip_norm = None # number or None ######################### # Testing / Plotting # ######################### self.test_aug_axes = (0,1,(0,1)) # None or list: choices are 0,1,(0,1) self.held_out_test_set = True self.max_test_patients = "all" # number or "all" for all self.test_against_exact_gt = not 'exact' in self.data_sourcedir self.val_against_exact_gt = False # True is an unrealistic --> irrelevant scenario. self.report_score_level = ['rois'] # 'patient' or 'rois' (incl) self.patient_class_of_interest = 1 self.patient_bin_of_interest = 2 self.eval_bins_separately = False#"additionally" if not 'class' in self.prediction_tasks else False self.metrics = ['ap', 'auc', 'dice'] if any(['regression' in task for task in self.prediction_tasks]): self.metrics += ['avp', 'rg_MAE_weighted', 'rg_MAE_weighted_tp', 'rg_bin_accuracy_weighted', 'rg_bin_accuracy_weighted_tp'] if 'aleatoric' in self.model: self.metrics += ['rg_uncertainty', 'rg_uncertainty_tp', 'rg_uncertainty_tp_weighted'] self.evaluate_fold_means = True self.ap_match_ious = [0.5] # threshold(s) for considering a prediction as true positive self.min_det_thresh = 0.3 self.model_max_iou_resolution = 0.9 # aggregation method for test and val_patient predictions. # wbc = weighted box clustering as in https://arxiv.org/pdf/1811.08661.pdf, # nms = standard non-maximum suppression, or None = no clustering self.clustering = 'wbc' # iou thresh (exclusive!) for regarding two preds as concerning the same ROI self.clustering_iou = self.model_max_iou_resolution # has to be larger than desired possible overlap iou of model predictions self.merge_2D_to_3D_preds = False self.merge_3D_iou = self.model_max_iou_resolution self.n_test_plots = 1 # per fold and rank self.test_n_epochs = self.save_n_models # should be called n_test_ens, since is number of models to ensemble over during testing # is multiplied by (1 + nr of test augs) #self.losses_to_monitor += ['class_loss', 'rg_loss'] ######################### # Assertions # ######################### if not 'class' in self.prediction_tasks: assert self.num_classes == 1 ######################### # Add model specifics # ######################### {'mrcnn': self.add_mrcnn_configs, 'mrcnn_aleatoric': self.add_mrcnn_configs, 'retina_net': self.add_mrcnn_configs, 'retina_unet': self.add_mrcnn_configs, 'detection_unet': self.add_det_unet_configs, 'detection_fpn': self.add_det_fpn_configs }[self.model]() def rg_val_to_bin_id(self, rg_val): #only meant for isotropic radii!! # only 2D radii (x and y dims) or 1D (x or y) are expected return np.round(np.digitize(rg_val, self.bin_edges).mean()) def add_det_fpn_configs(self): self.learning_rate = [5 * 1e-4] * self.num_epochs self.dynamic_lr_scheduling = True self.scheduling_criterion = 'torch_loss' self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' self.n_roi_candidates = 4 if self.dim == 2 else 6 # max number of roi candidates to identify per image (slice in 2D, volume in 3D) # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') self.seg_loss_mode = 'wce' self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1, 1] self.fp_dice_weight = 1 if self.dim == 2 else 1 # if <1, false positive predictions in foreground are penalized less. self.detection_min_confidence = 0.05 # how to determine score of roi: 'max' or 'median' self.score_det = 'max' def add_det_unet_configs(self): self.learning_rate = [5 * 1e-4] * self.num_epochs self.dynamic_lr_scheduling = True self.scheduling_criterion = "torch_loss" self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' # max number of roi candidates to identify per image (slice in 2D, volume in 3D) self.n_roi_candidates = 4 if self.dim == 2 else 6 # loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce') self.seg_loss_mode = 'wce' self.wce_weights = [1] * self.num_seg_classes if 'dice' in self.seg_loss_mode else [0.1, 1, 1] # if <1, false positive predictions in foreground are penalized less. self.fp_dice_weight = 1 if self.dim == 2 else 1 self.detection_min_confidence = 0.05 # how to determine score of roi: 'max' or 'median' self.score_det = 'max' self.init_filts = 32 self.kernel_size = 3 # ks for horizontal, normal convs self.kernel_size_m = 2 # ks for max pool self.pad = "same" # "same" or integer, padding of horizontal convs def add_mrcnn_configs(self): self.learning_rate = [1e-4] * self.num_epochs self.dynamic_lr_scheduling = True # with scheduler set in exec self.scheduling_criterion = max(self.model_selection_criteria, key=self.model_selection_criteria.get) self.scheduling_mode = 'min' if "loss" in self.scheduling_criterion else 'max' # number of classes for network heads: n_foreground_classes + 1 (background) self.head_classes = self.num_classes + 1 if 'class' in self.prediction_tasks else 2 # feed +/- n neighbouring slices into channel dimension. set to None for no context. self.n_3D_context = None if self.n_3D_context is not None and self.dim == 2: self.n_channels *= (self.n_3D_context * 2 + 1) self.detect_while_training = True # disable the re-sampling of mask proposals to original size for speed-up. # since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching), # mask outputs are optional. self.return_masks_in_train = True self.return_masks_in_val = True self.return_masks_in_test = True # feature map strides per pyramid level are inferred from architecture. anchor scales are set accordingly. self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]} # anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale # per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.) self.rpn_anchor_scales = {'xy': [[4], [8], [16], [32]], 'z': [[1], [2], [4], [8]]} # choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3. self.pyramid_levels = [0, 1, 2, 3] # number of feature maps in rpn. typically lowered in 3D to save gpu-memory. self.n_rpn_features = 512 if self.dim == 2 else 64 # anchor ratios and strides per position in feature maps. self.rpn_anchor_ratios = [0.5, 1., 2.] self.rpn_anchor_stride = 1 # Threshold for first stage (RPN) non-maximum suppression (NMS): LOWER == HARDER SELECTION self.rpn_nms_threshold = max(0.8, self.model_max_iou_resolution) # loss sampling settings. self.rpn_train_anchors_per_image = 4 self.train_rois_per_image = 6 # per batch_instance self.roi_positive_ratio = 0.5 self.anchor_matching_iou = 0.8 # k negative example candidates are drawn from a pool of size k*shem_poolsize (stochastic hard-example mining), # where k<=#positive examples. self.shem_poolsize = 2 self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3) self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5) self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10) self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2]) self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]]) self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1], self.patch_size_3D[2], self.patch_size_3D[2]]) # y1,x1,y2,x2,z1,z2 if self.dim == 2: self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4] self.bbox_std_dev = self.bbox_std_dev[:4] self.window = self.window[:4] self.scale = self.scale[:4] self.plot_y_max = 1.5 self.n_plot_rpn_props = 5 if self.dim == 2 else 30 # per batch_instance (slice in 2D / patient in 3D) # pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element. self.pre_nms_limit = 2000 if self.dim == 2 else 4000 # n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True, # since proposals of the entire batch are forwarded through second stage as one "batch". self.roi_chunk_size = 1300 if self.dim == 2 else 500 self.post_nms_rois_training = 200 * (self.head_classes-1) if self.dim == 2 else 400 self.post_nms_rois_inference = 200 * (self.head_classes-1) # Final selection of detections (refine_detections) self.model_max_instances_per_batch_element = 9 if self.dim == 2 else 18 # per batch element and class. self.detection_nms_threshold = self.model_max_iou_resolution # needs to be > 0, otherwise all predictions are one cluster. self.model_min_confidence = 0.2 # iou for nms in box refining (directly after heads), should be >0 since ths>=x in mrcnn.py if self.dim == 2: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride))] for stride in self.backbone_strides['xy']]) else: self.backbone_shapes = np.array( [[int(np.ceil(self.patch_size[0] / stride)), int(np.ceil(self.patch_size[1] / stride)), int(np.ceil(self.patch_size[2] / stride_z))] for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z'] )]) if self.model == 'retina_net' or self.model == 'retina_unet': # whether to use focal loss (True) or hard-example mining (set focal_loss to False) self.focal_loss = True # implement extra anchor-scales according to https://arxiv.org/abs/1708.02002 self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['xy']] self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in self.rpn_anchor_scales['z']] self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3 #self.n_rpn_features = 256 if self.dim == 2 else 64 # pre-selection of detections for NMS-speedup. per entire batch. self.pre_nms_limit = (500 if self.dim == 2 else 6250) * self.batch_size # anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002 self.anchor_matching_iou = 0.7 if self.model == 'retina_unet': self.operate_stride1 = True diff --git a/exec.py b/exec.py index 6dc00b5..9e9073b 100644 --- a/exec.py +++ b/exec.py @@ -1,348 +1,348 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ execution script. this where all routines come together and the only script you need to call. refer to parse args below to see options for execution. """ import plotting as plg import os import warnings import argparse import time import torch import utils.exp_utils as utils from evaluator import Evaluator from predictor import Predictor for msg in ["Attempting to set identical bottom==top results", "This figure includes Axes that are not compatible with tight_layout", "Data has no positive values, and therefore cannot be log-scaled.", ".*invalid value encountered in true_divide.*"]: warnings.filterwarnings("ignore", msg) def train(cf, logger): """ performs the training routine for a given fold. saves plots and selected parameters to the experiment dir specified in the configs. """ logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format( cf.dim, cf.fold, cf.exp_dir, cf.model)) logger.time("train_val") # -------------- inits and settings ----------------- net = model.net(cf, logger).cuda() if cf.optimizer == "ADAM": optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay) elif cf.optimizer == "SGD": optimizer = torch.optim.SGD(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay, momentum=0.3) if cf.dynamic_lr_scheduling: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=cf.scheduling_mode, factor=cf.lr_decay_factor, patience=cf.scheduling_patience) model_selector = utils.ModelSelector(cf, logger) starting_epoch = 1 if cf.resume_from_checkpoint: starting_epoch = utils.load_checkpoint(cf.resume_from_checkpoint, net, optimizer) logger.info('resumed from checkpoint {} at epoch {}'.format(cf.resume_from_checkpoint, starting_epoch)) # prepare monitoring monitor_metrics = utils.prepare_monitoring(cf) logger.info('loading dataset and initializing batch generators...') batch_gen = data_loader.get_train_generators(cf, logger) # -------------- training ----------------- for epoch in range(starting_epoch, cf.num_epochs + 1): logger.info('starting training epoch {}/{}'.format(epoch, cf.num_epochs)) logger.time("train_epoch") net.train() train_results_list = [] train_evaluator = Evaluator(cf, logger, mode='train') for i in range(cf.num_train_batches): logger.time("train_batch_loadfw") batch = next(batch_gen['train']) batch_gen['train'].generator.stats['roi_counts'] += batch['roi_counts'] batch_gen['train'].generator.stats['empty_samples_count'] += batch['empty_samples_count'] logger.time("train_batch_loadfw") logger.time("train_batch_netfw") results_dict = net.train_forward(batch) logger.time("train_batch_netfw") logger.time("train_batch_bw") optimizer.zero_grad() results_dict['torch_loss'].backward() if cf.clip_norm: torch.nn.utils.clip_grad_norm_(net.parameters(), cf.clip_norm, norm_type=2) #gradient clipping optimizer.step() train_results_list.append(({k:v for k,v in results_dict.items() if k != "seg_preds"}, batch["pid"])) #slim res dict if not cf.server_env: print("\rFinished training batch " + "{}/{} in {:.1f}s ({:.2f}/{:.2f} forw load/net, {:.2f} backw).".format(i+1, cf.num_train_batches, logger.get_time("train_batch_loadfw")+ logger.get_time("train_batch_netfw") +logger.time("train_batch_bw"), logger.get_time("train_batch_loadfw",reset=True), logger.get_time("train_batch_netfw", reset=True), logger.get_time("train_batch_bw", reset=True)), end="", flush=True) print() #--------------- train eval ---------------- if (epoch-1)%cf.plot_frequency==0: # view an example batch plg.view_batch(cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True, out_file=os.path.join(cf.plot_dir, 'batch_example_train_{}.png'.format(cf.fold))) logger.time("evals") _, monitor_metrics['train'] = train_evaluator.evaluate_predictions(train_results_list, monitor_metrics['train']) #np_loss, torch_loss = train_loss_running_mean / cf.num_train_batches, monitor_metrics['train']["loss"][-1] #assert np_loss/torch_loss-1<0.005, "{} vs {}".format(np_loss, torch_loss) logger.time("evals") logger.time("train_epoch", toggle=False) del train_results_list #----------- validation ------------ logger.info('starting validation in mode {}.'.format(cf.val_mode)) logger.time("val_epoch") with torch.no_grad(): net.eval() val_results_list = [] val_evaluator = Evaluator(cf, logger, mode=cf.val_mode) val_predictor = Predictor(cf, net, logger, mode='val') for i in range(batch_gen['n_val']): logger.time("val_batch") batch = next(batch_gen[cf.val_mode]) if cf.val_mode == 'val_patient': results_dict = val_predictor.predict_patient(batch) elif cf.val_mode == 'val_sampling': results_dict = net.train_forward(batch, is_validation=True) val_results_list.append([results_dict, batch["pid"]]) if not cf.server_env: print("\rFinished validation {} {}/{} in {:.1f}s.".format('patient' if cf.val_mode=='val_patient' else 'batch', i + 1, batch_gen['n_val'], logger.time("val_batch")), end="", flush=True) print() #------------ val eval ------------- logger.time("val_plot") if (epoch - 1) % cf.plot_frequency == 0: plg.view_batch(cf, batch, results_dict, has_colorchannels=cf.has_colorchannels, show_gt_labels=True, out_file=os.path.join(cf.plot_dir, 'batch_example_val_{}.png'.format(cf.fold))) logger.time("val_plot") logger.time("evals") _, monitor_metrics['val'] = val_evaluator.evaluate_predictions(val_results_list, monitor_metrics['val']) model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch) del val_results_list #----------- monitoring ------------- monitor_metrics.update({"lr": {str(g) : group['lr'] for (g, group) in enumerate(optimizer.param_groups)}}) logger.metrics2tboard(monitor_metrics, global_step=epoch) logger.time("evals") logger.info('finished epoch {}/{}, took {:.2f}s. train total: {:.2f}s, average: {:.2f}s. val total: {:.2f}s, average: {:.2f}s.'.format( epoch, cf.num_epochs, logger.get_time("train_epoch")+logger.time("val_epoch"), logger.get_time("train_epoch"), logger.get_time("train_epoch", reset=True)/cf.num_train_batches, logger.get_time("val_epoch"), logger.get_time("val_epoch", reset=True)/batch_gen["n_val"])) logger.info("time for evals: {:.2f}s, val plot {:.2f}s".format(logger.get_time("evals", reset=True), logger.get_time("val_plot", reset=True))) #-------------- scheduling ----------------- if not cf.dynamic_lr_scheduling: for param_group in optimizer.param_groups: param_group['lr'] = cf.learning_rate[epoch-1] else: scheduler.step(monitor_metrics["val"][cf.scheduling_criterion][-1]) logger.time("train_val") logger.info("Training and validating over {} epochs took {}".format(cf.num_epochs, logger.get_time("train_val", format="hms", reset=True))) batch_gen['train'].generator.print_stats(logger, plot=True) def test(cf, logger, max_fold=None): """performs testing for a given fold (or held out set). saves stats in evaluator. """ logger.time("test_fold") logger.info('starting testing model of fold {} in exp {}'.format(cf.fold, cf.exp_dir)) net = model.net(cf, logger).cuda() batch_gen = data_loader.get_test_generator(cf, logger) test_predictor = Predictor(cf, net, logger, mode='test') test_results_list = test_predictor.predict_test_set(batch_gen, return_results = not hasattr( cf, "eval_test_separately") or not cf.eval_test_separately) if test_results_list is not None: test_evaluator = Evaluator(cf, logger, mode='test') test_evaluator.evaluate_predictions(test_results_list) test_evaluator.score_test_df(max_fold=max_fold) mins, secs = divmod(logger.get_time("test_fold"), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) logger.info('Testing of fold {} took {}.'.format(cf.fold, t)) if __name__ == '__main__': stime = time.time() parser = argparse.ArgumentParser() parser.add_argument('-m', '--mode', type=str, default='train_test', help='one out of: create_exp, analysis, train, train_test, or test') parser.add_argument('-f', '--folds', nargs='+', type=int, default=None, help='None runs over all folds in CV. otherwise specify list of folds.') - parser.add_argument('--exp_dir', type=str, default='/home/gregor/Documents/RegRCNN/datasets/toy/experiments/dev', + parser.add_argument('--exp_dir', type=str, default='datasets/toy/experiments/dev', help='path to experiment dir. will be created if non existent.') parser.add_argument('--server_env', default=False, action='store_true', help='change IO settings to deploy models on a cluster.') parser.add_argument('--data_dest', type=str, default=None, help="path to final data folder if different from config") parser.add_argument('--use_stored_settings', default=False, action='store_true', help='load configs from existing exp_dir instead of source dir. always done for testing, ' 'but can be set to true to do the same for training. useful in job scheduler environment, ' 'where source code might change before the job actually runs.') parser.add_argument('--resume_from_checkpoint', type=str, default=None, help='path to checkpoint. if resuming from checkpoint, the desired fold still needs to be parsed via --folds.') parser.add_argument('--dataset_name', type=str, default='toy', help="path to the dataset-specific code in source_dir/datasets") parser.add_argument('-d', '--dev', default=False, action='store_true', help="development mode: shorten everything") args = parser.parse_args() args.dataset_name = os.path.join("datasets", args.dataset_name) if not "datasets" in args.dataset_name else args.dataset_name folds = args.folds resume_from_checkpoint = None if args.resume_from_checkpoint in ['None', 'none'] else args.resume_from_checkpoint if args.mode == 'create_exp': cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=False) logger = utils.get_logger(cf.exp_dir, cf.server_env) logger.info('created experiment directory at {}'.format(args.exp_dir)) elif args.mode == 'train' or args.mode == 'train_test': cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, args.use_stored_settings) if args.dev: folds = [0,1] cf.batch_size, cf.num_epochs, cf.min_save_thresh, cf.save_n_models = 3 if cf.dim==2 else 1, 1, 0, 1 cf.num_train_batches, cf.num_val_batches, cf.max_val_patients = 7, 1, 1 cf.test_n_epochs = cf.save_n_models cf.max_test_patients = 1 torch.backends.cudnn.benchmark = cf.dim==3 else: torch.backends.cudnn.benchmark = cf.cuda_benchmark if args.data_dest is not None: cf.data_dest = args.data_dest logger = utils.get_logger(cf.exp_dir, cf.server_env) data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py')) model = utils.import_module('model', cf.model_path) logger.info("loaded model from {}".format(cf.model_path)) if folds is None: folds = range(cf.n_cv_splits) for fold in folds: """k-fold cross-validation: the dataset is split into k equally-sized folds, one used for validation, one for testing, the rest for training. This loop iterates k-times over the dataset, cyclically moving the splits. k==folds, fold in [0,folds) says which split is used for testing. """ cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold)) cf.fold, logger.fold = fold, fold cf.resume_from_checkpoint = resume_from_checkpoint if not os.path.exists(cf.fold_dir): os.mkdir(cf.fold_dir) train(cf, logger) cf.resume_from_checkpoint = None if args.mode == 'train_test': test(cf, logger) elif args.mode == 'test': cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False) if args.data_dest is not None: cf.data_dest = args.data_dest logger = utils.get_logger(cf.exp_dir, cf.server_env) data_loader = utils.import_module('data_loader', os.path.join(args.dataset_name, 'data_loader.py')) model = utils.import_module('model', cf.model_path) logger.info("loaded model from {}".format(cf.model_path)) fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")]) if folds is None: folds = range(cf.n_cv_splits) if args.dev: folds = folds[:2] cf.batch_size, cf.num_test_patients, cf.test_n_epochs = 1 if cf.dim==2 else 1, 2, 2 else: torch.backends.cudnn.benchmark = cf.cuda_benchmark for fold in folds: cf.fold = fold cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold)) if cf.fold_dir in fold_dirs: test(cf, logger, max_fold=max([int(f[-1]) for f in fold_dirs])) else: logger.info("Skipping fold {} since no model parameters found.".format(fold)) # load raw predictions saved by predictor during testing, run aggregation algorithms and evaluation. elif args.mode == 'analysis': """ analyse already saved predictions. """ cf = utils.prep_exp(args.dataset_name, args.exp_dir, args.server_env, use_stored_settings=True, is_training=False) logger = utils.get_logger(cf.exp_dir, cf.server_env) if cf.held_out_test_set and not cf.eval_test_fold_wise: predictor = Predictor(cf, net=None, logger=logger, mode='analysis') results_list = predictor.load_saved_predictions() logger.info('starting evaluation...') cf.fold = 0 evaluator = Evaluator(cf, logger, mode='test') evaluator.evaluate_predictions(results_list) evaluator.score_test_df(max_fold=0) else: fold_dirs = sorted([os.path.join(cf.exp_dir, f) for f in os.listdir(cf.exp_dir) if os.path.isdir(os.path.join(cf.exp_dir, f)) and f.startswith("fold")]) if args.dev: fold_dirs = fold_dirs[:1] if folds is None: folds = range(cf.n_cv_splits) for fold in folds: cf.fold = fold cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(cf.fold)) if cf.fold_dir in fold_dirs: predictor = Predictor(cf, net=None, logger=logger, mode='analysis') results_list = predictor.load_saved_predictions() # results_list[x][1] is pid, results_list[x][0] is list of len samples-per-patient, each entry hlds # list of boxes per that sample, i.e., len(results_list[x][y][0]) would be nr of boxes in sample y of patient x logger.info('starting evaluation...') evaluator = Evaluator(cf, logger, mode='test') evaluator.evaluate_predictions(results_list) max_fold = max([int(f[-1]) for f in fold_dirs]) evaluator.score_test_df(max_fold=max_fold) else: logger.info("Skipping fold {} since no model parameters found.".format(fold)) else: raise ValueError('mode "{}" specified in args is not implemented.'.format(args.mode)) mins, secs = divmod((time.time() - stime), 60) h, mins = divmod(mins, 60) t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs)) logger.info("{} total runtime: {}".format(os.path.split(__file__)[1], t)) del logger torch.cuda.empty_cache() diff --git a/models/retina_net.py b/models/retina_net.py index ac4e17e..d25a343 100644 --- a/models/retina_net.py +++ b/models/retina_net.py @@ -1,782 +1,782 @@ #!/usr/bin/env python # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ). # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Retina Net. According to https://arxiv.org/abs/1708.02002""" import utils.model_utils as mutils import utils.exp_utils as utils import sys sys.path.append('../') from cuda_functions.nms_2D.pth_nms import nms_gpu as nms_2D from cuda_functions.nms_3D.pth_nms import nms_gpu as nms_3D import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.utils class Classifier(nn.Module): def __init__(self, cf, conv): """ Builds the classifier sub-network. """ super(Classifier, self).__init__() self.dim = conv.dim self.n_classes = cf.head_classes n_input_channels = cf.end_filts n_features = cf.n_rpn_features n_output_channels = cf.n_anchors_per_pos * cf.head_classes anchor_stride = cf.rpn_anchor_stride self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, pad=1, relu=None) def forward(self, x): """ :param x: input feature map (b, in_c, y, x, (z)) :return: class_logits (b, n_anchors, n_classes) """ x = self.conv_1(x) x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) class_logits = self.conv_final(x) axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) class_logits = class_logits.permute(*axes) class_logits = class_logits.contiguous() class_logits = class_logits.view(x.shape[0], -1, self.n_classes) return [class_logits] class BBRegressor(nn.Module): def __init__(self, cf, conv): """ Builds the bb-regression sub-network. """ super(BBRegressor, self).__init__() self.dim = conv.dim n_input_channels = cf.end_filts n_features = cf.n_rpn_features n_output_channels = cf.n_anchors_per_pos * self.dim * 2 anchor_stride = cf.rpn_anchor_stride self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, pad=1, relu=None) def forward(self, x): """ :param x: input feature map (b, in_c, y, x, (z)) :return: bb_logits (b, n_anchors, dim * 2) """ x = self.conv_1(x) x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) bb_logits = self.conv_final(x) axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) bb_logits = bb_logits.permute(*axes) bb_logits = bb_logits.contiguous() bb_logits = bb_logits.view(x.shape[0], -1, self.dim * 2) return [bb_logits] class RoIRegressor(nn.Module): def __init__(self, cf, conv, rg_feats): """ Builds the RoI-item-regression sub-network. Regression items can be, e.g., malignancy scores of tumors. """ super(RoIRegressor, self).__init__() self.dim = conv.dim n_input_channels = cf.end_filts n_features = cf.n_rpn_features self.rg_feats = rg_feats n_output_channels = cf.n_anchors_per_pos * self.rg_feats anchor_stride = cf.rpn_anchor_stride self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu, norm=cf.norm) self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, pad=1, relu=None) def forward(self, x): """ :param x: input feature map (b, in_c, y, x, (z)) :return: bb_logits (b, n_anchors, dim * 2) """ x = self.conv_1(x) x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) x = self.conv_final(x) axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) x = x.permute(*axes) x = x.contiguous() x = x.view(x.shape[0], -1, self.rg_feats) return [x] ############################################################ # Loss Functions ############################################################ # def compute_class_loss(anchor_matches, class_pred_logits, shem_poolsize=20): """ :param anchor_matches: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. :param class_pred_logits: (n_anchors, n_classes). logits from classifier sub-network. :param shem_poolsize: int. factor of top-k candidates to draw from per negative sample (online-hard-example-mining). :return: loss: torch tensor :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training. """ # Positive and Negative anchors contribute to the loss, # but neutral anchors (match value = 0) don't. pos_indices = torch.nonzero(anchor_matches > 0) neg_indices = torch.nonzero(anchor_matches == -1) # get positive samples and calucalte loss. if not 0 in pos_indices.size(): pos_indices = pos_indices.squeeze(1) roi_logits_pos = class_pred_logits[pos_indices] targets_pos = anchor_matches[pos_indices].detach() pos_loss = F.cross_entropy(roi_logits_pos, targets_pos.long()) else: pos_loss = torch.FloatTensor([0]).cuda() # get negative samples, such that the amount matches the number of positive samples, but at least 1. # get high scoring negatives by applying online-hard-example-mining. if not 0 in neg_indices.size(): neg_indices = neg_indices.squeeze(1) roi_logits_neg = class_pred_logits[neg_indices] negative_count = np.max((1, pos_indices.cpu().data.numpy().size)) roi_probs_neg = F.softmax(roi_logits_neg, dim=1) neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize) neg_loss = F.cross_entropy(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda()) # return the indices of negative samples, who contributed to the loss for monitoring plots. np_neg_ix = neg_ix.cpu().data.numpy() else: neg_loss = torch.FloatTensor([0]).cuda() np_neg_ix = np.array([]).astype('int32') loss = (pos_loss + neg_loss) / 2 return loss, np_neg_ix def compute_bbox_loss(target_deltas, pred_deltas, anchor_matches): """ :param target_deltas: (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))). Uses 0 padding to fill in unused bbox deltas. :param pred_deltas: predicted deltas from bbox regression head. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))) :param anchor_matches: tensor (n_anchors). value in [-1, 0, class_ids] for negative, neutral, and positive matched anchors. i.e., positively matched anchors are marked by class_id >0 :return: loss: torch 1D tensor. """ if not 0 in torch.nonzero(anchor_matches>0).shape: indices = torch.nonzero(anchor_matches>0).squeeze(1) # Pick bbox deltas that contribute to the loss pred_deltas = pred_deltas[indices] # Trim target bounding box deltas to the same length as pred_deltas. target_deltas = target_deltas[:pred_deltas.shape[0], :].detach() # Smooth L1 loss loss = F.smooth_l1_loss(pred_deltas, target_deltas) else: loss = torch.FloatTensor([0]).cuda() return loss def compute_rg_loss(tasks, target, pred, anchor_matches): """ :param target_deltas: (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))). Uses 0 padding to fill in unsed bbox deltas. :param pred_deltas: predicted deltas from bbox regression head. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))) :param anchor_matches: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. :return: loss: torch 1D tensor. """ if not 0 in target.shape and not 0 in torch.nonzero(anchor_matches>0).shape: indices = torch.nonzero(anchor_matches>0).squeeze(1) # Pick rgs that contribute to the loss pred = pred[indices] # Trim target target = target[:pred.shape[0]].detach() if 'regression_bin' in tasks: loss = F.cross_entropy(pred, target.long()) else: loss = F.smooth_l1_loss(pred, target) else: loss = torch.FloatTensor([0]).cuda() return loss def compute_focal_class_loss(anchor_matches, class_pred_logits, gamma=2.): """ Focal Loss FL = -(1-q)^g log(q) with q = pred class probability. :param anchor_matches: (n_anchors). [-1, 0, class] for negative, neutral, and positive matched anchors. :param class_pred_logits: (n_anchors, n_classes). logits from classifier sub-network. - :return: loss: torch tensor - :return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training. + :param gamma: g in above formula, good results with g=2 in original paper. + :return: focal loss """ # Positive and Negative anchors contribute to the loss, # but neutral anchors (match value = 0) don't. pos_indices = torch.nonzero(anchor_matches > 0).squeeze(-1) # dim=-1 instead of 1 or 0 to cover empty matches. neg_indices = torch.nonzero(anchor_matches == -1).squeeze(-1) target_classes = torch.cat( (anchor_matches[pos_indices].long(), torch.LongTensor([0] * neg_indices.shape[0]).cuda()) ) non_neutral_indices = torch.cat( (pos_indices, neg_indices) ) q = F.softmax(class_pred_logits[non_neutral_indices], dim=1) # q shape: (n_non_neutral_anchors, n_classes) # one-hot encoded target classes: keep only the pred probs of the correct class. it will receive incentive to be maximized. # log(q_i) where i = target class --> FL shape (n_anchors,) # need to transform to indices into flattened tensor to use torch.take target_locs_flat = q.shape[1] * torch.arange(q.shape[0]).cuda() + target_classes q = torch.take(q, target_locs_flat) FL = torch.log(q) # element-wise log FL *= -(1-q)**gamma # take mean over all considered anchors FL = FL.sum() / FL.shape[0] return FL def refine_detections(anchors, probs, deltas, regressions, batch_ixs, cf): """Refine classified proposals, filter overlaps and return final detections. n_proposals here is typically a very large number: batch_size * n_anchors. This function is hence optimized on trimming down n_proposals. :param anchors: (n_anchors, 2 * dim) :param probs: (n_proposals, n_classes) softmax probabilities for all rois as predicted by classifier head. :param deltas: (n_proposals, n_classes, 2 * dim) box refinement deltas as predicted by bbox regressor head. :param regressions: (n_proposals, n_classes, n_rg_feats) :param batch_ixs: (n_proposals) batch element assignemnt info for re-allocation. :return: result: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score, pred_regr)) """ anchors = anchors.repeat(len(np.unique(batch_ixs)), 1) #flatten foreground probabilities, sort and trim down to highest confidences by pre_nms limit. fg_probs = probs[:, 1:].contiguous() flat_probs, flat_probs_order = fg_probs.view(-1).sort(descending=True) keep_ix = flat_probs_order[:cf.pre_nms_limit] # reshape indices to 2D index array with shape like fg_probs. keep_arr = torch.cat(((keep_ix / fg_probs.shape[1]).unsqueeze(1), (keep_ix % fg_probs.shape[1]).unsqueeze(1)), 1) pre_nms_scores = flat_probs[:cf.pre_nms_limit] pre_nms_class_ids = keep_arr[:, 1] + 1 # add background again. pre_nms_batch_ixs = batch_ixs[keep_arr[:, 0]] pre_nms_anchors = anchors[keep_arr[:, 0]] pre_nms_deltas = deltas[keep_arr[:, 0]] pre_nms_regressions = regressions[keep_arr[:, 0]] keep = torch.arange(pre_nms_scores.size()[0]).long().cuda() # apply bounding box deltas. re-scale to image coordinates. std_dev = torch.from_numpy(np.reshape(cf.rpn_bbox_std_dev, [1, cf.dim * 2])).float().cuda() scale = torch.from_numpy(cf.scale).float().cuda() refined_rois = mutils.apply_box_deltas_2D(pre_nms_anchors / scale, pre_nms_deltas * std_dev) * scale \ if cf.dim == 2 else mutils.apply_box_deltas_3D(pre_nms_anchors / scale, pre_nms_deltas * std_dev) * scale # round and cast to int since we're deadling with pixels now refined_rois = mutils.clip_to_window(cf.window, refined_rois) pre_nms_rois = torch.round(refined_rois) for j, b in enumerate(mutils.unique1d(pre_nms_batch_ixs)): bixs = torch.nonzero(pre_nms_batch_ixs == b)[:, 0] bix_class_ids = pre_nms_class_ids[bixs] bix_rois = pre_nms_rois[bixs] bix_scores = pre_nms_scores[bixs] for i, class_id in enumerate(mutils.unique1d(bix_class_ids)): ixs = torch.nonzero(bix_class_ids == class_id)[:, 0] # nms expects boxes sorted by score. ix_rois = bix_rois[ixs] ix_scores = bix_scores[ixs] ix_scores, order = ix_scores.sort(descending=True) ix_rois = ix_rois[order, :] ix_scores = ix_scores if cf.dim == 2: class_keep = nms_2D(torch.cat((ix_rois, ix_scores.unsqueeze(1)), dim=1), cf.detection_nms_threshold) else: class_keep = nms_3D(torch.cat((ix_rois, ix_scores.unsqueeze(1)), dim=1), cf.detection_nms_threshold) # map indices back. class_keep = keep[bixs[ixs[order[class_keep]]]] # merge indices over classes for current batch element b_keep = class_keep if i == 0 else mutils.unique1d(torch.cat((b_keep, class_keep))) # only keep top-k boxes of current batch-element. top_ids = pre_nms_scores[b_keep].sort(descending=True)[1][:cf.model_max_instances_per_batch_element] b_keep = b_keep[top_ids] # merge indices over batch elements. batch_keep = b_keep if j == 0 else mutils.unique1d(torch.cat((batch_keep, b_keep))) keep = batch_keep # arrange output. result = torch.cat((pre_nms_rois[keep], pre_nms_batch_ixs[keep].unsqueeze(1).float(), pre_nms_class_ids[keep].unsqueeze(1).float(), pre_nms_scores[keep].unsqueeze(1), pre_nms_regressions[keep]), dim=1) return result def gt_anchor_matching(cf, anchors, gt_boxes, gt_class_ids=None, gt_regressions=None): """Given the anchors and GT boxes, compute overlaps and identify positive anchors and deltas to refine them to match their corresponding GT boxes. anchors: [num_anchors, (y1, x1, y2, x2, (z1), (z2))] gt_boxes: [num_gt_boxes, (y1, x1, y2, x2, (z1), (z2))] gt_class_ids (optional): [num_gt_boxes] Integer class IDs for one stage detectors. in RPN case of Mask R-CNN, set all positive matches to 1 (foreground) gt_regressions: [num_gt_rgs, n_rg_feats], if None empty rg_targets are returned Returns: anchor_class_matches: [N] (int32) matches between anchors and GT boxes. class_id = positive anchor, -1 = negative anchor, 0 = neutral. i.e., positively matched anchors are marked by class_id (which is >0). anchor_delta_targets: [N, (dy, dx, (dz), log(dh), log(dw), (log(dd)))] Anchor bbox deltas. anchor_rg_targets: [n_anchors, n_rg_feats] """ anchor_class_matches = np.zeros([anchors.shape[0]], dtype=np.int32) anchor_delta_targets = np.zeros((cf.rpn_train_anchors_per_image, 2*cf.dim)) if gt_regressions is not None: if 'regression_bin' in cf.prediction_tasks: anchor_rg_targets = np.zeros((cf.rpn_train_anchors_per_image,)) else: anchor_rg_targets = np.zeros((cf.rpn_train_anchors_per_image, cf.regression_n_features)) else: anchor_rg_targets = np.array([]) anchor_matching_iou = cf.anchor_matching_iou if gt_boxes is None: anchor_class_matches = np.full(anchor_class_matches.shape, fill_value=-1) return anchor_class_matches, anchor_delta_targets, anchor_rg_targets # for mrcnn: anchor matching is done for RPN loss, so positive labels are all 1 (foreground) if gt_class_ids is None: gt_class_ids = np.array([1] * len(gt_boxes)) # Compute overlaps [num_anchors, num_gt_boxes] overlaps = mutils.compute_overlaps(anchors, gt_boxes) # Match anchors to GT Boxes # If an anchor overlaps a GT box with IoU >= anchor_matching_iou then it's positive. # If an anchor overlaps a GT box with IoU < 0.1 then it's negative. # Neutral anchors are those that don't match the conditions above, # and they don't influence the loss function. # However, don't keep any GT box unmatched (rare, but happens). Instead, # match it to the closest anchor (even if its max IoU is < 0.1). # 1. Set negative anchors first. They get overwritten below if a GT box is # matched to them. Skip boxes in crowd areas. anchor_iou_argmax = np.argmax(overlaps, axis=1) anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax] if anchors.shape[1] == 4: anchor_class_matches[(anchor_iou_max < 0.1)] = -1 elif anchors.shape[1] == 6: anchor_class_matches[(anchor_iou_max < 0.01)] = -1 else: raise ValueError('anchor shape wrong {}'.format(anchors.shape)) # 2. Set an anchor for each GT box (regardless of IoU value). gt_iou_argmax = np.argmax(overlaps, axis=0) for ix, ii in enumerate(gt_iou_argmax): anchor_class_matches[ii] = gt_class_ids[ix] # 3. Set anchors with high overlap as positive. above_thresh_ixs = np.argwhere(anchor_iou_max >= anchor_matching_iou) anchor_class_matches[above_thresh_ixs] = gt_class_ids[anchor_iou_argmax[above_thresh_ixs]] # Subsample to balance positive anchors. ids = np.where(anchor_class_matches > 0)[0] extra = len(ids) - (cf.rpn_train_anchors_per_image // 2) if extra > 0: # Reset the extra ones to neutral ids = np.random.choice(ids, extra, replace=False) anchor_class_matches[ids] = 0 # Leave all negative proposals negative for now and sample from them later in online hard example mining. # For positive anchors, compute shift and scale needed to transform them to match the corresponding GT boxes. ids = np.where(anchor_class_matches > 0)[0] ix = 0 # index into anchor_delta_targets for i, a in zip(ids, anchors[ids]): # closest gt box (it might have IoU < anchor_matching_iou) gt = gt_boxes[anchor_iou_argmax[i]] # convert coordinates to center plus width/height. gt_h = gt[2] - gt[0] gt_w = gt[3] - gt[1] gt_center_y = gt[0] + 0.5 * gt_h gt_center_x = gt[1] + 0.5 * gt_w # Anchor a_h = a[2] - a[0] a_w = a[3] - a[1] a_center_y = a[0] + 0.5 * a_h a_center_x = a[1] + 0.5 * a_w if cf.dim == 2: anchor_delta_targets[ix] = [ (gt_center_y - a_center_y) / a_h, (gt_center_x - a_center_x) / a_w, np.log(gt_h / a_h), np.log(gt_w / a_w)] else: gt_d = gt[5] - gt[4] gt_center_z = gt[4] + 0.5 * gt_d a_d = a[5] - a[4] a_center_z = a[4] + 0.5 * a_d anchor_delta_targets[ix] = [ (gt_center_y - a_center_y) / a_h, (gt_center_x - a_center_x) / a_w, (gt_center_z - a_center_z) / a_d, np.log(gt_h / a_h), np.log(gt_w / a_w), np.log(gt_d / a_d)] # normalize. anchor_delta_targets[ix] /= cf.rpn_bbox_std_dev if gt_regressions is not None: anchor_rg_targets[ix] = gt_regressions[anchor_iou_argmax[i]] ix += 1 return anchor_class_matches, anchor_delta_targets, anchor_rg_targets ############################################################ # RetinaNet Class ############################################################ class net(nn.Module): """Encapsulates the RetinaNet model functionality. """ def __init__(self, cf, logger): """ cf: A Sub-class of the cf class model_dir: Directory to save training logs and trained weights """ super(net, self).__init__() self.cf = cf self.logger = logger self.build() if self.cf.weight_init is not None: logger.info("using pytorch weight init of type {}".format(self.cf.weight_init)) mutils.initialize_weights(self) else: logger.info("using default pytorch weight init") self.debug_acm = [] def build(self): """Build Retina Net architecture.""" # Image size must be dividable by 2 multiple times. h, w = self.cf.patch_size[:2] if h / 2 ** 5 != int(h / 2 ** 5) or w / 2 ** 5 != int(w / 2 ** 5): raise Exception("Image size must be divisible by 2 at least 5 times " "to avoid fractions when downscaling and upscaling." "For example, use 256, 320, 384, 448, 512, ... etc. ") backbone = utils.import_module('bbone', self.cf.backbone_path) self.logger.info("loaded backbone from {}".format(self.cf.backbone_path)) conv = backbone.ConvGenerator(self.cf.dim) # build Anchors, FPN, Classifier / Bbox-Regressor -head self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf) self.anchors = torch.from_numpy(self.np_anchors).float().cuda() self.fpn = backbone.FPN(self.cf, conv, operate_stride1=self.cf.operate_stride1).cuda() self.classifier = Classifier(self.cf, conv).cuda() self.bb_regressor = BBRegressor(self.cf, conv).cuda() if 'regression' in self.cf.prediction_tasks: self.roi_regressor = RoIRegressor(self.cf, conv, self.cf.regression_n_features).cuda() elif 'regression_bin' in self.cf.prediction_tasks: # classify into bins of regression values self.roi_regressor = RoIRegressor(self.cf, conv, len(self.cf.bin_labels)).cuda() else: self.roi_regressor = lambda x: [torch.tensor([]).cuda()] if self.cf.model == 'retina_unet': self.final_conv = conv(self.cf.end_filts, self.cf.num_seg_classes, ks=1, pad=0, norm=self.cf.norm, relu=None) def forward(self, img): """ :param img: input img (b, c, y, x, (z)). """ # Feature extraction fpn_outs = self.fpn(img) if self.cf.model == 'retina_unet': seg_logits = self.final_conv(fpn_outs[0]) selected_fmaps = [fpn_outs[i + 1] for i in self.cf.pyramid_levels] else: seg_logits = None selected_fmaps = [fpn_outs[i] for i in self.cf.pyramid_levels] # Loop through pyramid layers class_layer_outputs, bb_reg_layer_outputs, roi_reg_layer_outputs = [], [], [] # list of lists for p in selected_fmaps: class_layer_outputs.append(self.classifier(p)) bb_reg_layer_outputs.append(self.bb_regressor(p)) roi_reg_layer_outputs.append(self.roi_regressor(p)) # Concatenate layer outputs # Convert from list of lists of level outputs to list of lists # of outputs across levels. # e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]] class_logits = list(zip(*class_layer_outputs)) class_logits = [torch.cat(list(o), dim=1) for o in class_logits][0] bb_outputs = list(zip(*bb_reg_layer_outputs)) bb_outputs = [torch.cat(list(o), dim=1) for o in bb_outputs][0] if not 0 == roi_reg_layer_outputs[0][0].shape[0]: rg_outputs = list(zip(*roi_reg_layer_outputs)) rg_outputs = [torch.cat(list(o), dim=1) for o in rg_outputs][0] else: if self.cf.dim == 2: n_feats = np.array([p.shape[-2] * p.shape[-1] * self.cf.n_anchors_per_pos for p in selected_fmaps]).sum() else: n_feats = np.array([p.shape[-3]*p.shape[-2]*p.shape[-1]*self.cf.n_anchors_per_pos for p in selected_fmaps]).sum() rg_outputs = torch.zeros((selected_fmaps[0].shape[0], n_feats, self.cf.regression_n_features), dtype=torch.float32).fill_(float('NaN')).cuda() # merge batch_dimension and store info in batch_ixs for re-allocation. batch_ixs = torch.arange(class_logits.shape[0]).unsqueeze(1).repeat(1, class_logits.shape[1]).view(-1).cuda() flat_class_softmax = F.softmax(class_logits.view(-1, class_logits.shape[-1]), 1) flat_bb_outputs = bb_outputs.view(-1, bb_outputs.shape[-1]) flat_rg_outputs = rg_outputs.view(-1, rg_outputs.shape[-1]) detections = refine_detections(self.anchors, flat_class_softmax, flat_bb_outputs, flat_rg_outputs, batch_ixs, self.cf) return detections, class_logits, bb_outputs, rg_outputs, seg_logits def get_results(self, img_shape, detections, seg_logits, box_results_list=None): """ Restores batch dimension of merged detections, unmolds detections, creates and fills results dict. :param img_shape: :param detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score, pred_regression) :param box_results_list: None or list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, 1] only fg. vs. bg for now. class-specific return of masks will come with implementation of instance segmentation evaluation. """ detections = detections.cpu().data.numpy() batch_ixs = detections[:, self.cf.dim*2] detections = [detections[batch_ixs == ix] for ix in range(img_shape[0])] if box_results_list == None: # for test_forward, where no previous list exists. box_results_list = [[] for _ in range(img_shape[0])] for ix in range(img_shape[0]): if not 0 in detections[ix].shape: boxes = detections[ix][:, :2 * self.cf.dim].astype(np.int32) class_ids = detections[ix][:, 2 * self.cf.dim + 1].astype(np.int32) scores = detections[ix][:, 2 * self.cf.dim + 2] regressions = detections[ix][:, 2 * self.cf.dim + 3:] # Filter out detections with zero area. Often only happens in early # stages of training when the network weights are still a bit random. if self.cf.dim == 2: exclude_ix = np.where((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0] else: exclude_ix = np.where( (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4]) <= 0)[0] if exclude_ix.shape[0] > 0: boxes = np.delete(boxes, exclude_ix, axis=0) class_ids = np.delete(class_ids, exclude_ix, axis=0) scores = np.delete(scores, exclude_ix, axis=0) regressions = np.delete(regressions, exclude_ix, axis=0) if not 0 in boxes.shape: for ix2, score in enumerate(scores): if score >= self.cf.model_min_confidence: box = {'box_type': 'det', 'box_coords': boxes[ix2], 'box_score': score, 'box_pred_class_id': class_ids[ix2]} if "regression_bin" in self.cf.prediction_tasks: # in this case, regression preds are actually the rg_bin_ids --> map to rg value the bin stands for box['rg_bin'] = regressions[ix2].argmax() box['regression'] = self.cf.bin_id2rg_val[box['rg_bin']] else: box['regression'] = regressions[ix2] if hasattr(self.cf, "rg_val_to_bin_id") and \ any(['regression' in task for task in self.cf.prediction_tasks]): box['rg_bin'] = self.cf.rg_val_to_bin_id(regressions[ix2]) box_results_list[ix].append(box) results_dict = {} results_dict['boxes'] = box_results_list if seg_logits is None: # output dummy segmentation for retina_net. out_logits_shape = list(img_shape) out_logits_shape[1] = self.cf.num_seg_classes results_dict['seg_preds'] = np.zeros(out_logits_shape, dtype=np.float16) #todo: try with seg_preds=None? as to not carry heavy dummy preds. else: # output label maps for retina_unet. results_dict['seg_preds'] = F.softmax(seg_logits, 1).cpu().data.numpy() return results_dict def train_forward(self, batch, is_validation=False): """ train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data for processing, computes losses, and stores outputs in a dictionary. :param batch: dictionary containing 'data', 'seg', etc. :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixelwise segmentation output (b, c, y, x, (z)) with values [0, .., n_classes]. 'torch_loss': 1D torch tensor for backprop. 'class_loss': classification loss for monitoring. """ img = batch['data'] gt_class_ids = batch['class_targets'] gt_boxes = batch['bb_target'] if 'regression' in self.cf.prediction_tasks: gt_regressions = batch["regression_targets"] elif 'regression_bin' in self.cf.prediction_tasks: gt_regressions = batch["rg_bin_targets"] else: gt_regressions = None var_seg_ohe = torch.FloatTensor(mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).cuda() var_seg = torch.LongTensor(batch['seg']).cuda() img = torch.from_numpy(img).float().cuda() torch_loss = torch.FloatTensor([0]).cuda() # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. box_results_list = [[] for _ in range(img.shape[0])] detections, class_logits, pred_deltas, pred_rgs, seg_logits = self.forward(img) # loop over batch for b in range(img.shape[0]): # add gt boxes to results dict for monitoring. if len(gt_boxes[b]) > 0: for tix in range(len(gt_boxes[b])): gt_box = {'box_type': 'gt', 'box_coords': batch['bb_target'][b][tix]} for name in self.cf.roi_items: gt_box.update({name: batch[name][b][tix]}) box_results_list[b].append(gt_box) # match gt boxes with anchors to generate targets. anchor_class_match, anchor_target_deltas, anchor_target_rgs = gt_anchor_matching( self.cf, self.np_anchors, gt_boxes[b], gt_class_ids[b], gt_regressions[b] if gt_regressions is not None else None) # add positive anchors used for loss to results_dict for monitoring. pos_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere(anchor_class_match > 0)][:, 0], img.shape[2:]) for p in pos_anchors: box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'}) else: anchor_class_match = np.array([-1]*self.np_anchors.shape[0]) anchor_target_deltas = np.array([]) anchor_target_rgs = np.array([]) anchor_class_match = torch.from_numpy(anchor_class_match).cuda() anchor_target_deltas = torch.from_numpy(anchor_target_deltas).float().cuda() anchor_target_rgs = torch.from_numpy(anchor_target_rgs).float().cuda() if self.cf.focal_loss: # compute class loss as focal loss as suggested in original publication, but multi-class. class_loss = compute_focal_class_loss(anchor_class_match, class_logits[b], gamma=self.cf.focal_loss_gamma) # sparing appendix of negative anchors for monitoring as not really relevant else: # compute class loss with SHEM. class_loss, neg_anchor_ix = compute_class_loss(anchor_class_match, class_logits[b]) # add negative anchors used for loss to results_dict for monitoring. neg_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere(anchor_class_match == -1)][0, neg_anchor_ix], img.shape[2:]) for n in neg_anchors: box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'}) rg_loss = compute_rg_loss(self.cf.prediction_tasks, anchor_target_rgs, pred_rgs[b], anchor_class_match) bbox_loss = compute_bbox_loss(anchor_target_deltas, pred_deltas[b], anchor_class_match) torch_loss += (class_loss + bbox_loss + rg_loss) / img.shape[0] results_dict = self.get_results(img.shape, detections, seg_logits, box_results_list) results_dict['seg_preds'] = results_dict['seg_preds'].argmax(axis=1).astype('uint8')[:, np.newaxis] if self.cf.model == 'retina_unet': seg_loss_dice = 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1),var_seg_ohe) seg_loss_ce = F.cross_entropy(seg_logits, var_seg[:, 0]) torch_loss += (seg_loss_dice + seg_loss_ce) / 2 #self.logger.info("loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}, seg dice: {3:.3f}, seg ce: {4:.3f}, " # "mean pixel preds: {5:.5f}".format(torch_loss.item(), batch_class_loss.item(), batch_bbox_loss.item(), # seg_loss_dice.item(), seg_loss_ce.item(), np.mean(results_dict['seg_preds']))) if 'dice' in self.cf.metrics: results_dict['batch_dices'] = mutils.dice_per_batch_and_class( results_dict['seg_preds'], batch["seg"], self.cf.num_seg_classes, convert_to_ohe=True) #else: #self.logger.info("loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}".format( # torch_loss.item(), class_loss.item(), bbox_loss.item())) results_dict['torch_loss'] = torch_loss results_dict['class_loss'] = class_loss.item() return results_dict def test_forward(self, batch, **kwargs): """ test method. wrapper around forward pass of network without usage of any ground truth information. prepares input data for processing and stores outputs in a dictionary. :param batch: dictionary containing 'data' :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': actually contain seg probabilities since evaluated to seg_preds (via argmax) in predictor. or dummy seg logits for real retina net (detection only) """ img = torch.from_numpy(batch['data']).float().cuda() detections, _, _, _, seg_logits = self.forward(img) results_dict = self.get_results(img.shape, detections, seg_logits) return results_dict \ No newline at end of file