diff --git a/hyppopy/plugins/hyperopt_solver_plugin.py b/hyppopy/plugins/hyperopt_solver_plugin.py index 7cc784f..9f3f608 100644 --- a/hyppopy/plugins/hyperopt_solver_plugin.py +++ b/hyppopy/plugins/hyperopt_solver_plugin.py @@ -1,71 +1,69 @@ -# -*- coding: utf-8 -*- -# # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import logging from hyppopy.globals import DEBUGLEVEL LOG = logging.getLogger(os.path.basename(__file__)) LOG.setLevel(DEBUGLEVEL) from pprint import pformat from hyperopt import fmin, tpe, hp, STATUS_OK, STATUS_FAIL, Trials from yapsy.IPlugin import IPlugin from hyppopy.projectmanager import ProjectManager from hyppopy.solverpluginbase import SolverPluginBase class hyperopt_Solver(SolverPluginBase, IPlugin): trials = None best = None def __init__(self): SolverPluginBase.__init__(self) LOG.debug("initialized") def loss_function(self, params): try: loss = self.loss(self.data, params) status = STATUS_OK except Exception as e: LOG.error(f"execution of self.loss(self.data, params) failed due to:\n {e}") status = STATUS_FAIL return {'loss': loss, 'status': status} def execute_solver(self, parameter): LOG.debug(f"execute_solver using solution space:\n\n\t{pformat(parameter)}\n") self.trials = Trials() try: self.best = fmin(fn=self.loss_function, space=parameter, algo=tpe.suggest, max_evals=ProjectManager.max_iterations, trials=self.trials) except Exception as e: msg = f"internal error in hyperopt.fmin occured. {e}" LOG.error(msg) raise BrokenPipeError(msg) def convert_results(self): txt = "" solution = dict([(k, v) for k, v in self.best.items() if v is not None]) txt += 'Solution Hyperopt Plugin\n========\n' txt += "\n".join(map(lambda x: "%s \t %s" % (x[0], str(x[1])), solution.items())) txt += "\n" return txt diff --git a/hyppopy/workflows/dataloader/unetloader.py b/hyppopy/workflows/dataloader/unetloader.py index e234b68..d103726 100644 --- a/hyppopy/workflows/dataloader/unetloader.py +++ b/hyppopy/workflows/dataloader/unetloader.py @@ -1,154 +1,159 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os import pickle import numpy as np from medpy.io import load from collections import defaultdict -from hyppopy.workflows.dataloader.dataloaderbase import DataLoaderBase +from .dataloaderbase import DataLoaderBase class UnetDataLoader(DataLoaderBase): def read(self, **kwargs): # preprocess data if not already done root_dir = os.path.join(kwargs['data_path'], kwargs['data_name']) split_dir = os.path.join(kwargs['data_path'], kwargs['split_dir']) preproc_dir = os.path.join(root_dir, 'preprocessed') if not os.path.isdir(preproc_dir): self.preprocess_data(root=root_dir, image_dir=kwargs['image_dir'], labels_dir=kwargs['labels_dir'], output_dir=preproc_dir, classes=kwargs['num_classes']) self.data = self.create_splits(output_dir=split_dir, image_dir=preproc_dir) else: with open(os.path.join(split_dir, "splits.pkl"), 'rb') as f: self.data = pickle.load(f) def subfiles(self, folder, join=True, prefix=None, suffix=None, sort=True): if join: l = os.path.join else: l = lambda x, y: y res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) and (prefix is None or i.startswith(prefix)) and (suffix is None or i.endswith(suffix))] if sort: res.sort() return res def reshape(self, orig_img, append_value=-1024, new_shape=(512, 512, 512)): reshaped_image = np.zeros(new_shape) reshaped_image[...] = append_value x_offset = 0 y_offset = 0 # (new_shape[1] - orig_img.shape[1]) // 2 z_offset = 0 # (new_shape[2] - orig_img.shape[2]) // 2 reshaped_image[x_offset:orig_img.shape[0] + x_offset, y_offset:orig_img.shape[1] + y_offset, z_offset:orig_img.shape[2] + z_offset] = orig_img return reshaped_image def preprocess_data(self, root, image_dir, labels_dir, output_dir, classes): image_dir = os.path.join(root, image_dir) label_dir = os.path.join(root, labels_dir) output_dir = os.path.join(root, output_dir) classes = classes if not os.path.exists(output_dir): os.makedirs(output_dir) print('Created' + output_dir + '...') class_stats = defaultdict(int) total = 0 nii_files = self.subfiles(image_dir, suffix=".nii.gz", join=False) for i in range(0, len(nii_files)): if nii_files[i].startswith("._"): nii_files[i] = nii_files[i][2:] for f in nii_files: image, _ = load(os.path.join(image_dir, f)) label, _ = load(os.path.join(label_dir, f.replace('_0000', ''))) print(f) for i in range(classes): class_stats[i] += np.sum(label == i) total += np.sum(label == i) image = (image - image.min()) / (image.max() - image.min()) image = self.reshape(image, append_value=0, new_shape=(64, 64, 64)) label = self.reshape(label, append_value=0, new_shape=(64, 64, 64)) result = np.stack((image, label)) np.save(os.path.join(output_dir, f.split('.')[0] + '.npy'), result) print(f) print(total) for i in range(classes): print(class_stats[i], class_stats[i] / total) def subfiles(self, folder, join=True, prefix=None, suffix=None, sort=True): if join: l = os.path.join else: l = lambda x, y: y res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) and (prefix is None or i.startswith(prefix)) and (suffix is None or i.endswith(suffix))] if sort: res.sort() return res def create_splits(self, output_dir, image_dir): npy_files = self.subfiles(image_dir, suffix=".npy", join=False) trainset_size = len(npy_files)*50//100 valset_size = len(npy_files)*25//100 testset_size = len(npy_files)*25//100 splits = [] for split in range(0, 5): image_list = npy_files.copy() trainset = [] valset = [] testset = [] for i in range(0, trainset_size): patient = np.random.choice(image_list) image_list.remove(patient) trainset.append(patient[:-4]) for i in range(0, valset_size): patient = np.random.choice(image_list) image_list.remove(patient) valset.append(patient[:-4]) for i in range(0, testset_size): patient = np.random.choice(image_list) image_list.remove(patient) testset.append(patient[:-4]) split_dict = dict() split_dict['train'] = trainset split_dict['val'] = valset split_dict['test'] = testset splits.append(split_dict) with open(os.path.join(output_dir, 'splits.pkl'), 'wb') as f: pickle.dump(splits, f) return splits + + + + + diff --git a/hyppopy/workflows/unet_usecase/datasets/__init__.py b/hyppopy/workflows/unet_usecase/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hyppopy/workflows/unet_usecase/datasets/data_loader.py b/hyppopy/workflows/unet_usecase/datasets/data_loader.py new file mode 100644 index 0000000..3666ad8 --- /dev/null +++ b/hyppopy/workflows/unet_usecase/datasets/data_loader.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch.utils.data import DataLoader, Dataset +from trixi.util.pytorchutils import set_seed + + +class WrappedDataset(Dataset): + def __init__(self, dataset, transform): + self.transform = transform + self.dataset = dataset + + self.is_indexable = False + if hasattr(self.dataset, "__getitem__") and not (hasattr(self.dataset, "use_next") and self.dataset.use_next is True): + self.is_indexable = True + + def __getitem__(self, index): + + if not self.is_indexable: + item = next(self.dataset) + else: + item = self.dataset[index] + item = self.transform(**item) + return item + + def __len__(self): + return int(self.dataset.num_batches) + + +class MultiThreadedDataLoader(object): + def __init__(self, data_loader, transform, num_processes, **kwargs): + + self.cntr = 1 + self.ds_wrapper = WrappedDataset(data_loader, transform) + + self.generator = DataLoader(self.ds_wrapper, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, + num_workers=num_processes, pin_memory=True, drop_last=False, + worker_init_fn=self.get_worker_init_fn()) + + self.num_processes = num_processes + self.iter = None + + def get_worker_init_fn(self): + def init_fn(worker_id): + set_seed(worker_id + self.cntr) + + return init_fn + + def __iter__(self): + self.kill_iterator() + self.iter = iter(self.generator) + return self.iter + + def __next__(self): + if self.iter is None: + self.iter = iter(self.generator) + return next(self.iter) + + def renew(self): + self.cntr += 1 + self.kill_iterator() + self.generator.worker_init_fn = self.get_worker_init_fn() + self.iter = iter(self.generator) + + def restart(self): + pass + # self.iter = iter(self.generator) + + def kill_iterator(self): + try: + if self.iter is not None: + self.iter._shutdown_workers() + for p in self.iter.workers: + p.terminate() + except: + print("Could not kill Dataloader Iterator") diff --git a/hyppopy/workflows/unet_usecase/datasets/example_dataset/__init__.py b/hyppopy/workflows/unet_usecase/datasets/example_dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hyppopy/workflows/unet_usecase/datasets/example_dataset/create_splits.py b/hyppopy/workflows/unet_usecase/datasets/example_dataset/create_splits.py new file mode 100644 index 0000000..41ee520 --- /dev/null +++ b/hyppopy/workflows/unet_usecase/datasets/example_dataset/create_splits.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle +from utilities.file_and_folder_operations import subfiles + +import os +import numpy as np + + +def create_splits(output_dir, image_dir): + npy_files = subfiles(image_dir, suffix=".npy", join=False) + + trainset_size = len(npy_files)*50//100 + valset_size = len(npy_files)*25//100 + testset_size = len(npy_files)*25//100 + + splits = [] + for split in range(0, 5): + image_list = npy_files.copy() + trainset = [] + valset = [] + testset = [] + for i in range(0, trainset_size): + patient = np.random.choice(image_list) + image_list.remove(patient) + trainset.append(patient[:-4]) + for i in range(0, valset_size): + patient = np.random.choice(image_list) + image_list.remove(patient) + valset.append(patient[:-4]) + for i in range(0, testset_size): + patient = np.random.choice(image_list) + image_list.remove(patient) + testset.append(patient[:-4]) + split_dict = dict() + split_dict['train'] = trainset + split_dict['val'] = valset + split_dict['test'] = testset + + splits.append(split_dict) + + with open(os.path.join(output_dir, 'splits.pkl'), 'wb') as f: + pickle.dump(splits, f) diff --git a/hyppopy/workflows/unet_usecase/datasets/example_dataset/download_dataset.py b/hyppopy/workflows/unet_usecase/datasets/example_dataset/download_dataset.py new file mode 100644 index 0000000..77a6a5b --- /dev/null +++ b/hyppopy/workflows/unet_usecase/datasets/example_dataset/download_dataset.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from os.path import exists +import tarfile + +from google_drive_downloader import GoogleDriveDownloader as gdd + +def download_dataset(dest_path, dataset, id='1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C'): + tar_path = os.path.join(dest_path, dataset) + '.tar' + gdd.download_file_from_google_drive(file_id=id, + dest_path=tar_path, overwrite=False, + unzip=False) + + if not exists(os.path.join(dest_path, dataset)): + print('Extracting data [STARTED]') + tar = tarfile.open(tar_path) + tar.extractall(dest_path) + print('Extracting data [DONE]') + else: + print('Data already downloaded. Files are not extracted again.') + + return diff --git a/hyppopy/workflows/unet_usecase/datasets/example_dataset/preprocessing.py b/hyppopy/workflows/unet_usecase/datasets/example_dataset/preprocessing.py new file mode 100644 index 0000000..ebe132c --- /dev/null +++ b/hyppopy/workflows/unet_usecase/datasets/example_dataset/preprocessing.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + +from medpy.io import load +import os +import numpy as np + +from datasets.utils import reshape +from utilities.file_and_folder_operations import subfiles + + +def preprocess_data(root_dir): + image_dir = os.path.join(root_dir, 'imagesTr') + label_dir = os.path.join(root_dir, 'labelsTr') + output_dir = os.path.join(root_dir, 'preprocessed') + classes = 3 + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print('Created' + output_dir + '...') + + class_stats = defaultdict(int) + total = 0 + + nii_files = subfiles(image_dir, suffix=".nii.gz", join=False) + + for i in range(0, len(nii_files)): + if nii_files[i].startswith("._"): + nii_files[i] = nii_files[i][2:] + + for f in nii_files: + image, _ = load(os.path.join(image_dir, f)) + label, _ = load(os.path.join(label_dir, f.replace('_0000', ''))) + + print(f) + + for i in range(classes): + class_stats[i] += np.sum(label == i) + total += np.sum(label == i) + + image = (image - image.min())/(image.max()-image.min()) + + image = reshape(image, append_value=0, new_shape=(64, 64, 64)) + label = reshape(label, append_value=0, new_shape=(64, 64, 64)) + + result = np.stack((image, label)) + + np.save(os.path.join(output_dir, f.split('.')[0]+'.npy'), result) + print(f) + + print(total) + for i in range(classes): + print(class_stats[i], class_stats[i]/total) diff --git a/hyppopy/workflows/unet_usecase/datasets/three_dim/NumpyDataLoader.py b/hyppopy/workflows/unet_usecase/datasets/three_dim/NumpyDataLoader.py new file mode 100644 index 0000000..6b42934 --- /dev/null +++ b/hyppopy/workflows/unet_usecase/datasets/three_dim/NumpyDataLoader.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import fnmatch +import random + +import numpy as np + +from batchgenerators.dataloading import SlimDataLoaderBase +from datasets.data_loader import MultiThreadedDataLoader +from .data_augmentation import get_transforms + + +def load_dataset(base_dir, pattern='*.npy', keys=None): + fls = [] + files_len = [] + dataset = [] + + for root, dirs, files in os.walk(base_dir): + i = 0 + for filename in sorted(fnmatch.filter(files, pattern)): + + if keys is not None and filename[:-4] in keys: + npy_file = os.path.join(root, filename) + numpy_array = np.load(npy_file, mmap_mode="r") + + fls.append(npy_file) + files_len.append(numpy_array.shape[1]) + + dataset.extend([i]) + + i += 1 + + return fls, files_len, dataset + + +class NumpyDataSet(object): + """ + TODO + """ + def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, seed=None, num_processes=8, num_cached_per_queue=8 * 4, target_size=128, + file_pattern='*.npy', label=1, input=(0,), do_reshuffle=True, keys=None): + + data_loader = NumpyDataLoader(base_dir=base_dir, mode=mode, batch_size=batch_size, num_batches=num_batches, seed=seed, file_pattern=file_pattern, + input=input, label=label, keys=keys) + + self.data_loader = data_loader + self.batch_size = batch_size + self.do_reshuffle = do_reshuffle + self.number_of_slices = 1 + + self.transforms = get_transforms(mode=mode, target_size=target_size) + self.augmenter = MultiThreadedDataLoader(data_loader, self.transforms, num_processes=num_processes, + num_cached_per_queue=num_cached_per_queue, seeds=seed, + shuffle=do_reshuffle) + self.augmenter.restart() + + def __len__(self): + return len(self.data_loader) + + def __iter__(self): + if self.do_reshuffle: + self.data_loader.reshuffle() + self.augmenter.renew() + return self.augmenter + + def __next__(self): + return next(self.augmenter) + + +class NumpyDataLoader(SlimDataLoaderBase): + def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, + seed=None, file_pattern='*.npy', label=1, input=(0,), keys=None): + + self.files, self.file_len, self.dataset = load_dataset(base_dir=base_dir, pattern=file_pattern, keys=keys, ) + super(NumpyDataLoader, self).__init__(self.dataset, batch_size, num_batches) + + self.batch_size = batch_size + + self.use_next = False + if mode == "train": + self.use_next = False + + self.idxs = list(range(0, len(self.dataset))) + + self.data_len = len(self.dataset) + + self.num_batches = min((self.data_len // self.batch_size)+10, num_batches) + + if isinstance(label, int): + label = (label,) + self.input = input + self.label = label + + self.np_data = np.asarray(self.dataset) + + def reshuffle(self): + print("Reshuffle...") + random.shuffle(self.idxs) + print("Initializing... this might take a while...") + + def generate_train_batch(self): + open_arr = random.sample(self._data, self.batch_size) + return self.get_data_from_array(open_arr) + + def __len__(self): + n_items = min(self.data_len // self.batch_size, self.num_batches) + return n_items + + def __getitem__(self, item): + idxs = self.idxs + data_len = len(self.dataset) + np_data = self.np_data + + if item > len(self): + raise StopIteration() + if (item * self.batch_size) == data_len: + raise StopIteration() + + start_idx = (item * self.batch_size) % data_len + stop_idx = ((item + 1) * self.batch_size) % data_len + + if ((item + 1) * self.batch_size) == data_len: + stop_idx = data_len + + if stop_idx > start_idx: + idxs = idxs[start_idx:stop_idx] + else: + raise StopIteration() + + open_arr = np_data[idxs] + + return self.get_data_from_array(open_arr) + + def get_data_from_array(self, open_array): + data = [] + fnames = [] + idxs = [] + labels = [] + + for idx in open_array: + fn_name = self.files[idx] + + numpy_array = np.load(fn_name, mmap_mode="r") + + data.append(numpy_array[None, self.input[0]]) # 'None' keeps the dimension + + if self.label is not None: + labels.append(numpy_array[None, self.label[0]]) # 'None' keeps the dimension + + fnames.append(self.files[idx]) + idxs.append(idx) + + ret_dict = {'data': data, 'fnames': fnames, 'idxs': idxs} + if self.label is not None: + ret_dict['seg'] = labels + + return ret_dict diff --git a/hyppopy/workflows/unet_usecase/datasets/three_dim/__init__.py b/hyppopy/workflows/unet_usecase/datasets/three_dim/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hyppopy/workflows/unet_usecase/datasets/three_dim/data_augmentation.py b/hyppopy/workflows/unet_usecase/datasets/three_dim/data_augmentation.py new file mode 100644 index 0000000..c91a6a9 --- /dev/null +++ b/hyppopy/workflows/unet_usecase/datasets/three_dim/data_augmentation.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from batchgenerators.transforms import Compose, MirrorTransform +from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform, RandomCropTransform +from batchgenerators.transforms.spatial_transforms import ResizeTransform, SpatialTransform +from batchgenerators.transforms.utility_transforms import NumpyToTensor + + +def get_transforms(mode="train", target_size=128): + tranform_list = [] + + if mode == "train": + tranform_list = [CenterCropTransform(crop_size=target_size), + ResizeTransform(target_size=target_size, order=1), + MirrorTransform(axes=(2,)), + SpatialTransform(patch_size=(target_size,target_size,target_size), random_crop=False, + patch_center_dist_from_border=target_size // 2, + do_elastic_deform=True, alpha=(0., 1000.), sigma=(40., 60.), + do_rotation=True, + angle_x=(-0.1, 0.1), angle_y=(0, 1e-8), angle_z=(0, 1e-8), + scale=(0.9, 1.4), + border_mode_data="nearest", border_mode_seg="nearest"), + ] + + + elif mode == "val": + tranform_list = [CenterCropTransform(crop_size=target_size), + ResizeTransform(target_size=target_size, order=1), + ] + + elif mode == "test": + tranform_list = [CenterCropTransform(crop_size=target_size), + ResizeTransform(target_size=target_size, order=1), + ] + + tranform_list.append(NumpyToTensor()) + + return Compose(tranform_list) diff --git a/hyppopy/workflows/unet_usecase/datasets/two_dim/NumpyDataLoader.py b/hyppopy/workflows/unet_usecase/datasets/two_dim/NumpyDataLoader.py new file mode 100644 index 0000000..1ae0bb4 --- /dev/null +++ b/hyppopy/workflows/unet_usecase/datasets/two_dim/NumpyDataLoader.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import fnmatch +import random + +import numpy as np + +from batchgenerators.dataloading import SlimDataLoaderBase +from ..data_loader import MultiThreadedDataLoader +from .data_augmentation import get_transforms + + +def load_dataset(base_dir, pattern='*.npy', slice_offset=5, keys=None): + fls = [] + files_len = [] + slices_ax = [] + + for root, dirs, files in os.walk(base_dir): + i = 0 + for filename in sorted(fnmatch.filter(files, pattern)): + + if keys is not None and filename[:-4] in keys: + npy_file = os.path.join(root, filename) + numpy_array = np.load(npy_file, mmap_mode="r") + + fls.append(npy_file) + files_len.append(numpy_array.shape[1]) + + slices_ax.extend([(i, j) for j in range(slice_offset, files_len[-1] - slice_offset)]) + + i += 1 + + return fls, files_len, slices_ax, + + +class NumpyDataSet(object): + """ + TODO + """ + def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, seed=None, num_processes=8, num_cached_per_queue=8 * 4, target_size=128, + file_pattern='*.npy', label_slice=1, input_slice=(0,), do_reshuffle=True, keys=None): + + data_loader = NumpyDataLoader(base_dir=base_dir, mode=mode, batch_size=batch_size, num_batches=num_batches, seed=seed, file_pattern=file_pattern, + input_slice=input_slice, label_slice=label_slice, keys=keys) + + self.data_loader = data_loader + self.batch_size = batch_size + self.do_reshuffle = do_reshuffle + self.number_of_slices = 1 + + self.transforms = get_transforms(mode=mode, target_size=target_size) + self.augmenter = MultiThreadedDataLoader(data_loader, self.transforms, num_processes=1, + num_cached_per_queue=num_cached_per_queue, seeds=seed, + shuffle=do_reshuffle) + self.augmenter.restart() + + def __len__(self): + return len(self.data_loader) + + def __iter__(self): + if self.do_reshuffle: + self.data_loader.reshuffle() + self.augmenter.renew() + return self.augmenter + + def __next__(self): + return next(self.augmenter) + + +class NumpyDataLoader(SlimDataLoaderBase): + def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, + seed=None, file_pattern='*.npy', label_slice=1, input_slice=(0,), keys=None): + + self.files, self.file_len, self.slices = load_dataset(base_dir=base_dir, pattern=file_pattern, slice_offset=0, keys=keys, ) + super(NumpyDataLoader, self).__init__(self.slices, batch_size, num_batches) + + self.batch_size = batch_size + + self.use_next = False + if mode == "train": + self.use_next = False + + self.slice_idxs = list(range(0, len(self.slices))) + + self.data_len = len(self.slices) + + self.num_batches = min((self.data_len // self.batch_size)+10, num_batches) + + if isinstance(label_slice, int): + label_slice = (label_slice,) + self.input_slice = input_slice + self.label_slice = label_slice + + self.np_data = np.asarray(self.slices) + + def reshuffle(self): + print("Reshuffle...") + random.shuffle(self.slice_idxs) + print("Initializing... this might take a while...") + + def generate_train_batch(self): + open_arr = random.sample(self._data, self.batch_size) + return self.get_data_from_array(open_arr) + + def __len__(self): + n_items = min(self.data_len // self.batch_size, self.num_batches) + return n_items + + def __getitem__(self, item): + slice_idxs = self.slice_idxs + data_len = len(self.slices) + np_data = self.np_data + + if item > len(self): + raise StopIteration() + if (item * self.batch_size) == data_len: + raise StopIteration() + + start_idx = (item * self.batch_size) % data_len + stop_idx = ((item + 1) * self.batch_size) % data_len + + if ((item + 1) * self.batch_size) == data_len: + stop_idx = data_len + + if stop_idx > start_idx: + idxs = slice_idxs[start_idx:stop_idx] + else: + raise StopIteration() + + open_arr = np_data[idxs] + + return self.get_data_from_array(open_arr) + + def get_data_from_array(self, open_array): + data = [] + fnames = [] + slice_idxs = [] + labels = [] + + for slice in open_array: + fn_name = self.files[slice[0]] + + numpy_array = np.load(fn_name, mmap_mode="r") + + numpy_slice = numpy_array[ :, slice[1], ] + data.append(numpy_slice[None, self.input_slice[0]]) # 'None' keeps the dimension + + if self.label_slice is not None: + labels.append(numpy_slice[None, self.label_slice[0]]) # 'None' keeps the dimension + + fnames.append(self.files[slice[0]]) + slice_idxs.append(slice[1]) + + ret_dict = {'data': np.asarray(data), 'fnames': fnames, 'slice_idxs': slice_idxs} + if self.label_slice is not None: + ret_dict['seg'] = np.asarray(labels) + + return ret_dict diff --git a/hyppopy/workflows/unet_usecase/datasets/two_dim/__init__.py b/hyppopy/workflows/unet_usecase/datasets/two_dim/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hyppopy/workflows/unet_usecase/datasets/two_dim/data_augmentation.py b/hyppopy/workflows/unet_usecase/datasets/two_dim/data_augmentation.py new file mode 100644 index 0000000..ba7b7dd --- /dev/null +++ b/hyppopy/workflows/unet_usecase/datasets/two_dim/data_augmentation.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from batchgenerators.transforms import Compose, MirrorTransform +from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform, RandomCropTransform +from batchgenerators.transforms.spatial_transforms import ResizeTransform, SpatialTransform +from batchgenerators.transforms.utility_transforms import NumpyToTensor + + +def get_transforms(mode="train", target_size=128): + tranform_list = [] + + if mode == "train": + tranform_list = [# CenterCropTransform(crop_size=target_size), + ResizeTransform(target_size=(target_size,target_size), order=1), + MirrorTransform(axes=(1,)), + SpatialTransform(patch_size=(target_size, target_size), random_crop=False, + patch_center_dist_from_border=target_size // 2, + do_elastic_deform=True, alpha=(0., 1000.), sigma=(40., 60.), + do_rotation=True, p_rot_per_sample=0.5, + angle_x=(-0.1, 0.1), angle_y=(0, 1e-8), angle_z=(0, 1e-8), + scale=(0.5, 1.9), p_scale_per_sample=0.5, + border_mode_data="nearest", border_mode_seg="nearest"), + ] + + + elif mode == "val": + tranform_list = [CenterCropTransform(crop_size=target_size), + ResizeTransform(target_size=target_size, order=1), + ] + + elif mode == "test": + tranform_list = [CenterCropTransform(crop_size=target_size), + ResizeTransform(target_size=target_size, order=1), + ] + + tranform_list.append(NumpyToTensor()) + + return Compose(tranform_list) diff --git a/hyppopy/workflows/unet_usecase/datasets/utils.py b/hyppopy/workflows/unet_usecase/datasets/utils.py new file mode 100644 index 0000000..755b088 --- /dev/null +++ b/hyppopy/workflows/unet_usecase/datasets/utils.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +def reshape(orig_img, append_value=-1024, new_shape=(512, 512, 512)): + reshaped_image = np.zeros(new_shape) + reshaped_image[...] = append_value + x_offset = 0 + y_offset = 0 # (new_shape[1] - orig_img.shape[1]) // 2 + z_offset = 0 # (new_shape[2] - orig_img.shape[2]) // 2 + + reshaped_image[x_offset:orig_img.shape[0]+x_offset, y_offset:orig_img.shape[1]+y_offset, z_offset:orig_img.shape[2]+z_offset] = orig_img + # insert temp_img.min() as background value + + return reshaped_image diff --git a/hyppopy/workflows/unet_usecase/loss_functions/__init__.py b/hyppopy/workflows/unet_usecase/loss_functions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hyppopy/workflows/unet_usecase/loss_functions/dice_loss.py b/hyppopy/workflows/unet_usecase/loss_functions/dice_loss.py new file mode 100644 index 0000000..48c7acc --- /dev/null +++ b/hyppopy/workflows/unet_usecase/loss_functions/dice_loss.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np +from torch import nn + + +def sum_tensor(input, axes, keepdim=False): + axes = np.unique(axes).astype(int) + if keepdim: + for ax in axes: + input = input.sum(int(ax), keepdim=True) + else: + for ax in sorted(axes, reverse=True): + input = input.sum(int(ax)) + return input + + +def mean_tensor(input, axes, keepdim=False): + axes = np.unique(axes).astype(int) + if keepdim: + for ax in axes: + input = input.mean(int(ax), keepdim=True) + else: + for ax in sorted(axes, reverse=True): + input = input.mean(int(ax)) + return input + + +class SoftDiceLoss(nn.Module): + def __init__(self, smooth=1., apply_nonlin=None, batch_dice=False, do_bg=True, smooth_in_nom=True, background_weight=1, rebalance_weights=None): + """ + hahaa no documentation for you today + :param smooth: + :param apply_nonlin: + :param batch_dice: + :param do_bg: + :param smooth_in_nom: + :param background_weight: + :param rebalance_weights: + """ + super(SoftDiceLoss, self).__init__() + if not do_bg: + assert background_weight == 1, "if there is no bg, then set background weight to 1 you dummy" + self.rebalance_weights = rebalance_weights + self.background_weight = background_weight + self.smooth_in_nom = smooth_in_nom + self.do_bg = do_bg + self.batch_dice = batch_dice + self.apply_nonlin = apply_nonlin + self.smooth = smooth + self.y_onehot = None + if not smooth_in_nom: + self.nom_smooth = 0 + else: + self.nom_smooth = smooth + + def forward(self, x, y): + with torch.no_grad(): + y = y.long() + shp_x = x.shape + shp_y = y.shape + if self.apply_nonlin is not None: + x = self.apply_nonlin(x) + if len(shp_x) != len(shp_y): + y = y.view((shp_y[0], 1, *shp_y[1:])) + # now x and y should have shape (B, C, X, Y(, Z))) and (B, 1, X, Y(, Z))), respectively + y_onehot = torch.zeros(shp_x) + if x.device.type == "cuda": + y_onehot = y_onehot.cuda(x.device.index) + y_onehot.scatter_(1, y, 1) + if not self.do_bg: + x = x[:, 1:] + y_onehot = y_onehot[:, 1:] + if not self.batch_dice: + if self.background_weight != 1 or (self.rebalance_weights is not None): + raise NotImplementedError("nah son") + l = soft_dice(x, y_onehot, self.smooth, self.smooth_in_nom) + else: + l = soft_dice_per_batch_2(x, y_onehot, self.smooth, self.smooth_in_nom, + background_weight=self.background_weight, + rebalance_weights=self.rebalance_weights) + return l + + +def soft_dice_per_batch(net_output, gt, smooth=1., smooth_in_nom=1., background_weight=1): + axes = tuple([0] + list(range(2, len(net_output.size())))) + intersect = sum_tensor(net_output * gt, axes, keepdim=False) + denom = sum_tensor(net_output + gt, axes, keepdim=False) + weights = torch.ones(intersect.shape) + weights[0] = background_weight + if net_output.device.type == "cuda": + weights = weights.cuda(net_output.device.index) + result = (- ((2 * intersect + smooth_in_nom) / (denom + smooth)) * weights).mean() + return result + + +def soft_dice_per_batch_2(net_output, gt, smooth=1., smooth_in_nom=1., background_weight=1, rebalance_weights=None): + if rebalance_weights is not None and len(rebalance_weights) != gt.shape[1]: + rebalance_weights = rebalance_weights[1:] # this is the case when use_bg=False + axes = tuple([0] + list(range(2, len(net_output.size())))) + tp = sum_tensor(net_output * gt, axes, keepdim=False) + fn = sum_tensor((1 - net_output) * gt, axes, keepdim=False) + fp = sum_tensor(net_output * (1 - gt), axes, keepdim=False) + weights = torch.ones(tp.shape) + weights[0] = background_weight + if net_output.device.type == "cuda": + weights = weights.cuda(net_output.device.index) + if rebalance_weights is not None: + rebalance_weights = torch.from_numpy(rebalance_weights).float() + if net_output.device.type == "cuda": + rebalance_weights = rebalance_weights.cuda(net_output.device.index) + tp = tp * rebalance_weights + fn = fn * rebalance_weights + result = (- ((2 * tp + smooth_in_nom) / (2 * tp + fp + fn + smooth)) * weights).mean() + return result + + +def soft_dice(net_output, gt, smooth=1., smooth_in_nom=1.): + axes = tuple(range(2, len(net_output.size()))) + intersect = sum_tensor(net_output * gt, axes, keepdim=False) + denom = sum_tensor(net_output + gt, axes, keepdim=False) + result = (- ((2 * intersect + smooth_in_nom) / (denom + smooth)) * weights).mean() #TODO: Was ist weights and er Stelle? + return result + + +class MultipleOutputLoss(nn.Module): + def __init__(self, loss, weight_factors=None): + """ + use this if you have several outputs that should predict the same y + :param loss: + :param weight_factors: + """ + super(MultipleOutputLoss, self).__init__() + self.weight_factors = weight_factors + self.loss = loss + + def forward(self, x, y): + assert isinstance(x, (tuple, list)), "x must be either tuple or list" + if self.weight_factors is None: + weights = [1] * len(x) + else: + weights = self.weight_factors + l = weights[0] * self.loss(x[0], y) + for i in range(1, len(x)): + l += weights[i] * self.loss(x[i], y) + return l \ No newline at end of file diff --git a/hyppopy/workflows/unet_usecase/networks/RecursiveUNet.py b/hyppopy/workflows/unet_usecase/networks/RecursiveUNet.py new file mode 100644 index 0000000..8ca7017 --- /dev/null +++ b/hyppopy/workflows/unet_usecase/networks/RecursiveUNet.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Defines the Unet. +# |num_downs|: number of downsamplings in UNet. For example, +# if |num_downs| == 7, image of size 128x128 will become of size 1x1 at the bottleneck + +# recursive implementation of Unet +import torch + +from torch import nn + + +class UNet(nn.Module): + def __init__(self, num_classes=3, in_channels=1, initial_filter_size=64, kernel_size=3, num_downs=4, norm_layer=nn.InstanceNorm2d): + # norm_layer=nn.BatchNorm2d, use_dropout=False): + super(UNet, self).__init__() + + # construct unet structure + unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-1), out_channels=initial_filter_size * 2 ** num_downs, + num_classes=num_classes, kernel_size=kernel_size, norm_layer=norm_layer, innermost=True) + for i in range(1, num_downs): + unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-(i+1)), + out_channels=initial_filter_size * 2 ** (num_downs-i), + num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(in_channels=in_channels, out_channels=initial_filter_size, + num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer, + outermost=True) + + self.model = unet_block + + def forward(self, x): + return self.model(x) + + +# Defines the submodule with skip connection. +# X -------------------identity---------------------- X +# |-- downsampling -- |submodule| -- upsampling --| +class UnetSkipConnectionBlock(nn.Module): + def __init__(self, in_channels=None, out_channels=None, num_classes=1, kernel_size=3, + submodule=None, outermost=False, innermost=False, norm_layer=nn.InstanceNorm2d, use_dropout=False): + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + # downconv + pool = nn.MaxPool2d(2, stride=2) + conv1 = self.contract(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer) + conv2 = self.contract(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer) + + # upconv + conv3 = self.expand(in_channels=out_channels*2, out_channels=out_channels, kernel_size=kernel_size) + conv4 = self.expand(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size) + + if outermost: + final = nn.Conv2d(out_channels, num_classes, kernel_size=1) + down = [conv1, conv2] + up = [conv3, conv4, final] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(in_channels*2, in_channels, + kernel_size=2, stride=2) + model = [pool, conv1, conv2, upconv] + else: + upconv = nn.ConvTranspose2d(in_channels*2, in_channels, kernel_size=2, stride=2) + + down = [pool, conv1, conv2] + up = [conv3, conv4, upconv] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + @staticmethod + def contract(in_channels, out_channels, kernel_size=3, norm_layer=nn.InstanceNorm2d): + layer = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, padding=1), + norm_layer(out_channels), + nn.LeakyReLU(inplace=True)) + return layer + + @staticmethod + def expand(in_channels, out_channels, kernel_size=3): + layer = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, padding=1), + nn.LeakyReLU(inplace=True), + ) + return layer + + @staticmethod + def center_crop(layer, target_width, target_height): + batch_size, n_channels, layer_width, layer_height = layer.size() + xy1 = (layer_width - target_width) // 2 + xy2 = (layer_height - target_height) // 2 + return layer[:, :, xy1:(xy1 + target_width), xy2:(xy2 + target_height)] + + def forward(self, x): + if self.outermost: + return self.model(x) + else: + crop = self.center_crop(self.model(x), x.size()[2], x.size()[3]) + return torch.cat([x, crop], 1) diff --git a/hyppopy/workflows/unet_usecase/networks/RecursiveUNet3D.py b/hyppopy/workflows/unet_usecase/networks/RecursiveUNet3D.py new file mode 100644 index 0000000..0801adf --- /dev/null +++ b/hyppopy/workflows/unet_usecase/networks/RecursiveUNet3D.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Defines the Unet. +# |num_downs|: number of downsamplings in UNet. For example, +# if |num_downs| == 7, image of size 128x128 will become of size 1x1 at the bottleneck + +# recursive implementation of Unet +import torch + +from torch import nn + + +class UNet3D(nn.Module): + def __init__(self, num_classes=3, in_channels=1, initial_filter_size=64, kernel_size=3, num_downs=3, norm_layer=nn.InstanceNorm3d): + # norm_layer=nn.BatchNorm2d, use_dropout=False): + super(UNet3D, self).__init__() + + # construct unet structure + unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-1), out_channels=initial_filter_size * 2 ** num_downs, + num_classes=num_classes, kernel_size=kernel_size, norm_layer=norm_layer, innermost=True) + for i in range(1, num_downs): + unet_block = UnetSkipConnectionBlock(in_channels=initial_filter_size * 2 ** (num_downs-(i+1)), + out_channels=initial_filter_size * 2 ** (num_downs-i), + num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(in_channels=in_channels, out_channels=initial_filter_size, + num_classes=num_classes, kernel_size=kernel_size, submodule=unet_block, norm_layer=norm_layer, + outermost=True) + + self.model = unet_block + + def forward(self, x): + return self.model(x) + + +# Defines the submodule with skip connection. +# X -------------------identity---------------------- X +# |-- downsampling -- |submodule| -- upsampling --| +class UnetSkipConnectionBlock(nn.Module): + def __init__(self, in_channels=None, out_channels=None, num_classes=1, kernel_size=3, + submodule=None, outermost=False, innermost=False, norm_layer=nn.InstanceNorm3d, use_dropout=False): + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + # downconv + pool = nn.MaxPool3d(2, stride=2) + conv1 = self.contract(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer) + conv2 = self.contract(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, norm_layer=norm_layer) + + # upconv + conv3 = self.expand(in_channels=out_channels*2, out_channels=out_channels, kernel_size=kernel_size) + conv4 = self.expand(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size) + + if outermost: + final = nn.Conv3d(out_channels, num_classes, kernel_size=1) + down = [conv1, conv2] + up = [conv3, conv4, final] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose3d(in_channels*2, in_channels, + kernel_size=2, stride=2) + model = [pool, conv1, conv2, upconv] + else: + upconv = nn.ConvTranspose3d(in_channels*2, in_channels, kernel_size=2, stride=2) + + down = [pool, conv1, conv2] + up = [conv3, conv4, upconv] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + @staticmethod + def contract(in_channels, out_channels, kernel_size=3, norm_layer=nn.InstanceNorm3d): + layer = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size, padding=1), + norm_layer(out_channels), + nn.LeakyReLU(inplace=True)) + return layer + + @staticmethod + def expand(in_channels, out_channels, kernel_size=3): + layer = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size, padding=1), + nn.LeakyReLU(inplace=True), + ) + return layer + + @staticmethod + def center_crop(layer, target_depth, target_width, target_height): + batch_size, n_channels, layer_depth, layer_width, layer_height = layer.size() + xy0 = (layer_depth - target_depth) // 2 + xy1 = (layer_width - target_width) // 2 + xy2 = (layer_height - target_height) // 2 + return layer[:, :, xy0:(xy0 + target_depth), xy1:(xy1 + target_width), xy2:(xy2 + target_height)] + + def forward(self, x): + if self.outermost: + return self.model(x) + else: + crop = self.center_crop(self.model(x), x.size()[2], x.size()[3], x.size()[4]) + return torch.cat([x, crop], 1) diff --git a/hyppopy/workflows/unet_usecase/networks/UNET.py b/hyppopy/workflows/unet_usecase/networks/UNET.py new file mode 100644 index 0000000..11bd2ee --- /dev/null +++ b/hyppopy/workflows/unet_usecase/networks/UNET.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2017 Division of Medical Image Computing, German Cancer Research Center (DKFZ) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +class UNet(nn.Module): + + def __init__(self, num_classes, in_channels=1, initial_filter_size=64, kernel_size=3, do_instancenorm=True): + super().__init__() + + self.contr_1_1 = self.contract(in_channels, initial_filter_size, kernel_size, instancenorm=do_instancenorm) + self.contr_1_2 = self.contract(initial_filter_size, initial_filter_size, kernel_size, instancenorm=do_instancenorm) + self.pool = nn.MaxPool2d(2, stride=2) + + self.contr_2_1 = self.contract(initial_filter_size, initial_filter_size*2, kernel_size, instancenorm=do_instancenorm) + self.contr_2_2 = self.contract(initial_filter_size*2, initial_filter_size*2, kernel_size, instancenorm=do_instancenorm) + # self.pool2 = nn.MaxPool2d(2, stride=2) + + self.contr_3_1 = self.contract(initial_filter_size*2, initial_filter_size*2**2, kernel_size, instancenorm=do_instancenorm) + self.contr_3_2 = self.contract(initial_filter_size*2**2, initial_filter_size*2**2, kernel_size, instancenorm=do_instancenorm) + # self.pool3 = nn.MaxPool2d(2, stride=2) + + self.contr_4_1 = self.contract(initial_filter_size*2**2, initial_filter_size*2**3, kernel_size, instancenorm=do_instancenorm) + self.contr_4_2 = self.contract(initial_filter_size*2**3, initial_filter_size*2**3, kernel_size, instancenorm=do_instancenorm) + # self.pool4 = nn.MaxPool2d(2, stride=2) + + self.center = nn.Sequential( + nn.Conv2d(initial_filter_size*2**3, initial_filter_size*2**4, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(initial_filter_size*2**4, initial_filter_size*2**4, 3, padding=1), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(initial_filter_size*2**4, initial_filter_size*2**3, 2, stride=2), + nn.ReLU(inplace=True), + ) + + self.expand_4_1 = self.expand(initial_filter_size*2**4, initial_filter_size*2**3) + self.expand_4_2 = self.expand(initial_filter_size*2**3, initial_filter_size*2**3) + self.upscale4 = nn.ConvTranspose2d(initial_filter_size*2**3, initial_filter_size*2**2, kernel_size=2, stride=2) + + self.expand_3_1 = self.expand(initial_filter_size*2**3, initial_filter_size*2**2) + self.expand_3_2 = self.expand(initial_filter_size*2**2, initial_filter_size*2**2) + self.upscale3 = nn.ConvTranspose2d(initial_filter_size*2**2, initial_filter_size*2, 2, stride=2) + + self.expand_2_1 = self.expand(initial_filter_size*2**2, initial_filter_size*2) + self.expand_2_2 = self.expand(initial_filter_size*2, initial_filter_size*2) + self.upscale2 = nn.ConvTranspose2d(initial_filter_size*2, initial_filter_size, 2, stride=2) + + self.expand_1_1 = self.expand(initial_filter_size*2, initial_filter_size) + self.expand_1_2 = self.expand(initial_filter_size, initial_filter_size) + # Output layer for segmentation + self.final = nn.Conv2d(initial_filter_size, num_classes, kernel_size=1) # kernel size for final layer = 1, see paper + + self.softmax = torch.nn.Softmax2d() + + # Output layer for "autoencoder-mode" + self.output_reconstruction_map = nn.Conv2d(initial_filter_size, out_channels=1, kernel_size=1) + + @staticmethod + def contract(in_channels, out_channels, kernel_size=3, instancenorm=True): + if instancenorm: + layer = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, padding=1), + nn.InstanceNorm2d(out_channels), + nn.LeakyReLU(inplace=True)) + else: + layer = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, padding=1), + nn.LeakyReLU(inplace=True)) + return layer + + @staticmethod + def expand(in_channels, out_channels, kernel_size=3): + layer = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, padding=1), + nn.LeakyReLU(inplace=True), + ) + return layer + + @staticmethod + def center_crop(layer, target_width, target_height): + batch_size, n_channels, layer_width, layer_height = layer.size() + xy1 = (layer_width - target_width) // 2 + xy2 = (layer_height - target_height) // 2 + return layer[:, :, xy1:(xy1 + target_width), xy2:(xy2 + target_height)] + + def forward(self, x, enable_concat=True, print_layer_shapes=False): + concat_weight = 1 + if not enable_concat: + concat_weight = 0 + + contr_1 = self.contr_1_2(self.contr_1_1(x)) + pool = self.pool(contr_1) + + contr_2 = self.contr_2_2(self.contr_2_1(pool)) + pool = self.pool(contr_2) + + contr_3 = self.contr_3_2(self.contr_3_1(pool)) + pool = self.pool(contr_3) + + contr_4 = self.contr_4_2(self.contr_4_1(pool)) + pool = self.pool(contr_4) + + center = self.center(pool) + + crop = self.center_crop(contr_4, center.size()[2], center.size()[3]) + concat = torch.cat([center, crop*concat_weight], 1) + + expand = self.expand_4_2(self.expand_4_1(concat)) + upscale = self.upscale4(expand) + + crop = self.center_crop(contr_3, upscale.size()[2], upscale.size()[3]) + concat = torch.cat([upscale, crop*concat_weight], 1) + + expand = self.expand_3_2(self.expand_3_1(concat)) + upscale = self.upscale3(expand) + + crop = self.center_crop(contr_2, upscale.size()[2], upscale.size()[3]) + concat = torch.cat([upscale, crop*concat_weight], 1) + + expand = self.expand_2_2(self.expand_2_1(concat)) + upscale = self.upscale2(expand) + + crop = self.center_crop(contr_1, upscale.size()[2], upscale.size()[3]) + concat = torch.cat([upscale, crop*concat_weight], 1) + + expand = self.expand_1_2(self.expand_1_1(concat)) + + if enable_concat: + output = self.final(expand) + if not enable_concat: + output = self.output_reconstruction_map(expand) + + return output diff --git a/hyppopy/workflows/unet_usecase/unet_usecase.py b/hyppopy/workflows/unet_usecase/unet_usecase.py index 5880e4d..545a97c 100644 --- a/hyppopy/workflows/unet_usecase/unet_usecase.py +++ b/hyppopy/workflows/unet_usecase/unet_usecase.py @@ -1,41 +1,132 @@ # DKFZ # # # Copyright (c) German Cancer Research Center, # Division of Medical and Biological Informatics. # All rights reserved. # # This software is distributed WITHOUT ANY WARRANTY; without # even the implied warranty of MERCHANTABILITY or FITNESS FOR # A PARTICULAR PURPOSE. # # See LICENSE.txt or http://www.mitk.org for details. # # Author: Sven Wanner (s.wanner@dkfz.de) import os +import torch import numpy as np import pandas as pd from sklearn.svm import SVC +import torch.optim as optim +import torch.nn.functional as F +from .networks.RecursiveUNet import UNet +from .loss_functions.dice_loss import SoftDiceLoss from sklearn.model_selection import cross_val_score +from torch.optim.lr_scheduler import ReduceLROnPlateau +from .datasets.two_dim.NumpyDataLoader import NumpyDataSet from hyppopy.projectmanager import ProjectManager from hyppopy.workflows.workflowbase import WorkflowBase from hyppopy.workflows.dataloader.unetloader import UnetDataLoader class unet_usecase(WorkflowBase): def setup(self): dl = UnetDataLoader() dl.start(data_path=ProjectManager.data_path, data_name=ProjectManager.data_name, image_dir=ProjectManager.image_dir, labels_dir=ProjectManager.labels_dir, split_dir=ProjectManager.split_dir, output_dir=ProjectManager.data_path, num_classes=ProjectManager.num_classes) self.solver.set_data(dl.data) def blackbox_function(self, data, params): - pass + if "batch_size" in params.keys(): + params["batch_size"] = int(round(params["batch_size"])) + if "batch_size" in params.keys(): + params["batch_size"] = int(round(params["batch_size"])) + if "n_epochs" in params.keys(): + params["n_epochs"] = int(round(params["n_epochs"])) + + batch_size = 8 + patch_size = 64 + + tr_keys = data[ProjectManager.fold]['train'] + val_keys = data[ProjectManager.fold]['val'] + + data_dir = os.path.join(ProjectManager.data_path, *(ProjectManager.data_name, ProjectManager.preprocessed_dir)) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + train_data_loader = NumpyDataSet(data_dir, + target_size=patch_size, + batch_size=batch_size, + keys=tr_keys) + val_data_loader = NumpyDataSet(data_dir, + target_size=patch_size, + batch_size=batch_size, + keys=val_keys, + mode="val", + do_reshuffle=False) + + model = UNet(num_classes=ProjectManager.num_classes, + in_channels=ProjectManager.in_channels) + model.to(device) + + # We use a combination of DICE-loss and CE-Loss in this example. + # This proved good in the medical segmentation decathlon. + dice_loss = SoftDiceLoss(batch_dice=True) # Softmax für DICE Loss! + ce_loss = torch.nn.CrossEntropyLoss() # Kein Softmax für CE Loss -> ist in torch schon mit drin! + + optimizer = optim.Adam(model.parameters(), lr=params["learning_rate"]) + scheduler = ReduceLROnPlateau(optimizer, 'min') + + losses = [] + print(f"n_epochs {params['n_epochs']}") + for epoch in range(params["n_epochs"]): + #### Train #### + model.train() + data = None + batch_counter = 0 + for data_batch in train_data_loader: + optimizer.zero_grad() + + # Shape of data_batch = [1, b, c, w, h] + # Desired shape = [b, c, w, h] + # Move data and target to the GPU + data = data_batch['data'][0].float().to(device) + target = data_batch['seg'][0].long().to(device) + + pred = model(data) + pred_softmax = F.softmax(pred, dim=1) # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally. + + loss = dice_loss(pred_softmax, target.squeeze()) + ce_loss(pred, target.squeeze()) + loss.backward() + optimizer.step() + batch_counter += 1 + ############### + + #### Validate #### + model.eval() + data = None + loss_list = [] + with torch.no_grad(): + for data_batch in val_data_loader: + data = data_batch['data'][0].float().to(device) + target = data_batch['seg'][0].long().to(device) + + pred = model(data) + pred_softmax = F.softmax(pred) # We calculate a softmax, because our SoftDiceLoss expects that as an input. The CE-Loss does the softmax internally. + + loss = dice_loss(pred_softmax, target.squeeze()) + ce_loss(pred, target.squeeze()) + loss_list.append(loss.item()) + + assert data is not None, 'data is None. Please check if your dataloader works properly' + scheduler.step(np.mean(loss_list)) + losses.append(np.mean(loss_list)) + ################## + + return np.mean(losses) diff --git a/requirements_withunet.txt b/requirements_withunet.txt new file mode 100644 index 0000000..99da379 --- /dev/null +++ b/requirements_withunet.txt @@ -0,0 +1,113 @@ +alabaster==0.7.12 +atomicwrites==1.3.0 +attrs==18.2.0 +babel==2.6.0 +backcall==0.1.0 +batchgenerators==0.18.1 +bleach==3.1.0 +bz2file==0.98 +certifi==2018.11.29 +chardet==3.0.4 +Click==7.0 +cloudpickle==0.7.0 +colorama==0.4.1 +colorlover==0.3.0 +cycler==0.10.0 +dask==1.1.1 +decorator==4.3.2 +dicttoxml==1.7.4 +docutils==0.14 +entrypoints==0.3 +Flask==1.0.2 +future==0.17.1 +graphviz==0.10.1 +h5py==2.9.0 +hyperopt==0.1.1 +hyppopy==0.0.1 +idna==2.8 +imagesize==1.1.0 +ipykernel==5.1.0 +ipython==7.2.0 +ipython-genutils==0.2.0 +itsdangerous==1.1.0 +jedi==0.13.2 +Jinja2==2.10 +jsonschema==3.0.0a3 +jupyter-client==5.2.4 +jupyter-core==4.4.0 +kiwisolver==1.0.1 +linecache2==1.0.0 +llvmlite==0.27.0 +MarkupSafe==1.1.0 +matplotlib==3.0.2 +MedPy==0.3.0 +mistune==0.8.4 +more-itertools==5.0.0 +nbconvert==5.3.1 +nbformat==4.4.0 +networkx==2.2 +nibabel==2.3.3 +notebook==5.7.4 +numba==0.42.0 +numpy==1.16.0 +optunity==1.1.1 +packaging==19.0 +pandas==0.24.1 +pandocfilters==1.4.2 +parso==0.3.4 +pickleshare==0.7.5 +pillow==5.4.1 +plotly==3.6.1 +pluggy==0.8.1 +portalocker==1.4.0 +prometheus-client==0.5.0 +prompt-toolkit==2.0.8 +py==1.7.0 +pydicom==1.2.2 +Pygments==2.3.1 +pymongo==3.7.2 +pyparsing==2.3.1 +pypiwin32==223 +pyrsistent==0.14.10 +pytest==4.1.1 +python-dateutil==2.8.0 +python-telegram-bot==10.1.0 +pytz==2018.9 +pywavelets==1.0.1 +pywin32==224 +pywinpty==0.5.5 +pyzmq==17.1.2 +requests==2.21.0 +retrying==1.3.3 +scikit-image==0.14.2 +scikit-learn==0.20.2 +scipy==1.2.0 +seaborn==0.9.0 +Send2Trash==1.5.0 +six==1.12.0 +sklearn==0.0 +snowballstemmer==1.2.1 +Sphinx==1.8.3 +sphinxcontrib-websupport==1.1.0 +terminado==0.8.1 +testpath==0.4.2 +toolz==0.9.0 +torch==1.0.0 +torchfile==0.1.0 +torchvision==0.2.1 +tornado==5.1.1 +traceback2==1.4.0 +traitlets==4.3.2 +trixi==0.1.1.6 +umap-learn==0.3.7 +unittest2==1.1.0 +urllib3==1.24.1 +visdom==0.1.8.8 +wcwidth==0.1.7 +webencodings==0.5.1 +websocket-client==0.54.0 +Werkzeug==0.14.1 +wincertstore==0.2 +xmlrunner==1.7.7 +xmltodict==0.11.0 +Yapsy==1.11.223