diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..0df94ea
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..ccb5991
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..5ba649c
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/probunet.iml b/.idea/probunet.iml
new file mode 100644
index 0000000..6711606
--- /dev/null
+++ b/.idea/probunet.iml
@@ -0,0 +1,11 @@
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/workspace.xml b/.idea/workspace.xml
new file mode 100644
index 0000000..80f1354
--- /dev/null
+++ b/.idea/workspace.xml
@@ -0,0 +1,403 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ TODO
+ imp
+
+
+
+
+
+
+
+
+
+
+
+ true
+ DEFINITION_ORDER
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 1536748173508
+
+
+ 1536748173508
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..938c674
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,203 @@
+Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). All rights reserved.
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2018 Simon Kohl.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/README.md b/README.md
index 3fa444e..5d515c7 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,106 @@
# Probabilistic U-Net
-Coming shortly: Re-implementation of the model described in `A Probabilistic U-Net for Segmentation of Ambiguous Images' ([arxiv.org/abs/1806.05034](https://arxiv.org/abs/1806.05034)).
+Re-implementation of the model described in `A Probabilistic U-Net for Segmentation of Ambiguous Images' ([arxiv.org/abs/1806.05034](https://arxiv.org/abs/1806.05034)).
+The architecture of the Probabilistic U-Net is depicted below: subfigure a) shows sampling and b) the training setup:
+![](assets/architecture.png)
+
+Below see samples conditioned on held-out validation set images from the (stochastic) CityScapes data set:
+![](assets/10_image_16_sample.gif)
+
+## Setup package in virtual environment
+
+```
+git clone https://phabricator.mitk.org/source/skohl/browse/master/prob_unet .
+cd prob_unet/
+virtualenv -p python3 venv
+source venv/bin/activate
+pip3 install -e .
+```
+
+## Install batch-generators for data augmentation
+```
+cd ..
+git clone https://github.com/MIC-DKFZ/batchgenerators
+cd batchgenerators
+pip3 install nilearn scikit-image nibabel
+pip3 install -e .
+cd prob_unet
+```
+
+## Download & preprocess the Cityscapes dataset
+
+1) Create a login account on the Cityscapes website: https://www.cityscapes-dataset.com/
+2) Once you've logged in, download the train, val and test annotations and images:
+ - Annotations: [gtFine_trainvaltest.zip](https://www.cityscapes-dataset.com/file-handling/?packageID=1) (241MB)
+ - Images: [leftImg8bit_trainvaltest.zip](https://www.cityscapes-dataset.com/file-handling/?packageID=3) (11GB)
+3) unzip the data (unzip _trainvaltest.zip) and adjust `raw_data_dir` (full path to unzipped files) and `out_dir` (full path to desired output directory) in `preprocessing_config.py`
+4) bilinearly rescale the data to a resolution of 256 x 512 and save as numpy arrays by running
+```
+cd cityscapes
+python3 preprocessing.py
+cd ..
+```
+
+## Training
+
+[skip to evaluation in case you only want to use the pretrained model.]
+modify `data_dir` and `exp_dir` in `scripts/prob_unet_config.py` then:
+```
+cd training
+python3 train_prob_unet.py --config prob_unet_config.py
+```
+
+## Evaluation
+
+Load your own trained model or use a pretrained model. A set of pretrained weights can be downloaded from [zenodo.org](https://zenodo.org/record/1419051#.W5utoOEzYUE) (187MB). After down-loading, unpack the file via
+`tar -xvzf pretrained_weights.tar.gz`, e.g. in `/model`. In either case (using your own or the pretrained model), modify the `data_dir` and
+`exp_dir` in `evaluation/cityscapes_eval_config.py` to match you paths.
+
+then first write samples (defaults to 16 segmentation samples for each of the 500 validation images):
+```
+cd ../evaluation
+python3 eval_cityscapes.py --write_samples
+```
+followed by their evaluation (which is multi-threaded and thus reasonably fast):
+```
+python3 eval_cityscapes.py --eval_samples
+```
+The evaluation produces a dictionary holding the results. These can be visualized by launching an ipython notbook:
+```
+jupyter notebook evaluation_plots.ipynb
+```
+The following results are obtained from the pretrained model using above notebook:
+![](assets/validation_results.png)
+
+## Tests
+
+The evaluation metrics are under test-coverage. Run the tests as follows:
+```
+cd ../tests/evaluation
+python3 -m pytest eval_tests.py
+```
+
+## Deviations from original work
+
+The code found in this repository was not used in the original paper and slight modifications apply:
+
+- training on a single gpu (Titan Xp) instead of distributed training, which is not supported in this implementation
+- average-pooling rather than bilinear interpolation is used for down-sampling operations in the model
+- the number of conv kernels is kept constant after the 3rd scale as opposed to strictly doubling it after each scale (for reduction of memory footprint)
+- HeNormal weight initialization worked better than a orthogonal weight initialization
+
+
+## How to cite this code
+Please cite the original publication:
+```
+@article{kohl2018probabilistic,
+ title={A Probabilistic U-Net for Segmentation of Ambiguous Images},
+ author={Kohl, Simon AA and Romera-Paredes, Bernardino and Meyer, Clemens and De Fauw, Jeffrey and Ledsam, Joseph R and Maier-Hein, Klaus H and Eslami, SM and Rezende, Danilo Jimenez and Ronneberger, Olaf},
+ journal={arXiv preprint arXiv:1806.05034},
+ year={2018}
+}
+```
+
+## License
+The code is publihed under the [Apache License Version 2.0](LICENSE).
\ No newline at end of file
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/assets/10_image_16_sample.gif b/assets/10_image_16_sample.gif
new file mode 100644
index 0000000..3793e85
Binary files /dev/null and b/assets/10_image_16_sample.gif differ
diff --git a/assets/architecture.png b/assets/architecture.png
new file mode 100644
index 0000000..2e685d7
Binary files /dev/null and b/assets/architecture.png differ
diff --git a/assets/validation_results.png b/assets/validation_results.png
new file mode 100644
index 0000000..36420df
Binary files /dev/null and b/assets/validation_results.png differ
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/data/cityscapes/__init__.py b/data/cityscapes/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/data/cityscapes/cityscapes_labels.py b/data/cityscapes/cityscapes_labels.py
new file mode 100644
index 0000000..be3e5a9
--- /dev/null
+++ b/data/cityscapes/cityscapes_labels.py
@@ -0,0 +1,120 @@
+#!/usr/bin/python
+#
+# Cityscapes labels, code from https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py
+#
+
+from collections import namedtuple
+
+
+#--------------------------------------------------------------------------------
+# Definitions
+#--------------------------------------------------------------------------------
+
+# a label and all meta information
+Label = namedtuple( 'Label' , [
+
+ 'name' , # The identifier of this label, e.g. 'car', 'person', ... .
+ # We use them to uniquely name a class
+
+ 'id' , # An integer ID that is associated with this label.
+ # The IDs are used to represent the label in ground truth images
+ # An ID of -1 means that this label does not have an ID and thus
+ # is ignored when creating ground truth images (e.g. license plate).
+ # Do not modify these IDs, since exactly these IDs are expected by the
+ # evaluation server.
+
+ 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create
+ # ground truth images with train IDs, using the tools provided in the
+ # 'preparation' folder. However, make sure to validate or submit results
+ # to our evaluation server using the regular IDs above!
+ # For trainIds, multiple labels might have the same ID. Then, these labels
+ # are mapped to the same class in the ground truth images. For the inverse
+ # mapping, we use the label that is defined first in the list below.
+ # For example, mapping all void-type classes to the same ID in training,
+ # might make sense for some approaches.
+ # Max value is 255!
+
+ 'category' , # The name of the category that this label belongs to
+
+ 'categoryId' , # The ID of this category. Used to create ground truth images
+ # on category level.
+
+ 'hasInstances', # Whether this label distinguishes between single instances or not
+
+ 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
+ # during evaluations or not
+
+ 'color' , # The color of this label
+ ] )
+
+
+#--------------------------------------------------------------------------------
+# A list of all labels
+#--------------------------------------------------------------------------------
+
+# Please adapt the train IDs as appropriate for your approach.
+# Note that you might want to ignore labels with ID 255 during training.
+# Further note that the current train IDs are only a suggestion. You can use whatever you like.
+# Make sure to provide your results using the original IDs and not the training IDs.
+# Note that many IDs are ignored in evaluation and thus you never need to predict these!
+
+labels = [
+ # name id trainId category catId hasInstances ignoreInEval color
+ Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
+ Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
+ Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
+ Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
+ Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
+ Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ),
+ Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ),
+ Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ),
+ Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ),
+ Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ),
+ Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ),
+ Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ),
+ Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ),
+ Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ),
+ Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ),
+ Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ),
+ Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ),
+ Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ),
+ Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ),
+ Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ),
+ Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ),
+ Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ),
+ Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ),
+ Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ),
+ Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ),
+ Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ),
+ Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ),
+ Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ),
+ Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ),
+ Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ),
+ Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ),
+ Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ),
+ Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ),
+ Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ),
+ Label( 'license plate' , -1 , 255 , 'vehicle' , 7 , False , True , ( 0, 0,142) ),
+]
+
+
+#--------------------------------------------------------------------------------
+# Create dictionaries for a fast lookup
+#--------------------------------------------------------------------------------
+
+# Please refer to the main method below for example usages!
+
+# name to label object
+name2label = { label.name : label for label in labels }
+# id to label object
+id2label = { label.id : label for label in labels }
+# trainId to label object
+trainId2label = { label.trainId : label for label in reversed(labels) }
+# category to list of label objects
+category2labels = {}
+for label in labels:
+ category = label.category
+ if category in category2labels:
+ category2labels[category].append(label)
+ else:
+ category2labels[category] = [label]
diff --git a/data/cityscapes/data_loader.py b/data/cityscapes/data_loader.py
new file mode 100644
index 0000000..2c56f67
--- /dev/null
+++ b/data/cityscapes/data_loader.py
@@ -0,0 +1,353 @@
+# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to serve the CityScapes dataset."""
+
+import os
+import sys, glob
+import numpy as np
+import imp
+import logging
+
+from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
+from batchgenerators.transforms.spatial_transforms import MirrorTransform
+from batchgenerators.transforms.abstract_transforms import Compose
+from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
+from batchgenerators.transforms.spatial_transforms import SpatialTransform
+from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform
+from batchgenerators.transforms import AbstractTransform
+from .cityscapes_labels import labels as cityscapes_labels_tuple
+
+def loadFiles(label_density, split, input_path, cities=None, instance=False):
+ """
+ Assemble dict of file paths.
+ :param label_density: string in ['gtFine', 'gtCoarse']
+ :param split: string in ['train', 'val', 'test', 'train_extra']
+ :param input_path: string
+ :param cities: list of strings or None
+ :param instance: bool
+ :return: dict
+ """
+ input_dir = os.path.join(input_path, split)
+ logging.info("Assembling file dict from {}.".format(input_dir))
+ paths_dict = {}
+ for path, dirs, files in os.walk(input_dir):
+
+ skip_city = False
+ if cities is not None:
+ current_city = path.rsplit('/', 1)[-1]
+ if current_city not in cities:
+ skip_city = True
+
+ if not skip_city:
+ logging.info('Reading from {}'.format(path))
+ label_paths, img_paths = searchFiles(path, label_density, instance)
+ paths_dict = {**paths_dict, **zipPaths(label_paths, img_paths)}
+
+ return paths_dict
+
+
+def searchFiles(path, label_density, instance=False):
+ """
+ Get file paths via wildcard search.
+ :param path: path to files for each city
+ :param label_density: string in ['gtFine', 'gtCoarse']
+ :param instance: bool
+ :return: 2 lists
+ """
+ if (instance == True):
+ label_wildcard_search = os.path.join(path, "*{}_instanceIDs.npy".format(label_density))
+ else:
+ label_wildcard_search = os.path.join(path, "*{}_labelIds.npy".format(label_density))
+ label_paths = glob.glob(label_wildcard_search)
+ label_paths.sort()
+ img_wildcard_search = os.path.join(path, "*_leftImg8bit.npy")
+ img_paths = glob.glob(img_wildcard_search)
+ img_paths.sort()
+ return label_paths, img_paths
+
+
+def zipPaths(label_paths, img_paths):
+ """
+ zip paths in form of dict.
+ :param label_paths: list of strings
+ :param img_paths: list of strings
+ :return: dict
+ """
+ try:
+ assert len(label_paths) == len(img_paths)
+ except:
+ raise Exception('Missmatch: {} label paths vs. {} img paths!'.format(len(label_paths), len(img_paths)))
+
+ paths_dict = {}
+ for i, img_path in enumerate(img_paths):
+ img_spec = ('_').join(img_paths[i].split('/')[-1].split('_')[:-1])
+ try:
+ assert img_spec in label_paths[i]
+ except:
+ raise Exception('img and label name mismatch: {} vs. {}'.format(img_paths[i], label_paths[i]))
+
+ paths_dict[img_spec] = {"data": img_paths[i], "seg": label_paths[i], 'img_spec': img_spec}
+ return paths_dict
+
+
+def augment_gamma(data, gamma_range=(0.5, 2), invert_image=False, epsilon=1e-7, per_channel=False, retain_stats=False, p_per_sample=0.3):
+ """code by Fabian Isensee, see MIC_DKFZ/batch_generators on github."""
+ for sample in range(data.shape[0]):
+ if np.random.uniform() < p_per_sample:
+ if invert_image:
+ data = - data
+ if not per_channel:
+ if retain_stats:
+ mn = data[sample].mean()
+ sd = data[sample].std()
+ if np.random.random() < 0.5 and gamma_range[0] < 1:
+ gamma = np.random.uniform(gamma_range[0], 1)
+ else:
+ gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1])
+ minm = data[sample].min()
+ rnge = data[sample].max() - minm
+ data[sample] = np.power(((data[sample] - minm) / float(rnge + epsilon)), gamma) * rnge + minm
+ if retain_stats:
+ data[sample] = data[sample] - data[sample].mean() + mn
+ data[sample] = data[sample] / (data[sample].std() + 1e-8) * sd
+ else:
+ for c in range(data.shape[1]):
+ if retain_stats:
+ mn = data[sample][c].mean()
+ sd = data[sample][c].std()
+ if np.random.random() < 0.5 and gamma_range[0] < 1:
+ gamma = np.random.uniform(gamma_range[0], 1)
+ else:
+ gamma = np.random.uniform(max(gamma_range[0], 1), gamma_range[1])
+ minm = data[sample][c].min()
+ rnge = data[sample][c].max() - minm
+ data[sample][c] = np.power(((data[sample][c] - minm) / float(rnge + epsilon)), gamma) * rnge + minm
+ if retain_stats:
+ data[sample][c] = data[sample][c] - data[sample][c].mean() + mn
+ data[sample][c] = data[sample][c] / (data[sample][c].std() + 1e-8) * sd
+ return data
+
+
+class GammaTransform(AbstractTransform):
+ """Augments by changing 'gamma' of the image (same as gamma correction in photos or computer monitors
+
+ Args:
+ gamma_range (tuple of float): range to sample gamma from. If one value is smaller than 1 and the other one is
+ larger then half the samples will have gamma <1 and the other >1 (in the inverval that was specified)
+
+ invert_image: whether to invert the image before applying gamma augmentation
+
+ retain_stats: Gamma transformation will alter the mean and std of the data in the patch. If retain_stats=True,
+ the data will be transformed to match the mean and standard deviation before gamma augmentation
+
+ """
+
+ def __init__(self, gamma_range=(0.5, 2), invert_image=False, per_channel=False, data_key="data", retain_stats=False,
+ p_per_sample=0.3, mask_channel_in_seg=None):
+ self.mask_channel_in_seg = mask_channel_in_seg
+ self.p_per_sample = p_per_sample
+ self.retain_stats = retain_stats
+ self.per_channel = per_channel
+ self.data_key = data_key
+ self.gamma_range = gamma_range
+ self.invert_image = invert_image
+
+ def __call__(self, **data_dict):
+ data_dict[self.data_key] = augment_gamma(data_dict[self.data_key], self.gamma_range, self.invert_image,
+ per_channel=self.per_channel, retain_stats=self.retain_stats,
+ p_per_sample=self.p_per_sample)
+ return data_dict
+
+
+def map_labels_to_trainId(arr):
+ """Remap ids to corresponding training Ids. Note that the inplace mapping works because id > trainId here!"""
+ id2trainId = {label.id:label.trainId for label in cityscapes_labels_tuple}
+ for id, trainId in id2trainId.items():
+ arr[arr == id] = trainId
+ return arr
+
+
+class AddLossMask(AbstractTransform):
+ """Splits one-hot segmentation into a segmentation array and a loss mask,
+ where the loss mask needs to be encoded as the next available integer larger than the last segmentation labels.
+
+ Args:
+ classes (tuple of int): All the class labels that are in the dataset
+
+ output_key (string): key to use for output of the one hot encoding. Default is 'seg' but that will override any
+ other existing seg channels. Therefore you have the option to change that. BEWARE: Any non-'seg' segmentations
+ will not be augmented anymore. Use this only at the very end of your pipeline!
+ """
+
+ def __init__(self, label2mask, output_key="loss_mask"):
+ self.output_key = output_key
+ self.label2mask = label2mask
+
+ def __call__(self, **data_dict):
+ seg = data_dict['seg']
+ if seg is not None:
+ data_dict[self.output_key] = (seg == self.label2mask).astype(np.uint8)
+ else:
+ from warnings import warn
+ warn("calling AddLossMask but there is no segmentation")
+ return data_dict
+
+
+class StochasticLabelSwitches(AbstractTransform):
+ """
+ Stochastically switches labels in a batch of integer-labeled segmentations.
+ """
+ def __init__(self, name2id, label_switches):
+ self._name2id = name2id
+ self._label_switches = label_switches
+
+ def __call__(self, **data_dict):
+
+ switched_seg = data_dict['seg']
+ batch_size = switched_seg.shape[0]
+
+ for c, p in self._label_switches.items():
+ init_id = self._name2id[c]
+ final_id = self._name2id[c + '_2']
+ switch_instances = np.random.binomial(1, p, batch_size)
+
+ for i in range(batch_size):
+ if switch_instances[i]:
+ switched_seg[i][switched_seg[i] == init_id] = final_id
+
+ data_dict['seg'] = switched_seg
+ return data_dict
+
+
+class BatchGenerator(SlimDataLoaderBase):
+ """
+ create the training/validation batch generator. Randomly sample n_batch_size patients
+ from the data set, (draw a random slice if 2D), pad-crop them to equal sizes and merge to an array.
+ :param data: data dictionary as provided by 'load_dataset'
+ :param batch_size: number of patients to sample for the batch
+ :param pre_crop_size: equal size for merging the patients to a single array (before the final random-crop in data aug.)
+ :return dictionary containing the batch data / seg / pids
+ """
+ def __init__(self, batch_size, data_dir, label_density='gtFine', data_split='train', resolution='quarter',
+ cities=None, gt_instances=False, n_batches=None, random=True):
+ super(BatchGenerator, self).__init__(data=None, batch_size=batch_size)
+
+ data_dir = os.path.join(data_dir, resolution)
+ self._data_dir = data_dir
+ self._label_density = label_density
+ self._gt_instances = gt_instances
+ self._data_split = data_split
+ self._random = random
+ self._n_batches = n_batches
+ self._batches_generated = 0
+ self._data = loadFiles(label_density, data_split, data_dir, cities=cities, instance=gt_instances)
+ logging.info('{} set comprises {} files.'.format(data_split, len(self._data)))
+
+ def generate_train_batch(self):
+
+ if self._random:
+ img_ixs = np.random.choice(list(self._data.keys()), self.batch_size, replace=True)
+ else:
+ batch_no = self._batches_generated % self._n_batches
+ img_ixs = [list(self._data.keys())[i] for i in\
+ np.arange(batch_no * self.batch_size, (batch_no + 1) * self.batch_size)]
+ img_batch, seg_batch, ids_batch = [], [], []
+
+ for b in range(self.batch_size):
+
+ img = np.load(self._data[img_ixs[b]]['data']) / 255.
+ seg = np.load(self._data[img_ixs[b]]['seg'])
+ seg = map_labels_to_trainId(seg)
+ seg = seg[np.newaxis]
+ ids_batch.append(self._data[img_ixs[b]]['img_spec'])
+
+ img_batch.append(img)
+ seg_batch.append(seg)
+
+ self._batches_generated += 1
+ batch = {'data': np.array(img_batch).astype('float32'), 'seg': np.array(seg_batch).astype('uint8'),
+ 'id': ids_batch}
+ return batch
+
+
+def create_data_gen_pipeline(cf, cities=None, data_split='train', do_aug=True, random=True, n_batches=None):
+ """
+ create mutli-threaded train/val/test batch generation and augmentation pipeline.
+ :param cities: list of strings or None
+ :param patient_data: dictionary containing one dictionary per patient in the train/test subset
+ :param test_pids: (optional) list of test patient ids, calls the test generator.
+ :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing)
+ :param random: bool, whether to draw random batches or go through data linearly
+ :return: multithreaded_generator
+ """
+ data_gen = BatchGenerator(cities=cities, batch_size=cf.batch_size, data_dir=cf.data_dir,
+ label_density=cf.label_density, data_split=data_split, resolution=cf.resolution,
+ gt_instances=cf.gt_instances, n_batches=n_batches, random=random)
+ my_transforms = []
+ if do_aug:
+ mirror_transform = MirrorTransform(axes=(3,))
+ my_transforms.append(mirror_transform)
+ spatial_transform = SpatialTransform(patch_size=cf.patch_size[-2:],
+ patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
+ do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
+ alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
+ do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
+ angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
+ do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
+ random_crop=cf.da_kwargs['random_crop'],
+ border_mode_data=cf.da_kwargs['border_mode_data'],
+ border_mode_seg=cf.da_kwargs['border_mode_seg'],
+ border_cval_seg=cf.da_kwargs['border_cval_seg'])
+ my_transforms.append(spatial_transform)
+ else:
+ my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[-2:]))
+
+ my_transforms.append(GammaTransform(cf.da_kwargs['gamma_range'], invert_image=False, per_channel=True,
+ retain_stats=cf.da_kwargs['gamma_retain_stats'],
+ p_per_sample=cf.da_kwargs['p_gamma']))
+ my_transforms.append(AddLossMask(cf.ignore_label))
+ if cf.label_switches is not None:
+ my_transforms.append(StochasticLabelSwitches(cf.name2trainId, cf.label_switches))
+ all_transforms = Compose(my_transforms)
+ multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers,
+ seeds=range(cf.n_workers))
+ return multithreaded_generator
+
+
+def get_train_generators(cf):
+ """
+ wrapper function for creating the training batch generator pipeline. returns the train/val generators
+ """
+ batch_gen = {}
+ batch_gen['train'] = create_data_gen_pipeline(cf=cf, cities=cf.train_cities, data_split='train', do_aug=True,
+ n_batches=cf.n_train_batches)
+ batch_gen['val'] = create_data_gen_pipeline(cf=cf, cities=cf.val_cities, data_split='train', do_aug=False,
+ random=False, n_batches=cf.n_val_batches)
+ return batch_gen
+
+def main():
+ """Main entry point for the script."""
+ logging.info("start loading.")
+ cf = imp.load_source('cf', 'config.py')
+ dict = loadFiles("gtFine", "train", cf.out_dir, False)
+ logging.info('Contains {} elements.'.format(len(dict)))
+ logging.info(dict)
+ data_provider = BatchGenerator(8, cf.out_dir, data_split='val')
+ batch = next(data_provider)
+ logging.info(batch['data'].shape, batch['seg'].shape, batch['id'])
+
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/data/cityscapes/preprocessing.py b/data/cityscapes/preprocessing.py
new file mode 100644
index 0000000..1b17046
--- /dev/null
+++ b/data/cityscapes/preprocessing.py
@@ -0,0 +1,99 @@
+# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to preprocess the Cityscapes dataset."""
+
+import os
+import numpy as np
+from tqdm import tqdm
+from PIL import Image
+import imp
+
+resolution_map = {1.0: 'full', 0.5: 'half', 0.25: 'quarter'}
+
+def resample(img, scale_factor=1.0, interpolation=Image.BILINEAR):
+ """
+ Resample PIL.Image objects.
+ :param img: PIL.Image object
+ :param scale_factor: float
+ :param interpolation: PIL.Image interpoaltion method
+ :return: PIL.Image object
+ """
+ width, height = img.size
+ basewidth = width * scale_factor
+ basewidth = int(basewidth)
+ wpercent = (basewidth / float(width))
+ hsize = int((float(height) * wpercent))
+ return img.resize((basewidth, hsize), interpolation)
+
+def recursive_mkdir(nested_dir_list):
+ """
+ Make the full nested path of directories provided. Order in list implies nesting depth.
+ :param nested_dir_list: list of strings
+ :return:
+ """
+ nested_dir = ''
+ for dir in nested_dir_list:
+ nested_dir = os.path.join(nested_dir, dir)
+ if not os.path.isdir(nested_dir):
+ os.mkdir(nested_dir)
+ return
+
+def preprocess(cf):
+
+ for set in list(cf.settings.keys()):
+ print('Processing {} set.'.format(set))
+
+ # image dir
+ image_dir = os.path.join(cf.raw_data_dir, 'leftImg8bit', set)
+ city_names = os.listdir(image_dir)
+
+ for city in city_names:
+ print('Processing {}'.format(city))
+ city_dir = os.path.join(image_dir, city)
+ image_names = os.listdir(city_dir)
+ image_specifiers = ['_'.join(img.split('_')[:3]) for img in image_names]
+
+ for img_spec in tqdm(image_specifiers):
+ for scale in cf.settings[set]['resolutions']:
+ recursive_mkdir([cf.out_dir, resolution_map[scale], set, city])
+
+ # image
+ img_path = os.path.join(city_dir, img_spec + '_leftImg8bit.png')
+ img = Image.open(img_path)
+ if scale != 1.0:
+ img = resample(img, scale_factor=scale, interpolation=Image.BILINEAR)
+ img_out_path = os.path.join(cf.out_dir, resolution_map[scale], set, city, img_spec + '_leftImg8bit.npy')
+ img_arr = np.array(img).astype(np.float32)
+
+ channel_axis = 0 if img_arr.shape[0] == 3 else 2
+ if cf.data_format == 'NCHW' and channel_axis != 0:
+ img_arr = np.transpose(img_arr, axes=[2,0,1])
+ np.save(img_out_path, img_arr)
+
+ # labels
+ for label_density in cf.settings[set]['label_densities']:
+ label_dir = os.path.join(cf.raw_data_dir, label_density, set, city)
+ for mod in cf.settings[set]['label_modalities']:
+ label_spec = img_spec + '_{}_{}'.format(label_density, mod)
+ label_path = os.path.join(label_dir, label_spec + '.png')
+ label = Image.open(label_path)
+ if scale != 1.0:
+ label = resample(label, scale_factor=scale, interpolation=Image.NEAREST)
+ label_out_path = os.path.join(cf.out_dir, resolution_map[scale], set, city, label_spec + '.npy')
+ np.save(label_out_path, np.array(label).astype(np.uint8))
+
+if __name__ == "__main__":
+ cf = imp.load_source('cf', 'preprocessing_config.py')
+ preprocess(cf)
\ No newline at end of file
diff --git a/data/cityscapes/preprocessing_config.py b/data/cityscapes/preprocessing_config.py
new file mode 100644
index 0000000..548ae40
--- /dev/null
+++ b/data/cityscapes/preprocessing_config.py
@@ -0,0 +1,28 @@
+# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Cityscapes preprocessing config."""
+
+raw_data_dir = 'SET_INPUT_DIRECTORY_ABSOLUTE_PATH_HERE'
+out_dir = 'SET_OUTPUT_DIRECTORY_ABSOLUTE_PATH_HERE'
+
+# settings:
+settings = {
+ 'train': {'resolutions': [1.0, 0.5, 0.25], 'label_densities': ['gtFine'],
+ 'label_modalities': ['labelIds']},
+ 'val': {'resolutions': [1.0, 0.5, 0.25], 'label_densities': ['gtFine'],
+ 'label_modalities': ['labelIds']},
+ }
+
+data_format = 'NCHW'
diff --git a/evaluation/__init__.py b/evaluation/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/evaluation/cityscapes_eval_config.py b/evaluation/cityscapes_eval_config.py
new file mode 100644
index 0000000..5fa1792
--- /dev/null
+++ b/evaluation/cityscapes_eval_config.py
@@ -0,0 +1,92 @@
+# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Cityscapes evaluation config."""
+
+import os
+from collections import OrderedDict
+from data.cityscapes.cityscapes_labels import labels as cs_labels_tuple
+from model import pretrained_weights
+
+config_path = os.path.realpath(__file__)
+
+#########################################
+# data #
+#########################################
+
+data_dir = 'PREPROCESSING_OUTPUT_DIRECTORY_ABSOLUTE_PATH'
+resolution = 'quarter'
+label_density = 'gtFine'
+num_classes = 19
+one_hot_labels = False
+ignore_label = 255
+cities = ['frankfurt', 'lindau', 'munster']
+
+#########################################
+# label-switches #
+#########################################
+
+color_map = {label.trainId:label.color for label in cs_labels_tuple}
+color_map[255] = (0.,0.,0.)
+
+trainId2name = {labels.trainId: labels.name for labels in cs_labels_tuple}
+name2trainId = {labels.name: labels.trainId for labels in cs_labels_tuple}
+
+label_switches = OrderedDict([('sidewalk', 8./17.), ('person', 7./17.), ('car', 6./17.), ('vegetation', 5./17.), ('road', 4./17.)])
+num_classes += len(label_switches)
+switched_Id2name = {19+i:list(label_switches.keys())[i] + '_2' for i in range(len(label_switches))}
+switched_name2Id = {list(label_switches.keys())[i] + '_2':19+i for i in range(len(label_switches))}
+trainId2name = {**trainId2name, **switched_Id2name}
+name2trainId = {**name2trainId, **switched_name2Id}
+
+switched_labels2color = {'road_2': (84, 86, 22), 'person_2': (167, 242, 242), 'vegetation_2': (242, 160, 19),
+ 'car_2': (30, 193, 252), 'sidewalk_2': (46, 247, 180)}
+switched_cmap = {switched_name2Id[i]:switched_labels2color[i] for i in switched_name2Id.keys()}
+color_map = {**color_map, **switched_cmap}
+
+exp_modes = len(label_switches)
+num_modes = 2 ** exp_modes
+
+#########################################
+# network #
+#########################################
+
+cuda_visible_devices = '0'
+cpu_device = '/cpu:0'
+gpu_device = '/gpu:0'
+
+patch_size = [256, 512]
+network_input_shape = (None, 3) +tuple(patch_size)
+network_output_shape = (None, num_classes) + tuple(patch_size)
+label_shape = (None, 1) + tuple(patch_size)
+loss_mask_shape = label_shape
+
+base_channels = 32
+num_channels = [base_channels, 2*base_channels, 4*base_channels,
+ 6*base_channels, 6*base_channels, 6*base_channels, 6*base_channels]
+
+num_convs_per_block = 3
+
+latent_dim = 6
+num_1x1_convs = 3
+analytic_kl = True
+use_posterior_mean = False
+
+#########################################
+# evaluation #
+#########################################
+
+num_samples = 16
+exp_dir = '/'.join(os.path.abspath(pretrained_weights.__file__).split('/')[:-1])
+out_dir = 'EVALUATION_OUTPUT_DIRECTORY_ABSOLUTE_PATH'
\ No newline at end of file
diff --git a/evaluation/eval_cityscapes.py b/evaluation/eval_cityscapes.py
new file mode 100644
index 0000000..68e5951
--- /dev/null
+++ b/evaluation/eval_cityscapes.py
@@ -0,0 +1,479 @@
+# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Cityscapes evaluation script."""
+
+import tensorflow as tf
+import numpy as np
+import os
+import argparse
+from tqdm import tqdm
+from multiprocessing import Process, Queue
+from importlib.machinery import SourceFileLoader
+import logging
+import pickle
+
+from model.probabilistic_unet import ProbUNet
+from data.cityscapes.data_loader import loadFiles, map_labels_to_trainId
+from utils import training_utils
+
+
+def write_test_predictions(cf):
+ """
+ Write samples as numpy arrays.
+ :param cf: config module
+ :return:
+ """
+ # do not use all gpus
+ os.environ["CUDA_VISIBLE_DEVICES"] = cf.cuda_visible_devices
+
+ data_dir = os.path.join(cf.data_dir, cf.resolution)
+ data_dict = loadFiles(label_density=cf.label_density, split='val', input_path=data_dir,
+ cities=None, instance=False)
+ # prepare out_dir
+ if not os.path.isdir(cf.out_dir):
+ os.mkdir(cf.out_dir)
+
+ logging.info('Writing to {}'.format(cf.out_dir))
+
+ # initialize computation graph
+ prob_unet = ProbUNet(latent_dim=cf.latent_dim, num_channels=cf.num_channels,
+ num_1x1_convs=cf.num_1x1_convs,
+ num_classes=cf.num_classes, num_convs_per_block=cf.num_convs_per_block,
+ initializers={'w': training_utils.he_normal(),
+ 'b': tf.truncated_normal_initializer(stddev=0.001)},
+ regularizers={'w': tf.contrib.layers.l2_regularizer(1.0),
+ 'b': tf.contrib.layers.l2_regularizer(1.0)})
+ x = tf.placeholder(tf.float32, shape=cf.network_input_shape)
+
+ with tf.device(cf.gpu_device):
+ prob_unet(x, is_training=False, one_hot_labels=cf.one_hot_labels)
+ sampled_logits = prob_unet.sample()
+
+ saver = tf.train.Saver(save_relative_paths=True)
+ with tf.train.MonitoredTrainingSession() as sess:
+
+ print('EXP DIR', cf.exp_dir)
+ latest_ckpt_path = tf.train.latest_checkpoint(cf.exp_dir)
+ print('CKPT PATH', latest_ckpt_path)
+ saver.restore(sess, latest_ckpt_path)
+
+ for k, v in tqdm(data_dict.items()):
+ img = np.load(v['data']) / 255.
+ # add batch dimensions
+ img = img[np.newaxis]
+
+ for i in range(cf.num_samples):
+ sample = sess.run(sampled_logits, feed_dict={x: img})
+ sample = np.argmax(sample, axis=1)[:, np.newaxis]
+ sample = sample.astype(np.uint8)
+ sample_path = os.path.join(cf.out_dir, '{}_sample{}_labelIds.npy'.format(k, i))
+ np.save(sample_path, sample)
+
+
+def get_array_of_modes(cf, seg):
+ """
+ Assemble an array holding all label modes.
+ :param cf: config module
+ :param seg: 4D integer array
+ :return: 4D integer array
+ """
+ mode_stats = get_mode_statistics(cf.label_switches, exp_modes=cf.exp_modes)
+ switch = mode_stats['switch']
+
+ # construct ground-truth modes
+ gt_seg_modes = np.zeros(shape=(cf.num_modes,) + seg.shape, dtype=np.uint8)
+ for mode in range(cf.num_modes):
+ switched_seg = seg.copy()
+ for i, c in enumerate(cf.label_switches.keys()):
+ if switch[mode, i]:
+ init_id = cf.name2trainId[c]
+ final_id = cf.name2trainId[c + '_2']
+ switched_seg[switched_seg == init_id] = final_id
+ gt_seg_modes[mode] = switched_seg
+
+ return gt_seg_modes
+
+
+def get_array_of_samples(cf, img_key):
+ """
+ Assemble an array holding all segmentation samples for a given image.
+ :param cf: config module
+ :param img_key: string
+ :return: 5D integer array
+ """
+ seg_samples = np.zeros(shape=(cf.num_samples,1,1) + tuple(cf.patch_size), dtype=np.uint8)
+ for i in range(cf.num_samples):
+ sample_path = os.path.join(cf.out_dir, '{}_sample{}_labelIds.npy'.format(img_key, i))
+ try:
+ seg_samples[i] = np.load(sample_path)
+ except:
+ print('Could not load {}'.format(sample_path))
+
+ return seg_samples
+
+
+def get_mode_counts(d_matrix_YS):
+ """
+ Calculate image-level mode counts.
+ :param d_matrix_YS: 3D array
+ :return: numpy array
+ """
+ # assign each sample to a mode
+ mean_d = np.nanmean(d_matrix_YS, axis=-1)
+ sampled_modes = np.argmin(mean_d, axis=-2)
+
+ # count the modes
+ num_modes = d_matrix_YS.shape[0]
+ mode_count = np.zeros(shape=(num_modes,), dtype=np.int)
+ for sampled_mode in sampled_modes:
+ mode_count[sampled_mode] += 1
+
+ return mode_count
+
+
+def get_pixelwise_mode_counts(cf, seg, seg_samples):
+ """
+ Calculate pixel-wise mode counts.
+ :param cf: config module
+ :param seg: 4D array of integer labeled segmentations
+ :param seg_samples: 5D array of integer labeled segmentations
+ :return: array of shape (switchable classes, 3)
+ """
+ assert seg.shape == seg_samples.shape[1:]
+ num_samples = seg_samples.shape[0]
+ pixel_counts = np.zeros(shape=(len(cf.label_switches),3), dtype=np.int)
+
+ # iterate all switchable classes
+ for i,c in enumerate(cf.label_switches.keys()):
+ c_id = cf.name2trainId[c]
+ alt_c_id = cf.name2trainId[c+'_2']
+ c_ixs = np.where(seg == c_id)
+
+ total_num_pixels = np.sum((seg == c_id).astype(np.uint8)) * num_samples
+ pixel_counts[i,0] = total_num_pixels
+
+ # count the pixels of original class|original class and alternative class|original class
+ for j in range(num_samples):
+ sample = seg_samples[j]
+ sampled_original_pixels = np.sum((sample[c_ixs] == c_id).astype(np.uint8))
+ sampled_alternative_pixels = np.sum((sample[c_ixs] == alt_c_id).astype(np.uint8))
+ pixel_counts[i,1] += sampled_original_pixels
+ pixel_counts[i,2] += sampled_alternative_pixels
+
+ return pixel_counts
+
+
+def get_mode_statistics(label_switches, exp_modes=5):
+ """
+ Calculate a binary matrix of switches as well as a vector of mode probabilities.
+ :param label_switches: dict specifying class names and their individual sampling probabilities
+ :param exp_modes: integer, number of independently switchable classes
+ :return: dict
+ """
+ num_modes = 2 ** exp_modes
+
+ # assemble a binary matrix of switch decisions
+ switch = np.zeros(shape=(num_modes, 5), dtype=np.uint8)
+ for i in range(exp_modes):
+ switch[:,i] = 2 ** i * (2 ** (exp_modes - 1 - i) * [0] + 2 ** (exp_modes - 1 - i) * [1])
+
+ # calculate the probability for each individual mode
+ mode_probs = np.zeros(shape=(num_modes,), dtype=np.float32)
+ for mode in range(num_modes):
+ prob = 1.
+ for i, c in enumerate(label_switches.keys()):
+ if switch[mode, i]:
+ prob *= label_switches[c]
+ else:
+ prob *= 1. - label_switches[c]
+ mode_probs[mode] = prob
+ assert np.sum(mode_probs) == 1.
+
+ return {'switch': switch, 'mode_probs': mode_probs}
+
+
+def get_energy_distance_components(gt_seg_modes, seg_samples, eval_class_ids, ignore_mask=None):
+ """
+ Calculates the components for the IoU-based generalized energy distance given an array holding all segmentation
+ modes and an array holding all sampled segmentations.
+ :param gt_seg_modes: N-D array in format (num_modes,[...],H,W)
+ :param seg_samples: N-D array in format (num_samples,[...],H,W)
+ :param eval_class_ids: integer or list of integers specifying the classes to encode, if integer range() is applied
+ :param ignore_mask: N-D array in format ([...],H,W)
+ :return: dict
+ """
+ num_modes = gt_seg_modes.shape[0]
+ num_samples = seg_samples.shape[0]
+
+ if isinstance(eval_class_ids, int):
+ eval_class_ids = list(range(eval_class_ids))
+
+ d_matrix_YS = np.zeros(shape=(num_modes, num_samples, len(eval_class_ids)), dtype=np.float32)
+ d_matrix_YY = np.zeros(shape=(num_modes, num_modes, len(eval_class_ids)), dtype=np.float32)
+ d_matrix_SS = np.zeros(shape=(num_samples, num_samples, len(eval_class_ids)), dtype=np.float32)
+
+ # iterate all ground-truth modes
+ for mode in range(num_modes):
+
+ ##########################################
+ # Calculate d(Y,S) = [1 - IoU(Y,S)], #
+ # with S ~ P_pred, Y ~ P_gt #
+ ##########################################
+
+ # iterate the samples S
+ for i in range(num_samples):
+ conf_matrix = training_utils.calc_confusion(gt_seg_modes[mode], seg_samples[i],
+ loss_mask=ignore_mask, class_ixs=eval_class_ids)
+ iou = training_utils.metrics_from_conf_matrix(conf_matrix)['iou']
+ d_matrix_YS[mode, i] = 1. - iou
+
+ ###########################################
+ # Calculate d(Y,Y') = [1 - IoU(Y,Y')], #
+ # with Y,Y' ~ P_gt #
+ ###########################################
+
+ # iterate the ground-truth modes Y' while exploiting the pair-wise symmetries for efficiency
+ for mode_2 in range(mode, num_modes):
+ conf_matrix = training_utils.calc_confusion(gt_seg_modes[mode], gt_seg_modes[mode_2],
+ loss_mask=ignore_mask, class_ixs=eval_class_ids)
+ iou = training_utils.metrics_from_conf_matrix(conf_matrix)['iou']
+ d_matrix_YY[mode, mode_2] = 1. - iou
+ d_matrix_YY[mode_2, mode] = 1. - iou
+
+ #########################################
+ # Calculate d(S,S') = 1 - IoU(S,S'), #
+ # with S,S' ~ P_pred #
+ #########################################
+
+ # iterate all samples S
+ for i in range(num_samples):
+ # iterate all samples S'
+ for j in range(i, num_samples):
+ conf_matrix = training_utils.calc_confusion(seg_samples[i], seg_samples[j],
+ loss_mask=ignore_mask, class_ixs=eval_class_ids)
+ iou = training_utils.metrics_from_conf_matrix(conf_matrix)['iou']
+ d_matrix_SS[i, j] = 1. - iou
+ d_matrix_SS[j, i] = 1. - iou
+
+ return {'YS': d_matrix_YS, 'SS': d_matrix_SS, 'YY': d_matrix_YY}
+
+
+def calc_energy_distances(d_matrices, num_samples=None, probability_weighted=False, label_switches=None, exp_mode=5):
+ """
+ Calculate the energy distance for each image based on matrices holding the combinatorial distances.
+ :param d_matrices: dict holding 4D arrays of shape \
+ (num_images, num_modes/num_samples, num_modes/num_samples, num_classes)
+ :param num_samples: integer or None
+ :param probability_weighted: bool
+ :param label_switches: None or dict
+ :param exp_mode: integer
+ :return: numpy array
+ """
+ d_matrices = d_matrices.copy()
+
+ if num_samples is None:
+ num_samples = d_matrices['SS'].shape[1]
+
+ d_matrices['YS'] = d_matrices['YS'][:,:,:num_samples]
+ d_matrices['SS'] = d_matrices['SS'][:,:num_samples,:num_samples]
+
+ # perform a nanmean over the class axis so as to not factor in classes that are not present in
+ # both the ground-truth mode as well as the sampled prediction
+ if probability_weighted:
+ mode_stats = get_mode_statistics(label_switches, exp_modes=exp_mode)
+ mode_probs = mode_stats['mode_probs']
+
+ mean_d_YS = np.nanmean(d_matrices['YS'], axis=-1)
+ mean_d_YS = np.mean(mean_d_YS, axis=2)
+ mean_d_YS = mean_d_YS * mode_probs[np.newaxis, :]
+ d_YS = np.sum(mean_d_YS, axis=1)
+
+ mean_d_SS = np.nanmean(d_matrices['SS'], axis=-1)
+ d_SS = np.mean(mean_d_SS, axis=(1, 2))
+
+ mean_d_YY = np.nanmean(d_matrices['YY'], axis=-1)
+ mean_d_YY = mean_d_YY * mode_probs[np.newaxis, :, np.newaxis] * mode_probs[np.newaxis, np.newaxis, :]
+ d_YY = np.sum(mean_d_YY, axis=(1, 2))
+
+ else:
+ mean_d_YS = np.nanmean(d_matrices['YS'], axis=-1)
+ d_YS = np.mean(mean_d_YS, axis=(1,2))
+
+ mean_d_SS = np.nanmean(d_matrices['SS'], axis=-1)
+ d_SS = np.mean(mean_d_SS, axis=(1, 2))
+
+ mean_d_YY = np.nanmean(d_matrices['YY'], axis=-1)
+ d_YY = np.nanmean(mean_d_YY, axis=(1, 2))
+
+ return 2 * d_YS - d_SS - d_YY
+
+
+def eval(cf, cities, queue=None, ixs=None):
+ """
+ Perform evaluation w.r.t the generalized energy distance based on the IoU as well as image-level and pixel-level
+ mode frequencies (using samples written to file).
+ :param cf: config module
+ :param cities: string or list of strings
+ :param queue: instance of multiprocessing.Queue
+ :param ixs: None or 2-tuple of ints
+ :return: NoneType or numpy array
+ """
+ data_dir = os.path.join(cf.data_dir, cf.resolution)
+ data_dict = loadFiles(label_density=cf.label_density, split='val', input_path=data_dir, cities=cities,
+ instance=False)
+
+ num_modes = cf.num_modes
+ num_samples = cf.num_samples
+
+ # evaluate only switchable classes, so a total of 10 here
+ eval_class_names = list(cf.label_switches.keys()) + list(cf.switched_name2Id.keys())
+ eval_class_ids = [cf.name2trainId[n] for n in eval_class_names]
+ d_matrices = {'YS': np.zeros(shape=(len(data_dict), num_modes, num_samples, len(eval_class_ids)),
+ dtype=np.float32),
+ 'YY': np.ones(shape=(len(data_dict), num_modes, num_modes, len(eval_class_ids)),
+ dtype=np.float32),
+ 'SS': np.ones(shape=(len(data_dict), num_samples, num_samples, len(eval_class_ids)),
+ dtype=np.float32)}
+ sampled_mode_counts = np.zeros(shape=(num_modes,), dtype=np.int)
+ sampled_pixel_counts = np.zeros(shape=(len(cf.label_switches), 3), dtype=np.int)
+
+ logging.info('Evaluating class names: {} (corresponding to labels {})'.format(eval_class_names, eval_class_ids))
+
+ # allow for data selection by indexing via ixs
+ if ixs is None:
+ data_keys = list(data_dict.keys())
+ else:
+ data_keys = list(data_dict.keys())[ixs[0]:ixs[1]]
+ for k in d_matrices.keys():
+ d_matrices[k] = d_matrices[k][:ixs[1]-ixs[0]]
+
+ # iterate all validation images
+ for img_n, img_key in enumerate(tqdm(data_keys)):
+
+ seg = np.load(data_dict[img_key]['seg'])
+ seg = map_labels_to_trainId(seg)
+ seg = seg[np.newaxis, np.newaxis]
+ ignore_mask = (seg == cf.ignore_label).astype(np.uint8)
+
+ seg_samples = get_array_of_samples(cf, img_key)
+ gt_seg_modes = get_array_of_modes(cf, seg)
+
+ energy_dist = get_energy_distance_components(gt_seg_modes=gt_seg_modes, seg_samples=seg_samples,
+ eval_class_ids=eval_class_ids, ignore_mask=ignore_mask)
+ sampled_mode_counts += get_mode_counts(energy_dist['YS'])
+ sampled_pixel_counts += get_pixelwise_mode_counts(cf, seg, seg_samples)
+
+ for k in d_matrices.keys():
+ d_matrices[k][img_n] = energy_dist[k]
+
+ results = {'d_matrices': d_matrices, 'sampled_pixel_counts': sampled_pixel_counts,
+ 'sampled_mode_counts': sampled_mode_counts, 'total_num_samples': len(data_keys) * num_samples}
+
+ if queue is not None:
+ queue.put(results)
+ return
+ else:
+ return results
+
+
+def runInParallel(fns_args, queue):
+ """Run functions in parallel.
+ :param fns_args: list of tuples containing functions and a tuple of arguments each
+ :param queue: instance of multiprocessing.Queue()
+ :return: list of queue results
+ """
+ proc = []
+ for fn in fns_args:
+ p = Process(target=fn[0], args=fn[1])
+ p.start()
+ proc.append(p)
+ return [queue.get() for p in proc]
+
+
+def multiprocess_evaluation(cf):
+ """Evaluate the energy distance in multiprocessing.
+ :param cf: config module"""
+ q = Queue()
+ results = runInParallel([(eval, (cf, 'lindau', q)),
+ (eval, (cf, 'frankfurt', q, (0, 100))),
+ (eval, (cf, 'frankfurt', q, (100, 200))),
+ (eval, (cf, 'frankfurt', q, (200, 267))),
+ (eval, (cf, 'munster', q, (0, 100))),
+ (eval, (cf, 'munster', q, (100, 174)))],
+ queue=q)
+ total_num_samples = 0
+ sampled_mode_counts = np.zeros(shape=(cf.num_modes,), dtype=np.int)
+ sampled_pixel_counts = np.zeros(shape=(len(cf.label_switches), 3), dtype=np.int)
+ d_matrices = {'YS':[], 'SS':[], 'YY':[]}
+
+ # aggregate results from the queue
+ for result_dict in results:
+ for key in d_matrices.keys():
+ d_matrices[key].append(result_dict['d_matrices'][key])
+
+ sampled_pixel_counts += result_dict['sampled_pixel_counts']
+ sampled_mode_counts += result_dict['sampled_mode_counts']
+ total_num_samples += result_dict['total_num_samples']
+
+ for key in d_matrices.keys():
+ d_matrices[key] = np.concatenate(d_matrices[key], axis=0)
+
+ # calculate frequencies
+ print('pixel frequencies', sampled_pixel_counts)
+ sampled_pixelwise_mode_per_class = sampled_pixel_counts[:,1:]
+ total_num_pixels_per_class = sampled_pixel_counts[:,0:1]
+ sampled_pixel_frequencies = sampled_pixelwise_mode_per_class / total_num_pixels_per_class
+ sampled_mode_frequencies = sampled_mode_counts / total_num_samples
+
+ print('sampled pixel frequencies', sampled_pixel_frequencies)
+ print('sampled_mode_frequencies', sampled_mode_frequencies)
+
+ results_dict = {'d_matrices': d_matrices, 'pixel_frequencies': sampled_pixel_frequencies,
+ 'mode_frequencies': sampled_mode_frequencies}
+
+ results_file = os.path.join(cf.out_dir, 'eval_results.pkl')
+ with open(results_file, 'wb') as f:
+ pickle.dump(results_dict, f, pickle.HIGHEST_PROTOCOL)
+ logging.info('Wrote to {}'.format(results_file))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='Evaluation step selection.')
+ parser.add_argument('--write_samples', dest='write_samples', action='store_true')
+ parser.add_argument('--eval_samples', dest='write_samples', action='store_false')
+ parser.set_defaults(write_samples=True)
+ parser.add_argument('-c', '--config_name', type=str, default='cityscapes_eval_config.py',
+ help='name of the python file that is loaded as config module')
+ args = parser.parse_args()
+
+ # load evaluation config
+ cf = SourceFileLoader('cf', args.config_name).load_module()
+
+ # prepare evaluation directory
+ if not os.path.isdir(cf.out_dir):
+ os.mkdir(cf.out_dir)
+
+ # log to file and console
+ log_path = os.path.join(cf.out_dir, 'eval.log')
+ logging.basicConfig(filename=log_path, level=logging.INFO)
+ logging.getLogger().addHandler(logging.StreamHandler())
+ logging.info('Logging to {}'.format(log_path))
+
+ if args.write_samples:
+ logging.info('Writing samples to {}'.format(cf.out_dir))
+ write_test_predictions(cf)
+ else:
+ logging.info('Evaluating samples from {}'.format(cf.out_dir))
+ multiprocess_evaluation(cf)
\ No newline at end of file
diff --git a/evaluation/evaluation_plots.ipynb b/evaluation/evaluation_plots.ipynb
new file mode 100644
index 0000000..7eb0e37
--- /dev/null
+++ b/evaluation/evaluation_plots.ipynb
@@ -0,0 +1,590 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/utils/training_utils.py:21: UserWarning: \n",
+ "This call to matplotlib.use() has no effect because the backend has already\n",
+ "been chosen; matplotlib.use() must be called *before* pylab, matplotlib.pyplot,\n",
+ "or matplotlib.backends is imported for the first time.\n",
+ "\n",
+ "The backend was *originally* set to 'nbAgg' by the following code:\n",
+ " File \"/usr/local/lib/python3.6/runpy.py\", line 193, in _run_module_as_main\n",
+ " \"__main__\", mod_spec)\n",
+ " File \"/usr/local/lib/python3.6/runpy.py\", line 85, in _run_code\n",
+ " exec(code, run_globals)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/ipykernel_launcher.py\", line 16, in \n",
+ " app.launch_new_instance()\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/traitlets/config/application.py\", line 658, in launch_instance\n",
+ " app.start()\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/ipykernel/kernelapp.py\", line 486, in start\n",
+ " self.io_loop.start()\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/tornado/platform/asyncio.py\", line 132, in start\n",
+ " self.asyncio_loop.run_forever()\n",
+ " File \"/usr/local/lib/python3.6/asyncio/base_events.py\", line 422, in run_forever\n",
+ " self._run_once()\n",
+ " File \"/usr/local/lib/python3.6/asyncio/base_events.py\", line 1432, in _run_once\n",
+ " handle._run()\n",
+ " File \"/usr/local/lib/python3.6/asyncio/events.py\", line 145, in _run\n",
+ " self._callback(*self._args)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/tornado/platform/asyncio.py\", line 122, in _handle_events\n",
+ " handler_func(fileobj, events)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/tornado/stack_context.py\", line 300, in null_wrapper\n",
+ " return fn(*args, **kwargs)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py\", line 450, in _handle_events\n",
+ " self._handle_recv()\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py\", line 480, in _handle_recv\n",
+ " self._run_callback(callback, msg)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py\", line 432, in _run_callback\n",
+ " callback(*args, **kwargs)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/tornado/stack_context.py\", line 300, in null_wrapper\n",
+ " return fn(*args, **kwargs)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/ipykernel/kernelbase.py\", line 283, in dispatcher\n",
+ " return self.dispatch_shell(stream, msg)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/ipykernel/kernelbase.py\", line 233, in dispatch_shell\n",
+ " handler(stream, idents, msg)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/ipykernel/kernelbase.py\", line 399, in execute_request\n",
+ " user_expressions, allow_stdin)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/ipykernel/ipkernel.py\", line 208, in do_execute\n",
+ " res = shell.run_cell(code, store_history=store_history, silent=silent)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/ipykernel/zmqshell.py\", line 537, in run_cell\n",
+ " return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 2662, in run_cell\n",
+ " raw_cell, store_history, silent, shell_futures)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 2785, in _run_cell\n",
+ " interactivity=interactivity, compiler=compiler, result=result)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 2901, in run_ast_nodes\n",
+ " if self.run_code(code, result):\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 2961, in run_code\n",
+ " exec(code_obj, self.user_global_ns, self.user_ns)\n",
+ " File \"\", line 4, in \n",
+ " get_ipython().run_line_magic('matplotlib', 'notebook')\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 2131, in run_line_magic\n",
+ " result = fn(*args,**kwargs)\n",
+ " File \"\", line 2, in matplotlib\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/IPython/core/magic.py\", line 187, in \n",
+ " call = lambda f, *a, **k: f(*a, **k)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/IPython/core/magics/pylab.py\", line 99, in matplotlib\n",
+ " gui, backend = self.shell.enable_matplotlib(args.gui)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 3049, in enable_matplotlib\n",
+ " pt.activate_matplotlib(backend)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/IPython/core/pylabtools.py\", line 311, in activate_matplotlib\n",
+ " matplotlib.pyplot.switch_backend(backend)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/matplotlib/pyplot.py\", line 231, in switch_backend\n",
+ " matplotlib.use(newbackend, warn=False, force=True)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/matplotlib/__init__.py\", line 1410, in use\n",
+ " reload(sys.modules['matplotlib.backends'])\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/importlib/__init__.py\", line 166, in reload\n",
+ " _bootstrap._exec(spec, module)\n",
+ " File \"/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/matplotlib/backends/__init__.py\", line 16, in \n",
+ " line for line in traceback.format_stack()\n",
+ "\n",
+ "\n",
+ " matplotlib.use('agg')\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /home/skohl/PycharmProjects/phabricator/skohl/prob_unet/utils/training_utils.py:270: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with distribution=normal is deprecated and will be removed in a future version.\n",
+ "Instructions for updating:\n",
+ "`normal` is a deprecated alias for `truncated_normal`\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import os\n",
+ "import numpy as np\n",
+ "\n",
+ "%matplotlib notebook\n",
+ "import matplotlib.pyplot as plt\n",
+ "import matplotlib.gridspec as gridspec\n",
+ "import seaborn as sns\n",
+ "import pandas as pd\n",
+ "import imp \n",
+ "import pickle\n",
+ "\n",
+ "from utils.training_utils import to_rgb, calc_confusion, metrics_from_conf_matrix\n",
+ "from data.cityscapes.data_loader import loadFiles, map_labels_to_trainId\n",
+ "from evaluation.eval_cityscapes import calc_energy_distances, get_mode_statistics\n",
+ "\n",
+ "from IPython.core.display import display, HTML\n",
+ "display(HTML(\"\"))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/tmp/prob_unet_1.0beta_analyticKLTrue_6latents_31x1s_32chan_lr0.0001PieceWise_bs16_7blocks_QMeanFalse_henormal_noBiasReg_OneHotMinus0.5_240kepochs/munster_000062_000019_sample7_labelIds.npy\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "cf = imp.load_source('cf', 'cityscapes_eval_config.py')\n",
+ "path = cf.out_dir\n",
+ "\n",
+ "ix = 18\n",
+ "img_path = os.path.join(path, os.listdir(path)[ix])\n",
+ "print(img_path)\n",
+ "img = np.load(img_path)\n",
+ "\n",
+ "f = plt.figure(figsize=(8,4))\n",
+ "plt.imshow(to_rgb(img[0,0], cmap=cf.color_map))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# evaluate for different number of samples"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/evaluation/eval_cityscapes.py:309: RuntimeWarning: Mean of empty slice\n",
+ " mean_d_YY = np.nanmean(d_matrices['YY'], axis=-1)\n",
+ "/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/evaluation/eval_cityscapes.py:309: RuntimeWarning: Mean of empty slice\n",
+ " mean_d_YY = np.nanmean(d_matrices['YY'], axis=-1)\n",
+ "/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/evaluation/eval_cityscapes.py:309: RuntimeWarning: Mean of empty slice\n",
+ " mean_d_YY = np.nanmean(d_matrices['YY'], axis=-1)\n",
+ "/home/skohl/PycharmProjects/phabricator/skohl/prob_unet/evaluation/eval_cityscapes.py:309: RuntimeWarning: Mean of empty slice\n",
+ " mean_d_YY = np.nanmean(d_matrices['YY'], axis=-1)\n"
+ ]
+ }
+ ],
+ "source": [
+ "results_file = os.path.join(cf.out_dir, 'eval_results.pkl')\n",
+ "\n",
+ "with open(results_file, \"rb\") as f:\n",
+ " results = pickle.load(f)\n",
+ "\n",
+ "e_distances = []\n",
+ "e_means = []\n",
+ "samples_column = []\n",
+ " \n",
+ "for s in [1,4,8,16]:\n",
+ " e_dist = calc_energy_distances(results['d_matrices'], num_samples=s, probability_weighted=True, label_switches=cf.label_switches)\n",
+ " e_dist = e_dist[~np.isnan(e_dist)]\n",
+ " e_distances.extend(e_dist)\n",
+ " samples_column.extend([s] * len(e_dist))\n",
+ " e_means.append(np.mean(e_dist))\n",
+ "\n",
+ "energy = pd.DataFrame(data={'energy': e_distances, 'num_samples': samples_column})\n",
+ "means = pd.DataFrame(data={'energy': e_means, 'num_samples': [1,4,8,16]})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "Text(0.5,0,'# samples')"
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "f = plt.figure(figsize=(5,5), dpi=150)\n",
+ "gs = gridspec.GridSpec(1,1)\n",
+ "ax = plt.subplot(gs[0,0])\n",
+ "\n",
+ "sns.stripplot(x=\"num_samples\", y=\"energy\", data=energy, alpha=0.5, s=2, ax=ax)\n",
+ "sns.stripplot(x=\"num_samples\", y=\"energy\", data=means, s=18, marker='^', color='k', ax=ax, jitter=False)\n",
+ "sns.stripplot(x=\"num_samples\", y=\"energy\", data=means, s=14, marker='^', ax=ax, jitter=False)\n",
+ "ax.set_title('Probabilistic U-Net (validation set)', y=1.03)\n",
+ "fs=12\n",
+ "ax.set_ylabel(r'$D_{ged}^{2}$', fontsize=fs)\n",
+ "ax.set_xlabel('# samples', fontsize=fs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[[0.39207232 0.4350246 ]\n",
+ " [0.44614304 0.31564506]\n",
+ " [0.61882485 0.30369349]\n",
+ " [0.72800224 0.20223276]\n",
+ " [0.7578801 0.22053149]]\n",
+ "[0.163625 0.038375 0.036625 0.00875 0.06 0.0365 0.012375 0.00575\n",
+ " 0.065375 0.0125 0.02025 0.002375 0.024125 0.009625 0.007875 0.001875\n",
+ " 0.136125 0.03225 0.0425 0.006875 0.051125 0.032 0.01175 0.007\n",
+ " 0.0735 0.01125 0.03325 0.003125 0.02725 0.008625 0.014 0.003375]\n",
+ "[0.004732871, 0.0053244797, 0.0067612445, 0.0076063997, 0.00867693, 0.009761547, 0.011358891, 0.012395615, 0.012778752, 0.013945066, 0.015381831, 0.016226986, 0.01730456, 0.01825536, 0.020824632, 0.021974044, 0.023427712, 0.0247208, 0.028200023, 0.029749475, 0.031725027, 0.03346816, 0.036916394, 0.040285747, 0.041530944, 0.045321465, 0.052737705, 0.05932992, 0.06768005, 0.07614006, 0.09668579, 0.10877152]\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(results['pixel_frequencies'])\n",
+ "print(results['mode_frequencies'])\n",
+ "\n",
+ "stats = get_mode_statistics(cf.label_switches, exp_modes=5)\n",
+ "print(sorted(stats['mode_probs']))\n",
+ "\n",
+ "log_mode_frequencies = np.log(results['mode_frequencies'])\n",
+ "log_gt_mode_frequencies = np.log(stats['mode_probs'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "f = plt.figure(figsize=(8,7))\n",
+ "gs = gridspec.GridSpec(2, 2, wspace=0.0, hspace=0.0, height_ratios=[32,6], width_ratios=[3,1])\n",
+ "\n",
+ "ax = plt.subplot(gs[0, 0])\n",
+ "sort_ixs = np.argsort(stats['mode_probs'])[::-1]\n",
+ "plt.bar(x=range(len(log_gt_mode_frequencies)), height=results['mode_frequencies'][sort_ixs], align='center', color='g', label='Estimated Frequencies')\n",
+ "plt.bar(x=range(len(log_gt_mode_frequencies)), height=stats['mode_probs'][sort_ixs], align='center', edgecolor='k', fc=(0,0,0,0), label='Ground truth')\n",
+ "plt.tick_params(\n",
+ " axis='x', # changes apply to the x-axis\n",
+ " which='both', # both major and minor ticks are affected\n",
+ " bottom=False, # ticks along the bottom edge are off\n",
+ " top=False, # ticks along the top edge are off\n",
+ " labelbottom=False) # labels along the bottom edge are off\n",
+ "\n",
+ "ax.legend()\n",
+ "ax.set_title('Probabilistic U-Net (validation set)', y=1.03)\n",
+ "ax = plt.gca()\n",
+ "ax.set_xlim([-.5,31.5])\n",
+ "ax.set_ylim([0,0.2])\n",
+ "yticks = ax.yaxis.get_major_ticks()\n",
+ "yticks[0].label1.set_visible(False)\n",
+ "ax.set_ylabel('P', fontdict={'size': 15})\n",
+ "\n",
+ "rows = list(cf.label_switches.keys())\n",
+ "cell_text = np.transpose(stats['switch'][sort_ixs], axes=[1,0])\n",
+ "\n",
+ "ax = plt.subplot(gs[1, 0])\n",
+ "the_table = ax.table(cellText=cell_text,\n",
+ " rowLabels=rows,\n",
+ " loc='center')\n",
+ "\n",
+ "ax.get_yaxis().set_visible(False)\n",
+ "plt.tick_params(\n",
+ " axis='x', # changes apply to the x-axis\n",
+ " which='both', # both major and minor ticks are affected\n",
+ " bottom=False, # ticks along the bottom edge are off\n",
+ " top=False, # ticks along the top edge are off\n",
+ " labelbottom=False) # labels along the bottom edge are off\n",
+ "ax.set_xlabel('Discrete Modes')\n",
+ "\n",
+ "ax = plt.subplot(gs[1, 1])\n",
+ "ax.set_title('Pixel-wise Frequency Estimation',fontdict={'size': 6})\n",
+ "plt.barh(y=range(5), width=results['pixel_frequencies'][:,1], align='center', color='g')\n",
+ "plt.barh(y=range(5), width=list(cf.label_switches.values()), align='center', edgecolor='k', fc=(0,0,0,0))\n",
+ "ax.get_yaxis().set_visible(False)\n",
+ "ax.set_xticks([0.1,0.2,0.3,0.4,0.5])\n",
+ "ax.set_xlabel('P')\n",
+ "ax.tick_params(axis='both', which='major', labelsize=6)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# prepare animated gif from samples"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 156,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Saved to /home/skohl/PycharmProjects/phabricator/skohl/prob_unet/tmp//prob_unet_0.0005beta_analyticKLTrue_6latents_31x1s_32chan_lr0.0001PieceWise_bs16_7blocks_QMeanFalse_henormal_noBiasReg_OneHotMinus0.5_240kepochs_TFP/gif_panel_sample_0.png\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Saved to /home/skohl/PycharmProjects/phabricator/skohl/prob_unet/tmp//prob_unet_0.0005beta_analyticKLTrue_6latents_31x1s_32chan_lr0.0001PieceWise_bs16_7blocks_QMeanFalse_henormal_noBiasReg_OneHotMinus0.5_240kepochs_TFP/gif_panel_sample_1.png\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "ename": "KeyboardInterrupt",
+ "evalue": "",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0max\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlegend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhandles\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlegend_handles\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m9\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbbox_to_anchor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m0.03\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mncol\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlegend_handles\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mframeon\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0mout_dir\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'gif_panel_sample_{}.png'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msample_num\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 57\u001b[0;31m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msavefig\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdpi\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m200\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbbox_inches\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'tight'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpad_inches\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 58\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Saved to {}'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_dir\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/matplotlib/pyplot.py\u001b[0m in \u001b[0;36msavefig\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 708\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msavefig\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 709\u001b[0m \u001b[0mfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgcf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 710\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msavefig\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 711\u001b[0m \u001b[0mfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcanvas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdraw_idle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# need this if 'transparent=True' to reset colors\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 712\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/matplotlib/figure.py\u001b[0m in \u001b[0;36msavefig\u001b[0;34m(self, fname, **kwargs)\u001b[0m\n\u001b[1;32m 2033\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_frameon\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframeon\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2034\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2035\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcanvas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_figure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2036\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2037\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mframeon\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/matplotlib/backend_bases.py\u001b[0m in \u001b[0;36mprint_figure\u001b[0;34m(self, filename, dpi, facecolor, edgecolor, orientation, format, **kwargs)\u001b[0m\n\u001b[1;32m 2210\u001b[0m \u001b[0morientation\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0morientation\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2211\u001b[0m \u001b[0mdryrun\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2212\u001b[0;31m **kwargs)\n\u001b[0m\u001b[1;32m 2213\u001b[0m \u001b[0mrenderer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfigure\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cachedRenderer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2214\u001b[0m \u001b[0mbbox_inches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfigure\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_tightbbox\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrenderer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/PycharmProjects/phabricator/skohl/prob_unet/venv/lib/python3.6/site-packages/matplotlib/backends/backend_agg.py\u001b[0m in \u001b[0;36mprint_png\u001b[0;34m(self, filename_or_obj, *args, **kwargs)\u001b[0m\n\u001b[1;32m 526\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcbook\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen_file_cm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename_or_obj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"wb\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfh\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 527\u001b[0m _png.write_png(renderer._renderer, fh,\n\u001b[0;32m--> 528\u001b[0;31m self.figure.dpi, metadata=metadata)\n\u001b[0m\u001b[1;32m 529\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0mrenderer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdpi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moriginal_dpi\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
+ ],
+ "output_type": "error"
+ }
+ ],
+ "source": [
+ "from matplotlib.lines import Line2D\n",
+ "from data.cityscapes.data_loader import map_labels_to_trainId\n",
+ "\n",
+ "# get dictionary of image files\n",
+ "data_dir = os.path.join(cf.data_dir, cf.resolution)\n",
+ "data_dict = loadFiles(label_density=cf.label_density, split='val', input_path=data_dir,\n",
+ " cities=None, instance=False)\n",
+ "\n",
+ "# list of image indices in data_dict\n",
+ "ixs = [0,2,75,85,100,145,350,390,460,470]\n",
+ "num_samples = 16\n",
+ "y_dim = 256\n",
+ "x_dim = 512\n",
+ "\n",
+ "img_arr = np.zeros(shape=(y_dim, len(ixs) * x_dim, 3)) \n",
+ "gt_arr = np.zeros(shape=(y_dim, len(ixs) * x_dim))\n",
+ "sample_arr = np.zeros(shape=(num_samples, y_dim, len(ixs) * x_dim))\n",
+ "\n",
+ "for i,ix in enumerate(ixs):\n",
+ "\n",
+ " img_key = list(data_dict.keys())[ix]\n",
+ " img_path = data_dict[img_key]['data']\n",
+ " seg_path = data_dict[img_key]['seg']\n",
+ " \n",
+ " # load images\n",
+ " img_arr[:, i * x_dim: (i + 1) * x_dim] = np.transpose(np.load(img_path), axes=[1,2,0]) / 255.\n",
+ " \n",
+ " # load gt segmentation\n",
+ " gt_arr[:, i * x_dim: (i + 1) * x_dim] = map_labels_to_trainId(np.load(seg_path))\n",
+ " \n",
+ " # load samples\n",
+ " for sample_num in range(num_samples):\n",
+ " sample_path = os.path.join(cf.out_dir, '{}_sample{}_labelIds.npy'.format(img_key, sample_num))\n",
+ " sample_arr[sample_num, :, i * x_dim: (i + 1) * x_dim] = np.load(sample_path)[0,0]\n",
+ "\n",
+ "for sample_num in range(num_samples):\n",
+ " f = plt.figure(figsize=(len(ixs) * 4, 9))\n",
+ "\n",
+ " arr = np.concatenate([img_arr, to_rgb(gt_arr, cmap=cf.color_map), to_rgb(sample_arr[sample_num], cmap=cf.color_map)], axis=0)\n",
+ " \n",
+ " plt.imshow(arr)\n",
+ " plt.text(-80,170, 'image', fontsize=12, rotation=90, rotation_mode='anchor')\n",
+ " plt.text(-80,470, 'deterministic', fontsize=12, rotation=90, rotation_mode='anchor')\n",
+ " plt.text(-40,460, 'groundtruth', fontsize=12, rotation=90, rotation_mode='anchor')\n",
+ " plt.text(-80,690, 'sample {}'.format(sample_num + 1), fontsize=12, rotation=90, rotation_mode='anchor')\n",
+ " \n",
+ " ax = plt.gca()\n",
+ " ax.get_xaxis().set_visible(False)\n",
+ " ax.get_yaxis().set_visible(False)\n",
+ " \n",
+ " # custom legend\n",
+ " trainId2name = cf.trainId2name\n",
+ " trainId2name[255] = 'ignore'\n",
+ " legend_handles = [Line2D([0], [0], color=tuple([c / 255. for c in cf.color_map[trainID]]), lw=4, label=trainId2name[trainID]) for trainID in list(cf.color_map.keys())] \n",
+ " ax.legend(handles=legend_handles, loc=9, bbox_to_anchor=(0.5, -0.03), ncol=len(legend_handles), frameon=False)\n",
+ " out_dir = os.path.join(cf.out_dir, 'gif_panel_sample_{}.png'.format(sample_num))\n",
+ " plt.savefig(out_dir, dpi=200, bbox_inches='tight', pad_inches=0.0)\n",
+ " plt.close()\n",
+ " print('Saved to {}'.format(out_dir))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "venv",
+ "language": "python",
+ "name": "venv"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3.0
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/model/__init__.py b/model/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/model/probabilistic_unet.py b/model/probabilistic_unet.py
new file mode 100644
index 0000000..cb5d93b
--- /dev/null
+++ b/model/probabilistic_unet.py
@@ -0,0 +1,478 @@
+# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Probabilistic U-Net model."""
+
+import tensorflow as tf
+import sonnet as snt
+import tensorflow_probability as tfp
+tfd = tfp.distributions
+from utils.training_utils import ce_loss, he_normal
+
+def down_block(features,
+ output_channels,
+ kernel_shape,
+ stride=1,
+ rate=1,
+ num_convs=2,
+ initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)},
+ regularizers=None,
+ nonlinearity=tf.nn.relu,
+ down_sample_input=True,
+ down_sampling_op=lambda x, size: tf.image.resize_images(x, size,
+ method=tf.image.ResizeMethod.BILINEAR, align_corners=True),
+ data_format='NCHW',
+ name='down_block'):
+ """A block made up of a down-sampling step followed by several convolutional layers."""
+ with tf.variable_scope(name):
+ if down_sample_input:
+ features = down_sampling_op(features, data_format)
+
+ for _ in range(num_convs):
+ features = snt.Conv2D(output_channels, kernel_shape, stride, rate, data_format=data_format,
+ initializers=initializers, regularizers=regularizers)(features)
+ features = nonlinearity(features)
+
+ return features
+
+
+def up_block(lower_res_inputs,
+ same_res_inputs,
+ output_channels,
+ kernel_shape,
+ stride=1,
+ rate=1,
+ num_convs=2,
+ initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)},
+ regularizers=None,
+ nonlinearity=tf.nn.relu,
+ up_sampling_op=lambda x, size: tf.image.resize_images(x, size,
+ method=tf.image.ResizeMethod.BILINEAR, align_corners=True),
+ data_format='NCHW',
+ name='up_block'):
+ """A block made up of an up-sampling step followed by several convolutional layers."""
+ with tf.variable_scope(name):
+ spatial_shape = same_res_inputs.get_shape()[2:]
+
+ if data_format=='NHWC':
+ features = up_sampling_op(lower_res_inputs, spatial_shape)
+ features = tf.concat([features, same_res_inputs], axis=-1)
+ else:
+ lower_res_inputs = tf.transpose(lower_res_inputs, perm=[0,2,3,1])
+ features = up_sampling_op(lower_res_inputs, spatial_shape)
+ features = tf.transpose(features, perm=[0,3,1,2])
+ features = tf.concat([features, same_res_inputs], axis=1)
+
+ for _ in range(num_convs):
+ features = snt.Conv2D(output_channels, kernel_shape, stride, rate, data_format=data_format,
+ initializers=initializers, regularizers=regularizers)(features)
+ features = nonlinearity(features)
+
+ return features
+
+
+class VGG_Encoder(snt.AbstractModule):
+ """A quasi VGG-style convolutional net, made of M x (down-sampling, N x conv)-operations,
+ where M = len(num_channels), N = num_convs_per_block."""
+
+ def __init__(self,
+ num_channels,
+ nonlinearity=tf.nn.relu,
+ num_convs_per_block=3,
+ initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)},
+ regularizers={'w': tf.contrib.layers.l2_regularizer(1.0), 'b': tf.contrib.layers.l2_regularizer(1.0)},
+ data_format='NCHW',
+ down_sampling_op=lambda x, df: tf.nn.avg_pool(x, ksize=[1,1,2,2], strides=[1,1,2,2],
+ padding='SAME', data_format=df),
+ name="vgg_enc"):
+ super(VGG_Encoder, self).__init__(name=name)
+ self._num_channels = num_channels
+ self._nonlinearity = nonlinearity
+ self._num_convs = num_convs_per_block
+ self._initializers = initializers
+ self._regularizers = regularizers
+ self._data_format = data_format
+ self._down_sampling_op = down_sampling_op
+
+ def _build(self, inputs):
+ """
+ :param inputs: 4D tensor of shape NCHW or NWHC
+ :return: a list of 4D tensors of shape NCHW or NWHC
+ """
+ features = [inputs]
+
+ # iterate blocks (`processing scales')
+ for i, n_channels in enumerate(self._num_channels):
+
+ if i == 0:
+ down_sample = False
+ else:
+ down_sample = True
+ tf.logging.info('encoder scale {}: {}'.format(i, features[-1].get_shape()))
+ features.append(down_block(features[-1],
+ output_channels=n_channels,
+ kernel_shape=(3,3),
+ num_convs=self._num_convs,
+ nonlinearity=self._nonlinearity,
+ initializers=self._initializers,
+ regularizers=self._regularizers,
+ down_sample_input=down_sample,
+ data_format=self._data_format,
+ down_sampling_op=self._down_sampling_op,
+ name='down_block_{}'.format(i)))
+ # return all features except for the input images
+ return features[1:]
+
+
+class VGG_Decoder(snt.AbstractModule):
+ """A quasi VGG-style convolutional net, made of M x (up-sampling, N x conv)-operations,
+ where M = len(num_channels), N = num_convs_per_block."""
+
+ def __init__(self,
+ num_channels,
+ num_classes,
+ nonlinearity=tf.nn.relu,
+ num_convs_per_block=3,
+ initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)},
+ regularizers={'w': tf.contrib.layers.l2_regularizer(1.0), 'b': tf.contrib.layers.l2_regularizer(1.0)},
+ data_format='NCHW',
+ up_sampling_op=lambda x, size: tf.image.resize_images(x, size,
+ method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, align_corners=True),
+ name="vgg_dec"):
+ super(VGG_Decoder, self).__init__(name=name)
+ self._num_channels = num_channels
+ self._num_classes = num_classes
+ self._nonlinearity = nonlinearity
+ self._num_convs = num_convs_per_block
+ self._initializers = initializers
+ self._regularizers = regularizers
+ self._data_format = data_format
+ self._up_sampling_op = up_sampling_op
+
+ def _build(self, input_list):
+ """
+ :param input_list: a list of 4D tensors of shape NCHW or NHWC
+ :return: 4D tensor
+ """
+ try:
+ assert len(self._num_channels) == len(input_list)
+ except:
+ raise AssertionError('Missmatch: {} blocks vs. {} incoming feature-maps!')
+
+ n = len(input_list) - 2
+ lower_res_features = input_list[-1]
+
+ # iterate input features & channels in reverse order, starting from the second last (here, =nth) features
+ for i in range(n, -1, -1):
+ same_res_features = input_list[i]
+ n_channels = self._num_channels[i]
+
+ tf.logging.info('decoder scale {}: {}'.format(i, lower_res_features.get_shape()))
+ lower_res_features = up_block(lower_res_features,
+ same_res_features,
+ output_channels=n_channels,
+ kernel_shape=(3, 3),
+ num_convs=self._num_convs,
+ nonlinearity=self._nonlinearity,
+ initializers=self._initializers,
+ regularizers=self._regularizers,
+ data_format=self._data_format,
+ up_sampling_op=self._up_sampling_op,
+ name='up_block_{}'.format(i))
+ return lower_res_features
+
+
+class UNet(snt.AbstractModule):
+ """A quasi standard U-Net, similar to `U-Net: Convolutional Networks for Biomedical Image Segmentation',
+ https://arxiv.org/abs/1505.04597."""
+
+ def __init__(self,
+ num_channels,
+ num_classes,
+ nonlinearity=tf.nn.relu,
+ num_convs_per_block=3,
+ initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)},
+ regularizers=None,
+ data_format='NCHW',
+ down_sampling_op=lambda x, df: tf.nn.avg_pool(x, ksize=[1,1,2,2], strides=[1,1,2,2],
+ padding='SAME', data_format=df),
+ up_sampling_op=lambda x, size: tf.image.resize_images(x, size,
+ method=tf.image.ResizeMethod.BILINEAR, align_corners=True),
+ name="unet"):
+ super(UNet, self).__init__(name=name)
+ with self._enter_variable_scope():
+ tf.logging.info('Building U-Net.')
+ self._encoder = VGG_Encoder(num_channels, nonlinearity, num_convs_per_block, initializers, regularizers,
+ data_format=data_format, down_sampling_op=down_sampling_op)
+ self._decoder = VGG_Decoder(num_channels, num_classes, nonlinearity, num_convs_per_block, initializers,
+ regularizers, data_format=data_format, up_sampling_op=up_sampling_op)
+
+ def _build(self, inputs):
+ """
+ :param inputs: 4D tensor of shape NCHW or NHWC
+ :return: 4D tensor
+ """
+ encoder_features = self._encoder(inputs)
+ predicted_logits = self._decoder(encoder_features)
+ return predicted_logits
+
+
+class Conv1x1Decoder(snt.AbstractModule):
+ """A stack of 1x1 convolutions that takes two tensors to be concatenated along their channel axes."""
+
+ def __init__(self,
+ num_classes,
+ num_channels,
+ num_1x1_convs,
+ nonlinearity=tf.nn.relu,
+ initializers={'w': tf.orthogonal_initializer(), 'b': tf.truncated_normal_initializer(stddev=0.001)},
+ regularizers={'w': tf.contrib.layers.l2_regularizer(1.0), 'b': tf.contrib.layers.l2_regularizer(1.0)},
+ data_format='NCHW',
+ name='conv_decoder'):
+ super(Conv1x1Decoder, self).__init__(name=name)
+ self._num_classes = num_classes
+ self._num_channels = num_channels
+ self._num_1x1_convs = num_1x1_convs
+ self._nonlinearity = nonlinearity
+ self._initializers = initializers
+ self._regularizers = regularizers
+ self._data_format = data_format
+
+ if data_format == 'NCHW':
+ self._channel_axis = 1
+ self._spatial_axes = [2,3]
+ else:
+ self._channel_axis = -1
+ self._spatial_axes = [1,2]
+
+ def _build(self, features, z):
+ """
+ :param features: 4D tensor of shape NCHW or NHWC
+ :param z: 4D tensor of shape NC11 or N11C
+ :return: 4D tensor
+ """
+ shp = features.get_shape()
+ spatial_shape = [shp[axis] for axis in self._spatial_axes]
+ multiples = [1] + spatial_shape
+ multiples.insert(self._channel_axis, 1)
+
+ if len(z.get_shape()) == 2:
+ z = tf.expand_dims(z, axis=2)
+ z = tf.expand_dims(z, axis=2)
+
+ # broadcast latent vector to spatial dimensions of the image/feature tensor
+ broadcast_z = tf.tile(z, multiples)
+ features = tf.concat([features, broadcast_z], axis=self._channel_axis)
+ for _ in range(self._num_1x1_convs):
+ features = snt.Conv2D(self._num_channels, kernel_shape=(1,1), stride=1, rate=1,
+ data_format=self._data_format,
+ initializers=self._initializers, regularizers=self._regularizers)(features)
+ features = self._nonlinearity(features)
+ logits = snt.Conv2D(self._num_classes, kernel_shape=(1,1), stride=1, rate=1,
+ data_format=self._data_format,
+ initializers=self._initializers, regularizers=None)
+ return logits(features)
+
+
+class AxisAlignedConvGaussian(snt.AbstractModule):
+ """A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix."""
+
+ def __init__(self,
+ latent_dim,
+ num_channels,
+ nonlinearity=tf.nn.relu,
+ num_convs_per_block=3,
+ initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)},
+ regularizers={'w': tf.contrib.layers.l2_regularizer(1.0), 'b': tf.contrib.layers.l2_regularizer(1.0)},
+ data_format='NCHW',
+ down_sampling_op=lambda x, df:\
+ tf.nn.avg_pool(x, ksize=[1,1,2,2], strides=[1,1,2,2], padding='SAME', data_format=df),
+ name="conv_dist"):
+ self._latent_dim = latent_dim
+ self._initializers = initializers
+ self._regularizers = regularizers
+ self._data_format = data_format
+
+ if data_format == 'NCHW':
+ self._channel_axis = 1
+ self._spatial_axes = [2,3]
+ else:
+ self._channel_axis = -1
+ self._spatial_axes = [1,2]
+
+ super(AxisAlignedConvGaussian, self).__init__(name=name)
+ with self._enter_variable_scope():
+ tf.logging.info('Building ConvGaussian.')
+ self._encoder = VGG_Encoder(num_channels, nonlinearity, num_convs_per_block, initializers, regularizers,
+ data_format=data_format, down_sampling_op=down_sampling_op)
+
+ def _build(self, img, seg=None):
+ """
+ Evaluate mu and log_sigma of a Gaussian conditioned on an image + optionally, concatenated one-hot segmentation.
+ :param img: 4D array
+ :param seg: 4D array
+ :return: snt.AbstractModule object
+ """
+ if seg is not None:
+ seg = tf.cast(seg, tf.float32)
+ img = tf.concat([img, seg], axis=self._channel_axis)
+ encoding = self._encoder(img)[-1]
+ encoding = tf.reduce_mean(encoding, axis=self._spatial_axes, keepdims=True)
+
+ mu_log_sigma = snt.Conv2D(2 * self._latent_dim, (1,1), stride=1, rate=1, data_format=self._data_format,
+ initializers=self._initializers, regularizers=self._regularizers)(encoding)
+
+ mu_log_sigma = tf.squeeze(mu_log_sigma, axis=self._spatial_axes)
+ mu = mu_log_sigma[:, :self._latent_dim]
+ log_sigma = mu_log_sigma[:, self._latent_dim:]
+
+ return tfd.MultivariateNormalDiag(loc=mu, scale_diag=tf.exp(log_sigma))
+
+
+class ProbUNet(snt.AbstractModule):
+ """Probabilistic U-Net."""
+
+ def __init__(self,
+ latent_dim,
+ num_channels,
+ num_classes,
+ num_1x1_convs=3,
+ nonlinearity=tf.nn.relu,
+ num_convs_per_block=3,
+ initializers={'w': he_normal(), 'b': tf.truncated_normal_initializer(stddev=0.001)},
+ regularizers={'w': tf.contrib.layers.l2_regularizer(1.0), 'b': tf.contrib.layers.l2_regularizer(1.0)},
+ data_format='NCHW',
+ down_sampling_op=lambda x, df:\
+ tf.nn.avg_pool(x, ksize=[1,1,2,2], strides=[1,1,2,2], padding='SAME', data_format=df),
+ up_sampling_op=lambda x, size:\
+ tf.image.resize_images(x, size, method=tf.image.ResizeMethod.BILINEAR, align_corners=True),
+ name='prob_unet'):
+ super(ProbUNet, self).__init__(name=name)
+ self._data_format = data_format
+ self._num_classes = num_classes
+
+ with self._enter_variable_scope():
+ self._unet = UNet(num_channels=num_channels, num_classes=num_classes, nonlinearity=nonlinearity,
+ num_convs_per_block=num_convs_per_block, initializers=initializers,
+ regularizers=regularizers, data_format=data_format,
+ down_sampling_op=down_sampling_op, up_sampling_op=up_sampling_op)
+
+ self._f_comb = Conv1x1Decoder(num_classes=num_classes, num_1x1_convs=num_1x1_convs,
+ num_channels=num_channels[0], nonlinearity=nonlinearity,
+ data_format=data_format, initializers=initializers, regularizers=regularizers)
+
+ self._prior =\
+ AxisAlignedConvGaussian(latent_dim=latent_dim, num_channels=num_channels,
+ nonlinearity=nonlinearity, num_convs_per_block=num_convs_per_block,
+ initializers=initializers, regularizers=regularizers, name='prior')
+
+ self._posterior =\
+ AxisAlignedConvGaussian(latent_dim=latent_dim, num_channels=num_channels,
+ nonlinearity=nonlinearity, num_convs_per_block=num_convs_per_block,
+ initializers=initializers, regularizers=regularizers, name='posterior')
+
+ def _build(self, img, seg=None, is_training=True, one_hot_labels=True):
+ """
+ Evaluate individual components of the net.
+ :param img: 4D image array
+ :param seg: 4D segmentation array
+ :param is_training: if False, refrain from evaluating the posterior
+ :param one_hot_labels: bool, if False expects integer labeled segmentation of shape N1HW or NHW1
+ :return: None
+ """
+ if is_training:
+ if seg is not None:
+ if not one_hot_labels:
+ if self._data_format == 'NCHW':
+ spatial_shape = img.get_shape()[-2:]
+ class_axis = 1
+ one_hot_shape = (-1, self._num_classes) + tuple(spatial_shape)
+ elif self._data_format == 'NHWC':
+ spatial_shape = img.get_shape()[-3:-1]
+ one_hot_shape = (-1,) + tuple(spatial_shape) + (self._num_classes,)
+ class_axis = 3
+
+ seg = tf.reshape(seg, shape=[-1])
+ seg = tf.one_hot(indices=seg, depth=self._num_classes, axis=class_axis)
+ seg = tf.reshape(seg, shape=one_hot_shape)
+ seg -= 0.5
+ self._q = self._posterior(img, seg)
+
+ self._p = self._prior(img)
+ self._unet_features = self._unet(img)
+
+ def reconstruct(self, use_posterior_mean=False, z_q=None):
+ """
+ Reconstruct a given segmentation. Default settings result in decoding a posterior sample.
+ :param use_posterior_mean: use posterior_mean instead of sampling z_q
+ :param z_q: use provided latent sample z_q instead of sampling anew
+ :return: 4D logits tensor
+ """
+ if use_posterior_mean:
+ z_q = self._q._mu
+ else:
+ if z_q is None:
+ z_q = self._q.sample()
+ return self._f_comb(self._unet_features, z_q)
+
+ def sample(self):
+ """
+ Sample a segmentation by reconstructing from a prior sample.
+ Only needs to re-evaluate the last 1x1-convolutions.
+ :return: 4D logits tensor
+ """
+ z_p = self._p.sample()
+ return self._f_comb(self._unet_features, z_p)
+
+ def kl(self, analytic=True, z_q=None):
+ """
+ Calculate the Kullback-Leibler divergence KL(Q||P) between 2 axis-aligned gaussians,
+ i.e. the variance sigma is assumed diagonal.
+ :param analytic: bool, if False, approximate the KL via sampling from the posterior
+ :param z_q: None or 2D tensor, if analytic=False the posterior sample can be provided instead of sampling anew
+ :return: 4D tensor
+ """
+ if analytic:
+ kl = tfd.kl_divergence(self._q, self._p)
+ else:
+ if z_q is None:
+ z_q = self._q.sample()
+ log_q = self._q.log_prob(z_q)
+ log_p = self._p.log_prob(z_q)
+ kl = log_q - log_p
+ return kl
+
+ def elbo(self, seg, beta=1.0, analytic_kl=True, reconstruct_posterior_mean=False, z_q=None, one_hot_labels=True,
+ loss_mask=None):
+ """
+ Calculate the evidence lower bound (elbo) of the log-likelihood of P(Y|X).
+ :param seg: 4D tensor
+ :param analytic_kl: bool, if False calculate the KL via sampling
+ :param z_q: 4D tensor
+ :param one_hot_labels: bool, if False expects integer labeled segmentation of shape N1HW or NHW1
+ :param loss_mask: 4D tensor, binary
+ :return: 1D tensor
+ """
+ if z_q is None:
+ z_q = self._q.sample()
+
+ self._kl = tf.reduce_mean(self.kl(analytic_kl, z_q))
+
+ self._rec_logits = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean, z_q=z_q)
+ rec_loss = ce_loss(labels=seg, logits=self._rec_logits, n_classes=self._num_classes,
+ loss_mask=loss_mask, one_hot_labels=one_hot_labels)
+ self._rec_loss = rec_loss['sum']
+ self._rec_loss_mean = rec_loss['mean']
+
+ return -(self._rec_loss + beta * self._kl)
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..6dd1342
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,12 @@
+tensorflow_probability==0.3.0
+tensorflow-gpu==1.10.0
+dm-sonnet==1.23
+matplotlib==2.2.2
+numpy==1.14.5
+Pillow==5.1.0
+scipy==1.1.0
+SimpleITK==1.1.0
+tensorboard==1.10.0
+tqdm==4.23.4
+seaborn==0.9.0
+pytest==3.8.0
\ No newline at end of file
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..79ea0a1
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,33 @@
+#!/usr/bin/env python
+# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from distutils.core import setup
+from setuptools import find_packages
+
+req_file = "requirements.txt"
+
+def parse_requirements(filename):
+ lineiter = (line.strip() for line in open(filename))
+ return [line for line in lineiter if line and not line.startswith("#")]
+
+install_reqs = parse_requirements(req_file)
+
+setup(name='model',
+ version='latest',
+ packages=find_packages(exclude=['test', 'test.*']),
+ install_requires=install_reqs,
+ dependency_links=[],
+ )
\ No newline at end of file
diff --git a/tests/evaluation/eval_tests.py b/tests/evaluation/eval_tests.py
new file mode 100644
index 0000000..6b6c322
--- /dev/null
+++ b/tests/evaluation/eval_tests.py
@@ -0,0 +1,152 @@
+# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Evaluation metrics tests."""
+
+import pytest
+import numpy as np
+from evaluation.eval_cityscapes import get_energy_distance_components, get_mode_counts, get_pixelwise_mode_counts
+from utils.training_utils import calc_confusion, metrics_from_conf_matrix
+
+
+def nan_save_array_equal(a, b, nan_replacement=-1.):
+ """Replace NANs to savely compare arrays elementwise."""
+ a_nan_ixs = np.where(np.isnan(a))
+ a[a_nan_ixs] = nan_replacement
+
+ b_nan_ixs = np.where(np.isnan(b))
+ b[b_nan_ixs] = nan_replacement
+
+ return (a == b).all()
+
+
+@pytest.mark.parametrize("test_input,expected", [
+ ([np.zeros(shape=(1,1,10,10)), np.zeros(shape=(1,1,10,10))],
+ {'tp_0': 100, 'tp_1': 0, 'fp_0': 0, 'fp_1': 0, 'fn_0': 0, 'fn_1': 0}),
+ ([np.ones(shape=(1,1,10,10)), np.ones(shape=(1,1,10,10))],
+ {'tp_0': 0, 'tp_1': 100, 'fp_0': 0, 'fp_1': 0, 'fn_0': 0, 'fn_1': 0}),
+ ([np.ones(shape=(1,1,10,10)), np.zeros(shape=(1,1,10,10))],
+ {'tp_0': 0, 'tp_1': 0, 'fp_0': 100, 'fp_1': 0, 'fn_0': 0, 'fn_1': 100}),
+ ([np.concatenate([np.ones(shape=(1,1,10,5)), np.zeros(shape=(1,1,10,5))], axis=-1), np.zeros(shape=(1,1,10,10))],
+ {'tp_0': 50, 'tp_1': 0, 'fp_0': 50, 'fp_1': 0, 'fn_0': 0, 'fn_1': 50}),
+])
+def test_confusion(test_input, expected):
+
+ gt_seg_modes = test_input[0]
+ seg_samples = test_input[1]
+
+ conf_matrix = calc_confusion(gt_seg_modes, seg_samples, class_ixs=[0,1])
+
+ tp_0 = conf_matrix[0,0]
+ tp_1 = conf_matrix[1,0]
+ fp_0 = conf_matrix[0,1]
+ fp_1 = conf_matrix[1,1]
+ fn_0 = conf_matrix[0,3]
+ fn_1 = conf_matrix[1,3]
+
+ assert tp_0 == expected['tp_0']
+ assert tp_1 == expected['tp_1']
+ assert fp_0 == expected['fp_0']
+ assert fp_1 == expected['fp_1']
+ assert fn_0 == expected['fn_0']
+ assert fn_1 == expected['fn_1']
+
+
+@pytest.mark.parametrize("test_input,expected,eval_f", [
+ ([np.zeros(shape=(1,1,10,10)), np.zeros(shape=(1,1,10,10))], {'iou_0': 1., 'iou_1': np.nan},
+ [lambda x,y: x == y, lambda x,y: np.isnan(x) and np.isnan(y)]),
+ ([np.ones(shape=(1,1,10,10)), np.ones(shape=(1,1,10,10))], {'iou_0': np.nan, 'iou_1': 1.},
+ [lambda x,y: np.isnan(x) and np.isnan(y), lambda x,y: x == y]),
+ ([np.ones(shape=(1,1,10,10)), np.zeros(shape=(1,1,10,10))], {'iou_0': 0., 'iou_1': 0.},
+ [lambda x,y: x == y, lambda x,y: x == y]),
+ ([np.concatenate([np.ones(shape=(1, 1, 10, 5)), np.zeros(shape=(1, 1, 10, 5))], axis=-1),
+ np.zeros(shape=(1, 1, 10, 10))], {'iou_0': 0.5, 'iou_1': 0.},
+ [lambda x,y: x == y, lambda x,y: x == y])
+])
+def test_metrics_from_confusion(test_input, expected, eval_f):
+
+ gt_seg_modes = test_input[0]
+ seg_samples = test_input[1]
+
+ conf_matrix = calc_confusion(gt_seg_modes, seg_samples, class_ixs=[0,1])
+ iou = metrics_from_conf_matrix(conf_matrix)['iou']
+
+ assert eval_f[0](iou[0], expected['iou_0'])
+ assert eval_f[1](iou[1], expected['iou_1'])
+
+
+@pytest.mark.parametrize("test_input,expected,eval_f", [
+ ([np.zeros(shape=(1,1,1,10,10)), np.zeros(shape=(1,1,1,10,10)), [0]], {'YS': 0., 'SS': 0. , 'YY': 0.},
+ 3 * [lambda x,y: x == y]),
+ ([np.zeros(shape=(1,1,1,10,10)), np.zeros(shape=(1,1,1,10,10)), [1]], {'YS': np.nan, 'SS': np.nan , 'YY': np.nan},
+ 3 * [lambda x,y: np.isnan(x) and np.isnan(y)]),
+ ([np.ones(shape=(1,1,1,10,10)), np.zeros(shape=(1,1,1,10,10)), [0]], {'YS': 1., 'SS': 0. , 'YY': np.nan},
+ 2 * [lambda x,y: x == y] + [lambda x,y: np.isnan(x) and np.isnan(y)]),
+ ([np.concatenate([np.ones(shape=(1,1,1,10,10)), np.zeros(shape=(1,1,1,10,10))], axis=0),
+ np.concatenate([np.zeros(shape=(1,1,1,10,10)), np.ones(shape=(1,1,1,10,10))], axis=0), [0]],
+ {'YS': np.array([[[1.],[np.nan]], [[0.],[1.]]]), 'SS': np.array([[[0.],[1.]], [[1.],[np.nan]]]),
+ 'YY': np.array([[[np.nan],[1.]], [[1.],[0.]]])},
+ 3 * [lambda x,y: nan_save_array_equal(x,y)]),
+ ([np.concatenate([np.ones(shape=(1,1,1,10,10)), np.zeros(shape=(1,1,1,10,10))], axis=0),
+ np.concatenate([np.zeros(shape=(1,1,1,10,10)), np.ones(shape=(1,1,1,10,10))], axis=0), 2],
+ {'YS': np.array([[[1.,1.], [np.nan,0.]], [[0.,np.nan], [1.,1.]]]),
+ 'SS': np.array([[[0.,np.nan], [1.,1.]], [[1.,1.], [np.nan,0.]]]),
+ 'YY': np.array([[[np.nan,0.], [1.,1.]], [[1.,1.], [0.,np.nan]]])},
+ 3 * [lambda x,y: nan_save_array_equal(x, y)]),
+])
+def test_energy_distance_components(test_input, expected, eval_f):
+
+ gt_seg_modes = test_input[0]
+ seg_samples = test_input[1]
+ eval_class_ids = test_input[2]
+
+ results = get_energy_distance_components(gt_seg_modes, seg_samples, eval_class_ids=eval_class_ids)
+
+ assert eval_f[0](results['YS'], expected['YS'])
+ assert eval_f[1](results['SS'], expected['SS'])
+ assert eval_f[2](results['YY'], expected['YY'])
+
+
+@pytest.mark.parametrize("test_input,expected,eval_f", [
+ (np.concatenate([np.zeros(shape=(1,5,5)), 0.1 * np.ones(shape=(1,5,5))]), [5,0], lambda x,y: (x==y).all())
+])
+def test_get_mode_counts(test_input, expected, eval_f):
+ mode_counts = get_mode_counts(test_input)
+
+ assert eval_f(mode_counts, expected)
+
+
+@pytest.mark.parametrize("test_input,expected,eval_f", [
+ ([np.zeros(shape=(1,1,10,10)),
+ np.concatenate([np.ones(shape=(1,1,1,10,5)), np.zeros(shape=(1,1,1,10,5))], axis=-1), 1],
+ [[100, 50, 50]], lambda x,y: (x==y).all()),
+ ([np.concatenate([np.ones(shape=(1,1,10,5)), np.zeros(shape=(1,1,10,5))], axis=-1),
+ np.concatenate([np.ones(shape=(1,1,1,10,10)), np.zeros(shape=(1,1,1,10,10))], axis=0), 2],
+ [[100, 50, 0],[100, 50, 0]], lambda x,y: (x==y).all()),
+ ([np.concatenate([np.ones(shape=(1, 1, 10, 5)), np.zeros(shape=(1, 1, 10, 5))], axis=-1),
+ np.concatenate([3 * np.ones(shape=(1, 1, 1, 10, 10)), 2 * np.ones(shape=(1, 1, 1, 10, 10)),
+ np.ones(shape=(1, 1, 1, 10, 10)), np.zeros(shape=(1, 1, 1, 10, 10))], axis=0), 2],
+ [[200, 50, 50], [200, 50, 50]], lambda x, y: (x == y).all()),
+])
+def test_pixelwise_mode_counts(test_input, expected, eval_f):
+
+ class cf():
+ def __init__(self, num_classes):
+ self.label_switches = {'class_{}'.format(i): np.random.uniform(0.0,1.0) for i in range(num_classes)}
+ self.name2trainId = {**{'class_{}'.format(i): i for i in range(num_classes)},
+ **{'class_{}_2'.format(i): i + num_classes for i in range(num_classes)}}
+ num_classes = test_input[2]
+ pixelwise_mode_counts = get_pixelwise_mode_counts(cf(num_classes), seg=test_input[0], seg_samples=test_input[1])
+
+ assert eval_f(pixelwise_mode_counts, expected)
\ No newline at end of file
diff --git a/training/prob_unet_config.py b/training/prob_unet_config.py
new file mode 100644
index 0000000..50a653a
--- /dev/null
+++ b/training/prob_unet_config.py
@@ -0,0 +1,127 @@
+# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""CityScapes training config."""
+
+import os
+import numpy as np
+from collections import OrderedDict
+from data.cityscapes.cityscapes_labels import labels as cs_labels_tuple
+
+config_path = os.path.realpath(__file__)
+
+#########################################
+# data-loader #
+#########################################
+
+data_dir = 'PREPROCESSING_OUTPUT_DIRECTORY_ABSOLUTE_PATH'
+resolution = 'quarter'
+label_density = 'gtFine'
+gt_instances = False
+train_cities = ['aachen', 'bochum', 'bremen', 'cologne', 'dusseldorf', 'erfurt', 'hamburg', 'hanover',
+ 'jena', 'krefeld', 'stuttgart', 'strasbourg', 'tubingen', 'weimar', 'zurich']
+val_cities = ['darmstadt', 'monchengladbach', 'ulm', ]
+
+num_classes = 19
+batch_size = 10
+pre_crop_size = [256, 512]
+patch_size = [256, 512]
+n_train_batches = None
+n_val_batches = 274 // batch_size
+n_workers = 5
+ignore_label = 255
+
+da_kwargs = {
+ 'random_crop': True,
+ 'rand_crop_dist': (patch_size[0] / 2., patch_size[1] / 2.),
+ 'do_elastic_deform': True,
+ 'alpha': (0., 800.),
+ 'sigma': (25., 35.),
+ 'do_rotation': True,
+ 'angle_x': (-np.pi / 8., np.pi / 8.),
+ 'angle_y': (0., 0.),
+ 'angle_z': (0., 0.),
+ 'do_scale': True,
+ 'scale': (0.8, 1.2),
+ 'border_mode_data': 'constant',
+ 'border_mode_seg': 'constant',
+ 'border_cval_seg': ignore_label,
+ 'gamma_retain_stats': True,
+ 'gamma_range': (0.7, 1.5),
+ 'p_gamma': 0.3
+}
+
+data_format = 'NCHW'
+one_hot_labels = False
+
+#########################################
+# label-switches #
+#########################################
+
+color_map = {label.trainId:label.color for label in cs_labels_tuple}
+color_map[255] = (0.,0.,0.)
+
+trainId2name = {labels.trainId: labels.name for labels in cs_labels_tuple}
+name2trainId = {labels.name: labels.trainId for labels in cs_labels_tuple}
+
+label_switches = OrderedDict([('sidewalk', 8./17.), ('person', 7./17.), ('car', 6./17.), ('vegetation', 5./17.), ('road', 4./17.)])
+num_classes += len(label_switches)
+switched_Id2name = {19+i:list(label_switches.keys())[i] + '_2' for i in range(len(label_switches))}
+switched_name2Id = {list(label_switches.keys())[i] + '_2':19+i for i in range(len(label_switches))}
+trainId2name = {**trainId2name, **switched_Id2name}
+name2trainId = {**name2trainId, **switched_name2Id}
+
+switched_labels2color = {'road_2': (84, 86, 22), 'person_2': (167, 242, 242), 'vegetation_2': (242, 160, 19),
+ 'car_2': (30, 193, 252), 'sidewalk_2': (46, 247, 180)}
+switched_cmap = {switched_name2Id[i]:switched_labels2color[i] for i in switched_name2Id.keys()}
+color_map = {**color_map, **switched_cmap}
+
+#########################################
+# network & training #
+#########################################
+
+cuda_visible_devices = '0'
+cpu_device = '/cpu:0'
+gpu_device = '/gpu:0'
+
+network_input_shape = (None, 3) + tuple(patch_size)
+network_output_shape = (None, num_classes) + tuple(patch_size)
+label_shape = (None, 1) + tuple(patch_size)
+loss_mask_shape = label_shape
+
+base_channels = 32
+num_channels = [base_channels, 2*base_channels, 4*base_channels,
+ 6*base_channels, 6*base_channels, 6*base_channels, 6*base_channels]
+
+num_convs_per_block = 3
+
+n_training_batches = 240000
+validation = {'n_batches': n_val_batches, 'every_n_batches': 2000}
+
+learning_rate_schedule = 'piecewise_constant'
+learning_rate_kwargs = {'values': [1e-4, 0.5e-4, 1e-5, 0.5e-6],
+ 'boundaries': [80000, 160000, 240000],
+ 'name': 'piecewise_constant_lr_decay'}
+initial_learning_rate = learning_rate_kwargs['values'][0]
+
+regularizarion_weight = 1e-5
+latent_dim = 6
+num_1x1_convs = 3
+beta = 1.0
+analytic_kl = True
+use_posterior_mean = False
+save_every_n_steps = n_training_batches // 3 if n_training_batches >= 100000 else n_training_batches
+disable_progress_bar = False
+
+exp_dir = "EXPERIMENT_OUTPUT_DIRECTORY_ABSOLUTE_PATH"
diff --git a/training/train_prob_unet.py b/training/train_prob_unet.py
new file mode 100644
index 0000000..82c2e92
--- /dev/null
+++ b/training/train_prob_unet.py
@@ -0,0 +1,188 @@
+# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Probabilistic U-Net training script."""
+
+import tensorflow as tf
+import numpy as np
+import os
+import time
+from tqdm import tqdm
+import shutil
+import logging
+import argparse
+from importlib.machinery import SourceFileLoader
+
+from data.cityscapes.data_loader import get_train_generators
+from model.probabilistic_unet import ProbUNet
+import utils.training_utils as training_utils
+
+
+def train(cf):
+ """Perform training from scratch."""
+
+ # do not use all gpus
+ os.environ["CUDA_VISIBLE_DEVICES"] = cf.cuda_visible_devices
+
+ # initialize data providers
+ data_provider = get_train_generators(cf)
+ train_provider = data_provider['train']
+ val_provider = data_provider['val']
+
+ prob_unet = ProbUNet(latent_dim=cf.latent_dim, num_channels=cf.num_channels,
+ num_1x1_convs=cf.num_1x1_convs,
+ num_classes=cf.num_classes, num_convs_per_block=cf.num_convs_per_block,
+ initializers={'w': training_utils.he_normal(),
+ 'b': tf.truncated_normal_initializer(stddev=0.001)},
+ regularizers={'w': tf.contrib.layers.l2_regularizer(1.0)})
+
+ x = tf.placeholder(tf.float32, shape=cf.network_input_shape)
+ y = tf.placeholder(tf.uint8, shape=cf.label_shape)
+ mask = tf.placeholder(tf.uint8, shape=cf.loss_mask_shape)
+
+ global_step = tf.train.get_or_create_global_step()
+
+ if cf.learning_rate_schedule == 'piecewise_constant':
+ learning_rate = tf.train.piecewise_constant(x=global_step, **cf.learning_rate_kwargs)
+ else:
+ learning_rate = tf.train.exponential_decay(learning_rate=cf.initial_learning_rate, global_step=global_step,
+ **cf.learning_rate_kwargs)
+ with tf.device(cf.gpu_device):
+ prob_unet(x, y, is_training=True, one_hot_labels=cf.one_hot_labels)
+ elbo = prob_unet.elbo(y, reconstruct_posterior_mean=cf.use_posterior_mean, beta=cf.beta, loss_mask=mask,
+ analytic_kl=cf.analytic_kl, one_hot_labels=cf.one_hot_labels)
+ reconstructed_logits = prob_unet._rec_logits
+ sampled_logits = prob_unet.sample()
+
+ reg_loss = cf.regularizarion_weight * tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
+ loss = -elbo + reg_loss
+ rec_loss = prob_unet._rec_loss_mean
+ kl = prob_unet._kl
+
+ mean_val_rec_loss = tf.placeholder(tf.float32, shape=(), name="mean_val_rec_loss")
+ mean_val_kl = tf.placeholder(tf.float32, shape=(), name="mean_val_kl")
+
+ optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss, global_step=global_step)
+
+ # prepare tf summaries
+ train_elbo_summary = tf.summary.scalar('train_elbo', elbo)
+ train_kl_summary = tf.summary.scalar('train_kl', kl)
+ train_rec_loss_summary = tf.summary.scalar('rec_loss', rec_loss)
+ train_loss_summary = tf.summary.scalar('train_loss', loss)
+ reg_loss_summary = tf.summary.scalar('train_reg_loss', reg_loss)
+ lr_summary = tf.summary.scalar('learning_rate', learning_rate)
+ beta_summary = tf.summary.scalar('beta', cf.beta)
+ training_summary_op = tf.summary.merge([train_loss_summary, reg_loss_summary, lr_summary, train_elbo_summary,
+ train_kl_summary, train_rec_loss_summary, beta_summary])
+ batches_per_second = tf.placeholder(tf.float32, shape=(), name="batches_per_sec_placeholder")
+ timing_summary = tf.summary.scalar('batches_per_sec', batches_per_second)
+ val_rec_loss_summary = tf.summary.scalar('val_loss', mean_val_rec_loss)
+ val_kl_summary = tf.summary.scalar('val_kl', mean_val_kl)
+ validation_summary_op = tf.summary.merge([val_rec_loss_summary, val_kl_summary])
+
+ tf.global_variables_initializer()
+
+ # Add ops to save and restore all the variables.
+ saver_hook = tf.train.CheckpointSaverHook(checkpoint_dir=cf.exp_dir, save_steps=cf.save_every_n_steps,
+ saver=tf.train.Saver(save_relative_paths=True))
+ # save config
+ shutil.copyfile(cf.config_path, os.path.join(cf.exp_dir, 'used_config.py'))
+
+ with tf.train.MonitoredTrainingSession(hooks=[saver_hook]) as sess:
+ summary_writer = tf.summary.FileWriter(cf.exp_dir, sess.graph)
+ logging.info('Model: {}'.format(cf.exp_dir))
+
+ for i in tqdm(range(cf.n_training_batches), disable=cf.disable_progress_bar):
+
+ start_time = time.time()
+ train_batch = next(train_provider)
+ _, train_summary = sess.run([optimizer, training_summary_op],
+ feed_dict={x: train_batch['data'], y: train_batch['seg'],
+ mask: train_batch['loss_mask']})
+ summary_writer.add_summary(train_summary, i)
+ time_delta = time.time() - start_time
+ train_speed = sess.run(timing_summary, feed_dict={batches_per_second: 1. / time_delta})
+ summary_writer.add_summary(train_speed, i)
+
+ # validation
+ if i % cf.validation['every_n_batches'] == 0:
+
+ train_rec = sess.run(reconstructed_logits, feed_dict={x: train_batch['data'], y: train_batch['seg']})
+ image_path = os.path.join(cf.exp_dir,
+ 'batch_{}_train_reconstructions.png'.format(i // cf.validation['every_n_batches']))
+ training_utils.plot_batch(train_batch, train_rec, num_classes=cf.num_classes,
+ cmap=cf.color_map, out_dir=image_path)
+
+ running_mean_val_rec_loss = 0.
+ running_mean_val_kl = 0.
+
+ for j in range(cf.validation['n_batches']):
+ val_batch = next(val_provider)
+ val_rec, val_sample, val_rec_loss, val_kl =\
+ sess.run([reconstructed_logits, sampled_logits, rec_loss, kl],
+ feed_dict={x: val_batch['data'], y: val_batch['seg'], mask: val_batch['loss_mask']})
+ running_mean_val_rec_loss += val_rec_loss / cf.validation['n_batches']
+ running_mean_val_kl += val_kl / cf.validation['n_batches']
+
+ if j == 0:
+ image_path = os.path.join(cf.exp_dir,
+ 'batch_{}_val_reconstructions.png'.format(i // cf.validation['every_n_batches']))
+ training_utils.plot_batch(val_batch, val_rec, num_classes=cf.num_classes,
+ cmap=cf.color_map, out_dir=image_path)
+ image_path = os.path.join(cf.exp_dir,
+ 'batch_{}_val_samples.png'.format(i // cf.validation['every_n_batches']))
+
+ for _ in range(3):
+ val_sample_ = sess.run(sampled_logits, feed_dict={x: val_batch['data'], y: val_batch['seg']})
+ val_sample = np.concatenate([val_sample, val_sample_], axis=1)
+
+ training_utils.plot_batch(val_batch, val_sample, num_classes=cf.num_classes,
+ cmap=cf.color_map, out_dir=image_path)
+
+ val_summary = sess.run(validation_summary_op, feed_dict={mean_val_rec_loss: running_mean_val_rec_loss,
+ mean_val_kl: running_mean_val_kl})
+ summary_writer.add_summary(val_summary, i)
+
+ if cf.disable_progress_bar:
+ logging.info('Evaluating epoch {}/{}: validation loss={}, kl={}'\
+ .format(i, cf.n_training_batches, running_mean_val_rec_loss, running_mean_val_kl))
+
+ sess.run(global_step)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='Training of the Probabilistic U-Net')
+ parser.add_argument('-c', '--config', type=str, default='prob_unet_config.py',
+ help='name of the python script defining the training configuration')
+ parser.add_argument('-d', '--data_dir', type=str, default='',
+ help="full path to the data, if empty the config's data_dir attribute is used")
+ args = parser.parse_args()
+
+ # load config
+ cf = SourceFileLoader('cf', args.config).load_module()
+ if args.data_dir != '':
+ cf.data_dir = args.data_dir
+
+ # prepare experiment directory
+ if not os.path.isdir(cf.exp_dir):
+ os.mkdir(cf.exp_dir)
+
+ # log to file and console
+ log_path = os.path.join(cf.exp_dir, 'train.log')
+ logging.basicConfig(filename=log_path, level=logging.INFO)
+ logging.getLogger().addHandler(logging.StreamHandler())
+ logging.info('Logging to {}'.format(log_path))
+ tf.logging.set_verbosity(tf.logging.INFO)
+
+ train(cf)
\ No newline at end of file
diff --git a/utils/training_utils.py b/utils/training_utils.py
new file mode 100644
index 0000000..3571de6
--- /dev/null
+++ b/utils/training_utils.py
@@ -0,0 +1,249 @@
+# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Training utilities."""
+
+import tensorflow as tf
+import numpy as np
+import matplotlib
+
+matplotlib.use('agg')
+import matplotlib.pyplot as plt
+import matplotlib.gridspec as gridspec
+from tensorflow.python.ops.init_ops import VarianceScaling
+
+
+def ce_loss(labels, logits, n_classes, loss_mask=None, data_format='NCHW', one_hot_labels=True, name='ce_loss'):
+ """
+ Cross-entropy loss.
+ :param labels: 4D tensor
+ :param logits: 4D tensor
+ :param n_classes: integer for number of classes
+ :param loss_mask: binary 4D tensor, pixels to mask should be marked by 1s
+ :param data_format: string
+ :param one_hot_labels: bool, indicator for whether labels are to be expected in one-hot representation
+ :param name: string
+ :return: dict of (pixel-wise) mean and sum of cross-entropy loss
+ """
+ with tf.variable_scope(name):
+ # permute class channels into last axis
+ if data_format == 'NCHW':
+ labels = tf.transpose(labels, [0,2,3,1])
+ logits = tf.transpose(logits, [0,2,3,1])
+ elif data_format == 'NCDHW':
+ labels = tf.transpose(labels, [0,2,3,4,1])
+ logits = tf.transpose(logits, [0,2,3,4,1])
+
+ batch_size = tf.cast(tf.shape(labels)[0], tf.float32)
+
+ if one_hot_labels:
+ flat_labels = tf.reshape(labels, [-1, n_classes])
+ else:
+ flat_labels = tf.reshape(labels, [-1])
+ flat_labels = tf.one_hot(indices=flat_labels, depth=n_classes, axis=-1)
+ flat_logits = tf.reshape(logits, [-1, n_classes])
+
+ # do not compute gradients wrt the labels
+ flat_labels = tf.stop_gradient(flat_labels)
+
+ ce_per_pixel = tf.nn.softmax_cross_entropy_with_logits_v2(labels=flat_labels, logits=flat_logits)
+
+ # optional element-wise masking with binary loss mask
+ if loss_mask is None:
+ ce_sum = tf.reduce_sum(ce_per_pixel) / batch_size
+ ce_mean = tf.reduce_mean(ce_per_pixel)
+ else:
+ loss_mask_flat = tf.reshape(loss_mask, [-1,])
+ loss_mask_flat = (1. - tf.cast(loss_mask_flat, tf.float32))
+ ce_sum = tf.reduce_sum(loss_mask_flat * ce_per_pixel) / batch_size
+ n_valid_pixels = tf.reduce_sum(loss_mask_flat)
+ ce_mean = tf.reduce_sum(loss_mask_flat * ce_per_pixel) / n_valid_pixels
+
+ return {'sum': ce_sum, 'mean': ce_mean}
+
+
+def softmax_2_onehot(arr):
+ """Transform a numpy array of softmax values into a one-hot encoded array. Assumes classes are encoded in axis 1.
+ :param arr: ND array
+ :return: ND array
+ """
+ num_classes = arr.shape[1]
+ arr_argmax = np.argmax(arr, axis=1)
+
+ for c in range(num_classes):
+ arr[:,c] = (arr_argmax == c).astype(np.uint8)
+ return arr
+
+
+def numpy_one_hot(label_arr, num_classes):
+ """One-hotify an integer-labeled numpy array. One-hot encoding is encoded in additional last axis.
+ :param label_arr: ND array
+ :param num_classes: integer
+ :return: (N+1)D array
+ """
+ # replace labels >= num_classes with 0
+ label_arr[label_arr >= num_classes] = 0
+
+ res = np.eye(num_classes)[np.array(label_arr).reshape(-1)]
+ return res.reshape(list(label_arr.shape)+[num_classes])
+
+
+def calc_confusion(labels, samples, class_ixs, loss_mask=None):
+ """
+ Compute confusion matrix for each class across the given arrays.
+ Assumes classes are given in integer-valued encoding.
+ :param labels: 4/5D array
+ :param samples: 4/5D array
+ :param class_ixs: integer or list of integers specifying the classes to evaluate
+ :param loss_mask: 4/5D array
+ :return: 2D array
+ """
+ try:
+ assert labels.shape == samples.shape
+ except:
+ raise AssertionError('shape mismatch {} vs. {}'.format(labels.shape, samples.shape))
+
+ if isinstance(class_ixs, int):
+ num_classes = class_ixs
+ class_ixs = range(class_ixs)
+ elif isinstance(class_ixs, list):
+ num_classes = len(class_ixs)
+ else:
+ raise TypeError('arg class_ixs needs to be int or list, not {}.'.format(type(class_ixs)))
+
+ if loss_mask is None:
+ shp = labels.shape
+ loss_mask = np.zeros(shape=(shp[0], 1, shp[2], shp[3]))
+
+ conf_matrix = np.zeros(shape=(num_classes, 4), dtype=np.float32)
+ for i,c in enumerate(class_ixs):
+
+ pred_ = (samples == c).astype(np.uint8)
+ labels_ = (labels == c).astype(np.uint8)
+
+ conf_matrix[i,0] = int(((pred_ != 0) * (labels_ != 0) * (loss_mask != 1)).sum()) # TP
+ conf_matrix[i,1] = int(((pred_ != 0) * (labels_ == 0) * (loss_mask != 1)).sum()) # FP
+ conf_matrix[i,2] = int(((pred_ == 0) * (labels_ == 0) * (loss_mask != 1)).sum()) # TN
+ conf_matrix[i,3] = int(((pred_ == 0) * (labels_ != 0) * (loss_mask != 1)).sum()) # FN
+
+ return conf_matrix
+
+
+def metrics_from_conf_matrix(conf_matrix):
+ """
+ Calculate IoU per class from a confusion_matrix.
+ :param conf_matrix: 2D array of shape (num_classes, 4)
+ :return: dict holding 1D-vectors of metrics
+ """
+ tps = conf_matrix[:,0]
+ fps = conf_matrix[:,1]
+ fns = conf_matrix[:,3]
+
+ metrics = {}
+ metrics['iou'] = np.zeros_like(tps, dtype=np.float32)
+
+ # iterate classes
+ for c in range(tps.shape[0]):
+ # unless both the prediction and the ground-truth is empty, calculate a finite IoU
+ if tps[c] + fps[c] + fns[c] != 0:
+ metrics['iou'][c] = tps[c] / (tps[c] + fps[c] + fns[c])
+ else:
+ metrics['iou'][c] = np.nan
+
+ return metrics
+
+
+def he_normal(seed=None):
+ """He normal initializer.
+ It draws samples from a truncated normal distribution centered on 0
+ with `stddev = sqrt(2 / fan_in)`
+ where `fan_in` is the number of input units in the weight tensor.
+ Arguments:
+ seed: A Python integer. Used to seed the random generator.
+ Returns:
+ An initializer.
+ References:
+ He et al., http://arxiv.org/abs/1502.01852
+ Code:
+ https://github.com/tensorflow/tensorflow/blob/r1.9/tensorflow/python/keras/initializers.py
+ """
+ return VarianceScaling(scale=2., mode='fan_in', distribution='normal', seed=seed)
+
+
+def plot_batch(batch, prediction, cmap, num_classes, out_dir=None, clip_range=True):
+ """
+ Plots a batch of images, segmentations & samples and optionally saves it to disk.
+ :param batch: dict holding images and gt labels for a batch
+ :param prediction: logit prediction of the corresponding batch
+ :param cmap: dictionary as colormap
+ :param out_dir: full path to save png image to
+ :return:
+ """
+ img_arr = batch['data']
+ seg_arr = batch['seg']
+
+ num_predictions = prediction.shape[1] // num_classes
+ num_y_tiles = 2 + num_predictions
+ batch_size = img_arr.shape[0]
+
+ f = plt.figure(figsize=(batch_size * 4, num_y_tiles * 2))
+ gs = gridspec.GridSpec(num_y_tiles, batch_size, wspace=0.0, hspace=0.0)
+
+ # suppress matplotlib range warnings
+ if clip_range:
+ img_arr[img_arr < 0.] = 0.
+ img_arr[img_arr > 1.] = 1.
+
+ for tile in range(batch_size):
+ # image
+ ax = plt.subplot(gs[0, tile])
+ ax.get_xaxis().set_visible(False)
+ ax.get_yaxis().set_visible(False)
+ plt.imshow(np.transpose(img_arr[tile], axes=[1,2,0]))
+
+ # (here sampled) gt segmentation
+ ax = plt.subplot(gs[1, tile])
+ ax.get_xaxis().set_visible(False)
+ ax.get_yaxis().set_visible(False)
+ if seg_arr.shape[1] == 1:
+ gt_seg = np.squeeze(seg_arr[tile], axis=0)
+ else:
+ gt_seg = np.argmax(seg_arr[tile], axis=0)
+ plt.imshow(to_rgb(gt_seg, cmap))
+
+ # multiple predictions can be concatenated in channel axis, iterate all predictions
+ for i in range(num_predictions):
+ ax = plt.subplot(gs[2 + i, tile])
+ ax.get_xaxis().set_visible(False)
+ ax.get_yaxis().set_visible(False)
+ single_prediction = prediction[tile][i * num_classes: (i+1) * num_classes]
+ pred_seg = np.argmax(single_prediction, axis=0)
+ plt.imshow(to_rgb(pred_seg, cmap))
+ if out_dir is not None:
+ plt.savefig(out_dir, dpi=200, bbox_inches='tight', pad_inches=0.0)
+ plt.close(f)
+
+
+def to_rgb(arr, cmap):
+ """
+ Transform an integer-labeled segmentation map using an rgb color-map.
+ :param arr: img_arr w/o a color-channel
+ :param cmap: dictionary mapping from integer class labels to rgb values
+ :return:
+ """
+ new_arr = np.zeros(shape=(arr.shape)+(3,))
+ for c in cmap.keys():
+ ixs = np.where(arr == c)
+ new_arr[ixs] = [cmap[c][i] / 255. for i in range(3)]
+ return new_arr