diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b70cb0d --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +.idea/* +apex +example_network +*.pdf +*.pyc +runs diff --git a/README.md b/README.md new file mode 100644 index 0000000..578de7e --- /dev/null +++ b/README.md @@ -0,0 +1,13 @@ +# ROBUSTMIS 2019 +Scripts for the robust medical instrument segmentation challenge 2019 + +## Challenge Robust Medical Instrument Segmentation Challenge 2019 +If you want to learn more about the challenge please visit the following two pages: +* https://robustmis2019.grand-challenge.org/ +* https://www.synapse.org/#!Synapse:syn18779625/wiki/591267 + +## Helpful scripts +In this repository we provide useful scripts to enable a faster start with the challenge. Currently we only provide a dataloader for pytorch + + + diff --git a/evaluation/dice_calculations.py b/evaluation/dice_calculations.py new file mode 100644 index 0000000..e28bd79 --- /dev/null +++ b/evaluation/dice_calculations.py @@ -0,0 +1,94 @@ +import numpy as np +from scipy.optimize import linear_sum_assignment as hungarian_algorithm + + +def compute_dice_coefficient(mask_gt, mask_pred): + """Compute soerensen-dice coefficient. + + compute the soerensen-dice coefficient between the ground truth mask `mask_gt` + and the predicted mask `mask_pred`. + + Args: + mask_gt: 3-dim Numpy array of type bool. The ground truth mask. + mask_pred: 3-dim Numpy array of type bool. The predicted mask. + + Returns: + the dice coeffcient as float. If both masks are empty, the result is NaN + """ + volume_sum = mask_gt.sum() + mask_pred.sum() + if volume_sum == 0: + return np.NaN + volume_intersect = (mask_gt & mask_pred).sum() + return 2 * volume_intersect / volume_sum + + +def compute_dice_coefficient_per_instance(mask_gt, mask_pred): + """Compute instance soerensen-dice coefficient. + + compute the soerensen-dice coefficient between the ground truth mask `mask_gt` + and the predicted mask `mask_pred` for multiple instances. + + Args: + mask_gt: 3-dim Numpy array of type int. The ground truth image, where 0 means background and 1-N is an + instrument instance. + mask_pred: 3-dim Numpy array of type int. The predicted mask, where 0 means background and 1-N is an + instrument instance. + + Returns: + a instance dictionary with the dice coeffcient as float. + """ + # get number of labels in image + instances_gt = np.unique(mask_gt) + instances_pred = np.unique(mask_pred) + + # create performance matrix + performance_matrix = np.zeros((len(instances_gt), len(instances_pred))) + masks = [] + + # calculate dice score for each ground truth to predicted instance + for counter_gt, instance_gt in enumerate(instances_gt): + + # create binary mask for current gt instance + gt = mask_gt.copy() + gt[mask_gt != instance_gt] = 0 + gt[mask_gt == instance_gt] = 1 + + masks_row = [] + for counter_pred, instance_pred in enumerate(instances_pred): + # make binary mask for current predicted instance + prediction = mask_pred.copy() + prediction[mask_pred != instance_pred] = 0 + prediction[mask_pred == instance_pred] = 1 + + # calculate dice + performance_matrix[counter_gt, counter_pred] = compute_dice_coefficient(gt, prediction) + masks_row.append([gt, prediction]) + masks.append(masks_row) + + # assign instrument instances according to hungarian algorithm + label_assignment = hungarian_algorithm(performance_matrix * -1) + label_nr_gt, label_nr_pred = label_assignment + + # get performance per instance + image_performance = [] + for i in range(len(label_nr_gt)): + instance_dice = performance_matrix[label_nr_gt[i], label_nr_pred[i]] + image_performance.append(instance_dice) + + missing_pred = np.absolute(len(instances_pred) - len(image_performance)) + missing_gt = np.absolute(len(instances_gt) - len(image_performance)) + n_missing = np.max([missing_gt, missing_pred]) + + if n_missing > 0: + for i in range(n_missing): + image_performance.append(0) + + output = dict() + for i, performance in enumerate(image_performance): + if i > 0: + output["instrument_{}".format(i - 1)] = performance + else: + output["background"] = performance + + return output + diff --git a/evaluation/distance_calculations.py b/evaluation/distance_calculations.py new file mode 100644 index 0000000..1545566 --- /dev/null +++ b/evaluation/distance_calculations.py @@ -0,0 +1,475 @@ +import numpy as np +import scipy.ndimage + +# neighbour_code_to_normals is a lookup table. +# For every binary neighbour code +# (2x2x2 neighbourhood = 8 neighbours = 8 bits = 256 codes) +# it contains the surface normals of the triangles (called "surfel" for +# "surface element" in the following). The length of the normal +# vector encodes the surfel area. +# +# created by compute_surface_area_lookup_table.ipynb using the +# marching_cube algorithm, see e.g. https://en.wikipedia.org/wiki/Marching_cubes +# +neighbour_code_to_normals = [ + [[0,0,0]], + [[0.125,0.125,0.125]], + [[-0.125,-0.125,0.125]], + [[-0.25,-0.25,0.0],[0.25,0.25,-0.0]], + [[0.125,-0.125,0.125]], + [[-0.25,-0.0,-0.25],[0.25,0.0,0.25]], + [[0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]], + [[-0.125,0.125,0.125]], + [[0.125,0.125,0.125],[-0.125,0.125,0.125]], + [[-0.25,0.0,0.25],[-0.25,0.0,0.25]], + [[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]], + [[0.25,-0.25,0.0],[0.25,-0.25,0.0]], + [[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125]], + [[-0.5,0.0,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]], + [[0.5,0.0,0.0],[0.5,0.0,0.0]], + [[0.125,-0.125,-0.125]], + [[0.0,-0.25,-0.25],[0.0,0.25,0.25]], + [[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[0.0,-0.5,0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]], + [[0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[0.0,0.0,-0.5],[0.25,0.25,0.25],[-0.125,-0.125,-0.125]], + [[-0.125,-0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.25,0.25,0.25],[0.125,0.125,0.125]], + [[-0.125,0.125,0.125],[0.125,-0.125,-0.125]], + [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[-0.125,0.125,0.125]], + [[-0.25,0.0,0.25],[-0.25,0.0,0.25],[0.125,-0.125,-0.125]], + [[0.125,0.125,0.125],[0.375,0.375,0.375],[0.0,-0.25,0.25],[-0.25,0.0,0.25]], + [[0.125,-0.125,-0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]], + [[0.375,0.375,0.375],[0.0,0.25,-0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]], + [[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.125,0.125,0.125]], + [[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25]], + [[0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[0.125,-0.125,0.125]], + [[0.0,-0.25,0.25],[0.0,0.25,-0.25]], + [[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25]], + [[0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[0.125,-0.125,0.125],[-0.25,-0.0,-0.25],[0.25,0.0,0.25]], + [[0.0,-0.25,0.25],[0.0,0.25,-0.25],[0.125,-0.125,0.125]], + [[-0.375,-0.375,0.375],[-0.0,0.25,0.25],[0.125,0.125,-0.125],[-0.25,-0.0,-0.25]], + [[-0.125,0.125,0.125],[0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,0.125,0.125]], + [[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]], + [[0.25,0.25,-0.25],[0.25,0.25,-0.25],[0.125,0.125,-0.125],[-0.125,-0.125,0.125]], + [[0.125,-0.125,0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]], + [[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125],[0.125,-0.125,0.125]], + [[0.0,0.25,-0.25],[0.375,-0.375,-0.375],[-0.125,0.125,0.125],[0.25,0.25,0.0]], + [[-0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]], + [[0.25,-0.25,0.0],[-0.25,0.25,0.0]], + [[0.0,0.5,0.0],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]], + [[0.0,0.5,0.0],[0.125,-0.125,0.125],[-0.25,0.25,-0.25]], + [[0.0,0.5,0.0],[0.0,-0.5,0.0]], + [[0.25,-0.25,0.0],[-0.25,0.25,0.0],[0.125,-0.125,0.125]], + [[-0.375,-0.375,-0.375],[-0.25,0.0,0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]], + [[0.125,0.125,0.125],[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]], + [[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]], + [[-0.125,0.125,0.125],[0.25,-0.25,0.0],[-0.25,0.25,0.0]], + [[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.375,0.375,-0.375],[-0.25,-0.25,0.0],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]], + [[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]], + [[0.25,-0.25,0.0],[-0.25,0.25,0.0],[0.25,-0.25,0.0],[0.25,-0.25,0.0]], + [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[-0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]], + [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]], + [[-0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.125,-0.125,0.125],[-0.25,-0.25,0.0],[0.25,0.25,-0.0]], + [[0.0,-0.25,0.25],[0.0,-0.25,0.25]], + [[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]], + [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[-0.125,-0.125,0.125]], + [[0.375,-0.375,0.375],[0.0,-0.25,-0.25],[-0.125,0.125,-0.125],[0.25,0.25,0.0]], + [[-0.125,-0.125,0.125],[-0.125,0.125,0.125]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[-0.125,0.125,0.125]], + [[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[-0.25,0.0,0.25]], + [[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]], + [[-0.25,0.25,-0.25],[-0.25,0.25,-0.25],[-0.125,0.125,-0.125],[-0.125,0.125,-0.125]], + [[-0.25,0.0,-0.25],[0.375,-0.375,-0.375],[0.0,0.25,-0.25],[-0.125,0.125,0.125]], + [[0.5,0.0,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]], + [[-0.25,0.0,0.25],[0.25,0.0,-0.25]], + [[-0.0,0.0,0.5],[-0.25,0.25,0.25],[-0.125,0.125,0.125]], + [[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[0.25,0.0,-0.25]], + [[-0.25,-0.0,-0.25],[-0.375,0.375,0.375],[-0.25,-0.25,0.0],[-0.125,0.125,0.125]], + [[0.0,0.0,-0.5],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]], + [[-0.0,0.0,0.5],[0.0,0.0,0.5]], + [[0.125,0.125,0.125],[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]], + [[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]], + [[-0.25,0.0,0.25],[0.25,0.0,-0.25],[-0.125,0.125,0.125]], + [[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[-0.25,0.0,0.25],[-0.25,0.0,0.25],[-0.25,0.0,0.25],[0.25,0.0,-0.25]], + [[0.125,-0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]], + [[0.25,0.0,0.25],[-0.375,-0.375,0.375],[-0.25,0.25,0.0],[-0.125,-0.125,0.125]], + [[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]], + [[0.25,0.0,0.25],[0.25,0.0,0.25]], + [[-0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[-0.125,-0.125,0.125],[0.0,-0.25,0.25],[0.0,0.25,-0.25]], + [[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]], + [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.125,-0.125,0.125]], + [[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.0,0.25,-0.25]], + [[0.0,0.25,0.25],[0.0,0.25,0.25],[0.125,-0.125,-0.125]], + [[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]], + [[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[0.125,0.125,0.125]], + [[0.125,0.125,0.125],[0.125,-0.125,-0.125]], + [[0.5,0.0,-0.0],[0.25,-0.25,-0.25],[0.125,-0.125,-0.125]], + [[-0.25,0.25,0.25],[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]], + [[0.375,-0.375,0.375],[0.0,0.25,0.25],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]], + [[0.0,-0.5,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]], + [[-0.375,-0.375,0.375],[0.25,-0.25,0.0],[0.0,0.25,0.25],[-0.125,-0.125,0.125]], + [[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.0,0.0,0.5]], + [[0.125,0.125,0.125],[0.0,0.25,0.25],[0.0,0.25,0.25]], + [[0.0,0.25,0.25],[0.0,0.25,0.25]], + [[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125],[0.125,0.125,0.125]], + [[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]], + [[-0.25,-0.0,-0.25],[0.25,0.0,0.25],[0.125,0.125,0.125]], + [[0.125,0.125,0.125],[0.125,-0.125,0.125]], + [[-0.25,-0.25,0.0],[0.25,0.25,-0.0],[0.125,0.125,0.125]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[0.125,0.125,0.125]], + [[0.125,0.125,0.125]], + [[0.125,0.125,0.125]], + [[0.125,0.125,0.125],[0.125,0.125,0.125]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.25,-0.25,0.0],[0.25,0.25,-0.0],[0.125,0.125,0.125]], + [[0.125,0.125,0.125],[0.125,-0.125,0.125]], + [[-0.25,-0.0,-0.25],[0.25,0.0,0.25],[0.125,0.125,0.125]], + [[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]], + [[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125],[0.125,0.125,0.125]], + [[0.0,0.25,0.25],[0.0,0.25,0.25]], + [[0.125,0.125,0.125],[0.0,0.25,0.25],[0.0,0.25,0.25]], + [[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.0,0.0,0.5]], + [[-0.375,-0.375,0.375],[0.25,-0.25,0.0],[0.0,0.25,0.25],[-0.125,-0.125,0.125]], + [[0.0,-0.5,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]], + [[0.375,-0.375,0.375],[0.0,0.25,0.25],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]], + [[-0.25,0.25,0.25],[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]], + [[0.5,0.0,-0.0],[0.25,-0.25,-0.25],[0.125,-0.125,-0.125]], + [[0.125,0.125,0.125],[0.125,-0.125,-0.125]], + [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[0.125,0.125,0.125]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]], + [[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[0.0,0.25,0.25],[0.0,0.25,0.25],[0.125,-0.125,-0.125]], + [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[0.0,0.25,0.25],[0.0,0.25,0.25]], + [[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.125,-0.125,0.125]], + [[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]], + [[-0.125,-0.125,0.125],[0.0,-0.25,0.25],[0.0,0.25,-0.25]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[-0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[0.25,0.0,0.25],[0.25,0.0,0.25]], + [[0.125,0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]], + [[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]], + [[0.25,0.0,0.25],[-0.375,-0.375,0.375],[-0.25,0.25,0.0],[-0.125,-0.125,0.125]], + [[0.125,-0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]], + [[-0.25,-0.0,-0.25],[0.25,0.0,0.25],[0.25,0.0,0.25],[0.25,0.0,0.25]], + [[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[-0.25,0.0,0.25],[0.25,0.0,-0.25],[-0.125,0.125,0.125]], + [[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]], + [[0.125,0.125,0.125],[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]], + [[-0.0,0.0,0.5],[0.0,0.0,0.5]], + [[0.0,0.0,-0.5],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]], + [[-0.25,-0.0,-0.25],[-0.375,0.375,0.375],[-0.25,-0.25,0.0],[-0.125,0.125,0.125]], + [[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[0.25,0.0,-0.25]], + [[-0.0,0.0,0.5],[-0.25,0.25,0.25],[-0.125,0.125,0.125]], + [[-0.25,0.0,0.25],[0.25,0.0,-0.25]], + [[0.5,0.0,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]], + [[-0.25,0.0,-0.25],[0.375,-0.375,-0.375],[0.0,0.25,-0.25],[-0.125,0.125,0.125]], + [[-0.25,0.25,-0.25],[-0.25,0.25,-0.25],[-0.125,0.125,-0.125],[-0.125,0.125,-0.125]], + [[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]], + [[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[-0.25,0.0,0.25]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[-0.125,0.125,0.125]], + [[-0.125,-0.125,0.125],[-0.125,0.125,0.125]], + [[0.375,-0.375,0.375],[0.0,-0.25,-0.25],[-0.125,0.125,-0.125],[0.25,0.25,0.0]], + [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[-0.125,-0.125,0.125]], + [[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]], + [[0.0,-0.25,0.25],[0.0,-0.25,0.25]], + [[-0.125,-0.125,0.125],[-0.25,-0.25,0.0],[0.25,0.25,-0.0]], + [[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.125,-0.125,0.125]], + [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]], + [[0.125,0.125,0.125],[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]], + [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[-0.125,-0.125,0.125]], + [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[0.25,0.25,-0.0]], + [[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]], + [[-0.375,0.375,-0.375],[-0.25,-0.25,0.0],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]], + [[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.125,0.125,0.125],[0.25,-0.25,0.0],[-0.25,0.25,0.0]], + [[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]], + [[0.125,0.125,0.125],[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]], + [[-0.375,-0.375,-0.375],[-0.25,0.0,0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]], + [[0.25,-0.25,0.0],[-0.25,0.25,0.0],[0.125,-0.125,0.125]], + [[0.0,0.5,0.0],[0.0,-0.5,0.0]], + [[0.0,0.5,0.0],[0.125,-0.125,0.125],[-0.25,0.25,-0.25]], + [[0.0,0.5,0.0],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]], + [[0.25,-0.25,0.0],[-0.25,0.25,0.0]], + [[-0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]], + [[0.0,0.25,-0.25],[0.375,-0.375,-0.375],[-0.125,0.125,0.125],[0.25,0.25,0.0]], + [[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125],[0.125,-0.125,0.125]], + [[0.125,-0.125,0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]], + [[0.25,0.25,-0.25],[0.25,0.25,-0.25],[0.125,0.125,-0.125],[-0.125,-0.125,0.125]], + [[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,0.125,0.125]], + [[-0.125,0.125,0.125],[0.125,-0.125,0.125]], + [[-0.375,-0.375,0.375],[-0.0,0.25,0.25],[0.125,0.125,-0.125],[-0.25,-0.0,-0.25]], + [[0.0,-0.25,0.25],[0.0,0.25,-0.25],[0.125,-0.125,0.125]], + [[0.125,-0.125,0.125],[-0.25,-0.0,-0.25],[0.25,0.0,0.25]], + [[0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25]], + [[0.0,-0.25,0.25],[0.0,0.25,-0.25]], + [[0.125,0.125,0.125],[0.125,-0.125,0.125]], + [[0.125,-0.125,0.125]], + [[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25]], + [[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.125,0.125,0.125]], + [[0.375,0.375,0.375],[0.0,0.25,-0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]], + [[0.125,-0.125,-0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]], + [[0.125,0.125,0.125],[0.375,0.375,0.375],[0.0,-0.25,0.25],[-0.25,0.0,0.25]], + [[-0.25,0.0,0.25],[-0.25,0.0,0.25],[0.125,-0.125,-0.125]], + [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[-0.125,0.125,0.125]], + [[-0.125,0.125,0.125],[0.125,-0.125,-0.125]], + [[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.25,0.25,0.25],[0.125,0.125,0.125]], + [[-0.125,-0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[0.0,0.0,-0.5],[0.25,0.25,0.25],[-0.125,-0.125,-0.125]], + [[0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[0.0,-0.5,0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]], + [[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[0.0,-0.25,-0.25],[0.0,0.25,0.25]], + [[0.125,-0.125,-0.125]], + [[0.5,0.0,0.0],[0.5,0.0,0.0]], + [[-0.5,0.0,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]], + [[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125]], + [[0.25,-0.25,0.0],[0.25,-0.25,0.0]], + [[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]], + [[-0.25,0.0,0.25],[-0.25,0.0,0.25]], + [[0.125,0.125,0.125],[-0.125,0.125,0.125]], + [[-0.125,0.125,0.125]], + [[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]], + [[0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.25,-0.0,-0.25],[0.25,0.0,0.25]], + [[0.125,-0.125,0.125]], + [[-0.25,-0.25,0.0],[0.25,0.25,-0.0]], + [[-0.125,-0.125,0.125]], + [[0.125,0.125,0.125]], + [[0,0,0]]] + + +def compute_surface_distances(mask_gt, mask_pred, spacing_mm): + """Compute closest distances from all surface points to the other surface. + + Finds all surface elements "surfels" in the ground truth mask `mask_gt` and + the predicted mask `mask_pred`, computes their area in mm^2 and the distance + to the closest point on the other surface. It returns two sorted lists of + distances together with the corresponding surfel areas. If one of the masks + is empty, the corresponding lists are empty and all distances in the other + list are `inf` + + Args: + mask_gt: 3-dim Numpy array of type bool. The ground truth mask. + mask_pred: 3-dim Numpy array of type bool. The predicted mask. + spacing_mm: 3-element list-like structure. Voxel spacing in x0, x1 and x2 + direction + + Returns: + A dict with + "distances_gt_to_pred": 1-dim numpy array of type float. The distances in mm + from all ground truth surface elements to the predicted surface, + sorted from smallest to largest + "distances_pred_to_gt": 1-dim numpy array of type float. The distances in mm + from all predicted surface elements to the ground truth surface, + sorted from smallest to largest + "surfel_areas_gt": 1-dim numpy array of type float. The area in mm^2 of + the ground truth surface elements in the same order as + distances_gt_to_pred + "surfel_areas_pred": 1-dim numpy array of type float. The area in mm^2 of + the predicted surface elements in the same order as + distances_pred_to_gt + + """ + + # compute the area for all 256 possible surface elements + # (given a 2x2x2 neighbourhood) according to the spacing_mm + neighbour_code_to_surface_area = np.zeros([256]) + for code in range(256): + normals = np.array(neighbour_code_to_normals[code]) + sum_area = 0 + for normal_idx in range(normals.shape[0]): + # normal vector + n = np.zeros([3]) + n[0] = normals[normal_idx, 0] * spacing_mm[1] * spacing_mm[2] + n[1] = normals[normal_idx, 1] * spacing_mm[0] * spacing_mm[2] + n[2] = normals[normal_idx, 2] * spacing_mm[0] * spacing_mm[1] + area = np.linalg.norm(n) + sum_area += area + neighbour_code_to_surface_area[code] = sum_area + + # compute the bounding box of the masks to trim + # the volume to the smallest possible processing subvolume + mask_all = mask_gt | mask_pred + bbox_min = np.zeros(3, np.int64) + bbox_max = np.zeros(3, np.int64) + + # max projection to the x0-axis + proj_0 = np.max(np.max(mask_all, axis=2), axis=1) + idx_nonzero_0 = np.nonzero(proj_0)[0] + if len(idx_nonzero_0) == 0: + return {"distances_gt_to_pred": np.array([]), + "distances_pred_to_gt": np.array([]), + "surfel_areas_gt": np.array([]), + "surfel_areas_pred": np.array([])} + + bbox_min[0] = np.min(idx_nonzero_0) + bbox_max[0] = np.max(idx_nonzero_0) + + # max projection to the x1-axis + proj_1 = np.max(np.max(mask_all, axis=2), axis=0) + idx_nonzero_1 = np.nonzero(proj_1)[0] + bbox_min[1] = np.min(idx_nonzero_1) + bbox_max[1] = np.max(idx_nonzero_1) + + # max projection to the x2-axis + proj_2 = np.max(np.max(mask_all, axis=1), axis=0) + idx_nonzero_2 = np.nonzero(proj_2)[0] + bbox_min[2] = np.min(idx_nonzero_2) + bbox_max[2] = np.max(idx_nonzero_2) + + print("bounding box min = {}".format(bbox_min)) + print("bounding box max = {}".format(bbox_max)) + + # crop the processing subvolume. + # we need to zeropad the cropped region with 1 voxel at the lower, + # the right and the back side. This is required to obtain the "full" + # convolution result with the 2x2x2 kernel + cropmask_gt = np.zeros((bbox_max - bbox_min) + 2, np.uint8) + cropmask_pred = np.zeros((bbox_max - bbox_min) + 2, np.uint8) + + cropmask_gt[0:-1, 0:-1, 0:-1] = mask_gt[bbox_min[0]:bbox_max[0] + 1, + bbox_min[1]:bbox_max[1] + 1, + bbox_min[2]:bbox_max[2] + 1] + + cropmask_pred[0:-1, 0:-1, 0:-1] = mask_pred[bbox_min[0]:bbox_max[0] + 1, + bbox_min[1]:bbox_max[1] + 1, + bbox_min[2]:bbox_max[2] + 1] + + # compute the neighbour code (local binary pattern) for each voxel + # the resultsing arrays are spacially shifted by minus half a voxel in each axis. + # i.e. the points are located at the corners of the original voxels + kernel = np.array([[[128, 64], + [32, 16]], + [[8, 4], + [2, 1]]]) + neighbour_code_map_gt = scipy.ndimage.filters.correlate(cropmask_gt.astype(np.uint8), kernel, mode="constant", + cval=0) + neighbour_code_map_pred = scipy.ndimage.filters.correlate(cropmask_pred.astype(np.uint8), kernel, mode="constant", + cval=0) + + # create masks with the surface voxels + borders_gt = ((neighbour_code_map_gt != 0) & (neighbour_code_map_gt != 255)) + borders_pred = ((neighbour_code_map_pred != 0) & (neighbour_code_map_pred != 255)) + + # compute the distance transform (closest distance of each voxel to the surface voxels) + if borders_gt.any(): + distmap_gt = scipy.ndimage.morphology.distance_transform_edt(~borders_gt, sampling=spacing_mm) + else: + distmap_gt = np.Inf * np.ones(borders_gt.shape) + + if borders_pred.any(): + distmap_pred = scipy.ndimage.morphology.distance_transform_edt(~borders_pred, sampling=spacing_mm) + else: + distmap_pred = np.Inf * np.ones(borders_pred.shape) + + # compute the area of each surface element + surface_area_map_gt = neighbour_code_to_surface_area[neighbour_code_map_gt] + surface_area_map_pred = neighbour_code_to_surface_area[neighbour_code_map_pred] + + # create a list of all surface elements with distance and area + distances_gt_to_pred = distmap_pred[borders_gt] + distances_pred_to_gt = distmap_gt[borders_pred] + surfel_areas_gt = surface_area_map_gt[borders_gt] + surfel_areas_pred = surface_area_map_pred[borders_pred] + + # sort them by distance + if distances_gt_to_pred.shape != (0,): + sorted_surfels_gt = np.array(sorted(zip(distances_gt_to_pred, surfel_areas_gt))) + distances_gt_to_pred = sorted_surfels_gt[:, 0] + surfel_areas_gt = sorted_surfels_gt[:, 1] + + if distances_pred_to_gt.shape != (0,): + sorted_surfels_pred = np.array(sorted(zip(distances_pred_to_gt, surfel_areas_pred))) + distances_pred_to_gt = sorted_surfels_pred[:, 0] + surfel_areas_pred = sorted_surfels_pred[:, 1] + + return {"distances_gt_to_pred": distances_gt_to_pred, + "distances_pred_to_gt": distances_pred_to_gt, + "surfel_areas_gt": surfel_areas_gt, + "surfel_areas_pred": surfel_areas_pred} + + +def compute_average_surface_distance(surface_distances): + distances_gt_to_pred = surface_distances["distances_gt_to_pred"] + distances_pred_to_gt = surface_distances["distances_pred_to_gt"] + surfel_areas_gt = surface_distances["surfel_areas_gt"] + surfel_areas_pred = surface_distances["surfel_areas_pred"] + average_distance_gt_to_pred = np.sum(distances_gt_to_pred * surfel_areas_gt) / np.sum(surfel_areas_gt) + average_distance_pred_to_gt = np.sum(distances_pred_to_gt * surfel_areas_pred) / np.sum(surfel_areas_pred) + return (average_distance_gt_to_pred, average_distance_pred_to_gt) + + +def compute_robust_hausdorff(surface_distances, percent): + distances_gt_to_pred = surface_distances["distances_gt_to_pred"] + distances_pred_to_gt = surface_distances["distances_pred_to_gt"] + surfel_areas_gt = surface_distances["surfel_areas_gt"] + surfel_areas_pred = surface_distances["surfel_areas_pred"] + if len(distances_gt_to_pred) > 0: + surfel_areas_cum_gt = np.cumsum(surfel_areas_gt) / np.sum(surfel_areas_gt) + idx = np.searchsorted(surfel_areas_cum_gt, percent / 100.0) + perc_distance_gt_to_pred = distances_gt_to_pred[min(idx, len(distances_gt_to_pred) - 1)] + else: + perc_distance_gt_to_pred = np.Inf + + if len(distances_pred_to_gt) > 0: + surfel_areas_cum_pred = np.cumsum(surfel_areas_pred) / np.sum(surfel_areas_pred) + idx = np.searchsorted(surfel_areas_cum_pred, percent / 100.0) + perc_distance_pred_to_gt = distances_pred_to_gt[min(idx, len(distances_pred_to_gt) - 1)] + else: + perc_distance_pred_to_gt = np.Inf + + return max(perc_distance_gt_to_pred, perc_distance_pred_to_gt) + + +def compute_surface_overlap_at_tolerance(surface_distances, tolerance_mm): + distances_gt_to_pred = surface_distances["distances_gt_to_pred"] + distances_pred_to_gt = surface_distances["distances_pred_to_gt"] + surfel_areas_gt = surface_distances["surfel_areas_gt"] + surfel_areas_pred = surface_distances["surfel_areas_pred"] + rel_overlap_gt = np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm]) / np.sum(surfel_areas_gt) + rel_overlap_pred = np.sum(surfel_areas_pred[distances_pred_to_gt <= tolerance_mm]) / np.sum(surfel_areas_pred) + return (rel_overlap_gt, rel_overlap_pred) + + +def compute_surface_dice_at_tolerance(surface_distances, tolerance_mm): + distances_gt_to_pred = surface_distances["distances_gt_to_pred"] + distances_pred_to_gt = surface_distances["distances_pred_to_gt"] + surfel_areas_gt = surface_distances["surfel_areas_gt"] + surfel_areas_pred = surface_distances["surfel_areas_pred"] + overlap_gt = np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm]) + overlap_pred = np.sum(surfel_areas_pred[distances_pred_to_gt <= tolerance_mm]) + surface_dice = (overlap_gt + overlap_pred) / ( + np.sum(surfel_areas_gt) + np.sum(surfel_areas_pred)) + return surface_dice \ No newline at end of file diff --git a/evaluation/mean_average_precision_calculations.py b/evaluation/mean_average_precision_calculations.py new file mode 100644 index 0000000..5de1cd2 --- /dev/null +++ b/evaluation/mean_average_precision_calculations.py @@ -0,0 +1,183 @@ +import numpy as np +from pandas import DataFrame +from scipy.optimize import linear_sum_assignment as hungarian_algorithm + + +def compute_iou(mask_gt, mask_pred): + """ + Compute the intersection over union (https://en.wikipedia.org/wiki/Jaccard_index) + + compute the intersectin over union between the ground truth mask `mask_gt` + and the predicted mask `mask_pred`. + + Args: + mask_gt: 3-dim Numpy array of type bool. The ground truth mask. + mask_pred: 3-dim Numpy array of type bool. The predicted mask. + + Returns: + the iou coeffcient as float. If both masks are empty, the result is 0 + """ + mask_gt = mask_gt.astype('bool') + mask_pred = mask_pred.astype('bool') + overlap = mask_gt * mask_pred # Logical AND + union = mask_gt + mask_pred # Logical OR + iou = overlap.sum() / float(union.sum()) # Treats "True" as 1, + return iou + + +def compute_statistics(mask_gt, mask_pred): + """ + Compute Statistic + + compute statistics (TP, FP, FN, precision, recall) between the ground truth mask `mask_gt` + and the predicted mask `mask_pred`. + TP = True positive (defined as an iou>=0.03) + FP = False positive + FN = False negative + precision = true_positive / (true_positive + false_positive) + recall = true_positive / (true_positive + false_negative) + + Args: + mask_gt: 3-dim Numpy array of type bool. The ground truth mask. + mask_pred: 3-dim Numpy array of type bool. The predicted mask. + + Returns: + output = dict( + true_positive=true_positive, + false_positive=false_positive, + false_negative=false_negative, + precision=precision, + recall=recall + ) + """ + # define constants + min_iou_for_match = 0.03 + + # get number of labels in image + instances_gt = list(np.unique(mask_gt)) + instances_pred = list(np.unique(mask_pred)) + + # remove background + instances_gt = instances_gt[1:] + instances_pred = instances_pred[1:] + + # create performance matrix + performance_matrix = np.zeros((len(instances_gt), len(instances_pred))) + masks = [] + + # calculate dice score for each ground truth to predicted instance + for counter_gt, instance_gt in enumerate(instances_gt): + + # create binary mask for current gt instance + gt = mask_gt.copy() + gt[mask_gt != instance_gt] = 0 + gt[mask_gt == instance_gt] = 1 + + masks_row = [] + for counter_pred, instance_pred in enumerate(instances_pred): + # make binary mask for current predicted instance + prediction = mask_pred.copy() + prediction[mask_pred != instance_pred] = 0 + prediction[mask_pred == instance_pred] = 1 + + # calculate iou + # show_image(gt, prediction) + iou = compute_iou(gt, prediction) + performance_matrix[counter_gt, counter_pred] = iou + masks_row.append([gt, prediction]) + masks.append(masks_row) + + # delete all matches smaller than threshold + performance_matrix[performance_matrix < min_iou_for_match] = 0 + + # assign instrument instances according to hungarian algorithm + label_assignment = hungarian_algorithm(performance_matrix * -1) + label_nr_gt, label_nr_pred = label_assignment + + # get performance per instance + + true_positive_list = [] + for i in range(len(label_nr_gt)): + instance_iou = performance_matrix[label_nr_gt[i], label_nr_pred[i]] + true_positive_list.append(instance_iou) + true_positive_list = list(filter(lambda a: a != 0, true_positive_list)) # delete all 0s assigned to a label + + true_positive = len(true_positive_list) + false_negative = len(instances_gt) - true_positive + false_positive = len(instances_pred) - true_positive + + try: + precision = true_positive / (true_positive + false_positive) + except ZeroDivisionError: + precision = 0 + try: + recall = true_positive / (true_positive + false_negative) + except ZeroDivisionError: + recall = 0 + + output = dict( + true_positive=true_positive, + false_positive=false_positive, + false_negative=false_negative, + precision=precision, + recall=recall + ) + + return output + + +def compute_mean_average_precision(statistic_list): + """ + Compute the mean average precision: + (https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Mean_average_precision) + + We define average precision as Area under Curve AUC) + https://medium.com/@jonathan_hui/map-mean-average-precision-for-object-detection-45c121a31173 + + Args: + statistic_list: 1-dim list, containing statistics dicts (dict definition, see function compute_statistics). + + Returns: + the area_under_curve as float + ) + """ + + # create data frame + data_frame = DataFrame(columns=["true_positive", "false_positive", "false_negative", "precision", "recall"]) + + # add data + data_frame = data_frame.append(statistic_list) + data_frame = data_frame.reset_index() + + # interpolate precision with highest recall for precision + data_frame = data_frame.sort_values(by="recall", ascending=False) + precision_interpolated = [] + current_highest_value = 0 + for index, row in data_frame.iterrows(): + if row.precision > current_highest_value: + current_highest_value = row.precision + precision_interpolated.append(current_highest_value) + data_frame['precision_interpolated'] = precision_interpolated + + # get changes in interpolated precision curve + data_frame_grouped = data_frame.groupby("recall") + changes = [] + for item in data_frame_grouped.groups.items(): + current_recall = item[0] + idx_precision = item[1][0] + current_precision_interpolated = data_frame.loc[idx_precision].precision_interpolated + change = dict(recall=current_recall, precision_interpolated=current_precision_interpolated) + changes.append(change) + # add end and starting point + if changes[0]["recall"] != 0.0: + changes.insert(0, dict(recall=0, precision_interpolated=changes[0]["precision_interpolated"])) + if current_recall < 1: + changes.append(dict(recall=1, precision_interpolated=current_precision_interpolated)) + + # calculate area under curve + area_under_curve = 0 + for i in range(1, len(changes)): + precision_area = (changes[i]["recall"] - changes[i - 1]["recall"]) * changes[i]["precision_interpolated"] + area_under_curve += precision_area + + return area_under_curve diff --git a/evaluation/test/images/img1/instrument_instances.png b/evaluation/test/images/img1/instrument_instances.png new file mode 100644 index 0000000..d8c0dec Binary files /dev/null and b/evaluation/test/images/img1/instrument_instances.png differ diff --git a/evaluation/test/images/img1/raw.png b/evaluation/test/images/img1/raw.png new file mode 100644 index 0000000..15296cb Binary files /dev/null and b/evaluation/test/images/img1/raw.png differ diff --git a/evaluation/test/images/img2/instrument_instances.png b/evaluation/test/images/img2/instrument_instances.png new file mode 100644 index 0000000..e5ad9ef Binary files /dev/null and b/evaluation/test/images/img2/instrument_instances.png differ diff --git a/evaluation/test/images/img2/raw.png b/evaluation/test/images/img2/raw.png new file mode 100644 index 0000000..4b90477 Binary files /dev/null and b/evaluation/test/images/img2/raw.png differ diff --git a/evaluation/test/images/img3/instrument_instances.png b/evaluation/test/images/img3/instrument_instances.png new file mode 100644 index 0000000..d84c3c2 Binary files /dev/null and b/evaluation/test/images/img3/instrument_instances.png differ diff --git a/evaluation/test/images/img3/raw.png b/evaluation/test/images/img3/raw.png new file mode 100644 index 0000000..c42c19e Binary files /dev/null and b/evaluation/test/images/img3/raw.png differ diff --git a/evaluation/test/images/test_map/annotation_1.png b/evaluation/test/images/test_map/annotation_1.png new file mode 100644 index 0000000..51b0dbb Binary files /dev/null and b/evaluation/test/images/test_map/annotation_1.png differ diff --git a/evaluation/test/images/test_map/annotation_2.png b/evaluation/test/images/test_map/annotation_2.png new file mode 100644 index 0000000..4ed473b Binary files /dev/null and b/evaluation/test/images/test_map/annotation_2.png differ diff --git a/evaluation/test/images/test_map/annotation_3.png b/evaluation/test/images/test_map/annotation_3.png new file mode 100644 index 0000000..6dd275b Binary files /dev/null and b/evaluation/test/images/test_map/annotation_3.png differ diff --git a/evaluation/test/images/test_map/gt_1.png b/evaluation/test/images/test_map/gt_1.png new file mode 100644 index 0000000..2cb7c10 Binary files /dev/null and b/evaluation/test/images/test_map/gt_1.png differ diff --git a/evaluation/test/images/test_map/gt_2.png b/evaluation/test/images/test_map/gt_2.png new file mode 100644 index 0000000..60f582e Binary files /dev/null and b/evaluation/test/images/test_map/gt_2.png differ diff --git a/evaluation/test/images/test_map/gt_3.png b/evaluation/test/images/test_map/gt_3.png new file mode 100644 index 0000000..0a31d18 Binary files /dev/null and b/evaluation/test/images/test_map/gt_3.png differ diff --git a/evaluation/test/test_detection_metric.py b/evaluation/test/test_detection_metric.py new file mode 100644 index 0000000..2759407 --- /dev/null +++ b/evaluation/test/test_detection_metric.py @@ -0,0 +1,333 @@ +import unittest +import numpy as np + +from evaluation.mean_average_precision_calculations import compute_mean_average_precision, compute_iou, \ + compute_statistics + + +class TestMAPCalculation(unittest.TestCase): + + def test_intersection_over_union(self): + # define ground truth + gt_1 = np.array([[0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0]], np.uint8) + + gt_2 = np.array([[0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0]], np.uint8) + + # full intersection + expected_iou_1 = 1 + pred_1 = np.array([[0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0]], np.uint8) + + # one missing + expected_iou_2 = 0.8888888 + pred_2 = np.array([[0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 1, 0, 1, 0, 0], + [0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0]], np.uint8) + + # just one intersected + expected_iou_3 = 0.11111 + pred_3 = np.array([[0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0]], np.uint8) + + # no intersection + expected_iou_4 = 0 + pred_4 = np.array([[0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1], + [0, 0, 0, 0, 1, 1, 1], + [0, 0, 0, 0, 1, 1, 1]], np.uint8) + + # empty prediction + expected_iou_5 = 0 + pred_5 = np.array([[0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0]], np.uint8) + + delta = 0.0005 + self.assertAlmostEqual(compute_iou(mask_gt=gt_1, mask_pred=pred_1), expected_iou_1, delta=delta) + self.assertAlmostEqual(compute_iou(mask_gt=gt_1, mask_pred=pred_2), expected_iou_2, delta=delta) + self.assertAlmostEqual(compute_iou(mask_gt=gt_1, mask_pred=pred_3), expected_iou_3, delta=delta) + self.assertAlmostEqual(compute_iou(mask_gt=gt_1, mask_pred=pred_4), expected_iou_4, delta=delta) + self.assertAlmostEqual(compute_iou(mask_gt=gt_1, mask_pred=pred_5), expected_iou_5, delta=delta) + self.assertAlmostEqual(compute_iou(mask_gt=gt_2, mask_pred=pred_4), expected_iou_4, delta=delta) + self.assertAlmostEqual(compute_iou(mask_gt=gt_2, mask_pred=pred_5), expected_iou_5, delta=delta) + self.assertTrue(np.isnan(compute_iou(mask_gt=pred_5, mask_pred=pred_5))) + + def test_detection_statistics(self): + # define images + img_1 = np.array([[0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0]], np.uint8) + + img_2 = np.array([[0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0], + [0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0]], np.uint8) + + img_3 = np.array([[1, 1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 2, 2, 0], + [0, 0, 0, 0, 2, 2, 0], + [0, 0, 0, 0, 0, 0, 0]], np.uint8) + + img_4 = np.array([[1, 1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 3, 3, 0], + [0, 0, 0, 0, 3, 3, 0], + [0, 4, 4, 0, 0, 0, 0], + [0, 4, 4, 0, 2, 2, 0], + [5, 5, 0, 0, 2, 2, 0], + [5, 5, 0, 0, 0, 0, 0]], np.uint8) + + # statistics + # precision = true_positive / (true_positive + false_positive) + # recall = true_positive / (true_positive + false_negative) + result_test_case_1 = compute_statistics(mask_gt=img_1, mask_pred=img_1) + expectation_test_case_1 = dict( + true_positive=0, + false_positive=0, + false_negative=0, + precision=0, + recall=0 + ) + result_test_case_2 = compute_statistics(mask_gt=img_1, mask_pred=img_2) + expectation_test_case_2 = dict( + true_positive=0, + false_positive=1, + false_negative=0, + precision=0, + recall=0 + ) + result_test_case_3 = compute_statistics(mask_gt=img_2, mask_pred=img_1) + expectation_test_case_3 = dict( + true_positive=0, + false_positive=0, + false_negative=1, + precision=0, + recall=0 + ) + result_test_case_4 = compute_statistics(mask_gt=img_3, mask_pred=img_2) + expectation_test_case_4 = dict( + true_positive=0, + false_positive=1, + false_negative=2, + precision=0, + recall=0 + ) + result_test_case_5 = compute_statistics(mask_gt=img_3, mask_pred=img_4) + expectation_test_case_5 = dict( + true_positive=2, + false_positive=3, + false_negative=0, + precision=2 / (2 + 3), + recall=2 / (2 + 0) + ) + result_test_case_6 = compute_statistics(mask_gt=img_4, mask_pred=img_4) # expect fp=0, tp=2, fn=3 + expectation_test_case_6 = dict( + true_positive=5, + false_positive=0, + false_negative=0, + precision=1, + recall=1 + ) + + self.assertDictEqual(result_test_case_1, expectation_test_case_1) + self.assertDictEqual(result_test_case_2, expectation_test_case_2) + self.assertDictEqual(result_test_case_3, expectation_test_case_3) + self.assertDictEqual(result_test_case_4, expectation_test_case_4) + self.assertDictEqual(result_test_case_5, expectation_test_case_5) + self.assertDictEqual(result_test_case_6, expectation_test_case_6) + + def test_mean_average_precision(self): + statistics_list_1 = [ + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=1.0, + recall=0.2), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=1.0, + recall=0.4), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.67, + recall=0.4), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.5, + recall=0.4), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.4, + recall=0.4), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.5, + recall=0.6), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.57, + recall=0.8), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.5, + recall=0.8), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.44, + recall=0.8), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.5, + recall=1.0) + ] + + statistics_list_2 = [ + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=1.0, + recall=0.090909), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.5, + recall=0.090909), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.666667, + recall=0.166667), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.75, + recall=0.230769), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.6, + recall=0.230769), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.666667, + recall=0.285714), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.714286, + recall=0.33333), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.75, + recall=0.375), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.66667, + recall=0.375), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.7, + recall=0.411765), + ] + + statistics_list_3 = [ + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=1.0, + recall=0.33), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.5, + recall=0.33), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.67, + recall=0.67), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.5, + recall=0.67), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.4, + recall=0.67), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.5, + recall=1.0), + dict(true_positive=0, + false_positive=0, + false_negative=0, + precision=0.43, + recall=1.0), + ] + + delta = 0.0005 + self.assertAlmostEqual(compute_mean_average_precision(statistics_list_1), (0.4*1.0+0.4*0.57+0.2*0.5), delta=delta) + self.assertAlmostEqual(compute_mean_average_precision(statistics_list_2), (0.09*1+0.285*0.75+0.625*0.7), delta=delta) + self.assertAlmostEqual(compute_mean_average_precision(statistics_list_3), 0.33*1+0.34*0.67+0.33*0.5, delta=delta) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/evaluation/test/test_distances.py b/evaluation/test/test_distances.py new file mode 100644 index 0000000..3177894 --- /dev/null +++ b/evaluation/test/test_distances.py @@ -0,0 +1,147 @@ +import unittest +import numpy as np +from imageio import imread + +# single pixels, 2mm away +from evaluation.dice_calculations import compute_dice_coefficient +from evaluation.distance_calculations import compute_surface_distances, compute_surface_dice_at_tolerance, \ + compute_average_surface_distance, compute_robust_hausdorff, compute_surface_overlap_at_tolerance + + +class TestDiceCalculation(unittest.TestCase): + + def setUp(self): + self.delta = 0.0005 + + def test_surface_dice(self): + path = "images\\img{}\\instrument_instances.png".format(3) + + # read image + x = imread(path) + + # make image binary + x[x < 0.5] = 0 + x[x >= 0.5] = 1 + + mask_gt = np.reshape(x,x.shape + (1,)) + + surface_distances = compute_surface_distances(mask_gt, mask_gt, (1,1,1)) + surface_dice = compute_surface_dice_at_tolerance(surface_distances, 1) + + self.assertAlmostEqual(surface_dice, 1.0, delta=self.delta) + + def test_single_pixels_2mm_away(self): + mask_gt = np.zeros((128, 128, 128), np.uint8) + mask_pred = np.zeros((128, 128, 128), np.uint8) + mask_gt[50, 60, 70] = 1 + mask_pred[50, 60, 72] = 1 + surface_distances = compute_surface_distances(mask_gt, mask_pred, spacing_mm=(3, 2, 1)) + surface_dice_1mm = compute_surface_dice_at_tolerance(surface_distances, 1) + volumetric_dice = compute_dice_coefficient(mask_gt, mask_pred) + print("surface dice at 1mm: {}".format(surface_dice_1mm)) + print("volumetric dice: {}".format(volumetric_dice)) + self.assertAlmostEqual(surface_dice_1mm, 0.5, delta=self.delta) + self.assertAlmostEqual(volumetric_dice, 0.0, delta=self.delta) + + def test_two_cubes(self): + # two cubes. cube 1 is 100x100x100 mm^3 and cube 2 is 102x100x100 mm^3 + mask_gt = np.zeros((100, 100, 100), np.uint8) + mask_pred = np.zeros((100, 100, 100), np.uint8) + spacing_mm = (2, 1, 1) + mask_gt[0:50, :, :] = 1 + mask_pred[0:51, :, :] = 1 + surface_distances = compute_surface_distances(mask_gt, mask_pred, spacing_mm) + expected_average_distance_gt_to_pred = 0.836145008498 + expected_volumetric_dice = 2. * 100 * 100 * 100 / (100 * 100 * 100 + 102 * 100 * 100) + + surface_dice_1mm = compute_surface_dice_at_tolerance(surface_distances, 1) + volumetric_dice = compute_dice_coefficient(mask_gt, mask_pred) + + print("surface dice at 1mm: {}".format(compute_surface_dice_at_tolerance(surface_distances, 1))) + print("volumetric dice: {}".format(compute_dice_coefficient(mask_gt, mask_pred))) + + self.assertAlmostEqual(surface_dice_1mm, expected_average_distance_gt_to_pred, delta=self.delta) + self.assertAlmostEqual(volumetric_dice, expected_volumetric_dice, delta=self.delta) + + def test_empty_mask_in_pred(self): + # test empty mask in prediction + mask_gt = np.zeros((128, 128, 128), np.uint8) + mask_pred = np.zeros((128, 128, 128), np.uint8) + mask_gt[50, 60, 70] = 1 + # mask_pred[50,60,72] = 1 + + surface_distances = compute_surface_distances(mask_gt, mask_pred, spacing_mm=(3, 2, 1)) + + average_surface_distance = compute_average_surface_distance(surface_distances) + hausdorf_100 = compute_robust_hausdorff(surface_distances, 100) + hausdorf_95 = compute_robust_hausdorff(surface_distances, 95) + + surface_overlap_1_mm = compute_surface_overlap_at_tolerance(surface_distances, 1) + surface_dice_1mm = compute_surface_dice_at_tolerance(surface_distances, 1) + volumetric_dice = compute_dice_coefficient(mask_gt, mask_pred) + + print("average surface distance: {} mm".format(average_surface_distance)) + print("hausdorff (100%): {} mm".format(hausdorf_100)) + print("hausdorff (95%): {} mm".format(hausdorf_95)) + print("surface overlap at 1mm: {}".format(surface_overlap_1_mm)) + print("surface dice at 1mm: {}".format(surface_dice_1mm)) + print("volumetric dice: {}".format(volumetric_dice)) + + self.assertAlmostEqual(surface_dice_1mm, 0.0, delta=self.delta) + self.assertAlmostEqual(volumetric_dice, 0.0, delta=self.delta) + + def test_empty_mask_in_gt(self): + # test empty mask in ground truth + mask_gt = np.zeros((128, 128, 128), np.uint8) + mask_pred = np.zeros((128, 128, 128), np.uint8) + # mask_gt[50,60,70] = 1 + mask_pred[50, 60, 72] = 1 + + surface_distances = compute_surface_distances(mask_gt, mask_pred, spacing_mm=(3, 2, 1)) + + average_surface_distance = compute_average_surface_distance(surface_distances) + hausdorf_100 = compute_robust_hausdorff(surface_distances, 100) + hausdorf_95 = compute_robust_hausdorff(surface_distances, 95) + + surface_overlap_1_mm = compute_surface_overlap_at_tolerance(surface_distances, 1) + surface_dice_1mm = compute_surface_dice_at_tolerance(surface_distances, 1) + volumetric_dice = compute_dice_coefficient(mask_gt, mask_pred) + + print("average surface distance: {} mm".format(average_surface_distance)) + print("hausdorff (100%): {} mm".format(hausdorf_100)) + print("hausdorff (95%): {} mm".format(hausdorf_95)) + print("surface overlap at 1mm: {}".format(surface_overlap_1_mm)) + print("surface dice at 1mm: {}".format(surface_dice_1mm)) + print("volumetric dice: {}".format(volumetric_dice)) + + self.assertAlmostEqual(surface_dice_1mm, 0.0, delta=self.delta) + self.assertAlmostEqual(volumetric_dice, 0.0, delta=self.delta) + + def test_empty_mask_in_gt_and_pred(self): + # test both masks empty + mask_gt = np.zeros((128, 128, 128), np.uint8) + mask_pred = np.zeros((128, 128, 128), np.uint8) + # mask_gt[50,60,70] = 1 + # mask_pred[50,60,72] = 1 + surface_distances = compute_surface_distances(mask_gt, mask_pred, spacing_mm=(3, 2, 1)) + + average_surface_distance = compute_average_surface_distance(surface_distances) + hausdorf_100 = compute_robust_hausdorff(surface_distances, 100) + hausdorf_95 = compute_robust_hausdorff(surface_distances, 95) + + surface_overlap_1_mm = compute_surface_overlap_at_tolerance(surface_distances, 1) + surface_dice_1mm = compute_surface_dice_at_tolerance(surface_distances, 1) + volumetric_dice = compute_dice_coefficient(mask_gt, mask_pred) + + print("average surface distance: {} mm".format(average_surface_distance)) + print("hausdorff (100%): {} mm".format(hausdorf_100)) + print("hausdorff (95%): {} mm".format(hausdorf_95)) + print("surface overlap at 1mm: {}".format(surface_overlap_1_mm)) + print("surface dice at 1mm: {}".format(surface_dice_1mm)) + print("volumetric dice: {}".format(volumetric_dice)) + + self.assertTrue(np.isnan(surface_dice_1mm)) + self.assertTrue(np.isnan(volumetric_dice)) + +if __name__ == '__main__': + unittest.main() diff --git a/evaluation/test/test_instance_dice.py b/evaluation/test/test_instance_dice.py new file mode 100644 index 0000000..d764f3b --- /dev/null +++ b/evaluation/test/test_instance_dice.py @@ -0,0 +1,54 @@ +import unittest + +from imageio import imread +from evaluation.dice_calculations import compute_dice_coefficient, compute_dice_coefficient_per_instance + + +class TestDiceCalculation(unittest.TestCase): + def test_dice_coefficient(self): + # paths + x_path = "images/img{}/instrument_instances.png".format(1) + y_path = "images/img{}/instrument_instances.png".format(2) + + # read images + x = imread(x_path) + y = imread(y_path) + + # make images binary + x[x < 0.5] = 0 + x[x >= 0.5] = 1 + y[y < 0.5] = 0 + y[y >= 0.5] = 1 + + # calculate dice + dice = compute_dice_coefficient(x, y) + + # check if correct + expected_dice = 0.011 + delta = 0.0005 + self.assertAlmostEqual(dice, expected_dice, delta=delta) + + def test_multiple_instance_dice_coefficient(self): + # paths + x_path = "images/img{}/instrument_instances.png".format(2) + y_path = "images/img{}/instrument_instances.png".format(3) + + # read images + x = imread(x_path) + y = imread(y_path) + + # calculate instance dice + instance_dice_scores = compute_dice_coefficient_per_instance(x, y) + + # check if correct + expected_dice_scores = dict(background=0.8789, instrument_0=0, instrument_1=0.1676) + delta = 0.0005 + + for dice_key, expected_dice_key in zip(instance_dice_scores, expected_dice_scores): + dice = instance_dice_scores[dice_key] + expected_dice = expected_dice_scores[expected_dice_key] + self.assertAlmostEqual(dice, expected_dice, delta=delta) + + +if __name__ == '__main__': + unittest.main() diff --git a/synapse/download_data.py b/synapse/download_data.py new file mode 100644 index 0000000..d52fb4a --- /dev/null +++ b/synapse/download_data.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +"""DOWNLOAD ROBUST-MIS CHALLENGE DATA + +1. Create an account for synapse +2. Register for the Challenge: https://www.synapse.org/#!Synapse:syn18779624/wiki/591266 +3. Install the synapse client: `pip install synapseclient` +4. Insert the files_synapse_id as project_id below +5. Add your credentials at the end of the script +6. Run the script and get the data +7. Have fun :-) + +""" + +import synapseclient +import synapseutils + +def download_data(email, password, local_folder,project_id): + print("Start downloading") + + # login to Synapse + syn = synapseclient.login(email=email, password=password, rememberMe=True) + + # download all the files in folder files_synapse_id to a local folder + all_files = synapseutils.syncFromSynapse(syn, entity=project_id, path=local_folder) + + print("Finished downloading") + + +if __name__ is "main": + + # settings + project_id = "syn20575265" + local_folder = "" + email = "" + password = "" + + # download data + download_data(email, password, local_folder, project_id) \ No newline at end of file diff --git a/synapse/download_docker.py b/synapse/download_docker.py new file mode 100644 index 0000000..4cb610e --- /dev/null +++ b/synapse/download_docker.py @@ -0,0 +1,26 @@ +import synapseclient as sc +import docker + +email = "" +password = "" +project_id = "syn20575265" +evaluation_queues = [9614245, 9614272, 9614273] + +# login to synapse +syn = sc.login(email=email, password=password) + +# get docker env +client = docker.from_env() + +# download dockers +for evaluation_queue in evaluation_queues: + for submission in syn.getSubmissions(evaluation_queue): + try: + docker_name=submission["dockerRepositoryName"] + print(f"Download {docker_name}") + client.images.pull(docker_name) + except KeyError: + pass #print("Invalid submission, skip") + except docker.errors.ImageNotFound: + print(f"Not allowed to download the docker {docker_name}") + diff --git a/synapse/run_docker.sh b/synapse/run_docker.sh new file mode 100755 index 0000000..43e544e --- /dev/null +++ b/synapse/run_docker.sh @@ -0,0 +1,6 @@ +INPUT_FOLDER="" +OUTPUT_FOLDER="" +DOCKER_NAME="docker.synapse.org/syn20685953/cami_siat_8.30:latest" # example docker name +#DOCKER_COMMAND='nvidia-docker run --ipc=host -v "'$INPUT_FOLDER'/:/input" -v "'$OUTPUT_FOLDER':/output" $DOCKER_NAME /usr/local/bin/run_network.sh' +#echo $DOCKER_COMMAND +sudo docker run --ipc=host -v $INPUT_FOLDER/:/input -v $OUTPUT_FOLDER:/output $DOCKER_NAME /usr/local/bin/run_network.sh