diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..0df94ea --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,15 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..ccb5991 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..5ba649c --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/probunet.iml b/.idea/probunet.iml new file mode 100644 index 0000000..6711606 --- /dev/null +++ b/.idea/probunet.iml @@ -0,0 +1,11 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/.idea/workspace.xml b/.idea/workspace.xml new file mode 100644 index 0000000..80f1354 --- /dev/null +++ b/.idea/workspace.xml @@ -0,0 +1,403 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + TODO + imp + + + + + + + + + + + true + DEFINITION_ORDER + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1536748173508 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..938c674 --- /dev/null +++ b/LICENSE @@ -0,0 +1,203 @@ +Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2018 Simon Kohl. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md index 3fa444e..5d515c7 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,106 @@ # Probabilistic U-Net -Coming shortly: Re-implementation of the model described in `A Probabilistic U-Net for Segmentation of Ambiguous Images' ([arxiv.org/abs/1806.05034](https://arxiv.org/abs/1806.05034)). +Re-implementation of the model described in `A Probabilistic U-Net for Segmentation of Ambiguous Images' ([arxiv.org/abs/1806.05034](https://arxiv.org/abs/1806.05034)). +The architecture of the Probabilistic U-Net is depicted below: subfigure a) shows sampling and b) the training setup: +![](assets/architecture.png) + +Below see samples conditioned on held-out validation set images from the (stochastic) CityScapes data set: +![](assets/10_image_16_sample.gif) + +## Setup package in virtual environment + +``` +git clone https://phabricator.mitk.org/source/skohl/browse/master/prob_unet . +cd prob_unet/ +virtualenv -p python3 venv +source venv/bin/activate +pip3 install -e . +``` + +## Install batch-generators for data augmentation +``` +cd .. +git clone https://github.com/MIC-DKFZ/batchgenerators +cd batchgenerators +pip3 install nilearn scikit-image nibabel +pip3 install -e . +cd prob_unet +``` + +## Download & preprocess the Cityscapes dataset + +1) Create a login account on the Cityscapes website: https://www.cityscapes-dataset.com/ +2) Once you've logged in, download the train, val and test annotations and images: + - Annotations: [gtFine_trainvaltest.zip](https://www.cityscapes-dataset.com/file-handling/?packageID=1) (241MB) + - Images: [leftImg8bit_trainvaltest.zip](https://www.cityscapes-dataset.com/file-handling/?packageID=3) (11GB) +3) unzip the data (unzip _trainvaltest.zip) and adjust `raw_data_dir` (full path to unzipped files) and `out_dir` (full path to desired output directory) in `preprocessing_config.py` +4) bilinearly rescale the data to a resolution of 256 x 512 and save as numpy arrays by running +``` +cd cityscapes +python3 preprocessing.py +cd .. +``` + +## Training + +[skip to evaluation in case you only want to use the pretrained model.] +modify `data_dir` and `exp_dir` in `scripts/prob_unet_config.py` then: +``` +cd training +python3 train_prob_unet.py --config prob_unet_config.py +``` + +## Evaluation + +Load your own trained model or use a pretrained model. A set of pretrained weights can be downloaded from [zenodo.org](https://zenodo.org/record/1419051#.W5utoOEzYUE) (187MB). After down-loading, unpack the file via +`tar -xvzf pretrained_weights.tar.gz`, e.g. in `/model`. In either case (using your own or the pretrained model), modify the `data_dir` and +`exp_dir` in `evaluation/cityscapes_eval_config.py` to match you paths. + +then first write samples (defaults to 16 segmentation samples for each of the 500 validation images): +``` +cd ../evaluation +python3 eval_cityscapes.py --write_samples +``` +followed by their evaluation (which is multi-threaded and thus reasonably fast): +``` +python3 eval_cityscapes.py --eval_samples +``` +The evaluation produces a dictionary holding the results. These can be visualized by launching an ipython notbook: +``` +jupyter notebook evaluation_plots.ipynb +``` +The following results are obtained from the pretrained model using above notebook: +![](assets/validation_results.png) + +## Tests + +The evaluation metrics are under test-coverage. Run the tests as follows: +``` +cd ../tests/evaluation +python3 -m pytest eval_tests.py +``` + +## Deviations from original work + +The code found in this repository was not used in the original paper and slight modifications apply: + +- training on a single gpu (Titan Xp) instead of distributed training, which is not supported in this implementation +- average-pooling rather than bilinear interpolation is used for down-sampling operations in the model +- the number of conv kernels is kept constant after the 3rd scale as opposed to strictly doubling it after each scale (for reduction of memory footprint) +- HeNormal weight initialization worked better than a orthogonal weight initialization + + +## How to cite this code +Please cite the original publication: +``` +@article{kohl2018probabilistic, + title={A Probabilistic U-Net for Segmentation of Ambiguous Images}, + author={Kohl, Simon AA and Romera-Paredes, Bernardino and Meyer, Clemens and De Fauw, Jeffrey and Ledsam, Joseph R and Maier-Hein, Klaus H and Eslami, SM and Rezende, Danilo Jimenez and Ronneberger, Olaf}, + journal={arXiv preprint arXiv:1806.05034}, + year={2018} +} +``` + +## License +The code is publihed under the [Apache License Version 2.0](LICENSE). \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/assets/10_image_16_sample.gif b/assets/10_image_16_sample.gif new file mode 100644 index 0000000..3793e85 Binary files /dev/null and b/assets/10_image_16_sample.gif differ diff --git a/assets/architecture.png b/assets/architecture.png new file mode 100644 index 0000000..2e685d7 Binary files /dev/null and b/assets/architecture.png differ diff --git a/assets/validation_results.png b/assets/validation_results.png new file mode 100644 index 0000000..36420df Binary files /dev/null and b/assets/validation_results.png differ diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data/cityscapes/__init__.py b/data/cityscapes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data/cityscapes/cityscapes_labels.py b/data/cityscapes/cityscapes_labels.py new file mode 100644 index 0000000..be3e5a9 --- /dev/null +++ b/data/cityscapes/cityscapes_labels.py @@ -0,0 +1,120 @@ +#!/usr/bin/python +# +# Cityscapes labels, code from https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py +# + +from collections import namedtuple + + +#-------------------------------------------------------------------------------- +# Definitions +#-------------------------------------------------------------------------------- + +# a label and all meta information +Label = namedtuple( 'Label' , [ + + 'name' , # The identifier of this label, e.g. 'car', 'person', ... . + # We use them to uniquely name a class + + 'id' , # An integer ID that is associated with this label. + # The IDs are used to represent the label in ground truth images + # An ID of -1 means that this label does not have an ID and thus + # is ignored when creating ground truth images (e.g. license plate). + # Do not modify these IDs, since exactly these IDs are expected by the + # evaluation server. + + 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create + # ground truth images with train IDs, using the tools provided in the + # 'preparation' folder. However, make sure to validate or submit results + # to our evaluation server using the regular IDs above! + # For trainIds, multiple labels might have the same ID. Then, these labels + # are mapped to the same class in the ground truth images. For the inverse + # mapping, we use the label that is defined first in the list below. + # For example, mapping all void-type classes to the same ID in training, + # might make sense for some approaches. + # Max value is 255! + + 'category' , # The name of the category that this label belongs to + + 'categoryId' , # The ID of this category. Used to create ground truth images + # on category level. + + 'hasInstances', # Whether this label distinguishes between single instances or not + + 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored + # during evaluations or not + + 'color' , # The color of this label + ] ) + + +#-------------------------------------------------------------------------------- +# A list of all labels +#-------------------------------------------------------------------------------- + +# Please adapt the train IDs as appropriate for your approach. +# Note that you might want to ignore labels with ID 255 during training. +# Further note that the current train IDs are only a suggestion. You can use whatever you like. +# Make sure to provide your results using the original IDs and not the training IDs. +# Note that many IDs are ignored in evaluation and thus you never need to predict these! + +labels = [ + # name id trainId category catId hasInstances ignoreInEval color + Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), + Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), + Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), + Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), + Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), + Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), + Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), + Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), + Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), + Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), + Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), + Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), + Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), + Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), + Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), + Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), + Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), + Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), + Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ), + Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), + Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), + Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), + Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), + Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), + Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), + Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), + Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), + Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), + Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), + Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), + Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), + Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), + Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), + Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), + Label( 'license plate' , -1 , 255 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), +] + + +#-------------------------------------------------------------------------------- +# Create dictionaries for a fast lookup +#-------------------------------------------------------------------------------- + +# Please refer to the main method below for example usages! + +# name to label object +name2label = { label.name : label for label in labels } +# id to label object +id2label = { label.id : label for label in labels } +# trainId to label object +trainId2label = { label.trainId : label for label in reversed(labels) } +# category to list of label objects +category2labels = {} +for label in labels: + category = label.category + if category in category2labels: + category2labels[category].append(label) + else: + category2labels[category] = [label] diff --git a/data/cityscapes/data_loader.py b/data/cityscapes/data_loader.py new file mode 100644 index 0000000..2c56f67 --- /dev/null +++ b/data/cityscapes/data_loader.py @@ -0,0 +1,353 @@ +# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script to serve the CityScapes dataset.""" + +import os +import sys, glob +import numpy as np +import imp +import logging + +from batchgenerators.dataloading.data_loader import SlimDataLoaderBase +from batchgenerators.transforms.spatial_transforms import MirrorTransform +from batchgenerators.transforms.abstract_transforms import Compose +from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter +from batchgenerators.transforms.spatial_transforms import SpatialTransform +from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform +from batchgenerators.transforms import AbstractTransform +from .cityscapes_labels import labels as cityscapes_labels_tuple + +def loadFiles(label_density, split, input_path, cities=None, instance=False): + """ + Assemble dict of file paths. + :param label_density: string in ['gtFine', 'gtCoarse'] + :param split: string in ['train', 'val', 'test', 'train_extra'] + :param input_path: string + :param cities: list of strings or None + :param instance: bool + :return: dict + """ + input_dir = os.path.join(input_path, split) + logging.info("Assembling file dict from {}.".format(input_dir)) + paths_dict = {} + for path, dirs, files in os.walk(input_dir): + + skip_city = False + if cities is not None: + current_city = path.rsplit('/', 1)[-1] + if current_city not in cities: + skip_city = True + + if not skip_city: + logging.info('Reading from {}'.format(path)) + label_paths, img_paths = searchFiles(path, label_density, instance) + paths_dict = {**paths_dict, **zipPaths(label_paths, img_paths)} + + return paths_dict + + +def searchFiles(path, label_density, instance=False): + """ + Get file paths via wildcard search. + :param path: path to files for each city + :param label_density: string in ['gtFine', 'gtCoarse'] + :param instance: bool + :return: 2 lists + """ + if (instance == True): + label_wildcard_search = os.path.join(path, "*{}_instanceIDs.npy".format(label_density)) + else: + label_wildcard_search = os.path.join(path, "*{}_labelIds.npy".format(label_density)) + label_paths = glob.glob(label_wildcard_search) + label_paths.sort() + img_wildcard_search = os.path.join(path, "*_leftImg8bit.npy") + img_paths = glob.glob(img_wildcard_search) + img_paths.sort() + return label_paths, img_paths + + +def zipPaths(label_paths, img_paths): + """ + zip paths in form of dict. + :param label_paths: list of strings + :param img_paths: list of strings + :return: dict + """ + try: + assert len(label_paths) == len(img_paths) + except: + raise Exception('Missmatch: {} label paths vs. {} img paths!'.format(len(label_paths), len(img_paths))) + + paths_dict = {} + for i, img_path in enumerate(img_paths): + img_spec = ('_').join(img_paths[i].split('/')[-1].split('_')[:-1]) + try: + assert img_spec in label_paths[i] + except: + raise Exception('img and label name mismatch: {} vs. {}'.format(img_paths[i], label_paths[i])) + + paths_dict[img_spec] = {"data": img_paths[i], "seg": label_paths[i], 'img_spec': img_spec} + return paths_dict + + +def augment_gamma(data, gamma_range=(0.5, 2), invert_image=False, epsilon=1e-7, per_channel=False, retain_stats=False, p_per_sample=0.3): + """code by Fabian Isensee, see MIC_DKFZ/batch_generators on github.""" + for sample in range(data.shape[0]): + if np.random.uniform() < p_per_sample: + if invert_image: + data = - data + if not per_channel: + if retain_stats: + mn = data[sample].mean() + sd = data[sample].std() + if np.random.random() < 0.5 and gamma_range[0] < 1: + gamma = np.random.uniform(gamma_range[0], 1) + else: + gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1]) + minm = data[sample].min() + rnge = data[sample].max() - minm + data[sample] = np.power(((data[sample] - minm) / float(rnge + epsilon)), gamma) * rnge + minm + if retain_stats: + data[sample] = data[sample] - data[sample].mean() + mn + data[sample] = data[sample] / (data[sample].std() + 1e-8) * sd + else: + for c in range(data.shape[1]): + if retain_stats: + mn = data[sample][c].mean() + sd = data[sample][c].std() + if np.random.random() < 0.5 and gamma_range[0] < 1: + gamma = np.random.uniform(gamma_range[0], 1) + else: + gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1]) + minm = data[sample][c].min() + rnge = data[sample][c].max() - minm + data[sample][c] = np.power(((data[sample][c] - minm) / float(rnge + epsilon)), gamma) * rnge + minm + if retain_stats: + data[sample][c] = data[sample][c] - data[sample][c].mean() + mn + data[sample][c] = data[sample][c] / (data[sample][c].std() + 1e-8) * sd + return data + + +class GammaTransform(AbstractTransform): + """Augments by changing 'gamma' of the image (same as gamma correction in photos or computer monitors + + Args: + gamma_range (tuple of float): range to sample gamma from. If one value is smaller than 1 and the other one is + larger then half the samples will have gamma <1 and the other >1 (in the inverval that was specified) + + invert_image: whether to invert the image before applying gamma augmentation + + retain_stats: Gamma transformation will alter the mean and std of the data in the patch. If retain_stats=True, + the data will be transformed to match the mean and standard deviation before gamma augmentation + + """ + + def __init__(self, gamma_range=(0.5, 2), invert_image=False, per_channel=False, data_key="data", retain_stats=False, + p_per_sample=0.3, mask_channel_in_seg=None): + self.mask_channel_in_seg = mask_channel_in_seg + self.p_per_sample = p_per_sample + self.retain_stats = retain_stats + self.per_channel = per_channel + self.data_key = data_key + self.gamma_range = gamma_range + self.invert_image = invert_image + + def __call__(self, **data_dict): + data_dict[self.data_key] = augment_gamma(data_dict[self.data_key], self.gamma_range, self.invert_image, + per_channel=self.per_channel, retain_stats=self.retain_stats, + p_per_sample=self.p_per_sample) + return data_dict + + +def map_labels_to_trainId(arr): + """Remap ids to corresponding training Ids. Note that the inplace mapping works because id > trainId here!""" + id2trainId = {label.id:label.trainId for label in cityscapes_labels_tuple} + for id, trainId in id2trainId.items(): + arr[arr == id] = trainId + return arr + + +class AddLossMask(AbstractTransform): + """Splits one-hot segmentation into a segmentation array and a loss mask, + where the loss mask needs to be encoded as the next available integer larger than the last segmentation labels. + + Args: + classes (tuple of int): All the class labels that are in the dataset + + output_key (string): key to use for output of the one hot encoding. Default is 'seg' but that will override any + other existing seg channels. Therefore you have the option to change that. BEWARE: Any non-'seg' segmentations + will not be augmented anymore. Use this only at the very end of your pipeline! + """ + + def __init__(self, label2mask, output_key="loss_mask"): + self.output_key = output_key + self.label2mask = label2mask + + def __call__(self, **data_dict): + seg = data_dict['seg'] + if seg is not None: + data_dict[self.output_key] = (seg == self.label2mask).astype(np.uint8) + else: + from warnings import warn + warn("calling AddLossMask but there is no segmentation") + return data_dict + + +class StochasticLabelSwitches(AbstractTransform): + """ + Stochastically switches labels in a batch of integer-labeled segmentations. + """ + def __init__(self, name2id, label_switches): + self._name2id = name2id + self._label_switches = label_switches + + def __call__(self, **data_dict): + + switched_seg = data_dict['seg'] + batch_size = switched_seg.shape[0] + + for c, p in self._label_switches.items(): + init_id = self._name2id[c] + final_id = self._name2id[c + '_2'] + switch_instances = np.random.binomial(1, p, batch_size) + + for i in range(batch_size): + if switch_instances[i]: + switched_seg[i][switched_seg[i] == init_id] = final_id + + data_dict['seg'] = switched_seg + return data_dict + + +class BatchGenerator(SlimDataLoaderBase): + """ + create the training/validation batch generator. Randomly sample n_batch_size patients + from the data set, (draw a random slice if 2D), pad-crop them to equal sizes and merge to an array. + :param data: data dictionary as provided by 'load_dataset' + :param batch_size: number of patients to sample for the batch + :param pre_crop_size: equal size for merging the patients to a single array (before the final random-crop in data aug.) + :return dictionary containing the batch data / seg / pids + """ + def __init__(self, batch_size, data_dir, label_density='gtFine', data_split='train', resolution='quarter', + cities=None, gt_instances=False, n_batches=None, random=True): + super(BatchGenerator, self).__init__(data=None, batch_size=batch_size) + + data_dir = os.path.join(data_dir, resolution) + self._data_dir = data_dir + self._label_density = label_density + self._gt_instances = gt_instances + self._data_split = data_split + self._random = random + self._n_batches = n_batches + self._batches_generated = 0 + self._data = loadFiles(label_density, data_split, data_dir, cities=cities, instance=gt_instances) + logging.info('{} set comprises {} files.'.format(data_split, len(self._data))) + + def generate_train_batch(self): + + if self._random: + img_ixs = np.random.choice(list(self._data.keys()), self.batch_size, replace=True) + else: + batch_no = self._batches_generated % self._n_batches + img_ixs = [list(self._data.keys())[i] for i in\ + np.arange(batch_no * self.batch_size, (batch_no + 1) * self.batch_size)] + img_batch, seg_batch, ids_batch = [], [], [] + + for b in range(self.batch_size): + + img = np.load(self._data[img_ixs[b]]['data']) / 255. + seg = np.load(self._data[img_ixs[b]]['seg']) + seg = map_labels_to_trainId(seg) + seg = seg[np.newaxis] + ids_batch.append(self._data[img_ixs[b]]['img_spec']) + + img_batch.append(img) + seg_batch.append(seg) + + self._batches_generated += 1 + batch = {'data': np.array(img_batch).astype('float32'), 'seg': np.array(seg_batch).astype('uint8'), + 'id': ids_batch} + return batch + + +def create_data_gen_pipeline(cf, cities=None, data_split='train', do_aug=True, random=True, n_batches=None): + """ + create mutli-threaded train/val/test batch generation and augmentation pipeline. + :param cities: list of strings or None + :param patient_data: dictionary containing one dictionary per patient in the train/test subset + :param test_pids: (optional) list of test patient ids, calls the test generator. + :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing) + :param random: bool, whether to draw random batches or go through data linearly + :return: multithreaded_generator + """ + data_gen = BatchGenerator(cities=cities, batch_size=cf.batch_size, data_dir=cf.data_dir, + label_density=cf.label_density, data_split=data_split, resolution=cf.resolution, + gt_instances=cf.gt_instances, n_batches=n_batches, random=random) + my_transforms = [] + if do_aug: + mirror_transform = MirrorTransform(axes=(3,)) + my_transforms.append(mirror_transform) + spatial_transform = SpatialTransform(patch_size=cf.patch_size[-2:], + patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], + do_elastic_deform=cf.da_kwargs['do_elastic_deform'], + alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], + do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], + angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], + do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], + random_crop=cf.da_kwargs['random_crop'], + border_mode_data=cf.da_kwargs['border_mode_data'], + border_mode_seg=cf.da_kwargs['border_mode_seg'], + border_cval_seg=cf.da_kwargs['border_cval_seg']) + my_transforms.append(spatial_transform) + else: + my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[-2:])) + + my_transforms.append(GammaTransform(cf.da_kwargs['gamma_range'], invert_image=False, per_channel=True, + retain_stats=cf.da_kwargs['gamma_retain_stats'], + p_per_sample=cf.da_kwargs['p_gamma'])) + my_transforms.append(AddLossMask(cf.ignore_label)) + if cf.label_switches is not None: + my_transforms.append(StochasticLabelSwitches(cf.name2trainId, cf.label_switches)) + all_transforms = Compose(my_transforms) + multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, + seeds=range(cf.n_workers)) + return multithreaded_generator + + +def get_train_generators(cf): + """ + wrapper function for creating the training batch generator pipeline. returns the train/val generators + """ + batch_gen = {} + batch_gen['train'] = create_data_gen_pipeline(cf=cf, cities=cf.train_cities, data_split='train', do_aug=True, + n_batches=cf.n_train_batches) + batch_gen['val'] = create_data_gen_pipeline(cf=cf, cities=cf.val_cities, data_split='train', do_aug=False, + random=False, n_batches=cf.n_val_batches) + return batch_gen + +def main(): + """Main entry point for the script.""" + logging.info("start loading.") + cf = imp.load_source('cf', 'config.py') + dict = loadFiles("gtFine", "train", cf.out_dir, False) + logging.info('Contains {} elements.'.format(len(dict))) + logging.info(dict) + data_provider = BatchGenerator(8, cf.out_dir, data_split='val') + batch = next(data_provider) + logging.info(batch['data'].shape, batch['seg'].shape, batch['id']) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/data/cityscapes/preprocessing.py b/data/cityscapes/preprocessing.py new file mode 100644 index 0000000..1b17046 --- /dev/null +++ b/data/cityscapes/preprocessing.py @@ -0,0 +1,99 @@ +# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script to preprocess the Cityscapes dataset.""" + +import os +import numpy as np +from tqdm import tqdm +from PIL import Image +import imp + +resolution_map = {1.0: 'full', 0.5: 'half', 0.25: 'quarter'} + +def resample(img, scale_factor=1.0, interpolation=Image.BILINEAR): + """ + Resample PIL.Image objects. + :param img: PIL.Image object + :param scale_factor: float + :param interpolation: PIL.Image interpoaltion method + :return: PIL.Image object + """ + width, height = img.size + basewidth = width * scale_factor + basewidth = int(basewidth) + wpercent = (basewidth / float(width)) + hsize = int((float(height) * wpercent)) + return img.resize((basewidth, hsize), interpolation) + +def recursive_mkdir(nested_dir_list): + """ + Make the full nested path of directories provided. Order in list implies nesting depth. + :param nested_dir_list: list of strings + :return: + """ + nested_dir = '' + for dir in nested_dir_list: + nested_dir = os.path.join(nested_dir, dir) + if not os.path.isdir(nested_dir): + os.mkdir(nested_dir) + return + +def preprocess(cf): + + for set in list(cf.settings.keys()): + print('Processing {} set.'.format(set)) + + # image dir + image_dir = os.path.join(cf.raw_data_dir, 'leftImg8bit', set) + city_names = os.listdir(image_dir) + + for city in city_names: + print('Processing {}'.format(city)) + city_dir = os.path.join(image_dir, city) + image_names = os.listdir(city_dir) + image_specifiers = ['_'.join(img.split('_')[:3]) for img in image_names] + + for img_spec in tqdm(image_specifiers): + for scale in cf.settings[set]['resolutions']: + recursive_mkdir([cf.out_dir, resolution_map[scale], set, city]) + + # image + img_path = os.path.join(city_dir, img_spec + '_leftImg8bit.png') + img = Image.open(img_path) + if scale != 1.0: + img = resample(img, scale_factor=scale, interpolation=Image.BILINEAR) + img_out_path = os.path.join(cf.out_dir, resolution_map[scale], set, city, img_spec + '_leftImg8bit.npy') + img_arr = np.array(img).astype(np.float32) + + channel_axis = 0 if img_arr.shape[0] == 3 else 2 + if cf.data_format == 'NCHW' and channel_axis != 0: + img_arr = np.transpose(img_arr, axes=[2,0,1]) + np.save(img_out_path, img_arr) + + # labels + for label_density in cf.settings[set]['label_densities']: + label_dir = os.path.join(cf.raw_data_dir, label_density, set, city) + for mod in cf.settings[set]['label_modalities']: + label_spec = img_spec + '_{}_{}'.format(label_density, mod) + label_path = os.path.join(label_dir, label_spec + '.png') + label = Image.open(label_path) + if scale != 1.0: + label = resample(label, scale_factor=scale, interpolation=Image.NEAREST) + label_out_path = os.path.join(cf.out_dir, resolution_map[scale], set, city, label_spec + '.npy') + np.save(label_out_path, np.array(label).astype(np.uint8)) + +if __name__ == "__main__": + cf = imp.load_source('cf', 'preprocessing_config.py') + preprocess(cf) \ No newline at end of file diff --git a/data/cityscapes/preprocessing_config.py b/data/cityscapes/preprocessing_config.py new file mode 100644 index 0000000..548ae40 --- /dev/null +++ b/data/cityscapes/preprocessing_config.py @@ -0,0 +1,28 @@ +# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cityscapes preprocessing config.""" + +raw_data_dir = 'SET_INPUT_DIRECTORY_ABSOLUTE_PATH_HERE' +out_dir = 'SET_OUTPUT_DIRECTORY_ABSOLUTE_PATH_HERE' + +# settings: +settings = { + 'train': {'resolutions': [1.0, 0.5, 0.25], 'label_densities': ['gtFine'], + 'label_modalities': ['labelIds']}, + 'val': {'resolutions': [1.0, 0.5, 0.25], 'label_densities': ['gtFine'], + 'label_modalities': ['labelIds']}, + } + +data_format = 'NCHW' diff --git a/evaluation/__init__.py b/evaluation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/evaluation/cityscapes_eval_config.py b/evaluation/cityscapes_eval_config.py new file mode 100644 index 0000000..5fa1792 --- /dev/null +++ b/evaluation/cityscapes_eval_config.py @@ -0,0 +1,92 @@ +# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cityscapes evaluation config.""" + +import os +from collections import OrderedDict +from data.cityscapes.cityscapes_labels import labels as cs_labels_tuple +from model import pretrained_weights + +config_path = os.path.realpath(__file__) + +######################################### +# data # +######################################### + +data_dir = 'PREPROCESSING_OUTPUT_DIRECTORY_ABSOLUTE_PATH' +resolution = 'quarter' +label_density = 'gtFine' +num_classes = 19 +one_hot_labels = False +ignore_label = 255 +cities = ['frankfurt', 'lindau', 'munster'] + +######################################### +# label-switches # +######################################### + +color_map = {label.trainId:label.color for label in cs_labels_tuple} +color_map[255] = (0.,0.,0.) + +trainId2name = {labels.trainId: labels.name for labels in cs_labels_tuple} +name2trainId = {labels.name: labels.trainId for labels in cs_labels_tuple} + +label_switches = OrderedDict([('sidewalk', 8./17.), ('person', 7./17.), ('car', 6./17.), ('vegetation', 5./17.), ('road', 4./17.)]) +num_classes += len(label_switches) +switched_Id2name = {19+i:list(label_switches.keys())[i] + '_2' for i in range(len(label_switches))} +switched_name2Id = {list(label_switches.keys())[i] + '_2':19+i for i in range(len(label_switches))} +trainId2name = {**trainId2name, **switched_Id2name} +name2trainId = {**name2trainId, **switched_name2Id} + +switched_labels2color = {'road_2': (84, 86, 22), 'person_2': (167, 242, 242), 'vegetation_2': (242, 160, 19), + 'car_2': (30, 193, 252), 'sidewalk_2': (46, 247, 180)} +switched_cmap = {switched_name2Id[i]:switched_labels2color[i] for i in switched_name2Id.keys()} +color_map = {**color_map, **switched_cmap} + +exp_modes = len(label_switches) +num_modes = 2 ** exp_modes + +######################################### +# network # +######################################### + +cuda_visible_devices = '0' +cpu_device = '/cpu:0' +gpu_device = '/gpu:0' + +patch_size = [256, 512] +network_input_shape = (None, 3) +tuple(patch_size) +network_output_shape = (None, num_classes) + tuple(patch_size) +label_shape = (None, 1) + tuple(patch_size) +loss_mask_shape = label_shape + +base_channels = 32 +num_channels = [base_channels, 2*base_channels, 4*base_channels, + 6*base_channels, 6*base_channels, 6*base_channels, 6*base_channels] + +num_convs_per_block = 3 + +latent_dim = 6 +num_1x1_convs = 3 +analytic_kl = True +use_posterior_mean = False + +######################################### +# evaluation # +######################################### + +num_samples = 16 +exp_dir = '/'.join(os.path.abspath(pretrained_weights.__file__).split('/')[:-1]) +out_dir = 'EVALUATION_OUTPUT_DIRECTORY_ABSOLUTE_PATH' \ No newline at end of file diff --git a/evaluation/eval_cityscapes.py b/evaluation/eval_cityscapes.py new file mode 100644 index 0000000..68e5951 --- /dev/null +++ b/evaluation/eval_cityscapes.py @@ -0,0 +1,479 @@ +# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cityscapes evaluation script.""" + +import tensorflow as tf +import numpy as np +import os +import argparse +from tqdm import tqdm +from multiprocessing import Process, Queue +from importlib.machinery import SourceFileLoader +import logging +import pickle + +from model.probabilistic_unet import ProbUNet +from data.cityscapes.data_loader import loadFiles, map_labels_to_trainId +from utils import training_utils + + +def write_test_predictions(cf): + """ + Write samples as numpy arrays. + :param cf: config module + :return: + """ + # do not use all gpus + os.environ["CUDA_VISIBLE_DEVICES"] = cf.cuda_visible_devices + + data_dir = os.path.join(cf.data_dir, cf.resolution) + data_dict = loadFiles(label_density=cf.label_density, split='val', input_path=data_dir, + cities=None, instance=False) + # prepare out_dir + if not os.path.isdir(cf.out_dir): + os.mkdir(cf.out_dir) + + logging.info('Writing to {}'.format(cf.out_dir)) + + # initialize computation graph + prob_unet = ProbUNet(latent_dim=cf.latent_dim, num_channels=cf.num_channels, + num_1x1_convs=cf.num_1x1_convs, + num_classes=cf.num_classes, num_convs_per_block=cf.num_convs_per_block, + initializers={'w': training_utils.he_normal(), + 'b': tf.truncated_normal_initializer(stddev=0.001)}, + regularizers={'w': tf.contrib.layers.l2_regularizer(1.0), + 'b': tf.contrib.layers.l2_regularizer(1.0)}) + x = tf.placeholder(tf.float32, shape=cf.network_input_shape) + + with tf.device(cf.gpu_device): + prob_unet(x, is_training=False, one_hot_labels=cf.one_hot_labels) + sampled_logits = prob_unet.sample() + + saver = tf.train.Saver(save_relative_paths=True) + with tf.train.MonitoredTrainingSession() as sess: + + print('EXP DIR', cf.exp_dir) + latest_ckpt_path = tf.train.latest_checkpoint(cf.exp_dir) + print('CKPT PATH', latest_ckpt_path) + saver.restore(sess, latest_ckpt_path) + + for k, v in tqdm(data_dict.items()): + img = np.load(v['data']) / 255. + # add batch dimensions + img = img[np.newaxis] + + for i in range(cf.num_samples): + sample = sess.run(sampled_logits, feed_dict={x: img}) + sample = np.argmax(sample, axis=1)[:, np.newaxis] + sample = sample.astype(np.uint8) + sample_path = os.path.join(cf.out_dir, '{}_sample{}_labelIds.npy'.format(k, i)) + np.save(sample_path, sample) + + +def get_array_of_modes(cf, seg): + """ + Assemble an array holding all label modes. + :param cf: config module + :param seg: 4D integer array + :return: 4D integer array + """ + mode_stats = get_mode_statistics(cf.label_switches, exp_modes=cf.exp_modes) + switch = mode_stats['switch'] + + # construct ground-truth modes + gt_seg_modes = np.zeros(shape=(cf.num_modes,) + seg.shape, dtype=np.uint8) + for mode in range(cf.num_modes): + switched_seg = seg.copy() + for i, c in enumerate(cf.label_switches.keys()): + if switch[mode, i]: + init_id = cf.name2trainId[c] + final_id = cf.name2trainId[c + '_2'] + switched_seg[switched_seg == init_id] = final_id + gt_seg_modes[mode] = switched_seg + + return gt_seg_modes + + +def get_array_of_samples(cf, img_key): + """ + Assemble an array holding all segmentation samples for a given image. + :param cf: config module + :param img_key: string + :return: 5D integer array + """ + seg_samples = np.zeros(shape=(cf.num_samples,1,1) + tuple(cf.patch_size), dtype=np.uint8) + for i in range(cf.num_samples): + sample_path = os.path.join(cf.out_dir, '{}_sample{}_labelIds.npy'.format(img_key, i)) + try: + seg_samples[i] = np.load(sample_path) + except: + print('Could not load {}'.format(sample_path)) + + return seg_samples + + +def get_mode_counts(d_matrix_YS): + """ + Calculate image-level mode counts. + :param d_matrix_YS: 3D array + :return: numpy array + """ + # assign each sample to a mode + mean_d = np.nanmean(d_matrix_YS, axis=-1) + sampled_modes = np.argmin(mean_d, axis=-2) + + # count the modes + num_modes = d_matrix_YS.shape[0] + mode_count = np.zeros(shape=(num_modes,), dtype=np.int) + for sampled_mode in sampled_modes: + mode_count[sampled_mode] += 1 + + return mode_count + + +def get_pixelwise_mode_counts(cf, seg, seg_samples): + """ + Calculate pixel-wise mode counts. + :param cf: config module + :param seg: 4D array of integer labeled segmentations + :param seg_samples: 5D array of integer labeled segmentations + :return: array of shape (switchable classes, 3) + """ + assert seg.shape == seg_samples.shape[1:] + num_samples = seg_samples.shape[0] + pixel_counts = np.zeros(shape=(len(cf.label_switches),3), dtype=np.int) + + # iterate all switchable classes + for i,c in enumerate(cf.label_switches.keys()): + c_id = cf.name2trainId[c] + alt_c_id = cf.name2trainId[c+'_2'] + c_ixs = np.where(seg == c_id) + + total_num_pixels = np.sum((seg == c_id).astype(np.uint8)) * num_samples + pixel_counts[i,0] = total_num_pixels + + # count the pixels of original class|original class and alternative class|original class + for j in range(num_samples): + sample = seg_samples[j] + sampled_original_pixels = np.sum((sample[c_ixs] == c_id).astype(np.uint8)) + sampled_alternative_pixels = np.sum((sample[c_ixs] == alt_c_id).astype(np.uint8)) + pixel_counts[i,1] += sampled_original_pixels + pixel_counts[i,2] += sampled_alternative_pixels + + return pixel_counts + + +def get_mode_statistics(label_switches, exp_modes=5): + """ + Calculate a binary matrix of switches as well as a vector of mode probabilities. + :param label_switches: dict specifying class names and their individual sampling probabilities + :param exp_modes: integer, number of independently switchable classes + :return: dict + """ + num_modes = 2 ** exp_modes + + # assemble a binary matrix of switch decisions + switch = np.zeros(shape=(num_modes, 5), dtype=np.uint8) + for i in range(exp_modes): + switch[:,i] = 2 ** i * (2 ** (exp_modes - 1 - i) * [0] + 2 ** (exp_modes - 1 - i) * [1]) + + # calculate the probability for each individual mode + mode_probs = np.zeros(shape=(num_modes,), dtype=np.float32) + for mode in range(num_modes): + prob = 1. + for i, c in enumerate(label_switches.keys()): + if switch[mode, i]: + prob *= label_switches[c] + else: + prob *= 1. - label_switches[c] + mode_probs[mode] = prob + assert np.sum(mode_probs) == 1. + + return {'switch': switch, 'mode_probs': mode_probs} + + +def get_energy_distance_components(gt_seg_modes, seg_samples, eval_class_ids, ignore_mask=None): + """ + Calculates the components for the IoU-based generalized energy distance given an array holding all segmentation + modes and an array holding all sampled segmentations. + :param gt_seg_modes: N-D array in format (num_modes,[...],H,W) + :param seg_samples: N-D array in format (num_samples,[...],H,W) + :param eval_class_ids: integer or list of integers specifying the classes to encode, if integer range() is applied + :param ignore_mask: N-D array in format ([...],H,W) + :return: dict + """ + num_modes = gt_seg_modes.shape[0] + num_samples = seg_samples.shape[0] + + if isinstance(eval_class_ids, int): + eval_class_ids = list(range(eval_class_ids)) + + d_matrix_YS = np.zeros(shape=(num_modes, num_samples, len(eval_class_ids)), dtype=np.float32) + d_matrix_YY = np.zeros(shape=(num_modes, num_modes, len(eval_class_ids)), dtype=np.float32) + d_matrix_SS = np.zeros(shape=(num_samples, num_samples, len(eval_class_ids)), dtype=np.float32) + + # iterate all ground-truth modes + for mode in range(num_modes): + + ########################################## + # Calculate d(Y,S) = [1 - IoU(Y,S)], # + # with S ~ P_pred, Y ~ P_gt # + ########################################## + + # iterate the samples S + for i in range(num_samples): + conf_matrix = training_utils.calc_confusion(gt_seg_modes[mode], seg_samples[i], + loss_mask=ignore_mask, class_ixs=eval_class_ids) + iou = training_utils.metrics_from_conf_matrix(conf_matrix)['iou'] + d_matrix_YS[mode, i] = 1. - iou + + ########################################### + # Calculate d(Y,Y') = [1 - IoU(Y,Y')], # + # with Y,Y' ~ P_gt # + ########################################### + + # iterate the ground-truth modes Y' while exploiting the pair-wise symmetries for efficiency + for mode_2 in range(mode, num_modes): + conf_matrix = training_utils.calc_confusion(gt_seg_modes[mode], gt_seg_modes[mode_2], + loss_mask=ignore_mask, class_ixs=eval_class_ids) + iou = training_utils.metrics_from_conf_matrix(conf_matrix)['iou'] + d_matrix_YY[mode, mode_2] = 1. - iou + d_matrix_YY[mode_2, mode] = 1. - iou + + ######################################### + # Calculate d(S,S') = 1 - IoU(S,S'), # + # with S,S' ~ P_pred # + ######################################### + + # iterate all samples S + for i in range(num_samples): + # iterate all samples S' + for j in range(i, num_samples): + conf_matrix = training_utils.calc_confusion(seg_samples[i], seg_samples[j], + loss_mask=ignore_mask, class_ixs=eval_class_ids) + iou = training_utils.metrics_from_conf_matrix(conf_matrix)['iou'] + d_matrix_SS[i, j] = 1. - iou + d_matrix_SS[j, i] = 1. - iou + + return {'YS': d_matrix_YS, 'SS': d_matrix_SS, 'YY': d_matrix_YY} + + +def calc_energy_distances(d_matrices, num_samples=None, probability_weighted=False, label_switches=None, exp_mode=5): + """ + Calculate the energy distance for each image based on matrices holding the combinatorial distances. + :param d_matrices: dict holding 4D arrays of shape \ + (num_images, num_modes/num_samples, num_modes/num_samples, num_classes) + :param num_samples: integer or None + :param probability_weighted: bool + :param label_switches: None or dict + :param exp_mode: integer + :return: numpy array + """ + d_matrices = d_matrices.copy() + + if num_samples is None: + num_samples = d_matrices['SS'].shape[1] + + d_matrices['YS'] = d_matrices['YS'][:,:,:num_samples] + d_matrices['SS'] = d_matrices['SS'][:,:num_samples,:num_samples] + + # perform a nanmean over the class axis so as to not factor in classes that are not present in + # both the ground-truth mode as well as the sampled prediction + if probability_weighted: + mode_stats = get_mode_statistics(label_switches, exp_modes=exp_mode) + mode_probs = mode_stats['mode_probs'] + + mean_d_YS = np.nanmean(d_matrices['YS'], axis=-1) + mean_d_YS = np.mean(mean_d_YS, axis=2) + mean_d_YS = mean_d_YS * mode_probs[np.newaxis, :] + d_YS = np.sum(mean_d_YS, axis=1) + + mean_d_SS = np.nanmean(d_matrices['SS'], axis=-1) + d_SS = np.mean(mean_d_SS, axis=(1, 2)) + + mean_d_YY = np.nanmean(d_matrices['YY'], axis=-1) + mean_d_YY = mean_d_YY * mode_probs[np.newaxis, :, np.newaxis] * mode_probs[np.newaxis, np.newaxis, :] + d_YY = np.sum(mean_d_YY, axis=(1, 2)) + + else: + mean_d_YS = np.nanmean(d_matrices['YS'], axis=-1) + d_YS = np.mean(mean_d_YS, axis=(1,2)) + + mean_d_SS = np.nanmean(d_matrices['SS'], axis=-1) + d_SS = np.mean(mean_d_SS, axis=(1, 2)) + + mean_d_YY = np.nanmean(d_matrices['YY'], axis=-1) + d_YY = np.nanmean(mean_d_YY, axis=(1, 2)) + + return 2 * d_YS - d_SS - d_YY + + +def eval(cf, cities, queue=None, ixs=None): + """ + Perform evaluation w.r.t the generalized energy distance based on the IoU as well as image-level and pixel-level + mode frequencies (using samples written to file). + :param cf: config module + :param cities: string or list of strings + :param queue: instance of multiprocessing.Queue + :param ixs: None or 2-tuple of ints + :return: NoneType or numpy array + """ + data_dir = os.path.join(cf.data_dir, cf.resolution) + data_dict = loadFiles(label_density=cf.label_density, split='val', input_path=data_dir, cities=cities, + instance=False) + + num_modes = cf.num_modes + num_samples = cf.num_samples + + # evaluate only switchable classes, so a total of 10 here + eval_class_names = list(cf.label_switches.keys()) + list(cf.switched_name2Id.keys()) + eval_class_ids = [cf.name2trainId[n] for n in eval_class_names] + d_matrices = {'YS': np.zeros(shape=(len(data_dict), num_modes, num_samples, len(eval_class_ids)), + dtype=np.float32), + 'YY': np.ones(shape=(len(data_dict), num_modes, num_modes, len(eval_class_ids)), + dtype=np.float32), + 'SS': np.ones(shape=(len(data_dict), num_samples, num_samples, len(eval_class_ids)), + dtype=np.float32)} + sampled_mode_counts = np.zeros(shape=(num_modes,), dtype=np.int) + sampled_pixel_counts = np.zeros(shape=(len(cf.label_switches), 3), dtype=np.int) + + logging.info('Evaluating class names: {} (corresponding to labels {})'.format(eval_class_names, eval_class_ids)) + + # allow for data selection by indexing via ixs + if ixs is None: + data_keys = list(data_dict.keys()) + else: + data_keys = list(data_dict.keys())[ixs[0]:ixs[1]] + for k in d_matrices.keys(): + d_matrices[k] = d_matrices[k][:ixs[1]-ixs[0]] + + # iterate all validation images + for img_n, img_key in enumerate(tqdm(data_keys)): + + seg = np.load(data_dict[img_key]['seg']) + seg = map_labels_to_trainId(seg) + seg = seg[np.newaxis, np.newaxis] + ignore_mask = (seg == cf.ignore_label).astype(np.uint8) + + seg_samples = get_array_of_samples(cf, img_key) + gt_seg_modes = get_array_of_modes(cf, seg) + + energy_dist = get_energy_distance_components(gt_seg_modes=gt_seg_modes, seg_samples=seg_samples, + eval_class_ids=eval_class_ids, ignore_mask=ignore_mask) + sampled_mode_counts += get_mode_counts(energy_dist['YS']) + sampled_pixel_counts += get_pixelwise_mode_counts(cf, seg, seg_samples) + + for k in d_matrices.keys(): + d_matrices[k][img_n] = energy_dist[k] + + results = {'d_matrices': d_matrices, 'sampled_pixel_counts': sampled_pixel_counts, + 'sampled_mode_counts': sampled_mode_counts, 'total_num_samples': len(data_keys) * num_samples} + + if queue is not None: + queue.put(results) + return + else: + return results + + +def runInParallel(fns_args, queue): + """Run functions in parallel. + :param fns_args: list of tuples containing functions and a tuple of arguments each + :param queue: instance of multiprocessing.Queue() + :return: list of queue results + """ + proc = [] + for fn in fns_args: + p = Process(target=fn[0], args=fn[1]) + p.start() + proc.append(p) + return [queue.get() for p in proc] + + +def multiprocess_evaluation(cf): + """Evaluate the energy distance in multiprocessing. + :param cf: config module""" + q = Queue() + results = runInParallel([(eval, (cf, 'lindau', q)), + (eval, (cf, 'frankfurt', q, (0, 100))), + (eval, (cf, 'frankfurt', q, (100, 200))), + (eval, (cf, 'frankfurt', q, (200, 267))), + (eval, (cf, 'munster', q, (0, 100))), + (eval, (cf, 'munster', q, (100, 174)))], + queue=q) + total_num_samples = 0 + sampled_mode_counts = np.zeros(shape=(cf.num_modes,), dtype=np.int) + sampled_pixel_counts = np.zeros(shape=(len(cf.label_switches), 3), dtype=np.int) + d_matrices = {'YS':[], 'SS':[], 'YY':[]} + + # aggregate results from the queue + for result_dict in results: + for key in d_matrices.keys(): + d_matrices[key].append(result_dict['d_matrices'][key]) + + sampled_pixel_counts += result_dict['sampled_pixel_counts'] + sampled_mode_counts += result_dict['sampled_mode_counts'] + total_num_samples += result_dict['total_num_samples'] + + for key in d_matrices.keys(): + d_matrices[key] = np.concatenate(d_matrices[key], axis=0) + + # calculate frequencies + print('pixel frequencies', sampled_pixel_counts) + sampled_pixelwise_mode_per_class = sampled_pixel_counts[:,1:] + total_num_pixels_per_class = sampled_pixel_counts[:,0:1] + sampled_pixel_frequencies = sampled_pixelwise_mode_per_class / total_num_pixels_per_class + sampled_mode_frequencies = sampled_mode_counts / total_num_samples + + print('sampled pixel frequencies', sampled_pixel_frequencies) + print('sampled_mode_frequencies', sampled_mode_frequencies) + + results_dict = {'d_matrices': d_matrices, 'pixel_frequencies': sampled_pixel_frequencies, + 'mode_frequencies': sampled_mode_frequencies} + + results_file = os.path.join(cf.out_dir, 'eval_results.pkl') + with open(results_file, 'wb') as f: + pickle.dump(results_dict, f, pickle.HIGHEST_PROTOCOL) + logging.info('Wrote to {}'.format(results_file)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Evaluation step selection.') + parser.add_argument('--write_samples', dest='write_samples', action='store_true') + parser.add_argument('--eval_samples', dest='write_samples', action='store_false') + parser.set_defaults(write_samples=True) + parser.add_argument('-c', '--config_name', type=str, default='cityscapes_eval_config.py', + help='name of the python file that is loaded as config module') + args = parser.parse_args() + + # load evaluation config + cf = SourceFileLoader('cf', args.config_name).load_module() + + # prepare evaluation directory + if not os.path.isdir(cf.out_dir): + os.mkdir(cf.out_dir) + + # log to file and console + log_path = os.path.join(cf.out_dir, 'eval.log') + logging.basicConfig(filename=log_path, level=logging.INFO) + logging.getLogger().addHandler(logging.StreamHandler()) + logging.info('Logging to {}'.format(log_path)) + + if args.write_samples: + logging.info('Writing samples to {}'.format(cf.out_dir)) + write_test_predictions(cf) + else: + logging.info('Evaluating samples from {}'.format(cf.out_dir)) + multiprocess_evaluation(cf) \ No newline at end of file diff --git a/evaluation/evaluation_plots.ipynb b/evaluation/evaluation_plots.ipynb new file mode 100644 index 0000000..7eb0e37 --- /dev/null +++ b/evaluation/evaluation_plots.ipynb @@ -0,0 +1,590 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/utils/training_utils.py:21: UserWarning: \n", + "This call to matplotlib.use() has no effect because the backend has already\n", + "been chosen; matplotlib.use() must be called *before* pylab, matplotlib.pyplot,\n", + "or matplotlib.backends is imported for the first time.\n", + "\n", + "The backend was *originally* set to 'nbAgg' by the following code:\n", + " File \"/usr/local/lib/python3.6/runpy.py\", line 193, in _run_module_as_main\n", + " \"__main__\", mod_spec)\n", + " File \"/usr/local/lib/python3.6/runpy.py\", line 85, in _run_code\n", + " exec(code, run_globals)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/ipykernel_launcher.py\", line 16, in \n", + " app.launch_new_instance()\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/traitlets/config/application.py\", line 658, in launch_instance\n", + " app.start()\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/ipykernel/kernelapp.py\", line 486, in start\n", + " self.io_loop.start()\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/tornado/platform/asyncio.py\", line 132, in start\n", + " self.asyncio_loop.run_forever()\n", + " File \"/usr/local/lib/python3.6/asyncio/base_events.py\", line 422, in run_forever\n", + " self._run_once()\n", + " File \"/usr/local/lib/python3.6/asyncio/base_events.py\", line 1432, in _run_once\n", + " handle._run()\n", + " File \"/usr/local/lib/python3.6/asyncio/events.py\", line 145, in _run\n", + " self._callback(*self._args)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/tornado/platform/asyncio.py\", line 122, in _handle_events\n", + " handler_func(fileobj, events)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/tornado/stack_context.py\", line 300, in null_wrapper\n", + " return fn(*args, **kwargs)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py\", line 450, in _handle_events\n", + " self._handle_recv()\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py\", line 480, in _handle_recv\n", + " self._run_callback(callback, msg)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py\", line 432, in _run_callback\n", + " callback(*args, **kwargs)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/tornado/stack_context.py\", line 300, in null_wrapper\n", + " return fn(*args, **kwargs)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/ipykernel/kernelbase.py\", line 283, in dispatcher\n", + " return self.dispatch_shell(stream, msg)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/ipykernel/kernelbase.py\", line 233, in dispatch_shell\n", + " handler(stream, idents, msg)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/ipykernel/kernelbase.py\", line 399, in execute_request\n", + " user_expressions, allow_stdin)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/ipykernel/ipkernel.py\", line 208, in do_execute\n", + " res = shell.run_cell(code, store_history=store_history, silent=silent)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/ipykernel/zmqshell.py\", line 537, in run_cell\n", + " return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 2662, in run_cell\n", + " raw_cell, store_history, silent, shell_futures)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 2785, in _run_cell\n", + " interactivity=interactivity, compiler=compiler, result=result)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 2901, in run_ast_nodes\n", + " if self.run_code(code, result):\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 2961, in run_code\n", + " exec(code_obj, self.user_global_ns, self.user_ns)\n", + " File \"\", line 4, in \n", + " get_ipython().run_line_magic('matplotlib', 'notebook')\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 2131, in run_line_magic\n", + " result = fn(*args,**kwargs)\n", + " File \"\", line 2, in matplotlib\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/IPython/core/magic.py\", line 187, in \n", + " call = lambda f, *a, **k: f(*a, **k)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/IPython/core/magics/pylab.py\", line 99, in matplotlib\n", + " gui, backend = self.shell.enable_matplotlib(args.gui)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 3049, in enable_matplotlib\n", + " pt.activate_matplotlib(backend)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/IPython/core/pylabtools.py\", line 311, in activate_matplotlib\n", + " matplotlib.pyplot.switch_backend(backend)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/matplotlib/pyplot.py\", line 231, in switch_backend\n", + " matplotlib.use(newbackend, warn=False, force=True)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/matplotlib/__init__.py\", line 1410, in use\n", + " reload(sys.modules['matplotlib.backends'])\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/importlib/__init__.py\", line 166, in reload\n", + " _bootstrap._exec(spec, module)\n", + " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/matplotlib/backends/__init__.py\", line 16, in \n", + " line for line in traceback.format_stack()\n", + "\n", + "\n", + " matplotlib.use('agg')\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING:tensorflow:From /home/skohl/PycharmProjects/phabricator/skohl/prob_unet/utils/training_utils.py:270: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with distribution=normal is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "`normal` is a deprecated alias for `truncated_normal`\n" + ] + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import os\n", + "import numpy as np\n", + "\n", + "%matplotlib notebook\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.gridspec as gridspec\n", + "import seaborn as sns\n", + "import pandas as pd\n", + "import imp \n", + "import pickle\n", + "\n", + "from utils.training_utils import to_rgb, calc_confusion, metrics_from_conf_matrix\n", + "from data.cityscapes.data_loader import loadFiles, map_labels_to_trainId\n", + "from evaluation.eval_cityscapes import calc_energy_distances, get_mode_statistics\n", + "\n", + "from IPython.core.display import display, HTML\n", + "display(HTML(\"\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/tmp/prob_unet_1.0beta_analyticKLTrue_6latents_31x1s_32chan_lr0.0001PieceWise_bs16_7blocks_QMeanFalse_henormal_noBiasReg_OneHotMinus0.5_240kepochs/munster_000062_000019_sample7_labelIds.npy\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cf = imp.load_source('cf', 'cityscapes_eval_config.py')\n", + "path = cf.out_dir\n", + "\n", + "ix = 18\n", + "img_path = os.path.join(path, os.listdir(path)[ix])\n", + "print(img_path)\n", + "img = np.load(img_path)\n", + "\n", + "f = plt.figure(figsize=(8,4))\n", + "plt.imshow(to_rgb(img[0,0], cmap=cf.color_map))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# evaluate for different number of samples" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/evaluation/eval_cityscapes.py:309: RuntimeWarning: Mean of empty slice\n", + " mean_d_YY = np.nanmean(d_matrices['YY'], axis=-1)\n", + "/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/evaluation/eval_cityscapes.py:309: RuntimeWarning: Mean of empty slice\n", + " mean_d_YY = np.nanmean(d_matrices['YY'], axis=-1)\n", + "/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/evaluation/eval_cityscapes.py:309: RuntimeWarning: Mean of empty slice\n", + " mean_d_YY = np.nanmean(d_matrices['YY'], axis=-1)\n", + "/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/evaluation/eval_cityscapes.py:309: RuntimeWarning: Mean of empty slice\n", + " mean_d_YY = np.nanmean(d_matrices['YY'], axis=-1)\n" + ] + } + ], + "source": [ + "results_file = os.path.join(cf.out_dir, 'eval_results.pkl')\n", + "\n", + "with open(results_file, \"rb\") as f:\n", + " results = pickle.load(f)\n", + "\n", + "e_distances = []\n", + "e_means = []\n", + "samples_column = []\n", + " \n", + "for s in [1,4,8,16]:\n", + " e_dist = calc_energy_distances(results['d_matrices'], num_samples=s, probability_weighted=True, label_switches=cf.label_switches)\n", + " e_dist = e_dist[~np.isnan(e_dist)]\n", + " e_distances.extend(e_dist)\n", + " samples_column.extend([s] * len(e_dist))\n", + " e_means.append(np.mean(e_dist))\n", + "\n", + "energy = pd.DataFrame(data={'energy': e_distances, 'num_samples': samples_column})\n", + "means = pd.DataFrame(data={'energy': e_means, 'num_samples': [1,4,8,16]})" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "Text(0.5,0,'# samples')" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "f = plt.figure(figsize=(5,5), dpi=150)\n", + "gs = gridspec.GridSpec(1,1)\n", + "ax = plt.subplot(gs[0,0])\n", + "\n", + "sns.stripplot(x=\"num_samples\", y=\"energy\", data=energy, alpha=0.5, s=2, ax=ax)\n", + "sns.stripplot(x=\"num_samples\", y=\"energy\", data=means, s=18, marker='^', color='k', ax=ax, jitter=False)\n", + "sns.stripplot(x=\"num_samples\", y=\"energy\", data=means, s=14, marker='^', ax=ax, jitter=False)\n", + "ax.set_title('Probabilistic U-Net (validation set)', y=1.03)\n", + "fs=12\n", + "ax.set_ylabel(r'$D_{ged}^{2}$', fontsize=fs)\n", + "ax.set_xlabel('# samples', fontsize=fs)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0.39207232 0.4350246 ]\n", + " [0.44614304 0.31564506]\n", + " [0.61882485 0.30369349]\n", + " [0.72800224 0.20223276]\n", + " [0.7578801 0.22053149]]\n", + "[0.163625 0.038375 0.036625 0.00875 0.06 0.0365 0.012375 0.00575\n", + " 0.065375 0.0125 0.02025 0.002375 0.024125 0.009625 0.007875 0.001875\n", + " 0.136125 0.03225 0.0425 0.006875 0.051125 0.032 0.01175 0.007\n", + " 0.0735 0.01125 0.03325 0.003125 0.02725 0.008625 0.014 0.003375]\n", + "[0.004732871, 0.0053244797, 0.0067612445, 0.0076063997, 0.00867693, 0.009761547, 0.011358891, 0.012395615, 0.012778752, 0.013945066, 0.015381831, 0.016226986, 0.01730456, 0.01825536, 0.020824632, 0.021974044, 0.023427712, 0.0247208, 0.028200023, 0.029749475, 0.031725027, 0.03346816, 0.036916394, 0.040285747, 0.041530944, 0.045321465, 0.052737705, 0.05932992, 0.06768005, 0.07614006, 0.09668579, 0.10877152]\n" + ] + } + ], + "source": [ + "print(results['pixel_frequencies'])\n", + "print(results['mode_frequencies'])\n", + "\n", + "stats = get_mode_statistics(cf.label_switches, exp_modes=5)\n", + "print(sorted(stats['mode_probs']))\n", + "\n", + "log_mode_frequencies = np.log(results['mode_frequencies'])\n", + "log_gt_mode_frequencies = np.log(stats['mode_probs'])" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "f = plt.figure(figsize=(8,7))\n", + "gs = gridspec.GridSpec(2, 2, wspace=0.0, hspace=0.0, height_ratios=[32,6], width_ratios=[3,1])\n", + "\n", + "ax = plt.subplot(gs[0, 0])\n", + "sort_ixs = np.argsort(stats['mode_probs'])[::-1]\n", + "plt.bar(x=range(len(log_gt_mode_frequencies)), height=results['mode_frequencies'][sort_ixs], align='center', color='g', label='Estimated Frequencies')\n", + "plt.bar(x=range(len(log_gt_mode_frequencies)), height=stats['mode_probs'][sort_ixs], align='center', edgecolor='k', fc=(0,0,0,0), label='Ground truth')\n", + "plt.tick_params(\n", + " axis='x', # changes apply to the x-axis\n", + " which='both', # both major and minor ticks are affected\n", + " bottom=False, # ticks along the bottom edge are off\n", + " top=False, # ticks along the top edge are off\n", + " labelbottom=False) # labels along the bottom edge are off\n", + "\n", + "ax.legend()\n", + "ax.set_title('Probabilistic U-Net (validation set)', y=1.03)\n", + "ax = plt.gca()\n", + "ax.set_xlim([-.5,31.5])\n", + "ax.set_ylim([0,0.2])\n", + "yticks = ax.yaxis.get_major_ticks()\n", + "yticks[0].label1.set_visible(False)\n", + "ax.set_ylabel('P', fontdict={'size': 15})\n", + "\n", + "rows = list(cf.label_switches.keys())\n", + "cell_text = np.transpose(stats['switch'][sort_ixs], axes=[1,0])\n", + "\n", + "ax = plt.subplot(gs[1, 0])\n", + "the_table = ax.table(cellText=cell_text,\n", + " rowLabels=rows,\n", + " loc='center')\n", + "\n", + "ax.get_yaxis().set_visible(False)\n", + "plt.tick_params(\n", + " axis='x', # changes apply to the x-axis\n", + " which='both', # both major and minor ticks are affected\n", + " bottom=False, # ticks along the bottom edge are off\n", + " top=False, # ticks along the top edge are off\n", + " labelbottom=False) # labels along the bottom edge are off\n", + "ax.set_xlabel('Discrete Modes')\n", + "\n", + "ax = plt.subplot(gs[1, 1])\n", + "ax.set_title('Pixel-wise Frequency Estimation',fontdict={'size': 6})\n", + "plt.barh(y=range(5), width=results['pixel_frequencies'][:,1], align='center', color='g')\n", + "plt.barh(y=range(5), width=list(cf.label_switches.values()), align='center', edgecolor='k', fc=(0,0,0,0))\n", + "ax.get_yaxis().set_visible(False)\n", + "ax.set_xticks([0.1,0.2,0.3,0.4,0.5])\n", + "ax.set_xlabel('P')\n", + "ax.tick_params(axis='both', which='major', labelsize=6)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# prepare animated gif from samples" + ] + }, + { + "cell_type": "code", + "execution_count": 156, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saved to /home/skohl/PycharmProjects/phabricator/skohl/prob_unet/tmp//prob_unet_0.0005beta_analyticKLTrue_6latents_31x1s_32chan_lr0.0001PieceWise_bs16_7blocks_QMeanFalse_henormal_noBiasReg_OneHotMinus0.5_240kepochs_TFP/gif_panel_sample_0.png\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saved to /home/skohl/PycharmProjects/phabricator/skohl/prob_unet/tmp//prob_unet_0.0005beta_analyticKLTrue_6latents_31x1s_32chan_lr0.0001PieceWise_bs16_7blocks_QMeanFalse_henormal_noBiasReg_OneHotMinus0.5_240kepochs_TFP/gif_panel_sample_1.png\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "execute_result" + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0max\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlegend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhandles\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlegend_handles\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m9\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbbox_to_anchor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m0.03\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mncol\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlegend_handles\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mframeon\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0mout_dir\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'gif_panel_sample_{}.png'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msample_num\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 57\u001b[0;31m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msavefig\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdpi\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m200\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbbox_inches\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'tight'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpad_inches\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 58\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Saved to {}'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/matplotlib/pyplot.py\u001b[0m in \u001b[0;36msavefig\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 708\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msavefig\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 709\u001b[0m \u001b[0mfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgcf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 710\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msavefig\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 711\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcanvas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdraw_idle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# need this if 'transparent=True' to reset colors\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 712\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/matplotlib/figure.py\u001b[0m in \u001b[0;36msavefig\u001b[0;34m(self, fname, **kwargs)\u001b[0m\n\u001b[1;32m 2033\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_frameon\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframeon\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2034\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2035\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcanvas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2036\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2037\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mframeon\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/matplotlib/backend_bases.py\u001b[0m in \u001b[0;36mprint_figure\u001b[0;34m(self, filename, dpi, facecolor, edgecolor, orientation, format, **kwargs)\u001b[0m\n\u001b[1;32m 2210\u001b[0m \u001b[0morientation\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0morientation\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2211\u001b[0m \u001b[0mdryrun\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2212\u001b[0;31m **kwargs)\n\u001b[0m\u001b[1;32m 2213\u001b[0m \u001b[0mrenderer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfigure\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cachedRenderer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2214\u001b[0m \u001b[0mbbox_inches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfigure\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_tightbbox\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/matplotlib/backends/backend_agg.py\u001b[0m in \u001b[0;36mprint_png\u001b[0;34m(self, filename_or_obj, *args, **kwargs)\u001b[0m\n\u001b[1;32m 526\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcbook\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen_file_cm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename_or_obj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"wb\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfh\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 527\u001b[0m _png.write_png(renderer._renderer, fh,\n\u001b[0;32m--> 528\u001b[0;31m self.figure.dpi, metadata=metadata)\n\u001b[0m\u001b[1;32m 529\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdpi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moriginal_dpi\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ], + "output_type": "error" + } + ], + "source": [ + "from matplotlib.lines import Line2D\n", + "from data.cityscapes.data_loader import map_labels_to_trainId\n", + "\n", + "# get dictionary of image files\n", + "data_dir = os.path.join(cf.data_dir, cf.resolution)\n", + "data_dict = loadFiles(label_density=cf.label_density, split='val', input_path=data_dir,\n", + " cities=None, instance=False)\n", + "\n", + "# list of image indices in data_dict\n", + "ixs = [0,2,75,85,100,145,350,390,460,470]\n", + "num_samples = 16\n", + "y_dim = 256\n", + "x_dim = 512\n", + "\n", + "img_arr = np.zeros(shape=(y_dim, len(ixs) * x_dim, 3)) \n", + "gt_arr = np.zeros(shape=(y_dim, len(ixs) * x_dim))\n", + "sample_arr = np.zeros(shape=(num_samples, y_dim, len(ixs) * x_dim))\n", + "\n", + "for i,ix in enumerate(ixs):\n", + "\n", + " img_key = list(data_dict.keys())[ix]\n", + " img_path = data_dict[img_key]['data']\n", + " seg_path = data_dict[img_key]['seg']\n", + " \n", + " # load images\n", + " img_arr[:, i * x_dim: (i + 1) * x_dim] = np.transpose(np.load(img_path), axes=[1,2,0]) / 255.\n", + " \n", + " # load gt segmentation\n", + " gt_arr[:, i * x_dim: (i + 1) * x_dim] = map_labels_to_trainId(np.load(seg_path))\n", + " \n", + " # load samples\n", + " for sample_num in range(num_samples):\n", + " sample_path = os.path.join(cf.out_dir, '{}_sample{}_labelIds.npy'.format(img_key, sample_num))\n", + " sample_arr[sample_num, :, i * x_dim: (i + 1) * x_dim] = np.load(sample_path)[0,0]\n", + "\n", + "for sample_num in range(num_samples):\n", + " f = plt.figure(figsize=(len(ixs) * 4, 9))\n", + "\n", + " arr = np.concatenate([img_arr, to_rgb(gt_arr, cmap=cf.color_map), to_rgb(sample_arr[sample_num], cmap=cf.color_map)], axis=0)\n", + " \n", + " plt.imshow(arr)\n", + " plt.text(-80,170, 'image', fontsize=12, rotation=90, rotation_mode='anchor')\n", + " plt.text(-80,470, 'deterministic', fontsize=12, rotation=90, rotation_mode='anchor')\n", + " plt.text(-40,460, 'groundtruth', fontsize=12, rotation=90, rotation_mode='anchor')\n", + " plt.text(-80,690, 'sample {}'.format(sample_num + 1), fontsize=12, rotation=90, rotation_mode='anchor')\n", + " \n", + " ax = plt.gca()\n", + " ax.get_xaxis().set_visible(False)\n", + " ax.get_yaxis().set_visible(False)\n", + " \n", + " # custom legend\n", + " trainId2name = cf.trainId2name\n", + " trainId2name[255] = 'ignore'\n", + " legend_handles = [Line2D([0], [0], color=tuple([c / 255. for c in cf.color_map[trainID]]), lw=4, label=trainId2name[trainID]) for trainID in list(cf.color_map.keys())] \n", + " ax.legend(handles=legend_handles, loc=9, bbox_to_anchor=(0.5, -0.03), ncol=len(legend_handles), frameon=False)\n", + " out_dir = os.path.join(cf.out_dir, 'gif_panel_sample_{}.png'.format(sample_num))\n", + " plt.savefig(out_dir, dpi=200, bbox_inches='tight', pad_inches=0.0)\n", + " plt.close()\n", + " print('Saved to {}'.format(out_dir))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "venv" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3.0 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/probabilistic_unet.py b/model/probabilistic_unet.py new file mode 100644 index 0000000..cb5d93b --- /dev/null +++ b/model/probabilistic_unet.py @@ -0,0 +1,478 @@ +# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Probabilistic U-Net model.""" + +import tensorflow as tf +import sonnet as snt +import tensorflow_probability as tfp +tfd = tfp.distributions +from utils.training_utils import ce_loss, he_normal + +def down_block(features, + output_channels, + kernel_shape, + stride=1, + rate=1, + num_convs=2, + initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)}, + regularizers=None, + nonlinearity=tf.nn.relu, + down_sample_input=True, + down_sampling_op=lambda x, size: tf.image.resize_images(x, size, + method=tf.image.ResizeMethod.BILINEAR, align_corners=True), + data_format='NCHW', + name='down_block'): + """A block made up of a down-sampling step followed by several convolutional layers.""" + with tf.variable_scope(name): + if down_sample_input: + features = down_sampling_op(features, data_format) + + for _ in range(num_convs): + features = snt.Conv2D(output_channels, kernel_shape, stride, rate, data_format=data_format, + initializers=initializers, regularizers=regularizers)(features) + features = nonlinearity(features) + + return features + + +def up_block(lower_res_inputs, + same_res_inputs, + output_channels, + kernel_shape, + stride=1, + rate=1, + num_convs=2, + initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)}, + regularizers=None, + nonlinearity=tf.nn.relu, + up_sampling_op=lambda x, size: tf.image.resize_images(x, size, + method=tf.image.ResizeMethod.BILINEAR, align_corners=True), + data_format='NCHW', + name='up_block'): + """A block made up of an up-sampling step followed by several convolutional layers.""" + with tf.variable_scope(name): + spatial_shape = same_res_inputs.get_shape()[2:] + + if data_format=='NHWC': + features = up_sampling_op(lower_res_inputs, spatial_shape) + features = tf.concat([features, same_res_inputs], axis=-1) + else: + lower_res_inputs = tf.transpose(lower_res_inputs, perm=[0,2,3,1]) + features = up_sampling_op(lower_res_inputs, spatial_shape) + features = tf.transpose(features, perm=[0,3,1,2]) + features = tf.concat([features, same_res_inputs], axis=1) + + for _ in range(num_convs): + features = snt.Conv2D(output_channels, kernel_shape, stride, rate, data_format=data_format, + initializers=initializers, regularizers=regularizers)(features) + features = nonlinearity(features) + + return features + + +class VGG_Encoder(snt.AbstractModule): + """A quasi VGG-style convolutional net, made of M x (down-sampling, N x conv)-operations, + where M = len(num_channels), N = num_convs_per_block.""" + + def __init__(self, + num_channels, + nonlinearity=tf.nn.relu, + num_convs_per_block=3, + initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)}, + regularizers={'w': tf.contrib.layers.l2_regularizer(1.0), 'b': tf.contrib.layers.l2_regularizer(1.0)}, + data_format='NCHW', + down_sampling_op=lambda x, df: tf.nn.avg_pool(x, ksize=[1,1,2,2], strides=[1,1,2,2], + padding='SAME', data_format=df), + name="vgg_enc"): + super(VGG_Encoder, self).__init__(name=name) + self._num_channels = num_channels + self._nonlinearity = nonlinearity + self._num_convs = num_convs_per_block + self._initializers = initializers + self._regularizers = regularizers + self._data_format = data_format + self._down_sampling_op = down_sampling_op + + def _build(self, inputs): + """ + :param inputs: 4D tensor of shape NCHW or NWHC + :return: a list of 4D tensors of shape NCHW or NWHC + """ + features = [inputs] + + # iterate blocks (`processing scales') + for i, n_channels in enumerate(self._num_channels): + + if i == 0: + down_sample = False + else: + down_sample = True + tf.logging.info('encoder scale {}: {}'.format(i, features[-1].get_shape())) + features.append(down_block(features[-1], + output_channels=n_channels, + kernel_shape=(3,3), + num_convs=self._num_convs, + nonlinearity=self._nonlinearity, + initializers=self._initializers, + regularizers=self._regularizers, + down_sample_input=down_sample, + data_format=self._data_format, + down_sampling_op=self._down_sampling_op, + name='down_block_{}'.format(i))) + # return all features except for the input images + return features[1:] + + +class VGG_Decoder(snt.AbstractModule): + """A quasi VGG-style convolutional net, made of M x (up-sampling, N x conv)-operations, + where M = len(num_channels), N = num_convs_per_block.""" + + def __init__(self, + num_channels, + num_classes, + nonlinearity=tf.nn.relu, + num_convs_per_block=3, + initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)}, + regularizers={'w': tf.contrib.layers.l2_regularizer(1.0), 'b': tf.contrib.layers.l2_regularizer(1.0)}, + data_format='NCHW', + up_sampling_op=lambda x, size: tf.image.resize_images(x, size, + method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, align_corners=True), + name="vgg_dec"): + super(VGG_Decoder, self).__init__(name=name) + self._num_channels = num_channels + self._num_classes = num_classes + self._nonlinearity = nonlinearity + self._num_convs = num_convs_per_block + self._initializers = initializers + self._regularizers = regularizers + self._data_format = data_format + self._up_sampling_op = up_sampling_op + + def _build(self, input_list): + """ + :param input_list: a list of 4D tensors of shape NCHW or NHWC + :return: 4D tensor + """ + try: + assert len(self._num_channels) == len(input_list) + except: + raise AssertionError('Missmatch: {} blocks vs. {} incoming feature-maps!') + + n = len(input_list) - 2 + lower_res_features = input_list[-1] + + # iterate input features & channels in reverse order, starting from the second last (here, =nth) features + for i in range(n, -1, -1): + same_res_features = input_list[i] + n_channels = self._num_channels[i] + + tf.logging.info('decoder scale {}: {}'.format(i, lower_res_features.get_shape())) + lower_res_features = up_block(lower_res_features, + same_res_features, + output_channels=n_channels, + kernel_shape=(3, 3), + num_convs=self._num_convs, + nonlinearity=self._nonlinearity, + initializers=self._initializers, + regularizers=self._regularizers, + data_format=self._data_format, + up_sampling_op=self._up_sampling_op, + name='up_block_{}'.format(i)) + return lower_res_features + + +class UNet(snt.AbstractModule): + """A quasi standard U-Net, similar to `U-Net: Convolutional Networks for Biomedical Image Segmentation', + https://arxiv.org/abs/1505.04597.""" + + def __init__(self, + num_channels, + num_classes, + nonlinearity=tf.nn.relu, + num_convs_per_block=3, + initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)}, + regularizers=None, + data_format='NCHW', + down_sampling_op=lambda x, df: tf.nn.avg_pool(x, ksize=[1,1,2,2], strides=[1,1,2,2], + padding='SAME', data_format=df), + up_sampling_op=lambda x, size: tf.image.resize_images(x, size, + method=tf.image.ResizeMethod.BILINEAR, align_corners=True), + name="unet"): + super(UNet, self).__init__(name=name) + with self._enter_variable_scope(): + tf.logging.info('Building U-Net.') + self._encoder = VGG_Encoder(num_channels, nonlinearity, num_convs_per_block, initializers, regularizers, + data_format=data_format, down_sampling_op=down_sampling_op) + self._decoder = VGG_Decoder(num_channels, num_classes, nonlinearity, num_convs_per_block, initializers, + regularizers, data_format=data_format, up_sampling_op=up_sampling_op) + + def _build(self, inputs): + """ + :param inputs: 4D tensor of shape NCHW or NHWC + :return: 4D tensor + """ + encoder_features = self._encoder(inputs) + predicted_logits = self._decoder(encoder_features) + return predicted_logits + + +class Conv1x1Decoder(snt.AbstractModule): + """A stack of 1x1 convolutions that takes two tensors to be concatenated along their channel axes.""" + + def __init__(self, + num_classes, + num_channels, + num_1x1_convs, + nonlinearity=tf.nn.relu, + initializers={'w': tf.orthogonal_initializer(), 'b': tf.truncated_normal_initializer(stddev=0.001)}, + regularizers={'w': tf.contrib.layers.l2_regularizer(1.0), 'b': tf.contrib.layers.l2_regularizer(1.0)}, + data_format='NCHW', + name='conv_decoder'): + super(Conv1x1Decoder, self).__init__(name=name) + self._num_classes = num_classes + self._num_channels = num_channels + self._num_1x1_convs = num_1x1_convs + self._nonlinearity = nonlinearity + self._initializers = initializers + self._regularizers = regularizers + self._data_format = data_format + + if data_format == 'NCHW': + self._channel_axis = 1 + self._spatial_axes = [2,3] + else: + self._channel_axis = -1 + self._spatial_axes = [1,2] + + def _build(self, features, z): + """ + :param features: 4D tensor of shape NCHW or NHWC + :param z: 4D tensor of shape NC11 or N11C + :return: 4D tensor + """ + shp = features.get_shape() + spatial_shape = [shp[axis] for axis in self._spatial_axes] + multiples = [1] + spatial_shape + multiples.insert(self._channel_axis, 1) + + if len(z.get_shape()) == 2: + z = tf.expand_dims(z, axis=2) + z = tf.expand_dims(z, axis=2) + + # broadcast latent vector to spatial dimensions of the image/feature tensor + broadcast_z = tf.tile(z, multiples) + features = tf.concat([features, broadcast_z], axis=self._channel_axis) + for _ in range(self._num_1x1_convs): + features = snt.Conv2D(self._num_channels, kernel_shape=(1,1), stride=1, rate=1, + data_format=self._data_format, + initializers=self._initializers, regularizers=self._regularizers)(features) + features = self._nonlinearity(features) + logits = snt.Conv2D(self._num_classes, kernel_shape=(1,1), stride=1, rate=1, + data_format=self._data_format, + initializers=self._initializers, regularizers=None) + return logits(features) + + +class AxisAlignedConvGaussian(snt.AbstractModule): + """A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix.""" + + def __init__(self, + latent_dim, + num_channels, + nonlinearity=tf.nn.relu, + num_convs_per_block=3, + initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)}, + regularizers={'w': tf.contrib.layers.l2_regularizer(1.0), 'b': tf.contrib.layers.l2_regularizer(1.0)}, + data_format='NCHW', + down_sampling_op=lambda x, df:\ + tf.nn.avg_pool(x, ksize=[1,1,2,2], strides=[1,1,2,2], padding='SAME', data_format=df), + name="conv_dist"): + self._latent_dim = latent_dim + self._initializers = initializers + self._regularizers = regularizers + self._data_format = data_format + + if data_format == 'NCHW': + self._channel_axis = 1 + self._spatial_axes = [2,3] + else: + self._channel_axis = -1 + self._spatial_axes = [1,2] + + super(AxisAlignedConvGaussian, self).__init__(name=name) + with self._enter_variable_scope(): + tf.logging.info('Building ConvGaussian.') + self._encoder = VGG_Encoder(num_channels, nonlinearity, num_convs_per_block, initializers, regularizers, + data_format=data_format, down_sampling_op=down_sampling_op) + + def _build(self, img, seg=None): + """ + Evaluate mu and log_sigma of a Gaussian conditioned on an image + optionally, concatenated one-hot segmentation. + :param img: 4D array + :param seg: 4D array + :return: snt.AbstractModule object + """ + if seg is not None: + seg = tf.cast(seg, tf.float32) + img = tf.concat([img, seg], axis=self._channel_axis) + encoding = self._encoder(img)[-1] + encoding = tf.reduce_mean(encoding, axis=self._spatial_axes, keepdims=True) + + mu_log_sigma = snt.Conv2D(2 * self._latent_dim, (1,1), stride=1, rate=1, data_format=self._data_format, + initializers=self._initializers, regularizers=self._regularizers)(encoding) + + mu_log_sigma = tf.squeeze(mu_log_sigma, axis=self._spatial_axes) + mu = mu_log_sigma[:, :self._latent_dim] + log_sigma = mu_log_sigma[:, self._latent_dim:] + + return tfd.MultivariateNormalDiag(loc=mu, scale_diag=tf.exp(log_sigma)) + + +class ProbUNet(snt.AbstractModule): + """Probabilistic U-Net.""" + + def __init__(self, + latent_dim, + num_channels, + num_classes, + num_1x1_convs=3, + nonlinearity=tf.nn.relu, + num_convs_per_block=3, + initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)}, + regularizers={'w': tf.contrib.layers.l2_regularizer(1.0), 'b': tf.contrib.layers.l2_regularizer(1.0)}, + data_format='NCHW', + down_sampling_op=lambda x, df:\ + tf.nn.avg_pool(x, ksize=[1,1,2,2], strides=[1,1,2,2], padding='SAME', data_format=df), + up_sampling_op=lambda x, size:\ + tf.image.resize_images(x, size, method=tf.image.ResizeMethod.BILINEAR, align_corners=True), + name='prob_unet'): + super(ProbUNet, self).__init__(name=name) + self._data_format = data_format + self._num_classes = num_classes + + with self._enter_variable_scope(): + self._unet = UNet(num_channels=num_channels, num_classes=num_classes, nonlinearity=nonlinearity, + num_convs_per_block=num_convs_per_block, initializers=initializers, + regularizers=regularizers, data_format=data_format, + down_sampling_op=down_sampling_op, up_sampling_op=up_sampling_op) + + self._f_comb = Conv1x1Decoder(num_classes=num_classes, num_1x1_convs=num_1x1_convs, + num_channels=num_channels[0], nonlinearity=nonlinearity, + data_format=data_format, initializers=initializers, regularizers=regularizers) + + self._prior =\ + AxisAlignedConvGaussian(latent_dim=latent_dim, num_channels=num_channels, + nonlinearity=nonlinearity, num_convs_per_block=num_convs_per_block, + initializers=initializers, regularizers=regularizers, name='prior') + + self._posterior =\ + AxisAlignedConvGaussian(latent_dim=latent_dim, num_channels=num_channels, + nonlinearity=nonlinearity, num_convs_per_block=num_convs_per_block, + initializers=initializers, regularizers=regularizers, name='posterior') + + def _build(self, img, seg=None, is_training=True, one_hot_labels=True): + """ + Evaluate individual components of the net. + :param img: 4D image array + :param seg: 4D segmentation array + :param is_training: if False, refrain from evaluating the posterior + :param one_hot_labels: bool, if False expects integer labeled segmentation of shape N1HW or NHW1 + :return: None + """ + if is_training: + if seg is not None: + if not one_hot_labels: + if self._data_format == 'NCHW': + spatial_shape = img.get_shape()[-2:] + class_axis = 1 + one_hot_shape = (-1, self._num_classes) + tuple(spatial_shape) + elif self._data_format == 'NHWC': + spatial_shape = img.get_shape()[-3:-1] + one_hot_shape = (-1,) + tuple(spatial_shape) + (self._num_classes,) + class_axis = 3 + + seg = tf.reshape(seg, shape=[-1]) + seg = tf.one_hot(indices=seg, depth=self._num_classes, axis=class_axis) + seg = tf.reshape(seg, shape=one_hot_shape) + seg -= 0.5 + self._q = self._posterior(img, seg) + + self._p = self._prior(img) + self._unet_features = self._unet(img) + + def reconstruct(self, use_posterior_mean=False, z_q=None): + """ + Reconstruct a given segmentation. Default settings result in decoding a posterior sample. + :param use_posterior_mean: use posterior_mean instead of sampling z_q + :param z_q: use provided latent sample z_q instead of sampling anew + :return: 4D logits tensor + """ + if use_posterior_mean: + z_q = self._q._mu + else: + if z_q is None: + z_q = self._q.sample() + return self._f_comb(self._unet_features, z_q) + + def sample(self): + """ + Sample a segmentation by reconstructing from a prior sample. + Only needs to re-evaluate the last 1x1-convolutions. + :return: 4D logits tensor + """ + z_p = self._p.sample() + return self._f_comb(self._unet_features, z_p) + + def kl(self, analytic=True, z_q=None): + """ + Calculate the Kullback-Leibler divergence KL(Q||P) between 2 axis-aligned gaussians, + i.e. the variance sigma is assumed diagonal. + :param analytic: bool, if False, approximate the KL via sampling from the posterior + :param z_q: None or 2D tensor, if analytic=False the posterior sample can be provided instead of sampling anew + :return: 4D tensor + """ + if analytic: + kl = tfd.kl_divergence(self._q, self._p) + else: + if z_q is None: + z_q = self._q.sample() + log_q = self._q.log_prob(z_q) + log_p = self._p.log_prob(z_q) + kl = log_q - log_p + return kl + + def elbo(self, seg, beta=1.0, analytic_kl=True, reconstruct_posterior_mean=False, z_q=None, one_hot_labels=True, + loss_mask=None): + """ + Calculate the evidence lower bound (elbo) of the log-likelihood of P(Y|X). + :param seg: 4D tensor + :param analytic_kl: bool, if False calculate the KL via sampling + :param z_q: 4D tensor + :param one_hot_labels: bool, if False expects integer labeled segmentation of shape N1HW or NHW1 + :param loss_mask: 4D tensor, binary + :return: 1D tensor + """ + if z_q is None: + z_q = self._q.sample() + + self._kl = tf.reduce_mean(self.kl(analytic_kl, z_q)) + + self._rec_logits = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean, z_q=z_q) + rec_loss = ce_loss(labels=seg, logits=self._rec_logits, n_classes=self._num_classes, + loss_mask=loss_mask, one_hot_labels=one_hot_labels) + self._rec_loss = rec_loss['sum'] + self._rec_loss_mean = rec_loss['mean'] + + return -(self._rec_loss + beta * self._kl) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..6dd1342 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +tensorflow_probability==0.3.0 +tensorflow-gpu==1.10.0 +dm-sonnet==1.23 +matplotlib==2.2.2 +numpy==1.14.5 +Pillow==5.1.0 +scipy==1.1.0 +SimpleITK==1.1.0 +tensorboard==1.10.0 +tqdm==4.23.4 +seaborn==0.9.0 +pytest==3.8.0 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..79ea0a1 --- /dev/null +++ b/setup.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from distutils.core import setup +from setuptools import find_packages + +req_file = "requirements.txt" + +def parse_requirements(filename): + lineiter = (line.strip() for line in open(filename)) + return [line for line in lineiter if line and not line.startswith("#")] + +install_reqs = parse_requirements(req_file) + +setup(name='model', + version='latest', + packages=find_packages(exclude=['test', 'test.*']), + install_requires=install_reqs, + dependency_links=[], + ) \ No newline at end of file diff --git a/tests/evaluation/eval_tests.py b/tests/evaluation/eval_tests.py new file mode 100644 index 0000000..6b6c322 --- /dev/null +++ b/tests/evaluation/eval_tests.py @@ -0,0 +1,152 @@ +# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Evaluation metrics tests.""" + +import pytest +import numpy as np +from evaluation.eval_cityscapes import get_energy_distance_components, get_mode_counts, get_pixelwise_mode_counts +from utils.training_utils import calc_confusion, metrics_from_conf_matrix + + +def nan_save_array_equal(a, b, nan_replacement=-1.): + """Replace NANs to savely compare arrays elementwise.""" + a_nan_ixs = np.where(np.isnan(a)) + a[a_nan_ixs] = nan_replacement + + b_nan_ixs = np.where(np.isnan(b)) + b[b_nan_ixs] = nan_replacement + + return (a == b).all() + + +@pytest.mark.parametrize("test_input,expected", [ + ([np.zeros(shape=(1,1,10,10)), np.zeros(shape=(1,1,10,10))], + {'tp_0': 100, 'tp_1': 0, 'fp_0': 0, 'fp_1': 0, 'fn_0': 0, 'fn_1': 0}), + ([np.ones(shape=(1,1,10,10)), np.ones(shape=(1,1,10,10))], + {'tp_0': 0, 'tp_1': 100, 'fp_0': 0, 'fp_1': 0, 'fn_0': 0, 'fn_1': 0}), + ([np.ones(shape=(1,1,10,10)), np.zeros(shape=(1,1,10,10))], + {'tp_0': 0, 'tp_1': 0, 'fp_0': 100, 'fp_1': 0, 'fn_0': 0, 'fn_1': 100}), + ([np.concatenate([np.ones(shape=(1,1,10,5)), np.zeros(shape=(1,1,10,5))], axis=-1), np.zeros(shape=(1,1,10,10))], + {'tp_0': 50, 'tp_1': 0, 'fp_0': 50, 'fp_1': 0, 'fn_0': 0, 'fn_1': 50}), +]) +def test_confusion(test_input, expected): + + gt_seg_modes = test_input[0] + seg_samples = test_input[1] + + conf_matrix = calc_confusion(gt_seg_modes, seg_samples, class_ixs=[0,1]) + + tp_0 = conf_matrix[0,0] + tp_1 = conf_matrix[1,0] + fp_0 = conf_matrix[0,1] + fp_1 = conf_matrix[1,1] + fn_0 = conf_matrix[0,3] + fn_1 = conf_matrix[1,3] + + assert tp_0 == expected['tp_0'] + assert tp_1 == expected['tp_1'] + assert fp_0 == expected['fp_0'] + assert fp_1 == expected['fp_1'] + assert fn_0 == expected['fn_0'] + assert fn_1 == expected['fn_1'] + + +@pytest.mark.parametrize("test_input,expected,eval_f", [ + ([np.zeros(shape=(1,1,10,10)), np.zeros(shape=(1,1,10,10))], {'iou_0': 1., 'iou_1': np.nan}, + [lambda x,y: x == y, lambda x,y: np.isnan(x) and np.isnan(y)]), + ([np.ones(shape=(1,1,10,10)), np.ones(shape=(1,1,10,10))], {'iou_0': np.nan, 'iou_1': 1.}, + [lambda x,y: np.isnan(x) and np.isnan(y), lambda x,y: x == y]), + ([np.ones(shape=(1,1,10,10)), np.zeros(shape=(1,1,10,10))], {'iou_0': 0., 'iou_1': 0.}, + [lambda x,y: x == y, lambda x,y: x == y]), + ([np.concatenate([np.ones(shape=(1, 1, 10, 5)), np.zeros(shape=(1, 1, 10, 5))], axis=-1), + np.zeros(shape=(1, 1, 10, 10))], {'iou_0': 0.5, 'iou_1': 0.}, + [lambda x,y: x == y, lambda x,y: x == y]) +]) +def test_metrics_from_confusion(test_input, expected, eval_f): + + gt_seg_modes = test_input[0] + seg_samples = test_input[1] + + conf_matrix = calc_confusion(gt_seg_modes, seg_samples, class_ixs=[0,1]) + iou = metrics_from_conf_matrix(conf_matrix)['iou'] + + assert eval_f[0](iou[0], expected['iou_0']) + assert eval_f[1](iou[1], expected['iou_1']) + + +@pytest.mark.parametrize("test_input,expected,eval_f", [ + ([np.zeros(shape=(1,1,1,10,10)), np.zeros(shape=(1,1,1,10,10)), [0]], {'YS': 0., 'SS': 0. , 'YY': 0.}, + 3 * [lambda x,y: x == y]), + ([np.zeros(shape=(1,1,1,10,10)), np.zeros(shape=(1,1,1,10,10)), [1]], {'YS': np.nan, 'SS': np.nan , 'YY': np.nan}, + 3 * [lambda x,y: np.isnan(x) and np.isnan(y)]), + ([np.ones(shape=(1,1,1,10,10)), np.zeros(shape=(1,1,1,10,10)), [0]], {'YS': 1., 'SS': 0. , 'YY': np.nan}, + 2 * [lambda x,y: x == y] + [lambda x,y: np.isnan(x) and np.isnan(y)]), + ([np.concatenate([np.ones(shape=(1,1,1,10,10)), np.zeros(shape=(1,1,1,10,10))], axis=0), + np.concatenate([np.zeros(shape=(1,1,1,10,10)), np.ones(shape=(1,1,1,10,10))], axis=0), [0]], + {'YS': np.array([[[1.],[np.nan]], [[0.],[1.]]]), 'SS': np.array([[[0.],[1.]], [[1.],[np.nan]]]), + 'YY': np.array([[[np.nan],[1.]], [[1.],[0.]]])}, + 3 * [lambda x,y: nan_save_array_equal(x,y)]), + ([np.concatenate([np.ones(shape=(1,1,1,10,10)), np.zeros(shape=(1,1,1,10,10))], axis=0), + np.concatenate([np.zeros(shape=(1,1,1,10,10)), np.ones(shape=(1,1,1,10,10))], axis=0), 2], + {'YS': np.array([[[1.,1.], [np.nan,0.]], [[0.,np.nan], [1.,1.]]]), + 'SS': np.array([[[0.,np.nan], [1.,1.]], [[1.,1.], [np.nan,0.]]]), + 'YY': np.array([[[np.nan,0.], [1.,1.]], [[1.,1.], [0.,np.nan]]])}, + 3 * [lambda x,y: nan_save_array_equal(x, y)]), +]) +def test_energy_distance_components(test_input, expected, eval_f): + + gt_seg_modes = test_input[0] + seg_samples = test_input[1] + eval_class_ids = test_input[2] + + results = get_energy_distance_components(gt_seg_modes, seg_samples, eval_class_ids=eval_class_ids) + + assert eval_f[0](results['YS'], expected['YS']) + assert eval_f[1](results['SS'], expected['SS']) + assert eval_f[2](results['YY'], expected['YY']) + + +@pytest.mark.parametrize("test_input,expected,eval_f", [ + (np.concatenate([np.zeros(shape=(1,5,5)), 0.1 * np.ones(shape=(1,5,5))]), [5,0], lambda x,y: (x==y).all()) +]) +def test_get_mode_counts(test_input, expected, eval_f): + mode_counts = get_mode_counts(test_input) + + assert eval_f(mode_counts, expected) + + +@pytest.mark.parametrize("test_input,expected,eval_f", [ + ([np.zeros(shape=(1,1,10,10)), + np.concatenate([np.ones(shape=(1,1,1,10,5)), np.zeros(shape=(1,1,1,10,5))], axis=-1), 1], + [[100, 50, 50]], lambda x,y: (x==y).all()), + ([np.concatenate([np.ones(shape=(1,1,10,5)), np.zeros(shape=(1,1,10,5))], axis=-1), + np.concatenate([np.ones(shape=(1,1,1,10,10)), np.zeros(shape=(1,1,1,10,10))], axis=0), 2], + [[100, 50, 0],[100, 50, 0]], lambda x,y: (x==y).all()), + ([np.concatenate([np.ones(shape=(1, 1, 10, 5)), np.zeros(shape=(1, 1, 10, 5))], axis=-1), + np.concatenate([3 * np.ones(shape=(1, 1, 1, 10, 10)), 2 * np.ones(shape=(1, 1, 1, 10, 10)), + np.ones(shape=(1, 1, 1, 10, 10)), np.zeros(shape=(1, 1, 1, 10, 10))], axis=0), 2], + [[200, 50, 50], [200, 50, 50]], lambda x, y: (x == y).all()), +]) +def test_pixelwise_mode_counts(test_input, expected, eval_f): + + class cf(): + def __init__(self, num_classes): + self.label_switches = {'class_{}'.format(i): np.random.uniform(0.0,1.0) for i in range(num_classes)} + self.name2trainId = {**{'class_{}'.format(i): i for i in range(num_classes)}, + **{'class_{}_2'.format(i): i + num_classes for i in range(num_classes)}} + num_classes = test_input[2] + pixelwise_mode_counts = get_pixelwise_mode_counts(cf(num_classes), seg=test_input[0], seg_samples=test_input[1]) + + assert eval_f(pixelwise_mode_counts, expected) \ No newline at end of file diff --git a/training/prob_unet_config.py b/training/prob_unet_config.py new file mode 100644 index 0000000..50a653a --- /dev/null +++ b/training/prob_unet_config.py @@ -0,0 +1,127 @@ +# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""CityScapes training config.""" + +import os +import numpy as np +from collections import OrderedDict +from data.cityscapes.cityscapes_labels import labels as cs_labels_tuple + +config_path = os.path.realpath(__file__) + +######################################### +# data-loader # +######################################### + +data_dir = 'PREPROCESSING_OUTPUT_DIRECTORY_ABSOLUTE_PATH' +resolution = 'quarter' +label_density = 'gtFine' +gt_instances = False +train_cities = ['aachen', 'bochum', 'bremen', 'cologne', 'dusseldorf', 'erfurt', 'hamburg', 'hanover', + 'jena', 'krefeld', 'stuttgart', 'strasbourg', 'tubingen', 'weimar', 'zurich'] +val_cities = ['darmstadt', 'monchengladbach', 'ulm', ] + +num_classes = 19 +batch_size = 10 +pre_crop_size = [256, 512] +patch_size = [256, 512] +n_train_batches = None +n_val_batches = 274 // batch_size +n_workers = 5 +ignore_label = 255 + +da_kwargs = { + 'random_crop': True, + 'rand_crop_dist': (patch_size[0] / 2., patch_size[1] / 2.), + 'do_elastic_deform': True, + 'alpha': (0., 800.), + 'sigma': (25., 35.), + 'do_rotation': True, + 'angle_x': (-np.pi / 8., np.pi / 8.), + 'angle_y': (0., 0.), + 'angle_z': (0., 0.), + 'do_scale': True, + 'scale': (0.8, 1.2), + 'border_mode_data': 'constant', + 'border_mode_seg': 'constant', + 'border_cval_seg': ignore_label, + 'gamma_retain_stats': True, + 'gamma_range': (0.7, 1.5), + 'p_gamma': 0.3 +} + +data_format = 'NCHW' +one_hot_labels = False + +######################################### +# label-switches # +######################################### + +color_map = {label.trainId:label.color for label in cs_labels_tuple} +color_map[255] = (0.,0.,0.) + +trainId2name = {labels.trainId: labels.name for labels in cs_labels_tuple} +name2trainId = {labels.name: labels.trainId for labels in cs_labels_tuple} + +label_switches = OrderedDict([('sidewalk', 8./17.), ('person', 7./17.), ('car', 6./17.), ('vegetation', 5./17.), ('road', 4./17.)]) +num_classes += len(label_switches) +switched_Id2name = {19+i:list(label_switches.keys())[i] + '_2' for i in range(len(label_switches))} +switched_name2Id = {list(label_switches.keys())[i] + '_2':19+i for i in range(len(label_switches))} +trainId2name = {**trainId2name, **switched_Id2name} +name2trainId = {**name2trainId, **switched_name2Id} + +switched_labels2color = {'road_2': (84, 86, 22), 'person_2': (167, 242, 242), 'vegetation_2': (242, 160, 19), + 'car_2': (30, 193, 252), 'sidewalk_2': (46, 247, 180)} +switched_cmap = {switched_name2Id[i]:switched_labels2color[i] for i in switched_name2Id.keys()} +color_map = {**color_map, **switched_cmap} + +######################################### +# network & training # +######################################### + +cuda_visible_devices = '0' +cpu_device = '/cpu:0' +gpu_device = '/gpu:0' + +network_input_shape = (None, 3) + tuple(patch_size) +network_output_shape = (None, num_classes) + tuple(patch_size) +label_shape = (None, 1) + tuple(patch_size) +loss_mask_shape = label_shape + +base_channels = 32 +num_channels = [base_channels, 2*base_channels, 4*base_channels, + 6*base_channels, 6*base_channels, 6*base_channels, 6*base_channels] + +num_convs_per_block = 3 + +n_training_batches = 240000 +validation = {'n_batches': n_val_batches, 'every_n_batches': 2000} + +learning_rate_schedule = 'piecewise_constant' +learning_rate_kwargs = {'values': [1e-4, 0.5e-4, 1e-5, 0.5e-6], + 'boundaries': [80000, 160000, 240000], + 'name': 'piecewise_constant_lr_decay'} +initial_learning_rate = learning_rate_kwargs['values'][0] + +regularizarion_weight = 1e-5 +latent_dim = 6 +num_1x1_convs = 3 +beta = 1.0 +analytic_kl = True +use_posterior_mean = False +save_every_n_steps = n_training_batches // 3 if n_training_batches >= 100000 else n_training_batches +disable_progress_bar = False + +exp_dir = "EXPERIMENT_OUTPUT_DIRECTORY_ABSOLUTE_PATH" diff --git a/training/train_prob_unet.py b/training/train_prob_unet.py new file mode 100644 index 0000000..82c2e92 --- /dev/null +++ b/training/train_prob_unet.py @@ -0,0 +1,188 @@ +# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Probabilistic U-Net training script.""" + +import tensorflow as tf +import numpy as np +import os +import time +from tqdm import tqdm +import shutil +import logging +import argparse +from importlib.machinery import SourceFileLoader + +from data.cityscapes.data_loader import get_train_generators +from model.probabilistic_unet import ProbUNet +import utils.training_utils as training_utils + + +def train(cf): + """Perform training from scratch.""" + + # do not use all gpus + os.environ["CUDA_VISIBLE_DEVICES"] = cf.cuda_visible_devices + + # initialize data providers + data_provider = get_train_generators(cf) + train_provider = data_provider['train'] + val_provider = data_provider['val'] + + prob_unet = ProbUNet(latent_dim=cf.latent_dim, num_channels=cf.num_channels, + num_1x1_convs=cf.num_1x1_convs, + num_classes=cf.num_classes, num_convs_per_block=cf.num_convs_per_block, + initializers={'w': training_utils.he_normal(), + 'b': tf.truncated_normal_initializer(stddev=0.001)}, + regularizers={'w': tf.contrib.layers.l2_regularizer(1.0)}) + + x = tf.placeholder(tf.float32, shape=cf.network_input_shape) + y = tf.placeholder(tf.uint8, shape=cf.label_shape) + mask = tf.placeholder(tf.uint8, shape=cf.loss_mask_shape) + + global_step = tf.train.get_or_create_global_step() + + if cf.learning_rate_schedule == 'piecewise_constant': + learning_rate = tf.train.piecewise_constant(x=global_step, **cf.learning_rate_kwargs) + else: + learning_rate = tf.train.exponential_decay(learning_rate=cf.initial_learning_rate, global_step=global_step, + **cf.learning_rate_kwargs) + with tf.device(cf.gpu_device): + prob_unet(x, y, is_training=True, one_hot_labels=cf.one_hot_labels) + elbo = prob_unet.elbo(y, reconstruct_posterior_mean=cf.use_posterior_mean, beta=cf.beta, loss_mask=mask, + analytic_kl=cf.analytic_kl, one_hot_labels=cf.one_hot_labels) + reconstructed_logits = prob_unet._rec_logits + sampled_logits = prob_unet.sample() + + reg_loss = cf.regularizarion_weight * tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = -elbo + reg_loss + rec_loss = prob_unet._rec_loss_mean + kl = prob_unet._kl + + mean_val_rec_loss = tf.placeholder(tf.float32, shape=(), name="mean_val_rec_loss") + mean_val_kl = tf.placeholder(tf.float32, shape=(), name="mean_val_kl") + + optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss, global_step=global_step) + + # prepare tf summaries + train_elbo_summary = tf.summary.scalar('train_elbo', elbo) + train_kl_summary = tf.summary.scalar('train_kl', kl) + train_rec_loss_summary = tf.summary.scalar('rec_loss', rec_loss) + train_loss_summary = tf.summary.scalar('train_loss', loss) + reg_loss_summary = tf.summary.scalar('train_reg_loss', reg_loss) + lr_summary = tf.summary.scalar('learning_rate', learning_rate) + beta_summary = tf.summary.scalar('beta', cf.beta) + training_summary_op = tf.summary.merge([train_loss_summary, reg_loss_summary, lr_summary, train_elbo_summary, + train_kl_summary, train_rec_loss_summary, beta_summary]) + batches_per_second = tf.placeholder(tf.float32, shape=(), name="batches_per_sec_placeholder") + timing_summary = tf.summary.scalar('batches_per_sec', batches_per_second) + val_rec_loss_summary = tf.summary.scalar('val_loss', mean_val_rec_loss) + val_kl_summary = tf.summary.scalar('val_kl', mean_val_kl) + validation_summary_op = tf.summary.merge([val_rec_loss_summary, val_kl_summary]) + + tf.global_variables_initializer() + + # Add ops to save and restore all the variables. + saver_hook = tf.train.CheckpointSaverHook(checkpoint_dir=cf.exp_dir, save_steps=cf.save_every_n_steps, + saver=tf.train.Saver(save_relative_paths=True)) + # save config + shutil.copyfile(cf.config_path, os.path.join(cf.exp_dir, 'used_config.py')) + + with tf.train.MonitoredTrainingSession(hooks=[saver_hook]) as sess: + summary_writer = tf.summary.FileWriter(cf.exp_dir, sess.graph) + logging.info('Model: {}'.format(cf.exp_dir)) + + for i in tqdm(range(cf.n_training_batches), disable=cf.disable_progress_bar): + + start_time = time.time() + train_batch = next(train_provider) + _, train_summary = sess.run([optimizer, training_summary_op], + feed_dict={x: train_batch['data'], y: train_batch['seg'], + mask: train_batch['loss_mask']}) + summary_writer.add_summary(train_summary, i) + time_delta = time.time() - start_time + train_speed = sess.run(timing_summary, feed_dict={batches_per_second: 1. / time_delta}) + summary_writer.add_summary(train_speed, i) + + # validation + if i % cf.validation['every_n_batches'] == 0: + + train_rec = sess.run(reconstructed_logits, feed_dict={x: train_batch['data'], y: train_batch['seg']}) + image_path = os.path.join(cf.exp_dir, + 'batch_{}_train_reconstructions.png'.format(i // cf.validation['every_n_batches'])) + training_utils.plot_batch(train_batch, train_rec, num_classes=cf.num_classes, + cmap=cf.color_map, out_dir=image_path) + + running_mean_val_rec_loss = 0. + running_mean_val_kl = 0. + + for j in range(cf.validation['n_batches']): + val_batch = next(val_provider) + val_rec, val_sample, val_rec_loss, val_kl =\ + sess.run([reconstructed_logits, sampled_logits, rec_loss, kl], + feed_dict={x: val_batch['data'], y: val_batch['seg'], mask: val_batch['loss_mask']}) + running_mean_val_rec_loss += val_rec_loss / cf.validation['n_batches'] + running_mean_val_kl += val_kl / cf.validation['n_batches'] + + if j == 0: + image_path = os.path.join(cf.exp_dir, + 'batch_{}_val_reconstructions.png'.format(i // cf.validation['every_n_batches'])) + training_utils.plot_batch(val_batch, val_rec, num_classes=cf.num_classes, + cmap=cf.color_map, out_dir=image_path) + image_path = os.path.join(cf.exp_dir, + 'batch_{}_val_samples.png'.format(i // cf.validation['every_n_batches'])) + + for _ in range(3): + val_sample_ = sess.run(sampled_logits, feed_dict={x: val_batch['data'], y: val_batch['seg']}) + val_sample = np.concatenate([val_sample, val_sample_], axis=1) + + training_utils.plot_batch(val_batch, val_sample, num_classes=cf.num_classes, + cmap=cf.color_map, out_dir=image_path) + + val_summary = sess.run(validation_summary_op, feed_dict={mean_val_rec_loss: running_mean_val_rec_loss, + mean_val_kl: running_mean_val_kl}) + summary_writer.add_summary(val_summary, i) + + if cf.disable_progress_bar: + logging.info('Evaluating epoch {}/{}: validation loss={}, kl={}'\ + .format(i, cf.n_training_batches, running_mean_val_rec_loss, running_mean_val_kl)) + + sess.run(global_step) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Training of the Probabilistic U-Net') + parser.add_argument('-c', '--config', type=str, default='prob_unet_config.py', + help='name of the python script defining the training configuration') + parser.add_argument('-d', '--data_dir', type=str, default='', + help="full path to the data, if empty the config's data_dir attribute is used") + args = parser.parse_args() + + # load config + cf = SourceFileLoader('cf', args.config).load_module() + if args.data_dir != '': + cf.data_dir = args.data_dir + + # prepare experiment directory + if not os.path.isdir(cf.exp_dir): + os.mkdir(cf.exp_dir) + + # log to file and console + log_path = os.path.join(cf.exp_dir, 'train.log') + logging.basicConfig(filename=log_path, level=logging.INFO) + logging.getLogger().addHandler(logging.StreamHandler()) + logging.info('Logging to {}'.format(log_path)) + tf.logging.set_verbosity(tf.logging.INFO) + + train(cf) \ No newline at end of file diff --git a/utils/training_utils.py b/utils/training_utils.py new file mode 100644 index 0000000..3571de6 --- /dev/null +++ b/utils/training_utils.py @@ -0,0 +1,249 @@ +# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Training utilities.""" + +import tensorflow as tf +import numpy as np +import matplotlib + +matplotlib.use('agg') +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +from tensorflow.python.ops.init_ops import VarianceScaling + + +def ce_loss(labels, logits, n_classes, loss_mask=None, data_format='NCHW', one_hot_labels=True, name='ce_loss'): + """ + Cross-entropy loss. + :param labels: 4D tensor + :param logits: 4D tensor + :param n_classes: integer for number of classes + :param loss_mask: binary 4D tensor, pixels to mask should be marked by 1s + :param data_format: string + :param one_hot_labels: bool, indicator for whether labels are to be expected in one-hot representation + :param name: string + :return: dict of (pixel-wise) mean and sum of cross-entropy loss + """ + with tf.variable_scope(name): + # permute class channels into last axis + if data_format == 'NCHW': + labels = tf.transpose(labels, [0,2,3,1]) + logits = tf.transpose(logits, [0,2,3,1]) + elif data_format == 'NCDHW': + labels = tf.transpose(labels, [0,2,3,4,1]) + logits = tf.transpose(logits, [0,2,3,4,1]) + + batch_size = tf.cast(tf.shape(labels)[0], tf.float32) + + if one_hot_labels: + flat_labels = tf.reshape(labels, [-1, n_classes]) + else: + flat_labels = tf.reshape(labels, [-1]) + flat_labels = tf.one_hot(indices=flat_labels, depth=n_classes, axis=-1) + flat_logits = tf.reshape(logits, [-1, n_classes]) + + # do not compute gradients wrt the labels + flat_labels = tf.stop_gradient(flat_labels) + + ce_per_pixel = tf.nn.softmax_cross_entropy_with_logits_v2(labels=flat_labels, logits=flat_logits) + + # optional element-wise masking with binary loss mask + if loss_mask is None: + ce_sum = tf.reduce_sum(ce_per_pixel) / batch_size + ce_mean = tf.reduce_mean(ce_per_pixel) + else: + loss_mask_flat = tf.reshape(loss_mask, [-1,]) + loss_mask_flat = (1. - tf.cast(loss_mask_flat, tf.float32)) + ce_sum = tf.reduce_sum(loss_mask_flat * ce_per_pixel) / batch_size + n_valid_pixels = tf.reduce_sum(loss_mask_flat) + ce_mean = tf.reduce_sum(loss_mask_flat * ce_per_pixel) / n_valid_pixels + + return {'sum': ce_sum, 'mean': ce_mean} + + +def softmax_2_onehot(arr): + """Transform a numpy array of softmax values into a one-hot encoded array. Assumes classes are encoded in axis 1. + :param arr: ND array + :return: ND array + """ + num_classes = arr.shape[1] + arr_argmax = np.argmax(arr, axis=1) + + for c in range(num_classes): + arr[:,c] = (arr_argmax == c).astype(np.uint8) + return arr + + +def numpy_one_hot(label_arr, num_classes): + """One-hotify an integer-labeled numpy array. One-hot encoding is encoded in additional last axis. + :param label_arr: ND array + :param num_classes: integer + :return: (N+1)D array + """ + # replace labels >= num_classes with 0 + label_arr[label_arr >= num_classes] = 0 + + res = np.eye(num_classes)[np.array(label_arr).reshape(-1)] + return res.reshape(list(label_arr.shape)+[num_classes]) + + +def calc_confusion(labels, samples, class_ixs, loss_mask=None): + """ + Compute confusion matrix for each class across the given arrays. + Assumes classes are given in integer-valued encoding. + :param labels: 4/5D array + :param samples: 4/5D array + :param class_ixs: integer or list of integers specifying the classes to evaluate + :param loss_mask: 4/5D array + :return: 2D array + """ + try: + assert labels.shape == samples.shape + except: + raise AssertionError('shape mismatch {} vs. {}'.format(labels.shape, samples.shape)) + + if isinstance(class_ixs, int): + num_classes = class_ixs + class_ixs = range(class_ixs) + elif isinstance(class_ixs, list): + num_classes = len(class_ixs) + else: + raise TypeError('arg class_ixs needs to be int or list, not {}.'.format(type(class_ixs))) + + if loss_mask is None: + shp = labels.shape + loss_mask = np.zeros(shape=(shp[0], 1, shp[2], shp[3])) + + conf_matrix = np.zeros(shape=(num_classes, 4), dtype=np.float32) + for i,c in enumerate(class_ixs): + + pred_ = (samples == c).astype(np.uint8) + labels_ = (labels == c).astype(np.uint8) + + conf_matrix[i,0] = int(((pred_ != 0) * (labels_ != 0) * (loss_mask != 1)).sum()) # TP + conf_matrix[i,1] = int(((pred_ != 0) * (labels_ == 0) * (loss_mask != 1)).sum()) # FP + conf_matrix[i,2] = int(((pred_ == 0) * (labels_ == 0) * (loss_mask != 1)).sum()) # TN + conf_matrix[i,3] = int(((pred_ == 0) * (labels_ != 0) * (loss_mask != 1)).sum()) # FN + + return conf_matrix + + +def metrics_from_conf_matrix(conf_matrix): + """ + Calculate IoU per class from a confusion_matrix. + :param conf_matrix: 2D array of shape (num_classes, 4) + :return: dict holding 1D-vectors of metrics + """ + tps = conf_matrix[:,0] + fps = conf_matrix[:,1] + fns = conf_matrix[:,3] + + metrics = {} + metrics['iou'] = np.zeros_like(tps, dtype=np.float32) + + # iterate classes + for c in range(tps.shape[0]): + # unless both the prediction and the ground-truth is empty, calculate a finite IoU + if tps[c] + fps[c] + fns[c] != 0: + metrics['iou'][c] = tps[c] / (tps[c] + fps[c] + fns[c]) + else: + metrics['iou'][c] = np.nan + + return metrics + + +def he_normal(seed=None): + """He normal initializer. + It draws samples from a truncated normal distribution centered on 0 + with `stddev = sqrt(2 / fan_in)` + where `fan_in` is the number of input units in the weight tensor. + Arguments: + seed: A Python integer. Used to seed the random generator. + Returns: + An initializer. + References: + He et al., http://arxiv.org/abs/1502.01852 + Code: + https://github.com/tensorflow/tensorflow/blob/r1.9/tensorflow/python/keras/initializers.py + """ + return VarianceScaling(scale=2., mode='fan_in', distribution='normal', seed=seed) + + +def plot_batch(batch, prediction, cmap, num_classes, out_dir=None, clip_range=True): + """ + Plots a batch of images, segmentations & samples and optionally saves it to disk. + :param batch: dict holding images and gt labels for a batch + :param prediction: logit prediction of the corresponding batch + :param cmap: dictionary as colormap + :param out_dir: full path to save png image to + :return: + """ + img_arr = batch['data'] + seg_arr = batch['seg'] + + num_predictions = prediction.shape[1] // num_classes + num_y_tiles = 2 + num_predictions + batch_size = img_arr.shape[0] + + f = plt.figure(figsize=(batch_size * 4, num_y_tiles * 2)) + gs = gridspec.GridSpec(num_y_tiles, batch_size, wspace=0.0, hspace=0.0) + + # suppress matplotlib range warnings + if clip_range: + img_arr[img_arr < 0.] = 0. + img_arr[img_arr > 1.] = 1. + + for tile in range(batch_size): + # image + ax = plt.subplot(gs[0, tile]) + ax.get_xaxis().set_visible(False) + ax.get_yaxis().set_visible(False) + plt.imshow(np.transpose(img_arr[tile], axes=[1,2,0])) + + # (here sampled) gt segmentation + ax = plt.subplot(gs[1, tile]) + ax.get_xaxis().set_visible(False) + ax.get_yaxis().set_visible(False) + if seg_arr.shape[1] == 1: + gt_seg = np.squeeze(seg_arr[tile], axis=0) + else: + gt_seg = np.argmax(seg_arr[tile], axis=0) + plt.imshow(to_rgb(gt_seg, cmap)) + + # multiple predictions can be concatenated in channel axis, iterate all predictions + for i in range(num_predictions): + ax = plt.subplot(gs[2 + i, tile]) + ax.get_xaxis().set_visible(False) + ax.get_yaxis().set_visible(False) + single_prediction = prediction[tile][i * num_classes: (i+1) * num_classes] + pred_seg = np.argmax(single_prediction, axis=0) + plt.imshow(to_rgb(pred_seg, cmap)) + if out_dir is not None: + plt.savefig(out_dir, dpi=200, bbox_inches='tight', pad_inches=0.0) + plt.close(f) + + +def to_rgb(arr, cmap): + """ + Transform an integer-labeled segmentation map using an rgb color-map. + :param arr: img_arr w/o a color-channel + :param cmap: dictionary mapping from integer class labels to rgb values + :return: + """ + new_arr = np.zeros(shape=(arr.shape)+(3,)) + for c in cmap.keys(): + ixs = np.where(arr == c) + new_arr[ixs] = [cmap[c][i] / 255. for i in range(3)] + return new_arr